diff --git a/BUILD b/BUILD index 4e24f7fedf7..8893aee2bb1 100644 --- a/BUILD +++ b/BUILD @@ -3674,6 +3674,7 @@ grpc_cc_library( ], external_deps = [ "absl/base:core_headers", + "absl/cleanup", "absl/memory", "absl/status", "absl/status:statusor", diff --git a/src/core/BUILD b/src/core/BUILD index 7c273dce60e..f39d235d711 100644 --- a/src/core/BUILD +++ b/src/core/BUILD @@ -4455,6 +4455,7 @@ grpc_cc_library( ], external_deps = [ "absl/base:core_headers", + "absl/cleanup", "absl/functional:bind_front", "absl/memory", "absl/random", @@ -4664,6 +4665,7 @@ grpc_cc_library( "//:ref_counted_ptr", "//:sockaddr_utils", "//:uri_parser", + "//:xds_client", ], ) @@ -4732,6 +4734,7 @@ grpc_cc_library( "//:orphanable", "//:ref_counted_ptr", "//:work_serializer", + "//:xds_client", ], ) diff --git a/src/core/ext/filters/client_channel/lb_policy/xds/cds.cc b/src/core/ext/filters/client_channel/lb_policy/xds/cds.cc index 857bdf5c473..2aeae2e7a6b 100644 --- a/src/core/ext/filters/client_channel/lb_policy/xds/cds.cc +++ b/src/core/ext/filters/client_channel/lb_policy/xds/cds.cc @@ -40,6 +40,7 @@ #include "src/core/ext/filters/client_channel/lb_policy/outlier_detection/outlier_detection.h" #include "src/core/ext/xds/certificate_provider_store.h" #include "src/core/ext/xds/xds_certificate_provider.h" +#include "src/core/ext/xds/xds_client.h" #include "src/core/ext/xds/xds_client_grpc.h" #include "src/core/ext/xds/xds_cluster.h" #include "src/core/ext/xds/xds_common_types.h" @@ -123,26 +124,32 @@ class CdsLb : public LoadBalancingPolicy { : parent_(std::move(parent)), name_(std::move(name)) {} void OnResourceChanged( - std::shared_ptr cluster_data) override { + std::shared_ptr cluster_data, + RefCountedPtr read_delay_handle) override { parent_->work_serializer()->Run( [self = RefAsSubclass(), - cluster_data = std::move(cluster_data)]() mutable { + cluster_data = std::move(cluster_data), + read_handle = std::move(read_delay_handle)]() mutable { self->parent_->OnClusterChanged(self->name_, std::move(cluster_data)); }, DEBUG_LOCATION); } - void OnError(absl::Status status) override { + void OnError( + absl::Status status, + RefCountedPtr read_delay_handle) override { parent_->work_serializer()->Run( - [self = RefAsSubclass(), - status = std::move(status)]() mutable { + [self = RefAsSubclass(), status = std::move(status), + read_handle = std::move(read_delay_handle)]() mutable { self->parent_->OnError(self->name_, std::move(status)); }, DEBUG_LOCATION); } - void OnResourceDoesNotExist() override { + void OnResourceDoesNotExist( + RefCountedPtr read_delay_handle) override { parent_->work_serializer()->Run( - [self = RefAsSubclass()]() { + [self = RefAsSubclass(), + read_handle = std::move(read_delay_handle)]() { self->parent_->OnResourceDoesNotExist(self->name_); }, DEBUG_LOCATION); diff --git a/src/core/ext/filters/client_channel/lb_policy/xds/xds_cluster_resolver.cc b/src/core/ext/filters/client_channel/lb_policy/xds/xds_cluster_resolver.cc index f05d8f153a2..7e29c17553c 100644 --- a/src/core/ext/filters/client_channel/lb_policy/xds/xds_cluster_resolver.cc +++ b/src/core/ext/filters/client_channel/lb_policy/xds/xds_cluster_resolver.cc @@ -215,26 +215,33 @@ class XdsClusterResolverLb : public LoadBalancingPolicy { ~EndpointWatcher() override { discovery_mechanism_.reset(DEBUG_LOCATION, "EndpointWatcher"); } - void OnResourceChanged( - std::shared_ptr update) override { + void OnResourceChanged(std::shared_ptr update, + RefCountedPtr + read_delay_handle) override { discovery_mechanism_->parent()->work_serializer()->Run( [self = RefAsSubclass(), - update = std::move(update)]() mutable { + update = std::move(update), + read_delay_handle = std::move(read_delay_handle)]() mutable { self->OnResourceChangedHelper(std::move(update)); }, DEBUG_LOCATION); } - void OnError(absl::Status status) override { + void OnError(absl::Status status, + RefCountedPtr read_delay_handle) + override { discovery_mechanism_->parent()->work_serializer()->Run( [self = RefAsSubclass(), - status = std::move(status)]() mutable { + status = std::move(status), + read_delay_handle = std::move(read_delay_handle)]() mutable { self->OnErrorHelper(std::move(status)); }, DEBUG_LOCATION); } - void OnResourceDoesNotExist() override { + void OnResourceDoesNotExist(RefCountedPtr + read_delay_handle) override { discovery_mechanism_->parent()->work_serializer()->Run( - [self = RefAsSubclass()]() { + [self = RefAsSubclass(), + read_delay_handle = std::move(read_delay_handle)]() { self->OnResourceDoesNotExistHelper(); }, DEBUG_LOCATION); diff --git a/src/core/ext/filters/client_channel/resolver/xds/xds_resolver.cc b/src/core/ext/filters/client_channel/resolver/xds/xds_resolver.cc index 80d14cf17fb..2ae4b1d2d3b 100644 --- a/src/core/ext/filters/client_channel/resolver/xds/xds_resolver.cc +++ b/src/core/ext/filters/client_channel/resolver/xds/xds_resolver.cc @@ -55,6 +55,7 @@ #include "src/core/ext/filters/client_channel/lb_policy/ring_hash/ring_hash.h" #include "src/core/ext/xds/xds_bootstrap.h" #include "src/core/ext/xds/xds_bootstrap_grpc.h" +#include "src/core/ext/xds/xds_client.h" #include "src/core/ext/xds/xds_client_grpc.h" #include "src/core/ext/xds/xds_http_filters.h" #include "src/core/ext/xds/xds_listener.h" @@ -99,6 +100,8 @@ TraceFlag grpc_xds_resolver_trace(false, "xds_resolver"); namespace { +using ReadDelayHandle = XdsClient::ReadDelayHandle; + // // XdsResolver // @@ -141,26 +144,31 @@ class XdsResolver : public Resolver { explicit ListenerWatcher(RefCountedPtr resolver) : resolver_(std::move(resolver)) {} void OnResourceChanged( - std::shared_ptr listener) override { + std::shared_ptr listener, + RefCountedPtr read_delay_handle) override { resolver_->work_serializer_->Run( [self = RefAsSubclass(), - listener = std::move(listener)]() mutable { + listener = std::move(listener), + read_delay_handle = std::move(read_delay_handle)]() mutable { self->resolver_->OnListenerUpdate(std::move(listener)); }, DEBUG_LOCATION); } - void OnError(absl::Status status) override { + void OnError(absl::Status status, + RefCountedPtr read_delay_handle) override { resolver_->work_serializer_->Run( - [self = RefAsSubclass(), - status = std::move(status)]() mutable { + [self = RefAsSubclass(), status = std::move(status), + read_delay_handle = std::move(read_delay_handle)]() mutable { self->resolver_->OnError(self->resolver_->lds_resource_name_, std::move(status)); }, DEBUG_LOCATION); } - void OnResourceDoesNotExist() override { + void OnResourceDoesNotExist( + RefCountedPtr read_delay_handle) override { resolver_->work_serializer_->Run( - [self = RefAsSubclass()]() { + [self = RefAsSubclass(), + read_delay_handle = std::move(read_delay_handle)]() { self->resolver_->OnResourceDoesNotExist( absl::StrCat(self->resolver_->lds_resource_name_, ": xDS listener resource does not exist")); @@ -178,28 +186,34 @@ class XdsResolver : public Resolver { explicit RouteConfigWatcher(RefCountedPtr resolver) : resolver_(std::move(resolver)) {} void OnResourceChanged( - std::shared_ptr route_config) override { + std::shared_ptr route_config, + RefCountedPtr read_delay_handle) override { resolver_->work_serializer_->Run( [self = RefAsSubclass(), - route_config = std::move(route_config)]() mutable { + route_config = std::move(route_config), + read_delay_handle = std::move(read_delay_handle)]() mutable { if (self != self->resolver_->route_config_watcher_) return; self->resolver_->OnRouteConfigUpdate(std::move(route_config)); }, DEBUG_LOCATION); } - void OnError(absl::Status status) override { + void OnError(absl::Status status, + RefCountedPtr read_delay_handle) override { resolver_->work_serializer_->Run( [self = RefAsSubclass(), - status = std::move(status)]() mutable { + status = std::move(status), + read_delay_handle = std::move(read_delay_handle)]() mutable { if (self != self->resolver_->route_config_watcher_) return; self->resolver_->OnError(self->resolver_->route_config_name_, std::move(status)); }, DEBUG_LOCATION); } - void OnResourceDoesNotExist() override { + void OnResourceDoesNotExist( + RefCountedPtr read_delay_handle) override { resolver_->work_serializer_->Run( - [self = RefAsSubclass()]() { + [self = RefAsSubclass(), + read_delay_handle = std::move(read_delay_handle)]() { if (self != self->resolver_->route_config_watcher_) return; self->resolver_->OnResourceDoesNotExist(absl::StrCat( self->resolver_->route_config_name_, diff --git a/src/core/ext/xds/xds_client.cc b/src/core/ext/xds/xds_client.cc index e6399e98d1a..7b125b771d2 100644 --- a/src/core/ext/xds/xds_client.cc +++ b/src/core/ext/xds/xds_client.cc @@ -23,8 +23,10 @@ #include #include +#include #include +#include "absl/cleanup/cleanup.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" @@ -129,6 +131,16 @@ class XdsClient::ChannelState::AdsCallState bool HasSubscribedResources() const; private: + class AdsReadDelayHandle : public ReadDelayHandle { + public: + explicit AdsReadDelayHandle(RefCountedPtr ads_call_state) + : ads_call_state_(std::move(ads_call_state)) {} + ~AdsReadDelayHandle() override; + + private: + RefCountedPtr ads_call_state_; + }; + class AdsResponseParser : public XdsApi::AdsResponseParserInterface { public: struct Result { @@ -140,6 +152,7 @@ class XdsClient::ChannelState::AdsCallState std::map> resources_seen; bool have_valid_resources = false; + RefCountedPtr read_delay_handle; }; explicit AdsResponseParser(AdsCallState* ads_call_state) @@ -259,7 +272,7 @@ class XdsClient::ChannelState::AdsCallState ResourceState& state = authority_state.resource_map[type_][name_.key]; state.meta.client_status = XdsApi::ResourceMetadata::DOES_NOT_EXIST; ads_calld_->xds_client()->NotifyWatchersOnResourceDoesNotExist( - state.watchers); + state.watchers, ReadDelayHandle::NoWait()); } ads_calld_->xds_client()->work_serializer_.DrainQueue(); ads_calld_.reset(); @@ -588,7 +601,7 @@ void XdsClient::ChannelState::SetChannelStatusLocked(absl::Status status) { [watchers = std::move(watchers), status = std::move(status)]() ABSL_EXCLUSIVE_LOCKS_REQUIRED(xds_client_->work_serializer_) { for (const auto& watcher : watchers) { - watcher->OnError(status); + watcher->OnError(status, ReadDelayHandle::NoWait()); } }, DEBUG_LOCATION); @@ -686,6 +699,20 @@ void XdsClient::ChannelState::RetryableCall::OnRetryTimer() { } } +// +// XdsClient::ChannelState::AdsCallState::AdsReadDelayHandle +// + +XdsClient::ChannelState::AdsCallState::AdsReadDelayHandle:: + ~AdsReadDelayHandle() { + XdsClient* client = ads_call_state_->xds_client(); + MutexLock lock(&client->mu_); + auto call = ads_call_state_->call_.get(); + if (call != nullptr) { + call->StartRecvMessage(); + } +} + // // XdsClient::ChannelState::AdsCallState::AdsResponseParser // @@ -711,6 +738,8 @@ absl::Status XdsClient::ChannelState::AdsCallState::AdsResponseParser:: result_.type_url = std::move(fields.type_url); result_.version = std::move(fields.version); result_.nonce = std::move(fields.nonce); + result_.read_delay_handle = + MakeRefCounted(ads_call_state_->Ref()); return absl::OkStatus(); } @@ -841,7 +870,8 @@ void XdsClient::ChannelState::AdsCallState::AdsResponseParser::ParseResource( xds_client()->NotifyWatchersOnErrorLocked( resource_state.watchers, absl::UnavailableError( - absl::StrCat("invalid resource: ", decode_status.ToString()))); + absl::StrCat("invalid resource: ", decode_status.ToString())), + result_.read_delay_handle); UpdateResourceMetadataNacked(result_.version, decode_status.ToString(), update_time_, &resource_state.meta); return; @@ -867,10 +897,11 @@ void XdsClient::ChannelState::AdsCallState::AdsResponseParser::ParseResource( // Notify watchers. auto& watchers_list = resource_state.watchers; xds_client()->work_serializer_.Schedule( - [watchers_list, value = resource_state.resource]() + [watchers_list, value = resource_state.resource, + read_delay_handle = result_.read_delay_handle]() ABSL_EXCLUSIVE_LOCKS_REQUIRED(&xds_client()->work_serializer_) { for (const auto& p : watchers_list) { - p.first->OnGenericResourceChanged(value); + p.first->OnGenericResourceChanged(value, read_delay_handle); } }, DEBUG_LOCATION); @@ -930,6 +961,7 @@ XdsClient::ChannelState::AdsCallState::AdsCallState( for (const auto& p : state_map_) { SendMessageLocked(p.first); } + call_->StartRecvMessage(); } void XdsClient::ChannelState::AdsCallState::Orphan() { @@ -1034,12 +1066,17 @@ void XdsClient::ChannelState::AdsCallState::OnRequestSent(bool ok) { void XdsClient::ChannelState::AdsCallState::OnRecvMessage( absl::string_view payload) { + // Needs to be destroyed after the mutex is released. + RefCountedPtr read_delay_handle; { MutexLock lock(&xds_client()->mu_); if (!IsCurrentCallOnChannel()) return; // Parse and validate the response. AdsResponseParser parser(this); absl::Status status = xds_client()->api_.ParseAdsResponse(payload, &parser); + // This includes a handle that will trigger an ADS read. + AdsResponseParser::Result result = parser.TakeResult(); + read_delay_handle = std::move(result.read_delay_handle); if (!status.ok()) { // Ignore unparsable response. gpr_log(GPR_ERROR, @@ -1050,7 +1087,6 @@ void XdsClient::ChannelState::AdsCallState::OnRecvMessage( } else { seen_response_ = true; chand()->status_ = absl::OkStatus(); - AdsResponseParser::Result result = parser.TakeResult(); // Update nonce. auto& state = state_map_[result.type]; state.nonce = result.nonce; @@ -1110,7 +1146,7 @@ void XdsClient::ChannelState::AdsCallState::OnRecvMessage( resource_state.meta.client_status = XdsApi::ResourceMetadata::DOES_NOT_EXIST; xds_client()->NotifyWatchersOnResourceDoesNotExist( - resource_state.watchers); + resource_state.watchers, read_delay_handle); } } } @@ -1335,6 +1371,7 @@ XdsClient::ChannelState::LrsCallState::LrsCallState( std::string serialized_payload = xds_client()->api_.CreateLrsInitialRequest(); call_->SendMessage(std::move(serialized_payload)); send_message_pending_ = true; + call_->StartRecvMessage(); } void XdsClient::ChannelState::LrsCallState::Orphan() { @@ -1385,6 +1422,9 @@ void XdsClient::ChannelState::LrsCallState::OnRecvMessage( MutexLock lock(&xds_client()->mu_); // If we're no longer the current call, ignore the result. if (!IsCurrentCallOnChannel()) return; + // Start recv after any code branch + auto cleanup = + absl::MakeCleanup([call = call_.get()]() { call->StartRecvMessage(); }); // Parse the response. bool send_all_clusters = false; std::set new_cluster_names; @@ -1555,7 +1595,7 @@ void XdsClient::WatchResource(const XdsResourceType* type, work_serializer_.Run( [watcher = std::move(watcher), status = std::move(status)]() ABSL_EXCLUSIVE_LOCKS_REQUIRED(&work_serializer_) { - watcher->OnError(status); + watcher->OnError(status, ReadDelayHandle::NoWait()); }, DEBUG_LOCATION); }; @@ -1606,7 +1646,8 @@ void XdsClient::WatchResource(const XdsResourceType* type, work_serializer_.Schedule( [watcher, value = resource_state.resource]() ABSL_EXCLUSIVE_LOCKS_REQUIRED(&work_serializer_) { - watcher->OnGenericResourceChanged(value); + watcher->OnGenericResourceChanged(value, + ReadDelayHandle::NoWait()); }, DEBUG_LOCATION); } else if (resource_state.meta.client_status == @@ -1618,7 +1659,7 @@ void XdsClient::WatchResource(const XdsResourceType* type, } work_serializer_.Schedule( [watcher]() ABSL_EXCLUSIVE_LOCKS_REQUIRED(&work_serializer_) { - watcher->OnResourceDoesNotExist(); + watcher->OnResourceDoesNotExist(ReadDelayHandle::NoWait()); }, DEBUG_LOCATION); } else if (resource_state.meta.client_status == @@ -1638,8 +1679,9 @@ void XdsClient::WatchResource(const XdsResourceType* type, work_serializer_.Schedule( [watcher, details = std::move(details)]() ABSL_EXCLUSIVE_LOCKS_REQUIRED(&work_serializer_) { - watcher->OnError(absl::UnavailableError( - absl::StrCat("invalid resource: ", details))); + watcher->OnError(absl::UnavailableError(absl::StrCat( + "invalid resource: ", details)), + ReadDelayHandle::NoWait()); }, DEBUG_LOCATION); } @@ -1660,7 +1702,7 @@ void XdsClient::WatchResource(const XdsResourceType* type, work_serializer_.Schedule( [watcher = std::move(watcher), status = std::move(channel_status)]() ABSL_EXCLUSIVE_LOCKS_REQUIRED(&work_serializer_) mutable { - watcher->OnError(std::move(status)); + watcher->OnError(std::move(status), ReadDelayHandle::NoWait()); }, DEBUG_LOCATION); } @@ -1927,7 +1969,7 @@ void XdsClient::ResetBackoff() { void XdsClient::NotifyWatchersOnErrorLocked( const std::map>& watchers, - absl::Status status) { + absl::Status status, RefCountedPtr read_delay_handle) { const auto* node = bootstrap_->node(); if (node != nullptr) { status = absl::Status( @@ -1935,10 +1977,11 @@ void XdsClient::NotifyWatchersOnErrorLocked( absl::StrCat(status.message(), " (node ID:", node->id(), ")")); } work_serializer_.Schedule( - [watchers, status = std::move(status)]() + [watchers, status = std::move(status), + read_delay_handle = std::move(read_delay_handle)]() ABSL_EXCLUSIVE_LOCKS_REQUIRED(&work_serializer_) { for (const auto& p : watchers) { - p.first->OnError(status); + p.first->OnError(status, read_delay_handle); } }, DEBUG_LOCATION); @@ -1946,13 +1989,15 @@ void XdsClient::NotifyWatchersOnErrorLocked( void XdsClient::NotifyWatchersOnResourceDoesNotExist( const std::map>& watchers) { + RefCountedPtr>& watchers, + RefCountedPtr read_delay_handle) { work_serializer_.Schedule( - [watchers]() ABSL_EXCLUSIVE_LOCKS_REQUIRED(&work_serializer_) { - for (const auto& p : watchers) { - p.first->OnResourceDoesNotExist(); - } - }, + [watchers, read_delay_handle = std::move(read_delay_handle)]() + ABSL_EXCLUSIVE_LOCKS_REQUIRED(&work_serializer_) { + for (const auto& p : watchers) { + p.first->OnResourceDoesNotExist(read_delay_handle); + } + }, DEBUG_LOCATION); } diff --git a/src/core/ext/xds/xds_client.h b/src/core/ext/xds/xds_client.h index 0831c3c768a..779198c581e 100644 --- a/src/core/ext/xds/xds_client.h +++ b/src/core/ext/xds/xds_client.h @@ -56,6 +56,11 @@ extern TraceFlag grpc_xds_client_refcount_trace; class XdsClient : public DualRefCounted { public: + class ReadDelayHandle : public RefCounted { + public: + static RefCountedPtr NoWait() { return nullptr; } + }; + // Resource watcher interface. Implemented by callers. // Note: Most callers will not use this API directly but rather via a // resource-type-specific wrapper API provided by the relevant @@ -63,11 +68,14 @@ class XdsClient : public DualRefCounted { class ResourceWatcherInterface : public RefCounted { public: virtual void OnGenericResourceChanged( - std::shared_ptr resource) + std::shared_ptr resource, + RefCountedPtr read_delay_handle) ABSL_EXCLUSIVE_LOCKS_REQUIRED(&work_serializer_) = 0; - virtual void OnError(absl::Status status) + virtual void OnError(absl::Status status, + RefCountedPtr read_delay_handle) ABSL_EXCLUSIVE_LOCKS_REQUIRED(&work_serializer_) = 0; - virtual void OnResourceDoesNotExist() + virtual void OnResourceDoesNotExist( + RefCountedPtr read_delay_handle) ABSL_EXCLUSIVE_LOCKS_REQUIRED(&work_serializer_) = 0; }; @@ -277,11 +285,12 @@ class XdsClient : public DualRefCounted { void NotifyWatchersOnErrorLocked( const std::map>& watchers, - absl::Status status); + absl::Status status, RefCountedPtr read_delay_handle); // Sends a resource-does-not-exist notification to a specific set of watchers. void NotifyWatchersOnResourceDoesNotExist( const std::map>& watchers); + RefCountedPtr>& watchers, + RefCountedPtr read_delay_handle); void MaybeRegisterResourceTypeLocked(const XdsResourceType* resource_type) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); diff --git a/src/core/ext/xds/xds_resource_type_impl.h b/src/core/ext/xds/xds_resource_type_impl.h index 35dfbbfcb4c..ed3d2d6387b 100644 --- a/src/core/ext/xds/xds_resource_type_impl.h +++ b/src/core/ext/xds/xds_resource_type_impl.h @@ -42,16 +42,18 @@ class XdsResourceTypeImpl : public XdsResourceType { class WatcherInterface : public XdsClient::ResourceWatcherInterface { public: virtual void OnResourceChanged( - std::shared_ptr resource) = 0; + std::shared_ptr resource, + RefCountedPtr read_delay_handle) = 0; private: // Get result from XdsClient generic watcher interface, perform // down-casting, and invoke the caller's OnResourceChanged() method. void OnGenericResourceChanged( - std::shared_ptr resource) - override { + std::shared_ptr resource, + RefCountedPtr read_delay_handle) override { OnResourceChanged( - std::static_pointer_cast(std::move(resource))); + std::static_pointer_cast(std::move(resource)), + std::move(read_delay_handle)); } }; diff --git a/src/core/ext/xds/xds_server_config_fetcher.cc b/src/core/ext/xds/xds_server_config_fetcher.cc index ee85f0a7d4c..1c6373f4dfa 100644 --- a/src/core/ext/xds/xds_server_config_fetcher.cc +++ b/src/core/ext/xds/xds_server_config_fetcher.cc @@ -52,6 +52,7 @@ #include "src/core/ext/xds/xds_bootstrap_grpc.h" #include "src/core/ext/xds/xds_certificate_provider.h" #include "src/core/ext/xds/xds_channel_stack_modifier.h" +#include "src/core/ext/xds/xds_client.h" #include "src/core/ext/xds/xds_client_grpc.h" #include "src/core/ext/xds/xds_common_types.h" #include "src/core/ext/xds/xds_http_filters.h" @@ -91,6 +92,8 @@ namespace grpc_core { namespace { +using ReadDelayHandle = XdsClient::ReadDelayHandle; + TraceFlag grpc_xds_server_config_fetcher_trace(false, "xds_server_config_fetcher"); @@ -151,11 +154,14 @@ class XdsServerConfigFetcher::ListenerWatcher } void OnResourceChanged( - std::shared_ptr listener) override; + std::shared_ptr listener, + RefCountedPtr read_delay_handle) override; - void OnError(absl::Status status) override; + void OnError(absl::Status status, + RefCountedPtr read_delay_handle) override; - void OnResourceDoesNotExist() override; + void OnResourceDoesNotExist( + RefCountedPtr read_delay_handle) override; const std::string& listening_address() const { return listening_address_; } @@ -292,16 +298,20 @@ class XdsServerConfigFetcher::ListenerWatcher::FilterChainMatchManager:: filter_chain_match_manager_(std::move(filter_chain_match_manager)) {} void OnResourceChanged( - std::shared_ptr route_config) override { + std::shared_ptr route_config, + RefCountedPtr /* read_delay_handle */) override { filter_chain_match_manager_->OnRouteConfigChanged(resource_name_, std::move(route_config)); } - void OnError(absl::Status status) override { + void OnError( + absl::Status status, + RefCountedPtr /* read_delay_handle */) override { filter_chain_match_manager_->OnError(resource_name_, status); } - void OnResourceDoesNotExist() override { + void OnResourceDoesNotExist( + RefCountedPtr /* read_delay_handle */) override { filter_chain_match_manager_->OnResourceDoesNotExist(resource_name_); } @@ -488,13 +498,21 @@ class XdsServerConfigFetcher::ListenerWatcher::FilterChainMatchManager:: : parent_(std::move(parent)) {} void OnResourceChanged( - std::shared_ptr route_config) override { + std::shared_ptr route_config, + RefCountedPtr /* read_delay_handle */) override { parent_->OnRouteConfigChanged(std::move(route_config)); } - void OnError(absl::Status status) override { parent_->OnError(status); } + void OnError( + absl::Status status, + RefCountedPtr /* read_delay_handle */) override { + parent_->OnError(status); + } - void OnResourceDoesNotExist() override { parent_->OnResourceDoesNotExist(); } + void OnResourceDoesNotExist( + RefCountedPtr /* read_delay_handle */) override { + parent_->OnResourceDoesNotExist(); + } private: WeakRefCountedPtr parent_; @@ -574,7 +592,8 @@ XdsServerConfigFetcher::ListenerWatcher::ListenerWatcher( listening_address_(std::move(listening_address)) {} void XdsServerConfigFetcher::ListenerWatcher::OnResourceChanged( - std::shared_ptr listener) { + std::shared_ptr listener, + RefCountedPtr /* read_delay_handle */) { if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_server_config_fetcher_trace)) { gpr_log(GPR_INFO, "[ListenerWatcher %p] Received LDS update from xds client %p: %s", @@ -610,7 +629,9 @@ void XdsServerConfigFetcher::ListenerWatcher::OnResourceChanged( } } -void XdsServerConfigFetcher::ListenerWatcher::OnError(absl::Status status) { +void XdsServerConfigFetcher::ListenerWatcher::OnError( + absl::Status status, + RefCountedPtr /* read_delay_handle */) { MutexLock lock(&mu_); if (filter_chain_match_manager_ != nullptr || pending_filter_chain_match_manager_ != nullptr) { @@ -653,7 +674,8 @@ void XdsServerConfigFetcher::ListenerWatcher::OnFatalError( } } -void XdsServerConfigFetcher::ListenerWatcher::OnResourceDoesNotExist() { +void XdsServerConfigFetcher::ListenerWatcher::OnResourceDoesNotExist( + RefCountedPtr /* read_delay_handle */) { MutexLock lock(&mu_); OnFatalError(absl::NotFoundError("Requested listener does not exist")); } diff --git a/src/core/ext/xds/xds_transport.h b/src/core/ext/xds/xds_transport.h index 40be0fb1e63..66155475903 100644 --- a/src/core/ext/xds/xds_transport.h +++ b/src/core/ext/xds/xds_transport.h @@ -58,6 +58,9 @@ class XdsTransportFactory : public InternallyRefCounted { // Only one message will be in flight at a time; subsequent // messages will not be sent until this one is done. virtual void SendMessage(std::string payload) = 0; + + // Starts a recv_message operation on the stream. + virtual void StartRecvMessage() = 0; }; // Create a streaming call on this transport for the specified method. diff --git a/src/core/ext/xds/xds_transport_grpc.cc b/src/core/ext/xds/xds_transport_grpc.cc index d2142c91396..1deff78f2a3 100644 --- a/src/core/ext/xds/xds_transport_grpc.cc +++ b/src/core/ext/xds/xds_transport_grpc.cc @@ -86,39 +86,31 @@ GrpcXdsTransportFactory::GrpcXdsTransport::GrpcStreamingCall::GrpcStreamingCall( GRPC_CLOSURE_INIT(&on_request_sent_, OnRequestSent, this, nullptr); // Start ops on the call. grpc_call_error call_error; - grpc_op ops[3]; + grpc_op ops[2]; memset(ops, 0, sizeof(ops)); - // Send initial metadata. No callback for this, since we don't really - // care when it finishes. + // Send initial metadata. grpc_op* op = ops; op->op = GRPC_OP_SEND_INITIAL_METADATA; op->data.send_initial_metadata.count = 0; op->flags = GRPC_INITIAL_METADATA_WAIT_FOR_READY | GRPC_INITIAL_METADATA_WAIT_FOR_READY_EXPLICITLY_SET; op->reserved = nullptr; - op++; - call_error = grpc_call_start_batch_and_execute( - call_, ops, static_cast(op - ops), nullptr); - GPR_ASSERT(GRPC_CALL_OK == call_error); - // Start a batch with recv_initial_metadata and recv_message. - op = ops; + ++op; op->op = GRPC_OP_RECV_INITIAL_METADATA; op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv_; op->flags = 0; op->reserved = nullptr; - op++; - op->op = GRPC_OP_RECV_MESSAGE; - op->data.recv_message.recv_message = &recv_message_payload_; - op->flags = 0; - op->reserved = nullptr; - op++; - Ref(DEBUG_LOCATION, "OnResponseReceived").release(); - GRPC_CLOSURE_INIT(&on_response_received_, OnResponseReceived, this, nullptr); + ++op; + // Ref will be released in the callback + GRPC_CLOSURE_INIT( + &on_recv_initial_metadata_, OnRecvInitialMetadata, + this->Ref(DEBUG_LOCATION, "OnRecvInitialMetadata").release(), nullptr); call_error = grpc_call_start_batch_and_execute( - call_, ops, static_cast(op - ops), &on_response_received_); + call_, ops, static_cast(op - ops), &on_recv_initial_metadata_); GPR_ASSERT(GRPC_CALL_OK == call_error); // Start a batch for recv_trailing_metadata. + memset(ops, 0, sizeof(ops)); op = ops; op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv_; @@ -126,7 +118,7 @@ GrpcXdsTransportFactory::GrpcXdsTransport::GrpcStreamingCall::GrpcStreamingCall( op->data.recv_status_on_client.status_details = &status_details_; op->flags = 0; op->reserved = nullptr; - op++; + ++op; // This callback signals the end of the call, so it relies on the initial // ref instead of a new ref. When it's invoked, it's the initial ref that is // unreffed. @@ -134,11 +126,11 @@ GrpcXdsTransportFactory::GrpcXdsTransport::GrpcStreamingCall::GrpcStreamingCall( call_error = grpc_call_start_batch_and_execute( call_, ops, static_cast(op - ops), &on_status_received_); GPR_ASSERT(GRPC_CALL_OK == call_error); + GRPC_CLOSURE_INIT(&on_response_received_, OnResponseReceived, this, nullptr); } GrpcXdsTransportFactory::GrpcXdsTransport::GrpcStreamingCall:: ~GrpcStreamingCall() { - grpc_metadata_array_destroy(&initial_metadata_recv_); grpc_metadata_array_destroy(&trailing_metadata_recv_); grpc_byte_buffer_destroy(send_message_payload_); grpc_byte_buffer_destroy(recv_message_payload_); @@ -175,55 +167,59 @@ void GrpcXdsTransportFactory::GrpcXdsTransport::GrpcStreamingCall::SendMessage( GPR_ASSERT(GRPC_CALL_OK == call_error); } +void GrpcXdsTransportFactory::GrpcXdsTransport::GrpcStreamingCall:: + StartRecvMessage() { + Ref(DEBUG_LOCATION, "StartRecvMessage").release(); + grpc_op op; + memset(&op, 0, sizeof(op)); + op.op = GRPC_OP_RECV_MESSAGE; + op.data.recv_message.recv_message = &recv_message_payload_; + GPR_ASSERT(call_ != nullptr); + const grpc_call_error call_error = + grpc_call_start_batch_and_execute(call_, &op, 1, &on_response_received_); + GPR_ASSERT(GRPC_CALL_OK == call_error); +} + +void GrpcXdsTransportFactory::GrpcXdsTransport::GrpcStreamingCall:: + OnRecvInitialMetadata(void* arg, grpc_error_handle /*error*/) { + RefCountedPtr self(static_cast(arg)); + grpc_metadata_array_destroy(&self->initial_metadata_recv_); +} + void GrpcXdsTransportFactory::GrpcXdsTransport::GrpcStreamingCall:: OnRequestSent(void* arg, grpc_error_handle error) { - auto* self = static_cast(arg); + RefCountedPtr self(static_cast(arg)); // Clean up the sent message. grpc_byte_buffer_destroy(self->send_message_payload_); self->send_message_payload_ = nullptr; // Invoke request handler. self->event_handler_->OnRequestSent(error.ok()); - // Drop the ref. - self->Unref(DEBUG_LOCATION, "OnRequestSent"); } void GrpcXdsTransportFactory::GrpcXdsTransport::GrpcStreamingCall:: OnResponseReceived(void* arg, grpc_error_handle /*error*/) { - auto* self = static_cast(arg); + RefCountedPtr self(static_cast(arg)); // If there was no payload, then we received status before we received // another message, so we stop reading. - if (self->recv_message_payload_ == nullptr) { - self->Unref(DEBUG_LOCATION, "OnResponseReceived"); - return; + if (self->recv_message_payload_ != nullptr) { + // Process the response. + grpc_byte_buffer_reader bbr; + grpc_byte_buffer_reader_init(&bbr, self->recv_message_payload_); + grpc_slice response_slice = grpc_byte_buffer_reader_readall(&bbr); + grpc_byte_buffer_reader_destroy(&bbr); + grpc_byte_buffer_destroy(self->recv_message_payload_); + self->recv_message_payload_ = nullptr; + self->event_handler_->OnRecvMessage(StringViewFromSlice(response_slice)); + CSliceUnref(response_slice); } - // Process the response. - grpc_byte_buffer_reader bbr; - grpc_byte_buffer_reader_init(&bbr, self->recv_message_payload_); - grpc_slice response_slice = grpc_byte_buffer_reader_readall(&bbr); - grpc_byte_buffer_reader_destroy(&bbr); - grpc_byte_buffer_destroy(self->recv_message_payload_); - self->recv_message_payload_ = nullptr; - self->event_handler_->OnRecvMessage(StringViewFromSlice(response_slice)); - CSliceUnref(response_slice); - // Keep reading. - grpc_op op; - memset(&op, 0, sizeof(op)); - op.op = GRPC_OP_RECV_MESSAGE; - op.data.recv_message.recv_message = &self->recv_message_payload_; - GPR_ASSERT(self->call_ != nullptr); - // Reuses the "OnResponseReceived" ref taken in ctor. - const grpc_call_error call_error = grpc_call_start_batch_and_execute( - self->call_, &op, 1, &self->on_response_received_); - GPR_ASSERT(GRPC_CALL_OK == call_error); } void GrpcXdsTransportFactory::GrpcXdsTransport::GrpcStreamingCall:: OnStatusReceived(void* arg, grpc_error_handle /*error*/) { - auto* self = static_cast(arg); + RefCountedPtr self(static_cast(arg)); self->event_handler_->OnStatusReceived( absl::Status(static_cast(self->status_code_), StringViewFromSlice(self->status_details_))); - self->Unref(DEBUG_LOCATION, "OnStatusReceived"); } // diff --git a/src/core/ext/xds/xds_transport_grpc.h b/src/core/ext/xds/xds_transport_grpc.h index f9e29fc0e4c..a4d25ff75c9 100644 --- a/src/core/ext/xds/xds_transport_grpc.h +++ b/src/core/ext/xds/xds_transport_grpc.h @@ -100,7 +100,10 @@ class GrpcXdsTransportFactory::GrpcXdsTransport::GrpcStreamingCall void SendMessage(std::string payload) override; + void StartRecvMessage() override; + private: + static void OnRecvInitialMetadata(void* arg, grpc_error_handle /*error*/); static void OnRequestSent(void* arg, grpc_error_handle error); static void OnResponseReceived(void* arg, grpc_error_handle /*error*/); static void OnStatusReceived(void* arg, grpc_error_handle /*error*/); @@ -114,6 +117,7 @@ class GrpcXdsTransportFactory::GrpcXdsTransport::GrpcStreamingCall // recv_initial_metadata grpc_metadata_array initial_metadata_recv_; + grpc_closure on_recv_initial_metadata_; // send_message grpc_byte_buffer* send_message_payload_ = nullptr; diff --git a/test/core/xds/xds_client_fuzzer.cc b/test/core/xds/xds_client_fuzzer.cc index be166400f49..0e7459c9abe 100644 --- a/test/core/xds/xds_client_fuzzer.cc +++ b/test/core/xds/xds_client_fuzzer.cc @@ -57,7 +57,8 @@ class Fuzzer { // Leave xds_client_ unset, so Act() will be a no-op. return; } - auto transport_factory = MakeOrphanable(); + auto transport_factory = MakeOrphanable( + []() { Crash("Multiple concurrent reads"); }); transport_factory->SetAutoCompleteMessagesFromClient(false); transport_factory->SetAbortOnUndrainedMessages(false); transport_factory_ = transport_factory.get(); @@ -147,20 +148,26 @@ class Fuzzer { : resource_name_(std::move(resource_name)) {} void OnResourceChanged( - std::shared_ptr resource) + std::shared_ptr resource, + RefCountedPtr /* read_delay_handle */) override { gpr_log(GPR_INFO, "==> OnResourceChanged(%s %s): %s", std::string(ResourceType::Get()->type_url()).c_str(), resource_name_.c_str(), resource->ToString().c_str()); } - void OnError(absl::Status status) override { + void OnError( + absl::Status status, + RefCountedPtr /* read_delay_handle */) + override { gpr_log(GPR_INFO, "==> OnError(%s %s): %s", std::string(ResourceType::Get()->type_url()).c_str(), resource_name_.c_str(), status.ToString().c_str()); } - void OnResourceDoesNotExist() override { + void OnResourceDoesNotExist( + RefCountedPtr /* read_delay_handle */) + override { gpr_log(GPR_INFO, "==> OnResourceDoesNotExist(%s %s)", std::string(ResourceType::Get()->type_url()).c_str(), resource_name_.c_str()); diff --git a/test/core/xds/xds_client_test.cc b/test/core/xds/xds_client_test.cc index 729ebc1c38a..14f2ffdc9e6 100644 --- a/test/core/xds/xds_client_test.cc +++ b/test/core/xds/xds_client_test.cc @@ -235,6 +235,11 @@ class XdsClientTest : public ::testing::Test { XdsTestResourceType, ResourceStruct> { public: + struct ResourceAndReadDelayHandle { + std::shared_ptr resource; + RefCountedPtr read_delay_handle; + }; + // A watcher implementation that queues delivered watches. class Watcher : public XdsResourceTypeImpl< XdsTestResourceType WaitForNextResource( + absl::optional WaitForNextResourceAndHandle( absl::Duration timeout = absl::Seconds(1), SourceLocation location = SourceLocation()) { MutexLock lock(&mu_); - if (!WaitForEventLocked(timeout)) return nullptr; + if (!WaitForEventLocked(timeout)) return absl::nullopt; Event& event = queue_.front(); - if (!absl::holds_alternative>( - event)) { + if (!absl::holds_alternative(event)) { EXPECT_TRUE(false) << "got unexpected event " << (absl::holds_alternative(event) ? "error" : "does-not-exist") << " at " << location.file() << ":" << location.line(); - return nullptr; + return absl::nullopt; } - auto foo = - std::move(absl::get>(event)); + auto foo = std::move(absl::get(event)); queue_.pop_front(); return foo; } + std::shared_ptr WaitForNextResource( + absl::Duration timeout = absl::Seconds(1), + SourceLocation location = SourceLocation()) { + auto resource_and_handle = + WaitForNextResourceAndHandle(timeout, location); + if (!resource_and_handle.has_value()) { + return nullptr; + } + return std::move(resource_and_handle->resource); + } + absl::optional WaitForNextError( absl::Duration timeout = absl::Seconds(1), SourceLocation location = SourceLocation()) { @@ -283,8 +297,7 @@ class XdsClientTest : public ::testing::Test { if (!absl::holds_alternative(event)) { EXPECT_TRUE(false) << "got unexpected event " - << (absl::holds_alternative< - std::shared_ptr>(event) + << (absl::holds_alternative(event) ? "resource" : "does-not-exist") << " at " << location.file() << ":" << location.line(); @@ -314,21 +327,31 @@ class XdsClientTest : public ::testing::Test { private: struct DoesNotExist {}; - using Event = absl::variant, - absl::Status, DoesNotExist>; + using Event = + absl::variant; - void OnResourceChanged( - std::shared_ptr foo) override { + void OnResourceChanged(std::shared_ptr foo, + RefCountedPtr + read_delay_handle) override { MutexLock lock(&mu_); - queue_.push_back(std::move(foo)); + ResourceAndReadDelayHandle event_details = { + std::move(foo), std::move(read_delay_handle)}; + queue_.emplace_back(std::move(event_details)); cv_.Signal(); } - void OnError(absl::Status status) override { + + void OnError( + absl::Status status, + RefCountedPtr /* read_delay_handle */) + override { MutexLock lock(&mu_); queue_.push_back(std::move(status)); cv_.Signal(); } - void OnResourceDoesNotExist() override { + + void OnResourceDoesNotExist( + RefCountedPtr /* read_delay_handle */) + override { MutexLock lock(&mu_); queue_.push_back(DoesNotExist()); cv_.Signal(); @@ -577,7 +600,8 @@ class XdsClientTest : public ::testing::Test { void InitXdsClient( FakeXdsBootstrap::Builder bootstrap_builder = FakeXdsBootstrap::Builder(), Duration resource_request_timeout = Duration::Seconds(15)) { - auto transport_factory = MakeOrphanable(); + auto transport_factory = MakeOrphanable( + []() { FAIL() << "Multiple concurrent reads"; }); transport_factory_ = transport_factory->Ref().TakeAsSubclass(); xds_client_ = MakeRefCounted( @@ -671,8 +695,9 @@ class XdsClientTest : public ::testing::Test { // Helper function to check the fields of a DiscoveryRequest. void CheckRequest(const DiscoveryRequest& request, absl::string_view type_url, absl::string_view version_info, - absl::string_view response_nonce, absl::Status error_detail, - std::set resource_names, + absl::string_view response_nonce, + const absl::Status& error_detail, + const std::set& resource_names, SourceLocation location = SourceLocation()) { EXPECT_EQ(request.type_url(), absl::StrCat("type.googleapis.com/", type_url)) @@ -2695,6 +2720,91 @@ TEST_F(XdsClientTest, FederationChannelFailureReportedToWatchers) { EXPECT_TRUE(stream2->Orphaned()); } +TEST_F(XdsClientTest, AdsReadWaitsForHandleRelease) { + InitXdsClient(); + // Start watches for "foo1" and "foo2". + auto watcher1 = StartFooWatch("foo1"); + // XdsClient should have created an ADS stream. + auto stream = WaitForAdsStream(); + ASSERT_TRUE(stream != nullptr); + // XdsClient should have sent a subscription request on the ADS stream. + auto request = WaitForRequest(stream.get()); + ASSERT_TRUE(request.has_value()); + CheckRequest(*request, XdsFooResourceType::Get()->type_url(), + /*version_info=*/"", /*response_nonce=*/"", + /*error_detail=*/absl::OkStatus(), + /*resource_names=*/{"foo1"}); + auto watcher2 = StartFooWatch("foo2"); + request = WaitForRequest(stream.get()); + ASSERT_TRUE(request.has_value()); + CheckRequest(*request, XdsFooResourceType::Get()->type_url(), + /*version_info=*/"", /*response_nonce=*/"", + /*error_detail=*/absl::OkStatus(), + /*resource_names=*/{"foo1", "foo2"}); + // Send a response with 2 resources. + stream->SendMessageToClient( + ResponseBuilder(XdsFooResourceType::Get()->type_url()) + .set_version_info("1") + .set_nonce("A") + .AddFooResource(XdsFooResource("foo1", 6)) + .AddFooResource(XdsFooResource("foo2", 10)) + .Serialize()); + // Send a response with a single resource, will not be read until the handle + // is released + stream->SendMessageToClient( + ResponseBuilder(XdsFooResourceType::Get()->type_url()) + .set_version_info("2") + .set_nonce("B") + .AddFooResource(XdsFooResource("foo1", 8)) + .Serialize()); + // XdsClient should have delivered the response to the watcher. + auto resource1 = watcher1->WaitForNextResourceAndHandle(); + ASSERT_NE(resource1, absl::nullopt); + EXPECT_EQ(resource1->resource->name, "foo1"); + EXPECT_EQ(resource1->resource->value, 6); + auto resource2 = watcher2->WaitForNextResourceAndHandle(); + ASSERT_NE(resource2, absl::nullopt); + EXPECT_EQ(resource2->resource->name, "foo2"); + EXPECT_EQ(resource2->resource->value, 10); + // XdsClient should have sent an ACK message to the xDS server. + request = WaitForRequest(stream.get()); + ASSERT_TRUE(request.has_value()); + CheckRequest(*request, XdsFooResourceType::Get()->type_url(), + /*version_info=*/"1", /*response_nonce=*/"A", + /*error_detail=*/absl::OkStatus(), + /*resource_names=*/{"foo1", "foo2"}); + EXPECT_EQ(stream->reads_started(), 1); + resource1->read_delay_handle.reset(); + EXPECT_EQ(stream->reads_started(), 1); + resource2->read_delay_handle.reset(); + EXPECT_EQ(stream->reads_started(), 2); + resource1 = watcher1->WaitForNextResourceAndHandle(); + ASSERT_NE(resource1, absl::nullopt); + EXPECT_EQ(resource1->resource->name, "foo1"); + EXPECT_EQ(resource1->resource->value, 8); + EXPECT_EQ(watcher2->WaitForNextResourceAndHandle(), absl::nullopt); + // XdsClient should have sent an ACK message to the xDS server. + request = WaitForRequest(stream.get()); + ASSERT_TRUE(request.has_value()); + CheckRequest(*request, XdsFooResourceType::Get()->type_url(), + /*version_info=*/"2", /*response_nonce=*/"B", + /*error_detail=*/absl::OkStatus(), + /*resource_names=*/{"foo1", "foo2"}); + EXPECT_EQ(stream->reads_started(), 2); + resource1->read_delay_handle.reset(); + EXPECT_EQ(stream->reads_started(), 3); + // Cancel watch. + CancelFooWatch(watcher1.get(), "foo1"); + request = WaitForRequest(stream.get()); + ASSERT_TRUE(request.has_value()); + CheckRequest(*request, XdsFooResourceType::Get()->type_url(), + /*version_info=*/"2", /*response_nonce=*/"B", + /*error_detail=*/absl::OkStatus(), + /*resource_names=*/{"foo2"}); + CancelFooWatch(watcher2.get(), "foo2"); + EXPECT_TRUE(stream->Orphaned()); +} + } // namespace } // namespace testing } // namespace grpc_core diff --git a/test/core/xds/xds_transport_fake.cc b/test/core/xds/xds_transport_fake.cc index 054aa4b7cba..300e849acef 100644 --- a/test/core/xds/xds_transport_fake.cc +++ b/test/core/xds/xds_transport_fake.cc @@ -20,6 +20,8 @@ #include #include +#include +#include #include #include @@ -46,6 +48,10 @@ FakeXdsTransportFactory::FakeStreamingCall::~FakeStreamingCall() { { MutexLock lock(&mu_); if (transport_->abort_on_undrained_messages()) { + for (const auto& message : from_client_messages_) { + gpr_log(GPR_ERROR, "From client message left in queue: %s", + message.c_str()); + } GPR_ASSERT(from_client_messages_.empty()); } } @@ -120,15 +126,49 @@ void FakeXdsTransportFactory::FakeStreamingCall::CompleteSendMessageFromClient( CompleteSendMessageFromClientLocked(ok); } +void FakeXdsTransportFactory::FakeStreamingCall::StartRecvMessage() { + MutexLock lock(&mu_); + if (num_pending_reads_ > 0) { + transport_->factory()->too_many_pending_reads_callback_(); + } + ++reads_started_; + ++num_pending_reads_; + if (!to_client_messages_.empty()) { + // Dispatch pending message (if there's one) on a separate thread to avoid + // recursion + GetDefaultEventEngine()->Run([call = RefAsSubclass()]() { + call->MaybeDeliverMessageToClient(); + }); + } +} + void FakeXdsTransportFactory::FakeStreamingCall::SendMessageToClient( absl::string_view payload) { - ExecCtx exec_ctx; - RefCountedPtr event_handler; { MutexLock lock(&mu_); - event_handler = event_handler_->Ref(); + to_client_messages_.emplace_back(payload); + } + MaybeDeliverMessageToClient(); +} + +void FakeXdsTransportFactory::FakeStreamingCall::MaybeDeliverMessageToClient() { + RefCountedPtr event_handler; + std::string message; + // Loop terminates with a break inside + while (true) { + { + MutexLock lock(&mu_); + if (num_pending_reads_ == 0 || to_client_messages_.empty()) { + break; + } + --num_pending_reads_; + message = std::move(to_client_messages_.front()); + to_client_messages_.pop_front(); + event_handler = event_handler_; + } + ExecCtx exec_ctx; + event_handler->OnRecvMessage(message); } - event_handler->OnRecvMessage(payload); } void FakeXdsTransportFactory::FakeStreamingCall::MaybeSendStatusToClient( diff --git a/test/core/xds/xds_transport_fake.h b/test/core/xds/xds_transport_fake.h index 1dc52083552..f3d51de240e 100644 --- a/test/core/xds/xds_transport_fake.h +++ b/test/core/xds/xds_transport_fake.h @@ -19,6 +19,8 @@ #include +#include + #include #include #include @@ -66,6 +68,8 @@ class FakeXdsTransportFactory : public XdsTransportFactory { void Orphan() override; + void StartRecvMessage() override; + using StreamingCall::Ref; // Make it public. bool HaveMessageFromClient(); @@ -84,6 +88,11 @@ class FakeXdsTransportFactory : public XdsTransportFactory { bool Orphaned(); + size_t reads_started() { + MutexLock lock(&mu_); + return reads_started_; + } + private: class RefCountedEventHandler : public RefCounted { public: @@ -107,6 +116,7 @@ class FakeXdsTransportFactory : public XdsTransportFactory { void CompleteSendMessageFromClientLocked(bool ok) ABSL_EXCLUSIVE_LOCKS_REQUIRED(&mu_); + void MaybeDeliverMessageToClient(); RefCountedPtr transport_; const char* method_; @@ -117,9 +127,15 @@ class FakeXdsTransportFactory : public XdsTransportFactory { std::deque from_client_messages_ ABSL_GUARDED_BY(&mu_); bool status_sent_ ABSL_GUARDED_BY(&mu_) = false; bool orphaned_ ABSL_GUARDED_BY(&mu_) = false; + size_t reads_started_ ABSL_GUARDED_BY(&mu_) = 0; + size_t num_pending_reads_ ABSL_GUARDED_BY(&mu_) = 0; + std::deque to_client_messages_ ABSL_GUARDED_BY(&mu_); }; - FakeXdsTransportFactory() = default; + explicit FakeXdsTransportFactory( + std::function too_many_pending_reads_callback) + : too_many_pending_reads_callback_( + std::move(too_many_pending_reads_callback)) {} using XdsTransportFactory::Ref; // Make it public. @@ -130,7 +146,7 @@ class FakeXdsTransportFactory : public XdsTransportFactory { // EventHandler::OnRequestSent() upon reading a request from the client. // If this is set to false, that behavior will be inhibited, and // EventHandler::OnRequestSent() will not be called until the test - // expicitly calls FakeStreamingCall::CompleteSendMessageFromClient(). + // explicitly calls FakeStreamingCall::CompleteSendMessageFromClient(). // // This value affects all transports created after this call is // complete. Any transport that already exists prior to this call @@ -189,6 +205,8 @@ class FakeXdsTransportFactory : public XdsTransportFactory { void RemoveStream(const char* method, FakeStreamingCall* call); + FakeXdsTransportFactory* factory() const { return factory_.get(); } + private: class RefCountedOnConnectivityFailure : public RefCounted { @@ -237,6 +255,7 @@ class FakeXdsTransportFactory : public XdsTransportFactory { transport_map_ ABSL_GUARDED_BY(&mu_); bool auto_complete_messages_from_client_ ABSL_GUARDED_BY(&mu_) = true; bool abort_on_undrained_messages_ ABSL_GUARDED_BY(&mu_) = true; + std::function too_many_pending_reads_callback_; }; } // namespace grpc_core