From ac228814a040201a49527b873fd67cea244c140c Mon Sep 17 00:00:00 2001 From: Eugene Ostroukhov Date: Mon, 1 May 2023 19:07:20 -0700 Subject: [PATCH] [core] Expand core attributes to hold values of any type (#32835) --- src/core/BUILD | 4 + .../filters/client_channel/client_channel.cc | 12 +- .../client_channel/client_channel_internal.h | 4 +- .../lb_policy/ring_hash/ring_hash.cc | 9 +- .../lb_policy/ring_hash/ring_hash.h | 19 ++- .../lb_policy/xds/xds_cluster_manager.cc | 15 +- .../lb_policy/xds/xds_override_host.cc | 7 +- .../resolver/xds/xds_resolver.cc | 11 +- .../resolver/xds/xds_resolver.h | 17 ++- .../stateful_session_filter.cc | 6 +- .../stateful_session_filter.h | 17 ++- .../service_config/service_config_call_data.h | 22 ++- .../lb_policy/lb_policy_test_lib.h | 54 ++++--- .../lb_policy/weighted_round_robin_test.cc | 1 - .../lb_policy/xds_override_host_test.cc | 137 +++++++++++------- 15 files changed, 222 insertions(+), 113 deletions(-) diff --git a/src/core/BUILD b/src/core/BUILD index d3553854041..08077dcf9f8 100644 --- a/src/core/BUILD +++ b/src/core/BUILD @@ -4553,6 +4553,7 @@ grpc_cc_library( "closure", "error", "grpc_lb_subchannel_list", + "grpc_service_config", "json", "json_args", "json_object_loader", @@ -5142,8 +5143,10 @@ grpc_cc_library( hdrs = [ "ext/filters/client_channel/resolver/xds/xds_resolver.h", ], + external_deps = ["absl/strings"], language = "c++", deps = [ + "grpc_service_config", "unique_type_name", "//:gpr_platform", ], @@ -5174,6 +5177,7 @@ grpc_cc_library( "channel_fwd", "dual_ref_counted", "grpc_lb_policy_ring_hash", + "grpc_resolver_xds_header", "grpc_service_config", "grpc_xds_client", "iomgr_fwd", diff --git a/src/core/ext/filters/client_channel/client_channel.cc b/src/core/ext/filters/client_channel/client_channel.cc index 6629fe52655..6128abed83d 100644 --- a/src/core/ext/filters/client_channel/client_channel.cc +++ b/src/core/ext/filters/client_channel/client_channel.cc @@ -2335,7 +2335,8 @@ class ClientChannel::LoadBalancedCall::LbCallState // Internal API to allow first-party LB policies to access per-call // attributes set by the ConfigSelector. - absl::string_view GetCallAttribute(UniqueTypeName type) override; + ServiceConfigCallData::CallAttributeInterface* GetCallAttribute( + UniqueTypeName type) const override; private: LoadBalancedCall* lb_call_; @@ -2420,15 +2421,12 @@ class ClientChannel::LoadBalancedCall::Metadata // ClientChannel::LoadBalancedCall::LbCallState // -absl::string_view +ServiceConfigCallData::CallAttributeInterface* ClientChannel::LoadBalancedCall::LbCallState::GetCallAttribute( - UniqueTypeName type) { + UniqueTypeName type) const { auto* service_config_call_data = static_cast( lb_call_->call_context()[GRPC_CONTEXT_SERVICE_CONFIG_CALL_DATA].value); - auto& call_attributes = service_config_call_data->call_attributes(); - auto it = call_attributes.find(type); - if (it == call_attributes.end()) return absl::string_view(); - return it->second; + return service_config_call_data->GetCallAttribute(type); } // diff --git a/src/core/ext/filters/client_channel/client_channel_internal.h b/src/core/ext/filters/client_channel/client_channel_internal.h index 9a103b9d779..bad212da4da 100644 --- a/src/core/ext/filters/client_channel/client_channel_internal.h +++ b/src/core/ext/filters/client_channel/client_channel_internal.h @@ -22,7 +22,6 @@ #include #include "absl/functional/any_invocable.h" -#include "absl/strings/string_view.h" #include "src/core/lib/channel/context.h" #include "src/core/lib/gprpp/ref_counted_ptr.h" @@ -48,7 +47,8 @@ namespace grpc_core { // LB policies to access internal call attributes. class ClientChannelLbCallState : public LoadBalancingPolicy::CallState { public: - virtual absl::string_view GetCallAttribute(UniqueTypeName type) = 0; + virtual ServiceConfigCallData::CallAttributeInterface* GetCallAttribute( + UniqueTypeName type) const = 0; }; // Internal type for ServiceConfigCallData. Handles call commits. diff --git a/src/core/ext/filters/client_channel/lb_policy/ring_hash/ring_hash.cc b/src/core/ext/filters/client_channel/lb_policy/ring_hash/ring_hash.cc index 66262aefd2b..78c78609262 100644 --- a/src/core/ext/filters/client_channel/lb_policy/ring_hash/ring_hash.cc +++ b/src/core/ext/filters/client_channel/lb_policy/ring_hash/ring_hash.cc @@ -70,7 +70,7 @@ namespace grpc_core { TraceFlag grpc_lb_ring_hash_trace(false, "ring_hash_lb"); -UniqueTypeName RequestHashAttributeName() { +UniqueTypeName RequestHashAttribute::TypeName() { static UniqueTypeName::Factory kFactory("request_hash"); return kFactory.Create(); } @@ -345,7 +345,12 @@ class RingHash : public LoadBalancingPolicy { RingHash::PickResult RingHash::Picker::Pick(PickArgs args) { auto* call_state = static_cast(args.call_state); - auto hash = call_state->GetCallAttribute(RequestHashAttributeName()); + auto* hash_attribute = static_cast( + call_state->GetCallAttribute(RequestHashAttribute::TypeName())); + absl::string_view hash; + if (hash_attribute != nullptr) { + hash = hash_attribute->request_hash(); + } uint64_t h; if (!absl::SimpleAtoi(hash, &h)) { return PickResult::Fail( diff --git a/src/core/ext/filters/client_channel/lb_policy/ring_hash/ring_hash.h b/src/core/ext/filters/client_channel/lb_policy/ring_hash/ring_hash.h index 4d8137908cc..95b248ea1e1 100644 --- a/src/core/ext/filters/client_channel/lb_policy/ring_hash/ring_hash.h +++ b/src/core/ext/filters/client_channel/lb_policy/ring_hash/ring_hash.h @@ -21,15 +21,32 @@ #include +#include "absl/strings/string_view.h" + #include "src/core/lib/gprpp/unique_type_name.h" #include "src/core/lib/gprpp/validation_errors.h" #include "src/core/lib/json/json.h" #include "src/core/lib/json/json_args.h" #include "src/core/lib/json/json_object_loader.h" +#include "src/core/lib/service_config/service_config_call_data.h" namespace grpc_core { -UniqueTypeName RequestHashAttributeName(); +class RequestHashAttribute + : public ServiceConfigCallData::CallAttributeInterface { + public: + static UniqueTypeName TypeName(); + + explicit RequestHashAttribute(absl::string_view request_hash) + : request_hash_(request_hash) {} + + absl::string_view request_hash() const { return request_hash_; } + + private: + UniqueTypeName type() const override { return TypeName(); } + + absl::string_view request_hash_; +}; // Helper Parsing method to parse ring hash policy configs; for example, ring // hash size validity. diff --git a/src/core/ext/filters/client_channel/lb_policy/xds/xds_cluster_manager.cc b/src/core/ext/filters/client_channel/lb_policy/xds/xds_cluster_manager.cc index d235724c646..13482aff9a5 100644 --- a/src/core/ext/filters/client_channel/lb_policy/xds/xds_cluster_manager.cc +++ b/src/core/ext/filters/client_channel/lb_policy/xds/xds_cluster_manager.cc @@ -19,6 +19,7 @@ #include #include +#include #include #include #include @@ -123,8 +124,8 @@ class XdsClusterManagerLb : public LoadBalancingPolicy { class ClusterPicker : public SubchannelPicker { public: // Maintains a map of cluster names to pickers. - using ClusterMap = - std::map>; + using ClusterMap = std::map, std::less<>>; // It is required that the keys of cluster_map have to live at least as long // as the ClusterPicker instance. @@ -230,9 +231,13 @@ class XdsClusterManagerLb : public LoadBalancingPolicy { XdsClusterManagerLb::PickResult XdsClusterManagerLb::ClusterPicker::Pick( PickArgs args) { auto* call_state = static_cast(args.call_state); - auto cluster_name = - call_state->GetCallAttribute(XdsClusterAttributeTypeName()); - auto it = cluster_map_.find(std::string(cluster_name)); + auto* cluster_name_attribute = static_cast( + call_state->GetCallAttribute(XdsClusterAttribute::TypeName())); + absl::string_view cluster_name; + if (cluster_name_attribute != nullptr) { + cluster_name = cluster_name_attribute->cluster(); + } + auto it = cluster_map_.find(cluster_name); if (it != cluster_map_.end()) { return it->second->Pick(args); } diff --git a/src/core/ext/filters/client_channel/lb_policy/xds/xds_override_host.cc b/src/core/ext/filters/client_channel/lb_policy/xds/xds_override_host.cc index f36e0171455..4fa540ad4b0 100644 --- a/src/core/ext/filters/client_channel/lb_policy/xds/xds_override_host.cc +++ b/src/core/ext/filters/client_channel/lb_policy/xds/xds_override_host.cc @@ -370,8 +370,11 @@ XdsOverrideHostLb::Picker::PickOverridenHost(absl::string_view override_host) { LoadBalancingPolicy::PickResult XdsOverrideHostLb::Picker::Pick( LoadBalancingPolicy::PickArgs args) { auto* call_state = static_cast(args.call_state); - auto override_host = call_state->GetCallAttribute(XdsOverrideHostTypeName()); - auto overridden_host_pick = PickOverridenHost(override_host); + auto* override_host = static_cast( + call_state->GetCallAttribute(XdsOverrideHostAttribute::TypeName())); + auto overridden_host_pick = + PickOverridenHost(override_host != nullptr ? override_host->host_name() + : absl::string_view()); if (overridden_host_pick.has_value()) { return std::move(*overridden_host_pick); } diff --git a/src/core/ext/filters/client_channel/resolver/xds/xds_resolver.cc b/src/core/ext/filters/client_channel/resolver/xds/xds_resolver.cc index 752820d7e87..9470560a3d2 100644 --- a/src/core/ext/filters/client_channel/resolver/xds/xds_resolver.cc +++ b/src/core/ext/filters/client_channel/resolver/xds/xds_resolver.cc @@ -16,11 +16,11 @@ #include +#include #include #include #include -#include #include #include #include @@ -59,6 +59,7 @@ #include "src/core/ext/filters/client_channel/config_selector.h" #include "src/core/ext/filters/client_channel/lb_policy/ring_hash/ring_hash.h" +#include "src/core/ext/filters/client_channel/resolver/xds/xds_resolver.h" #include "src/core/ext/xds/xds_bootstrap.h" #include "src/core/ext/xds/xds_bootstrap_grpc.h" #include "src/core/ext/xds/xds_client_grpc.h" @@ -93,7 +94,7 @@ namespace grpc_core { TraceFlag grpc_xds_resolver_trace(false, "xds_resolver"); -UniqueTypeName XdsClusterAttributeTypeName() { +UniqueTypeName XdsClusterAttribute::TypeName() { static UniqueTypeName::Factory kFactory("xds_cluster_name"); return kFactory.Create(); } @@ -731,13 +732,15 @@ XdsResolver::XdsConfigSelector::GetCallConfig(GetCallConfigArgs args) { method_config->GetMethodParsedConfigVector(grpc_empty_slice()); call_config.service_config = std::move(method_config); } - call_config.call_attributes[XdsClusterAttributeTypeName()] = it->first; + call_config.call_attributes[XdsClusterAttribute::TypeName()] = + args.arena->New(it->first); std::string hash_string = absl::StrCat(hash.value()); char* hash_value = static_cast(args.arena->Alloc(hash_string.size() + 1)); memcpy(hash_value, hash_string.c_str(), hash_string.size()); hash_value[hash_string.size()] = '\0'; - call_config.call_attributes[RequestHashAttributeName()] = hash_value; + call_config.call_attributes[RequestHashAttribute::TypeName()] = + args.arena->New(hash_value); call_config.on_commit = [cluster_state = it->second->Ref()]() mutable { cluster_state.reset(); }; diff --git a/src/core/ext/filters/client_channel/resolver/xds/xds_resolver.h b/src/core/ext/filters/client_channel/resolver/xds/xds_resolver.h index 1693962e896..ff6e4523781 100644 --- a/src/core/ext/filters/client_channel/resolver/xds/xds_resolver.h +++ b/src/core/ext/filters/client_channel/resolver/xds/xds_resolver.h @@ -19,12 +19,27 @@ #include +#include "absl/strings/string_view.h" + #include "src/core/lib/gprpp/unique_type_name.h" +#include "src/core/lib/service_config/service_config_call_data.h" namespace grpc_core { -UniqueTypeName XdsClusterAttributeTypeName(); +class XdsClusterAttribute + : public ServiceConfigCallData::CallAttributeInterface { + public: + static UniqueTypeName TypeName(); + + explicit XdsClusterAttribute(absl::string_view cluster) : cluster_(cluster) {} + + absl::string_view cluster() const { return cluster_; } + + private: + UniqueTypeName type() const override { return TypeName(); } + absl::string_view cluster_; +}; } // namespace grpc_core #endif // GRPC_SRC_CORE_EXT_FILTERS_CLIENT_CHANNEL_RESOLVER_XDS_XDS_RESOLVER_H diff --git a/src/core/ext/filters/stateful_session/stateful_session_filter.cc b/src/core/ext/filters/stateful_session/stateful_session_filter.cc index 381fbc5b930..66d6bad0cc9 100644 --- a/src/core/ext/filters/stateful_session/stateful_session_filter.cc +++ b/src/core/ext/filters/stateful_session/stateful_session_filter.cc @@ -59,7 +59,7 @@ namespace grpc_core { TraceFlag grpc_stateful_session_filter_trace(false, "stateful_session_filter"); -UniqueTypeName XdsOverrideHostTypeName() { +UniqueTypeName XdsOverrideHostAttribute::TypeName() { static UniqueTypeName::Factory kFactory("xds_override_host"); return kFactory.Create(); } @@ -160,8 +160,8 @@ ArenaPromise StatefulSessionFilter::MakeCallPromise( } // We have a valid cookie, so add the call attribute to be used by the // xds_override_host LB policy. - service_config_call_data->SetCallAttribute(XdsOverrideHostTypeName(), - *cookie_value); + service_config_call_data->SetCallAttribute( + GetContext()->New(*cookie_value)); } // Intercept server initial metadata. call_args.server_initial_metadata->InterceptAndMap( diff --git a/src/core/ext/filters/stateful_session/stateful_session_filter.h b/src/core/ext/filters/stateful_session/stateful_session_filter.h index 1d28960bdb8..c45d98db6ad 100644 --- a/src/core/ext/filters/stateful_session/stateful_session_filter.h +++ b/src/core/ext/filters/stateful_session/stateful_session_filter.h @@ -30,11 +30,26 @@ #include "src/core/lib/channel/promise_based_filter.h" #include "src/core/lib/gprpp/unique_type_name.h" #include "src/core/lib/promise/arena_promise.h" +#include "src/core/lib/service_config/service_config_call_data.h" #include "src/core/lib/transport/transport.h" namespace grpc_core { -UniqueTypeName XdsOverrideHostTypeName(); +class XdsOverrideHostAttribute + : public ServiceConfigCallData::CallAttributeInterface { + public: + static UniqueTypeName TypeName(); + + explicit XdsOverrideHostAttribute(absl::string_view host_name) + : host_name_(host_name) {} + + absl::string_view host_name() const { return host_name_; } + + private: + UniqueTypeName type() const override { return TypeName(); } + + absl::string_view host_name_; +}; // A filter to provide cookie-based stateful session affinity. class StatefulSessionFilter : public ChannelFilter { diff --git a/src/core/lib/service_config/service_config_call_data.h b/src/core/lib/service_config/service_config_call_data.h index 9d05ed98a9c..97d78635ec8 100644 --- a/src/core/lib/service_config/service_config_call_data.h +++ b/src/core/lib/service_config/service_config_call_data.h @@ -25,8 +25,6 @@ #include #include -#include "absl/strings/string_view.h" - #include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/gprpp/unique_type_name.h" #include "src/core/lib/service_config/service_config.h" @@ -40,7 +38,13 @@ namespace grpc_core { /// easily access method and global parameters for the call. class ServiceConfigCallData { public: - using CallAttributes = std::map; + class CallAttributeInterface { + public: + virtual ~CallAttributeInterface() = default; + virtual UniqueTypeName type() const = 0; + }; + + using CallAttributes = std::map; ServiceConfigCallData() : method_configs_(nullptr) {} @@ -63,12 +67,16 @@ class ServiceConfigCallData { return service_config_->GetGlobalParsedConfig(index); } - const CallAttributes& call_attributes() const { return call_attributes_; } - // Must be called when holding the call combiner (legacy filter) or from // inside the activity (promise-based filter). - void SetCallAttribute(UniqueTypeName name, absl::string_view value) { - call_attributes_[name] = value; + void SetCallAttribute(CallAttributeInterface* value) { + call_attributes_.emplace(value->type(), value); + } + + CallAttributeInterface* GetCallAttribute(UniqueTypeName name) const { + auto it = call_attributes_.find(name); + if (it == call_attributes_.end()) return nullptr; + return it->second; } private: diff --git a/test/core/client_channel/lb_policy/lb_policy_test_lib.h b/test/core/client_channel/lb_policy/lb_policy_test_lib.h index 317f035f93c..2298091d021 100644 --- a/test/core/client_channel/lb_policy/lb_policy_test_lib.h +++ b/test/core/client_channel/lb_policy/lb_policy_test_lib.h @@ -76,6 +76,7 @@ #include "src/core/lib/load_balancing/lb_policy_registry.h" #include "src/core/lib/load_balancing/subchannel_interface.h" #include "src/core/lib/resolver/server_address.h" +#include "src/core/lib/service_config/service_config_call_data.h" #include "src/core/lib/transport/connectivity_state.h" #include "src/core/lib/uri/uri_parser.h" @@ -84,6 +85,9 @@ namespace testing { class LoadBalancingPolicyTest : public ::testing::Test { protected: + using CallAttributes = std::vector< + std::unique_ptr>; + // Channel-level subchannel state for a specific address and channel args. // This is analogous to the real subchannel in the ClientChannel code. class SubchannelState { @@ -464,10 +468,9 @@ class LoadBalancingPolicyTest : public ::testing::Test { // A fake CallState implementation, for use in PickArgs. class FakeCallState : public ClientChannelLbCallState { public: - explicit FakeCallState( - const std::map& attributes) { + explicit FakeCallState(const CallAttributes& attributes) { for (const auto& p : attributes) { - attributes_.emplace(p.first, std::string(p.second)); + attributes_.emplace(p->type(), p.get()); } } @@ -484,12 +487,18 @@ class LoadBalancingPolicyTest : public ::testing::Test { return allocation; } - absl::string_view GetCallAttribute(UniqueTypeName type) override { - return attributes_[type]; + ServiceConfigCallData::CallAttributeInterface* GetCallAttribute( + UniqueTypeName type) const override { + auto it = attributes_.find(type); + if (it != attributes_.end()) { + return it->second; + } + return nullptr; } std::vector allocations_; - std::map attributes_; + std::map + attributes_; }; // A fake BackendMetricAccessor implementation, for passing to @@ -699,11 +708,11 @@ class LoadBalancingPolicyTest : public ::testing::Test { // the old list followed by one READY update where the picker is using the // new list. Returns a picker if the reported states match expectations. RefCountedPtr - WaitForRoundRobinListChange( - absl::Span old_addresses, - absl::Span new_addresses, - const std::map& call_attributes = {}, - size_t num_iterations = 3, SourceLocation location = SourceLocation()) { + WaitForRoundRobinListChange(absl::Span old_addresses, + absl::Span new_addresses, + const CallAttributes& call_attributes = {}, + size_t num_iterations = 3, + SourceLocation location = SourceLocation()) { gpr_log(GPR_INFO, "Waiting for expected RR addresses..."); RefCountedPtr retval; size_t num_picks = @@ -762,7 +771,7 @@ class LoadBalancingPolicyTest : public ::testing::Test { // Does a pick and returns the result. LoadBalancingPolicy::PickResult DoPick( LoadBalancingPolicy::SubchannelPicker* picker, - const std::map& call_attributes = {}) { + const CallAttributes& call_attributes = {}) { ExecCtx exec_ctx; FakeMetadata metadata({}); FakeCallState call_state(call_attributes); @@ -770,10 +779,9 @@ class LoadBalancingPolicyTest : public ::testing::Test { } // Requests a pick on picker and expects a Queue result. - void ExpectPickQueued( - LoadBalancingPolicy::SubchannelPicker* picker, - const std::map& call_attributes = {}, - SourceLocation location = SourceLocation()) { + void ExpectPickQueued(LoadBalancingPolicy::SubchannelPicker* picker, + const CallAttributes call_attributes = {}, + SourceLocation location = SourceLocation()) { ASSERT_NE(picker, nullptr); auto pick_result = DoPick(picker, call_attributes); ASSERT_TRUE(absl::holds_alternative( @@ -791,7 +799,7 @@ class LoadBalancingPolicyTest : public ::testing::Test { // automatically to represent a complete call with no backend metric data. absl::optional ExpectPickComplete( LoadBalancingPolicy::SubchannelPicker* picker, - const std::map& call_attributes = {}, + const CallAttributes& call_attributes = {}, std::unique_ptr* subchannel_call_tracker = nullptr, SourceLocation location = SourceLocation()) { @@ -827,7 +835,7 @@ class LoadBalancingPolicyTest : public ::testing::Test { // list of addresses, or nullopt if a non-complete pick was returned. absl::optional> GetCompletePicks( LoadBalancingPolicy::SubchannelPicker* picker, size_t num_picks, - const std::map& call_attributes = {}, + const CallAttributes& call_attributes = {}, std::vector< std::unique_ptr>* subchannel_call_trackers = nullptr, @@ -874,11 +882,11 @@ class LoadBalancingPolicyTest : public ::testing::Test { // Checks that the picker has round-robin behavior over the specified // set of addresses. - void ExpectRoundRobinPicks( - LoadBalancingPolicy::SubchannelPicker* picker, - absl::Span addresses, - const std::map& call_attributes = {}, - size_t num_iterations = 3, SourceLocation location = SourceLocation()) { + void ExpectRoundRobinPicks(LoadBalancingPolicy::SubchannelPicker* picker, + absl::Span addresses, + const CallAttributes& call_attributes = {}, + size_t num_iterations = 3, + SourceLocation location = SourceLocation()) { auto picks = GetCompletePicks(picker, num_iterations * addresses.size(), call_attributes, nullptr, location); ASSERT_TRUE(picks.has_value()) << location.file() << ":" << location.line(); diff --git a/test/core/client_channel/lb_policy/weighted_round_robin_test.cc b/test/core/client_channel/lb_policy/weighted_round_robin_test.cc index 0aa03b43c51..83678266dc0 100644 --- a/test/core/client_channel/lb_policy/weighted_round_robin_test.cc +++ b/test/core/client_channel/lb_policy/weighted_round_robin_test.cc @@ -47,7 +47,6 @@ #include "src/core/lib/gprpp/orphanable.h" #include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/gprpp/time.h" -#include "src/core/lib/gprpp/unique_type_name.h" #include "src/core/lib/json/json.h" #include "src/core/lib/json/json_writer.h" #include "src/core/lib/load_balancing/lb_policy.h" diff --git a/test/core/client_channel/lb_policy/xds_override_host_test.cc b/test/core/client_channel/lb_policy/xds_override_host_test.cc index a3d001299d6..0dffc7c14af 100644 --- a/test/core/client_channel/lb_policy/xds_override_host_test.cc +++ b/test/core/client_channel/lb_policy/xds_override_host_test.cc @@ -16,6 +16,7 @@ #include +#include #include #include #include @@ -35,7 +36,6 @@ #include "src/core/ext/xds/xds_health_status.h" #include "src/core/lib/gprpp/orphanable.h" #include "src/core/lib/gprpp/ref_counted_ptr.h" -#include "src/core/lib/gprpp/unique_type_name.h" #include "src/core/lib/json/json.h" #include "src/core/lib/load_balancing/lb_policy.h" #include "src/core/lib/resolver/server_address.h" @@ -112,6 +112,13 @@ class XdsOverrideHostTest : public LoadBalancingPolicyTest { EXPECT_EQ(ApplyUpdate(update, policy_.get()), absl::OkStatus()); } + CallAttributes MakeOverrideHostAttribute(absl::string_view host) { + CallAttributes override_host_attributes; + override_host_attributes.emplace_back( + std::make_unique(host)); + return override_host_attributes; + } + OrphanablePtr policy_; }; @@ -134,13 +141,18 @@ TEST_F(XdsOverrideHostTest, OverrideHost) { auto picker = ExpectStartupWithRoundRobin(kAddresses); ASSERT_NE(picker, nullptr); // Check that the host is overridden - std::map call_attributes{ - {XdsOverrideHostTypeName(), kAddresses[1]}}; - EXPECT_EQ(ExpectPickComplete(picker.get(), call_attributes), kAddresses[1]); - EXPECT_EQ(ExpectPickComplete(picker.get(), call_attributes), kAddresses[1]); - call_attributes[XdsOverrideHostTypeName()] = kAddresses[0]; - EXPECT_EQ(ExpectPickComplete(picker.get(), call_attributes), kAddresses[0]); - EXPECT_EQ(ExpectPickComplete(picker.get(), call_attributes), kAddresses[0]); + EXPECT_EQ(ExpectPickComplete(picker.get(), + MakeOverrideHostAttribute(kAddresses[1])), + kAddresses[1]); + EXPECT_EQ(ExpectPickComplete(picker.get(), + MakeOverrideHostAttribute(kAddresses[1])), + kAddresses[1]); + EXPECT_EQ(ExpectPickComplete(picker.get(), + MakeOverrideHostAttribute(kAddresses[0])), + kAddresses[0]); + EXPECT_EQ(ExpectPickComplete(picker.get(), + MakeOverrideHostAttribute(kAddresses[0])), + kAddresses[0]); } TEST_F(XdsOverrideHostTest, SubchannelNotFound) { @@ -149,9 +161,8 @@ TEST_F(XdsOverrideHostTest, SubchannelNotFound) { "ipv4:127.0.0.1:441", "ipv4:127.0.0.1:442", "ipv4:127.0.0.1:443"}; auto picker = ExpectStartupWithRoundRobin(kAddresses); ASSERT_NE(picker, nullptr); - std::map call_attributes{ - {XdsOverrideHostTypeName(), "no such host"}}; - ExpectRoundRobinPicks(picker.get(), kAddresses, call_attributes); + ExpectRoundRobinPicks(picker.get(), kAddresses, + MakeOverrideHostAttribute("no such host")); } TEST_F(XdsOverrideHostTest, SubchannelsComeAndGo) { @@ -160,9 +171,8 @@ TEST_F(XdsOverrideHostTest, SubchannelsComeAndGo) { auto picker = ExpectStartupWithRoundRobin(kAddresses); ASSERT_NE(picker, nullptr); // Check that the host is overridden - std::map call_attributes{ - {XdsOverrideHostTypeName(), kAddresses[1]}}; - ExpectRoundRobinPicks(picker.get(), {kAddresses[1]}, call_attributes); + ExpectRoundRobinPicks(picker.get(), {kAddresses[1]}, + MakeOverrideHostAttribute(kAddresses[1])); // Some other address is gone EXPECT_EQ(ApplyUpdate(BuildUpdate({kAddresses[0], kAddresses[1]}, MakeXdsOverrideHostConfig()), @@ -175,7 +185,8 @@ TEST_F(XdsOverrideHostTest, SubchannelsComeAndGo) { picker = WaitForRoundRobinListChange(kAddresses, {kAddresses[0], kAddresses[1]}); // Make sure host override still works. - ExpectRoundRobinPicks(picker.get(), {kAddresses[1]}, call_attributes); + ExpectRoundRobinPicks(picker.get(), {kAddresses[1]}, + MakeOverrideHostAttribute(kAddresses[1])); // "Our" address is gone so others get returned in round-robin order EXPECT_EQ(ApplyUpdate(BuildUpdate({kAddresses[0], kAddresses[2]}, MakeXdsOverrideHostConfig()), @@ -186,7 +197,8 @@ TEST_F(XdsOverrideHostTest, SubchannelsComeAndGo) { // checking again afterward, because the host override won't actually // be used. WaitForRoundRobinListChange({kAddresses[0], kAddresses[1]}, - {kAddresses[0], kAddresses[2]}, call_attributes); + {kAddresses[0], kAddresses[2]}, + MakeOverrideHostAttribute(kAddresses[1])); // And now it is back EXPECT_EQ(ApplyUpdate(BuildUpdate({kAddresses[1], kAddresses[2]}, MakeXdsOverrideHostConfig()), @@ -196,7 +208,8 @@ TEST_F(XdsOverrideHostTest, SubchannelsComeAndGo) { picker = WaitForRoundRobinListChange({kAddresses[0], kAddresses[2]}, {kAddresses[1], kAddresses[2]}); // Make sure host override works. - ExpectRoundRobinPicks(picker.get(), {kAddresses[1]}, call_attributes); + ExpectRoundRobinPicks(picker.get(), {kAddresses[1]}, + MakeOverrideHostAttribute(kAddresses[1])); } TEST_F(XdsOverrideHostTest, FailedSubchannelIsNotPicked) { @@ -206,9 +219,9 @@ TEST_F(XdsOverrideHostTest, FailedSubchannelIsNotPicked) { auto picker = ExpectStartupWithRoundRobin(kAddresses); ASSERT_NE(picker, nullptr); // Check that the host is overridden - std::map pick_arg{ - {XdsOverrideHostTypeName(), kAddresses[1]}}; - EXPECT_EQ(ExpectPickComplete(picker.get(), pick_arg), kAddresses[1]); + EXPECT_EQ(ExpectPickComplete(picker.get(), + MakeOverrideHostAttribute(kAddresses[1])), + kAddresses[1]); auto subchannel = FindSubchannel(kAddresses[1]); ASSERT_NE(subchannel, nullptr); subchannel->SetConnectivityState(GRPC_CHANNEL_IDLE); @@ -222,7 +235,8 @@ TEST_F(XdsOverrideHostTest, FailedSubchannelIsNotPicked) { absl::ResourceExhaustedError("Hmmmm")); ExpectReresolutionRequest(); picker = ExpectState(GRPC_CHANNEL_READY); - ExpectRoundRobinPicks(picker.get(), {kAddresses[0], kAddresses[2]}, pick_arg); + ExpectRoundRobinPicks(picker.get(), {kAddresses[0], kAddresses[2]}, + MakeOverrideHostAttribute(kAddresses[1])); } TEST_F(XdsOverrideHostTest, ConnectingSubchannelIsQueued) { @@ -232,19 +246,25 @@ TEST_F(XdsOverrideHostTest, ConnectingSubchannelIsQueued) { auto picker = ExpectStartupWithRoundRobin(kAddresses); ASSERT_NE(picker, nullptr); // Check that the host is overridden - std::map pick_arg{ - {XdsOverrideHostTypeName(), kAddresses[1]}}; - EXPECT_EQ(ExpectPickComplete(picker.get(), pick_arg), kAddresses[1]); + EXPECT_EQ(ExpectPickComplete(picker.get(), + { + MakeOverrideHostAttribute(kAddresses[1]), + }), + kAddresses[1]); auto subchannel = FindSubchannel(kAddresses[1]); ASSERT_NE(subchannel, nullptr); subchannel->SetConnectivityState(GRPC_CHANNEL_IDLE); ExpectReresolutionRequest(); EXPECT_TRUE(subchannel->ConnectionRequested()); picker = ExpectState(GRPC_CHANNEL_READY); - ExpectPickQueued(picker.get(), pick_arg); + ExpectPickQueued(picker.get(), { + MakeOverrideHostAttribute(kAddresses[1]), + }); subchannel->SetConnectivityState(GRPC_CHANNEL_CONNECTING); picker = ExpectState(GRPC_CHANNEL_READY); - ExpectPickQueued(picker.get(), pick_arg); + ExpectPickQueued(picker.get(), { + MakeOverrideHostAttribute(kAddresses[1]), + }); } TEST_F(XdsOverrideHostTest, DrainingState) { @@ -262,16 +282,17 @@ TEST_F(XdsOverrideHostTest, DrainingState) { ExpectRoundRobinPicks(picker.get(), {kAddresses[0], kAddresses[2]}); ExpectQueueEmpty(); // Draining subchannel is returned - std::map pick_arg{ - {XdsOverrideHostTypeName(), kAddresses[1]}}; - EXPECT_EQ(ExpectPickComplete(picker.get(), pick_arg), kAddresses[1]); + EXPECT_EQ(ExpectPickComplete(picker.get(), + MakeOverrideHostAttribute(kAddresses[1])), + kAddresses[1]); ApplyUpdateWithHealthStatuses( {{kAddresses[0], XdsHealthStatus::HealthStatus::kUnknown}, {kAddresses[2], XdsHealthStatus::HealthStatus::kHealthy}}); picker = ExpectState(GRPC_CHANNEL_READY); ASSERT_NE(picker, nullptr); // Gone! - ExpectRoundRobinPicks(picker.get(), {kAddresses[0], kAddresses[2]}, pick_arg); + ExpectRoundRobinPicks(picker.get(), {kAddresses[0], kAddresses[2]}, + MakeOverrideHostAttribute(kAddresses[1])); } TEST_F(XdsOverrideHostTest, DrainingSubchannelIsConnecting) { @@ -281,9 +302,9 @@ TEST_F(XdsOverrideHostTest, DrainingSubchannelIsConnecting) { auto picker = ExpectStartupWithRoundRobin(kAddresses); ASSERT_NE(picker, nullptr); // Check that the host is overridden - std::map pick_arg{ - {XdsOverrideHostTypeName(), kAddresses[1]}}; - EXPECT_EQ(ExpectPickComplete(picker.get(), pick_arg), kAddresses[1]); + EXPECT_EQ(ExpectPickComplete(picker.get(), + MakeOverrideHostAttribute(kAddresses[1])), + kAddresses[1]); ApplyUpdateWithHealthStatuses( {{kAddresses[0], XdsHealthStatus::HealthStatus::kUnknown}, {kAddresses[1], XdsHealthStatus::HealthStatus::kDraining}, @@ -294,21 +315,25 @@ TEST_F(XdsOverrideHostTest, DrainingSubchannelIsConnecting) { // There are two notifications - one from child policy and one from the parent // policy due to draining channel update picker = ExpectState(GRPC_CHANNEL_READY); - EXPECT_EQ(ExpectPickComplete(picker.get(), pick_arg), kAddresses[1]); + EXPECT_EQ(ExpectPickComplete(picker.get(), + MakeOverrideHostAttribute(kAddresses[1])), + kAddresses[1]); ExpectRoundRobinPicks(picker.get(), {kAddresses[0], kAddresses[2]}); subchannel->SetConnectivityState(GRPC_CHANNEL_IDLE); picker = ExpectState(GRPC_CHANNEL_READY); - ExpectPickQueued(picker.get(), pick_arg); + ExpectPickQueued(picker.get(), MakeOverrideHostAttribute(kAddresses[1])); ExpectRoundRobinPicks(picker.get(), {kAddresses[0], kAddresses[2]}); EXPECT_TRUE(subchannel->ConnectionRequested()); ExpectQueueEmpty(); subchannel->SetConnectivityState(GRPC_CHANNEL_CONNECTING); picker = ExpectState(GRPC_CHANNEL_READY); - ExpectPickQueued(picker.get(), pick_arg); + ExpectPickQueued(picker.get(), MakeOverrideHostAttribute(kAddresses[1])); ExpectRoundRobinPicks(picker.get(), {kAddresses[0], kAddresses[2]}); subchannel->SetConnectivityState(GRPC_CHANNEL_READY); picker = ExpectState(GRPC_CHANNEL_READY); - EXPECT_EQ(ExpectPickComplete(picker.get(), pick_arg), kAddresses[1]); + EXPECT_EQ(ExpectPickComplete(picker.get(), + MakeOverrideHostAttribute(kAddresses[1])), + kAddresses[1]); ExpectRoundRobinPicks(picker.get(), {kAddresses[0], kAddresses[2]}); } @@ -326,9 +351,9 @@ TEST_F(XdsOverrideHostTest, DrainingToHealthy) { ASSERT_NE(picker, nullptr); ExpectRoundRobinPicks(picker.get(), {kAddresses[0], kAddresses[2]}); ExpectQueueEmpty(); - std::map pick_arg{ - {XdsOverrideHostTypeName(), kAddresses[1]}}; - EXPECT_EQ(ExpectPickComplete(picker.get(), pick_arg), kAddresses[1]); + EXPECT_EQ(ExpectPickComplete(picker.get(), + MakeOverrideHostAttribute(kAddresses[1])), + kAddresses[1]); ApplyUpdateWithHealthStatuses( {{kAddresses[0], XdsHealthStatus::HealthStatus::kHealthy}, {kAddresses[1], XdsHealthStatus::HealthStatus::kHealthy}, @@ -336,8 +361,12 @@ TEST_F(XdsOverrideHostTest, DrainingToHealthy) { {"UNKNOWN", "HEALTHY", "DRAINING"}); picker = ExpectState(GRPC_CHANNEL_READY); ASSERT_NE(picker, nullptr); - EXPECT_EQ(ExpectPickComplete(picker.get(), pick_arg), kAddresses[1]); - EXPECT_EQ(ExpectPickComplete(picker.get(), pick_arg), kAddresses[1]); + EXPECT_EQ(ExpectPickComplete(picker.get(), + MakeOverrideHostAttribute(kAddresses[1])), + kAddresses[1]); + EXPECT_EQ(ExpectPickComplete(picker.get(), + MakeOverrideHostAttribute(kAddresses[1])), + kAddresses[1]); } TEST_F(XdsOverrideHostTest, OverrideHostStatus) { @@ -353,13 +382,13 @@ TEST_F(XdsOverrideHostTest, OverrideHostStatus) { ASSERT_NE(picker, nullptr); ExpectRoundRobinPicks(picker.get(), {kAddresses[0], kAddresses[1]}); EXPECT_EQ(ExpectPickComplete(picker.get(), - {{XdsOverrideHostTypeName(), kAddresses[0]}}), + MakeOverrideHostAttribute(kAddresses[0])), kAddresses[0]); EXPECT_EQ(ExpectPickComplete(picker.get(), - {{XdsOverrideHostTypeName(), kAddresses[1]}}), + MakeOverrideHostAttribute(kAddresses[1])), kAddresses[1]); EXPECT_EQ(ExpectPickComplete(picker.get(), - {{XdsOverrideHostTypeName(), kAddresses[2]}}), + MakeOverrideHostAttribute(kAddresses[2])), kAddresses[2]); // UNKNOWN excluded - first chanel does not get overridden ApplyUpdateWithHealthStatuses( @@ -371,12 +400,12 @@ TEST_F(XdsOverrideHostTest, OverrideHostStatus) { ASSERT_NE(picker, nullptr); ExpectRoundRobinPicks(picker.get(), {kAddresses[0], kAddresses[1]}); ExpectRoundRobinPicks(picker.get(), {kAddresses[0], kAddresses[1]}, - {{XdsOverrideHostTypeName(), kAddresses[0]}}); + MakeOverrideHostAttribute(kAddresses[0])); EXPECT_EQ(ExpectPickComplete(picker.get(), - {{XdsOverrideHostTypeName(), kAddresses[1]}}), + MakeOverrideHostAttribute(kAddresses[1])), kAddresses[1]); EXPECT_EQ(ExpectPickComplete(picker.get(), - {{XdsOverrideHostTypeName(), kAddresses[2]}}), + MakeOverrideHostAttribute(kAddresses[2])), kAddresses[2]); // HEALTHY excluded - second chanel does not get overridden ApplyUpdateWithHealthStatuses( @@ -388,13 +417,13 @@ TEST_F(XdsOverrideHostTest, OverrideHostStatus) { ASSERT_NE(picker, nullptr); ExpectRoundRobinPicks(picker.get(), {kAddresses[0], kAddresses[1]}); EXPECT_EQ(ExpectPickComplete(picker.get(), - {{XdsOverrideHostTypeName(), kAddresses[0]}}), + MakeOverrideHostAttribute(kAddresses[0])), kAddresses[0]); EXPECT_EQ(ExpectPickComplete(picker.get(), - {{XdsOverrideHostTypeName(), kAddresses[1]}}), + MakeOverrideHostAttribute(kAddresses[1])), kAddresses[1]); ExpectRoundRobinPicks(picker.get(), {kAddresses[0], kAddresses[1]}, - {{XdsOverrideHostTypeName(), kAddresses[2]}}); + MakeOverrideHostAttribute(kAddresses[2])); // DRAINING excluded - third chanel does not get overridden ApplyUpdateWithHealthStatuses( {{kAddresses[0], XdsHealthStatus::HealthStatus::kUnknown}, @@ -405,13 +434,13 @@ TEST_F(XdsOverrideHostTest, OverrideHostStatus) { ASSERT_NE(picker, nullptr); ExpectRoundRobinPicks(picker.get(), {kAddresses[0], kAddresses[1]}); EXPECT_EQ(ExpectPickComplete(picker.get(), - {{XdsOverrideHostTypeName(), kAddresses[0]}}), + MakeOverrideHostAttribute(kAddresses[0])), kAddresses[0]); EXPECT_EQ(ExpectPickComplete(picker.get(), - {{XdsOverrideHostTypeName(), kAddresses[1]}}), + MakeOverrideHostAttribute(kAddresses[1])), kAddresses[1]); ExpectRoundRobinPicks(picker.get(), {kAddresses[0], kAddresses[1]}, - {{XdsOverrideHostTypeName(), kAddresses[2]}}); + MakeOverrideHostAttribute(kAddresses[2])); } } // namespace