diff --git a/src/core/ext/filters/client_channel/client_channel.cc b/src/core/ext/filters/client_channel/client_channel.cc index c4089151b26..e036610a6db 100644 --- a/src/core/ext/filters/client_channel/client_channel.cc +++ b/src/core/ext/filters/client_channel/client_channel.cc @@ -881,6 +881,9 @@ class CallData { // ChannelData::SubchannelWrapper // +using ServerAddressAttributeMap = + std::map>; + // This class is a wrapper for Subchannel that hides details of the // channel's implementation (such as the health check service name and // connected subchannel) from the LB policy API. @@ -892,11 +895,13 @@ class CallData { class ChannelData::SubchannelWrapper : public SubchannelInterface { public: SubchannelWrapper(ChannelData* chand, Subchannel* subchannel, - grpc_core::UniquePtr health_check_service_name) + grpc_core::UniquePtr health_check_service_name, + ServerAddressAttributeMap attributes) : SubchannelInterface(&grpc_client_channel_routing_trace), chand_(chand), subchannel_(subchannel), - health_check_service_name_(std::move(health_check_service_name)) { + health_check_service_name_(std::move(health_check_service_name)), + attributes_(std::move(attributes)) { if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_routing_trace)) { gpr_log(GPR_INFO, "chand=%p: creating subchannel wrapper %p for subchannel %p", @@ -974,14 +979,21 @@ class ChannelData::SubchannelWrapper : public SubchannelInterface { void ResetBackoff() override { subchannel_->ResetBackoff(); } - void ThrottleKeepaliveTime(int new_keepalive_time) { - subchannel_->ThrottleKeepaliveTime(new_keepalive_time); - } - const grpc_channel_args* channel_args() override { return subchannel_->channel_args(); } + const ServerAddress::AttributeInterface* GetAttribute( + const char* key) const override { + auto it = attributes_.find(key); + if (it == attributes_.end()) return nullptr; + return it->second.get(); + } + + void ThrottleKeepaliveTime(int new_keepalive_time) { + subchannel_->ThrottleKeepaliveTime(new_keepalive_time); + } + void UpdateHealthCheckServiceName( grpc_core::UniquePtr health_check_service_name) { if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_routing_trace)) { @@ -1175,6 +1187,7 @@ class ChannelData::SubchannelWrapper : public SubchannelInterface { ChannelData* chand_; Subchannel* subchannel_; grpc_core::UniquePtr health_check_service_name_; + ServerAddressAttributeMap attributes_; // Maps from the address of the watcher passed to us by the LB policy // to the address of the WrapperWatcher that we passed to the underlying // subchannel. This is needed so that when the LB policy calls @@ -1349,6 +1362,18 @@ class ChannelData::ConnectivityWatcherRemover { // ChannelData::ClientChannelControlHelper // +} // namespace + +// Allows accessing the attributes from a ServerAddress. +class ChannelServerAddressPeer { + public: + static ServerAddressAttributeMap GetAttributes(ServerAddress* address) { + return std::move(address->attributes_); + } +}; + +namespace { + class ChannelData::ClientChannelControlHelper : public LoadBalancingPolicy::ChannelControlHelper { public: @@ -1362,7 +1387,8 @@ class ChannelData::ClientChannelControlHelper } RefCountedPtr CreateSubchannel( - const grpc_channel_args& args) override { + ServerAddress address, const grpc_channel_args& args) override { + // Determine health check service name. bool inhibit_health_checking = grpc_channel_arg_get_bool( grpc_channel_args_find(&args, GRPC_ARG_INHIBIT_HEALTH_CHECKING), false); grpc_core::UniquePtr health_check_service_name; @@ -1370,21 +1396,37 @@ class ChannelData::ClientChannelControlHelper health_check_service_name.reset( gpr_strdup(chand_->health_check_service_name_.get())); } + // Remove channel args that should not affect subchannel uniqueness. static const char* args_to_remove[] = { GRPC_ARG_INHIBIT_HEALTH_CHECKING, GRPC_ARG_CHANNELZ_CHANNEL_NODE, }; - grpc_arg arg = SubchannelPoolInterface::CreateChannelArg( - chand_->subchannel_pool_.get()); + // Add channel args needed for the subchannel. + absl::InlinedVector args_to_add = { + Subchannel::CreateSubchannelAddressArg(&address.address()), + SubchannelPoolInterface::CreateChannelArg( + chand_->subchannel_pool_.get()), + }; + if (address.args() != nullptr) { + for (size_t j = 0; j < address.args()->num_args; ++j) { + args_to_add.emplace_back(address.args()->args[j]); + } + } grpc_channel_args* new_args = grpc_channel_args_copy_and_add_and_remove( - &args, args_to_remove, GPR_ARRAY_SIZE(args_to_remove), &arg, 1); + &args, args_to_remove, GPR_ARRAY_SIZE(args_to_remove), + args_to_add.data(), args_to_add.size()); + gpr_free(args_to_add[0].value.string); + // Create subchannel. Subchannel* subchannel = chand_->client_channel_factory_->CreateSubchannel(new_args); grpc_channel_args_destroy(new_args); if (subchannel == nullptr) return nullptr; + // Make sure the subchannel has updated keepalive time. subchannel->ThrottleKeepaliveTime(chand_->keepalive_time_); + // Create and return wrapper for the subchannel. return MakeRefCounted( - chand_, subchannel, std::move(health_check_service_name)); + chand_, subchannel, std::move(health_check_service_name), + ChannelServerAddressPeer::GetAttributes(&address)); } void UpdateState( @@ -1662,9 +1704,12 @@ ChannelData::ChannelData(grpc_channel_element_args* args, grpc_error** error) &new_args); target_uri_.reset(proxy_name != nullptr ? proxy_name : gpr_strdup(server_uri)); - channel_args_ = new_args != nullptr - ? new_args - : grpc_channel_args_copy(args->channel_args); + // Strip out service config channel arg, so that it doesn't affect + // subchannel uniqueness when the args flow down to that layer. + const char* arg_to_remove = GRPC_ARG_SERVICE_CONFIG; + channel_args_ = grpc_channel_args_copy_and_remove( + new_args != nullptr ? new_args : args->channel_args, &arg_to_remove, 1); + grpc_channel_args_destroy(new_args); keepalive_time_ = grpc_channel_args_find_integer( channel_args_, GRPC_ARG_KEEPALIVE_TIME_MS, {-1 /* default value, unset */, 1, INT_MAX}); diff --git a/src/core/ext/filters/client_channel/lb_policy.h b/src/core/ext/filters/client_channel/lb_policy.h index 81f5bb33b20..e132564bc38 100644 --- a/src/core/ext/filters/client_channel/lb_policy.h +++ b/src/core/ext/filters/client_channel/lb_policy.h @@ -279,7 +279,7 @@ class LoadBalancingPolicy : public InternallyRefCounted { /// Creates a new subchannel with the specified channel args. virtual RefCountedPtr CreateSubchannel( - const grpc_channel_args& args) = 0; + ServerAddress address, const grpc_channel_args& args) = 0; /// Sets the connectivity state and returns a new picker to be used /// by the client channel. diff --git a/src/core/ext/filters/client_channel/lb_policy/child_policy_handler.cc b/src/core/ext/filters/client_channel/lb_policy/child_policy_handler.cc index 1464c718602..2ca1573af6a 100644 --- a/src/core/ext/filters/client_channel/lb_policy/child_policy_handler.cc +++ b/src/core/ext/filters/client_channel/lb_policy/child_policy_handler.cc @@ -39,10 +39,11 @@ class ChildPolicyHandler::Helper ~Helper() { parent_.reset(DEBUG_LOCATION, "Helper"); } RefCountedPtr CreateSubchannel( - const grpc_channel_args& args) override { + ServerAddress address, const grpc_channel_args& args) override { if (parent_->shutting_down_) return nullptr; if (!CalledByCurrentChild() && !CalledByPendingChild()) return nullptr; - return parent_->channel_control_helper()->CreateSubchannel(args); + return parent_->channel_control_helper()->CreateSubchannel( + std::move(address), args); } void UpdateState(grpc_connectivity_state state, const absl::Status& status, diff --git a/src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb.cc b/src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb.cc index 005d913fb2c..da8794b7c61 100644 --- a/src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb.cc +++ b/src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb.cc @@ -302,7 +302,7 @@ class GrpcLb : public LoadBalancingPolicy { : parent_(std::move(parent)) {} RefCountedPtr CreateSubchannel( - const grpc_channel_args& args) override; + ServerAddress address, const grpc_channel_args& args) override; void UpdateState(grpc_connectivity_state state, const absl::Status& status, std::unique_ptr picker) override; void RequestReresolution() override; @@ -654,9 +654,10 @@ GrpcLb::PickResult GrpcLb::Picker::Pick(PickArgs args) { // RefCountedPtr GrpcLb::Helper::CreateSubchannel( - const grpc_channel_args& args) { + ServerAddress address, const grpc_channel_args& args) { if (parent_->shutting_down_) return nullptr; - return parent_->channel_control_helper()->CreateSubchannel(args); + return parent_->channel_control_helper()->CreateSubchannel(std::move(address), + args); } void GrpcLb::Helper::UpdateState(grpc_connectivity_state state, diff --git a/src/core/ext/filters/client_channel/lb_policy/pick_first/pick_first.cc b/src/core/ext/filters/client_channel/lb_policy/pick_first/pick_first.cc index 4a1f006cf64..784e964960f 100644 --- a/src/core/ext/filters/client_channel/lb_policy/pick_first/pick_first.cc +++ b/src/core/ext/filters/client_channel/lb_policy/pick_first/pick_first.cc @@ -84,9 +84,9 @@ class PickFirst : public LoadBalancingPolicy { PickFirstSubchannelData> { public: PickFirstSubchannelList(PickFirst* policy, TraceFlag* tracer, - const ServerAddressList& addresses, + ServerAddressList addresses, const grpc_channel_args& args) - : SubchannelList(policy, tracer, addresses, + : SubchannelList(policy, tracer, std::move(addresses), policy->channel_control_helper(), args) { // Need to maintain a ref to the LB policy as long as we maintain // any references to subchannels, since the subchannels' diff --git a/src/core/ext/filters/client_channel/lb_policy/priority/priority.cc b/src/core/ext/filters/client_channel/lb_policy/priority/priority.cc index 6b0489ceb16..07a7b2700cb 100644 --- a/src/core/ext/filters/client_channel/lb_policy/priority/priority.cc +++ b/src/core/ext/filters/client_channel/lb_policy/priority/priority.cc @@ -155,7 +155,7 @@ class PriorityLb : public LoadBalancingPolicy { ~Helper() { priority_.reset(DEBUG_LOCATION, "Helper"); } RefCountedPtr CreateSubchannel( - const grpc_channel_args& args) override; + ServerAddress address, const grpc_channel_args& args) override; void UpdateState(grpc_connectivity_state state, const absl::Status& status, std::unique_ptr picker) override; @@ -736,10 +736,10 @@ void PriorityLb::ChildPriority::Helper::RequestReresolution() { RefCountedPtr PriorityLb::ChildPriority::Helper::CreateSubchannel( - const grpc_channel_args& args) { + ServerAddress address, const grpc_channel_args& args) { if (priority_->priority_policy_->shutting_down_) return nullptr; return priority_->priority_policy_->channel_control_helper() - ->CreateSubchannel(args); + ->CreateSubchannel(std::move(address), args); } void PriorityLb::ChildPriority::Helper::UpdateState( diff --git a/src/core/ext/filters/client_channel/lb_policy/round_robin/round_robin.cc b/src/core/ext/filters/client_channel/lb_policy/round_robin/round_robin.cc index d1ad2ca35e7..6dc107702e3 100644 --- a/src/core/ext/filters/client_channel/lb_policy/round_robin/round_robin.cc +++ b/src/core/ext/filters/client_channel/lb_policy/round_robin/round_robin.cc @@ -112,9 +112,9 @@ class RoundRobin : public LoadBalancingPolicy { RoundRobinSubchannelData> { public: RoundRobinSubchannelList(RoundRobin* policy, TraceFlag* tracer, - const ServerAddressList& addresses, + ServerAddressList addresses, const grpc_channel_args& args) - : SubchannelList(policy, tracer, addresses, + : SubchannelList(policy, tracer, std::move(addresses), policy->channel_control_helper(), args) { // Need to maintain a ref to the LB policy as long as we maintain // any references to subchannels, since the subchannels' @@ -445,7 +445,7 @@ void RoundRobin::UpdateLocked(UpdateArgs args) { } } latest_pending_subchannel_list_ = MakeOrphanable( - this, &grpc_lb_round_robin_trace, args.addresses, *args.args); + this, &grpc_lb_round_robin_trace, std::move(args.addresses), *args.args); if (latest_pending_subchannel_list_->num_subchannels() == 0) { // If the new list is empty, immediately promote the new list to the // current list and transition to TRANSIENT_FAILURE. diff --git a/src/core/ext/filters/client_channel/lb_policy/subchannel_list.h b/src/core/ext/filters/client_channel/lb_policy/subchannel_list.h index 940bad02b9d..e6d0b546f77 100644 --- a/src/core/ext/filters/client_channel/lb_policy/subchannel_list.h +++ b/src/core/ext/filters/client_channel/lb_policy/subchannel_list.h @@ -200,7 +200,7 @@ class SubchannelList : public InternallyRefCounted { protected: SubchannelList(LoadBalancingPolicy* policy, TraceFlag* tracer, - const ServerAddressList& addresses, + ServerAddressList addresses, LoadBalancingPolicy::ChannelControlHelper* helper, const grpc_channel_args& args); @@ -350,8 +350,7 @@ void SubchannelData::ShutdownLocked() { template SubchannelList::SubchannelList( - LoadBalancingPolicy* policy, TraceFlag* tracer, - const ServerAddressList& addresses, + LoadBalancingPolicy* policy, TraceFlag* tracer, ServerAddressList addresses, LoadBalancingPolicy::ChannelControlHelper* helper, const grpc_channel_args& args) : InternallyRefCounted(tracer), @@ -363,50 +362,28 @@ SubchannelList::SubchannelList( tracer_->name(), policy, this, addresses.size()); } subchannels_.reserve(addresses.size()); - // We need to remove the LB addresses in order to be able to compare the - // subchannel keys of subchannels from a different batch of addresses. - // We remove the service config, since it will be passed into the - // subchannel via call context. - static const char* keys_to_remove[] = {GRPC_ARG_SUBCHANNEL_ADDRESS, - GRPC_ARG_SERVICE_CONFIG}; // Create a subchannel for each address. - for (size_t i = 0; i < addresses.size(); i++) { - absl::InlinedVector args_to_add; - const size_t subchannel_address_arg_index = args_to_add.size(); - args_to_add.emplace_back( - Subchannel::CreateSubchannelAddressArg(&addresses[i].address())); - if (addresses[i].args() != nullptr) { - for (size_t j = 0; j < addresses[i].args()->num_args; ++j) { - args_to_add.emplace_back(addresses[i].args()->args[j]); - } - } - grpc_channel_args* new_args = grpc_channel_args_copy_and_add_and_remove( - &args, keys_to_remove, GPR_ARRAY_SIZE(keys_to_remove), - args_to_add.data(), args_to_add.size()); - gpr_free(args_to_add[subchannel_address_arg_index].value.string); + for (const ServerAddress& address : addresses) { RefCountedPtr subchannel = - helper->CreateSubchannel(*new_args); - grpc_channel_args_destroy(new_args); + helper->CreateSubchannel(std::move(address), args); if (subchannel == nullptr) { // Subchannel could not be created. if (GRPC_TRACE_FLAG_ENABLED(*tracer_)) { gpr_log(GPR_INFO, - "[%s %p] could not create subchannel for address uri %s, " + "[%s %p] could not create subchannel for address %s, " "ignoring", - tracer_->name(), policy_, - grpc_sockaddr_to_uri(&addresses[i].address()).c_str()); + tracer_->name(), policy_, address.ToString().c_str()); } continue; } if (GRPC_TRACE_FLAG_ENABLED(*tracer_)) { gpr_log(GPR_INFO, "[%s %p] subchannel list %p index %" PRIuPTR - ": Created subchannel %p for address uri %s", + ": Created subchannel %p for address %s", tracer_->name(), policy_, this, subchannels_.size(), - subchannel.get(), - grpc_sockaddr_to_uri(&addresses[i].address()).c_str()); + subchannel.get(), address.ToString().c_str()); } - subchannels_.emplace_back(this, addresses[i], std::move(subchannel)); + subchannels_.emplace_back(this, address, std::move(subchannel)); } } diff --git a/src/core/ext/filters/client_channel/lb_policy/weighted_target/weighted_target.cc b/src/core/ext/filters/client_channel/lb_policy/weighted_target/weighted_target.cc index c1baca11258..0f99692da5c 100644 --- a/src/core/ext/filters/client_channel/lb_policy/weighted_target/weighted_target.cc +++ b/src/core/ext/filters/client_channel/lb_policy/weighted_target/weighted_target.cc @@ -145,7 +145,7 @@ class WeightedTargetLb : public LoadBalancingPolicy { ~Helper() { weighted_child_.reset(DEBUG_LOCATION, "Helper"); } RefCountedPtr CreateSubchannel( - const grpc_channel_args& args) override; + ServerAddress address, const grpc_channel_args& args) override; void UpdateState(grpc_connectivity_state state, const absl::Status& status, std::unique_ptr picker) override; @@ -590,10 +590,10 @@ void WeightedTargetLb::WeightedChild::OnDelayedRemovalTimerLocked( RefCountedPtr WeightedTargetLb::WeightedChild::Helper::CreateSubchannel( - const grpc_channel_args& args) { + ServerAddress address, const grpc_channel_args& args) { if (weighted_child_->weighted_target_policy_->shutting_down_) return nullptr; return weighted_child_->weighted_target_policy_->channel_control_helper() - ->CreateSubchannel(args); + ->CreateSubchannel(std::move(address), args); } void WeightedTargetLb::WeightedChild::Helper::UpdateState( diff --git a/src/core/ext/filters/client_channel/lb_policy/xds/cds.cc b/src/core/ext/filters/client_channel/lb_policy/xds/cds.cc index 5f991f90910..9978b55cb4d 100644 --- a/src/core/ext/filters/client_channel/lb_policy/xds/cds.cc +++ b/src/core/ext/filters/client_channel/lb_policy/xds/cds.cc @@ -79,7 +79,7 @@ class CdsLb : public LoadBalancingPolicy { public: explicit Helper(RefCountedPtr parent) : parent_(std::move(parent)) {} RefCountedPtr CreateSubchannel( - const grpc_channel_args& args) override; + ServerAddress address, const grpc_channel_args& args) override; void UpdateState(grpc_connectivity_state state, const absl::Status& status, std::unique_ptr picker) override; void RequestReresolution() override; @@ -239,9 +239,10 @@ void CdsLb::ClusterWatcher::OnResourceDoesNotExist() { // RefCountedPtr CdsLb::Helper::CreateSubchannel( - const grpc_channel_args& args) { + ServerAddress address, const grpc_channel_args& args) { if (parent_->shutting_down_) return nullptr; - return parent_->channel_control_helper()->CreateSubchannel(args); + return parent_->channel_control_helper()->CreateSubchannel(std::move(address), + args); } void CdsLb::Helper::UpdateState(grpc_connectivity_state state, diff --git a/src/core/ext/filters/client_channel/lb_policy/xds/eds.cc b/src/core/ext/filters/client_channel/lb_policy/xds/eds.cc index ef604c0318d..dfd9a065b81 100644 --- a/src/core/ext/filters/client_channel/lb_policy/xds/eds.cc +++ b/src/core/ext/filters/client_channel/lb_policy/xds/eds.cc @@ -133,7 +133,7 @@ class EdsLb : public LoadBalancingPolicy { ~Helper() { eds_policy_.reset(DEBUG_LOCATION, "Helper"); } RefCountedPtr CreateSubchannel( - const grpc_channel_args& args) override; + ServerAddress address, const grpc_channel_args& args) override; void UpdateState(grpc_connectivity_state state, const absl::Status& status, std::unique_ptr picker) override; // This is a no-op, because we get the addresses from the xds @@ -261,9 +261,10 @@ EdsLb::PickResult EdsLb::DropPicker::Pick(PickArgs args) { // RefCountedPtr EdsLb::Helper::CreateSubchannel( - const grpc_channel_args& args) { + ServerAddress address, const grpc_channel_args& args) { if (eds_policy_->shutting_down_) return nullptr; - return eds_policy_->channel_control_helper()->CreateSubchannel(args); + return eds_policy_->channel_control_helper()->CreateSubchannel( + std::move(address), args); } void EdsLb::Helper::UpdateState(grpc_connectivity_state state, diff --git a/src/core/ext/filters/client_channel/lb_policy/xds/lrs.cc b/src/core/ext/filters/client_channel/lb_policy/xds/lrs.cc index e8553d5640e..297f83c4226 100644 --- a/src/core/ext/filters/client_channel/lb_policy/xds/lrs.cc +++ b/src/core/ext/filters/client_channel/lb_policy/xds/lrs.cc @@ -119,7 +119,7 @@ class LrsLb : public LoadBalancingPolicy { ~Helper() { lrs_policy_.reset(DEBUG_LOCATION, "Helper"); } RefCountedPtr CreateSubchannel( - const grpc_channel_args& args) override; + ServerAddress address, const grpc_channel_args& args) override; void UpdateState(grpc_connectivity_state state, const absl::Status& status, std::unique_ptr picker) override; void RequestReresolution() override; @@ -324,9 +324,10 @@ void LrsLb::UpdateChildPolicyLocked(ServerAddressList addresses, // RefCountedPtr LrsLb::Helper::CreateSubchannel( - const grpc_channel_args& args) { + ServerAddress address, const grpc_channel_args& args) { if (lrs_policy_->shutting_down_) return nullptr; - return lrs_policy_->channel_control_helper()->CreateSubchannel(args); + return lrs_policy_->channel_control_helper()->CreateSubchannel( + std::move(address), args); } void LrsLb::Helper::UpdateState(grpc_connectivity_state state, diff --git a/src/core/ext/filters/client_channel/lb_policy/xds/xds_cluster_manager.cc b/src/core/ext/filters/client_channel/lb_policy/xds/xds_cluster_manager.cc index fff6c411a08..6f7279bff05 100644 --- a/src/core/ext/filters/client_channel/lb_policy/xds/xds_cluster_manager.cc +++ b/src/core/ext/filters/client_channel/lb_policy/xds/xds_cluster_manager.cc @@ -154,7 +154,7 @@ class XdsClusterManagerLb : public LoadBalancingPolicy { ~Helper() { xds_cluster_manager_child_.reset(DEBUG_LOCATION, "Helper"); } RefCountedPtr CreateSubchannel( - const grpc_channel_args& args) override; + ServerAddress address, const grpc_channel_args& args) override; void UpdateState(grpc_connectivity_state state, const absl::Status& status, std::unique_ptr picker) override; @@ -546,12 +546,12 @@ void XdsClusterManagerLb::ClusterChild::OnDelayedRemovalTimerLocked( RefCountedPtr XdsClusterManagerLb::ClusterChild::Helper::CreateSubchannel( - const grpc_channel_args& args) { + ServerAddress address, const grpc_channel_args& args) { if (xds_cluster_manager_child_->xds_cluster_manager_policy_->shutting_down_) return nullptr; return xds_cluster_manager_child_->xds_cluster_manager_policy_ ->channel_control_helper() - ->CreateSubchannel(args); + ->CreateSubchannel(std::move(address), args); } void XdsClusterManagerLb::ClusterChild::Helper::UpdateState( diff --git a/src/core/ext/filters/client_channel/resolving_lb_policy.cc b/src/core/ext/filters/client_channel/resolving_lb_policy.cc index 5d1d0f0eb4d..49512c5d5ab 100644 --- a/src/core/ext/filters/client_channel/resolving_lb_policy.cc +++ b/src/core/ext/filters/client_channel/resolving_lb_policy.cc @@ -109,9 +109,10 @@ class ResolvingLoadBalancingPolicy::ResolvingControlHelper : parent_(std::move(parent)) {} RefCountedPtr CreateSubchannel( - const grpc_channel_args& args) override { + ServerAddress address, const grpc_channel_args& args) override { if (parent_->resolver_ == nullptr) return nullptr; // Shutting down. - return parent_->channel_control_helper()->CreateSubchannel(args); + return parent_->channel_control_helper()->CreateSubchannel( + std::move(address), args); } void UpdateState(grpc_connectivity_state state, const absl::Status& status, diff --git a/src/core/ext/filters/client_channel/server_address.h b/src/core/ext/filters/client_channel/server_address.h index 7a188a0ce45..ddcf530d3c8 100644 --- a/src/core/ext/filters/client_channel/server_address.h +++ b/src/core/ext/filters/client_channel/server_address.h @@ -97,6 +97,10 @@ class ServerAddress { std::string ToString() const; private: + // Allows the channel to access the attributes without knowing the keys. + // (We intentionally do not allow LB policies to do this.) + friend class ChannelServerAddressPeer; + grpc_resolved_address address_; grpc_channel_args* args_; std::map> attributes_; diff --git a/src/core/ext/filters/client_channel/subchannel_interface.h b/src/core/ext/filters/client_channel/subchannel_interface.h index f7a788a9b94..61a999aab2a 100644 --- a/src/core/ext/filters/client_channel/subchannel_interface.h +++ b/src/core/ext/filters/client_channel/subchannel_interface.h @@ -21,6 +21,7 @@ #include +#include "src/core/ext/filters/client_channel/server_address.h" #include "src/core/lib/gprpp/ref_counted.h" #include "src/core/lib/gprpp/ref_counted_ptr.h" @@ -87,6 +88,11 @@ class SubchannelInterface : public RefCounted { // TODO(roth): Need a better non-grpc-specific abstraction here. virtual const grpc_channel_args* channel_args() = 0; + + // Allows accessing the attributes associated with the address for + // this subchannel. + virtual const ServerAddress::AttributeInterface* GetAttribute( + const char* key) const = 0; }; } // namespace grpc_core diff --git a/test/core/util/test_lb_policies.cc b/test/core/util/test_lb_policies.cc index 5aa5dabafa0..b6804fd132c 100644 --- a/test/core/util/test_lb_policies.cc +++ b/test/core/util/test_lb_policies.cc @@ -138,8 +138,9 @@ class TestPickArgsLb : public ForwardingLoadBalancingPolicy { : parent_(std::move(parent)), cb_(std::move(cb)) {} RefCountedPtr CreateSubchannel( - const grpc_channel_args& args) override { - return parent_->channel_control_helper()->CreateSubchannel(args); + ServerAddress address, const grpc_channel_args& args) override { + return parent_->channel_control_helper()->CreateSubchannel( + std::move(address), args); } void UpdateState(grpc_connectivity_state state, const absl::Status& status, @@ -248,8 +249,9 @@ class InterceptRecvTrailingMetadataLoadBalancingPolicy : parent_(std::move(parent)), cb_(std::move(cb)) {} RefCountedPtr CreateSubchannel( - const grpc_channel_args& args) override { - return parent_->channel_control_helper()->CreateSubchannel(args); + ServerAddress address, const grpc_channel_args& args) override { + return parent_->channel_control_helper()->CreateSubchannel( + std::move(address), args); } void UpdateState(grpc_connectivity_state state, const absl::Status& status, @@ -331,6 +333,87 @@ class InterceptTrailingFactory : public LoadBalancingPolicyFactory { InterceptRecvTrailingMetadataCallback cb_; }; +// +// AddressTestLoadBalancingPolicy +// + +constexpr char kAddressTestLbPolicyName[] = "address_test_lb"; + +class AddressTestLoadBalancingPolicy : public ForwardingLoadBalancingPolicy { + public: + AddressTestLoadBalancingPolicy(Args args, AddressTestCallback cb) + : ForwardingLoadBalancingPolicy( + absl::make_unique( + RefCountedPtr(this), + std::move(cb)), + std::move(args), + /*delegate_lb_policy_name=*/"pick_first", + /*initial_refcount=*/2) {} + + ~AddressTestLoadBalancingPolicy() override = default; + + const char* name() const override { return kAddressTestLbPolicyName; } + + private: + class Helper : public ChannelControlHelper { + public: + Helper(RefCountedPtr parent, + AddressTestCallback cb) + : parent_(std::move(parent)), cb_(std::move(cb)) {} + + RefCountedPtr CreateSubchannel( + ServerAddress address, const grpc_channel_args& args) override { + cb_(address); + return parent_->channel_control_helper()->CreateSubchannel( + std::move(address), args); + } + + void UpdateState(grpc_connectivity_state state, const absl::Status& status, + std::unique_ptr picker) override { + parent_->channel_control_helper()->UpdateState(state, status, + std::move(picker)); + } + + void RequestReresolution() override { + parent_->channel_control_helper()->RequestReresolution(); + } + + void AddTraceEvent(TraceSeverity severity, + absl::string_view message) override { + parent_->channel_control_helper()->AddTraceEvent(severity, message); + } + + private: + RefCountedPtr parent_; + AddressTestCallback cb_; + }; +}; + +class AddressTestConfig : public LoadBalancingPolicy::Config { + public: + const char* name() const override { return kAddressTestLbPolicyName; } +}; + +class AddressTestFactory : public LoadBalancingPolicyFactory { + public: + explicit AddressTestFactory(AddressTestCallback cb) : cb_(std::move(cb)) {} + + OrphanablePtr CreateLoadBalancingPolicy( + LoadBalancingPolicy::Args args) const override { + return MakeOrphanable(std::move(args), cb_); + } + + const char* name() const override { return kAddressTestLbPolicyName; } + + RefCountedPtr ParseLoadBalancingConfig( + const Json& /*json*/, grpc_error** /*error*/) const override { + return MakeRefCounted(); + } + + private: + AddressTestCallback cb_; +}; + } // namespace void RegisterTestPickArgsLoadBalancingPolicy(TestPickArgsCallback cb) { @@ -344,4 +427,9 @@ void RegisterInterceptRecvTrailingMetadataLoadBalancingPolicy( absl::make_unique(std::move(cb))); } +void RegisterAddressTestLoadBalancingPolicy(AddressTestCallback cb) { + LoadBalancingPolicyRegistry::Builder::RegisterLoadBalancingPolicyFactory( + absl::make_unique(std::move(cb))); +} + } // namespace grpc_core diff --git a/test/core/util/test_lb_policies.h b/test/core/util/test_lb_policies.h index ffb079181ea..e125bb88a3b 100644 --- a/test/core/util/test_lb_policies.h +++ b/test/core/util/test_lb_policies.h @@ -45,11 +45,16 @@ using InterceptRecvTrailingMetadataCallback = std::function; // Registers an LB policy called "intercept_trailing_metadata_lb" that -// invokes cb with argument user_data when trailing metadata is received -// for each call. +// invokes cb when trailing metadata is received for each call. void RegisterInterceptRecvTrailingMetadataLoadBalancingPolicy( InterceptRecvTrailingMetadataCallback cb); +using AddressTestCallback = std::function; + +// Registers an LB policy called "address_test_lb" that invokes cb for each +// address used to create a subchannel. +void RegisterAddressTestLoadBalancingPolicy(AddressTestCallback cb); + } // namespace grpc_core #endif // GRPC_TEST_CORE_UTIL_TEST_LB_POLICIES_H diff --git a/test/cpp/end2end/client_lb_end2end_test.cc b/test/cpp/end2end/client_lb_end2end_test.cc index fc4149aaa2a..7f4427cda46 100644 --- a/test/cpp/end2end/client_lb_end2end_test.cc +++ b/test/cpp/end2end/client_lb_end2end_test.cc @@ -161,11 +161,14 @@ class FakeResolverResponseGeneratorWrapper { response_generator_ = std::move(other.response_generator_); } - void SetNextResolution(const std::vector& ports, - const char* service_config_json = nullptr) { + void SetNextResolution( + const std::vector& ports, const char* service_config_json = nullptr, + const char* attribute_key = nullptr, + std::unique_ptr attribute = + nullptr) { grpc_core::ExecCtx exec_ctx; - response_generator_->SetResponse( - BuildFakeResults(ports, service_config_json)); + response_generator_->SetResponse(BuildFakeResults( + ports, service_config_json, attribute_key, std::move(attribute))); } void SetNextResolutionUponError(const std::vector& ports) { @@ -184,8 +187,10 @@ class FakeResolverResponseGeneratorWrapper { private: static grpc_core::Resolver::Result BuildFakeResults( - const std::vector& ports, - const char* service_config_json = nullptr) { + const std::vector& ports, const char* service_config_json = nullptr, + const char* attribute_key = nullptr, + std::unique_ptr attribute = + nullptr) { grpc_core::Resolver::Result result; for (const int& port : ports) { std::string lb_uri_str = absl::StrCat("ipv4:127.0.0.1:", port); @@ -193,8 +198,14 @@ class FakeResolverResponseGeneratorWrapper { GPR_ASSERT(lb_uri != nullptr); grpc_resolved_address address; GPR_ASSERT(grpc_parse_uri(lb_uri, &address)); + std::map> + attributes; + if (attribute != nullptr) { + attributes[attribute_key] = attribute->Copy(); + } result.addresses.emplace_back(address.addr, address.len, - nullptr /* args */); + nullptr /* args */, std::move(attributes)); grpc_uri_destroy(lb_uri); } if (service_config_json != nullptr) { @@ -1887,6 +1898,83 @@ TEST_F(ClientLbInterceptTrailingMetadataTest, BackendMetricData) { EXPECT_EQ(kNumRpcs, trailers_intercepted()); } +class ClientLbAddressTest : public ClientLbEnd2endTest { + protected: + static const char* kAttributeKey; + + class Attribute : public grpc_core::ServerAddress::AttributeInterface { + public: + explicit Attribute(const std::string& str) : str_(str) {} + + std::unique_ptr Copy() const override { + return absl::make_unique(str_); + } + + int Cmp(const AttributeInterface* other) const override { + return str_.compare(static_cast(other)->str_); + } + + std::string ToString() const override { return str_; } + + private: + std::string str_; + }; + + void SetUp() override { + ClientLbEnd2endTest::SetUp(); + current_test_instance_ = this; + } + + static void SetUpTestCase() { + grpc_init(); + grpc_core::RegisterAddressTestLoadBalancingPolicy(SaveAddress); + } + + static void TearDownTestCase() { grpc_shutdown_blocking(); } + + const std::vector& addresses_seen() { + grpc::internal::MutexLock lock(&mu_); + return addresses_seen_; + } + + private: + static void SaveAddress(const grpc_core::ServerAddress& address) { + ClientLbAddressTest* self = current_test_instance_; + grpc::internal::MutexLock lock(&self->mu_); + self->addresses_seen_.emplace_back(address.ToString()); + } + + static ClientLbAddressTest* current_test_instance_; + grpc::internal::Mutex mu_; + std::vector addresses_seen_; +}; + +const char* ClientLbAddressTest::kAttributeKey = "attribute_key"; + +ClientLbAddressTest* ClientLbAddressTest::current_test_instance_ = nullptr; + +TEST_F(ClientLbAddressTest, Basic) { + const int kNumServers = 1; + StartServers(kNumServers); + auto response_generator = BuildResolverResponseGenerator(); + auto channel = BuildChannel("address_test_lb", response_generator); + auto stub = BuildStub(channel); + // Addresses returned by the resolver will have attached attributes. + response_generator.SetNextResolution(GetServersPorts(), nullptr, + kAttributeKey, + absl::make_unique("foo")); + CheckRpcSendOk(stub, DEBUG_LOCATION); + // Check LB policy name for the channel. + EXPECT_EQ("address_test_lb", channel->GetLoadBalancingPolicyName()); + // Make sure that the attributes wind up on the subchannels. + std::vector expected; + for (const int port : GetServersPorts()) { + expected.emplace_back(absl::StrCat( + "127.0.0.1:", port, " args={} attributes={", kAttributeKey, "=foo}")); + } + EXPECT_EQ(addresses_seen(), expected); +} + } // namespace } // namespace testing } // namespace grpc