diff --git a/src/core/ext/filters/client_channel/client_channel.cc b/src/core/ext/filters/client_channel/client_channel.cc index fe1a5a2e4eb..cc34178d619 100644 --- a/src/core/ext/filters/client_channel/client_channel.cc +++ b/src/core/ext/filters/client_channel/client_channel.cc @@ -545,13 +545,6 @@ struct call_data { bool have_request = false; grpc_closure pick_closure; - // A closure to fork notifying the lb interceptor and run the original trailer - // interception callback. - grpc_closure recv_trailing_metadata_ready_for_lb; - // The original trailer interception callback. - grpc_closure* original_recv_trailing_metadata_ready = nullptr; - grpc_transport_stream_op_batch* recv_trailing_metadata_op_batch = nullptr; - grpc_polling_entity* pollent = nullptr; // Batches are added to this list when received from above. @@ -612,8 +605,6 @@ static void start_internal_recv_trailing_metadata(grpc_call_element* elem); static void on_complete(void* arg, grpc_error* error); static void start_retriable_subchannel_batches(void* arg, grpc_error* ignored); static void start_pick_locked(void* arg, grpc_error* ignored); -static void maybe_intercept_trailing_metadata_for_lb( - grpc_call_element* arg, grpc_transport_stream_op_batch* batch); // // send op data caching @@ -736,6 +727,25 @@ static void free_cached_send_op_data_for_completed_batch( } } +// +// LB recv_trailing_metadata_ready handling +// + +void maybe_inject_recv_trailing_metadata_ready_for_lb( + const grpc_core::LoadBalancingPolicy::PickState& pick, + grpc_transport_stream_op_batch* batch) { + if (pick.recv_trailing_metadata_ready != nullptr) { + *pick.original_recv_trailing_metadata_ready = + batch->payload->recv_trailing_metadata.recv_trailing_metadata_ready; + batch->payload->recv_trailing_metadata.recv_trailing_metadata_ready = + pick.recv_trailing_metadata_ready; + if (pick.recv_trailing_metadata != nullptr) { + *pick.recv_trailing_metadata = + batch->payload->recv_trailing_metadata.recv_trailing_metadata; + } + } +} + // // pending_batches management // @@ -860,6 +870,10 @@ static void pending_batches_fail(grpc_call_element* elem, grpc_error* error, pending_batch* pending = &calld->pending_batches[i]; grpc_transport_stream_op_batch* batch = pending->batch; if (batch != nullptr) { + if (batch->recv_trailing_metadata) { + maybe_inject_recv_trailing_metadata_ready_for_lb( + *calld->request->pick(), batch); + } batch->handler_private.extra_arg = calld; GRPC_CLOSURE_INIT(&batch->handler_private.closure, fail_pending_batch_in_call_combiner, batch, @@ -912,7 +926,10 @@ static void pending_batches_resume(grpc_call_element* elem) { pending_batch* pending = &calld->pending_batches[i]; grpc_transport_stream_op_batch* batch = pending->batch; if (batch != nullptr) { - maybe_intercept_trailing_metadata_for_lb(elem, batch); + if (batch->recv_trailing_metadata) { + maybe_inject_recv_trailing_metadata_ready_for_lb( + *calld->request->pick(), batch); + } batch->handler_private.extra_arg = calld->subchannel_call; GRPC_CLOSURE_INIT(&batch->handler_private.closure, resume_pending_batch_in_call_combiner, batch, @@ -1582,8 +1599,7 @@ static void run_closures_for_completed_call(subchannel_batch_data* batch_data, // Intercepts recv_trailing_metadata_ready callback for retries. // Commits the call and returns the trailing metadata up the stack. -static void recv_trailing_metadata_ready_for_retries( - void* arg, grpc_error* error) { +static void recv_trailing_metadata_ready(void* arg, grpc_error* error) { subchannel_batch_data* batch_data = static_cast(arg); grpc_call_element* elem = batch_data->elem; channel_data* chand = static_cast(elem->channel_data); @@ -1603,16 +1619,6 @@ static void recv_trailing_metadata_ready_for_retries( grpc_mdelem* server_pushback_md = nullptr; grpc_metadata_batch* md_batch = batch_data->batch.payload->recv_trailing_metadata.recv_trailing_metadata; - // If the lb policy asks for the trailing metadata, set its receiving ptr - if (calld->pick.recv_trailing_metadata != nullptr) { - *calld->pick.recv_trailing_metadata = md_batch; - } - // We use GRPC_CLOSURE_RUN synchronously on the callback. In the case of - // a retry, we would have already freed the metadata before returning from - // this function. - GRPC_CLOSURE_RUN( - calld->pick.recv_trailing_metadata_ready, - GRPC_ERROR_REF(error)); get_call_status(elem, md_batch, GRPC_ERROR_REF(error), &status, &server_pushback_md); if (grpc_client_channel_trace.enabled()) { @@ -1948,11 +1954,13 @@ static void add_retriable_recv_trailing_metadata_op( batch_data->batch.payload->recv_trailing_metadata.collect_stats = &retry_state->collect_stats; GRPC_CLOSURE_INIT(&retry_state->recv_trailing_metadata_ready, - recv_trailing_metadata_ready_for_retries, batch_data, + recv_trailing_metadata_ready, batch_data, grpc_schedule_on_exec_ctx); batch_data->batch.payload->recv_trailing_metadata .recv_trailing_metadata_ready = &retry_state->recv_trailing_metadata_ready; + maybe_inject_recv_trailing_metadata_ready_for_lb(*calld->request->pick(), + &batch_data->batch); } // Helper function used to start a recv_trailing_metadata batch. This @@ -2222,45 +2230,6 @@ static void start_retriable_subchannel_batches(void* arg, grpc_error* ignored) { // LB pick // -// The callback to intercept trailing metadata if retries is not enabled -static void recv_trailing_metadata_ready_for_lb(void* arg, grpc_error* error) { - grpc_call_element* elem = static_cast(arg); - call_data* calld = static_cast(elem->call_data); - if (calld->pick.recv_trailing_metadata != nullptr) { - *calld->pick.recv_trailing_metadata = - calld->recv_trailing_metadata_op_batch->payload - ->recv_trailing_metadata.recv_trailing_metadata; - } - GRPC_CLOSURE_SCHED( - calld->pick.recv_trailing_metadata_ready, - GRPC_ERROR_REF(error)); - GRPC_CLOSURE_SCHED( - calld->original_recv_trailing_metadata_ready, - GRPC_ERROR_REF(error)); - GRPC_ERROR_UNREF(error); -} - -// If needed, intercepts the recv_trailing_metadata_ready callback to return -// trailing metadata to the LB policy. -static void maybe_intercept_trailing_metadata_for_lb( - grpc_call_element* elem, grpc_transport_stream_op_batch* batch) { - call_data* calld = static_cast(elem->call_data); - if (!batch->recv_trailing_metadata) { - return; - } - if (calld->pick.recv_trailing_metadata_ready != nullptr) { - calld->recv_trailing_metadata_op_batch = batch; - GRPC_CLOSURE_INIT(&calld->recv_trailing_metadata_ready_for_lb, - recv_trailing_metadata_ready_for_lb, - elem, - grpc_schedule_on_exec_ctx); - calld->original_recv_trailing_metadata_ready = - batch->payload->recv_trailing_metadata.recv_trailing_metadata_ready; - batch->payload->recv_trailing_metadata.recv_trailing_metadata_ready = - &calld->recv_trailing_metadata_ready_for_lb; - } -} - static void create_subchannel_call(grpc_call_element* elem, grpc_error* error) { channel_data* chand = static_cast(elem->channel_data); call_data* calld = static_cast(elem->call_data); diff --git a/src/core/ext/filters/client_channel/lb_policy.h b/src/core/ext/filters/client_channel/lb_policy.h index 709eee7de83..dea8f4fa69f 100644 --- a/src/core/ext/filters/client_channel/lb_policy.h +++ b/src/core/ext/filters/client_channel/lb_policy.h @@ -77,6 +77,11 @@ class LoadBalancingPolicy : public InternallyRefCounted { // Callback set by lb policy to be notified of trailing metadata. // The callback must be scheduled on grpc_schedule_on_exec_ctx. grpc_closure* recv_trailing_metadata_ready = nullptr; + // The address that will be set to point to the original + // recv_trailing_metadata_ready callback, to be invoked by the LB + // policy's recv_trailing_metadata_ready callback when complete. + // Must be non-null if recv_trailing_metadata_ready is non-null. + grpc_closure** original_recv_trailing_metadata_ready = nullptr; // If this is not nullptr, then the client channel will point it to the // call's trailing metadata before invoking recv_trailing_metadata_ready. // If this is nullptr, then the callback will still be called. diff --git a/test/cpp/end2end/client_lb_end2end_test.cc b/test/cpp/end2end/client_lb_end2end_test.cc index bdc4d8edf67..328f28e3db6 100644 --- a/test/cpp/end2end/client_lb_end2end_test.cc +++ b/test/cpp/end2end/client_lb_end2end_test.cc @@ -35,24 +35,25 @@ #include #include +#include "src/core/ext/filters/client_channel/lb_policy.h" +#include "src/core/ext/filters/client_channel/lb_policy_registry.h" #include "src/core/ext/filters/client_channel/parse_address.h" #include "src/core/ext/filters/client_channel/resolver/fake/fake_resolver.h" #include "src/core/ext/filters/client_channel/server_address.h" #include "src/core/ext/filters/client_channel/subchannel_index.h" -#include "src/core/ext/filters/client_channel/lb_policy_registry.h" #include "src/core/lib/backoff/backoff.h" #include "src/core/lib/channel/channelz.h" -#include "src/core/lib/iomgr/closure.h" -#include "src/core/lib/iomgr/error.h" #include "src/core/lib/gpr/env.h" #include "src/core/lib/gprpp/debug_location.h" #include "src/core/lib/gprpp/orphanable.h" #include "src/core/lib/gprpp/ref_counted_ptr.h" +#include "src/core/lib/iomgr/closure.h" +#include "src/core/lib/iomgr/error.h" #include "src/core/lib/iomgr/tcp_client.h" +#include "src/core/lib/security/credentials/fake/fake_credentials.h" #include "src/core/lib/transport/connectivity_state.h" #include "src/core/lib/transport/static_metadata.h" #include "src/core/lib/transport/status_metadata.h" -#include "src/core/lib/security/credentials/fake/fake_credentials.h" #include "src/cpp/client/secure_credentials.h" #include "src/cpp/server/secure_server_credentials.h" @@ -61,7 +62,6 @@ #include "test/core/util/test_config.h" #include "test/cpp/end2end/test_service_impl.h" - #include using grpc::testing::EchoRequest; @@ -1231,22 +1231,32 @@ TEST_F(ClientLbEnd2endTest, RoundRobinWithHealthCheckingInhibitPerChannel) { EnableDefaultHealthCheckService(false); } +grpc_core::TraceFlag forwarding_lb_tracer(false, "forwarding_lb"); + // A minimal forwarding class to avoid implementing a standalone test LB. class ForwardingLoadBalancingPolicy : public grpc_core::LoadBalancingPolicy { public: - ForwardingLoadBalancingPolicy( - const Args& args, - const std::string& delegate_policy_name) - : grpc_core::LoadBalancingPolicy(args), args_{args} { - delegate_ = grpc_core::LoadBalancingPolicyRegistry - ::CreateLoadBalancingPolicy(delegate_policy_name.c_str(), args); - grpc_pollset_set_add_pollset_set( - delegate_->interested_parties(), - interested_parties()); + ForwardingLoadBalancingPolicy(const Args& args, + const std::string& delegate_policy_name) + : grpc_core::LoadBalancingPolicy(args) { + delegate_ = + grpc_core::LoadBalancingPolicyRegistry::CreateLoadBalancingPolicy( + delegate_policy_name.c_str(), args); + grpc_pollset_set_add_pollset_set(delegate_->interested_parties(), + interested_parties()); + // Give re-resolution closure to delegate. + GRPC_CLOSURE_INIT(&on_delegate_request_reresolution_, + OnDelegateRequestReresolutionLocked, this, + grpc_combiner_scheduler(combiner())); + Ref().release(); // held by callback. + delegate_->SetReresolutionClosureLocked(&on_delegate_request_reresolution_); } - void UpdateLocked(const grpc_channel_args& args) override { - delegate_->UpdateLocked(args); + const char* name() const override { return delegate_->name(); } + + void UpdateLocked(const grpc_channel_args& args, + grpc_json* lb_config) override { + delegate_->UpdateLocked(args, lb_config); } bool PickLocked(PickState* pick, grpc_error** error) override { @@ -1260,10 +1270,8 @@ class ForwardingLoadBalancingPolicy : public grpc_core::LoadBalancingPolicy { void CancelMatchingPicksLocked(uint32_t initial_metadata_flags_mask, uint32_t initial_metadata_flags_eq, grpc_error* error) override { - delegate_->CancelMatchingPicksLocked( - initial_metadata_flags_mask, - initial_metadata_flags_eq, - error); + delegate_->CancelMatchingPicksLocked(initial_metadata_flags_mask, + initial_metadata_flags_eq, error); } void NotifyOnStateChangeLocked(grpc_connectivity_state* state, @@ -1280,13 +1288,9 @@ class ForwardingLoadBalancingPolicy : public grpc_core::LoadBalancingPolicy { delegate_->HandOffPendingPicksLocked(new_policy); } - void ExitIdleLocked() override{ - delegate_->ExitIdleLocked(); - } + void ExitIdleLocked() override { delegate_->ExitIdleLocked(); } - void ResetBackoffLocked() override { - delegate_->ResetBackoffLocked(); - } + void ResetBackoffLocked() override { delegate_->ResetBackoffLocked(); } void FillChildRefsForChannelz( grpc_core::channelz::ChildRefsList* child_subchannels, @@ -1295,13 +1299,24 @@ class ForwardingLoadBalancingPolicy : public grpc_core::LoadBalancingPolicy { } protected: - void ShutdownLocked() override { - // noop - } - Args args_; + void ShutdownLocked() override { delegate_.reset(); } private: + static void OnDelegateRequestReresolutionLocked(void* arg, + grpc_error* error) { + ForwardingLoadBalancingPolicy* self = + static_cast(arg); + if (error != GRPC_ERROR_NONE || self->delegate_ == nullptr) { + self->Unref(); + return; + } + self->TryReresolutionLocked(&forwarding_lb_tracer, GRPC_ERROR_NONE); + self->delegate_->SetReresolutionClosureLocked( + &self->on_delegate_request_reresolution_); + } + grpc_core::OrphanablePtr delegate_; + grpc_closure on_delegate_request_reresolution_; }; class ClientLbInterceptTrailingMetadataTest : public ClientLbEnd2endTest { @@ -1314,71 +1329,81 @@ class ClientLbInterceptTrailingMetadataTest : public ClientLbEnd2endTest { grpc_core::New(this))); } - void TearDown() override { - ClientLbEnd2endTest::TearDown(); - } + void TearDown() override { ClientLbEnd2endTest::TearDown(); } class InterceptTrailingLb : public ForwardingLoadBalancingPolicy { public: - InterceptTrailingLb( - const Args& args, - const std::string& delegate_lb_policy_name, - ClientLbInterceptTrailingMetadataTest* test) + InterceptTrailingLb(const Args& args, + const std::string& delegate_lb_policy_name, + ClientLbInterceptTrailingMetadataTest* test) : ForwardingLoadBalancingPolicy(args, delegate_lb_policy_name), - test_{test} { - } + test_(test) {} bool PickLocked(PickState* pick, grpc_error** error) override { bool ret = ForwardingLoadBalancingPolicy::PickLocked(pick, error); - // If these asserts fail, then we will need to add code to - // proxy the results to the delegate LB. - GPR_ASSERT(pick->recv_trailing_metadata == nullptr); - GPR_ASSERT(pick->recv_trailing_metadata_ready == nullptr); - // OK to add add callbacks for test - GRPC_CLOSURE_INIT( - &recv_trailing_metadata_ready_, - InterceptTrailingLb::RecordRecvTrailingMetadata, - this, - grpc_schedule_on_exec_ctx); - pick->recv_trailing_metadata_ready = &recv_trailing_metadata_ready_; - pick->recv_trailing_metadata = &recv_trailing_metadata_; + // Note: This assumes that the delegate policy does not + // intercepting recv_trailing_metadata. If we ever need to use + // this with a delegate policy that does, then we'll need to + // handle async pick returns separately. + new TrailingMetadataHandler(pick, test_); // deletes itself return ret; } - static void RecordRecvTrailingMetadata(void* arg, grpc_error* err) { - InterceptTrailingLb* lb = static_cast(arg); - GPR_ASSERT(err == GRPC_ERROR_NONE); - GPR_ASSERT(lb->recv_trailing_metadata_ != nullptr); - // an simple check to make sure the trailing metadata is valid - GPR_ASSERT(grpc_get_status_code_from_metadata( - lb->recv_trailing_metadata_->idx.named.grpc_status->md) == - grpc_status_code::GRPC_STATUS_OK); - GRPC_ERROR_UNREF(err); - lb->test_->ReportTrailerIntercepted(); - } - private: - grpc_closure recv_trailing_metadata_ready_; - grpc_metadata_batch* recv_trailing_metadata_; + class TrailingMetadataHandler { + public: + TrailingMetadataHandler(PickState* pick, + ClientLbInterceptTrailingMetadataTest* test) + : test_(test) { + GRPC_CLOSURE_INIT(&recv_trailing_metadata_ready_, + RecordRecvTrailingMetadata, this, + grpc_schedule_on_exec_ctx); + pick->recv_trailing_metadata_ready = &recv_trailing_metadata_ready_; + pick->original_recv_trailing_metadata_ready = + &original_recv_trailing_metadata_ready_; + pick->recv_trailing_metadata = &recv_trailing_metadata_; + } + + private: + static void RecordRecvTrailingMetadata(void* arg, grpc_error* err) { + TrailingMetadataHandler* self = + static_cast(arg); + GPR_ASSERT(self->recv_trailing_metadata_ != nullptr); + // a simple check to make sure the trailing metadata is valid + GPR_ASSERT( + grpc_get_status_code_from_metadata( + self->recv_trailing_metadata_->idx.named.grpc_status->md) == + grpc_status_code::GRPC_STATUS_OK); + self->test_->ReportTrailerIntercepted(); + GRPC_CLOSURE_SCHED(self->original_recv_trailing_metadata_ready_, + GRPC_ERROR_REF(err)); + delete self; + } + + ClientLbInterceptTrailingMetadataTest* test_; + grpc_closure recv_trailing_metadata_ready_; + grpc_closure* original_recv_trailing_metadata_ready_ = nullptr; + grpc_metadata_batch* recv_trailing_metadata_ = nullptr; + }; + ClientLbInterceptTrailingMetadataTest* test_; }; // A factory for a test LB policy that intercepts trailing metadata. // The LB policy is implemented as a wrapper around a delegate LB policy. - class InterceptTrailingFactory : - public grpc_core::LoadBalancingPolicyFactory { + class InterceptTrailingFactory + : public grpc_core::LoadBalancingPolicyFactory { public: - InterceptTrailingFactory(ClientLbInterceptTrailingMetadataTest* test): - test_{test} {} + explicit InterceptTrailingFactory( + ClientLbInterceptTrailingMetadataTest* test) + : test_(test) {} grpc_core::OrphanablePtr CreateLoadBalancingPolicy( const grpc_core::LoadBalancingPolicy::Args& args) const override { return grpc_core::OrphanablePtr( grpc_core::New( - args, - /*delegate_lb_policy_name=*/ "pick_first", - test_)); + args, /*delegate_lb_policy_name=*/ "pick_first", test_)); } const char* name() const override { @@ -1394,14 +1419,14 @@ class ClientLbInterceptTrailingMetadataTest : public ClientLbEnd2endTest { trailers_intercepted_++; } - uint32_t trailers_intercepted() { + int trailers_intercepted() { std::unique_lock lock(mu_); return trailers_intercepted_; } private: std::mutex mu_; - uint32_t trailers_intercepted_ = 0; + int trailers_intercepted_ = 0; }; TEST_F(ClientLbInterceptTrailingMetadataTest, InterceptsRetriesDisabled) { @@ -1418,9 +1443,8 @@ TEST_F(ClientLbInterceptTrailingMetadataTest, InterceptsRetriesDisabled) { CheckRpcSendOk(stub, DEBUG_LOCATION); } // Check LB policy name for the channel. - EXPECT_EQ( - "intercept_trailing_metadata_lb", - channel->GetLoadBalancingPolicyName()); + EXPECT_EQ("intercept_trailing_metadata_lb", + channel->GetLoadBalancingPolicyName()); EXPECT_EQ(kNumServers, trailers_intercepted()); }