diff --git a/src/core/ext/filters/http/client/http_client_filter.cc b/src/core/ext/filters/http/client/http_client_filter.cc index 8b547bac6ee..07fcba62581 100644 --- a/src/core/ext/filters/http/client/http_client_filter.cc +++ b/src/core/ext/filters/http/client/http_client_filter.cc @@ -51,6 +51,9 @@ namespace grpc_core { +const NoInterceptor HttpClientFilter::Call::OnServerToClientMessage; +const NoInterceptor HttpClientFilter::Call::OnClientToServerMessage; + const grpc_channel_filter HttpClientFilter::kFilter = MakePromiseBasedFilter("http-client"); @@ -105,40 +108,27 @@ Slice UserAgentFromArgs(const ChannelArgs& args, } } // namespace -ArenaPromise HttpClientFilter::MakeCallPromise( - CallArgs call_args, NextPromiseFactory next_promise_factory) { - auto& md = call_args.client_initial_metadata; - if (test_only_use_put_requests_) { - md->Set(HttpMethodMetadata(), HttpMethodMetadata::kPut); +void HttpClientFilter::Call::OnClientInitialMetadata(ClientMetadata& md, + HttpClientFilter* filter) { + if (filter->test_only_use_put_requests_) { + md.Set(HttpMethodMetadata(), HttpMethodMetadata::kPut); } else { - md->Set(HttpMethodMetadata(), HttpMethodMetadata::kPost); + md.Set(HttpMethodMetadata(), HttpMethodMetadata::kPost); } - md->Set(HttpSchemeMetadata(), scheme_); - md->Set(TeMetadata(), TeMetadata::kTrailers); - md->Set(ContentTypeMetadata(), ContentTypeMetadata::kApplicationGrpc); - md->Set(UserAgentMetadata(), user_agent_.Ref()); - - auto* initial_metadata_err = - GetContext()->New>(); - - call_args.server_initial_metadata->InterceptAndMap( - [initial_metadata_err]( - ServerMetadataHandle md) -> absl::optional { - auto r = CheckServerMetadata(md.get()); - if (!r.ok()) { - initial_metadata_err->Set(ServerMetadataFromStatus(r)); - return absl::nullopt; - } - return std::move(md); - }); - - return Race(initial_metadata_err->Wait(), - Map(next_promise_factory(std::move(call_args)), - [](ServerMetadataHandle md) -> ServerMetadataHandle { - auto r = CheckServerMetadata(md.get()); - if (!r.ok()) return ServerMetadataFromStatus(r); - return md; - })); + md.Set(HttpSchemeMetadata(), filter->scheme_); + md.Set(TeMetadata(), TeMetadata::kTrailers); + md.Set(ContentTypeMetadata(), ContentTypeMetadata::kApplicationGrpc); + md.Set(UserAgentMetadata(), filter->user_agent_.Ref()); +} + +absl::Status HttpClientFilter::Call::OnServerInitialMetadata( + ServerMetadata& md) { + return CheckServerMetadata(&md); +} + +absl::Status HttpClientFilter::Call::OnServerTrailingMetadata( + ServerMetadata& md) { + return CheckServerMetadata(&md); } HttpClientFilter::HttpClientFilter(HttpSchemeMetadata::ValueType scheme, diff --git a/src/core/ext/filters/http/client/http_client_filter.h b/src/core/ext/filters/http/client/http_client_filter.h index 298daf03c67..3146ea07385 100644 --- a/src/core/ext/filters/http/client/http_client_filter.h +++ b/src/core/ext/filters/http/client/http_client_filter.h @@ -25,23 +25,27 @@ #include "src/core/lib/channel/channel_args.h" #include "src/core/lib/channel/channel_fwd.h" #include "src/core/lib/channel/promise_based_filter.h" -#include "src/core/lib/promise/arena_promise.h" #include "src/core/lib/slice/slice.h" #include "src/core/lib/transport/metadata_batch.h" #include "src/core/lib/transport/transport.h" namespace grpc_core { -class HttpClientFilter : public ChannelFilter { +class HttpClientFilter : public ImplementChannelFilter { public: static const grpc_channel_filter kFilter; static absl::StatusOr Create( const ChannelArgs& args, ChannelFilter::Args filter_args); - // Construct a promise for one call. - ArenaPromise MakeCallPromise( - CallArgs call_args, NextPromiseFactory next_promise_factory) override; + class Call { + public: + void OnClientInitialMetadata(ClientMetadata& md, HttpClientFilter* filter); + absl::Status OnServerInitialMetadata(ServerMetadata& md); + absl::Status OnServerTrailingMetadata(ServerMetadata& md); + static const NoInterceptor OnClientToServerMessage; + static const NoInterceptor OnServerToClientMessage; + }; private: HttpClientFilter(HttpSchemeMetadata::ValueType scheme, Slice user_agent, diff --git a/src/core/lib/channel/promise_based_filter.h b/src/core/lib/channel/promise_based_filter.h index 19efe505db2..25ee1230ef1 100644 --- a/src/core/lib/channel/promise_based_filter.h +++ b/src/core/lib/channel/promise_based_filter.h @@ -43,6 +43,7 @@ #include #include "src/core/lib/channel/call_finalization.h" +#include "src/core/lib/channel/channel_args.h" #include "src/core/lib/channel/channel_fwd.h" #include "src/core/lib/channel/channel_stack.h" #include "src/core/lib/channel/context.h" @@ -60,6 +61,7 @@ #include "src/core/lib/promise/context.h" #include "src/core/lib/promise/pipe.h" #include "src/core/lib/promise/poll.h" +#include "src/core/lib/promise/race.h" #include "src/core/lib/resource_quota/arena.h" #include "src/core/lib/slice/slice_buffer.h" #include "src/core/lib/surface/call.h" @@ -122,6 +124,289 @@ class ChannelFilter { grpc_event_engine::experimental::GetDefaultEventEngine(); }; +struct NoInterceptor {}; + +namespace promise_filter_detail { + +// Determine if a list of interceptors has any that need to asyncronously error +// the promise. If so, we need to allocate a latch for the generated promise for +// the original promise stack polyfill code that's generated. + +inline constexpr bool HasAsyncErrorInterceptor() { return false; } + +inline constexpr bool HasAsyncErrorInterceptor(const NoInterceptor*) { + return false; +} + +template +inline constexpr bool HasAsyncErrorInterceptor(absl::Status (T::*)(A...)) { + return true; +} + +template +inline constexpr bool HasAsyncErrorInterceptor(void (T::*)(A...)) { + return false; +} + +// For the list case we do two interceptors to avoid amiguities with the single +// argument forms above. +template +inline constexpr bool HasAsyncErrorInterceptor(I1 i1, I2 i2, + Interceptors... interceptors) { + return HasAsyncErrorInterceptor(i1) || HasAsyncErrorInterceptor(i2) || + HasAsyncErrorInterceptor(interceptors...); +} + +// Composite for a given channel type to determine if any of its interceptors +// fall into this category: later code should use this. +template +inline constexpr bool CallHasAsyncErrorInterceptor() { + return HasAsyncErrorInterceptor(&Derived::Call::OnClientToServerMessage, + &Derived::Call::OnServerInitialMetadata, + &Derived::Call::OnServerToClientMessage); +} + +// Determine if an interceptor needs to access the channel via one of its +// arguments. If so, we need to allocate a pointer to the channel for the +// generated polyfill promise for the original promise stack. + +inline constexpr bool HasChannelAccess() { return false; } + +inline constexpr bool HasChannelAccess(const NoInterceptor*) { return false; } + +template +inline constexpr bool HasChannelAccess(R (T::*)(A)) { + return false; +} + +template +inline constexpr bool HasChannelAccess(R (T::*)(A, C)) { + return true; +} + +// For the list case we do two interceptors to avoid amiguities with the single +// argument forms above. +template +inline constexpr bool HasChannelAccess(I1 i1, I2 i2, + Interceptors... interceptors) { + return HasChannelAccess(i1) || HasChannelAccess(i2) || + HasChannelAccess(interceptors...); +} + +// Composite for a given channel type to determine if any of its interceptors +// fall into this category: later code should use this. +template +inline constexpr bool CallHasChannelAccess() { + return HasChannelAccess(&Derived::Call::OnClientInitialMetadata, + &Derived::Call::OnClientToServerMessage, + &Derived::Call::OnServerInitialMetadata, + &Derived::Call::OnServerToClientMessage, + &Derived::Call::OnServerTrailingMetadata); +} + +// Given a boolean X export a type: +// either T if X is true +// or an empty type if it is false +template +struct TypeIfNeeded; + +template +struct TypeIfNeeded { + struct Type { + Type() = default; + template + explicit Type(Whatever) : Type() {} + }; +}; + +template +struct TypeIfNeeded { + using Type = T; +}; + +// For the original promise scheme polyfill: +// If a set of interceptors might fail asynchronously, wrap the main +// promise in a race with the cancellation latch. +// If not, just return the main promise. +template +struct RaceAsyncCompletion; + +template <> +struct RaceAsyncCompletion { + template + static Promise Run(Promise x, void*) { + return x; + } +}; + +template <> +struct RaceAsyncCompletion { + template + static Promise Run(Promise x, Latch* latch) { + return Race(latch->Wait(), std::move(x)); + } +}; + +// For the original promise scheme polyfill: data associated with once call. +template +struct FilterCallData { + explicit FilterCallData(Derived* channel) : channel(channel) {} + GPR_NO_UNIQUE_ADDRESS typename Derived::Call call; + GPR_NO_UNIQUE_ADDRESS + typename TypeIfNeeded, + CallHasAsyncErrorInterceptor()>::Type + error_latch; + GPR_NO_UNIQUE_ADDRESS + typename TypeIfNeeded()>::Type + channel; +}; + +template +auto MapResult(const NoInterceptor*, Promise x, void*) { + return x; +} + +template +auto MapResult(absl::Status (Derived::Call::*fn)(ServerMetadata&), Promise x, + FilterCallData* call_data) { + GPR_DEBUG_ASSERT(fn == &Derived::Call::OnServerTrailingMetadata); + return Map(std::move(x), [call_data](ServerMetadataHandle md) { + auto status = call_data->call.OnServerTrailingMetadata(*md); + if (!status.ok()) return ServerMetadataFromStatus(status); + return md; + }); +} + +inline auto RunCall(const NoInterceptor*, CallArgs call_args, + NextPromiseFactory next_promise_factory, void*) { + return next_promise_factory(std::move(call_args)); +} + +template +inline auto RunCall(void (Derived::Call::*fn)(ClientMetadata& md), + CallArgs call_args, NextPromiseFactory next_promise_factory, + FilterCallData* call_data) { + GPR_DEBUG_ASSERT(fn == &Derived::Call::OnClientInitialMetadata); + call_data->call.OnClientInitialMetadata(*call_args.client_initial_metadata); + return next_promise_factory(std::move(call_args)); +} + +template +inline auto RunCall(void (Derived::Call::*fn)(ClientMetadata& md, + Derived* channel), + CallArgs call_args, NextPromiseFactory next_promise_factory, + FilterCallData* call_data) { + GPR_DEBUG_ASSERT(fn == &Derived::Call::OnClientInitialMetadata); + call_data->call.OnClientInitialMetadata(*call_args.client_initial_metadata, + call_data->channel); + return next_promise_factory(std::move(call_args)); +} + +inline void InterceptClientToServerMessage(const NoInterceptor*, void*, + CallArgs&) {} + +inline void InterceptServerInitialMetadata(const NoInterceptor*, void*, + CallArgs&) {} + +template +inline void InterceptServerInitialMetadata( + absl::Status (Derived::Call::*fn)(ServerMetadata&), + FilterCallData* call_data, CallArgs& call_args) { + GPR_DEBUG_ASSERT(fn == &Derived::Call::OnServerInitialMetadata); + call_args.server_initial_metadata->InterceptAndMap( + [call_data]( + ServerMetadataHandle md) -> absl::optional { + auto status = call_data->call.OnServerInitialMetadata(*md); + if (!status.ok() && !call_data->error_latch.is_set()) { + call_data->error_latch.Set(ServerMetadataFromStatus(status)); + return absl::nullopt; + } + return std::move(md); + }); +} + +inline void InterceptServerToClientMessage(const NoInterceptor*, void*, + CallArgs&) {} + +template +absl::enable_if_t>::value, + FilterCallData*> +MakeFilterCall(Derived*) { + static FilterCallData call{nullptr}; + return &call; +} + +template +absl::enable_if_t>::value, + FilterCallData*> +MakeFilterCall(Derived* derived) { + return GetContext()->ManagedNew>(derived); +} + +} // namespace promise_filter_detail + +// Base class for promise-based channel filters. +// Eventually this machinery will move elsewhere (the interception logic will +// move directly into the channel stack, and so filters will just directly +// derive from `ChannelFilter`) +// +// Implements new-style call filters, and polyfills them into the previous +// scheme. +// +// Call filters: +// Derived types should declare a class `Call` with the following members: +// - OnClientInitialMetadata - $VALUE_TYPE = ClientMetadata +// - OnServerInitialMetadata - $VALUE_TYPE = ServerMetadata +// - OnServerToClientMessage - $VALUE_TYPE = Message +// - OnClientToServerMessage - $VALUE_TYPE = Message +// - OnServerTrailingMetadata - $VALUE_TYPE = ServerMetadata +// These members define an interception point for a particular event in +// the call lifecycle. +// The type of these members matters, and is selectable by the class +// author. For $INTERCEPTOR_NAME in the above list: +// - static const NoInterceptor $INTERCEPTOR_NAME: +// defines that this filter does not intercept this event. +// there is zero runtime cost added to handling that event by this filter. +// - void $INTERCEPTOR_NAME($VALUE_TYPE&): +// the filter intercepts this event, and can modify the value. +// it never fails. +// - absl::Status $INTERCEPTOR_NAME($VALUE_TYPE&): +// the filter intercepts this event, and can modify the value. +// it can fail, in which case the call will be aborted. +// - void $INTERCEPTOR_NAME($VALUE_TYPE&, Derived*): +// the filter intercepts this event, and can modify the value. +// it can access the channel via the second argument. +// it never fails. +// - absl::Status $INTERCEPTOR_NAME($VALUE_TYPE&, Derived*): +// the filter intercepts this event, and can modify the value. +// it can access the channel via the second argument. +// it can fail, in which case the call will be aborted. +template +class ImplementChannelFilter : public ChannelFilter { + public: + ArenaPromise MakeCallPromise( + CallArgs call_args, NextPromiseFactory next_promise_factory) final { + auto* call = promise_filter_detail::MakeFilterCall( + static_cast(this)); + promise_filter_detail::InterceptClientToServerMessage( + &Derived::Call::OnClientToServerMessage, call, call_args); + promise_filter_detail::InterceptServerInitialMetadata( + &Derived::Call::OnServerInitialMetadata, call, call_args); + promise_filter_detail::InterceptServerToClientMessage( + &Derived::Call::OnServerToClientMessage, call, call_args); + return promise_filter_detail::MapResult( + &Derived::Call::OnServerTrailingMetadata, + promise_filter_detail::RaceAsyncCompletion< + promise_filter_detail::CallHasAsyncErrorInterceptor()>:: + Run(promise_filter_detail::RunCall( + &Derived::Call::OnClientInitialMetadata, + std::move(call_args), std::move(next_promise_factory), + call), + &call->error_latch), + call); + } +}; + // Designator for whether a filter is client side or server side. // Please don't use this outside calls to MakePromiseBasedFilter - it's // intended to be deleted once the promise conversion is complete.