From 3d291cc463a8abc4b323cb304a708a5847857d4a Mon Sep 17 00:00:00 2001 From: Eugene Ostroukhov Date: Tue, 9 May 2023 08:50:43 -0700 Subject: [PATCH] [xDS] Implement cluster locking by ClusterSelectionFilter (#32938) Part of the work needed for [gRFC A60](https://github.com/grpc/proposal/blob/master/A60-xds-stateful-session-affinity-weighted-clusters.md). --- src/core/BUILD | 4 + .../resolver/xds/xds_resolver.cc | 177 +++++++++++++++--- 2 files changed, 150 insertions(+), 31 deletions(-) diff --git a/src/core/BUILD b/src/core/BUILD index 62fba64e0f6..a53c9117b85 100644 --- a/src/core/BUILD +++ b/src/core/BUILD @@ -5202,8 +5202,10 @@ grpc_cc_library( language = "c++", deps = [ "arena", + "arena_promise", "channel_args", "channel_fwd", + "context", "dual_ref_counted", "grpc_lb_policy_ring_hash", "grpc_resolver_xds_header", @@ -5212,6 +5214,7 @@ grpc_cc_library( "iomgr_fwd", "match", "pollset_set", + "ref_counted", "slice", "time", "unique_type_name", @@ -5224,6 +5227,7 @@ grpc_cc_library( "//:grpc_resolver", "//:grpc_service_config_impl", "//:grpc_trace", + "//:legacy_context", "//:orphanable", "//:ref_counted_ptr", "//:server_address", diff --git a/src/core/ext/filters/client_channel/resolver/xds/xds_resolver.cc b/src/core/ext/filters/client_channel/resolver/xds/xds_resolver.cc index c372ccb45e1..d4733e3ee26 100644 --- a/src/core/ext/filters/client_channel/resolver/xds/xds_resolver.cc +++ b/src/core/ext/filters/client_channel/resolver/xds/xds_resolver.cc @@ -21,6 +21,7 @@ #include #include +#include #include #include #include @@ -46,7 +47,6 @@ #include -#include "src/core/ext/filters/client_channel/client_channel_internal.h" #include "src/core/lib/gprpp/unique_type_name.h" #include "src/core/lib/slice/slice.h" @@ -57,6 +57,7 @@ #include #include +#include "src/core/ext/filters/client_channel/client_channel_internal.h" #include "src/core/ext/filters/client_channel/config_selector.h" #include "src/core/ext/filters/client_channel/lb_policy/ring_hash/ring_hash.h" #include "src/core/ext/filters/client_channel/resolver/xds/xds_resolver.h" @@ -69,6 +70,9 @@ #include "src/core/ext/xds/xds_routing.h" #include "src/core/lib/channel/channel_args.h" #include "src/core/lib/channel/channel_fwd.h" +#include "src/core/lib/channel/channel_stack.h" +#include "src/core/lib/channel/context.h" +#include "src/core/lib/channel/promise_based_filter.h" #include "src/core/lib/channel/status_util.h" #include "src/core/lib/config/core_configuration.h" #include "src/core/lib/debug/trace.h" @@ -76,11 +80,14 @@ #include "src/core/lib/gprpp/dual_ref_counted.h" #include "src/core/lib/gprpp/match.h" #include "src/core/lib/gprpp/orphanable.h" +#include "src/core/lib/gprpp/ref_counted.h" #include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/gprpp/time.h" #include "src/core/lib/gprpp/work_serializer.h" #include "src/core/lib/iomgr/iomgr_fwd.h" #include "src/core/lib/iomgr/pollset_set.h" +#include "src/core/lib/promise/arena_promise.h" +#include "src/core/lib/promise/context.h" #include "src/core/lib/resolver/resolver.h" #include "src/core/lib/resolver/resolver_factory.h" #include "src/core/lib/resolver/server_address.h" @@ -88,6 +95,7 @@ #include "src/core/lib/service_config/service_config.h" #include "src/core/lib/service_config/service_config_impl.h" #include "src/core/lib/transport/metadata_batch.h" +#include "src/core/lib/transport/transport.h" #include "src/core/lib/uri/uri_parser.h" namespace grpc_core { @@ -240,9 +248,9 @@ class XdsResolver : public Resolver { // back into the WorkSerializer to remove the entry from the map. class ClusterState : public DualRefCounted { public: - ClusterState(RefCountedPtr resolver, std::string cluster_name) - : resolver_(std::move(resolver)), - cluster_name_(std::move(cluster_name)) {} + ClusterState(RefCountedPtr resolver, + absl::string_view cluster_name) + : resolver_(std::move(resolver)), cluster_name_(cluster_name) {} void Orphan() override { auto* resolver = resolver_.get(); @@ -260,6 +268,32 @@ class XdsResolver : public Resolver { std::string cluster_name_; }; + // A map containing cluster refs held by the XdsConfigSelector. A ref to + // this map will be taken by each call processed by the XdsConfigSelector, + // stored in a the call's call attributes, and later unreffed + // by the ClusterSelection filter. + class XdsClusterMap : public RefCounted { + public: + explicit XdsClusterMap( + std::map> clusters) + : clusters_(std::move(clusters)) {} + + bool operator==(const XdsClusterMap& other) const { + return clusters_ == other.clusters_; + } + + RefCountedPtr Find(absl::string_view name) const { + auto it = clusters_.find(name); + if (it == clusters_.end()) { + return nullptr; + } + return it->second; + } + + private: + std::map> clusters_; + }; + class XdsConfigSelector : public ConfigSelector { public: XdsConfigSelector(RefCountedPtr resolver, @@ -272,7 +306,7 @@ class XdsResolver : public Resolver { const auto* other_xds = static_cast(other); // Don't need to compare resolver_, since that will always be the same. return route_table_ == other_xds->route_table_ && - clusters_ == other_xds->clusters_; + *cluster_map_ == *other_xds->cluster_map_; } absl::Status GetCallConfig(GetCallConfigArgs args) override; @@ -301,7 +335,6 @@ class XdsResolver : public Resolver { class RouteListIterator; - void MaybeAddCluster(const std::string& name); absl::StatusOr> CreateMethodConfig( const XdsRouteConfigResource::Route& route, const XdsRouteConfigResource::Route::RouteAction::ClusterWeight* @@ -309,10 +342,92 @@ class XdsResolver : public Resolver { RefCountedPtr resolver_; RouteTable route_table_; - std::map> clusters_; + RefCountedPtr cluster_map_; std::vector filters_; }; + class ClusterSelectionFilter : public ChannelFilter { + public: + const static grpc_channel_filter kFilter; + + static absl::StatusOr Create( + const ChannelArgs& /* unused */, ChannelFilter::Args filter_args) { + return ClusterSelectionFilter(filter_args); + } + + // Construct a promise for one call. + ArenaPromise MakeCallPromise( + CallArgs call_args, NextPromiseFactory next_promise_factory) override { + auto* service_config_call_data = + static_cast( + GetContext() + [GRPC_CONTEXT_SERVICE_CONFIG_CALL_DATA] + .value); + GPR_ASSERT(service_config_call_data != nullptr); + auto* cluster_data = static_cast( + service_config_call_data->GetCallAttribute( + XdsClusterMapAttribute::TypeName())); + auto* cluster_name_attribute = static_cast( + service_config_call_data->GetCallAttribute( + XdsClusterAttribute::TypeName())); + if (cluster_data != nullptr && cluster_name_attribute != nullptr) { + auto cluster = + cluster_data->LockAndGetCluster(cluster_name_attribute->cluster()); + if (cluster != nullptr) { + service_config_call_data->SetOnCommit( + [cluster = std::move(cluster)]() mutable { cluster.reset(); }); + } + } + return next_promise_factory(std::move(call_args)); + } + + private: + explicit ClusterSelectionFilter(ChannelFilter::Args filter_args) + : filter_args_(filter_args) {} + + ChannelFilter::Args filter_args_; + }; + + RefCountedPtr GetOrCreateClusterState( + absl::string_view cluster_name) { + auto it = cluster_state_map_.find(cluster_name); + if (it == cluster_state_map_.end()) { + auto cluster = MakeRefCounted(Ref(), cluster_name); + cluster_state_map_.emplace(cluster->cluster_name(), cluster->WeakRef()); + return cluster; + } + return it->second->Ref(); + } + + class XdsClusterMapAttribute + : public ServiceConfigCallData::CallAttributeInterface { + public: + static UniqueTypeName TypeName() { + static UniqueTypeName::Factory factory("xds_cluster_lb_data"); + return factory.Create(); + } + + explicit XdsClusterMapAttribute(RefCountedPtr cluster_map) + : cluster_map_(std::move(cluster_map)) {} + + // This method can be called only once. The first call will release the + // reference to the cluster map, and subsequent calls will return nullptr. + RefCountedPtr LockAndGetCluster( + absl::string_view cluster_name) { + if (cluster_map_ == nullptr) { + return nullptr; + } + auto cluster = cluster_map_->Find(cluster_name); + cluster_map_.reset(); + return cluster; + } + + UniqueTypeName type() const override { return TypeName(); } + + private: + RefCountedPtr cluster_map_; + }; + void OnListenerUpdate(XdsListenerResource listener); void OnRouteConfigUpdate(XdsRouteConfigResource rds_update); void OnError(absl::string_view context, absl::Status status); @@ -360,6 +475,11 @@ bool MethodConfigsEqual(const ServiceConfig* sc1, const ServiceConfig* sc2) { return sc1->json_string() == sc2->json_string(); } +const grpc_channel_filter XdsResolver::ClusterSelectionFilter::kFilter = + MakePromiseBasedFilter( + "cluster_selection_filter"); + bool XdsResolver::XdsConfigSelector::Route::ClusterWeightState::operator==( const ClusterWeightState& other) const { return range_end == other.range_end && cluster == other.cluster && @@ -413,6 +533,13 @@ XdsResolver::XdsConfigSelector::XdsConfigSelector( // moving the entry in a reallocation will cause the string_view to point to // invalid data. route_table_.reserve(resolver_->current_virtual_host_->routes.size()); + std::map> clusters; + auto maybe_add_cluster = [&](absl::string_view cluster_name) { + if (clusters.find(cluster_name) != clusters.end()) return; + auto cluster_state = resolver_->GetOrCreateClusterState(cluster_name); + absl::string_view name = cluster_state->cluster_name(); + clusters.emplace(name, std::move(cluster_state)); + }; for (auto& route : resolver_->current_virtual_host_->routes) { if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_resolver_trace)) { gpr_log(GPR_INFO, "[xds_resolver %p] XdsConfigSelector %p: route: %s", @@ -442,7 +569,7 @@ XdsResolver::XdsConfigSelector::XdsConfigSelector( return; } route_entry.method_config = std::move(*result); - MaybeAddCluster( + maybe_add_cluster( absl::StrCat("cluster:", cluster_name.cluster_name)); }, // WeightedClusters @@ -464,7 +591,8 @@ XdsResolver::XdsConfigSelector::XdsConfigSelector( cluster_weight_state.cluster = weighted_cluster.name; route_entry.weighted_cluster_state.push_back( std::move(cluster_weight_state)); - MaybeAddCluster(absl::StrCat("cluster:", weighted_cluster.name)); + maybe_add_cluster( + absl::StrCat("cluster:", weighted_cluster.name)); } }, // ClusterSpecifierPlugin @@ -476,13 +604,14 @@ XdsResolver::XdsConfigSelector::XdsConfigSelector( return; } route_entry.method_config = std::move(*result); - MaybeAddCluster(absl::StrCat( + maybe_add_cluster(absl::StrCat( "cluster_specifier_plugin:", cluster_specifier_plugin_name.cluster_specifier_plugin_name)); }); if (!status->ok()) return; } } + cluster_map_ = MakeRefCounted(std::move(clusters)); // Populate filter list. const auto& http_filter_registry = static_cast(resolver_->xds_client_->bootstrap()) @@ -499,6 +628,7 @@ XdsResolver::XdsConfigSelector::XdsConfigSelector( filters_.push_back(filter_impl->channel_filter()); } } + filters_.push_back(&ClusterSelectionFilter::kFilter); } XdsResolver::XdsConfigSelector::~XdsConfigSelector() { @@ -506,7 +636,7 @@ XdsResolver::XdsConfigSelector::~XdsConfigSelector() { gpr_log(GPR_INFO, "[xds_resolver %p] destroying XdsConfigSelector %p", resolver_.get(), this); } - clusters_.clear(); + cluster_map_.reset(); resolver_->MaybeRemoveUnusedClusters(); } @@ -592,21 +722,6 @@ XdsResolver::XdsConfigSelector::CreateMethodConfig( return nullptr; } -void XdsResolver::XdsConfigSelector::MaybeAddCluster(const std::string& name) { - if (clusters_.find(name) == clusters_.end()) { - auto it = resolver_->cluster_state_map_.find(name); - if (it == resolver_->cluster_state_map_.end()) { - auto new_cluster_state = MakeRefCounted(resolver_, name); - resolver_->cluster_state_map_.emplace(new_cluster_state->cluster_name(), - new_cluster_state->WeakRef()); - clusters_[new_cluster_state->cluster_name()] = - std::move(new_cluster_state); - } else { - clusters_[it->second->cluster_name()] = it->second->Ref(); - } - } -} - absl::optional HeaderHashHelper( const XdsRouteConfigResource::Route::RouteAction::HashPolicy::Header& header_policy, @@ -694,8 +809,8 @@ absl::Status XdsResolver::XdsConfigSelector::GetCallConfig( cluster_specifier_plugin_name.cluster_specifier_plugin_name); method_config = entry.method_config; }); - auto it = clusters_.find(cluster_name); - GPR_ASSERT(it != clusters_.end()); + auto cluster = cluster_map_->Find(cluster_name); + GPR_ASSERT(cluster != nullptr); // Generate a hash. absl::optional hash; for (const auto& hash_policy : route_action->hash_policies) { @@ -733,7 +848,7 @@ absl::Status XdsResolver::XdsConfigSelector::GetCallConfig( parsed_method_configs); } args.service_config_call_data->SetCallAttribute( - args.arena->New(it->first)); + args.arena->New(cluster->cluster_name())); std::string hash_string = absl::StrCat(hash.value()); char* hash_value = static_cast(args.arena->Alloc(hash_string.size() + 1)); @@ -741,8 +856,8 @@ absl::Status XdsResolver::XdsConfigSelector::GetCallConfig( hash_value[hash_string.size()] = '\0'; args.service_config_call_data->SetCallAttribute( args.arena->New(hash_value)); - args.service_config_call_data->SetOnCommit( - [cluster_state = it->second->Ref()]() mutable { cluster_state.reset(); }); + args.service_config_call_data->SetCallAttribute( + args.arena->ManagedNew(cluster_map_)); return absl::OkStatus(); }