From 232611bfc2421e0fb31540a34dc4be5bfa970f4a Mon Sep 17 00:00:00 2001 From: Craig Tiller Date: Wed, 27 Sep 2023 14:24:35 -0700 Subject: [PATCH] [client-channel] fix use-after-free in promise based code (#34493) Also eliminate a virtual because it's not super needed and was throwing me off for a few hours. --- .../filters/client_channel/client_channel.cc | 30 ++++++++----------- .../filters/client_channel/client_channel.h | 11 +++---- 2 files changed, 17 insertions(+), 24 deletions(-) diff --git a/src/core/ext/filters/client_channel/client_channel.cc b/src/core/ext/filters/client_channel/client_channel.cc index b315bf19050..5a3810fd9a2 100644 --- a/src/core/ext/filters/client_channel/client_channel.cc +++ b/src/core/ext/filters/client_channel/client_channel.cc @@ -2668,7 +2668,7 @@ ServiceConfigCallData::CallAttributeInterface* ClientChannel::LoadBalancedCall::LbCallState::GetCallAttribute( UniqueTypeName type) const { auto* service_config_call_data = - GetServiceConfigCallData(lb_call_->call_context()); + GetServiceConfigCallData(lb_call_->call_context_); return service_config_call_data->GetCallAttribute(type); } @@ -2723,14 +2723,13 @@ class ClientChannel::LoadBalancedCall::BackendMetricAccessor namespace { -ClientCallTracer::CallAttemptTracer* CreateCallAttemptTracer( - grpc_call_context_element* context, bool is_transparent_retry) { +void CreateCallAttemptTracer(grpc_call_context_element* context, + bool is_transparent_retry) { auto* call_tracer = static_cast( context[GRPC_CONTEXT_CALL_TRACER_ANNOTATION_INTERFACE].value); - if (call_tracer == nullptr) return nullptr; + if (call_tracer == nullptr) return; auto* tracer = call_tracer->StartNewAttempt(is_transparent_retry); context[GRPC_CONTEXT_CALL_TRACER].value = tracer; - return tracer; } } // namespace @@ -2743,7 +2742,8 @@ ClientChannel::LoadBalancedCall::LoadBalancedCall( ? "LoadBalancedCall" : nullptr), chand_(chand), - on_commit_(std::move(on_commit)) { + on_commit_(std::move(on_commit)), + call_context_(call_context) { CreateCallAttemptTracer(call_context, is_transparent_retry); if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_lb_call_trace)) { gpr_log(GPR_INFO, "chand=%p lb_call=%p: created", chand_, this); @@ -3005,7 +3005,6 @@ ClientChannel::FilterBasedLoadBalancedCall::FilterBasedLoadBalancedCall( is_transparent_retry), deadline_(args.deadline), arena_(args.arena), - call_context_(args.context), owning_call_(args.call_stack), call_combiner_(args.call_combiner), pollent_(pollent), @@ -3442,7 +3441,7 @@ void ClientChannel::FilterBasedLoadBalancedCall::CreateSubchannelCall() { deadline_, arena_, // TODO(roth): When we implement hedging support, we will probably // need to use a separate call context for each subchannel call. - call_context_, call_combiner_}; + call_context(), call_combiner_}; grpc_error_handle error; subchannel_call_ = SubchannelCall::Create(std::move(call_args), &error); if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_lb_call_trace)) { @@ -3491,12 +3490,14 @@ ClientChannel::PromiseBasedLoadBalancedCall::MakeCallPromise( } // Extract peer name from server initial metadata. call_args.server_initial_metadata->InterceptAndMap( - [this](ServerMetadataHandle metadata) { - if (call_attempt_tracer() != nullptr) { - call_attempt_tracer()->RecordReceivedInitialMetadata(metadata.get()); + [self = RefCountedPtr(lb_call->Ref())]( + ServerMetadataHandle metadata) { + if (self->call_attempt_tracer() != nullptr) { + self->call_attempt_tracer()->RecordReceivedInitialMetadata( + metadata.get()); } Slice* peer_string = metadata->get_pointer(PeerString()); - if (peer_string != nullptr) peer_string_ = peer_string->Ref(); + if (peer_string != nullptr) self->peer_string_ = peer_string->Ref(); return metadata; }); client_initial_metadata_ = std::move(call_args.client_initial_metadata); @@ -3587,11 +3588,6 @@ Arena* ClientChannel::PromiseBasedLoadBalancedCall::arena() const { return GetContext(); } -grpc_call_context_element* -ClientChannel::PromiseBasedLoadBalancedCall::call_context() const { - return GetContext(); -} - grpc_metadata_batch* ClientChannel::PromiseBasedLoadBalancedCall::send_initial_metadata() const { return client_initial_metadata_.get(); diff --git a/src/core/ext/filters/client_channel/client_channel.h b/src/core/ext/filters/client_channel/client_channel.h index 6597200cbc5..f54c5db4669 100644 --- a/src/core/ext/filters/client_channel/client_channel.h +++ b/src/core/ext/filters/client_channel/client_channel.h @@ -407,7 +407,7 @@ class ClientChannel::LoadBalancedCall ClientChannel* chand() const { return chand_; } ClientCallTracer::CallAttemptTracer* call_attempt_tracer() const { return static_cast( - call_context()[GRPC_CONTEXT_CALL_TRACER].value); + call_context_[GRPC_CONTEXT_CALL_TRACER].value); } ConnectedSubchannel* connected_subchannel() const { return connected_subchannel_.get(); @@ -441,13 +441,14 @@ class ClientChannel::LoadBalancedCall void RecordLatency(); + grpc_call_context_element* call_context() const { return call_context_; } + private: class LbCallState; class Metadata; class BackendMetricAccessor; virtual Arena* arena() const = 0; - virtual grpc_call_context_element* call_context() const = 0; virtual grpc_polling_entity* pollent() = 0; virtual grpc_metadata_batch* send_initial_metadata() const = 0; @@ -473,6 +474,7 @@ class ClientChannel::LoadBalancedCall const BackendMetricData* backend_metric_data_ = nullptr; std::unique_ptr lb_subchannel_call_tracker_; + grpc_call_context_element* const call_context_; }; class ClientChannel::FilterBasedLoadBalancedCall @@ -509,9 +511,6 @@ class ClientChannel::FilterBasedLoadBalancedCall using LoadBalancedCall::Commit; Arena* arena() const override { return arena_; } - grpc_call_context_element* call_context() const override { - return call_context_; - } grpc_polling_entity* pollent() override { return pollent_; } grpc_metadata_batch* send_initial_metadata() const override { return pending_batches_[0] @@ -568,7 +567,6 @@ class ClientChannel::FilterBasedLoadBalancedCall // context. This will save per-call memory overhead. Timestamp deadline_; Arena* arena_; - grpc_call_context_element* call_context_; grpc_call_stack* owning_call_; CallCombiner* call_combiner_; grpc_polling_entity* pollent_; @@ -618,7 +616,6 @@ class ClientChannel::PromiseBasedLoadBalancedCall private: Arena* arena() const override; - grpc_call_context_element* call_context() const override; grpc_polling_entity* pollent() override { return &pollent_; } grpc_metadata_batch* send_initial_metadata() const override;