From 7c3fefea3f79710d17e064395ccfe34434ba46e7 Mon Sep 17 00:00:00 2001 From: "Mark D. Roth" Date: Tue, 7 Jul 2020 15:25:36 -0700 Subject: [PATCH] Make request path more easily visible to LB policies. --- .../filters/client_channel/client_channel.cc | 14 +- .../ext/filters/client_channel/lb_policy.h | 2 + .../lb_policy/xds/xds_routing.cc | 4 +- test/core/util/test_lb_policies.cc | 182 +++++++++++++----- test/core/util/test_lb_policies.h | 24 ++- test/cpp/end2end/client_lb_end2end_test.cc | 98 +++++++++- 6 files changed, 266 insertions(+), 58 deletions(-) diff --git a/src/core/ext/filters/client_channel/client_channel.cc b/src/core/ext/filters/client_channel/client_channel.cc index 921de985972..b0788a9d2ba 100644 --- a/src/core/ext/filters/client_channel/client_channel.cc +++ b/src/core/ext/filters/client_channel/client_channel.cc @@ -413,7 +413,8 @@ class CallData { iterator begin() const override { static_assert(sizeof(grpc_linked_mdelem*) <= sizeof(intptr_t), "iterator size too large"); - return iterator(this, reinterpret_cast(batch_->list.head)); + return iterator( + this, reinterpret_cast(MaybeSkipEntry(batch_->list.head))); } iterator end() const override { static_assert(sizeof(grpc_linked_mdelem*) <= sizeof(intptr_t), @@ -430,11 +431,19 @@ class CallData { } private: + grpc_linked_mdelem* MaybeSkipEntry(grpc_linked_mdelem* entry) const { + if (entry != nullptr && batch_->idx.named.path == entry) { + return entry->next; + } + return entry; + } + intptr_t IteratorHandleNext(intptr_t handle) const override { grpc_linked_mdelem* linked_mdelem = reinterpret_cast(handle); - return reinterpret_cast(linked_mdelem->next); + return reinterpret_cast(MaybeSkipEntry(linked_mdelem->next)); } + std::pair IteratorHandleGet( intptr_t handle) const override { grpc_linked_mdelem* linked_mdelem = @@ -4024,6 +4033,7 @@ bool CallData::PickSubchannelLocked(grpc_call_element* elem, // subchannel's copy of the metadata batch (which is copied for each // attempt) to the LB policy instead the one from the parent channel. LoadBalancingPolicy::PickArgs pick_args; + pick_args.path = StringViewFromSlice(path_); pick_args.call_state = &lb_call_state_; Metadata initial_metadata(this, initial_metadata_batch); pick_args.initial_metadata = &initial_metadata; diff --git a/src/core/ext/filters/client_channel/lb_policy.h b/src/core/ext/filters/client_channel/lb_policy.h index 7a775af82f6..6a290a3f51b 100644 --- a/src/core/ext/filters/client_channel/lb_policy.h +++ b/src/core/ext/filters/client_channel/lb_policy.h @@ -190,6 +190,8 @@ class LoadBalancingPolicy : public InternallyRefCounted { /// Arguments used when picking a subchannel for a call. struct PickArgs { + /// The path of the call. Indicates the RPC service and method name. + absl::string_view path; /// Initial metadata associated with the picking call. /// The LB policy may use the existing metadata to influence its routing /// decision, and it may add new metadata elements to be sent with the diff --git a/src/core/ext/filters/client_channel/lb_policy/xds/xds_routing.cc b/src/core/ext/filters/client_channel/lb_policy/xds/xds_routing.cc index 59c423f0317..9d9865563e2 100644 --- a/src/core/ext/filters/client_channel/lb_policy/xds/xds_routing.cc +++ b/src/core/ext/filters/client_channel/lb_policy/xds/xds_routing.cc @@ -305,9 +305,7 @@ bool UnderFraction(const uint32_t fraction_per_million) { XdsRoutingLb::PickResult XdsRoutingLb::RoutePicker::Pick(PickArgs args) { for (const Route& route : route_table_) { // Path matching. - auto path = GetMetadataValue(":path", args.initial_metadata); - GPR_DEBUG_ASSERT(path.has_value()); - if (!PathMatch(path.value(), route.matchers->path_matcher)) continue; + if (!PathMatch(args.path, route.matchers->path_matcher)) continue; // Header Matching. if (!HeadersMatch(args, route.matchers->header_matchers)) continue; // Match fraction check diff --git a/test/core/util/test_lb_policies.cc b/test/core/util/test_lb_policies.cc index eae25bc4265..56978b97949 100644 --- a/test/core/util/test_lb_policies.cc +++ b/test/core/util/test_lb_policies.cc @@ -39,8 +39,6 @@ namespace grpc_core { -TraceFlag grpc_trace_forwarding_lb(false, "forwarding_lb"); - namespace { // @@ -80,6 +78,117 @@ class ForwardingLoadBalancingPolicy : public LoadBalancingPolicy { OrphanablePtr delegate_; }; +// +// CopyMetadataToVector() +// + +MetadataVector CopyMetadataToVector( + LoadBalancingPolicy::MetadataInterface* metadata) { + MetadataVector result; + for (const auto& p : *metadata) { + result.push_back({std::string(p.first), std::string(p.second)}); + } + return result; +} + +// +// TestPickArgsLb +// + +constexpr char kTestPickArgsLbPolicyName[] = "test_pick_args_lb"; + +class TestPickArgsLb : public ForwardingLoadBalancingPolicy { + public: + TestPickArgsLb(Args args, TestPickArgsCallback cb) + : ForwardingLoadBalancingPolicy( + absl::make_unique(RefCountedPtr(this), cb), + std::move(args), + /*delegate_lb_policy_name=*/"pick_first", + /*initial_refcount=*/2) {} + + ~TestPickArgsLb() override = default; + + const char* name() const override { return kTestPickArgsLbPolicyName; } + + private: + class Picker : public SubchannelPicker { + public: + Picker(std::unique_ptr delegate_picker, + TestPickArgsCallback cb) + : delegate_picker_(std::move(delegate_picker)), cb_(std::move(cb)) {} + + PickResult Pick(PickArgs args) override { + // Report args seen. + PickArgsSeen args_seen; + args_seen.path = std::string(args.path); + args_seen.metadata = CopyMetadataToVector(args.initial_metadata); + cb_(args_seen); + // Do pick. + return delegate_picker_->Pick(args); + } + + private: + std::unique_ptr delegate_picker_; + TestPickArgsCallback cb_; + }; + + class Helper : public ChannelControlHelper { + public: + Helper(RefCountedPtr parent, TestPickArgsCallback cb) + : parent_(std::move(parent)), cb_(std::move(cb)) {} + + RefCountedPtr CreateSubchannel( + const grpc_channel_args& args) override { + return parent_->channel_control_helper()->CreateSubchannel(args); + } + + void UpdateState(grpc_connectivity_state state, + std::unique_ptr picker) override { + parent_->channel_control_helper()->UpdateState( + state, absl::make_unique(std::move(picker), cb_)); + } + + void RequestReresolution() override { + parent_->channel_control_helper()->RequestReresolution(); + } + + void AddTraceEvent(TraceSeverity severity, + absl::string_view message) override { + parent_->channel_control_helper()->AddTraceEvent(severity, message); + } + + private: + RefCountedPtr parent_; + TestPickArgsCallback cb_; + }; +}; + +class TestPickArgsLbConfig : public LoadBalancingPolicy::Config { + public: + const char* name() const override { return kTestPickArgsLbPolicyName; } +}; + +class TestPickArgsLbFactory : public LoadBalancingPolicyFactory { + public: + explicit TestPickArgsLbFactory(TestPickArgsCallback cb) + : cb_(std::move(cb)) {} + + OrphanablePtr CreateLoadBalancingPolicy( + LoadBalancingPolicy::Args args) const override { + return MakeOrphanable(std::move(args), cb_); + } + + const char* name() const override { return kTestPickArgsLbPolicyName; } + + RefCountedPtr ParseLoadBalancingConfig( + const Json& /*json*/, grpc_error** /*error*/) const override { + return MakeRefCounted(); + } + + private: + TestPickArgsCallback cb_; +}; + // // InterceptRecvTrailingMetadataLoadBalancingPolicy // @@ -91,12 +200,12 @@ class InterceptRecvTrailingMetadataLoadBalancingPolicy : public ForwardingLoadBalancingPolicy { public: InterceptRecvTrailingMetadataLoadBalancingPolicy( - Args args, InterceptRecvTrailingMetadataCallback cb, void* user_data) + Args args, InterceptRecvTrailingMetadataCallback cb) : ForwardingLoadBalancingPolicy( - std::unique_ptr(new Helper( + absl::make_unique( RefCountedPtr( this), - cb, user_data)), + std::move(cb)), std::move(args), /*delegate_lb_policy_name=*/"pick_first", /*initial_refcount=*/2) {} @@ -110,24 +219,18 @@ class InterceptRecvTrailingMetadataLoadBalancingPolicy private: class Picker : public SubchannelPicker { public: - explicit Picker(std::unique_ptr delegate_picker, - InterceptRecvTrailingMetadataCallback cb, void* user_data) - : delegate_picker_(std::move(delegate_picker)), - cb_(cb), - user_data_(user_data) {} + Picker(std::unique_ptr delegate_picker, + InterceptRecvTrailingMetadataCallback cb) + : delegate_picker_(std::move(delegate_picker)), cb_(std::move(cb)) {} PickResult Pick(PickArgs args) override { - // Check that we can read initial metadata. - gpr_log(GPR_INFO, "initial metadata:"); - InterceptRecvTrailingMetadataLoadBalancingPolicy::LogMetadata( - args.initial_metadata); // Do pick. PickResult result = delegate_picker_->Pick(args); // Intercept trailing metadata. if (result.type == PickResult::PICK_COMPLETE && result.subchannel != nullptr) { new (args.call_state->Alloc(sizeof(TrailingMetadataHandler))) - TrailingMetadataHandler(&result, cb_, user_data_); + TrailingMetadataHandler(&result, cb_); } return result; } @@ -135,15 +238,14 @@ class InterceptRecvTrailingMetadataLoadBalancingPolicy private: std::unique_ptr delegate_picker_; InterceptRecvTrailingMetadataCallback cb_; - void* user_data_; }; class Helper : public ChannelControlHelper { public: Helper( RefCountedPtr parent, - InterceptRecvTrailingMetadataCallback cb, void* user_data) - : parent_(std::move(parent)), cb_(cb), user_data_(user_data) {} + InterceptRecvTrailingMetadataCallback cb) + : parent_(std::move(parent)), cb_(std::move(cb)) {} RefCountedPtr CreateSubchannel( const grpc_channel_args& args) override { @@ -153,8 +255,7 @@ class InterceptRecvTrailingMetadataLoadBalancingPolicy void UpdateState(grpc_connectivity_state state, std::unique_ptr picker) override { parent_->channel_control_helper()->UpdateState( - state, std::unique_ptr( - new Picker(std::move(picker), cb_, user_data_))); + state, absl::make_unique(std::move(picker), cb_)); } void RequestReresolution() override { @@ -169,15 +270,13 @@ class InterceptRecvTrailingMetadataLoadBalancingPolicy private: RefCountedPtr parent_; InterceptRecvTrailingMetadataCallback cb_; - void* user_data_; }; class TrailingMetadataHandler { public: TrailingMetadataHandler(PickResult* result, - InterceptRecvTrailingMetadataCallback cb, - void* user_data) - : cb_(cb), user_data_(user_data) { + InterceptRecvTrailingMetadataCallback cb) + : cb_(std::move(cb)) { result->recv_trailing_metadata_ready = [this](grpc_error* error, MetadataInterface* metadata, CallState* call_state) { @@ -189,25 +288,16 @@ class InterceptRecvTrailingMetadataLoadBalancingPolicy void RecordRecvTrailingMetadata(grpc_error* /*error*/, MetadataInterface* recv_trailing_metadata, CallState* call_state) { + TrailingMetadataArgsSeen args_seen; + args_seen.backend_metric_data = call_state->GetBackendMetricData(); GPR_ASSERT(recv_trailing_metadata != nullptr); - gpr_log(GPR_INFO, "trailing metadata:"); - InterceptRecvTrailingMetadataLoadBalancingPolicy::LogMetadata( - recv_trailing_metadata); - cb_(user_data_, call_state->GetBackendMetricData()); + args_seen.metadata = CopyMetadataToVector(recv_trailing_metadata); + cb_(args_seen); this->~TrailingMetadataHandler(); } InterceptRecvTrailingMetadataCallback cb_; - void* user_data_; }; - - static void LogMetadata(MetadataInterface* metadata) { - for (const auto& p : *metadata) { - gpr_log(GPR_INFO, " \"%.*s\"=>\"%.*s\"", - static_cast(p.first.size()), p.first.data(), - static_cast(p.second.size()), p.second.data()); - } - } }; class InterceptTrailingConfig : public LoadBalancingPolicy::Config { @@ -219,14 +309,13 @@ class InterceptTrailingConfig : public LoadBalancingPolicy::Config { class InterceptTrailingFactory : public LoadBalancingPolicyFactory { public: - explicit InterceptTrailingFactory(InterceptRecvTrailingMetadataCallback cb, - void* user_data) - : cb_(cb), user_data_(user_data) {} + explicit InterceptTrailingFactory(InterceptRecvTrailingMetadataCallback cb) + : cb_(std::move(cb)) {} OrphanablePtr CreateLoadBalancingPolicy( LoadBalancingPolicy::Args args) const override { return MakeOrphanable( - std::move(args), cb_, user_data_); + std::move(args), cb_); } const char* name() const override { @@ -240,16 +329,19 @@ class InterceptTrailingFactory : public LoadBalancingPolicyFactory { private: InterceptRecvTrailingMetadataCallback cb_; - void* user_data_; }; } // namespace +void RegisterTestPickArgsLoadBalancingPolicy(TestPickArgsCallback cb) { + LoadBalancingPolicyRegistry::Builder::RegisterLoadBalancingPolicyFactory( + absl::make_unique(std::move(cb))); +} + void RegisterInterceptRecvTrailingMetadataLoadBalancingPolicy( - InterceptRecvTrailingMetadataCallback cb, void* user_data) { + InterceptRecvTrailingMetadataCallback cb) { LoadBalancingPolicyRegistry::Builder::RegisterLoadBalancingPolicyFactory( - std::unique_ptr( - new InterceptTrailingFactory(cb, user_data))); + absl::make_unique(std::move(cb))); } } // namespace grpc_core diff --git a/test/core/util/test_lb_policies.h b/test/core/util/test_lb_policies.h index 3652515e57e..ffb079181ea 100644 --- a/test/core/util/test_lb_policies.h +++ b/test/core/util/test_lb_policies.h @@ -23,14 +23,32 @@ namespace grpc_core { -typedef void (*InterceptRecvTrailingMetadataCallback)( - void*, const LoadBalancingPolicy::BackendMetricData*); +using MetadataVector = std::vector>; + +struct PickArgsSeen { + std::string path; + MetadataVector metadata; +}; + +using TestPickArgsCallback = std::function; + +// Registers an LB policy called "test_pick_args_lb" that checks the args +// passed to SubchannelPicker::Pick(). +void RegisterTestPickArgsLoadBalancingPolicy(TestPickArgsCallback cb); + +struct TrailingMetadataArgsSeen { + const LoadBalancingPolicy::BackendMetricData* backend_metric_data; + MetadataVector metadata; +}; + +using InterceptRecvTrailingMetadataCallback = + std::function; // Registers an LB policy called "intercept_trailing_metadata_lb" that // invokes cb with argument user_data when trailing metadata is received // for each call. void RegisterInterceptRecvTrailingMetadataLoadBalancingPolicy( - InterceptRecvTrailingMetadataCallback cb, void* user_data); + InterceptRecvTrailingMetadataCallback cb); } // namespace grpc_core diff --git a/test/cpp/end2end/client_lb_end2end_test.cc b/test/cpp/end2end/client_lb_end2end_test.cc index 174c09cd79d..0a4aa4a15a3 100644 --- a/test/cpp/end2end/client_lb_end2end_test.cc +++ b/test/cpp/end2end/client_lb_end2end_test.cc @@ -295,9 +295,13 @@ class ClientLbEnd2endTest : public ::testing::Test { if (local_response) response = new EchoResponse; EchoRequest request; request.set_message(kRequestMessage_); + request.mutable_param()->set_echo_metadata(true); ClientContext context; context.set_deadline(grpc_timeout_milliseconds_to_deadline(timeout_ms)); if (wait_for_ready) context.set_wait_for_ready(true); + context.AddMetadata("foo", "1"); + context.AddMetadata("bar", "2"); + context.AddMetadata("baz", "3"); Status status = stub->Echo(&context, request, response); if (result != nullptr) *result = status; if (local_response) delete response; @@ -1632,19 +1636,82 @@ TEST_F(ClientLbEnd2endTest, ChannelIdleness) { EXPECT_EQ(channel->GetState(false), GRPC_CHANNEL_READY); } -class ClientLbInterceptTrailingMetadataTest : public ClientLbEnd2endTest { +class ClientLbPickArgsTest : public ClientLbEnd2endTest { protected: void SetUp() override { ClientLbEnd2endTest::SetUp(); current_test_instance_ = this; } - void TearDown() override { ClientLbEnd2endTest::TearDown(); } + static void SetUpTestCase() { + grpc_init(); + grpc_core::RegisterTestPickArgsLoadBalancingPolicy(SavePickArgs); + } + + static void TearDownTestCase() { grpc_shutdown_blocking(); } + + const std::vector& args_seen_list() { + grpc::internal::MutexLock lock(&mu_); + return args_seen_list_; + } + + private: + static void SavePickArgs(const grpc_core::PickArgsSeen& args_seen) { + ClientLbPickArgsTest* self = current_test_instance_; + grpc::internal::MutexLock lock(&self->mu_); + self->args_seen_list_.emplace_back(args_seen); + } + + static ClientLbPickArgsTest* current_test_instance_; + grpc::internal::Mutex mu_; + std::vector args_seen_list_; +}; + +ClientLbPickArgsTest* ClientLbPickArgsTest::current_test_instance_ = nullptr; + +TEST_F(ClientLbPickArgsTest, Basic) { + const int kNumServers = 1; + StartServers(kNumServers); + auto response_generator = BuildResolverResponseGenerator(); + auto channel = BuildChannel("test_pick_args_lb", response_generator); + auto stub = BuildStub(channel); + response_generator.SetNextResolution(GetServersPorts()); + CheckRpcSendOk(stub, DEBUG_LOCATION, /*wait_for_ready=*/true); + // Check LB policy name for the channel. + EXPECT_EQ("test_pick_args_lb", channel->GetLoadBalancingPolicyName()); + // There will be two entries, one for the pick tried in state + // CONNECTING and another for the pick tried in state READY. + EXPECT_THAT(args_seen_list(), + ::testing::ElementsAre( + ::testing::AllOf( + ::testing::Field(&grpc_core::PickArgsSeen::path, + "/grpc.testing.EchoTestService/Echo"), + ::testing::Field(&grpc_core::PickArgsSeen::metadata, + ::testing::UnorderedElementsAre( + ::testing::Pair("foo", "1"), + ::testing::Pair("bar", "2"), + ::testing::Pair("baz", "3")))), + ::testing::AllOf( + ::testing::Field(&grpc_core::PickArgsSeen::path, + "/grpc.testing.EchoTestService/Echo"), + ::testing::Field(&grpc_core::PickArgsSeen::metadata, + ::testing::UnorderedElementsAre( + ::testing::Pair("foo", "1"), + ::testing::Pair("bar", "2"), + ::testing::Pair("baz", "3")))))); +} + +class ClientLbInterceptTrailingMetadataTest : public ClientLbEnd2endTest { + protected: + void SetUp() override { + ClientLbEnd2endTest::SetUp(); + current_test_instance_ = this; + } static void SetUpTestCase() { grpc_init(); grpc_core::RegisterInterceptRecvTrailingMetadataLoadBalancingPolicy( - ReportTrailerIntercepted, nullptr); + ReportTrailerIntercepted); } static void TearDownTestCase() { grpc_shutdown_blocking(); } @@ -1654,6 +1721,11 @@ class ClientLbInterceptTrailingMetadataTest : public ClientLbEnd2endTest { return trailers_intercepted_; } + const grpc_core::MetadataVector& trailing_metadata() { + grpc::internal::MutexLock lock(&mu_); + return trailing_metadata_; + } + const udpa::data::orca::v1::OrcaLoadReport* backend_load_report() { grpc::internal::MutexLock lock(&mu_); return load_report_.get(); @@ -1661,11 +1733,12 @@ class ClientLbInterceptTrailingMetadataTest : public ClientLbEnd2endTest { private: static void ReportTrailerIntercepted( - void* arg, const grpc_core::LoadBalancingPolicy::BackendMetricData* - backend_metric_data) { + const grpc_core::TrailingMetadataArgsSeen& args_seen) { + const auto* backend_metric_data = args_seen.backend_metric_data; ClientLbInterceptTrailingMetadataTest* self = current_test_instance_; grpc::internal::MutexLock lock(&self->mu_); self->trailers_intercepted_++; + self->trailing_metadata_ = args_seen.metadata; if (backend_metric_data != nullptr) { self->load_report_.reset(new udpa::data::orca::v1::OrcaLoadReport); self->load_report_->set_cpu_utilization( @@ -1689,6 +1762,7 @@ class ClientLbInterceptTrailingMetadataTest : public ClientLbEnd2endTest { static ClientLbInterceptTrailingMetadataTest* current_test_instance_; grpc::internal::Mutex mu_; int trailers_intercepted_ = 0; + grpc_core::MetadataVector trailing_metadata_; std::unique_ptr load_report_; }; @@ -1711,6 +1785,13 @@ TEST_F(ClientLbInterceptTrailingMetadataTest, InterceptsRetriesDisabled) { EXPECT_EQ("intercept_trailing_metadata_lb", channel->GetLoadBalancingPolicyName()); EXPECT_EQ(kNumRpcs, trailers_intercepted()); + EXPECT_THAT(trailing_metadata(), + ::testing::UnorderedElementsAre( + // TODO(roth): Should grpc-status be visible here? + ::testing::Pair("grpc-status", "0"), + ::testing::Pair("user-agent", ::testing::_), + ::testing::Pair("foo", "1"), ::testing::Pair("bar", "2"), + ::testing::Pair("baz", "3"))); EXPECT_EQ(nullptr, backend_load_report()); } @@ -1746,6 +1827,13 @@ TEST_F(ClientLbInterceptTrailingMetadataTest, InterceptsRetriesEnabled) { EXPECT_EQ("intercept_trailing_metadata_lb", channel->GetLoadBalancingPolicyName()); EXPECT_EQ(kNumRpcs, trailers_intercepted()); + EXPECT_THAT(trailing_metadata(), + ::testing::UnorderedElementsAre( + // TODO(roth): Should grpc-status be visible here? + ::testing::Pair("grpc-status", "0"), + ::testing::Pair("user-agent", ::testing::_), + ::testing::Pair("foo", "1"), ::testing::Pair("bar", "2"), + ::testing::Pair("baz", "3"))); EXPECT_EQ(nullptr, backend_load_report()); }