diff --git a/BUILD b/BUILD index 97c1ee534cd..342018e6ad1 100644 --- a/BUILD +++ b/BUILD @@ -2532,6 +2532,7 @@ grpc_cc_library( grpc_cc_library( name = "ref_counted_ptr", + external_deps = ["absl/hash"], language = "c++", public_hdrs = ["//src/core:lib/gprpp/ref_counted_ptr.h"], visibility = ["@grpc:ref_counted_ptr"], @@ -3038,6 +3039,7 @@ grpc_cc_library( "legacy_context", "orphanable", "parse_address", + "promise", "protobuf_duration_upb", "ref_counted_ptr", "server_address", @@ -3047,8 +3049,10 @@ grpc_cc_library( "work_serializer", "xds_orca_service_upb", "xds_orca_upb", + "//src/core:activity", "//src/core:arena", "//src/core:arena_promise", + "//src/core:cancel_callback", "//src/core:channel_args", "//src/core:channel_fwd", "//src/core:channel_init", @@ -3070,15 +3074,20 @@ grpc_cc_library( "//src/core:json_args", "//src/core:json_channel_args", "//src/core:json_object_loader", + "//src/core:latch", "//src/core:lb_policy", "//src/core:lb_policy_registry", + "//src/core:map", "//src/core:memory_quota", + "//src/core:pipe", + "//src/core:poll", "//src/core:pollset_set", "//src/core:proxy_mapper", "//src/core:proxy_mapper_registry", "//src/core:ref_counted", "//src/core:resolved_address", "//src/core:resource_quota", + "//src/core:seq", "//src/core:service_config_parser", "//src/core:slice", "//src/core:slice_buffer", @@ -3088,6 +3097,7 @@ grpc_cc_library( "//src/core:subchannel_interface", "//src/core:time", "//src/core:transport_fwd", + "//src/core:try_seq", "//src/core:unique_type_name", "//src/core:useful", "//src/core:validation_errors", diff --git a/CMakeLists.txt b/CMakeLists.txt index f326f0afe28..60067c6639f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -6534,6 +6534,7 @@ target_include_directories(avl_test target_link_libraries(avl_test ${_gRPC_ALLTARGETS_LIBRARIES} gtest + absl::hash gpr ) @@ -10629,6 +10630,7 @@ target_include_directories(endpoint_config_test target_link_libraries(endpoint_config_test ${_gRPC_ALLTARGETS_LIBRARIES} gtest + absl::hash absl::type_traits absl::statusor gpr @@ -15099,6 +15101,7 @@ target_include_directories(latch_test target_link_libraries(latch_test ${_gRPC_ALLTARGETS_LIBRARIES} gtest + absl::hash absl::type_traits absl::statusor gpr @@ -24338,6 +24341,7 @@ target_include_directories(thread_quota_test target_link_libraries(thread_quota_test ${_gRPC_ALLTARGETS_LIBRARIES} gtest + absl::hash gpr ) @@ -25389,6 +25393,7 @@ target_include_directories(wait_for_callback_test target_link_libraries(wait_for_callback_test ${_gRPC_ALLTARGETS_LIBRARIES} gtest + absl::hash absl::type_traits absl::statusor gpr diff --git a/build_autogenerated.yaml b/build_autogenerated.yaml index 44c8b46f5d9..ea4eab19b96 100644 --- a/build_autogenerated.yaml +++ b/build_autogenerated.yaml @@ -5338,6 +5338,7 @@ targets: - test/core/avl/avl_test.cc deps: - gtest + - absl/hash:hash - gpr uses_polling: false - name: aws_request_signer_test @@ -7774,6 +7775,7 @@ targets: - test/core/event_engine/endpoint_config_test.cc deps: - gtest + - absl/hash:hash - absl/meta:type_traits - absl/status:statusor - gpr @@ -10460,6 +10462,7 @@ targets: - test/core/promise/latch_test.cc deps: - gtest + - absl/hash:hash - absl/meta:type_traits - absl/status:statusor - gpr @@ -16259,6 +16262,7 @@ targets: - test/core/resource_quota/thread_quota_test.cc deps: - gtest + - absl/hash:hash - gpr uses_polling: false - name: thread_stress_test @@ -16798,6 +16802,7 @@ targets: - test/core/promise/wait_for_callback_test.cc deps: - gtest + - absl/hash:hash - absl/meta:type_traits - absl/status:statusor - gpr diff --git a/src/core/ext/filters/client_channel/client_channel.cc b/src/core/ext/filters/client_channel/client_channel.cc index 02b205a633c..6cfd4c48fcd 100644 --- a/src/core/ext/filters/client_channel/client_channel.cc +++ b/src/core/ext/filters/client_channel/client_channel.cc @@ -82,6 +82,14 @@ #include "src/core/lib/json/json.h" #include "src/core/lib/load_balancing/lb_policy_registry.h" #include "src/core/lib/load_balancing/subchannel_interface.h" +#include "src/core/lib/promise/cancel_callback.h" +#include "src/core/lib/promise/context.h" +#include "src/core/lib/promise/latch.h" +#include "src/core/lib/promise/map.h" +#include "src/core/lib/promise/pipe.h" +#include "src/core/lib/promise/poll.h" +#include "src/core/lib/promise/promise.h" +#include "src/core/lib/promise/try_seq.h" #include "src/core/lib/resolver/resolver_registry.h" #include "src/core/lib/resolver/server_address.h" #include "src/core/lib/security/credentials/credentials.h" @@ -89,6 +97,7 @@ #include "src/core/lib/service_config/service_config_impl.h" #include "src/core/lib/slice/slice.h" #include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/surface/call.h" #include "src/core/lib/surface/channel.h" #include "src/core/lib/transport/connectivity_state.h" #include "src/core/lib/transport/error_utils.h" @@ -146,7 +155,7 @@ class ClientChannel::CallData { // Accessors for data stored in the subclass. virtual ClientChannel* chand() const = 0; virtual Arena* arena() const = 0; - virtual grpc_polling_entity* pollent() const = 0; + virtual grpc_polling_entity* pollent() = 0; virtual grpc_metadata_batch* send_initial_metadata() = 0; virtual grpc_call_context_element* call_context() const = 0; @@ -205,7 +214,7 @@ class ClientChannel::FilterBasedCallData : public ClientChannel::CallData { return static_cast<ClientChannel*>(elem()->channel_data); } Arena* arena() const override { return deadline_state_.arena; } - grpc_polling_entity* pollent() const override { return pollent_; } + grpc_polling_entity* pollent() override { return pollent_; } grpc_metadata_batch* send_initial_metadata() override { return pending_batches_[0] ->payload->send_initial_metadata.send_initial_metadata; @@ -298,11 +307,105 @@ class ClientChannel::FilterBasedCallData : public ClientChannel::CallData { grpc_error_handle cancel_error_; }; +class ClientChannel::PromiseBasedCallData : public ClientChannel::CallData { + public: + explicit PromiseBasedCallData(ClientChannel* chand) : chand_(chand) {} + + ArenaPromise<absl::StatusOr<CallArgs>> MakeNameResolutionPromise( + CallArgs call_args) { + pollent_ = NowOrNever(call_args.polling_entity->WaitAndCopy()).value(); + client_initial_metadata_ = std::move(call_args.client_initial_metadata); + // If we're still in IDLE, we need to start resolving. + if (GPR_UNLIKELY(chand_->CheckConnectivityState(false) == + GRPC_CHANNEL_IDLE)) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_call_trace)) { + gpr_log(GPR_INFO, "chand=%p calld=%p: %striggering exit idle", chand_, + this, Activity::current()->DebugTag().c_str()); + } + // Bounce into the control plane work serializer to start resolving. + GRPC_CHANNEL_STACK_REF(chand_->owning_stack_, "ExitIdle"); + chand_->work_serializer_->Run( + [chand = chand_]() + ABSL_EXCLUSIVE_LOCKS_REQUIRED(*chand_->work_serializer_) { + chand->CheckConnectivityState(/*try_to_connect=*/true); + GRPC_CHANNEL_STACK_UNREF(chand->owning_stack_, "ExitIdle"); + }, + DEBUG_LOCATION); + } + return [this, call_args = std::move( + call_args)]() mutable -> Poll<absl::StatusOr<CallArgs>> { + auto result = CheckResolution(was_queued_); + if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_call_trace)) { + gpr_log(GPR_INFO, "chand=%p calld=%p: %sCheckResolution returns %s", + chand_, this, Activity::current()->DebugTag().c_str(), + result.has_value() ? result->ToString().c_str() : "Pending"); + } + if (!result.has_value()) { + waker_ = Activity::current()->MakeNonOwningWaker(); + was_queued_ = true; + return Pending{}; + } + if (!result->ok()) return *result; + call_args.client_initial_metadata = std::move(client_initial_metadata_); + return std::move(call_args); + }; + } + + private: + ClientChannel* chand() const override { return chand_; } + Arena* arena() const override { return GetContext<Arena>(); } + grpc_polling_entity* pollent() override { return &pollent_; } + grpc_metadata_batch* send_initial_metadata() override { + return client_initial_metadata_.get(); + } + grpc_call_context_element* call_context() const override { + return GetContext<grpc_call_context_element>(); + } + + void RetryCheckResolutionLocked() override { + if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_call_trace)) { + gpr_log(GPR_INFO, "chand=%p calld=%p: RetryCheckResolutionLocked()", + chand_, this); + } + waker_.WakeupAsync(); + } + + void ResetDeadline(Duration timeout) override { + CallContext* call_context = GetContext<CallContext>(); + const Timestamp per_method_deadline = + Timestamp::FromCycleCounterRoundUp(call_context->call_start_time()) + + timeout; + call_context->UpdateDeadline(per_method_deadline); + } + + ClientChannel* chand_; + grpc_polling_entity pollent_; + ClientMetadataHandle client_initial_metadata_; + bool was_queued_ = false; + Waker waker_; +}; + // // Filter vtable // -const grpc_channel_filter ClientChannel::kFilterVtable = { +const grpc_channel_filter ClientChannel::kFilterVtableWithPromises = { + ClientChannel::FilterBasedCallData::StartTransportStreamOpBatch, + ClientChannel::MakeCallPromise, + ClientChannel::StartTransportOp, + sizeof(ClientChannel::FilterBasedCallData), + ClientChannel::FilterBasedCallData::Init, + ClientChannel::FilterBasedCallData::SetPollent, + ClientChannel::FilterBasedCallData::Destroy, + sizeof(ClientChannel), + ClientChannel::Init, + grpc_channel_stack_no_post_init, + ClientChannel::Destroy, + ClientChannel::GetChannelInfo, + "client-channel", +}; + +const grpc_channel_filter ClientChannel::kFilterVtableWithoutPromises = { ClientChannel::FilterBasedCallData::StartTransportStreamOpBatch, nullptr, ClientChannel::StartTransportOp, @@ -324,6 +427,12 @@ const grpc_channel_filter ClientChannel::kFilterVtable = { namespace { +ClientChannelServiceConfigCallData* GetServiceConfigCallData( + grpc_call_context_element* context) { + return static_cast<ClientChannelServiceConfigCallData*>( + context[GRPC_CONTEXT_SERVICE_CONFIG_CALL_DATA].value); +} + class DynamicTerminationFilter { public: class CallData; @@ -349,6 +458,19 @@ class DynamicTerminationFilter { static void GetChannelInfo(grpc_channel_element* /*elem*/, const grpc_channel_info* /*info*/) {} + static ArenaPromise<ServerMetadataHandle> MakeCallPromise( + grpc_channel_element* elem, CallArgs call_args, NextPromiseFactory) { + auto* chand = static_cast<DynamicTerminationFilter*>(elem->channel_data); + return chand->chand_->CreateLoadBalancedCallPromise( + std::move(call_args), + []() { + auto* service_config_call_data = + GetServiceConfigCallData(GetContext<grpc_call_context_element>()); + service_config_call_data->Commit(); + }, + /*is_transparent_retry=*/false); + } + private: explicit DynamicTerminationFilter(const ChannelArgs& args) : chand_(args.GetObject<ClientChannel>()) {} @@ -397,8 +519,7 @@ class DynamicTerminationFilter::CallData { /*start_time=*/0, calld->deadline_, calld->arena_, calld->call_combiner_}; auto* service_config_call_data = - static_cast<ClientChannelServiceConfigCallData*>( - calld->call_context_[GRPC_CONTEXT_SERVICE_CONFIG_CALL_DATA].value); + GetServiceConfigCallData(calld->call_context_); calld->lb_call_ = client_channel->CreateLoadBalancedCall( args, pollent, nullptr, [service_config_call_data]() { service_config_call_data->Commit(); }, @@ -433,7 +554,7 @@ class DynamicTerminationFilter::CallData { const grpc_channel_filter DynamicTerminationFilter::kFilterVtable = { DynamicTerminationFilter::CallData::StartTransportStreamOpBatch, - nullptr, + DynamicTerminationFilter::MakeCallPromise, DynamicTerminationFilter::StartTransportOp, sizeof(DynamicTerminationFilter::CallData), DynamicTerminationFilter::CallData::Init, @@ -1013,14 +1134,18 @@ class ClientChannel::ClientChannelControlHelper ClientChannel* ClientChannel::GetFromChannel(Channel* channel) { grpc_channel_element* elem = grpc_channel_stack_last_element(channel->channel_stack()); - if (elem->filter != &kFilterVtable) return nullptr; + if (elem->filter != &kFilterVtableWithPromises && + elem->filter != &kFilterVtableWithoutPromises) { + return nullptr; + } return static_cast<ClientChannel*>(elem->channel_data); } grpc_error_handle ClientChannel::Init(grpc_channel_element* elem, grpc_channel_element_args* args) { GPR_ASSERT(args->is_last); - GPR_ASSERT(elem->filter == &kFilterVtable); + GPR_ASSERT(elem->filter == &kFilterVtableWithPromises || + elem->filter == &kFilterVtableWithoutPromises); grpc_error_handle error; new (elem->channel_data) ClientChannel(args, &error); return error; @@ -1136,6 +1261,21 @@ ClientChannel::~ClientChannel() { grpc_pollset_set_destroy(interested_parties_); } +ArenaPromise<ServerMetadataHandle> ClientChannel::MakeCallPromise( + grpc_channel_element* elem, CallArgs call_args, NextPromiseFactory) { + auto* chand = static_cast<ClientChannel*>(elem->channel_data); + // TODO(roth): Is this the right lifetime story for calld? + auto* calld = GetContext<Arena>()->ManagedNew<PromiseBasedCallData>(chand); + return TrySeq( + // Name resolution. + calld->MakeNameResolutionPromise(std::move(call_args)), + // Dynamic filter stack. + [calld](CallArgs call_args) mutable { + return calld->dynamic_filters()->channel_stack()->MakeClientCallPromise( + std::move(call_args)); + }); +} + OrphanablePtr<ClientChannel::FilterBasedLoadBalancedCall> ClientChannel::CreateLoadBalancedCall( const grpc_call_element_args& args, grpc_polling_entity* pollent, @@ -1147,6 +1287,16 @@ ClientChannel::CreateLoadBalancedCall( std::move(on_commit), is_transparent_retry)); } +ArenaPromise<ServerMetadataHandle> ClientChannel::CreateLoadBalancedCallPromise( + CallArgs call_args, absl::AnyInvocable<void()> on_commit, + bool is_transparent_retry) { + OrphanablePtr<PromiseBasedLoadBalancedCall> lb_call( + GetContext<Arena>()->New<PromiseBasedLoadBalancedCall>( + this, std::move(on_commit), is_transparent_retry)); + auto* call_ptr = lb_call.get(); + return call_ptr->MakeCallPromise(std::move(call_args), std::move(lb_call)); +} + ChannelArgs ClientChannel::MakeSubchannelArgs( const ChannelArgs& channel_args, const ChannelArgs& address_args, const RefCountedPtr<SubchannelPoolInterface>& subchannel_pool, @@ -1610,7 +1760,7 @@ void ClientChannel::UpdateStateAndPickerLocked( MutexLock lock(&lb_mu_); picker_.swap(picker); // Reprocess queued picks. - for (LoadBalancedCall* call : lb_queued_calls_) { + for (auto& call : lb_queued_calls_) { call->RemoveCallFromLbQueuedCallsLocked(); call->RetryPickLocked(); } @@ -1840,8 +1990,10 @@ void ClientChannel::CallData::RemoveCallFromResolverQueuedCallsLocked() { void ClientChannel::CallData::AddCallToResolverQueuedCallsLocked() { if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_call_trace)) { - gpr_log(GPR_INFO, "chand=%p calld=%p: adding to resolver queued picks list", - chand(), this); + gpr_log( + GPR_INFO, + "chand=%p calld=%p: adding to resolver queued picks list; pollent=%s", + chand(), this, grpc_polling_entity_string(pollent()).c_str()); } // Add call's pollent to channel's interested_parties, so that I/O // can be done under the call's CQ. @@ -2351,8 +2503,7 @@ void ClientChannel::FilterBasedCallData:: auto* calld = static_cast<FilterBasedCallData*>(arg); auto* chand = calld->chand(); auto* service_config_call_data = - static_cast<ClientChannelServiceConfigCallData*>( - calld->call_context()[GRPC_CONTEXT_SERVICE_CONFIG_CALL_DATA].value); + GetServiceConfigCallData(calld->call_context()); if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_call_trace)) { gpr_log(GPR_INFO, "chand=%p calld=%p: got recv_trailing_metadata_ready: error=%s " @@ -2470,8 +2621,8 @@ class ClientChannel::LoadBalancedCall::Metadata ServiceConfigCallData::CallAttributeInterface* ClientChannel::LoadBalancedCall::LbCallState::GetCallAttribute( UniqueTypeName type) const { - auto* service_config_call_data = static_cast<ServiceConfigCallData*>( - lb_call_->call_context()[GRPC_CONTEXT_SERVICE_CONFIG_CALL_DATA].value); + auto* service_config_call_data = + GetServiceConfigCallData(lb_call_->call_context()); return service_config_call_data->GetCallAttribute(type); } @@ -2559,16 +2710,6 @@ ClientChannel::LoadBalancedCall::~LoadBalancedCall() { } } -void ClientChannel::LoadBalancedCall::Orphan() { - // Compute latency and report it to the tracer. - if (call_attempt_tracer() != nullptr) { - gpr_timespec latency = - gpr_cycle_counter_sub(gpr_get_cycle_counter(), lb_call_start_time_); - call_attempt_tracer()->RecordEnd(latency); - } - Unref(); -} - void ClientChannel::LoadBalancedCall::RecordCallCompletion( absl::Status status, grpc_metadata_batch* recv_trailing_metadata, grpc_transport_stream_stats* transport_stream_stats, @@ -2590,6 +2731,15 @@ void ClientChannel::LoadBalancedCall::RecordCallCompletion( } } +void ClientChannel::LoadBalancedCall::RecordLatency() { + // Compute latency and report it to the tracer. + if (call_attempt_tracer() != nullptr) { + gpr_timespec latency = + gpr_cycle_counter_sub(gpr_get_cycle_counter(), lb_call_start_time_); + call_attempt_tracer()->RecordEnd(latency); + } +} + void ClientChannel::LoadBalancedCall::RemoveCallFromLbQueuedCallsLocked() { if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_lb_call_trace)) { gpr_log(GPR_INFO, "chand=%p lb_call=%p: removing from queued picks list", @@ -2614,7 +2764,7 @@ void ClientChannel::LoadBalancedCall::AddCallToLbQueuedCallsLocked() { grpc_polling_entity_add_to_pollset_set(pollent(), chand_->interested_parties_); // Add to queue. - chand_->lb_queued_calls_.insert(this); + chand_->lb_queued_calls_.insert(Ref()); OnAddToQueueLocked(); } @@ -2813,6 +2963,7 @@ void ClientChannel::FilterBasedLoadBalancedCall::Orphan() { RecordCallCompletion(absl::CancelledError("call cancelled"), nullptr, nullptr, ""); } + RecordLatency(); // Delegate to parent. LoadBalancedCall::Orphan(); } @@ -3153,7 +3304,7 @@ class ClientChannel::FilterBasedLoadBalancedCall::LbQueuedCallCanceller { // Remove pick from list of queued picks. lb_call->RemoveCallFromLbQueuedCallsLocked(); // Remove from queued picks list. - chand->lb_queued_calls_.erase(lb_call); + chand->lb_queued_calls_.erase(self->lb_call_); // Fail pending batches on the call. lb_call->PendingBatchesFail(error, YieldCallCombinerIfPendingBatchesFound); @@ -3243,4 +3394,152 @@ void ClientChannel::FilterBasedLoadBalancedCall::CreateSubchannelCall() { } } +// +// ClientChannel::PromiseBasedLoadBalancedCall +// + +ClientChannel::PromiseBasedLoadBalancedCall::PromiseBasedLoadBalancedCall( + ClientChannel* chand, absl::AnyInvocable<void()> on_commit, + bool is_transparent_retry) + : LoadBalancedCall(chand, GetContext<grpc_call_context_element>(), + std::move(on_commit), is_transparent_retry) {} + +ArenaPromise<ServerMetadataHandle> +ClientChannel::PromiseBasedLoadBalancedCall::MakeCallPromise( + CallArgs call_args, OrphanablePtr<PromiseBasedLoadBalancedCall> lb_call) { + pollent_ = NowOrNever(call_args.polling_entity->WaitAndCopy()).value(); + // Record ops in tracer. + if (call_attempt_tracer() != nullptr) { + call_attempt_tracer()->RecordSendInitialMetadata( + call_args.client_initial_metadata.get()); + // TODO(ctiller): Find a way to do this without registering a no-op mapper. + call_args.client_to_server_messages->InterceptAndMapWithHalfClose( + [](MessageHandle message) { return message; }, // No-op. + [this]() { + // TODO(roth): Change CallTracer API to not pass metadata + // batch to this method, since the batch is always empty. + grpc_metadata_batch metadata(GetContext<Arena>()); + call_attempt_tracer()->RecordSendTrailingMetadata(&metadata); + }); + } + // Extract peer name from server initial metadata. + call_args.server_initial_metadata->InterceptAndMap( + [this](ServerMetadataHandle metadata) { + if (call_attempt_tracer() != nullptr) { + call_attempt_tracer()->RecordReceivedInitialMetadata(metadata.get()); + } + Slice* peer_string = metadata->get_pointer(PeerString()); + if (peer_string != nullptr) peer_string_ = peer_string->Ref(); + return metadata; + }); + client_initial_metadata_ = std::move(call_args.client_initial_metadata); + return OnCancel( + Map(TrySeq( + // LB pick. + [this]() -> Poll<absl::Status> { + auto result = PickSubchannel(was_queued_); + if (GRPC_TRACE_FLAG_ENABLED( + grpc_client_channel_lb_call_trace)) { + gpr_log(GPR_INFO, + "chand=%p lb_call=%p: %sPickSubchannel() returns %s", + chand(), this, + Activity::current()->DebugTag().c_str(), + result.has_value() ? result->ToString().c_str() + : "Pending"); + } + if (result == absl::nullopt) return Pending{}; + return std::move(*result); + }, + [this, call_args = std::move(call_args)]() mutable + -> ArenaPromise<ServerMetadataHandle> { + call_args.client_initial_metadata = + std::move(client_initial_metadata_); + return connected_subchannel()->MakeCallPromise( + std::move(call_args)); + }), + // Record call completion. + [this](ServerMetadataHandle metadata) { + if (call_attempt_tracer() != nullptr || + lb_subchannel_call_tracker() != nullptr) { + absl::Status status; + grpc_status_code code = metadata->get(GrpcStatusMetadata()) + .value_or(GRPC_STATUS_UNKNOWN); + if (code != GRPC_STATUS_OK) { + absl::string_view message; + if (const auto* grpc_message = + metadata->get_pointer(GrpcMessageMetadata())) { + message = grpc_message->as_string_view(); + } + status = + absl::Status(static_cast<absl::StatusCode>(code), message); + } + RecordCallCompletion(status, metadata.get(), + &GetContext<CallContext>() + ->call_stats() + ->transport_stream_stats, + peer_string_.as_string_view()); + } + RecordLatency(); + return metadata; + }), + [lb_call = std::move(lb_call)]() { + // If the waker is pending, then we need to remove ourself from + // the list of queued LB calls. + if (!lb_call->waker_.is_unwakeable()) { + MutexLock lock(&lb_call->chand()->lb_mu_); + lb_call->Commit(); + // Remove pick from list of queued picks. + lb_call->RemoveCallFromLbQueuedCallsLocked(); + // Remove from queued picks list. + lb_call->chand()->lb_queued_calls_.erase(lb_call.get()); + } + // TODO(ctiller): We don't have access to the call's actual status + // here, so we just assume CANCELLED. We could change this to use + // CallFinalization instead of OnCancel() so that we can get the + // actual status. But we should also have access to the trailing + // metadata, which we don't have in either case. Ultimately, we + // need a better story for code that needs to run at the end of a + // call in both cancellation and non-cancellation cases that needs + // access to server trailing metadata and the call's real status. + if (lb_call->call_attempt_tracer() != nullptr) { + lb_call->call_attempt_tracer()->RecordCancel( + absl::CancelledError("call cancelled")); + } + if (lb_call->call_attempt_tracer() != nullptr || + lb_call->lb_subchannel_call_tracker() != nullptr) { + // If we were cancelled without recording call completion, then + // record call completion here, as best we can. We assume status + // CANCELLED in this case. + lb_call->RecordCallCompletion(absl::CancelledError("call cancelled"), + nullptr, nullptr, ""); + } + }); +} + +Arena* ClientChannel::PromiseBasedLoadBalancedCall::arena() const { + return GetContext<Arena>(); +} + +grpc_call_context_element* +ClientChannel::PromiseBasedLoadBalancedCall::call_context() const { + return GetContext<grpc_call_context_element>(); +} + +grpc_metadata_batch* +ClientChannel::PromiseBasedLoadBalancedCall::send_initial_metadata() const { + return client_initial_metadata_.get(); +} + +void ClientChannel::PromiseBasedLoadBalancedCall::OnAddToQueueLocked() { + waker_ = Activity::current()->MakeNonOwningWaker(); + was_queued_ = true; +} + +void ClientChannel::PromiseBasedLoadBalancedCall::RetryPickLocked() { + if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_lb_call_trace)) { + gpr_log(GPR_INFO, "chand=%p lb_call=%p: RetryPickLocked()", chand(), this); + } + waker_.WakeupAsync(); +} + } // namespace grpc_core diff --git a/src/core/ext/filters/client_channel/client_channel.h b/src/core/ext/filters/client_channel/client_channel.h index e43b8ce7b97..6597200cbc5 100644 --- a/src/core/ext/filters/client_channel/client_channel.h +++ b/src/core/ext/filters/client_channel/client_channel.h @@ -62,6 +62,8 @@ #include "src/core/lib/iomgr/iomgr_fwd.h" #include "src/core/lib/iomgr/polling_entity.h" #include "src/core/lib/load_balancing/lb_policy.h" +#include "src/core/lib/promise/activity.h" +#include "src/core/lib/promise/arena_promise.h" #include "src/core/lib/resolver/resolver.h" #include "src/core/lib/resource_quota/arena.h" #include "src/core/lib/service_config/service_config.h" @@ -102,10 +104,12 @@ namespace grpc_core { class ClientChannel { public: - static const grpc_channel_filter kFilterVtable; + static const grpc_channel_filter kFilterVtableWithPromises; + static const grpc_channel_filter kFilterVtableWithoutPromises; class LoadBalancedCall; class FilterBasedLoadBalancedCall; + class PromiseBasedLoadBalancedCall; // Flag that this object gets stored in channel args as a raw pointer. struct RawPointerChannelArgTag {}; @@ -115,6 +119,10 @@ class ClientChannel { // is not a client channel. static ClientChannel* GetFromChannel(Channel* channel); + static ArenaPromise<ServerMetadataHandle> MakeCallPromise( + grpc_channel_element* elem, CallArgs call_args, + NextPromiseFactory next_promise_factory); + grpc_connectivity_state CheckConnectivityState(bool try_to_connect); // Starts a one-time connectivity state watch. When the channel's state @@ -164,6 +172,10 @@ class ClientChannel { grpc_closure* on_call_destruction_complete, absl::AnyInvocable<void()> on_commit, bool is_transparent_retry); + ArenaPromise<ServerMetadataHandle> CreateLoadBalancedCallPromise( + CallArgs call_args, absl::AnyInvocable<void()> on_commit, + bool is_transparent_retry); + // Exposed for testing only. static ChannelArgs MakeSubchannelArgs( const ChannelArgs& channel_args, const ChannelArgs& address_args, @@ -173,6 +185,7 @@ class ClientChannel { private: class CallData; class FilterBasedCallData; + class PromiseBasedCallData; class ResolverResultHandler; class SubchannelWrapper; class ClientChannelControlHelper; @@ -315,8 +328,10 @@ class ClientChannel { mutable Mutex lb_mu_; RefCountedPtr<LoadBalancingPolicy::SubchannelPicker> picker_ ABSL_GUARDED_BY(lb_mu_); - absl::flat_hash_set<LoadBalancedCall*> lb_queued_calls_ - ABSL_GUARDED_BY(lb_mu_); + absl::flat_hash_set<RefCountedPtr<LoadBalancedCall>, + RefCountedPtrHash<LoadBalancedCall>, + RefCountedPtrEq<LoadBalancedCall>> + lb_queued_calls_ ABSL_GUARDED_BY(lb_mu_); // // Fields used in the control plane. Guarded by work_serializer. @@ -377,7 +392,7 @@ class ClientChannel::LoadBalancedCall bool is_transparent_retry); ~LoadBalancedCall() override; - void Orphan() override; + void Orphan() override { Unref(); } // Called by channel when removing a call from the list of queued calls. void RemoveCallFromLbQueuedCallsLocked() @@ -394,7 +409,6 @@ class ClientChannel::LoadBalancedCall return static_cast<ClientCallTracer::CallAttemptTracer*>( call_context()[GRPC_CONTEXT_CALL_TRACER].value); } - gpr_cycle_counter lb_call_start_time() const { return lb_call_start_time_; } ConnectedSubchannel* connected_subchannel() const { return connected_subchannel_.get(); } @@ -425,6 +439,8 @@ class ClientChannel::LoadBalancedCall grpc_transport_stream_stats* transport_stream_stats, absl::string_view peer_address); + void RecordLatency(); + private: class LbCallState; class Metadata; @@ -432,7 +448,7 @@ class ClientChannel::LoadBalancedCall virtual Arena* arena() const = 0; virtual grpc_call_context_element* call_context() const = 0; - virtual grpc_polling_entity* pollent() const = 0; + virtual grpc_polling_entity* pollent() = 0; virtual grpc_metadata_batch* send_initial_metadata() const = 0; // Helper function for performing an LB pick with a specified picker. @@ -445,7 +461,7 @@ class ClientChannel::LoadBalancedCall // Called when adding the call to the LB queue. virtual void OnAddToQueueLocked() - ABSL_EXCLUSIVE_LOCKS_REQUIRED(&ClientChannel::lb_mu_) {} + ABSL_EXCLUSIVE_LOCKS_REQUIRED(&ClientChannel::lb_mu_) = 0; ClientChannel* chand_; @@ -496,7 +512,7 @@ class ClientChannel::FilterBasedLoadBalancedCall grpc_call_context_element* call_context() const override { return call_context_; } - grpc_polling_entity* pollent() const override { return pollent_; } + grpc_polling_entity* pollent() override { return pollent_; } grpc_metadata_batch* send_initial_metadata() const override { return pending_batches_[0] ->payload->send_initial_metadata.send_initial_metadata; @@ -590,6 +606,34 @@ class ClientChannel::FilterBasedLoadBalancedCall grpc_transport_stream_op_batch* pending_batches_[MAX_PENDING_BATCHES] = {}; }; +class ClientChannel::PromiseBasedLoadBalancedCall + : public ClientChannel::LoadBalancedCall { + public: + PromiseBasedLoadBalancedCall(ClientChannel* chand, + absl::AnyInvocable<void()> on_commit, + bool is_transparent_retry); + + ArenaPromise<ServerMetadataHandle> MakeCallPromise( + CallArgs call_args, OrphanablePtr<PromiseBasedLoadBalancedCall> lb_call); + + private: + Arena* arena() const override; + grpc_call_context_element* call_context() const override; + grpc_polling_entity* pollent() override { return &pollent_; } + grpc_metadata_batch* send_initial_metadata() const override; + + void RetryPickLocked() override; + + void OnAddToQueueLocked() override + ABSL_EXCLUSIVE_LOCKS_REQUIRED(&ClientChannel::lb_mu_); + + grpc_polling_entity pollent_; + ClientMetadataHandle client_initial_metadata_; + Waker waker_; + bool was_queued_ = false; + Slice peer_string_; +}; + } // namespace grpc_core #endif // GRPC_SRC_CORE_EXT_FILTERS_CLIENT_CHANNEL_CLIENT_CHANNEL_H diff --git a/src/core/ext/filters/client_channel/client_channel_plugin.cc b/src/core/ext/filters/client_channel/client_channel_plugin.cc index f2272d783df..6f32bf74bb1 100644 --- a/src/core/ext/filters/client_channel/client_channel_plugin.cc +++ b/src/core/ext/filters/client_channel/client_channel_plugin.cc @@ -18,9 +18,14 @@ #include <grpc/support/port_platform.h> +#include "absl/types/optional.h" + +#include <grpc/impl/channel_arg_names.h> + #include "src/core/ext/filters/client_channel/client_channel.h" #include "src/core/ext/filters/client_channel/client_channel_service_config.h" #include "src/core/ext/filters/client_channel/retry_service_config.h" +#include "src/core/lib/channel/channel_args.h" #include "src/core/lib/channel/channel_stack_builder.h" #include "src/core/lib/config/core_configuration.h" #include "src/core/lib/surface/channel_init.h" @@ -28,13 +33,22 @@ namespace grpc_core { +namespace { +bool IsEverythingBelowClientChannelPromiseSafe(const ChannelArgs& args) { + return !args.GetBool(GRPC_ARG_ENABLE_RETRIES).value_or(true); +} +} // namespace + void BuildClientChannelConfiguration(CoreConfiguration::Builder* builder) { internal::ClientChannelServiceConfigParser::Register(builder); internal::RetryServiceConfigParser::Register(builder); builder->channel_init()->RegisterStage( GRPC_CLIENT_CHANNEL, GRPC_CHANNEL_INIT_BUILTIN_PRIORITY, [](ChannelStackBuilder* builder) { - builder->AppendFilter(&ClientChannel::kFilterVtable); + builder->AppendFilter( + IsEverythingBelowClientChannelPromiseSafe(builder->channel_args()) + ? &ClientChannel::kFilterVtableWithPromises + : &ClientChannel::kFilterVtableWithoutPromises); return true; }); } diff --git a/src/core/ext/filters/client_channel/dynamic_filters.h b/src/core/ext/filters/client_channel/dynamic_filters.h index 87712aa6c5c..9944c10f0b9 100644 --- a/src/core/ext/filters/client_channel/dynamic_filters.h +++ b/src/core/ext/filters/client_channel/dynamic_filters.h @@ -99,6 +99,8 @@ class DynamicFilters : public RefCounted<DynamicFilters> { RefCountedPtr<Call> CreateCall(Call::Args args, grpc_error_handle* error); + grpc_channel_stack* channel_stack() const { return channel_stack_.get(); } + private: RefCountedPtr<grpc_channel_stack> channel_stack_; }; diff --git a/src/core/ext/filters/client_channel/subchannel.cc b/src/core/ext/filters/client_channel/subchannel.cc index 80b5c7606a0..0a090637b2d 100644 --- a/src/core/ext/filters/client_channel/subchannel.cc +++ b/src/core/ext/filters/client_channel/subchannel.cc @@ -59,6 +59,8 @@ #include "src/core/lib/handshaker/proxy_mapper_registry.h" #include "src/core/lib/iomgr/exec_ctx.h" #include "src/core/lib/iomgr/pollset_set.h" +#include "src/core/lib/promise/cancel_callback.h" +#include "src/core/lib/promise/seq.h" #include "src/core/lib/slice/slice_internal.h" #include "src/core/lib/surface/channel_init.h" #include "src/core/lib/surface/channel_stack_type.h" @@ -133,6 +135,36 @@ size_t ConnectedSubchannel::GetInitialCallSizeEstimate() const { channel_stack_->call_stack_size; } +ArenaPromise<ServerMetadataHandle> ConnectedSubchannel::MakeCallPromise( + CallArgs call_args) { + // If not using channelz, we just need to call the channel stack. + if (channelz_subchannel() == nullptr) { + return channel_stack_->MakeClientCallPromise(std::move(call_args)); + } + // Otherwise, we need to wrap the channel stack promise with code that + // handles the channelz updates. + return OnCancel( + Seq(channel_stack_->MakeClientCallPromise(std::move(call_args)), + [self = Ref()](ServerMetadataHandle metadata) { + channelz::SubchannelNode* channelz_subchannel = + self->channelz_subchannel(); + GPR_ASSERT(channelz_subchannel != nullptr); + if (metadata->get(GrpcStatusMetadata()) + .value_or(GRPC_STATUS_UNKNOWN) != GRPC_STATUS_OK) { + channelz_subchannel->RecordCallFailed(); + } else { + channelz_subchannel->RecordCallSucceeded(); + } + return metadata; + }), + [self = Ref()]() { + channelz::SubchannelNode* channelz_subchannel = + self->channelz_subchannel(); + GPR_ASSERT(channelz_subchannel != nullptr); + channelz_subchannel->RecordCallFailed(); + }); +} + // // SubchannelCall // diff --git a/src/core/ext/filters/client_channel/subchannel.h b/src/core/ext/filters/client_channel/subchannel.h index 4d46415d40a..48b8d9cf3e7 100644 --- a/src/core/ext/filters/client_channel/subchannel.h +++ b/src/core/ext/filters/client_channel/subchannel.h @@ -54,6 +54,7 @@ #include "src/core/lib/iomgr/iomgr_fwd.h" #include "src/core/lib/iomgr/polling_entity.h" #include "src/core/lib/iomgr/resolved_address.h" +#include "src/core/lib/promise/arena_promise.h" #include "src/core/lib/resource_quota/arena.h" #include "src/core/lib/slice/slice.h" #include "src/core/lib/transport/connectivity_state.h" @@ -84,6 +85,8 @@ class ConnectedSubchannel : public RefCounted<ConnectedSubchannel> { size_t GetInitialCallSizeEstimate() const; + ArenaPromise<ServerMetadataHandle> MakeCallPromise(CallArgs call_args); + private: grpc_channel_stack* channel_stack_; ChannelArgs args_; diff --git a/src/core/lib/gprpp/ref_counted_ptr.h b/src/core/lib/gprpp/ref_counted_ptr.h index c6c3b97f9d0..adcffe40f1d 100644 --- a/src/core/lib/gprpp/ref_counted_ptr.h +++ b/src/core/lib/gprpp/ref_counted_ptr.h @@ -21,10 +21,14 @@ #include <grpc/support/port_platform.h> +#include <stddef.h> + #include <iosfwd> #include <type_traits> #include <utility> +#include "absl/hash/hash.h" + #include "src/core/lib/gprpp/debug_location.h" namespace grpc_core { @@ -333,6 +337,65 @@ bool operator<(const WeakRefCountedPtr<T>& p1, const WeakRefCountedPtr<T>& p2) { return p1.get() < p2.get(); } +// +// absl::Hash integration +// + +template <typename H, typename T> +H AbslHashValue(H h, const RefCountedPtr<T>& p) { + return H::combine(std::move(h), p.get()); +} + +template <typename H, typename T> +H AbslHashValue(H h, const WeakRefCountedPtr<T>& p) { + return H::combine(std::move(h), p.get()); +} + +// Heterogenous lookup support. +template <typename T> +struct RefCountedPtrHash { + using is_transparent = void; + size_t operator()(const RefCountedPtr<T>& p) const { + return absl::Hash<RefCountedPtr<T>>{}(p); + } + size_t operator()(const WeakRefCountedPtr<T>& p) const { + return absl::Hash<WeakRefCountedPtr<T>>{}(p); + } + size_t operator()(T* p) const { return absl::Hash<T*>{}(p); } +}; +template <typename T> +struct RefCountedPtrEq { + using is_transparent = void; + bool operator()(const RefCountedPtr<T>& p1, + const RefCountedPtr<T>& p2) const { + return p1 == p2; + } + bool operator()(const WeakRefCountedPtr<T>& p1, + const WeakRefCountedPtr<T>& p2) const { + return p1 == p2; + } + bool operator()(const RefCountedPtr<T>& p1, + const WeakRefCountedPtr<T>& p2) const { + return p1 == p2.get(); + } + bool operator()(const WeakRefCountedPtr<T>& p1, + const RefCountedPtr<T>& p2) const { + return p1 == p2.get(); + } + bool operator()(const RefCountedPtr<T>& p1, const T* p2) const { + return p1 == p2; + } + bool operator()(const WeakRefCountedPtr<T>& p1, const T* p2) const { + return p1 == p2; + } + bool operator()(const T* p1, const RefCountedPtr<T>& p2) const { + return p2 == p1; + } + bool operator()(const T* p1, const WeakRefCountedPtr<T>& p2) const { + return p2 == p1; + } +}; + } // namespace grpc_core #endif // GRPC_SRC_CORE_LIB_GPRPP_REF_COUNTED_PTR_H diff --git a/src/core/lib/iomgr/polling_entity.cc b/src/core/lib/iomgr/polling_entity.cc index 5df50e1ce41..b5caa8a3156 100644 --- a/src/core/lib/iomgr/polling_entity.cc +++ b/src/core/lib/iomgr/polling_entity.cc @@ -98,3 +98,13 @@ void grpc_polling_entity_del_from_pollset_set(grpc_polling_entity* pollent, absl::StrFormat("Invalid grpc_polling_entity tag '%d'", pollent->tag)); } } + +std::string grpc_polling_entity_string(grpc_polling_entity* pollent) { + if (pollent->tag == GRPC_POLLS_POLLSET) { + return absl::StrFormat("pollset:%p", pollent->pollent.pollset); + } else if (pollent->tag == GRPC_POLLS_POLLSET_SET) { + return absl::StrFormat("pollset_set:%p", pollent->pollent.pollset_set); + } else { + return absl::StrFormat("invalid_tag:%d", pollent->tag); + } +} diff --git a/src/core/lib/iomgr/polling_entity.h b/src/core/lib/iomgr/polling_entity.h index 39214ce62ac..18525868481 100644 --- a/src/core/lib/iomgr/polling_entity.h +++ b/src/core/lib/iomgr/polling_entity.h @@ -66,6 +66,8 @@ void grpc_polling_entity_add_to_pollset_set(grpc_polling_entity* pollent, void grpc_polling_entity_del_from_pollset_set(grpc_polling_entity* pollent, grpc_pollset_set* pss_dst); +std::string grpc_polling_entity_string(grpc_polling_entity* pollent); + namespace grpc_core { template <> struct ContextType<grpc_polling_entity> {}; diff --git a/src/core/lib/promise/latch.h b/src/core/lib/promise/latch.h index 9d33fe7d280..914a61acb53 100644 --- a/src/core/lib/promise/latch.h +++ b/src/core/lib/promise/latch.h @@ -44,6 +44,7 @@ class Latch { public: Latch() = default; Latch(const Latch&) = delete; + explicit Latch(T value) : value_(std::move(value)), has_value_(true) {} Latch& operator=(const Latch&) = delete; Latch(Latch&& other) noexcept : value_(std::move(other.value_)), has_value_(other.has_value_) { diff --git a/src/core/lib/promise/pipe.h b/src/core/lib/promise/pipe.h index f485d558f5b..6de4f4ae584 100644 --- a/src/core/lib/promise/pipe.h +++ b/src/core/lib/promise/pipe.h @@ -541,7 +541,9 @@ class Next { Next(Next&& other) noexcept = default; Next& operator=(Next&& other) noexcept = default; - Poll<absl::optional<T>> operator()() { return center_->Next(); } + Poll<absl::optional<T>> operator()() { + return center_ == nullptr ? absl::nullopt : center_->Next(); + } private: friend class PipeReceiver<T>; @@ -572,29 +574,27 @@ class PipeReceiver { // Blocks the promise until the receiver is either closed or a message is // available. auto Next() { - return Seq( - pipe_detail::Next<T>(center_->Ref()), - [center = center_->Ref()](absl::optional<T> value) { - bool open = value.has_value(); - bool cancelled = center->cancelled(); - return If( - open, - [center = std::move(center), value = std::move(value)]() mutable { - auto run = center->Run(std::move(value)); - return Map(std::move(run), - [center = std::move(center)]( - absl::optional<T> value) mutable { - if (value.has_value()) { - center->value() = std::move(*value); - return NextResult<T>(std::move(center)); - } else { - center->MarkCancelled(); - return NextResult<T>(true); - } - }); - }, - [cancelled]() { return NextResult<T>(cancelled); }); - }); + return Seq(pipe_detail::Next<T>(center_), [center = center_]( + absl::optional<T> value) { + bool open = value.has_value(); + bool cancelled = center == nullptr ? true : center->cancelled(); + return If( + open, + [center = std::move(center), value = std::move(value)]() mutable { + auto run = center->Run(std::move(value)); + return Map(std::move(run), [center = std::move(center)]( + absl::optional<T> value) mutable { + if (value.has_value()) { + center->value() = std::move(*value); + return NextResult<T>(std::move(center)); + } else { + center->MarkCancelled(); + return NextResult<T>(true); + } + }); + }, + [cancelled]() { return NextResult<T>(cancelled); }); + }); } // Return a promise that resolves when the receiver is closed. diff --git a/src/core/lib/surface/call.cc b/src/core/lib/surface/call.cc index 8f25ce457a6..d567bcece38 100644 --- a/src/core/lib/surface/call.cc +++ b/src/core/lib/surface/call.cc @@ -1980,10 +1980,21 @@ class PromiseBasedCall : public Call, void SetCompletionQueue(grpc_completion_queue* cq) override; bool Completed() final { return finished_.IsSet(); } + virtual void OrphanCall() = 0; + // Implementation of call refcounting: move this to DualRefCounted once we // don't need to maintain FilterStackCall compatibility - void ExternalRef() final { InternalRef("external"); } - void ExternalUnref() final { InternalUnref("external"); } + void ExternalRef() final { + if (external_refs_.fetch_add(1, std::memory_order_relaxed) == 0) { + InternalRef("external"); + } + } + void ExternalUnref() final { + if (external_refs_.fetch_sub(1, std::memory_order_acq_rel) == 1) { + OrphanCall(); + InternalUnref("external"); + } + } void InternalRef(const char* reason) final { if (grpc_call_refcount_trace.enabled()) { gpr_log(GPR_DEBUG, "INTERNAL_REF:%p:%s", this, reason); @@ -2346,14 +2357,16 @@ class PromiseBasedCall : public Call, } CallContext call_context_{this}; - + // Double refcounted for now: party owns the internal refcount, we track the + // external refcount. Figure out a better scheme post-promise conversion. + std::atomic<size_t> external_refs_; // Contexts for various subsystems (security, tracing, ...). grpc_call_context_element context_[GRPC_CONTEXT_COUNT] = {}; grpc_completion_queue* cq_; CompletionInfo completion_info_[6]; grpc_call_stats final_stats_{}; Slice final_message_; - grpc_status_code final_status_; + grpc_status_code final_status_ = GRPC_STATUS_UNKNOWN; CallFinalization finalization_; // Current deadline. Mutex deadline_mu_; @@ -2391,7 +2404,8 @@ PromiseBasedCall::PromiseBasedCall(Arena* arena, uint32_t initial_external_refs, const grpc_call_create_args& args) : Call(arena, args.server_transport_data == nullptr, args.send_deadline, args.channel->Ref()), - Party(arena, initial_external_refs), + Party(arena, initial_external_refs != 0 ? 1 : 0), + external_refs_(initial_external_refs), cq_(args.cq) { if (args.cq != nullptr) { GRPC_CQ_INTERNAL_REF(args.cq, "bind"); @@ -2684,18 +2698,20 @@ void PublishMetadataArray(grpc_metadata_batch* md, grpc_metadata_array* array, class ClientPromiseBasedCall final : public PromiseBasedCall { public: ClientPromiseBasedCall(Arena* arena, grpc_call_create_args* args) - : PromiseBasedCall(arena, 1, *args) { + : PromiseBasedCall(arena, 1, *args), + polling_entity_( + args->cq != nullptr + ? grpc_polling_entity_create_from_pollset( + grpc_cq_pollset(args->cq)) + : (args->pollset_set_alternative != nullptr + ? grpc_polling_entity_create_from_pollset_set( + args->pollset_set_alternative) + : grpc_polling_entity{})) { global_stats().IncrementClientCallsCreated(); if (args->cq != nullptr) { GPR_ASSERT(args->pollset_set_alternative == nullptr && "Only one of 'cq' and 'pollset_set_alternative' should be " "non-nullptr."); - polling_entity_.Set( - grpc_polling_entity_create_from_pollset(grpc_cq_pollset(args->cq))); - } - if (args->pollset_set_alternative != nullptr) { - polling_entity_.Set(grpc_polling_entity_create_from_pollset_set( - args->pollset_set_alternative)); } ScopedContext context(this); send_initial_metadata_ = @@ -2711,8 +2727,18 @@ class ClientPromiseBasedCall final : public PromiseBasedCall { if (args->send_deadline != Timestamp::InfFuture()) { UpdateDeadline(args->send_deadline); } + Call* parent = Call::FromC(args->parent); + if (parent != nullptr) { + auto parent_status = InitParent(parent, args->propagation_mask); + if (!parent_status.ok()) { + CancelWithError(std::move(parent_status)); + } + PublishToParent(parent); + } } + void OrphanCall() override { MaybeUnpublishFromParent(); } + ~ClientPromiseBasedCall() override { ScopedContext context(this); send_initial_metadata_.reset(); @@ -2740,7 +2766,9 @@ class ClientPromiseBasedCall final : public PromiseBasedCall { "cancel_with_error", [error = std::move(error), this]() { if (!cancel_error_.is_set()) { - cancel_error_.Set(ServerMetadataFromStatus(error)); + auto md = ServerMetadataFromStatus(error); + md->Set(GrpcCallWasCancelled(), true); + cancel_error_.Set(std::move(md)); } return Empty{}; }, @@ -2790,7 +2818,7 @@ class ClientPromiseBasedCall final : public PromiseBasedCall { Latch<grpc_polling_entity> polling_entity_; Pipe<MessageHandle> client_to_server_messages_{arena()}; Pipe<MessageHandle> server_to_client_messages_{arena()}; - bool is_trailers_only_; + bool is_trailers_only_ = false; // True once the promise for the call is started. // This corresponds to sending initial metadata, or cancelling before doing // so. @@ -2868,7 +2896,9 @@ grpc_call_error ClientPromiseBasedCall::ValidateBatch(const grpc_op* ops, case GRPC_OP_SEND_STATUS_FROM_SERVER: return GRPC_CALL_ERROR_NOT_ON_CLIENT; } - if (got_ops.is_set(op.op)) return GRPC_CALL_ERROR_TOO_MANY_OPERATIONS; + if (got_ops.is_set(op.op)) { + return GRPC_CALL_ERROR_TOO_MANY_OPERATIONS; + } got_ops.set(op.op); } return GRPC_CALL_OK; @@ -2972,9 +3002,16 @@ void ClientPromiseBasedCall::StartRecvInitialMetadata( NextResult<ServerMetadataHandle> next_metadata) mutable { server_initial_metadata_.sender.Close(); ServerMetadataHandle metadata; + if (grpc_call_trace.enabled()) { + gpr_log(GPR_INFO, "%s[call] RecvTrailingMetadata: %s", + DebugTag().c_str(), + next_metadata.has_value() + ? next_metadata.value()->DebugString().c_str() + : "null"); + } if (next_metadata.has_value()) { - is_trailers_only_ = false; metadata = std::move(next_metadata.value()); + is_trailers_only_ = metadata->get(GrpcTrailersOnly()).value_or(false); } else { is_trailers_only_ = true; metadata = arena()->MakePooled<ServerMetadata>(arena()); @@ -2993,7 +3030,11 @@ void ClientPromiseBasedCall::Finish(ServerMetadataHandle trailing_metadata) { } ResetDeadline(); set_completed(); - client_to_server_messages_.sender.Close(); + client_to_server_messages_.sender.CloseWithError(); + client_to_server_messages_.receiver.CloseWithError(); + if (trailing_metadata->get(GrpcCallWasCancelled()).value_or(false)) { + server_to_client_messages_.receiver.CloseWithError(); + } if (auto* channelz_channel = channel()->channelz_node()) { if (trailing_metadata->get(GrpcStatusMetadata()) .value_or(GRPC_STATUS_UNKNOWN) == GRPC_STATUS_OK) { @@ -3071,6 +3112,7 @@ class ServerPromiseBasedCall final : public PromiseBasedCall { public: ServerPromiseBasedCall(Arena* arena, grpc_call_create_args* args); + void OrphanCall() override {} void CancelWithError(grpc_error_handle) override; grpc_call_error StartBatch(const grpc_op* ops, size_t nops, void* notify_tag, bool is_notify_tag_closure) override; diff --git a/src/core/lib/surface/call.h b/src/core/lib/surface/call.h index d3c1e94bf8e..94fa53a5b96 100644 --- a/src/core/lib/surface/call.h +++ b/src/core/lib/surface/call.h @@ -38,6 +38,7 @@ #include "src/core/lib/channel/channel_stack.h" #include "src/core/lib/channel/context.h" #include "src/core/lib/debug/trace.h" +#include "src/core/lib/gpr/time_precise.h" #include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/gprpp/time.h" #include "src/core/lib/iomgr/closure.h" @@ -126,6 +127,7 @@ class CallContext { grpc_call_stats* call_stats() { return &call_stats_; } gpr_atm* peer_string_atm_ptr(); + gpr_cycle_counter call_start_time() { return start_time_; } ServerCallContext* server_call_context(); @@ -139,6 +141,7 @@ class CallContext { // TODO(ctiller): remove this once transport APIs are promise based and we // don't need refcounting here. PromiseBasedCall* const call_; + gpr_cycle_counter start_time_ = gpr_get_cycle_counter(); // Is this call traced? bool traced_ = false; }; diff --git a/test/core/end2end/tests/filter_causes_close.cc b/test/core/end2end/tests/filter_causes_close.cc index 08cda31d1bd..a6f4f7498b3 100644 --- a/test/core/end2end/tests/filter_causes_close.cc +++ b/test/core/end2end/tests/filter_causes_close.cc @@ -18,6 +18,8 @@ #include <stdint.h> +#include <memory> + #include "absl/status/status.h" #include "gtest/gtest.h" @@ -32,6 +34,8 @@ #include "src/core/lib/gprpp/time.h" #include "src/core/lib/iomgr/closure.h" #include "src/core/lib/iomgr/error.h" +#include "src/core/lib/promise/arena_promise.h" +#include "src/core/lib/promise/promise.h" #include "src/core/lib/surface/channel_stack_type.h" #include "src/core/lib/transport/transport.h" #include "test/core/end2end/end2end_tests.h" @@ -92,7 +96,11 @@ void destroy_channel_elem(grpc_channel_element* /*elem*/) {} const grpc_channel_filter test_filter = { start_transport_stream_op_batch, - nullptr, + [](grpc_channel_element*, CallArgs, + NextPromiseFactory) -> ArenaPromise<ServerMetadataHandle> { + return Immediate(ServerMetadataFromStatus( + absl::PermissionDeniedError("Failure that's not preventable."))); + }, grpc_channel_next_op, sizeof(call_data), init_call_elem, diff --git a/test/core/end2end/tests/server_streaming.cc b/test/core/end2end/tests/server_streaming.cc index c7e78e3b5fb..18ec90815c4 100644 --- a/test/core/end2end/tests/server_streaming.cc +++ b/test/core/end2end/tests/server_streaming.cc @@ -67,6 +67,8 @@ void ServerStreaming(CoreEnd2endTest& test, int num_messages) { test.Expect(104, true); test.Step(); + gpr_log(GPR_DEBUG, "SEEN_STATUS:%d", seen_status); + // Client keeps reading messages till it gets the status int num_messages_received = 0; while (true) { diff --git a/test/core/gprpp/BUILD b/test/core/gprpp/BUILD index dd2e90387d9..09f7c9a281d 100644 --- a/test/core/gprpp/BUILD +++ b/test/core/gprpp/BUILD @@ -200,6 +200,7 @@ grpc_cc_test( name = "ref_counted_ptr_test", srcs = ["ref_counted_ptr_test.cc"], external_deps = [ + "absl/container:flat_hash_set", "gtest", ], language = "C++", diff --git a/test/core/gprpp/ref_counted_ptr_test.cc b/test/core/gprpp/ref_counted_ptr_test.cc index d1f1da8a8aa..c6e7b0a4588 100644 --- a/test/core/gprpp/ref_counted_ptr_test.cc +++ b/test/core/gprpp/ref_counted_ptr_test.cc @@ -18,6 +18,7 @@ #include "src/core/lib/gprpp/ref_counted_ptr.h" +#include "absl/container/flat_hash_set.h" #include "gtest/gtest.h" #include <grpc/support/log.h> @@ -511,6 +512,65 @@ TEST(WeakRefCountedPtr, CanPassWeakSubclassToFunctionExpectingWeakSubclass) { FunctionTakingWeakSubclass(p); } +// +// tests for absl hash integration +// + +TEST(AbslHashIntegration, RefCountedPtr) { + absl::flat_hash_set<RefCountedPtr<Foo>> set; + auto p = MakeRefCounted<Foo>(5); + set.insert(p); + auto it = set.find(p); + ASSERT_NE(it, set.end()); + EXPECT_EQ(*it, p); +} + +TEST(AbslHashIntegration, WeakRefCountedPtr) { + absl::flat_hash_set<WeakRefCountedPtr<Bar>> set; + auto p = MakeRefCounted<Bar>(5); + auto q = p->WeakRef(); + set.insert(q); + auto it = set.find(q); + ASSERT_NE(it, set.end()); + EXPECT_EQ(*it, q); +} + +TEST(AbslHashIntegration, RefCountedPtrHeterogenousLookup) { + absl::flat_hash_set<RefCountedPtr<Bar>, RefCountedPtrHash<Bar>, + RefCountedPtrEq<Bar>> + set; + auto p = MakeRefCounted<Bar>(5); + set.insert(p); + auto it = set.find(p); + ASSERT_NE(it, set.end()); + EXPECT_EQ(*it, p); + auto q = p->WeakRef(); + it = set.find(q); + ASSERT_NE(it, set.end()); + EXPECT_EQ(*it, p); + it = set.find(p.get()); + ASSERT_NE(it, set.end()); + EXPECT_EQ(*it, p); +} + +TEST(AbslHashIntegration, WeakRefCountedPtrHeterogenousLookup) { + absl::flat_hash_set<WeakRefCountedPtr<Bar>, RefCountedPtrHash<Bar>, + RefCountedPtrEq<Bar>> + set; + auto p = MakeRefCounted<Bar>(5); + auto q = p->WeakRef(); + set.insert(q); + auto it = set.find(q); + ASSERT_NE(it, set.end()); + EXPECT_EQ(*it, q); + it = set.find(p); + ASSERT_NE(it, set.end()); + EXPECT_EQ(*it, q); + it = set.find(p.get()); + ASSERT_NE(it, set.end()); + EXPECT_EQ(*it, q); +} + } // namespace } // namespace testing } // namespace grpc_core diff --git a/test/cpp/microbenchmarks/bm_call_create.cc b/test/cpp/microbenchmarks/bm_call_create.cc index cdb948c24b2..c60fc4a3f81 100644 --- a/test/cpp/microbenchmarks/bm_call_create.cc +++ b/test/cpp/microbenchmarks/bm_call_create.cc @@ -570,7 +570,7 @@ BENCHMARK_TEMPLATE(BM_IsolatedFilter, NoFilter, NoOp); typedef Fixture<&phony_filter::phony_filter, 0> PhonyFilter; BENCHMARK_TEMPLATE(BM_IsolatedFilter, PhonyFilter, NoOp); BENCHMARK_TEMPLATE(BM_IsolatedFilter, PhonyFilter, SendEmptyMetadata); -typedef Fixture<&grpc_core::ClientChannel::kFilterVtable, 0> +typedef Fixture<&grpc_core::ClientChannel::kFilterVtableWithoutPromises, 0> ClientChannelFilter; BENCHMARK_TEMPLATE(BM_IsolatedFilter, ClientChannelFilter, NoOp); typedef Fixture<&grpc_core::ClientCompressionFilter::kFilter, CHECKS_NOT_LAST>