[client-channel] fix use-after-free in promise based code (#34493)

Also eliminate a virtual because it's not super needed and was throwing
me off for a few hours.
pull/34536/head
Craig Tiller 1 year ago committed by GitHub
parent f01e6b7dee
commit 232611bfc2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 30
      src/core/ext/filters/client_channel/client_channel.cc
  2. 11
      src/core/ext/filters/client_channel/client_channel.h

@ -2668,7 +2668,7 @@ ServiceConfigCallData::CallAttributeInterface*
ClientChannel::LoadBalancedCall::LbCallState::GetCallAttribute(
UniqueTypeName type) const {
auto* service_config_call_data =
GetServiceConfigCallData(lb_call_->call_context());
GetServiceConfigCallData(lb_call_->call_context_);
return service_config_call_data->GetCallAttribute(type);
}
@ -2723,14 +2723,13 @@ class ClientChannel::LoadBalancedCall::BackendMetricAccessor
namespace {
ClientCallTracer::CallAttemptTracer* CreateCallAttemptTracer(
grpc_call_context_element* context, bool is_transparent_retry) {
void CreateCallAttemptTracer(grpc_call_context_element* context,
bool is_transparent_retry) {
auto* call_tracer = static_cast<ClientCallTracer*>(
context[GRPC_CONTEXT_CALL_TRACER_ANNOTATION_INTERFACE].value);
if (call_tracer == nullptr) return nullptr;
if (call_tracer == nullptr) return;
auto* tracer = call_tracer->StartNewAttempt(is_transparent_retry);
context[GRPC_CONTEXT_CALL_TRACER].value = tracer;
return tracer;
}
} // namespace
@ -2743,7 +2742,8 @@ ClientChannel::LoadBalancedCall::LoadBalancedCall(
? "LoadBalancedCall"
: nullptr),
chand_(chand),
on_commit_(std::move(on_commit)) {
on_commit_(std::move(on_commit)),
call_context_(call_context) {
CreateCallAttemptTracer(call_context, is_transparent_retry);
if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_lb_call_trace)) {
gpr_log(GPR_INFO, "chand=%p lb_call=%p: created", chand_, this);
@ -3005,7 +3005,6 @@ ClientChannel::FilterBasedLoadBalancedCall::FilterBasedLoadBalancedCall(
is_transparent_retry),
deadline_(args.deadline),
arena_(args.arena),
call_context_(args.context),
owning_call_(args.call_stack),
call_combiner_(args.call_combiner),
pollent_(pollent),
@ -3442,7 +3441,7 @@ void ClientChannel::FilterBasedLoadBalancedCall::CreateSubchannelCall() {
deadline_, arena_,
// TODO(roth): When we implement hedging support, we will probably
// need to use a separate call context for each subchannel call.
call_context_, call_combiner_};
call_context(), call_combiner_};
grpc_error_handle error;
subchannel_call_ = SubchannelCall::Create(std::move(call_args), &error);
if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_lb_call_trace)) {
@ -3491,12 +3490,14 @@ ClientChannel::PromiseBasedLoadBalancedCall::MakeCallPromise(
}
// Extract peer name from server initial metadata.
call_args.server_initial_metadata->InterceptAndMap(
[this](ServerMetadataHandle metadata) {
if (call_attempt_tracer() != nullptr) {
call_attempt_tracer()->RecordReceivedInitialMetadata(metadata.get());
[self = RefCountedPtr<PromiseBasedLoadBalancedCall>(lb_call->Ref())](
ServerMetadataHandle metadata) {
if (self->call_attempt_tracer() != nullptr) {
self->call_attempt_tracer()->RecordReceivedInitialMetadata(
metadata.get());
}
Slice* peer_string = metadata->get_pointer(PeerString());
if (peer_string != nullptr) peer_string_ = peer_string->Ref();
if (peer_string != nullptr) self->peer_string_ = peer_string->Ref();
return metadata;
});
client_initial_metadata_ = std::move(call_args.client_initial_metadata);
@ -3587,11 +3588,6 @@ Arena* ClientChannel::PromiseBasedLoadBalancedCall::arena() const {
return GetContext<Arena>();
}
grpc_call_context_element*
ClientChannel::PromiseBasedLoadBalancedCall::call_context() const {
return GetContext<grpc_call_context_element>();
}
grpc_metadata_batch*
ClientChannel::PromiseBasedLoadBalancedCall::send_initial_metadata() const {
return client_initial_metadata_.get();

@ -407,7 +407,7 @@ class ClientChannel::LoadBalancedCall
ClientChannel* chand() const { return chand_; }
ClientCallTracer::CallAttemptTracer* call_attempt_tracer() const {
return static_cast<ClientCallTracer::CallAttemptTracer*>(
call_context()[GRPC_CONTEXT_CALL_TRACER].value);
call_context_[GRPC_CONTEXT_CALL_TRACER].value);
}
ConnectedSubchannel* connected_subchannel() const {
return connected_subchannel_.get();
@ -441,13 +441,14 @@ class ClientChannel::LoadBalancedCall
void RecordLatency();
grpc_call_context_element* call_context() const { return call_context_; }
private:
class LbCallState;
class Metadata;
class BackendMetricAccessor;
virtual Arena* arena() const = 0;
virtual grpc_call_context_element* call_context() const = 0;
virtual grpc_polling_entity* pollent() = 0;
virtual grpc_metadata_batch* send_initial_metadata() const = 0;
@ -473,6 +474,7 @@ class ClientChannel::LoadBalancedCall
const BackendMetricData* backend_metric_data_ = nullptr;
std::unique_ptr<LoadBalancingPolicy::SubchannelCallTrackerInterface>
lb_subchannel_call_tracker_;
grpc_call_context_element* const call_context_;
};
class ClientChannel::FilterBasedLoadBalancedCall
@ -509,9 +511,6 @@ class ClientChannel::FilterBasedLoadBalancedCall
using LoadBalancedCall::Commit;
Arena* arena() const override { return arena_; }
grpc_call_context_element* call_context() const override {
return call_context_;
}
grpc_polling_entity* pollent() override { return pollent_; }
grpc_metadata_batch* send_initial_metadata() const override {
return pending_batches_[0]
@ -568,7 +567,6 @@ class ClientChannel::FilterBasedLoadBalancedCall
// context. This will save per-call memory overhead.
Timestamp deadline_;
Arena* arena_;
grpc_call_context_element* call_context_;
grpc_call_stack* owning_call_;
CallCombiner* call_combiner_;
grpc_polling_entity* pollent_;
@ -618,7 +616,6 @@ class ClientChannel::PromiseBasedLoadBalancedCall
private:
Arena* arena() const override;
grpc_call_context_element* call_context() const override;
grpc_polling_entity* pollent() override { return &pollent_; }
grpc_metadata_batch* send_initial_metadata() const override;

Loading…
Cancel
Save