diff --git a/src/core/ext/filters/client_channel/client_channel.cc b/src/core/ext/filters/client_channel/client_channel.cc index be57bf8d4c7..fcc5d4f4ff6 100644 --- a/src/core/ext/filters/client_channel/client_channel.cc +++ b/src/core/ext/filters/client_channel/client_channel.cc @@ -24,6 +24,7 @@ #include #include #include +#include #include #include #include @@ -570,7 +571,15 @@ class ClientChannel::SubchannelWrapper : public SubchannelInterface { static_cast( watcher.release())); internal_watcher->SetSubchannel(subchannel_.get()); - data_watchers_.push_back(std::move(internal_watcher)); + data_watchers_.insert(std::move(internal_watcher)); + } + + void CancelDataWatcher(DataWatcherInterface* watcher) override + ABSL_EXCLUSIVE_LOCKS_REQUIRED(*chand_->work_serializer_) { + auto* internal_watcher = + static_cast(watcher); + auto it = data_watchers_.find(internal_watcher); + if (it != data_watchers_.end()) data_watchers_.erase(it); } void ThrottleKeepaliveTime(int new_keepalive_time) { @@ -683,6 +692,29 @@ class ClientChannel::SubchannelWrapper : public SubchannelInterface { RefCountedPtr parent_; }; + // A heterogenous lookup comparator for data watchers that allows + // unique_ptr keys to be looked up as raw pointers. + struct DataWatcherCompare { + using is_transparent = void; + bool operator()( + const std::unique_ptr& p1, + const std::unique_ptr& p2) + const { + return p1 == p2; + } + bool operator()( + const std::unique_ptr& p1, + const InternalSubchannelDataWatcherInterface* p2) const { + return p1.get() == p2; + } + bool operator()( + const InternalSubchannelDataWatcherInterface* p1, + const std::unique_ptr& p2) + const { + return p1 == p2.get(); + } + }; + ClientChannel* chand_; RefCountedPtr subchannel_; // Maps from the address of the watcher passed to us by the LB policy @@ -692,7 +724,8 @@ class ClientChannel::SubchannelWrapper : public SubchannelInterface { // corresponding WrapperWatcher to cancel on the underlying subchannel. std::map watcher_map_ ABSL_GUARDED_BY(*chand_->work_serializer_); - std::vector> + std::set, + DataWatcherCompare> data_watchers_ ABSL_GUARDED_BY(*chand_->work_serializer_); }; diff --git a/src/core/ext/filters/client_channel/lb_policy/subchannel_list.h b/src/core/ext/filters/client_channel/lb_policy/subchannel_list.h index d5a0ecfda71..0181b2eb9ca 100644 --- a/src/core/ext/filters/client_channel/lb_policy/subchannel_list.h +++ b/src/core/ext/filters/client_channel/lb_policy/subchannel_list.h @@ -173,6 +173,7 @@ class SubchannelData { // Will be non-null when the subchannel's state is being watched. SubchannelInterface::ConnectivityStateWatcherInterface* pending_watcher_ = nullptr; + SubchannelInterface::DataWatcherInterface* health_watcher_ = nullptr; // Data updated by the watcher. absl::optional connectivity_state_; absl::Status connectivity_status_; @@ -259,7 +260,7 @@ void SubchannelData::Watcher:: GPR_INFO, "[%s %p] subchannel list %p index %" PRIuPTR " of %" PRIuPTR " (subchannel %p): connectivity changed: old_state=%s, new_state=%s, " - "status=%s, shutting_down=%d, pending_watcher=%p", + "status=%s, shutting_down=%d, pending_watcher=%p, health_watcher=%p", subchannel_list_->tracer(), subchannel_list_->policy(), subchannel_list_.get(), subchannel_data_->Index(), subchannel_list_->num_subchannels(), @@ -268,10 +269,12 @@ void SubchannelData::Watcher:: ? ConnectivityStateName(*subchannel_data_->connectivity_state_) : "N/A"), ConnectivityStateName(new_state), status.ToString().c_str(), - subchannel_list_->shutting_down(), subchannel_data_->pending_watcher_); + subchannel_list_->shutting_down(), subchannel_data_->pending_watcher_, + subchannel_data_->health_watcher_); } if (!subchannel_list_->shutting_down() && - subchannel_data_->pending_watcher_ != nullptr) { + (subchannel_data_->pending_watcher_ != nullptr || + subchannel_data_->health_watcher_ != nullptr)) { absl::optional old_state = subchannel_data_->connectivity_state_; subchannel_data_->connectivity_state_ = new_state; @@ -336,14 +339,17 @@ void SubchannelDatahealth_check_service_name_.value_or("N/A").c_str()); } GPR_ASSERT(pending_watcher_ == nullptr); + GPR_ASSERT(health_watcher_ == nullptr); auto watcher = std::make_unique( this, subchannel_list()->WeakRef(DEBUG_LOCATION, "Watcher")); - pending_watcher_ = watcher.get(); if (subchannel_list()->health_check_service_name_.has_value()) { - subchannel_->AddDataWatcher(MakeHealthCheckWatcher( + auto health_watcher = MakeHealthCheckWatcher( subchannel_list_->work_serializer(), - *subchannel_list()->health_check_service_name_, std::move(watcher))); + *subchannel_list()->health_check_service_name_, std::move(watcher)); + health_watcher_ = health_watcher.get(); + subchannel_->AddDataWatcher(std::move(health_watcher)); } else { + pending_watcher_ = watcher.get(); subchannel_->WatchConnectivityState(std::move(watcher)); } } @@ -360,12 +366,19 @@ void SubchannelData:: subchannel_list_, Index(), subchannel_list_->num_subchannels(), subchannel_.get(), reason); } - // No need to cancel if using health checking, because the data - // watcher will be destroyed automatically when the subchannel is. - if (!subchannel_list()->health_check_service_name_.has_value()) { - subchannel_->CancelConnectivityStateWatch(pending_watcher_); - } + subchannel_->CancelConnectivityStateWatch(pending_watcher_); pending_watcher_ = nullptr; + } else if (health_watcher_ != nullptr) { + if (GPR_UNLIKELY(subchannel_list_->tracer() != nullptr)) { + gpr_log(GPR_INFO, + "[%s %p] subchannel list %p index %" PRIuPTR " of %" PRIuPTR + " (subchannel %p): canceling health watch (%s)", + subchannel_list_->tracer(), subchannel_list_->policy(), + subchannel_list_, Index(), subchannel_list_->num_subchannels(), + subchannel_.get(), reason); + } + subchannel_->CancelDataWatcher(health_watcher_); + health_watcher_ = nullptr; } } diff --git a/src/core/lib/load_balancing/subchannel_interface.h b/src/core/lib/load_balancing/subchannel_interface.h index fea852690bd..9a9e855546a 100644 --- a/src/core/lib/load_balancing/subchannel_interface.h +++ b/src/core/lib/load_balancing/subchannel_interface.h @@ -97,6 +97,9 @@ class SubchannelInterface : public DualRefCounted { // Registers a new data watcher. virtual void AddDataWatcher( std::unique_ptr watcher) = 0; + + // Cancels a data watch. + virtual void CancelDataWatcher(DataWatcherInterface* watcher) = 0; }; // A class that delegates to another subchannel, to be used in cases @@ -125,6 +128,9 @@ class DelegatingSubchannel : public SubchannelInterface { void AddDataWatcher(std::unique_ptr watcher) override { wrapped_subchannel_->AddDataWatcher(std::move(watcher)); } + void CancelDataWatcher(DataWatcherInterface* watcher) override { + wrapped_subchannel_->CancelDataWatcher(watcher); + } private: RefCountedPtr wrapped_subchannel_; diff --git a/test/core/client_channel/lb_policy/lb_policy_test_lib.h b/test/core/client_channel/lb_policy/lb_policy_test_lib.h index e73ec073fa4..eecc3874ccb 100644 --- a/test/core/client_channel/lb_policy/lb_policy_test_lib.h +++ b/test/core/client_channel/lb_policy/lb_policy_test_lib.h @@ -175,6 +175,13 @@ class LoadBalancingPolicyTest : public ::testing::Test { state_->watchers_.insert(orca_watcher_.get()); } + void CancelDataWatcher(DataWatcherInterface* watcher) override { + MutexLock lock(&state_->backend_metric_watcher_mu_); + if (orca_watcher_.get() != static_cast(watcher)) return; + state_->watchers_.erase(orca_watcher_.get()); + orca_watcher_.reset(); + } + // Don't need this method, so it's a no-op. void ResetBackoff() override {}