diff --git a/src/core/ext/filters/client_channel/client_channel.cc b/src/core/ext/filters/client_channel/client_channel.cc index 9732b1753a8..5a74ccc2a05 100644 --- a/src/core/ext/filters/client_channel/client_channel.cc +++ b/src/core/ext/filters/client_channel/client_channel.cc @@ -937,6 +937,7 @@ typedef struct client_channel_call_data { grpc_closure recv_trailing_metadata_ready_for_lb; // The original trailer interception callback. grpc_closure* original_recv_trailing_metadata_ready; + grpc_transport_stream_op_batch* recv_trailing_metadata_op_batch; grpc_polling_entity* pollent; bool pollent_added_to_interested_parties; @@ -1000,8 +1001,7 @@ static void on_complete(void* arg, grpc_error* error); static void start_retriable_subchannel_batches(void* arg, grpc_error* ignored); static void start_pick_locked(void* arg, grpc_error* ignored); static void maybe_intercept_trailing_metadata_for_lb( - void* arg, grpc_transport_stream_op_batch* batch); -static void recv_trailing_metadata_ready_for_lb(void* arg, grpc_error* error); + grpc_call_element* arg, grpc_transport_stream_op_batch* batch); // // send op data caching @@ -1977,6 +1977,16 @@ static void recv_trailing_metadata_ready_for_retries( grpc_mdelem* server_pushback_md = nullptr; grpc_metadata_batch* md_batch = batch_data->batch.payload->recv_trailing_metadata.recv_trailing_metadata; + // If the lb policy asks for the trailing metadata, set its receiving ptr + if (calld->pick.recv_trailing_metadata != nullptr) { + *calld->pick.recv_trailing_metadata = md_batch; + } + // We use GRPC_CLOSURE_RUN synchronously on the callback. In the case of + // a retry, we would have already freed the metadata before returning from + // this function. + GRPC_CLOSURE_RUN( + calld->pick.recv_trailing_metadata_ready, + GRPC_ERROR_REF(error)); get_call_status(elem, md_batch, GRPC_ERROR_REF(error), &status, &server_pushback_md); if (grpc_client_channel_trace.enabled()) { @@ -2000,13 +2010,6 @@ static void recv_trailing_metadata_ready_for_retries( } // Not retrying, so commit the call. retry_commit(elem, retry_state); - // Now that the try is committed, give the trailer to the lb policy as needed - if (calld->pick.recv_trailing_metadata != nullptr) { - *calld->pick.recv_trailing_metadata = md_batch; - } - GRPC_CLOSURE_SCHED( - calld->pick.recv_trailing_metadata_ready, - GRPC_ERROR_REF(error)); // Run any necessary closures. run_closures_for_completed_call(batch_data, GRPC_ERROR_REF(error)); } @@ -2595,13 +2598,12 @@ static void start_retriable_subchannel_batches(void* arg, grpc_error* ignored) { // The callback to intercept trailing metadata if retries is not enabled static void recv_trailing_metadata_ready_for_lb(void* arg, grpc_error* error) { - subchannel_batch_data* batch_data = static_cast(arg); - grpc_call_element* elem = batch_data->elem; + grpc_call_element* elem = static_cast(arg); call_data* calld = static_cast(elem->call_data); if (calld->pick.recv_trailing_metadata != nullptr) { *calld->pick.recv_trailing_metadata = - batch_data->batch.payload->recv_trailing_metadata - .recv_trailing_metadata; + calld->recv_trailing_metadata_op_batch->payload + ->recv_trailing_metadata.recv_trailing_metadata; } GRPC_CLOSURE_SCHED( calld->pick.recv_trailing_metadata_ready, @@ -2611,19 +2613,22 @@ static void recv_trailing_metadata_ready_for_lb(void* arg, grpc_error* error) { GRPC_ERROR_REF(error)); } -// Installs a interceptor to inform the lb of the trailing metadata, if needed +// If needed, intercepts the recv_trailing_metadata_ready callback to return +// trailing metadata to the LB policy. static void maybe_intercept_trailing_metadata_for_lb( - void* arg, grpc_transport_stream_op_batch* batch) { - subchannel_batch_data* batch_data = static_cast(arg); - grpc_call_element* elem = batch_data->elem; + grpc_call_element* elem, grpc_transport_stream_op_batch* batch) { call_data* calld = static_cast(elem->call_data); - calld->original_recv_trailing_metadata_ready = - batch->payload->recv_trailing_metadata.recv_trailing_metadata_ready; - GRPC_CLOSURE_INIT(&calld->recv_trailing_metadata_ready_for_lb, - recv_trailing_metadata_ready_for_lb, elem, - grpc_schedule_on_exec_ctx); - batch->payload->recv_trailing_metadata.recv_trailing_metadata_ready = - &calld->recv_trailing_metadata_ready_for_lb; + if (!batch->recv_trailing_metadata) { + return; + } + if (calld->pick.recv_trailing_metadata_ready != nullptr) { + calld->recv_trailing_metadata_op_batch = batch; + GRPC_CLOSURE_INIT(&calld->recv_trailing_metadata_ready_for_lb, + recv_trailing_metadata_ready_for_lb, elem, + grpc_schedule_on_exec_ctx); + batch->payload->recv_trailing_metadata.recv_trailing_metadata_ready = + &calld->recv_trailing_metadata_ready_for_lb; + } } static void create_subchannel_call(grpc_call_element* elem, grpc_error* error) { diff --git a/test/cpp/end2end/client_lb_end2end_test.cc b/test/cpp/end2end/client_lb_end2end_test.cc index a9d68ab0582..acd8ab46c59 100644 --- a/test/cpp/end2end/client_lb_end2end_test.cc +++ b/test/cpp/end2end/client_lb_end2end_test.cc @@ -36,12 +36,17 @@ #include "src/core/ext/filters/client_channel/resolver/fake/fake_resolver.h" #include "src/core/ext/filters/client_channel/subchannel_index.h" +#include "src/core/ext/filters/client_channel/lb_policy_registry.h" #include "src/core/lib/backoff/backoff.h" +#include "src/core/lib/channel/channelz.h" +#include "src/core/lib/iomgr/closure.h" +#include "src/core/lib/iomgr/error.h" #include "src/core/lib/gpr/env.h" #include "src/core/lib/gprpp/debug_location.h" +#include "src/core/lib/gprpp/orphanable.h" #include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/iomgr/tcp_client.h" - +#include "src/core/lib/transport/connectivity_state.h" #include "src/proto/grpc/testing/echo.grpc.pb.h" #include "test/core/util/port.h" #include "test/core/util/test_config.h" @@ -996,6 +1001,187 @@ TEST_F(ClientLbEnd2endTest, RoundRobinSingleReconnect) { WaitForServer(stub, 0, DEBUG_LOCATION); } + +const char intercept_trailing_name[] = "intercept_trailing_metadata"; + +// LoadBalancingPolicy implementations are not designed to be extended. +// A hacky forwarding class to avoid implementing a standalone test LB. +class InterceptTrailing : public grpc_core::LoadBalancingPolicy { + public: + InterceptTrailing(const Args& args) + : grpc_core::LoadBalancingPolicy(args) { + UpdateLocked(*args.args); + grpc_connectivity_state_init(&state_tracker_, GRPC_CHANNEL_IDLE, + intercept_trailing_name); + } + + bool PickLocked(PickState* pick, grpc_error** error) override { + GRPC_CLOSURE_INIT( + &recv_trailing_metadata_ready_, + InterceptTrailing::RecordRecvTrailingMetadata, + /*cb_arg=*/ nullptr, + grpc_schedule_on_exec_ctx); + pick->recv_trailing_metadata_ready = &recv_trailing_metadata_ready_; + pick->recv_trailing_metadata = &recv_trailing_metadata_; + pick->connected_subchannel = + grpc_subchannel_get_connected_subchannel(hardcoded_subchannel_); + + if (pick->connected_subchannel.get() != nullptr) { + *error = GRPC_ERROR_NONE; + return true; + } + + if (pick->on_complete == nullptr) { + *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "No pick result available but synchronous result required."); + return true; + } else { + on_complete_ = pick->on_complete; + // TODO(zpencer): call on_completed_ at some point + return false; + } + } + + void UpdateLocked(const grpc_channel_args& args) override { + const grpc_arg* arg = grpc_channel_args_find(&args, GRPC_ARG_LB_ADDRESSES); + grpc_lb_addresses* addresses = + static_cast(arg->value.pointer.p); + grpc_arg addr_arg = + grpc_create_subchannel_address_arg(&addresses->addresses[0].address); + static const char* keys_to_remove[] = {GRPC_ARG_SUBCHANNEL_ADDRESS, + GRPC_ARG_LB_ADDRESSES}; + grpc_channel_args* new_args = grpc_channel_args_copy_and_add_and_remove( + &args, keys_to_remove, GPR_ARRAY_SIZE(keys_to_remove), &addr_arg, 1); + gpr_free(addr_arg.value.string); + grpc_subchannel_args sc_args; + memset(&sc_args, 0, sizeof(grpc_subchannel_args)); + sc_args.args = new_args; + if (hardcoded_subchannel_ != nullptr) { + GRPC_SUBCHANNEL_UNREF(hardcoded_subchannel_, "new pick"); + } + hardcoded_subchannel_ = grpc_client_channel_factory_create_subchannel( + client_channel_factory(), &sc_args); + grpc_channel_args_destroy(new_args); + } + + void CancelMatchingPicksLocked(uint32_t initial_metadata_flags_mask, + uint32_t initial_metadata_flags_eq, + grpc_error* error) override { + GRPC_ERROR_UNREF(error); + } + + void CancelPickLocked(PickState* pick, + grpc_error* error) override { + pick->connected_subchannel.reset(); + GRPC_CLOSURE_SCHED(pick->on_complete, + GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( + "Pick Cancelled", &error, 1)); + + GRPC_ERROR_UNREF(error); + } + + grpc_connectivity_state CheckConnectivityLocked( + grpc_error** error) override { + return grpc_connectivity_state_get(&state_tracker_, error); + } + + void NotifyOnStateChangeLocked(grpc_connectivity_state* current, + grpc_closure* notify) override { + grpc_connectivity_state_notify_on_state_change(&state_tracker_, current, + notify); + } + + void ShutdownLocked() override { + grpc_error* error = + GRPC_ERROR_CREATE_FROM_STATIC_STRING("Channel shutdown"); + grpc_connectivity_state_set( + &state_tracker_, + GRPC_CHANNEL_SHUTDOWN, + GRPC_ERROR_REF(error), + "intercept_trailing_shutdown"); + } + + ~InterceptTrailing() { + grpc_connectivity_state_destroy(&state_tracker_); + } + + private: + grpc_closure* on_complete_ = nullptr; + grpc_closure recv_trailing_metadata_ready_; + grpc_metadata_batch* recv_trailing_metadata_ = nullptr; + grpc_subchannel* hardcoded_subchannel_ = nullptr; + grpc_connectivity_state_tracker state_tracker_; + + static void RecordRecvTrailingMetadata( + void* ignored_arg, grpc_error* ignored_err) { + gpr_log(GPR_INFO, "trailer intercepted by lb"); + } +}; + +// A factory for a test LB policy that intercepts trailing metadata. +// The LB policy is implemented as a wrapper around a delegate LB policy. +class InterceptTrailingFactory : public grpc_core::LoadBalancingPolicyFactory { + public: + InterceptTrailingFactory(){} + + grpc_core::OrphanablePtr + CreateLoadBalancingPolicy( + const grpc_core::LoadBalancingPolicy::Args& args) const override { + return grpc_core::OrphanablePtr( + grpc_core::New(args)); + } + + const char* name() const override { + return intercept_trailing_name; + } +}; + +class ClientLbInterceptTrailingMetadataTest : public ClientLbEnd2endTest { + protected: + void SetUp() override { + ClientLbEnd2endTest::SetUp(); + grpc_core::LoadBalancingPolicyRegistry::Builder:: + RegisterLoadBalancingPolicyFactory( + grpc_core::UniquePtr( + grpc_core::New())); + } + + void TearDown() override { + ClientLbEnd2endTest::TearDown(); + } +}; + +TEST_F(ClientLbInterceptTrailingMetadataTest, Intercepts_retries_disabled) { + const int kNumServers = 1; + StartServers(kNumServers); + auto channel = BuildChannel(intercept_trailing_name); + auto stub = BuildStub(channel); + std::vector ports; + for (size_t i = 0; i < servers_.size(); ++i) { + ports.emplace_back(servers_[i]->port_); + } + SetNextResolution(ports); + + for (size_t i = 0; i < servers_.size(); ++i) { + CheckRpcSendOk(stub, DEBUG_LOCATION); + } + // All requests should have gone to a single server. + bool found = false; + for (size_t i = 0; i < servers_.size(); ++i) { + const int request_count = servers_[i]->service_.request_count(); + if (request_count == kNumServers) { + found = true; + } else { + EXPECT_EQ(0, request_count); + } + } + EXPECT_TRUE(found); + // Check LB policy name for the channel. + EXPECT_EQ( + intercept_trailing_name, + channel->GetLoadBalancingPolicyName()); +} + } // namespace } // namespace testing } // namespace grpc