diff --git a/src/core/ext/filters/http/client/http_client_filter.cc b/src/core/ext/filters/http/client/http_client_filter.cc index 07fcba62581..ac8004cdd3e 100644 --- a/src/core/ext/filters/http/client/http_client_filter.cc +++ b/src/core/ext/filters/http/client/http_client_filter.cc @@ -53,6 +53,7 @@ namespace grpc_core { const NoInterceptor HttpClientFilter::Call::OnServerToClientMessage; const NoInterceptor HttpClientFilter::Call::OnClientToServerMessage; +const NoInterceptor HttpClientFilter::Call::OnFinalize; const grpc_channel_filter HttpClientFilter::kFilter = MakePromiseBasedFilter { absl::Status OnServerTrailingMetadata(ServerMetadata& md); static const NoInterceptor OnClientToServerMessage; static const NoInterceptor OnServerToClientMessage; + static const NoInterceptor OnFinalize; }; private: diff --git a/src/core/ext/filters/http/server/http_server_filter.cc b/src/core/ext/filters/http/server/http_server_filter.cc index 830b931520f..38c2dda0e16 100644 --- a/src/core/ext/filters/http/server/http_server_filter.cc +++ b/src/core/ext/filters/http/server/http_server_filter.cc @@ -51,6 +51,7 @@ namespace grpc_core { const NoInterceptor HttpServerFilter::Call::OnClientToServerMessage; const NoInterceptor HttpServerFilter::Call::OnServerToClientMessage; +const NoInterceptor HttpServerFilter::Call::OnFinalize; const grpc_channel_filter HttpServerFilter::kFilter = MakePromiseBasedFilter { void OnServerTrailingMetadata(ServerMetadata& md); static const NoInterceptor OnClientToServerMessage; static const NoInterceptor OnServerToClientMessage; + static const NoInterceptor OnFinalize; }; private: diff --git a/src/core/ext/filters/message_size/message_size_filter.cc b/src/core/ext/filters/message_size/message_size_filter.cc index fcd5677ff51..9cfc00474c4 100644 --- a/src/core/ext/filters/message_size/message_size_filter.cc +++ b/src/core/ext/filters/message_size/message_size_filter.cc @@ -53,9 +53,11 @@ namespace grpc_core { const NoInterceptor ClientMessageSizeFilter::Call::OnClientInitialMetadata; const NoInterceptor ClientMessageSizeFilter::Call::OnServerInitialMetadata; const NoInterceptor ClientMessageSizeFilter::Call::OnServerTrailingMetadata; +const NoInterceptor ClientMessageSizeFilter::Call::OnFinalize; const NoInterceptor ServerMessageSizeFilter::Call::OnClientInitialMetadata; const NoInterceptor ServerMessageSizeFilter::Call::OnServerInitialMetadata; const NoInterceptor ServerMessageSizeFilter::Call::OnServerTrailingMetadata; +const NoInterceptor ServerMessageSizeFilter::Call::OnFinalize; // // MessageSizeParsedConfig diff --git a/src/core/ext/filters/message_size/message_size_filter.h b/src/core/ext/filters/message_size/message_size_filter.h index fdfba2fa788..647aeeed94f 100644 --- a/src/core/ext/filters/message_size/message_size_filter.h +++ b/src/core/ext/filters/message_size/message_size_filter.h @@ -99,6 +99,7 @@ class ServerMessageSizeFilter final static const NoInterceptor OnClientInitialMetadata; static const NoInterceptor OnServerInitialMetadata; static const NoInterceptor OnServerTrailingMetadata; + static const NoInterceptor OnFinalize; ServerMetadataHandle OnClientToServerMessage( const Message& message, ServerMessageSizeFilter* filter); ServerMetadataHandle OnServerToClientMessage( @@ -126,6 +127,7 @@ class ClientMessageSizeFilter final static const NoInterceptor OnClientInitialMetadata; static const NoInterceptor OnServerInitialMetadata; static const NoInterceptor OnServerTrailingMetadata; + static const NoInterceptor OnFinalize; ServerMetadataHandle OnClientToServerMessage(const Message& message); ServerMetadataHandle OnServerToClientMessage(const Message& message); diff --git a/src/core/lib/channel/promise_based_filter.h b/src/core/lib/channel/promise_based_filter.h index abb9ce0fbe0..ad33f366378 100644 --- a/src/core/lib/channel/promise_based_filter.h +++ b/src/core/lib/channel/promise_based_filter.h @@ -186,6 +186,11 @@ inline constexpr bool HasChannelAccess(R (T::*)(A)) { return false; } +template +inline constexpr bool HasChannelAccess(R (T::*)()) { + return false; +} + template inline constexpr bool HasChannelAccess(R (T::*)(A, C)) { return true; @@ -208,7 +213,8 @@ inline constexpr bool CallHasChannelAccess() { &Derived::Call::OnClientToServerMessage, &Derived::Call::OnServerInitialMetadata, &Derived::Call::OnServerToClientMessage, - &Derived::Call::OnServerTrailingMetadata); + &Derived::Call::OnServerTrailingMetadata, + &Derived::Call::OnFinalize); } // Given a boolean X export a type: @@ -642,6 +648,18 @@ inline void InterceptServerTrailingMetadata( }); } +inline void InterceptFinalize(const NoInterceptor*, void*) {} + +template +inline void InterceptFinalize(void (Call::*fn)(const grpc_call_final_info*), + Call* call) { + GPR_DEBUG_ASSERT(fn == &Call::OnFinalize); + GetContext()->Add( + [call](const grpc_call_final_info* final_info) { + call->OnFinalize(final_info); + }); +} + template absl::enable_if_t>::value, FilterCallData*> @@ -674,6 +692,7 @@ MakeFilterCall(Derived* derived) { // - OnServerToClientMessage - $VALUE_TYPE = Message // - OnClientToServerMessage - $VALUE_TYPE = Message // - OnServerTrailingMetadata - $VALUE_TYPE = ServerMetadata +// - OnFinalize - special, see below // These members define an interception point for a particular event in // the call lifecycle. // The type of these members matters, and is selectable by the class @@ -706,6 +725,12 @@ MakeFilterCall(Derived* derived) { // the filter can return nullptr for success, or a metadata handle for // failure (in which case the call will be aborted). // useful for cases where the exact metadata returned needs to be customized. +// Finally, OnFinalize can be added to intecept call finalization. +// It must have one of the signatures: +// - static const NoInterceptor OnFinalize: +// the filter does not intercept call finalization. +// - void OnFinalize(const grpc_call_final_info*): +// the filter intercepts call finalization. template class ImplementChannelFilter : public ChannelFilter { public: @@ -730,6 +755,7 @@ class ImplementChannelFilter : public ChannelFilter { promise_filter_detail::InterceptServerTrailingMetadata( &Derived::Call::OnServerTrailingMetadata, call, static_cast(this), call_spine); + promise_filter_detail::InterceptFinalize(&Derived::Call::OnFinalize, call); } // Polyfill for the original promise scheme. @@ -745,6 +771,9 @@ class ImplementChannelFilter : public ChannelFilter { &Derived::Call::OnServerInitialMetadata, call, call_args); promise_filter_detail::InterceptServerToClientMessage( &Derived::Call::OnServerToClientMessage, call, call_args); + promise_filter_detail::InterceptFinalize( + &Derived::Call::OnFinalize, + static_cast(&call->call)); return promise_filter_detail::MapResult( &Derived::Call::OnServerTrailingMetadata, promise_filter_detail::RaceAsyncCompletion< diff --git a/src/core/lib/channel/server_call_tracer_filter.cc b/src/core/lib/channel/server_call_tracer_filter.cc index 026216fd59e..c2450a97d4c 100644 --- a/src/core/lib/channel/server_call_tracer_filter.cc +++ b/src/core/lib/channel/server_call_tracer_filter.cc @@ -42,19 +42,55 @@ namespace grpc_core { namespace { -// TODO(yashykt): This filter is not really needed. We should be able to move -// this to the connected filter. -class ServerCallTracerFilter : public ChannelFilter { +class ServerCallTracerFilter + : public ImplementChannelFilter { public: static const grpc_channel_filter kFilter; static absl::StatusOr Create( const ChannelArgs& /*args*/, ChannelFilter::Args /*filter_args*/); - ArenaPromise MakeCallPromise( - CallArgs call_args, NextPromiseFactory next_promise_factory) override; + class Call { + public: + void OnClientInitialMetadata(ClientMetadata& client_initial_metadata) { + auto* call_tracer = CallTracer(); + if (call_tracer == nullptr) return; + call_tracer->RecordReceivedInitialMetadata(&client_initial_metadata); + } + + void OnServerInitialMetadata(ServerMetadata& server_initial_metadata) { + auto* call_tracer = CallTracer(); + if (call_tracer == nullptr) return; + call_tracer->RecordSendInitialMetadata(&server_initial_metadata); + } + + void OnFinalize(const grpc_call_final_info* final_info) { + auto* call_tracer = CallTracer(); + if (call_tracer == nullptr) return; + call_tracer->RecordEnd(final_info); + } + + void OnServerTrailingMetadata(ServerMetadata& server_trailing_metadata) { + auto* call_tracer = CallTracer(); + if (call_tracer == nullptr) return; + call_tracer->RecordSendTrailingMetadata(&server_trailing_metadata); + } + + static const NoInterceptor OnClientToServerMessage; + static const NoInterceptor OnServerToClientMessage; + + private: + static ServerCallTracer* CallTracer() { + auto* call_context = GetContext(); + return static_cast( + call_context[GRPC_CONTEXT_CALL_TRACER].value); + } + }; }; +const NoInterceptor ServerCallTracerFilter::Call::OnClientToServerMessage; +const NoInterceptor ServerCallTracerFilter::Call::OnServerToClientMessage; + const grpc_channel_filter ServerCallTracerFilter::kFilter = MakePromiseBasedFilter( @@ -65,34 +101,6 @@ absl::StatusOr ServerCallTracerFilter::Create( return ServerCallTracerFilter(); } -ArenaPromise ServerCallTracerFilter::MakeCallPromise( - CallArgs call_args, NextPromiseFactory next_promise_factory) { - auto* call_context = GetContext(); - auto* call_tracer = static_cast( - call_context[GRPC_CONTEXT_CALL_TRACER].value); - if (call_tracer == nullptr) { - return next_promise_factory(std::move(call_args)); - } - call_tracer->RecordReceivedInitialMetadata( - call_args.client_initial_metadata.get()); - call_args.server_initial_metadata->InterceptAndMap( - [call_tracer](ServerMetadataHandle metadata) { - call_tracer->RecordSendInitialMetadata(metadata.get()); - return metadata; - }); - GetContext()->Add( - [call_tracer](const grpc_call_final_info* final_info) { - call_tracer->RecordEnd(final_info); - }); - return OnCancel( - Map(next_promise_factory(std::move(call_args)), - [call_tracer](ServerMetadataHandle md) { - call_tracer->RecordSendTrailingMetadata(md.get()); - return md; - }), - [call_tracer]() { call_tracer->RecordCancel(absl::CancelledError()); }); -} - } // namespace void RegisterServerCallTracerFilter(CoreConfiguration::Builder* builder) {