[client-channel] fix use-after-free in promise based code (#34493)

Also eliminate a virtual because it's not super needed and was throwing
me off for a few hours.
pull/34536/head
Craig Tiller 1 year ago committed by GitHub
parent f01e6b7dee
commit 232611bfc2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 30
      src/core/ext/filters/client_channel/client_channel.cc
  2. 11
      src/core/ext/filters/client_channel/client_channel.h

@ -2668,7 +2668,7 @@ ServiceConfigCallData::CallAttributeInterface*
ClientChannel::LoadBalancedCall::LbCallState::GetCallAttribute( ClientChannel::LoadBalancedCall::LbCallState::GetCallAttribute(
UniqueTypeName type) const { UniqueTypeName type) const {
auto* service_config_call_data = auto* service_config_call_data =
GetServiceConfigCallData(lb_call_->call_context()); GetServiceConfigCallData(lb_call_->call_context_);
return service_config_call_data->GetCallAttribute(type); return service_config_call_data->GetCallAttribute(type);
} }
@ -2723,14 +2723,13 @@ class ClientChannel::LoadBalancedCall::BackendMetricAccessor
namespace { namespace {
ClientCallTracer::CallAttemptTracer* CreateCallAttemptTracer( void CreateCallAttemptTracer(grpc_call_context_element* context,
grpc_call_context_element* context, bool is_transparent_retry) { bool is_transparent_retry) {
auto* call_tracer = static_cast<ClientCallTracer*>( auto* call_tracer = static_cast<ClientCallTracer*>(
context[GRPC_CONTEXT_CALL_TRACER_ANNOTATION_INTERFACE].value); context[GRPC_CONTEXT_CALL_TRACER_ANNOTATION_INTERFACE].value);
if (call_tracer == nullptr) return nullptr; if (call_tracer == nullptr) return;
auto* tracer = call_tracer->StartNewAttempt(is_transparent_retry); auto* tracer = call_tracer->StartNewAttempt(is_transparent_retry);
context[GRPC_CONTEXT_CALL_TRACER].value = tracer; context[GRPC_CONTEXT_CALL_TRACER].value = tracer;
return tracer;
} }
} // namespace } // namespace
@ -2743,7 +2742,8 @@ ClientChannel::LoadBalancedCall::LoadBalancedCall(
? "LoadBalancedCall" ? "LoadBalancedCall"
: nullptr), : nullptr),
chand_(chand), chand_(chand),
on_commit_(std::move(on_commit)) { on_commit_(std::move(on_commit)),
call_context_(call_context) {
CreateCallAttemptTracer(call_context, is_transparent_retry); CreateCallAttemptTracer(call_context, is_transparent_retry);
if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_lb_call_trace)) { if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_lb_call_trace)) {
gpr_log(GPR_INFO, "chand=%p lb_call=%p: created", chand_, this); gpr_log(GPR_INFO, "chand=%p lb_call=%p: created", chand_, this);
@ -3005,7 +3005,6 @@ ClientChannel::FilterBasedLoadBalancedCall::FilterBasedLoadBalancedCall(
is_transparent_retry), is_transparent_retry),
deadline_(args.deadline), deadline_(args.deadline),
arena_(args.arena), arena_(args.arena),
call_context_(args.context),
owning_call_(args.call_stack), owning_call_(args.call_stack),
call_combiner_(args.call_combiner), call_combiner_(args.call_combiner),
pollent_(pollent), pollent_(pollent),
@ -3442,7 +3441,7 @@ void ClientChannel::FilterBasedLoadBalancedCall::CreateSubchannelCall() {
deadline_, arena_, deadline_, arena_,
// TODO(roth): When we implement hedging support, we will probably // TODO(roth): When we implement hedging support, we will probably
// need to use a separate call context for each subchannel call. // need to use a separate call context for each subchannel call.
call_context_, call_combiner_}; call_context(), call_combiner_};
grpc_error_handle error; grpc_error_handle error;
subchannel_call_ = SubchannelCall::Create(std::move(call_args), &error); subchannel_call_ = SubchannelCall::Create(std::move(call_args), &error);
if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_lb_call_trace)) { if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_lb_call_trace)) {
@ -3491,12 +3490,14 @@ ClientChannel::PromiseBasedLoadBalancedCall::MakeCallPromise(
} }
// Extract peer name from server initial metadata. // Extract peer name from server initial metadata.
call_args.server_initial_metadata->InterceptAndMap( call_args.server_initial_metadata->InterceptAndMap(
[this](ServerMetadataHandle metadata) { [self = RefCountedPtr<PromiseBasedLoadBalancedCall>(lb_call->Ref())](
if (call_attempt_tracer() != nullptr) { ServerMetadataHandle metadata) {
call_attempt_tracer()->RecordReceivedInitialMetadata(metadata.get()); if (self->call_attempt_tracer() != nullptr) {
self->call_attempt_tracer()->RecordReceivedInitialMetadata(
metadata.get());
} }
Slice* peer_string = metadata->get_pointer(PeerString()); Slice* peer_string = metadata->get_pointer(PeerString());
if (peer_string != nullptr) peer_string_ = peer_string->Ref(); if (peer_string != nullptr) self->peer_string_ = peer_string->Ref();
return metadata; return metadata;
}); });
client_initial_metadata_ = std::move(call_args.client_initial_metadata); client_initial_metadata_ = std::move(call_args.client_initial_metadata);
@ -3587,11 +3588,6 @@ Arena* ClientChannel::PromiseBasedLoadBalancedCall::arena() const {
return GetContext<Arena>(); return GetContext<Arena>();
} }
grpc_call_context_element*
ClientChannel::PromiseBasedLoadBalancedCall::call_context() const {
return GetContext<grpc_call_context_element>();
}
grpc_metadata_batch* grpc_metadata_batch*
ClientChannel::PromiseBasedLoadBalancedCall::send_initial_metadata() const { ClientChannel::PromiseBasedLoadBalancedCall::send_initial_metadata() const {
return client_initial_metadata_.get(); return client_initial_metadata_.get();

@ -407,7 +407,7 @@ class ClientChannel::LoadBalancedCall
ClientChannel* chand() const { return chand_; } ClientChannel* chand() const { return chand_; }
ClientCallTracer::CallAttemptTracer* call_attempt_tracer() const { ClientCallTracer::CallAttemptTracer* call_attempt_tracer() const {
return static_cast<ClientCallTracer::CallAttemptTracer*>( return static_cast<ClientCallTracer::CallAttemptTracer*>(
call_context()[GRPC_CONTEXT_CALL_TRACER].value); call_context_[GRPC_CONTEXT_CALL_TRACER].value);
} }
ConnectedSubchannel* connected_subchannel() const { ConnectedSubchannel* connected_subchannel() const {
return connected_subchannel_.get(); return connected_subchannel_.get();
@ -441,13 +441,14 @@ class ClientChannel::LoadBalancedCall
void RecordLatency(); void RecordLatency();
grpc_call_context_element* call_context() const { return call_context_; }
private: private:
class LbCallState; class LbCallState;
class Metadata; class Metadata;
class BackendMetricAccessor; class BackendMetricAccessor;
virtual Arena* arena() const = 0; virtual Arena* arena() const = 0;
virtual grpc_call_context_element* call_context() const = 0;
virtual grpc_polling_entity* pollent() = 0; virtual grpc_polling_entity* pollent() = 0;
virtual grpc_metadata_batch* send_initial_metadata() const = 0; virtual grpc_metadata_batch* send_initial_metadata() const = 0;
@ -473,6 +474,7 @@ class ClientChannel::LoadBalancedCall
const BackendMetricData* backend_metric_data_ = nullptr; const BackendMetricData* backend_metric_data_ = nullptr;
std::unique_ptr<LoadBalancingPolicy::SubchannelCallTrackerInterface> std::unique_ptr<LoadBalancingPolicy::SubchannelCallTrackerInterface>
lb_subchannel_call_tracker_; lb_subchannel_call_tracker_;
grpc_call_context_element* const call_context_;
}; };
class ClientChannel::FilterBasedLoadBalancedCall class ClientChannel::FilterBasedLoadBalancedCall
@ -509,9 +511,6 @@ class ClientChannel::FilterBasedLoadBalancedCall
using LoadBalancedCall::Commit; using LoadBalancedCall::Commit;
Arena* arena() const override { return arena_; } Arena* arena() const override { return arena_; }
grpc_call_context_element* call_context() const override {
return call_context_;
}
grpc_polling_entity* pollent() override { return pollent_; } grpc_polling_entity* pollent() override { return pollent_; }
grpc_metadata_batch* send_initial_metadata() const override { grpc_metadata_batch* send_initial_metadata() const override {
return pending_batches_[0] return pending_batches_[0]
@ -568,7 +567,6 @@ class ClientChannel::FilterBasedLoadBalancedCall
// context. This will save per-call memory overhead. // context. This will save per-call memory overhead.
Timestamp deadline_; Timestamp deadline_;
Arena* arena_; Arena* arena_;
grpc_call_context_element* call_context_;
grpc_call_stack* owning_call_; grpc_call_stack* owning_call_;
CallCombiner* call_combiner_; CallCombiner* call_combiner_;
grpc_polling_entity* pollent_; grpc_polling_entity* pollent_;
@ -618,7 +616,6 @@ class ClientChannel::PromiseBasedLoadBalancedCall
private: private:
Arena* arena() const override; Arena* arena() const override;
grpc_call_context_element* call_context() const override;
grpc_polling_entity* pollent() override { return &pollent_; } grpc_polling_entity* pollent() override { return &pollent_; }
grpc_metadata_batch* send_initial_metadata() const override; grpc_metadata_batch* send_initial_metadata() const override;

Loading…
Cancel
Save