From 41debbf1a7087038ec38a3af089e477aed7c8302 Mon Sep 17 00:00:00 2001 From: Craig Tiller Date: Mon, 21 Mar 2022 13:11:05 -0700 Subject: [PATCH] HTTP Client Filter --> promises (#29031) * Remove idempotent/cacheable requests * more cleanup * bump core version * fix * fix * fix * review feedback * fixes * fix * remove more * objc * fix * fix * fix * scrub * introduce call args * bs * x * Automated change: Fix sanity tests * fix * roughing out * push/pull impl * comment * prove out new combinator * Simplify naming * fix * Automated change: Fix sanity tests * Automated change: Fix sanity tests * progress * builds * Automated change: Fix sanity tests * progress * tweak * merge * progress * fix * first test passes * progress * fix * repair name * fix * small fix * small fix * properly stop call combiner * fix allocation in this benchmark * fix * fix * fix merge * fix bad merge Co-authored-by: ctiller --- BUILD | 2 + build_autogenerated.yaml | 2 + gRPC-C++.podspec | 2 + gRPC-Core.podspec | 2 + grpc.gemspec | 1 + package.xml | 1 + .../filters/http/client/http_client_filter.cc | 278 ++------- .../filters/http/client/http_client_filter.h | 24 +- .../ext/filters/http/http_filters_plugin.cc | 4 +- src/core/lib/channel/promise_based_filter.cc | 576 +++++++++++++----- src/core/lib/channel/promise_based_filter.h | 42 +- src/core/lib/promise/call_push_pull.h | 4 +- src/core/lib/promise/detail/status.h | 3 +- src/core/lib/transport/transport.h | 2 +- test/cpp/microbenchmarks/bm_call_create.cc | 5 +- tools/doxygen/Doxyfile.c++.internal | 1 + tools/doxygen/Doxyfile.core.internal | 1 + 17 files changed, 562 insertions(+), 388 deletions(-) diff --git a/BUILD b/BUILD index a27da2ec4d1..b8f07d53bdb 100644 --- a/BUILD +++ b/BUILD @@ -2746,10 +2746,12 @@ grpc_cc_library( ], language = "c++", deps = [ + "call_push_pull", "config", "gpr_base", "grpc_base", "grpc_message_size_filter", + "seq", "slice", ], ) diff --git a/build_autogenerated.yaml b/build_autogenerated.yaml index b7992335411..05137fd0b5e 100644 --- a/build_autogenerated.yaml +++ b/build_autogenerated.yaml @@ -815,6 +815,7 @@ libs: - src/core/lib/matchers/matchers.h - src/core/lib/promise/activity.h - src/core/lib/promise/arena_promise.h + - src/core/lib/promise/call_push_pull.h - src/core/lib/promise/context.h - src/core/lib/promise/detail/basic_seq.h - src/core/lib/promise/detail/promise_factory.h @@ -1990,6 +1991,7 @@ libs: - src/core/lib/json/json_util.h - src/core/lib/promise/activity.h - src/core/lib/promise/arena_promise.h + - src/core/lib/promise/call_push_pull.h - src/core/lib/promise/context.h - src/core/lib/promise/detail/basic_seq.h - src/core/lib/promise/detail/promise_factory.h diff --git a/gRPC-C++.podspec b/gRPC-C++.podspec index 083a2998934..5b350c9accf 100644 --- a/gRPC-C++.podspec +++ b/gRPC-C++.podspec @@ -783,6 +783,7 @@ Pod::Spec.new do |s| 'src/core/lib/profiling/timers.h', 'src/core/lib/promise/activity.h', 'src/core/lib/promise/arena_promise.h', + 'src/core/lib/promise/call_push_pull.h', 'src/core/lib/promise/context.h', 'src/core/lib/promise/detail/basic_seq.h', 'src/core/lib/promise/detail/promise_factory.h', @@ -1588,6 +1589,7 @@ Pod::Spec.new do |s| 'src/core/lib/profiling/timers.h', 'src/core/lib/promise/activity.h', 'src/core/lib/promise/arena_promise.h', + 'src/core/lib/promise/call_push_pull.h', 'src/core/lib/promise/context.h', 'src/core/lib/promise/detail/basic_seq.h', 'src/core/lib/promise/detail/promise_factory.h', diff --git a/gRPC-Core.podspec b/gRPC-Core.podspec index 1b106b25d7e..ed547ea0e36 100644 --- a/gRPC-Core.podspec +++ b/gRPC-Core.podspec @@ -1286,6 +1286,7 @@ Pod::Spec.new do |s| 'src/core/lib/promise/activity.cc', 'src/core/lib/promise/activity.h', 'src/core/lib/promise/arena_promise.h', + 'src/core/lib/promise/call_push_pull.h', 'src/core/lib/promise/context.h', 'src/core/lib/promise/detail/basic_seq.h', 'src/core/lib/promise/detail/promise_factory.h', @@ -2186,6 +2187,7 @@ Pod::Spec.new do |s| 'src/core/lib/profiling/timers.h', 'src/core/lib/promise/activity.h', 'src/core/lib/promise/arena_promise.h', + 'src/core/lib/promise/call_push_pull.h', 'src/core/lib/promise/context.h', 'src/core/lib/promise/detail/basic_seq.h', 'src/core/lib/promise/detail/promise_factory.h', diff --git a/grpc.gemspec b/grpc.gemspec index 8a56c3841a0..f6346486b94 100644 --- a/grpc.gemspec +++ b/grpc.gemspec @@ -1205,6 +1205,7 @@ Gem::Specification.new do |s| s.files += %w( src/core/lib/promise/activity.cc ) s.files += %w( src/core/lib/promise/activity.h ) s.files += %w( src/core/lib/promise/arena_promise.h ) + s.files += %w( src/core/lib/promise/call_push_pull.h ) s.files += %w( src/core/lib/promise/context.h ) s.files += %w( src/core/lib/promise/detail/basic_seq.h ) s.files += %w( src/core/lib/promise/detail/promise_factory.h ) diff --git a/package.xml b/package.xml index f7c2e8ff347..b4c3e57fa2d 100644 --- a/package.xml +++ b/package.xml @@ -1185,6 +1185,7 @@ + diff --git a/src/core/ext/filters/http/client/http_client_filter.cc b/src/core/ext/filters/http/client/http_client_filter.cc index 8d58c979b7c..fcd07c89f00 100644 --- a/src/core/ext/filters/http/client/http_client_filter.cc +++ b/src/core/ext/filters/http/client/http_client_filter.cc @@ -33,218 +33,60 @@ #include #include "src/core/lib/channel/channel_args.h" -#include "src/core/lib/gpr/string.h" -#include "src/core/lib/gprpp/manual_constructor.h" -#include "src/core/lib/profiling/timers.h" -#include "src/core/lib/slice/b64.h" +#include "src/core/lib/promise/call_push_pull.h" +#include "src/core/lib/promise/seq.h" #include "src/core/lib/slice/percent_encoding.h" -#include "src/core/lib/slice/slice_internal.h" -#include "src/core/lib/slice/slice_string_helpers.h" #include "src/core/lib/transport/status_conversion.h" #include "src/core/lib/transport/transport_impl.h" -#define EXPECTED_CONTENT_TYPE "application/grpc" -#define EXPECTED_CONTENT_TYPE_LENGTH (sizeof(EXPECTED_CONTENT_TYPE) - 1) +namespace grpc_core { -static void recv_initial_metadata_ready(void* user_data, - grpc_error_handle error); -static void recv_trailing_metadata_ready(void* user_data, - grpc_error_handle error); +const grpc_channel_filter HttpClientFilter::kFilter = + MakePromiseBasedFilter("http-client"); namespace { -struct call_data { - call_data(grpc_call_element* elem, const grpc_call_element_args& args) - : call_combiner(args.call_combiner) { - GRPC_CLOSURE_INIT(&recv_initial_metadata_ready, - ::recv_initial_metadata_ready, elem, - grpc_schedule_on_exec_ctx); - GRPC_CLOSURE_INIT(&recv_trailing_metadata_ready, - ::recv_trailing_metadata_ready, elem, - grpc_schedule_on_exec_ctx); - } - - ~call_data() { GRPC_ERROR_UNREF(recv_initial_metadata_error); } - - grpc_core::CallCombiner* call_combiner; - // State for handling recv_initial_metadata ops. - grpc_metadata_batch* recv_initial_metadata; - grpc_error_handle recv_initial_metadata_error = GRPC_ERROR_NONE; - grpc_closure* original_recv_initial_metadata_ready = nullptr; - grpc_closure recv_initial_metadata_ready; - // State for handling recv_trailing_metadata ops. - grpc_metadata_batch* recv_trailing_metadata; - grpc_closure* original_recv_trailing_metadata_ready; - grpc_closure recv_trailing_metadata_ready; - grpc_error_handle recv_trailing_metadata_error = GRPC_ERROR_NONE; - bool seen_recv_trailing_metadata_ready = false; -}; - -struct channel_data { - grpc_core::HttpSchemeMetadata::ValueType static_scheme; - grpc_core::Slice user_agent; -}; -} // namespace - -static grpc_error_handle client_filter_incoming_metadata( - grpc_metadata_batch* b) { - if (auto* status = b->get_pointer(grpc_core::HttpStatusMetadata())) { +absl::Status CheckServerMetadata(ServerMetadata* b) { + if (auto* status = b->get_pointer(HttpStatusMetadata())) { /* If both gRPC status and HTTP status are provided in the response, we * should prefer the gRPC status code, as mentioned in * https://github.com/grpc/grpc/blob/master/doc/http-grpc-status-mapping.md. */ - const grpc_status_code* grpc_status = - b->get_pointer(grpc_core::GrpcStatusMetadata()); + const grpc_status_code* grpc_status = b->get_pointer(GrpcStatusMetadata()); if (grpc_status != nullptr || *status == 200) { - b->Remove(grpc_core::HttpStatusMetadata()); + b->Remove(HttpStatusMetadata()); } else { - std::string msg = - absl::StrCat("Received http2 header with status: ", *status); - grpc_error_handle e = grpc_error_set_str( - grpc_error_set_int( - grpc_error_set_str( - GRPC_ERROR_CREATE_FROM_STATIC_STRING( - "Received http2 :status header with non-200 OK status"), - GRPC_ERROR_STR_VALUE, std::to_string(*status)), - GRPC_ERROR_INT_GRPC_STATUS, + return absl::Status( + static_cast( grpc_http2_status_to_grpc_status(*status)), - GRPC_ERROR_STR_GRPC_MESSAGE, msg); - return e; + absl::StrCat("Received http2 header with status: ", *status)); } } - if (grpc_core::Slice* grpc_message = - b->get_pointer(grpc_core::GrpcMessageMetadata())) { - *grpc_message = - grpc_core::PermissivePercentDecodeSlice(std::move(*grpc_message)); - } - - b->Remove(grpc_core::ContentTypeMetadata()); - - return GRPC_ERROR_NONE; -} - -static void recv_initial_metadata_ready(void* user_data, - grpc_error_handle error) { - grpc_call_element* elem = static_cast(user_data); - call_data* calld = static_cast(elem->call_data); - if (error == GRPC_ERROR_NONE) { - error = client_filter_incoming_metadata(calld->recv_initial_metadata); - calld->recv_initial_metadata_error = GRPC_ERROR_REF(error); - } else { - (void)GRPC_ERROR_REF(error); - } - grpc_closure* closure = calld->original_recv_initial_metadata_ready; - calld->original_recv_initial_metadata_ready = nullptr; - if (calld->seen_recv_trailing_metadata_ready) { - GRPC_CALL_COMBINER_START( - calld->call_combiner, &calld->recv_trailing_metadata_ready, - calld->recv_trailing_metadata_error, "continue recv_trailing_metadata"); - } - grpc_core::Closure::Run(DEBUG_LOCATION, closure, error); -} - -static void recv_trailing_metadata_ready(void* user_data, - grpc_error_handle error) { - grpc_call_element* elem = static_cast(user_data); - call_data* calld = static_cast(elem->call_data); - if (calld->original_recv_initial_metadata_ready != nullptr) { - calld->recv_trailing_metadata_error = GRPC_ERROR_REF(error); - calld->seen_recv_trailing_metadata_ready = true; - GRPC_CALL_COMBINER_STOP(calld->call_combiner, - "deferring recv_trailing_metadata_ready until " - "after recv_initial_metadata_ready"); - return; - } - if (error == GRPC_ERROR_NONE) { - error = client_filter_incoming_metadata(calld->recv_trailing_metadata); - } else { - (void)GRPC_ERROR_REF(error); - } - error = grpc_error_add_child( - error, GRPC_ERROR_REF(calld->recv_initial_metadata_error)); - grpc_core::Closure::Run(DEBUG_LOCATION, - calld->original_recv_trailing_metadata_ready, error); -} - -static void http_client_start_transport_stream_op_batch( - grpc_call_element* elem, grpc_transport_stream_op_batch* batch) { - call_data* calld = static_cast(elem->call_data); - channel_data* channeld = static_cast(elem->channel_data); - GPR_TIMER_SCOPE("http_client_start_transport_stream_op_batch", 0); - - if (batch->recv_initial_metadata) { - /* substitute our callback for the higher callback */ - calld->recv_initial_metadata = - batch->payload->recv_initial_metadata.recv_initial_metadata; - calld->original_recv_initial_metadata_ready = - batch->payload->recv_initial_metadata.recv_initial_metadata_ready; - batch->payload->recv_initial_metadata.recv_initial_metadata_ready = - &calld->recv_initial_metadata_ready; - } - - if (batch->recv_trailing_metadata) { - /* substitute our callback for the higher callback */ - calld->recv_trailing_metadata = - batch->payload->recv_trailing_metadata.recv_trailing_metadata; - calld->original_recv_trailing_metadata_ready = - batch->payload->recv_trailing_metadata.recv_trailing_metadata_ready; - batch->payload->recv_trailing_metadata.recv_trailing_metadata_ready = - &calld->recv_trailing_metadata_ready; - } - - if (batch->send_initial_metadata) { - /* Send : prefixed headers, which have to be before any application - layer headers. */ - batch->payload->send_initial_metadata.send_initial_metadata->Set( - grpc_core::HttpMethodMetadata(), grpc_core::HttpMethodMetadata::kPost); - batch->payload->send_initial_metadata.send_initial_metadata->Set( - grpc_core::HttpSchemeMetadata(), channeld->static_scheme); - batch->payload->send_initial_metadata.send_initial_metadata->Set( - grpc_core::TeMetadata(), grpc_core::TeMetadata::kTrailers); - batch->payload->send_initial_metadata.send_initial_metadata->Set( - grpc_core::ContentTypeMetadata(), - grpc_core::ContentTypeMetadata::kApplicationGrpc); - batch->payload->send_initial_metadata.send_initial_metadata->Set( - grpc_core::UserAgentMetadata(), channeld->user_agent.Ref()); + if (Slice* grpc_message = b->get_pointer(GrpcMessageMetadata())) { + *grpc_message = PermissivePercentDecodeSlice(std::move(*grpc_message)); } - grpc_call_next_op(elem, batch); + b->Remove(ContentTypeMetadata()); + return absl::OkStatus(); } -/* Constructor for call_data */ -static grpc_error_handle http_client_init_call_elem( - grpc_call_element* elem, const grpc_call_element_args* args) { - new (elem->call_data) call_data(elem, *args); - return GRPC_ERROR_NONE; -} - -/* Destructor for call_data */ -static void http_client_destroy_call_elem( - grpc_call_element* elem, const grpc_call_final_info* /*final_info*/, - grpc_closure* /*ignored*/) { - call_data* calld = static_cast(elem->call_data); - calld->~call_data(); -} - -static grpc_core::HttpSchemeMetadata::ValueType scheme_from_args( - const grpc_channel_args* args) { +HttpSchemeMetadata::ValueType SchemeFromArgs(const grpc_channel_args* args) { if (args != nullptr) { for (size_t i = 0; i < args->num_args; ++i) { if (args->args[i].type == GRPC_ARG_STRING && 0 == strcmp(args->args[i].key, GRPC_ARG_HTTP2_SCHEME)) { - grpc_core::HttpSchemeMetadata::ValueType scheme = - grpc_core::HttpSchemeMetadata::Parse( - args->args[i].value.string, - [](absl::string_view, const grpc_core::Slice&) {}); - if (scheme != grpc_core::HttpSchemeMetadata::kInvalid) return scheme; + HttpSchemeMetadata::ValueType scheme = HttpSchemeMetadata::Parse( + args->args[i].value.string, [](absl::string_view, const Slice&) {}); + if (scheme != HttpSchemeMetadata::kInvalid) return scheme; } } } - return grpc_core::HttpSchemeMetadata::kHttp; + return HttpSchemeMetadata::kHttp; } -static grpc_core::Slice user_agent_from_args(const grpc_channel_args* args, - const char* transport_name) { +Slice UserAgentFromArgs(const grpc_channel_args* args, + const char* transport_name) { std::vector user_agent_fields; for (size_t i = 0; args && i < args->num_args; i++) { @@ -274,39 +116,51 @@ static grpc_core::Slice user_agent_from_args(const grpc_channel_args* args, } std::string user_agent_string = absl::StrJoin(user_agent_fields, " "); - return grpc_core::Slice::FromCopiedString(user_agent_string.c_str()); + return Slice::FromCopiedString(user_agent_string.c_str()); } +} // namespace -/* Constructor for channel_data */ -static grpc_error_handle http_client_init_channel_elem( - grpc_channel_element* elem, grpc_channel_element_args* args) { - channel_data* chand = static_cast(elem->channel_data); - new (chand) channel_data(); - GPR_ASSERT(!args->is_last); - auto* transport = grpc_channel_args_find_pointer( - args->channel_args, GRPC_ARG_TRANSPORT); - GPR_ASSERT(transport != nullptr); - chand->static_scheme = scheme_from_args(args->channel_args); - chand->user_agent = grpc_core::Slice( - user_agent_from_args(args->channel_args, transport->vtable->name)); - return GRPC_ERROR_NONE; +ArenaPromise HttpClientFilter::MakeCallPromise( + CallArgs call_args, NextPromiseFactory next_promise_factory) { + auto& md = call_args.client_initial_metadata; + md->Set(HttpMethodMetadata(), HttpMethodMetadata::kPost); + md->Set(HttpSchemeMetadata(), scheme_); + md->Set(TeMetadata(), TeMetadata::kTrailers); + md->Set(ContentTypeMetadata(), ContentTypeMetadata::kApplicationGrpc); + md->Set(UserAgentMetadata(), user_agent_.Ref()); + + auto* read_latch = GetContext()->New>(); + auto* write_latch = + absl::exchange(call_args.server_initial_metadata, read_latch); + + return CallPushPull( + Seq(next_promise_factory(std::move(call_args)), + [](ServerMetadataHandle md) -> ServerMetadataHandle { + auto r = CheckServerMetadata(md.get()); + if (!r.ok()) return ServerMetadataHandle(r); + return md; + }), + []() { return absl::OkStatus(); }, + Seq(read_latch->Wait(), + [write_latch](ServerMetadata** md) -> absl::Status { + auto r = + *md == nullptr ? absl::OkStatus() : CheckServerMetadata(*md); + write_latch->Set(*md); + return r; + })); } -/* Destructor for channel data */ -static void http_client_destroy_channel_elem(grpc_channel_element* elem) { - static_cast(elem->channel_data)->~channel_data(); +HttpClientFilter::HttpClientFilter(HttpSchemeMetadata::ValueType scheme, + Slice user_agent) + : scheme_(scheme), user_agent_(std::move(user_agent)) {} + +absl::StatusOr HttpClientFilter::Create( + const grpc_channel_args* args, ChannelFilter::Args) { + auto* transport = + grpc_channel_args_find_pointer(args, GRPC_ARG_TRANSPORT); + GPR_ASSERT(transport != nullptr); + return HttpClientFilter(SchemeFromArgs(args), + UserAgentFromArgs(args, transport->vtable->name)); } -const grpc_channel_filter grpc_http_client_filter = { - http_client_start_transport_stream_op_batch, - nullptr, - grpc_channel_next_op, - sizeof(call_data), - http_client_init_call_elem, - grpc_call_stack_ignore_set_pollset_or_pollset_set, - http_client_destroy_call_elem, - sizeof(channel_data), - http_client_init_channel_elem, - http_client_destroy_channel_elem, - grpc_channel_next_get_info, - "http-client"}; +} // namespace grpc_core diff --git a/src/core/ext/filters/http/client/http_client_filter.h b/src/core/ext/filters/http/client/http_client_filter.h index ce577b44178..dd0190bbee1 100644 --- a/src/core/ext/filters/http/client/http_client_filter.h +++ b/src/core/ext/filters/http/client/http_client_filter.h @@ -21,8 +21,28 @@ #include #include "src/core/lib/channel/channel_stack.h" +#include "src/core/lib/channel/promise_based_filter.h" -/* Processes metadata on the client side for HTTP2 transports */ -extern const grpc_channel_filter grpc_http_client_filter; +namespace grpc_core { + +class HttpClientFilter : public ChannelFilter { + public: + static const grpc_channel_filter kFilter; + + static absl::StatusOr Create( + const grpc_channel_args* args, ChannelFilter::Args filter_args); + + // Construct a promise for one call. + ArenaPromise MakeCallPromise( + CallArgs call_args, NextPromiseFactory next_promise_factory) override; + + private: + HttpClientFilter(HttpSchemeMetadata::ValueType scheme, Slice user_agent); + + HttpSchemeMetadata::ValueType scheme_; + Slice user_agent_; +}; + +} // namespace grpc_core #endif /* GRPC_CORE_EXT_FILTERS_HTTP_CLIENT_HTTP_CLIENT_FILTER_H */ diff --git a/src/core/ext/filters/http/http_filters_plugin.cc b/src/core/ext/filters/http/http_filters_plugin.cc index 28650d14a8a..b0ed558e050 100644 --- a/src/core/ext/filters/http/http_filters_plugin.cc +++ b/src/core/ext/filters/http/http_filters_plugin.cc @@ -83,8 +83,8 @@ void RegisterHttpFilters(CoreConfiguration::Builder* builder) { GRPC_ARG_ENABLE_PER_MESSAGE_DECOMPRESSION, &MessageDecompressFilter); optional(GRPC_SERVER_CHANNEL, kMinimalStackHasDecompression, GRPC_ARG_ENABLE_PER_MESSAGE_DECOMPRESSION, &MessageDecompressFilter); - required(GRPC_CLIENT_SUBCHANNEL, &grpc_http_client_filter); - required(GRPC_CLIENT_DIRECT_CHANNEL, &grpc_http_client_filter); + required(GRPC_CLIENT_SUBCHANNEL, &HttpClientFilter::kFilter); + required(GRPC_CLIENT_DIRECT_CHANNEL, &HttpClientFilter::kFilter); required(GRPC_SERVER_CHANNEL, &grpc_http_server_filter); } } // namespace grpc_core diff --git a/src/core/lib/channel/promise_based_filter.cc b/src/core/lib/channel/promise_based_filter.cc index 77f10ff0566..480d085344d 100644 --- a/src/core/lib/channel/promise_based_filter.cc +++ b/src/core/lib/channel/promise_based_filter.cc @@ -16,6 +16,8 @@ #include "src/core/lib/channel/promise_based_filter.h" +#include + #include "src/core/lib/channel/channel_stack.h" namespace grpc_core { @@ -24,6 +26,25 @@ namespace promise_filter_detail { /////////////////////////////////////////////////////////////////////////////// // BaseCallData +BaseCallData::BaseCallData(grpc_call_element* elem, + const grpc_call_element_args* args, uint8_t flags) + : call_stack_(args->call_stack), + elem_(elem), + arena_(args->arena), + call_combiner_(args->call_combiner), + deadline_(args->deadline), + context_(args->context) { + if (flags & kFilterExaminesServerInitialMetadata) { + server_initial_metadata_latch_ = arena_->New>(); + } +} + +BaseCallData::~BaseCallData() { + if (server_initial_metadata_latch_ != nullptr) { + server_initial_metadata_latch_->~Latch(); + } +} + // We don't form ActivityPtr's to this type, and consequently don't need // Orphan(). void BaseCallData::Orphan() { abort(); } @@ -52,23 +73,285 @@ void BaseCallData::Drop() { GRPC_CALL_STACK_UNREF(call_stack_, "waker"); } /////////////////////////////////////////////////////////////////////////////// // ClientCallData +struct ClientCallData::RecvInitialMetadata final { + enum State { + // Initial state; no op seen + kInitial, + // No op seen, but we have a latch that would like to modify it when we do + kGotLatch, + // Hooked, no latch yet + kHookedWaitingForLatch, + // Hooked, latch seen + kHookedAndGotLatch, + // Got the callback, haven't set latch yet + kCompleteWaitingForLatch, + // Got the callback and got the latch + kCompleteAndGotLatch, + // Got the callback and set the latch + kCompleteAndSetLatch, + // Called the original callback + kResponded, + }; + + State state = kInitial; + grpc_closure* original_on_ready = nullptr; + grpc_closure on_ready; + grpc_metadata_batch* metadata = nullptr; + Latch* server_initial_metadata_publisher = nullptr; +}; + +class ClientCallData::PollContext { + public: + explicit PollContext(ClientCallData* self) : self_(self) { + GPR_ASSERT(self_->poll_ctx_ == nullptr); + self_->poll_ctx_ = this; + scoped_activity_.Init(self_); + have_scoped_activity_ = true; + } + + PollContext(const PollContext&) = delete; + PollContext& operator=(const PollContext&) = delete; + + void Run() { + GPR_ASSERT(have_scoped_activity_); + repoll_ = false; + if (self_->server_initial_metadata_latch() != nullptr) { + switch (self_->recv_initial_metadata_->state) { + case RecvInitialMetadata::kInitial: + case RecvInitialMetadata::kGotLatch: + case RecvInitialMetadata::kHookedWaitingForLatch: + case RecvInitialMetadata::kHookedAndGotLatch: + case RecvInitialMetadata::kCompleteWaitingForLatch: + case RecvInitialMetadata::kResponded: + break; + case RecvInitialMetadata::kCompleteAndGotLatch: + self_->recv_initial_metadata_->state = + RecvInitialMetadata::kCompleteAndSetLatch; + self_->recv_initial_metadata_->server_initial_metadata_publisher->Set( + self_->recv_initial_metadata_->metadata); + ABSL_FALLTHROUGH_INTENDED; + case RecvInitialMetadata::kCompleteAndSetLatch: { + Poll p = + self_->server_initial_metadata_latch()->Wait()(); + if (ServerMetadata*** ppp = absl::get_if(&p)) { + ServerMetadata* md = **ppp; + if (self_->recv_initial_metadata_->metadata != md) { + *self_->recv_initial_metadata_->metadata = std::move(*md); + } + self_->recv_initial_metadata_->state = + RecvInitialMetadata::kResponded; + call_closures_.Add( + absl::exchange(self_->recv_initial_metadata_->original_on_ready, + nullptr), + GRPC_ERROR_NONE, + "wake_inside_combiner:recv_initial_metadata_ready"); + } + } break; + } + } + if (self_->recv_trailing_state_ == RecvTrailingState::kCancelled || + self_->recv_trailing_state_ == RecvTrailingState::kResponded) { + return; + } + switch (self_->send_initial_state_) { + case SendInitialState::kQueued: + case SendInitialState::kForwarded: { + // Poll the promise once since we're waiting for it. + Poll poll = self_->promise_(); + if (auto* r = absl::get_if(&poll)) { + auto* md = UnwrapMetadata(std::move(*r)); + bool destroy_md = true; + if (self_->recv_trailing_state_ == RecvTrailingState::kComplete) { + if (self_->recv_trailing_metadata_ != md) { + *self_->recv_trailing_metadata_ = std::move(*md); + } else { + destroy_md = false; + } + self_->recv_trailing_state_ = RecvTrailingState::kResponded; + call_closures_.Add( + absl::exchange(self_->original_recv_trailing_metadata_ready_, + nullptr), + GRPC_ERROR_NONE, "wake_inside_combiner:recv_trailing_ready:1"); + } else { + GPR_ASSERT(*md->get_pointer(GrpcStatusMetadata()) != + GRPC_STATUS_OK); + grpc_error_handle error = grpc_error_set_int( + GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "early return from promise based filter"), + GRPC_ERROR_INT_GRPC_STATUS, + *md->get_pointer(GrpcStatusMetadata())); + if (auto* message = md->get_pointer(GrpcMessageMetadata())) { + error = grpc_error_set_str(error, GRPC_ERROR_STR_GRPC_MESSAGE, + message->as_string_view()); + } + GRPC_ERROR_UNREF(self_->cancelled_error_); + self_->cancelled_error_ = GRPC_ERROR_REF(error); + if (self_->recv_initial_metadata_ != nullptr) { + switch (self_->recv_initial_metadata_->state) { + case RecvInitialMetadata::kInitial: + case RecvInitialMetadata::kGotLatch: + case RecvInitialMetadata::kHookedWaitingForLatch: + case RecvInitialMetadata::kHookedAndGotLatch: + case RecvInitialMetadata::kResponded: + break; + case RecvInitialMetadata::kCompleteWaitingForLatch: + case RecvInitialMetadata::kCompleteAndGotLatch: + case RecvInitialMetadata::kCompleteAndSetLatch: + self_->recv_initial_metadata_->state = + RecvInitialMetadata::kResponded; + call_closures_.Add( + absl::exchange( + self_->recv_initial_metadata_->original_on_ready, + nullptr), + GRPC_ERROR_REF(error), + "wake_inside_combiner:recv_initial_metadata_ready"); + } + } + if (self_->send_initial_state_ == SendInitialState::kQueued) { + self_->send_initial_state_ = SendInitialState::kCancelled; + cancel_send_initial_metadata_error_ = error; + } else { + GPR_ASSERT( + self_->recv_trailing_state_ == RecvTrailingState::kInitial || + self_->recv_trailing_state_ == RecvTrailingState::kForwarded); + self_->call_combiner()->Cancel(GRPC_ERROR_REF(error)); + forward_batch_ = + grpc_make_transport_stream_op(GRPC_CLOSURE_CREATE( + [](void* p, grpc_error_handle) { + GRPC_CALL_COMBINER_STOP(static_cast(p), + "finish_cancel"); + }, + self_->call_combiner(), nullptr)); + forward_batch_->cancel_stream = true; + forward_batch_->payload->cancel_stream.cancel_error = error; + } + self_->recv_trailing_state_ = RecvTrailingState::kCancelled; + } + if (destroy_md) { + md->~grpc_metadata_batch(); + } + scoped_activity_.Destroy(); + have_scoped_activity_ = false; + self_->promise_ = ArenaPromise(); + } + } break; + case SendInitialState::kInitial: + case SendInitialState::kCancelled: + // If we get a response without sending anything, we just propagate + // that up. (note: that situation isn't possible once we finish the + // promise transition). + if (self_->recv_trailing_state_ == RecvTrailingState::kComplete) { + self_->recv_trailing_state_ = RecvTrailingState::kResponded; + call_closures_.Add( + absl::exchange(self_->original_recv_trailing_metadata_ready_, + nullptr), + GRPC_ERROR_NONE, "wake_inside_combiner:recv_trailing_ready:2"); + } + break; + } + } + + ~PollContext() { + self_->poll_ctx_ = nullptr; + if (have_scoped_activity_) scoped_activity_.Destroy(); + GRPC_CALL_STACK_REF(self_->call_stack(), "finish_poll"); + bool in_combiner = true; + if (call_closures_.size() != 0) { + if (forward_batch_ != nullptr) { + call_closures_.RunClosuresWithoutYielding(self_->call_combiner()); + } else { + in_combiner = false; + call_closures_.RunClosures(self_->call_combiner()); + } + } + if (forward_batch_ != nullptr) { + GPR_ASSERT(in_combiner); + in_combiner = false; + forward_send_initial_metadata_ = false; + grpc_call_next_op(self_->elem(), forward_batch_); + } + if (cancel_send_initial_metadata_error_ != GRPC_ERROR_NONE) { + GPR_ASSERT(in_combiner); + forward_send_initial_metadata_ = false; + in_combiner = false; + grpc_transport_stream_op_batch_finish_with_failure( + absl::exchange(self_->send_initial_metadata_batch_, nullptr), + cancel_send_initial_metadata_error_, self_->call_combiner()); + } + if (absl::exchange(forward_send_initial_metadata_, false)) { + GPR_ASSERT(in_combiner); + in_combiner = false; + grpc_call_next_op( + self_->elem(), + absl::exchange(self_->send_initial_metadata_batch_, nullptr)); + } + if (repoll_) { + if (in_combiner) { + self_->WakeInsideCombiner(); + } else { + struct NextPoll : public grpc_closure { + grpc_call_stack* call_stack; + ClientCallData* call_data; + }; + auto run = [](void* p, grpc_error_handle) { + auto* next_poll = static_cast(p); + next_poll->call_data->WakeInsideCombiner(); + GRPC_CALL_STACK_UNREF(next_poll->call_stack, "re-poll"); + delete next_poll; + }; + auto* p = absl::make_unique().release(); + p->call_stack = self_->call_stack(); + p->call_data = self_; + GRPC_CALL_STACK_REF(self_->call_stack(), "re-poll"); + GRPC_CLOSURE_INIT(p, run, p, nullptr); + GRPC_CALL_COMBINER_START(self_->call_combiner(), p, GRPC_ERROR_NONE, + "re-poll"); + } + } else if (in_combiner) { + GRPC_CALL_COMBINER_STOP(self_->call_combiner(), "poll paused"); + } + GRPC_CALL_STACK_UNREF(self_->call_stack(), "finish_poll"); + } + + void Repoll() { repoll_ = true; } + + void ForwardSendInitialMetadata() { forward_send_initial_metadata_ = true; } + + private: + ManualConstructor scoped_activity_; + ClientCallData* self_; + CallCombinerClosureList call_closures_; + grpc_error_handle cancel_send_initial_metadata_error_ = GRPC_ERROR_NONE; + grpc_transport_stream_op_batch* forward_batch_ = nullptr; + bool repoll_ = false; + bool forward_send_initial_metadata_ = false; + bool have_scoped_activity_; +}; + ClientCallData::ClientCallData(grpc_call_element* elem, - const grpc_call_element_args* args) - : BaseCallData(elem, args) { + const grpc_call_element_args* args, + uint8_t flags) + : BaseCallData(elem, args, flags) { GRPC_CLOSURE_INIT(&recv_trailing_metadata_ready_, RecvTrailingMetadataReadyCallback, this, grpc_schedule_on_exec_ctx); + if (server_initial_metadata_latch() != nullptr) { + recv_initial_metadata_ = arena()->New(); + } } ClientCallData::~ClientCallData() { - GPR_ASSERT(!is_polling_); + GPR_ASSERT(poll_ctx_ == nullptr); GRPC_ERROR_UNREF(cancelled_error_); + if (recv_initial_metadata_ != nullptr) { + recv_initial_metadata_->~RecvInitialMetadata(); + } } // Activity implementation. void ClientCallData::ForceImmediateRepoll() { - GPR_ASSERT(is_polling_); - repoll_ = true; + GPR_ASSERT(poll_ctx_ != nullptr); + poll_ctx_->Repoll(); } // Handle one grpc_transport_stream_op_batch @@ -88,6 +371,36 @@ void ClientCallData::StartBatch(grpc_transport_stream_op_batch* batch) { return; } + if (recv_initial_metadata_ != nullptr && batch->recv_initial_metadata) { + switch (recv_initial_metadata_->state) { + case RecvInitialMetadata::kInitial: + recv_initial_metadata_->state = + RecvInitialMetadata::kHookedWaitingForLatch; + break; + case RecvInitialMetadata::kGotLatch: + recv_initial_metadata_->state = RecvInitialMetadata::kHookedAndGotLatch; + break; + case RecvInitialMetadata::kHookedWaitingForLatch: + case RecvInitialMetadata::kHookedAndGotLatch: + case RecvInitialMetadata::kCompleteWaitingForLatch: + case RecvInitialMetadata::kCompleteAndGotLatch: + case RecvInitialMetadata::kCompleteAndSetLatch: + case RecvInitialMetadata::kResponded: + abort(); // unreachable + } + auto cb = [](void* ptr, grpc_error_handle error) { + ClientCallData* self = static_cast(ptr); + self->RecvInitialMetadataReady(error); + }; + recv_initial_metadata_->metadata = + batch->payload->recv_initial_metadata.recv_initial_metadata; + recv_initial_metadata_->original_on_ready = + batch->payload->recv_initial_metadata.recv_initial_metadata_ready; + GRPC_CLOSURE_INIT(&recv_initial_metadata_->on_ready, cb, this, nullptr); + batch->payload->recv_initial_metadata.recv_initial_metadata_ready = + &recv_initial_metadata_->on_ready; + } + // send_initial_metadata: seeing this triggers the start of the promise part // of this filter. if (batch->send_initial_metadata) { @@ -164,6 +477,25 @@ void ClientCallData::Cancel(grpc_error_handle error) { } else { send_initial_state_ = SendInitialState::kCancelled; } + if (recv_initial_metadata_ != nullptr) { + switch (recv_initial_metadata_->state) { + case RecvInitialMetadata::kCompleteWaitingForLatch: + case RecvInitialMetadata::kCompleteAndGotLatch: + case RecvInitialMetadata::kCompleteAndSetLatch: + recv_initial_metadata_->state = RecvInitialMetadata::kResponded; + GRPC_CALL_COMBINER_START( + call_combiner(), + absl::exchange(recv_initial_metadata_->original_on_ready, nullptr), + GRPC_ERROR_REF(error), "propagate cancellation"); + break; + case RecvInitialMetadata::kInitial: + case RecvInitialMetadata::kGotLatch: + case RecvInitialMetadata::kHookedWaitingForLatch: + case RecvInitialMetadata::kHookedAndGotLatch: + case RecvInitialMetadata::kResponded: + break; + } + } } // Begin running the promise - which will ultimately take some initial @@ -173,18 +505,48 @@ void ClientCallData::StartPromise() { ChannelFilter* filter = static_cast(elem()->channel_data); // Construct the promise. - { - ScopedActivity activity(this); - promise_ = filter->MakeCallPromise( - CallArgs{ - WrapMetadata(send_initial_metadata_batch_->payload - ->send_initial_metadata.send_initial_metadata), - nullptr}, - [this](CallArgs call_args) { - return MakeNextPromise(std::move(call_args)); - }); + PollContext ctx(this); + promise_ = filter->MakeCallPromise( + CallArgs{WrapMetadata(send_initial_metadata_batch_->payload + ->send_initial_metadata.send_initial_metadata), + server_initial_metadata_latch()}, + [this](CallArgs call_args) { + return MakeNextPromise(std::move(call_args)); + }); + ctx.Run(); +} + +void ClientCallData::RecvInitialMetadataReady(grpc_error_handle error) { + ScopedContext context(this); + switch (recv_initial_metadata_->state) { + case RecvInitialMetadata::kHookedWaitingForLatch: + recv_initial_metadata_->state = + RecvInitialMetadata::kCompleteWaitingForLatch; + break; + case RecvInitialMetadata::kHookedAndGotLatch: + recv_initial_metadata_->state = RecvInitialMetadata::kCompleteAndGotLatch; + break; + case RecvInitialMetadata::kInitial: + case RecvInitialMetadata::kGotLatch: + case RecvInitialMetadata::kCompleteWaitingForLatch: + case RecvInitialMetadata::kCompleteAndGotLatch: + case RecvInitialMetadata::kCompleteAndSetLatch: + case RecvInitialMetadata::kResponded: + abort(); // unreachable + } + if (error != GRPC_ERROR_NONE) { + recv_initial_metadata_->state = RecvInitialMetadata::kResponded; + GRPC_CALL_COMBINER_START( + call_combiner(), + absl::exchange(recv_initial_metadata_->original_on_ready, nullptr), + GRPC_ERROR_REF(error), "propagate cancellation"); + } else if (send_initial_state_ == SendInitialState::kCancelled) { + recv_initial_metadata_->state = RecvInitialMetadata::kResponded; + GRPC_CALL_COMBINER_START( + call_combiner(), + absl::exchange(recv_initial_metadata_->original_on_ready, nullptr), + GRPC_ERROR_REF(cancelled_error_), "propagate cancellation"); } - // Poll once. WakeInsideCombiner(); } @@ -207,10 +569,42 @@ void ClientCallData::HookRecvTrailingMetadata( // - return a wrapper around PollTrailingMetadata as the promise. ArenaPromise ClientCallData::MakeNextPromise( CallArgs call_args) { + GPR_ASSERT(poll_ctx_ != nullptr); GPR_ASSERT(send_initial_state_ == SendInitialState::kQueued); send_initial_metadata_batch_->payload->send_initial_metadata .send_initial_metadata = UnwrapMetadata(std::move(call_args.client_initial_metadata)); + if (recv_initial_metadata_ != nullptr) { + // Call args should contain a latch for receiving initial metadata. + // It might be the one we passed in - in which case we know this filter + // only wants to examine the metadata, or it might be a new instance, in + // which case we know the filter wants to mutate. + GPR_ASSERT(call_args.server_initial_metadata != nullptr); + recv_initial_metadata_->server_initial_metadata_publisher = + call_args.server_initial_metadata; + switch (recv_initial_metadata_->state) { + case RecvInitialMetadata::kInitial: + recv_initial_metadata_->state = RecvInitialMetadata::kGotLatch; + break; + case RecvInitialMetadata::kHookedWaitingForLatch: + recv_initial_metadata_->state = RecvInitialMetadata::kHookedAndGotLatch; + poll_ctx_->Repoll(); + break; + case RecvInitialMetadata::kCompleteWaitingForLatch: + recv_initial_metadata_->state = + RecvInitialMetadata::kCompleteAndGotLatch; + poll_ctx_->Repoll(); + break; + case RecvInitialMetadata::kGotLatch: + case RecvInitialMetadata::kHookedAndGotLatch: + case RecvInitialMetadata::kCompleteAndGotLatch: + case RecvInitialMetadata::kCompleteAndSetLatch: + case RecvInitialMetadata::kResponded: + abort(); // unreachable + } + } else { + GPR_ASSERT(call_args.server_initial_metadata == nullptr); + } return ArenaPromise( [this]() { return PollTrailingMetadata(); }); } @@ -220,6 +614,7 @@ ArenaPromise ClientCallData::MakeNextPromise( // All polls: await receiving the trailing metadata, then return it to the // application. Poll ClientCallData::PollTrailingMetadata() { + GPR_ASSERT(poll_ctx_ != nullptr); if (send_initial_state_ == SendInitialState::kQueued) { // First poll: pass the send_initial_metadata op down the stack. GPR_ASSERT(send_initial_metadata_batch_ != nullptr); @@ -229,7 +624,7 @@ Poll ClientCallData::PollTrailingMetadata() { HookRecvTrailingMetadata(send_initial_metadata_batch_); recv_trailing_state_ = RecvTrailingState::kForwarded; } - forward_send_initial_metadata_ = true; + poll_ctx_->ForwardSendInitialMetadata(); } switch (recv_trailing_state_) { case RecvTrailingState::kInitial: @@ -264,6 +659,11 @@ void ClientCallData::RecvTrailingMetadataReadyCallback( } void ClientCallData::RecvTrailingMetadataReady(grpc_error_handle error) { + if (recv_trailing_state_ == RecvTrailingState::kCancelled) { + Closure::Run(DEBUG_LOCATION, original_recv_trailing_metadata_ready_, + GRPC_ERROR_REF(cancelled_error_)); + return; + } // If there was an error, we'll put that into the trailing metadata and // proceed as if there was not. if (error != GRPC_ERROR_NONE) { @@ -291,131 +691,7 @@ void ClientCallData::SetStatusFromError(grpc_metadata_batch* metadata, } // Wakeup and poll the promise if appropriate. -void ClientCallData::WakeInsideCombiner() { - GPR_ASSERT(!is_polling_); - grpc_closure* call_closure = nullptr; - is_polling_ = true; - grpc_error_handle cancel_send_initial_metadata_error = GRPC_ERROR_NONE; - grpc_transport_stream_op_batch* forward_batch = nullptr; - switch (send_initial_state_) { - case SendInitialState::kQueued: - case SendInitialState::kForwarded: { - // Poll the promise once since we're waiting for it. - Poll poll; - { - ScopedActivity activity(this); - poll = promise_(); - } - if (auto* r = absl::get_if(&poll)) { - promise_ = ArenaPromise(); - auto* md = UnwrapMetadata(std::move(*r)); - bool destroy_md = true; - if (recv_trailing_state_ == RecvTrailingState::kComplete) { - if (recv_trailing_metadata_ != md) { - *recv_trailing_metadata_ = std::move(*md); - } else { - destroy_md = false; - } - recv_trailing_state_ = RecvTrailingState::kResponded; - call_closure = - absl::exchange(original_recv_trailing_metadata_ready_, nullptr); - } else { - GPR_ASSERT(*md->get_pointer(GrpcStatusMetadata()) != GRPC_STATUS_OK); - grpc_error_handle error = - grpc_error_set_int(GRPC_ERROR_CREATE_FROM_STATIC_STRING( - "early return from promise based filter"), - GRPC_ERROR_INT_GRPC_STATUS, - *md->get_pointer(GrpcStatusMetadata())); - if (auto* message = md->get_pointer(GrpcMessageMetadata())) { - error = grpc_error_set_str(error, GRPC_ERROR_STR_GRPC_MESSAGE, - message->as_string_view()); - } - GRPC_ERROR_UNREF(cancelled_error_); - cancelled_error_ = GRPC_ERROR_REF(error); - if (send_initial_state_ == SendInitialState::kQueued) { - send_initial_state_ = SendInitialState::kCancelled; - cancel_send_initial_metadata_error = error; - } else { - GPR_ASSERT(recv_trailing_state_ == RecvTrailingState::kInitial || - recv_trailing_state_ == RecvTrailingState::kForwarded); - call_combiner()->Cancel(GRPC_ERROR_REF(error)); - forward_batch = grpc_make_transport_stream_op(GRPC_CLOSURE_CREATE( - [](void*, grpc_error_handle) {}, nullptr, nullptr)); - forward_batch->cancel_stream = true; - forward_batch->payload->cancel_stream.cancel_error = error; - } - recv_trailing_state_ = RecvTrailingState::kCancelled; - } - if (destroy_md) { - md->~grpc_metadata_batch(); - } - } - } break; - case SendInitialState::kInitial: - case SendInitialState::kCancelled: - // If we get a response without sending anything, we just propagate - // that up. (note: that situation isn't possible once we finish the - // promise transition). - if (recv_trailing_state_ == RecvTrailingState::kComplete) { - recv_trailing_state_ = RecvTrailingState::kResponded; - call_closure = - absl::exchange(original_recv_trailing_metadata_ready_, nullptr); - } - break; - } - GRPC_CALL_STACK_REF(call_stack(), "finish_poll"); - is_polling_ = false; - bool in_combiner = true; - bool repoll = absl::exchange(repoll_, false); - if (forward_batch != nullptr) { - GPR_ASSERT(in_combiner); - in_combiner = false; - forward_send_initial_metadata_ = false; - grpc_call_next_op(elem(), forward_batch); - } - if (cancel_send_initial_metadata_error != GRPC_ERROR_NONE) { - GPR_ASSERT(in_combiner); - forward_send_initial_metadata_ = false; - in_combiner = false; - grpc_transport_stream_op_batch_finish_with_failure( - absl::exchange(send_initial_metadata_batch_, nullptr), - cancel_send_initial_metadata_error, call_combiner()); - } - if (absl::exchange(forward_send_initial_metadata_, false)) { - GPR_ASSERT(in_combiner); - in_combiner = false; - grpc_call_next_op(elem(), - absl::exchange(send_initial_metadata_batch_, nullptr)); - } - if (call_closure != nullptr) { - GPR_ASSERT(in_combiner); - in_combiner = false; - Closure::Run(DEBUG_LOCATION, call_closure, GRPC_ERROR_NONE); - } - if (repoll) { - if (in_combiner) { - WakeInsideCombiner(); - } else { - struct NextPoll : public grpc_closure { - grpc_call_stack* call_stack; - ClientCallData* call_data; - }; - auto run = [](void* p, grpc_error_handle) { - auto* next_poll = static_cast(p); - next_poll->call_data->WakeInsideCombiner(); - GRPC_CALL_STACK_UNREF(next_poll->call_stack, "re-poll"); - delete next_poll; - }; - auto* p = new NextPoll; - GRPC_CALL_STACK_REF(call_stack(), "re-poll"); - GRPC_CLOSURE_INIT(p, run, p, nullptr); - GRPC_CALL_COMBINER_START(call_combiner(), p, GRPC_ERROR_NONE, "re-poll"); - } - } else if (in_combiner) { - GRPC_CALL_COMBINER_STOP(call_combiner(), "poll paused"); - } - GRPC_CALL_STACK_UNREF(call_stack(), "finish_poll"); -} +void ClientCallData::WakeInsideCombiner() { PollContext(this).Run(); } void ClientCallData::OnWakeup() { ScopedContext context(this); @@ -426,8 +702,9 @@ void ClientCallData::OnWakeup() { // ServerCallData ServerCallData::ServerCallData(grpc_call_element* elem, - const grpc_call_element_args* args) - : BaseCallData(elem, args) { + const grpc_call_element_args* args, + uint8_t flags) + : BaseCallData(elem, args, flags) { GRPC_CLOSURE_INIT(&recv_initial_metadata_ready_, RecvInitialMetadataReadyCallback, this, grpc_schedule_on_exec_ctx); @@ -589,11 +866,12 @@ void ServerCallData::RecvInitialMetadataReady(grpc_error_handle error) { ScopedContext context(this); // Construct the promise. ChannelFilter* filter = static_cast(elem()->channel_data); - promise_ = filter->MakeCallPromise( - CallArgs{WrapMetadata(recv_initial_metadata_), nullptr}, - [this](CallArgs call_args) { - return MakeNextPromise(std::move(call_args)); - }); + promise_ = + filter->MakeCallPromise(CallArgs{WrapMetadata(recv_initial_metadata_), + server_initial_metadata_latch()}, + [this](CallArgs call_args) { + return MakeNextPromise(std::move(call_args)); + }); // Poll once. bool own_error = false; WakeInsideCombiner([&error, &own_error](grpc_error_handle new_error) { diff --git a/src/core/lib/channel/promise_based_filter.h b/src/core/lib/channel/promise_based_filter.h index f653a10b745..09d6794b208 100644 --- a/src/core/lib/channel/promise_based_filter.h +++ b/src/core/lib/channel/promise_based_filter.h @@ -76,18 +76,17 @@ enum class FilterEndpoint { kServer, }; +// Flags for MakePromiseBasedFilter. +static constexpr uint8_t kFilterExaminesServerInitialMetadata = 1; + namespace promise_filter_detail { // Call data shared between all implementations of promise-based filters. class BaseCallData : public Activity, private Wakeable { public: - BaseCallData(grpc_call_element* elem, const grpc_call_element_args* args) - : call_stack_(args->call_stack), - elem_(elem), - arena_(args->arena), - call_combiner_(args->call_combiner), - deadline_(args->deadline), - context_(args->context) {} + BaseCallData(grpc_call_element* elem, const grpc_call_element_args* args, + uint8_t flags); + ~BaseCallData() override; void set_pollent(grpc_polling_entity* pollent) { GPR_ASSERT(nullptr == @@ -130,10 +129,14 @@ class BaseCallData : public Activity, private Wakeable { return p.Unwrap(); } + Arena* arena() { return arena_; } grpc_call_element* elem() const { return elem_; } CallCombiner* call_combiner() const { return call_combiner_; } Timestamp deadline() const { return deadline_; } grpc_call_stack* call_stack() const { return call_stack_; } + Latch* server_initial_metadata_latch() const { + return server_initial_metadata_latch_; + } private: // Wakeable implementation. @@ -150,11 +153,13 @@ class BaseCallData : public Activity, private Wakeable { CallFinalization finalization_; grpc_call_context_element* const context_; std::atomic pollent_{nullptr}; + Latch* server_initial_metadata_latch_ = nullptr; }; class ClientCallData : public BaseCallData { public: - ClientCallData(grpc_call_element* elem, const grpc_call_element_args* args); + ClientCallData(grpc_call_element* elem, const grpc_call_element_args* args, + uint8_t flags); ~ClientCallData() override; // Activity implementation. @@ -195,6 +200,9 @@ class ClientCallData : public BaseCallData { kCancelled }; + struct RecvInitialMetadata; + class PollContext; + // Handle cancellation. void Cancel(grpc_error_handle error); // Begin running the promise - which will ultimately take some initial @@ -217,6 +225,7 @@ class ClientCallData : public BaseCallData { static void RecvTrailingMetadataReadyCallback(void* arg, grpc_error_handle error); void RecvTrailingMetadataReady(grpc_error_handle error); + void RecvInitialMetadataReady(grpc_error_handle error); // Given an error, fill in ServerMetadataHandle to represent that error. void SetStatusFromError(grpc_metadata_batch* metadata, grpc_error_handle error); @@ -230,6 +239,8 @@ class ClientCallData : public BaseCallData { grpc_transport_stream_op_batch* send_initial_metadata_batch_ = nullptr; // Pointer to where trailing metadata will be stored. grpc_metadata_batch* recv_trailing_metadata_ = nullptr; + // State tracking recv initial metadata for filters that care about it. + RecvInitialMetadata* recv_initial_metadata_ = nullptr; // Closure to call when we're done with the trailing metadata. grpc_closure* original_recv_trailing_metadata_ready_ = nullptr; // Our closure pointing to RecvTrailingMetadataReadyCallback. @@ -240,17 +251,14 @@ class ClientCallData : public BaseCallData { SendInitialState send_initial_state_ = SendInitialState::kInitial; // State of the recv_trailing_metadata op. RecvTrailingState recv_trailing_state_ = RecvTrailingState::kInitial; - // Whether we're currently polling the promise. - bool is_polling_ = false; - // Should we repoll after completing polling? - bool repoll_ = false; - // Whether we should forward send initial metadata after polling? - bool forward_send_initial_metadata_ = false; + // Polling related data. Non-null if we're actively polling + PollContext* poll_ctx_ = nullptr; }; class ServerCallData : public BaseCallData { public: - ServerCallData(grpc_call_element* elem, const grpc_call_element_args* args); + ServerCallData(grpc_call_element* elem, const grpc_call_element_args* args, + uint8_t flags); ~ServerCallData() override; // Activity implementation. @@ -356,7 +364,7 @@ class CallData : public ServerCallData { // }; // TODO(ctiller): allow implementing get_channel_info, start_transport_op in // some way on ChannelFilter. -template +template absl::enable_if_t::value, grpc_channel_filter> MakePromiseBasedFilter(const char* name) { using CallData = promise_filter_detail::CallData; @@ -383,7 +391,7 @@ MakePromiseBasedFilter(const char* name) { sizeof(CallData), // init_call_elem [](grpc_call_element* elem, const grpc_call_element_args* args) { - new (elem->call_data) CallData(elem, args); + new (elem->call_data) CallData(elem, args, kFlags); return GRPC_ERROR_NONE; }, // set_pollset_or_pollset_set diff --git a/src/core/lib/promise/call_push_pull.h b/src/core/lib/promise/call_push_pull.h index 26d54a0e613..a56496002f1 100644 --- a/src/core/lib/promise/call_push_pull.h +++ b/src/core/lib/promise/call_push_pull.h @@ -75,7 +75,7 @@ class CallPushPull { if (IsStatusOk(*status)) { done_.set(kDonePush); } else { - return std::move(*status); + return Result(std::move(*status)); } } } @@ -97,7 +97,7 @@ class CallPushPull { if (IsStatusOk(*status)) { done_.set(kDonePull); } else { - return std::move(*status); + return Result(std::move(*status)); } } } diff --git a/src/core/lib/promise/detail/status.h b/src/core/lib/promise/detail/status.h index 69dfede1bc8..a56c6fe08c0 100644 --- a/src/core/lib/promise/detail/status.h +++ b/src/core/lib/promise/detail/status.h @@ -39,11 +39,12 @@ inline absl::Status IntoStatus(absl::Status* status) { } } // namespace promise_detail -} // namespace grpc_core // Return true if the status represented by the argument is ok, false if not. // By implementing this function for other, non-absl::Status types, those types // can participate in TrySeq as result types that affect control flow. inline bool IsStatusOk(const absl::Status& status) { return status.ok(); } +} // namespace grpc_core + #endif // GRPC_CORE_LIB_PROMISE_DETAIL_STATUS_H diff --git a/src/core/lib/transport/transport.h b/src/core/lib/transport/transport.h index b392d1bd6bd..1168e22b649 100644 --- a/src/core/lib/transport/transport.h +++ b/src/core/lib/transport/transport.h @@ -98,7 +98,7 @@ class MetadataHandle { T* handle_ = nullptr; }; -// Trailing metadata type +// Server metadata type // TODO(ctiller): This should be a bespoke instance of MetadataMap<> using ServerMetadata = grpc_metadata_batch; using ServerMetadataHandle = MetadataHandle; diff --git a/test/cpp/microbenchmarks/bm_call_create.cc b/test/cpp/microbenchmarks/bm_call_create.cc index 137e14e0a5b..c56172006b5 100644 --- a/test/cpp/microbenchmarks/bm_call_create.cc +++ b/test/cpp/microbenchmarks/bm_call_create.cc @@ -561,7 +561,7 @@ static void BM_IsolatedFilter(benchmark::State& state) { grpc_slice method = grpc_slice_from_static_string("/foo/bar"); grpc_call_final_info final_info; TestOp test_op_data; - const int kArenaSize = 4096; + const int kArenaSize = 32 * 1024 * 1024; grpc_call_context_element context[GRPC_CONTEXT_COUNT] = {}; grpc_call_element_args call_args{ call_stack, @@ -617,7 +617,8 @@ typedef Fixture<&grpc_server_deadline_filter, CHECKS_NOT_LAST> ServerDeadlineFilter; BENCHMARK_TEMPLATE(BM_IsolatedFilter, ServerDeadlineFilter, NoOp); BENCHMARK_TEMPLATE(BM_IsolatedFilter, ServerDeadlineFilter, SendEmptyMetadata); -typedef Fixture<&grpc_http_client_filter, CHECKS_NOT_LAST | REQUIRES_TRANSPORT> +typedef Fixture<&grpc_core::HttpClientFilter::kFilter, + CHECKS_NOT_LAST | REQUIRES_TRANSPORT> HttpClientFilter; BENCHMARK_TEMPLATE(BM_IsolatedFilter, HttpClientFilter, NoOp); BENCHMARK_TEMPLATE(BM_IsolatedFilter, HttpClientFilter, SendEmptyMetadata); diff --git a/tools/doxygen/Doxyfile.c++.internal b/tools/doxygen/Doxyfile.c++.internal index 1625b8c33de..41296d921a0 100644 --- a/tools/doxygen/Doxyfile.c++.internal +++ b/tools/doxygen/Doxyfile.c++.internal @@ -2184,6 +2184,7 @@ src/core/lib/profiling/timers.h \ src/core/lib/promise/activity.cc \ src/core/lib/promise/activity.h \ src/core/lib/promise/arena_promise.h \ +src/core/lib/promise/call_push_pull.h \ src/core/lib/promise/context.h \ src/core/lib/promise/detail/basic_seq.h \ src/core/lib/promise/detail/promise_factory.h \ diff --git a/tools/doxygen/Doxyfile.core.internal b/tools/doxygen/Doxyfile.core.internal index ec96d23da06..00920470885 100644 --- a/tools/doxygen/Doxyfile.core.internal +++ b/tools/doxygen/Doxyfile.core.internal @@ -1979,6 +1979,7 @@ src/core/lib/profiling/timers.h \ src/core/lib/promise/activity.cc \ src/core/lib/promise/activity.h \ src/core/lib/promise/arena_promise.h \ +src/core/lib/promise/call_push_pull.h \ src/core/lib/promise/context.h \ src/core/lib/promise/detail/basic_seq.h \ src/core/lib/promise/detail/promise_factory.h \