diff --git a/src/core/ext/filters/backend_metrics/backend_metric_filter.cc b/src/core/ext/filters/backend_metrics/backend_metric_filter.cc index c3ce55921d8..d9f02b41bf7 100644 --- a/src/core/ext/filters/backend_metrics/backend_metric_filter.cc +++ b/src/core/ext/filters/backend_metrics/backend_metric_filter.cc @@ -52,6 +52,7 @@ TraceFlag grpc_backend_metric_filter_trace(false, "backend_metric_filter"); const NoInterceptor BackendMetricFilter::Call::OnClientInitialMetadata; const NoInterceptor BackendMetricFilter::Call::OnServerInitialMetadata; const NoInterceptor BackendMetricFilter::Call::OnClientToServerMessage; +const NoInterceptor BackendMetricFilter::Call::OnClientToServerHalfClose; const NoInterceptor BackendMetricFilter::Call::OnServerToClientMessage; const NoInterceptor BackendMetricFilter::Call::OnFinalize; diff --git a/src/core/ext/filters/backend_metrics/backend_metric_filter.h b/src/core/ext/filters/backend_metrics/backend_metric_filter.h index d97b0c8cb65..114fc3cc7bc 100644 --- a/src/core/ext/filters/backend_metrics/backend_metric_filter.h +++ b/src/core/ext/filters/backend_metrics/backend_metric_filter.h @@ -44,6 +44,7 @@ class BackendMetricFilter : public ImplementChannelFilter { static const NoInterceptor OnServerInitialMetadata; void OnServerTrailingMetadata(ServerMetadata& md); static const NoInterceptor OnClientToServerMessage; + static const NoInterceptor OnClientToServerHalfClose; static const NoInterceptor OnServerToClientMessage; static const NoInterceptor OnFinalize; }; diff --git a/src/core/ext/filters/fault_injection/fault_injection_filter.cc b/src/core/ext/filters/fault_injection/fault_injection_filter.cc index 87d5a4d2f1f..ae64f83ddc8 100644 --- a/src/core/ext/filters/fault_injection/fault_injection_filter.cc +++ b/src/core/ext/filters/fault_injection/fault_injection_filter.cc @@ -58,6 +58,7 @@ TraceFlag grpc_fault_injection_filter_trace(false, "fault_injection_filter"); const NoInterceptor FaultInjectionFilter::Call::OnServerInitialMetadata; const NoInterceptor FaultInjectionFilter::Call::OnServerTrailingMetadata; const NoInterceptor FaultInjectionFilter::Call::OnClientToServerMessage; +const NoInterceptor FaultInjectionFilter::Call::OnClientToServerHalfClose; const NoInterceptor FaultInjectionFilter::Call::OnServerToClientMessage; const NoInterceptor FaultInjectionFilter::Call::OnFinalize; diff --git a/src/core/ext/filters/fault_injection/fault_injection_filter.h b/src/core/ext/filters/fault_injection/fault_injection_filter.h index b6b1b811cde..515df16a853 100644 --- a/src/core/ext/filters/fault_injection/fault_injection_filter.h +++ b/src/core/ext/filters/fault_injection/fault_injection_filter.h @@ -58,6 +58,7 @@ class FaultInjectionFilter static const NoInterceptor OnServerInitialMetadata; static const NoInterceptor OnServerTrailingMetadata; static const NoInterceptor OnClientToServerMessage; + static const NoInterceptor OnClientToServerHalfClose; static const NoInterceptor OnServerToClientMessage; static const NoInterceptor OnFinalize; }; diff --git a/src/core/ext/filters/http/client/http_client_filter.cc b/src/core/ext/filters/http/client/http_client_filter.cc index 390df545efe..6af2b959be3 100644 --- a/src/core/ext/filters/http/client/http_client_filter.cc +++ b/src/core/ext/filters/http/client/http_client_filter.cc @@ -54,6 +54,7 @@ namespace grpc_core { const NoInterceptor HttpClientFilter::Call::OnServerToClientMessage; const NoInterceptor HttpClientFilter::Call::OnClientToServerMessage; +const NoInterceptor HttpClientFilter::Call::OnClientToServerHalfClose; const NoInterceptor HttpClientFilter::Call::OnFinalize; const grpc_channel_filter HttpClientFilter::kFilter = diff --git a/src/core/ext/filters/http/client/http_client_filter.h b/src/core/ext/filters/http/client/http_client_filter.h index f5d7875da5e..f985337f2ca 100644 --- a/src/core/ext/filters/http/client/http_client_filter.h +++ b/src/core/ext/filters/http/client/http_client_filter.h @@ -47,6 +47,7 @@ class HttpClientFilter : public ImplementChannelFilter { absl::Status OnServerInitialMetadata(ServerMetadata& md); absl::Status OnServerTrailingMetadata(ServerMetadata& md); static const NoInterceptor OnClientToServerMessage; + static const NoInterceptor OnClientToServerHalfClose; static const NoInterceptor OnServerToClientMessage; static const NoInterceptor OnFinalize; }; diff --git a/src/core/ext/filters/http/client_authority_filter.cc b/src/core/ext/filters/http/client_authority_filter.cc index 1d5258493e4..b6970d9ecb5 100644 --- a/src/core/ext/filters/http/client_authority_filter.cc +++ b/src/core/ext/filters/http/client_authority_filter.cc @@ -40,6 +40,7 @@ namespace grpc_core { const NoInterceptor ClientAuthorityFilter::Call::OnServerInitialMetadata; const NoInterceptor ClientAuthorityFilter::Call::OnServerTrailingMetadata; const NoInterceptor ClientAuthorityFilter::Call::OnClientToServerMessage; +const NoInterceptor ClientAuthorityFilter::Call::OnClientToServerHalfClose; const NoInterceptor ClientAuthorityFilter::Call::OnServerToClientMessage; const NoInterceptor ClientAuthorityFilter::Call::OnFinalize; diff --git a/src/core/ext/filters/http/client_authority_filter.h b/src/core/ext/filters/http/client_authority_filter.h index 44229c6cdde..da154fbac5d 100644 --- a/src/core/ext/filters/http/client_authority_filter.h +++ b/src/core/ext/filters/http/client_authority_filter.h @@ -52,6 +52,7 @@ class ClientAuthorityFilter final static const NoInterceptor OnServerInitialMetadata; static const NoInterceptor OnServerTrailingMetadata; static const NoInterceptor OnClientToServerMessage; + static const NoInterceptor OnClientToServerHalfClose; static const NoInterceptor OnServerToClientMessage; static const NoInterceptor OnFinalize; }; diff --git a/src/core/ext/filters/http/message_compress/compression_filter.cc b/src/core/ext/filters/http/message_compress/compression_filter.cc index b688b866cf9..ed077de6e70 100644 --- a/src/core/ext/filters/http/message_compress/compression_filter.cc +++ b/src/core/ext/filters/http/message_compress/compression_filter.cc @@ -57,8 +57,10 @@ namespace grpc_core { +const NoInterceptor ServerCompressionFilter::Call::OnClientToServerHalfClose; const NoInterceptor ServerCompressionFilter::Call::OnServerTrailingMetadata; const NoInterceptor ServerCompressionFilter::Call::OnFinalize; +const NoInterceptor ClientCompressionFilter::Call::OnClientToServerHalfClose; const NoInterceptor ClientCompressionFilter::Call::OnServerTrailingMetadata; const NoInterceptor ClientCompressionFilter::Call::OnFinalize; diff --git a/src/core/ext/filters/http/message_compress/compression_filter.h b/src/core/ext/filters/http/message_compress/compression_filter.h index 99e57a0ac1d..5d82846d01d 100644 --- a/src/core/ext/filters/http/message_compress/compression_filter.h +++ b/src/core/ext/filters/http/message_compress/compression_filter.h @@ -129,6 +129,7 @@ class ClientCompressionFilter final absl::StatusOr OnServerToClientMessage( MessageHandle message, ClientCompressionFilter* filter); + static const NoInterceptor OnClientToServerHalfClose; static const NoInterceptor OnServerTrailingMetadata; static const NoInterceptor OnFinalize; @@ -165,6 +166,7 @@ class ServerCompressionFilter final MessageHandle OnServerToClientMessage(MessageHandle message, ServerCompressionFilter* filter); + static const NoInterceptor OnClientToServerHalfClose; static const NoInterceptor OnServerTrailingMetadata; static const NoInterceptor OnFinalize; diff --git a/src/core/ext/filters/http/server/http_server_filter.cc b/src/core/ext/filters/http/server/http_server_filter.cc index 925cc73c23e..4d92cef6eaa 100644 --- a/src/core/ext/filters/http/server/http_server_filter.cc +++ b/src/core/ext/filters/http/server/http_server_filter.cc @@ -50,6 +50,7 @@ namespace grpc_core { const NoInterceptor HttpServerFilter::Call::OnClientToServerMessage; +const NoInterceptor HttpServerFilter::Call::OnClientToServerHalfClose; const NoInterceptor HttpServerFilter::Call::OnServerToClientMessage; const NoInterceptor HttpServerFilter::Call::OnFinalize; diff --git a/src/core/ext/filters/http/server/http_server_filter.h b/src/core/ext/filters/http/server/http_server_filter.h index 282973ddecd..a1f330e58bb 100644 --- a/src/core/ext/filters/http/server/http_server_filter.h +++ b/src/core/ext/filters/http/server/http_server_filter.h @@ -50,6 +50,7 @@ class HttpServerFilter : public ImplementChannelFilter { void OnServerInitialMetadata(ServerMetadata& md); void OnServerTrailingMetadata(ServerMetadata& md); static const NoInterceptor OnClientToServerMessage; + static const NoInterceptor OnClientToServerHalfClose; static const NoInterceptor OnServerToClientMessage; static const NoInterceptor OnFinalize; }; diff --git a/src/core/ext/filters/load_reporting/server_load_reporting_filter.cc b/src/core/ext/filters/load_reporting/server_load_reporting_filter.cc index 7937ab6fe74..100584f5ba1 100644 --- a/src/core/ext/filters/load_reporting/server_load_reporting_filter.cc +++ b/src/core/ext/filters/load_reporting/server_load_reporting_filter.cc @@ -74,6 +74,7 @@ constexpr char kEmptyAddressLengthString[] = "00"; const NoInterceptor ServerLoadReportingFilter::Call::OnServerInitialMetadata; const NoInterceptor ServerLoadReportingFilter::Call::OnClientToServerMessage; +const NoInterceptor ServerLoadReportingFilter::Call::OnClientToServerHalfClose; const NoInterceptor ServerLoadReportingFilter::Call::OnServerToClientMessage; absl::StatusOr> diff --git a/src/core/ext/filters/load_reporting/server_load_reporting_filter.h b/src/core/ext/filters/load_reporting/server_load_reporting_filter.h index f11c8c38bcf..76684093a0b 100644 --- a/src/core/ext/filters/load_reporting/server_load_reporting_filter.h +++ b/src/core/ext/filters/load_reporting/server_load_reporting_filter.h @@ -54,6 +54,7 @@ class ServerLoadReportingFilter void OnServerTrailingMetadata(ServerMetadata& md, ServerLoadReportingFilter* filter); static const NoInterceptor OnClientToServerMessage; + static const NoInterceptor OnClientToServerHalfClose; static const NoInterceptor OnServerToClientMessage; void OnFinalize(const grpc_call_final_info* final_info, ServerLoadReportingFilter* filter); diff --git a/src/core/ext/filters/logging/logging_filter.cc b/src/core/ext/filters/logging/logging_filter.cc index 587e87ff506..1980f880e78 100644 --- a/src/core/ext/filters/logging/logging_filter.cc +++ b/src/core/ext/filters/logging/logging_filter.cc @@ -73,6 +73,9 @@ namespace grpc_core { +const NoInterceptor ClientLoggingFilter::Call::OnFinalize; +const NoInterceptor ServerLoggingFilter::Call::OnFinalize; + namespace { LoggingSink* g_logging_sink = nullptr; @@ -195,152 +198,147 @@ void EncodeMessageToPayload(const SliceBuffer* message, uint32_t log_len, } } -class CallData { - public: - CallData(bool is_client, const CallArgs& call_args, - const std::string& authority) - : call_id_(GetCallId()) { - absl::string_view path; - if (auto* value = call_args.client_initial_metadata->get_pointer( - HttpPathMetadata())) { - path = value->as_string_view(); - } - std::vector parts = - absl::StrSplit(path, '/', absl::SkipEmpty()); - if (parts.size() == 2) { - service_name_ = std::move(parts[0]); - method_name_ = std::move(parts[1]); +} // namespace + +namespace logging_filter_detail { + +CallData::CallData(bool is_client, + const ClientMetadata& client_initial_metadata, + const std::string& authority) + : call_id_(GetCallId()) { + absl::string_view path; + if (auto* value = client_initial_metadata.get_pointer(HttpPathMetadata())) { + path = value->as_string_view(); + } + std::vector parts = absl::StrSplit(path, '/', absl::SkipEmpty()); + if (parts.size() == 2) { + service_name_ = std::move(parts[0]); + method_name_ = std::move(parts[1]); + } + config_ = g_logging_sink->FindMatch(is_client, service_name_, method_name_); + if (config_.ShouldLog()) { + if (auto* value = + client_initial_metadata.get_pointer(HttpAuthorityMetadata())) { + authority_ = std::string(value->as_string_view()); + } else { + authority_ = authority; } - config_ = g_logging_sink->FindMatch(is_client, service_name_, method_name_); - if (config_.ShouldLog()) { - if (auto* value = call_args.client_initial_metadata->get_pointer( - HttpAuthorityMetadata())) { - authority_ = std::string(value->as_string_view()); - } else { - authority_ = authority; - } + } +} + +void CallData::LogClientHeader(bool is_client, + CallTracerAnnotationInterface* tracer, + const ClientMetadata& metadata) { + LoggingSink::Entry entry; + if (!is_client) { + if (auto* value = metadata.get_pointer(PeerString())) { + peer_ = PeerStringToAddress(*value); } } + SetCommonEntryFields(&entry, is_client, tracer, + LoggingSink::Entry::EventType::kClientHeader); + MetadataEncoder encoder(&entry.payload, nullptr, + config_.max_metadata_bytes()); + metadata.Encode(&encoder); + entry.payload_truncated = encoder.truncated(); + g_logging_sink->LogEntry(std::move(entry)); +} - bool ShouldLog() { return config_.ShouldLog(); } +void CallData::LogClientHalfClose(bool is_client, + CallTracerAnnotationInterface* tracer) { + LoggingSink::Entry entry; + SetCommonEntryFields(&entry, is_client, tracer, + LoggingSink::Entry::EventType::kClientHalfClose); + g_logging_sink->LogEntry(std::move(entry)); +} - void LogClientHeader(bool is_client, CallTracerAnnotationInterface* tracer, - const ClientMetadataHandle& metadata) { - LoggingSink::Entry entry; - if (!is_client) { +void CallData::LogServerHeader(bool is_client, + CallTracerAnnotationInterface* tracer, + const ServerMetadata* metadata) { + LoggingSink::Entry entry; + if (metadata != nullptr) { + entry.is_trailer_only = metadata->get(GrpcTrailersOnly()).value_or(false); + if (is_client) { if (auto* value = metadata->get_pointer(PeerString())) { peer_ = PeerStringToAddress(*value); } } - SetCommonEntryFields(&entry, is_client, tracer, - LoggingSink::Entry::EventType::kClientHeader); + } + SetCommonEntryFields(&entry, is_client, tracer, + LoggingSink::Entry::EventType::kServerHeader); + if (metadata != nullptr) { MetadataEncoder encoder(&entry.payload, nullptr, config_.max_metadata_bytes()); metadata->Encode(&encoder); entry.payload_truncated = encoder.truncated(); - g_logging_sink->LogEntry(std::move(entry)); - } - - void LogClientHalfClose(bool is_client, - CallTracerAnnotationInterface* tracer) { - LoggingSink::Entry entry; - SetCommonEntryFields(&entry, is_client, tracer, - LoggingSink::Entry::EventType::kClientHalfClose); - g_logging_sink->LogEntry(std::move(entry)); - } - - void LogServerHeader(bool is_client, CallTracerAnnotationInterface* tracer, - const ServerMetadata* metadata) { - LoggingSink::Entry entry; - if (metadata != nullptr) { - entry.is_trailer_only = metadata->get(GrpcTrailersOnly()).value_or(false); - if (is_client) { - if (auto* value = metadata->get_pointer(PeerString())) { - peer_ = PeerStringToAddress(*value); - } - } - } - SetCommonEntryFields(&entry, is_client, tracer, - LoggingSink::Entry::EventType::kServerHeader); - if (metadata != nullptr) { - MetadataEncoder encoder(&entry.payload, nullptr, - config_.max_metadata_bytes()); - metadata->Encode(&encoder); - entry.payload_truncated = encoder.truncated(); - } - g_logging_sink->LogEntry(std::move(entry)); - } - - void LogServerTrailer(bool is_client, CallTracerAnnotationInterface* tracer, - const ServerMetadata* metadata) { - LoggingSink::Entry entry; - SetCommonEntryFields(&entry, is_client, tracer, - LoggingSink::Entry::EventType::kServerTrailer); - if (metadata != nullptr) { - entry.is_trailer_only = metadata->get(GrpcTrailersOnly()).value_or(false); - MetadataEncoder encoder(&entry.payload, &entry.payload.status_details, - config_.max_metadata_bytes()); - metadata->Encode(&encoder); - entry.payload_truncated = encoder.truncated(); - } - g_logging_sink->LogEntry(std::move(entry)); } + g_logging_sink->LogEntry(std::move(entry)); +} - void LogClientMessage(bool is_client, CallTracerAnnotationInterface* tracer, - const SliceBuffer* message) { - LoggingSink::Entry entry; - SetCommonEntryFields(&entry, is_client, tracer, - LoggingSink::Entry::EventType::kClientMessage); - EncodeMessageToPayload(message, config_.max_message_bytes(), &entry); - g_logging_sink->LogEntry(std::move(entry)); +void CallData::LogServerTrailer(bool is_client, + CallTracerAnnotationInterface* tracer, + const ServerMetadata* metadata) { + LoggingSink::Entry entry; + SetCommonEntryFields(&entry, is_client, tracer, + LoggingSink::Entry::EventType::kServerTrailer); + if (metadata != nullptr) { + entry.is_trailer_only = metadata->get(GrpcTrailersOnly()).value_or(false); + MetadataEncoder encoder(&entry.payload, &entry.payload.status_details, + config_.max_metadata_bytes()); + metadata->Encode(&encoder); + entry.payload_truncated = encoder.truncated(); } + g_logging_sink->LogEntry(std::move(entry)); +} +void CallData::LogClientMessage(bool is_client, + CallTracerAnnotationInterface* tracer, + const SliceBuffer* message) { + LoggingSink::Entry entry; + SetCommonEntryFields(&entry, is_client, tracer, + LoggingSink::Entry::EventType::kClientMessage); + EncodeMessageToPayload(message, config_.max_message_bytes(), &entry); + g_logging_sink->LogEntry(std::move(entry)); +} - void LogServerMessage(bool is_client, CallTracerAnnotationInterface* tracer, - const SliceBuffer* message) { - LoggingSink::Entry entry; - SetCommonEntryFields(&entry, is_client, tracer, - LoggingSink::Entry::EventType::kServerMessage); - EncodeMessageToPayload(message, config_.max_message_bytes(), &entry); - g_logging_sink->LogEntry(std::move(entry)); - } +void CallData::LogServerMessage(bool is_client, + CallTracerAnnotationInterface* tracer, + const SliceBuffer* message) { + LoggingSink::Entry entry; + SetCommonEntryFields(&entry, is_client, tracer, + LoggingSink::Entry::EventType::kServerMessage); + EncodeMessageToPayload(message, config_.max_message_bytes(), &entry); + g_logging_sink->LogEntry(std::move(entry)); +} - void LogCancel(bool is_client, CallTracerAnnotationInterface* tracer) { - LoggingSink::Entry entry; - SetCommonEntryFields(&entry, is_client, tracer, - LoggingSink::Entry::EventType::kCancel); - g_logging_sink->LogEntry(std::move(entry)); - } +void CallData::LogCancel(bool is_client, + CallTracerAnnotationInterface* tracer) { + LoggingSink::Entry entry; + SetCommonEntryFields(&entry, is_client, tracer, + LoggingSink::Entry::EventType::kCancel); + g_logging_sink->LogEntry(std::move(entry)); +} - private: - void SetCommonEntryFields(LoggingSink::Entry* entry, bool is_client, - CallTracerAnnotationInterface* tracer, - LoggingSink::Entry::EventType event_type) { - entry->call_id = call_id_; - entry->sequence_id = sequence_id_++; - entry->type = event_type; - entry->logger = is_client ? LoggingSink::Entry::Logger::kClient - : LoggingSink::Entry::Logger::kServer; - entry->authority = authority_; - entry->peer = peer_; - entry->service_name = service_name_; - entry->method_name = method_name_; - entry->timestamp = Timestamp::Now(); - if (tracer != nullptr) { - entry->trace_id = tracer->TraceId(); - entry->span_id = tracer->SpanId(); - entry->is_sampled = tracer->IsSampled(); - } +void CallData::SetCommonEntryFields(LoggingSink::Entry* entry, bool is_client, + CallTracerAnnotationInterface* tracer, + LoggingSink::Entry::EventType event_type) { + entry->call_id = call_id_; + entry->sequence_id = sequence_id_++; + entry->type = event_type; + entry->logger = is_client ? LoggingSink::Entry::Logger::kClient + : LoggingSink::Entry::Logger::kServer; + entry->authority = authority_; + entry->peer = peer_; + entry->service_name = service_name_; + entry->method_name = method_name_; + entry->timestamp = Timestamp::Now(); + if (tracer != nullptr) { + entry->trace_id = tracer->TraceId(); + entry->span_id = tracer->SpanId(); + entry->is_sampled = tracer->IsSampled(); } - absl::uint128 call_id_; - uint32_t sequence_id_ = 0; - std::string service_name_; - std::string method_name_; - std::string authority_; - LoggingSink::Entry::Address peer_; - LoggingSink::Config config_; -}; +} -} // namespace +} // namespace logging_filter_detail absl::StatusOr> ClientLoggingFilter::Create(const ChannelArgs& args, @@ -361,84 +359,49 @@ ClientLoggingFilter::Create(const ChannelArgs& args, return std::make_unique(""); } -// Construct a promise for one call. -ArenaPromise ClientLoggingFilter::MakeCallPromise( - CallArgs call_args, NextPromiseFactory next_promise_factory) { - CallData* calld = GetContext()->ManagedNew( - true, call_args, default_authority_); - if (!calld->ShouldLog()) { - return next_promise_factory(std::move(call_args)); +void ClientLoggingFilter::Call::OnClientInitialMetadata( + ClientMetadata& md, ClientLoggingFilter* filter) { + call_data_.emplace(true, md, filter->default_authority_); + if (!call_data_->ShouldLog()) { + call_data_.reset(); + return; } - calld->LogClientHeader( - /*is_client=*/true, - static_cast( - GetContext() - [GRPC_CONTEXT_CALL_TRACER_ANNOTATION_INTERFACE] - .value), - call_args.client_initial_metadata); - call_args.server_initial_metadata->InterceptAndMap( - [calld](ServerMetadataHandle metadata) { - calld->LogServerHeader( - /*is_client=*/true, - static_cast( - GetContext() - [GRPC_CONTEXT_CALL_TRACER_ANNOTATION_INTERFACE] - .value), - metadata.get()); - return metadata; - }); - call_args.client_to_server_messages->InterceptAndMapWithHalfClose( - [calld](MessageHandle message) { - calld->LogClientMessage( - /*is_client=*/true, - static_cast( - GetContext() - [GRPC_CONTEXT_CALL_TRACER_ANNOTATION_INTERFACE] - .value), - message->payload()); - return message; - }, - [calld] { - calld->LogClientHalfClose( - /*is_client=*/true, - static_cast( - GetContext() - [GRPC_CONTEXT_CALL_TRACER_ANNOTATION_INTERFACE] - .value)); - }); - call_args.server_to_client_messages->InterceptAndMap( - [calld](MessageHandle message) { - calld->LogServerMessage( - /*is_client=*/true, - static_cast( - GetContext() - [GRPC_CONTEXT_CALL_TRACER_ANNOTATION_INTERFACE] - .value), - message->payload()); - return message; - }); - return OnCancel( - Map(next_promise_factory(std::move(call_args)), - [calld](ServerMetadataHandle md) { - calld->LogServerTrailer( - /*is_client=*/true, - static_cast( - GetContext() - [GRPC_CONTEXT_CALL_TRACER_ANNOTATION_INTERFACE] - .value), - md.get()); - return md; - }), - // TODO(yashykt/ctiller): GetContext is not - // valid for the cancellation function requiring us to capture it here. - // This ought to be easy to fix once client side promises are completely - // rolled out. - [calld, ctx = GetContext()]() { - calld->LogCancel( - /*is_client=*/true, - static_cast( - ctx[GRPC_CONTEXT_CALL_TRACER_ANNOTATION_INTERFACE].value)); - }); + call_data_->LogClientHeader( + /*is_client=*/true, GetContext(), md); +} + +void ClientLoggingFilter::Call::OnServerInitialMetadata(ServerMetadata& md) { + if (!call_data_.has_value()) return; + call_data_->LogServerHeader( + /*is_client=*/true, GetContext(), &md); +} + +void ClientLoggingFilter::Call::OnServerTrailingMetadata(ServerMetadata& md) { + if (!call_data_.has_value()) return; + call_data_->LogServerTrailer( + /*is_client=*/true, GetContext(), &md); +} + +void ClientLoggingFilter::Call::OnClientToServerMessage( + const Message& message) { + if (!call_data_.has_value()) return; + call_data_->LogClientMessage( + /*is_client=*/true, GetContext(), + message.payload()); +} + +void ClientLoggingFilter::Call::OnClientToServerHalfClose() { + if (!call_data_.has_value()) return; + call_data_->LogClientHalfClose( + /*is_client=*/true, GetContext()); +} + +void ClientLoggingFilter::Call::OnServerToClientMessage( + const Message& message) { + if (!call_data_.has_value()) return; + call_data_->LogServerMessage( + /*is_client=*/true, GetContext(), + message.payload()); } const grpc_channel_filter ClientLoggingFilter::kFilter = @@ -454,79 +417,49 @@ ServerLoggingFilter::Create(const ChannelArgs& /*args*/, } // Construct a promise for one call. -ArenaPromise ServerLoggingFilter::MakeCallPromise( - CallArgs call_args, NextPromiseFactory next_promise_factory) { - CallData* calld = GetContext()->ManagedNew( - false, call_args, /*default_authority=*/""); - if (!calld->ShouldLog()) { - return next_promise_factory(std::move(call_args)); +void ServerLoggingFilter::Call::OnClientInitialMetadata( + ClientMetadata& md, ServerLoggingFilter* filter) { + call_data_.emplace(false, md, ""); + if (!call_data_->ShouldLog()) { + call_data_.reset(); + return; } - auto* call_tracer = static_cast( - GetContext() - [GRPC_CONTEXT_CALL_TRACER_ANNOTATION_INTERFACE] - .value); - calld->LogClientHeader( - /*is_client=*/false, call_tracer, call_args.client_initial_metadata); - call_args.server_initial_metadata->InterceptAndMap( - [calld](ServerMetadataHandle metadata) { - calld->LogServerHeader( - /*is_client=*/false, - static_cast( - GetContext() - [GRPC_CONTEXT_CALL_TRACER_ANNOTATION_INTERFACE] - .value), - metadata.get()); - return metadata; - }); - call_args.client_to_server_messages->InterceptAndMapWithHalfClose( - [calld](MessageHandle message) { - calld->LogClientMessage( - /*is_client=*/false, - static_cast( - GetContext() - [GRPC_CONTEXT_CALL_TRACER_ANNOTATION_INTERFACE] - .value), - message->payload()); - return message; - }, - [calld] { - calld->LogClientHalfClose( - /*is_client=*/false, - static_cast( - GetContext() - [GRPC_CONTEXT_CALL_TRACER_ANNOTATION_INTERFACE] - .value)); - }); - call_args.server_to_client_messages->InterceptAndMap( - [calld](MessageHandle message) { - calld->LogServerMessage( - /*is_client=*/false, - static_cast( - GetContext() - [GRPC_CONTEXT_CALL_TRACER_ANNOTATION_INTERFACE] - .value), - message->payload()); - return message; - }); - return OnCancel( - Map(next_promise_factory(std::move(call_args)), - [calld](ServerMetadataHandle md) { - calld->LogServerTrailer( - /*is_client=*/false, - static_cast( - GetContext() - [GRPC_CONTEXT_CALL_TRACER_ANNOTATION_INTERFACE] - .value), - md.get()); - return md; - }), - // TODO(yashykt/ctiller): GetContext is not - // valid for the cancellation function requiring us to capture - // call_tracer. - [calld, call_tracer]() { - calld->LogCancel( - /*is_client=*/false, call_tracer); - }); + call_data_->LogClientHeader( + /*is_client=*/false, GetContext(), md); +} + +void ServerLoggingFilter::Call::OnServerInitialMetadata(ServerMetadata& md) { + if (!call_data_.has_value()) return; + call_data_->LogServerHeader( + /*is_client=*/false, GetContext(), &md); +} + +void ServerLoggingFilter::Call::OnServerTrailingMetadata(ServerMetadata& md) { + if (!call_data_.has_value()) return; + call_data_->LogServerTrailer( + /*is_client=*/false, GetContext(), &md); +} + +void ServerLoggingFilter::Call::OnClientToServerMessage( + const Message& message) { + if (!call_data_.has_value()) return; + call_data_->LogClientMessage( + /*is_client=*/false, GetContext(), + message.payload()); +} + +void ServerLoggingFilter::Call::OnClientToServerHalfClose() { + if (!call_data_.has_value()) return; + call_data_->LogClientHalfClose( + /*is_client=*/false, GetContext()); +} + +void ServerLoggingFilter::Call::OnServerToClientMessage( + const Message& message) { + if (!call_data_.has_value()) return; + call_data_->LogServerMessage( + /*is_client=*/false, GetContext(), + message.payload()); } const grpc_channel_filter ServerLoggingFilter::kFilter = diff --git a/src/core/ext/filters/logging/logging_filter.h b/src/core/ext/filters/logging/logging_filter.h index 7d42abbc337..77832652212 100644 --- a/src/core/ext/filters/logging/logging_filter.h +++ b/src/core/ext/filters/logging/logging_filter.h @@ -35,7 +35,46 @@ namespace grpc_core { -class ClientLoggingFilter final : public ChannelFilter { +namespace logging_filter_detail { + +class CallData { + public: + CallData(bool is_client, const ClientMetadata& client_initial_metadata, + const std::string& authority); + + bool ShouldLog() { return config_.ShouldLog(); } + + void LogClientHeader(bool is_client, CallTracerAnnotationInterface* tracer, + const ClientMetadata& metadata); + void LogClientHalfClose(bool is_client, + CallTracerAnnotationInterface* tracer); + void LogServerHeader(bool is_client, CallTracerAnnotationInterface* tracer, + const ServerMetadata* metadata); + void LogServerTrailer(bool is_client, CallTracerAnnotationInterface* tracer, + const ServerMetadata* metadata); + void LogClientMessage(bool is_client, CallTracerAnnotationInterface* tracer, + const SliceBuffer* message); + void LogServerMessage(bool is_client, CallTracerAnnotationInterface* tracer, + const SliceBuffer* message); + void LogCancel(bool is_client, CallTracerAnnotationInterface* tracer); + + private: + void SetCommonEntryFields(LoggingSink::Entry* entry, bool is_client, + CallTracerAnnotationInterface* tracer, + LoggingSink::Entry::EventType event_type); + absl::uint128 call_id_; + uint32_t sequence_id_ = 0; + std::string service_name_; + std::string method_name_; + std::string authority_; + LoggingSink::Entry::Address peer_; + LoggingSink::Config config_; +}; + +} // namespace logging_filter_detail + +class ClientLoggingFilter final + : public ImplementChannelFilter { public: static const grpc_channel_filter kFilter; @@ -45,24 +84,47 @@ class ClientLoggingFilter final : public ChannelFilter { explicit ClientLoggingFilter(std::string default_authority) : default_authority_(std::move(default_authority)) {} - // Construct a promise for one call. - ArenaPromise MakeCallPromise( - CallArgs call_args, NextPromiseFactory next_promise_factory) override; + class Call { + public: + void OnClientInitialMetadata(ClientMetadata& md, + ClientLoggingFilter* filter); + void OnServerInitialMetadata(ServerMetadata& md); + void OnServerTrailingMetadata(ServerMetadata& md); + void OnClientToServerMessage(const Message& message); + void OnClientToServerHalfClose(); + void OnServerToClientMessage(const Message& message); + static const NoInterceptor OnFinalize; + + private: + absl::optional call_data_; + }; private: const std::string default_authority_; }; -class ServerLoggingFilter final : public ChannelFilter { +class ServerLoggingFilter final + : public ImplementChannelFilter { public: static const grpc_channel_filter kFilter; static absl::StatusOr> Create( const ChannelArgs& args, ChannelFilter::Args /*filter_args*/); - // Construct a promise for one call. - ArenaPromise MakeCallPromise( - CallArgs call_args, NextPromiseFactory next_promise_factory) override; + class Call { + public: + void OnClientInitialMetadata(ClientMetadata& md, + ServerLoggingFilter* filter); + void OnServerInitialMetadata(ServerMetadata& md); + void OnServerTrailingMetadata(ServerMetadata& md); + void OnClientToServerMessage(const Message& message); + void OnClientToServerHalfClose(); + void OnServerToClientMessage(const Message& message); + static const NoInterceptor OnFinalize; + + private: + absl::optional call_data_; + }; }; void RegisterLoggingFilter(LoggingSink* sink); diff --git a/src/core/ext/filters/message_size/message_size_filter.cc b/src/core/ext/filters/message_size/message_size_filter.cc index 0933d633272..2c73c63a370 100644 --- a/src/core/ext/filters/message_size/message_size_filter.cc +++ b/src/core/ext/filters/message_size/message_size_filter.cc @@ -51,10 +51,12 @@ namespace grpc_core { const NoInterceptor ClientMessageSizeFilter::Call::OnClientInitialMetadata; const NoInterceptor ClientMessageSizeFilter::Call::OnServerInitialMetadata; const NoInterceptor ClientMessageSizeFilter::Call::OnServerTrailingMetadata; +const NoInterceptor ClientMessageSizeFilter::Call::OnClientToServerHalfClose; const NoInterceptor ClientMessageSizeFilter::Call::OnFinalize; const NoInterceptor ServerMessageSizeFilter::Call::OnClientInitialMetadata; const NoInterceptor ServerMessageSizeFilter::Call::OnServerInitialMetadata; const NoInterceptor ServerMessageSizeFilter::Call::OnServerTrailingMetadata; +const NoInterceptor ServerMessageSizeFilter::Call::OnClientToServerHalfClose; const NoInterceptor ServerMessageSizeFilter::Call::OnFinalize; // diff --git a/src/core/ext/filters/message_size/message_size_filter.h b/src/core/ext/filters/message_size/message_size_filter.h index 89d21201a5c..1637bfe3561 100644 --- a/src/core/ext/filters/message_size/message_size_filter.h +++ b/src/core/ext/filters/message_size/message_size_filter.h @@ -105,6 +105,7 @@ class ServerMessageSizeFilter final static const NoInterceptor OnFinalize; ServerMetadataHandle OnClientToServerMessage( const Message& message, ServerMessageSizeFilter* filter); + static const NoInterceptor OnClientToServerHalfClose; ServerMetadataHandle OnServerToClientMessage( const Message& message, ServerMessageSizeFilter* filter); }; @@ -133,6 +134,7 @@ class ClientMessageSizeFilter final static const NoInterceptor OnServerTrailingMetadata; static const NoInterceptor OnFinalize; ServerMetadataHandle OnClientToServerMessage(const Message& message); + static const NoInterceptor OnClientToServerHalfClose; ServerMetadataHandle OnServerToClientMessage(const Message& message); private: diff --git a/src/core/ext/filters/rbac/rbac_filter.cc b/src/core/ext/filters/rbac/rbac_filter.cc index 7c75f46ae7d..c89c6962e19 100644 --- a/src/core/ext/filters/rbac/rbac_filter.cc +++ b/src/core/ext/filters/rbac/rbac_filter.cc @@ -46,6 +46,7 @@ namespace grpc_core { const NoInterceptor RbacFilter::Call::OnServerInitialMetadata; const NoInterceptor RbacFilter::Call::OnServerTrailingMetadata; const NoInterceptor RbacFilter::Call::OnClientToServerMessage; +const NoInterceptor RbacFilter::Call::OnClientToServerHalfClose; const NoInterceptor RbacFilter::Call::OnServerToClientMessage; const NoInterceptor RbacFilter::Call::OnFinalize; diff --git a/src/core/ext/filters/rbac/rbac_filter.h b/src/core/ext/filters/rbac/rbac_filter.h index a4c41cbdd0b..d033b799d5f 100644 --- a/src/core/ext/filters/rbac/rbac_filter.h +++ b/src/core/ext/filters/rbac/rbac_filter.h @@ -55,6 +55,7 @@ class RbacFilter : public ImplementChannelFilter { static const NoInterceptor OnServerInitialMetadata; static const NoInterceptor OnServerTrailingMetadata; static const NoInterceptor OnClientToServerMessage; + static const NoInterceptor OnClientToServerHalfClose; static const NoInterceptor OnServerToClientMessage; static const NoInterceptor OnFinalize; }; diff --git a/src/core/ext/filters/stateful_session/stateful_session_filter.cc b/src/core/ext/filters/stateful_session/stateful_session_filter.cc index 5ae22cd864d..f383f3fb9ed 100644 --- a/src/core/ext/filters/stateful_session/stateful_session_filter.cc +++ b/src/core/ext/filters/stateful_session/stateful_session_filter.cc @@ -60,6 +60,7 @@ namespace grpc_core { TraceFlag grpc_stateful_session_filter_trace(false, "stateful_session_filter"); const NoInterceptor StatefulSessionFilter::Call::OnClientToServerMessage; +const NoInterceptor StatefulSessionFilter::Call::OnClientToServerHalfClose; const NoInterceptor StatefulSessionFilter::Call::OnServerToClientMessage; const NoInterceptor StatefulSessionFilter::Call::OnFinalize; diff --git a/src/core/ext/filters/stateful_session/stateful_session_filter.h b/src/core/ext/filters/stateful_session/stateful_session_filter.h index 5cd534843aa..64c488bce33 100644 --- a/src/core/ext/filters/stateful_session/stateful_session_filter.h +++ b/src/core/ext/filters/stateful_session/stateful_session_filter.h @@ -86,6 +86,7 @@ class StatefulSessionFilter void OnServerInitialMetadata(ServerMetadata& md); void OnServerTrailingMetadata(ServerMetadata& md); static const NoInterceptor OnClientToServerMessage; + static const NoInterceptor OnClientToServerHalfClose; static const NoInterceptor OnServerToClientMessage; static const NoInterceptor OnFinalize; diff --git a/src/core/lib/channel/context.h b/src/core/lib/channel/context.h index ddcd01395a9..95ce62240bf 100644 --- a/src/core/lib/channel/context.h +++ b/src/core/lib/channel/context.h @@ -72,6 +72,7 @@ struct grpc_call_context_element { namespace grpc_core { class Call; +class CallTracerAnnotationInterface; // Bind the legacy context array into the new style structure // TODO(ctiller): remove as we migrate these contexts to the new system. @@ -89,6 +90,12 @@ struct OldStyleContext { static constexpr grpc_context_index kIndex = GRPC_CONTEXT_CALL; }; +template <> +struct OldStyleContext { + static constexpr grpc_context_index kIndex = + GRPC_CONTEXT_CALL_TRACER_ANNOTATION_INTERFACE; +}; + template class Context::kIndex)>> { public: diff --git a/src/core/lib/channel/promise_based_filter.h b/src/core/lib/channel/promise_based_filter.h index c0735c45e61..c845c5b3cb8 100644 --- a/src/core/lib/channel/promise_based_filter.h +++ b/src/core/lib/channel/promise_based_filter.h @@ -351,7 +351,7 @@ auto MapResult(const NoInterceptor*, Promise x, void*) { } template -auto MapResult(absl::Status (Derived::Call::*fn)(ServerMetadata&), Promise x, +auto MapResult(absl::Status (Derived::Call::* fn)(ServerMetadata&), Promise x, FilterCallData* call_data) { DCHECK(fn == &Derived::Call::OnServerTrailingMetadata); return Map(std::move(x), [call_data](ServerMetadataHandle md) { @@ -362,7 +362,7 @@ auto MapResult(absl::Status (Derived::Call::*fn)(ServerMetadata&), Promise x, } template -auto MapResult(void (Derived::Call::*fn)(ServerMetadata&), Promise x, +auto MapResult(void (Derived::Call::* fn)(ServerMetadata&), Promise x, FilterCallData* call_data) { DCHECK(fn == &Derived::Call::OnServerTrailingMetadata); return Map(std::move(x), [call_data](ServerMetadataHandle md) { @@ -372,7 +372,7 @@ auto MapResult(void (Derived::Call::*fn)(ServerMetadata&), Promise x, } template -auto MapResult(void (Derived::Call::*fn)(ServerMetadata&, Derived*), Promise x, +auto MapResult(void (Derived::Call::* fn)(ServerMetadata&, Derived*), Promise x, FilterCallData* call_data) { DCHECK(fn == &Derived::Call::OnServerTrailingMetadata); return Map(std::move(x), [call_data](ServerMetadataHandle md) { @@ -492,137 +492,203 @@ auto RunCall(Interceptor interceptor, CallArgs call_args, std::move(call_args), std::move(next_promise_factory), call_data); } -inline void InterceptClientToServerMessage(const NoInterceptor*, void*, - const CallArgs&) {} +template +inline auto InterceptClientToServerMessageHandler( + void (Derived::Call::* fn)(const Message&), + FilterCallData* call_data, const CallArgs& call_args) { + DCHECK(fn == &Derived::Call::OnClientToServerMessage); + return [call_data](MessageHandle msg) -> absl::optional { + call_data->call.OnClientToServerMessage(*msg); + return std::move(msg); + }; +} template -inline void InterceptClientToServerMessage( - ServerMetadataHandle (Derived::Call::*fn)(const Message&), +inline auto InterceptClientToServerMessageHandler( + ServerMetadataHandle (Derived::Call::* fn)(const Message&), FilterCallData* call_data, const CallArgs& call_args) { DCHECK(fn == &Derived::Call::OnClientToServerMessage); - call_args.client_to_server_messages->InterceptAndMap( - [call_data](MessageHandle msg) -> absl::optional { - auto return_md = call_data->call.OnClientToServerMessage(*msg); - if (return_md == nullptr) return std::move(msg); - if (call_data->error_latch.is_set()) return absl::nullopt; - call_data->error_latch.Set(std::move(return_md)); - return absl::nullopt; - }); + return [call_data](MessageHandle msg) -> absl::optional { + auto return_md = call_data->call.OnClientToServerMessage(*msg); + if (return_md == nullptr) return std::move(msg); + if (call_data->error_latch.is_set()) return absl::nullopt; + call_data->error_latch.Set(std::move(return_md)); + return absl::nullopt; + }; } template -inline void InterceptClientToServerMessage( - ServerMetadataHandle (Derived::Call::*fn)(const Message&, Derived*), +inline auto InterceptClientToServerMessageHandler( + ServerMetadataHandle (Derived::Call::* fn)(const Message&, Derived*), FilterCallData* call_data, const CallArgs& call_args) { DCHECK(fn == &Derived::Call::OnClientToServerMessage); - call_args.client_to_server_messages->InterceptAndMap( - [call_data](MessageHandle msg) -> absl::optional { - auto return_md = - call_data->call.OnClientToServerMessage(*msg, call_data->channel); - if (return_md == nullptr) return std::move(msg); - if (call_data->error_latch.is_set()) return absl::nullopt; - call_data->error_latch.Set(std::move(return_md)); - return absl::nullopt; - }); + return [call_data](MessageHandle msg) -> absl::optional { + auto return_md = + call_data->call.OnClientToServerMessage(*msg, call_data->channel); + if (return_md == nullptr) return std::move(msg); + if (call_data->error_latch.is_set()) return absl::nullopt; + call_data->error_latch.Set(std::move(return_md)); + return absl::nullopt; + }; } template -inline void InterceptClientToServerMessage( - MessageHandle (Derived::Call::*fn)(MessageHandle, Derived*), +inline auto InterceptClientToServerMessageHandler( + MessageHandle (Derived::Call::* fn)(MessageHandle, Derived*), FilterCallData* call_data, const CallArgs& call_args) { DCHECK(fn == &Derived::Call::OnClientToServerMessage); - call_args.client_to_server_messages->InterceptAndMap( - [call_data](MessageHandle msg) -> absl::optional { - return call_data->call.OnClientToServerMessage(std::move(msg), - call_data->channel); - }); + return [call_data](MessageHandle msg) -> absl::optional { + return call_data->call.OnClientToServerMessage(std::move(msg), + call_data->channel); + }; } template -inline void InterceptClientToServerMessage( - absl::StatusOr (Derived::Call::*fn)(MessageHandle, Derived*), +inline auto InterceptClientToServerMessageHandler( + absl::StatusOr (Derived::Call::* fn)(MessageHandle, + Derived*), FilterCallData* call_data, const CallArgs& call_args) { DCHECK(fn == &Derived::Call::OnClientToServerMessage); + return [call_data](MessageHandle msg) -> absl::optional { + auto r = call_data->call.OnClientToServerMessage(std::move(msg), + call_data->channel); + if (r.ok()) return std::move(*r); + if (call_data->error_latch.is_set()) return absl::nullopt; + call_data->error_latch.Set(ServerMetadataFromStatus(r.status())); + return absl::nullopt; + }; +} + +template +inline void InterceptClientToServerMessage(HookFunction hook, + const NoInterceptor*, + FilterCallData* call_data, + const CallArgs& call_args) { call_args.client_to_server_messages->InterceptAndMap( - [call_data](MessageHandle msg) -> absl::optional { - auto r = call_data->call.OnClientToServerMessage(std::move(msg), - call_data->channel); - if (r.ok()) return std::move(*r); - if (call_data->error_latch.is_set()) return absl::nullopt; - call_data->error_latch.Set(ServerMetadataFromStatus(r.status())); - return absl::nullopt; - }); + InterceptClientToServerMessageHandler(hook, call_data, call_args)); } -inline void InterceptClientToServerMessage(const NoInterceptor*, void*, void*, - CallSpineInterface*) {} +template +inline void InterceptClientToServerMessage(HookFunction hook, + void (Derived::Call::* half_close)(), + FilterCallData* call_data, + const CallArgs& call_args) { + call_args.client_to_server_messages->InterceptAndMapWithHalfClose( + InterceptClientToServerMessageHandler(hook, call_data, call_args), + [call_data]() { call_data->call.OnClientToServerHalfClose(); }); +} + +template +inline void InterceptClientToServerMessage(const NoInterceptor*, + const NoInterceptor*, + FilterCallData*, + const CallArgs&) {} template -inline void InterceptClientToServerMessage( - ServerMetadataHandle (Derived::Call::*fn)(const Message&), +inline auto InterceptClientToServerMessageHandler( + ServerMetadataHandle (Derived::Call::* fn)(const Message&), typename Derived::Call* call, Derived*, PipeBasedCallSpine* call_spine) { DCHECK(fn == &Derived::Call::OnClientToServerMessage); - call_spine->client_to_server_messages().receiver.InterceptAndMap( + return [call, call_spine](MessageHandle msg) -> absl::optional { auto return_md = call->OnClientToServerMessage(*msg); if (return_md == nullptr) return std::move(msg); call_spine->PushServerTrailingMetadata(std::move(return_md)); return absl::nullopt; - }); + }; } template -inline void InterceptClientToServerMessage( - ServerMetadataHandle (Derived::Call::*fn)(const Message&, Derived*), +inline auto InterceptClientToServerMessageHandler( + void (Derived::Call::* fn)(const Message&), typename Derived::Call* call, + Derived*, PipeBasedCallSpine* call_spine) { + DCHECK(fn == &Derived::Call::OnClientToServerMessage); + return [call](MessageHandle msg) -> absl::optional { + call->OnClientToServerMessage(*msg); + return std::move(msg); + }; +} + +template +inline auto InterceptClientToServerMessageHandler( + ServerMetadataHandle (Derived::Call::* fn)(const Message&, Derived*), typename Derived::Call* call, Derived* channel, PipeBasedCallSpine* call_spine) { DCHECK(fn == &Derived::Call::OnClientToServerMessage); - call_spine->client_to_server_messages().receiver.InterceptAndMap( - [call, call_spine, - channel](MessageHandle msg) -> absl::optional { - auto return_md = call->OnClientToServerMessage(*msg, channel); - if (return_md == nullptr) return std::move(msg); - call_spine->PushServerTrailingMetadata(std::move(return_md)); - return absl::nullopt; - }); + return [call, call_spine, + channel](MessageHandle msg) -> absl::optional { + auto return_md = call->OnClientToServerMessage(*msg, channel); + if (return_md == nullptr) return std::move(msg); + call_spine->PushServerTrailingMetadata(std::move(return_md)); + return absl::nullopt; + }; } template -inline void InterceptClientToServerMessage( - MessageHandle (Derived::Call::*fn)(MessageHandle, Derived*), +inline auto InterceptClientToServerMessageHandler( + MessageHandle (Derived::Call::* fn)(MessageHandle, Derived*), typename Derived::Call* call, Derived* channel, PipeBasedCallSpine* call_spine) { DCHECK(fn == &Derived::Call::OnClientToServerMessage); - call_spine->client_to_server_messages().receiver.InterceptAndMap( - [call, channel](MessageHandle msg) { - return call->OnClientToServerMessage(std::move(msg), channel); - }); + return [call, channel](MessageHandle msg) { + return call->OnClientToServerMessage(std::move(msg), channel); + }; } template -inline void InterceptClientToServerMessage( - absl::StatusOr (Derived::Call::*fn)(MessageHandle, Derived*), +inline auto InterceptClientToServerMessageHandler( + absl::StatusOr (Derived::Call::* fn)(MessageHandle, + Derived*), typename Derived::Call* call, Derived* channel, PipeBasedCallSpine* call_spine) { DCHECK(fn == &Derived::Call::OnClientToServerMessage); + return [call, call_spine, + channel](MessageHandle msg) -> absl::optional { + auto r = call->OnClientToServerMessage(std::move(msg), channel); + if (r.ok()) return std::move(*r); + call_spine->PushServerTrailingMetadata( + ServerMetadataFromStatus(r.status())); + return absl::nullopt; + }; +} + +template +inline void InterceptClientToServerMessage(HookFunction fn, + const NoInterceptor*, + typename Derived::Call* call, + Derived* channel, + PipeBasedCallSpine* call_spine) { + DCHECK(fn == &Derived::Call::OnClientToServerMessage); call_spine->client_to_server_messages().receiver.InterceptAndMap( - [call, call_spine, - channel](MessageHandle msg) -> absl::optional { - auto r = call->OnClientToServerMessage(std::move(msg), channel); - if (r.ok()) return std::move(*r); - call_spine->PushServerTrailingMetadata( - ServerMetadataFromStatus(r.status())); - return absl::nullopt; - }); + InterceptClientToServerMessageHandler(fn, call, channel, call_spine)); } +template +inline void InterceptClientToServerMessage(HookFunction fn, + void (Derived::Call::* half_close)(), + typename Derived::Call* call, + Derived* channel, + PipeBasedCallSpine* call_spine) { + DCHECK(fn == &Derived::Call::OnClientToServerMessage); + DCHECK(half_close == &Derived::Call::OnClientToServerHalfClose); + call_spine->client_to_server_messages().receiver.InterceptAndMapWithHalfClose( + InterceptClientToServerMessageHandler(fn, call, channel, call_spine), + [call]() { call->OnClientToServerHalfClose(); }); +} + +template +inline void InterceptClientToServerMessage(const NoInterceptor*, + const NoInterceptor*, + typename Derived::Call*, Derived*, + PipeBasedCallSpine*) {} + inline void InterceptClientInitialMetadata(const NoInterceptor*, void*, void*, PipeBasedCallSpine*) {} template inline void InterceptClientInitialMetadata( - void (Derived::Call::*fn)(ClientMetadata& md), typename Derived::Call* call, - Derived*, PipeBasedCallSpine* call_spine) { + void (Derived::Call::* fn)(ClientMetadata& md), + typename Derived::Call* call, Derived*, PipeBasedCallSpine* call_spine) { DCHECK(fn == &Derived::Call::OnClientInitialMetadata); call_spine->client_initial_metadata().receiver.InterceptAndMap( [call](ClientMetadataHandle md) { @@ -633,7 +699,7 @@ inline void InterceptClientInitialMetadata( template inline void InterceptClientInitialMetadata( - void (Derived::Call::*fn)(ClientMetadata& md, Derived* channel), + void (Derived::Call::* fn)(ClientMetadata& md, Derived* channel), typename Derived::Call* call, Derived* channel, PipeBasedCallSpine* call_spine) { DCHECK(fn == &Derived::Call::OnClientInitialMetadata); @@ -646,7 +712,7 @@ inline void InterceptClientInitialMetadata( template inline void InterceptClientInitialMetadata( - ServerMetadataHandle (Derived::Call::*fn)(ClientMetadata& md), + ServerMetadataHandle (Derived::Call::* fn)(ClientMetadata& md), typename Derived::Call* call, Derived*, PipeBasedCallSpine* call_spine) { DCHECK(fn == &Derived::Call::OnClientInitialMetadata); call_spine->client_initial_metadata().receiver.InterceptAndMap( @@ -661,8 +727,8 @@ inline void InterceptClientInitialMetadata( template inline void InterceptClientInitialMetadata( - ServerMetadataHandle (Derived::Call::*fn)(ClientMetadata& md, - Derived* channel), + ServerMetadataHandle (Derived::Call::* fn)(ClientMetadata& md, + Derived* channel), typename Derived::Call* call, Derived* channel, PipeBasedCallSpine* call_spine) { DCHECK(fn == &Derived::Call::OnClientInitialMetadata); @@ -678,7 +744,7 @@ inline void InterceptClientInitialMetadata( template inline void InterceptClientInitialMetadata( - absl::Status (Derived::Call::*fn)(ClientMetadata& md), + absl::Status (Derived::Call::* fn)(ClientMetadata& md), typename Derived::Call* call, Derived*, PipeBasedCallSpine* call_spine) { DCHECK(fn == &Derived::Call::OnClientInitialMetadata); call_spine->client_initial_metadata().receiver.InterceptAndMap( @@ -694,7 +760,7 @@ inline void InterceptClientInitialMetadata( template inline void InterceptClientInitialMetadata( - absl::Status (Derived::Call::*fn)(ClientMetadata& md, Derived* channel), + absl::Status (Derived::Call::* fn)(ClientMetadata& md, Derived* channel), typename Derived::Call* call, Derived* channel, PipeBasedCallSpine* call_spine) { DCHECK(fn == &Derived::Call::OnClientInitialMetadata); @@ -714,7 +780,7 @@ inline void InterceptClientInitialMetadata( template absl::void_t( std::declval>))> -InterceptClientInitialMetadata(Promise (Derived::Call::*promise_factory)( +InterceptClientInitialMetadata(Promise (Derived::Call::* promise_factory)( ClientMetadata& md, Derived* channel), typename Derived::Call* call, Derived* channel, PipeBasedCallSpine* call_spine) { @@ -740,7 +806,7 @@ inline void InterceptServerInitialMetadata(const NoInterceptor*, void*, template inline void InterceptServerInitialMetadata( - void (Derived::Call::*fn)(ServerMetadata&), + void (Derived::Call::* fn)(ServerMetadata&), FilterCallData* call_data, const CallArgs& call_args) { DCHECK(fn == &Derived::Call::OnServerInitialMetadata); call_args.server_initial_metadata->InterceptAndMap( @@ -752,7 +818,7 @@ inline void InterceptServerInitialMetadata( template inline void InterceptServerInitialMetadata( - absl::Status (Derived::Call::*fn)(ServerMetadata&), + absl::Status (Derived::Call::* fn)(ServerMetadata&), FilterCallData* call_data, const CallArgs& call_args) { DCHECK(fn == &Derived::Call::OnServerInitialMetadata); call_args.server_initial_metadata->InterceptAndMap( @@ -769,7 +835,7 @@ inline void InterceptServerInitialMetadata( template inline void InterceptServerInitialMetadata( - void (Derived::Call::*fn)(ServerMetadata&, Derived*), + void (Derived::Call::* fn)(ServerMetadata&, Derived*), FilterCallData* call_data, const CallArgs& call_args) { DCHECK(fn == &Derived::Call::OnServerInitialMetadata); call_args.server_initial_metadata->InterceptAndMap( @@ -781,7 +847,7 @@ inline void InterceptServerInitialMetadata( template inline void InterceptServerInitialMetadata( - absl::Status (Derived::Call::*fn)(ServerMetadata&, Derived*), + absl::Status (Derived::Call::* fn)(ServerMetadata&, Derived*), FilterCallData* call_data, const CallArgs& call_args) { DCHECK(fn == &Derived::Call::OnServerInitialMetadata); call_args.server_initial_metadata->InterceptAndMap( @@ -802,7 +868,7 @@ inline void InterceptServerInitialMetadata(const NoInterceptor*, void*, void*, template inline void InterceptServerInitialMetadata( - void (Derived::Call::*fn)(ServerMetadata&), typename Derived::Call* call, + void (Derived::Call::* fn)(ServerMetadata&), typename Derived::Call* call, Derived*, PipeBasedCallSpine* call_spine) { DCHECK(fn == &Derived::Call::OnServerInitialMetadata); call_spine->server_initial_metadata().sender.InterceptAndMap( @@ -814,7 +880,7 @@ inline void InterceptServerInitialMetadata( template inline void InterceptServerInitialMetadata( - absl::Status (Derived::Call::*fn)(ServerMetadata&), + absl::Status (Derived::Call::* fn)(ServerMetadata&), typename Derived::Call* call, Derived*, PipeBasedCallSpine* call_spine) { DCHECK(fn == &Derived::Call::OnServerInitialMetadata); call_spine->server_initial_metadata().sender.InterceptAndMap( @@ -830,7 +896,7 @@ inline void InterceptServerInitialMetadata( template inline void InterceptServerInitialMetadata( - void (Derived::Call::*fn)(ServerMetadata&, Derived*), + void (Derived::Call::* fn)(ServerMetadata&, Derived*), typename Derived::Call* call, Derived* channel, PipeBasedCallSpine* call_spine) { DCHECK(fn == &Derived::Call::OnServerInitialMetadata); @@ -843,7 +909,7 @@ inline void InterceptServerInitialMetadata( template inline void InterceptServerInitialMetadata( - absl::Status (Derived::Call::*fn)(ServerMetadata&, Derived*), + absl::Status (Derived::Call::* fn)(ServerMetadata&, Derived*), typename Derived::Call* call, Derived* channel, PipeBasedCallSpine* call_spine) { DCHECK(fn == &Derived::Call::OnServerInitialMetadata); @@ -863,7 +929,19 @@ inline void InterceptServerToClientMessage(const NoInterceptor*, void*, template inline void InterceptServerToClientMessage( - ServerMetadataHandle (Derived::Call::*fn)(const Message&), + void (Derived::Call::* fn)(const Message&), + FilterCallData* call_data, const CallArgs& call_args) { + DCHECK(fn == &Derived::Call::OnServerToClientMessage); + call_args.server_to_client_messages->InterceptAndMap( + [call_data](MessageHandle msg) -> absl::optional { + call_data->call.OnServerToClientMessage(*msg); + return std::move(msg); + }); +} + +template +inline void InterceptServerToClientMessage( + ServerMetadataHandle (Derived::Call::* fn)(const Message&), FilterCallData* call_data, const CallArgs& call_args) { DCHECK(fn == &Derived::Call::OnServerToClientMessage); call_args.server_to_client_messages->InterceptAndMap( @@ -878,7 +956,7 @@ inline void InterceptServerToClientMessage( template inline void InterceptServerToClientMessage( - ServerMetadataHandle (Derived::Call::*fn)(const Message&, Derived*), + ServerMetadataHandle (Derived::Call::* fn)(const Message&, Derived*), FilterCallData* call_data, const CallArgs& call_args) { DCHECK(fn == &Derived::Call::OnServerToClientMessage); call_args.server_to_client_messages->InterceptAndMap( @@ -894,7 +972,7 @@ inline void InterceptServerToClientMessage( template inline void InterceptServerToClientMessage( - MessageHandle (Derived::Call::*fn)(MessageHandle, Derived*), + MessageHandle (Derived::Call::* fn)(MessageHandle, Derived*), FilterCallData* call_data, const CallArgs& call_args) { DCHECK(fn == &Derived::Call::OnServerToClientMessage); call_args.server_to_client_messages->InterceptAndMap( @@ -906,7 +984,8 @@ inline void InterceptServerToClientMessage( template inline void InterceptServerToClientMessage( - absl::StatusOr (Derived::Call::*fn)(MessageHandle, Derived*), + absl::StatusOr (Derived::Call::* fn)(MessageHandle, + Derived*), FilterCallData* call_data, const CallArgs& call_args) { DCHECK(fn == &Derived::Call::OnServerToClientMessage); call_args.server_to_client_messages->InterceptAndMap( @@ -925,7 +1004,19 @@ inline void InterceptServerToClientMessage(const NoInterceptor*, void*, void*, template inline void InterceptServerToClientMessage( - ServerMetadataHandle (Derived::Call::*fn)(const Message&), + void (Derived::Call::* fn)(const Message&), typename Derived::Call* call, + Derived*, PipeBasedCallSpine* call_spine) { + DCHECK(fn == &Derived::Call::OnServerToClientMessage); + call_spine->server_to_client_messages().sender.InterceptAndMap( + [call](MessageHandle msg) -> absl::optional { + call->OnServerToClientMessage(*msg); + return std::move(msg); + }); +} + +template +inline void InterceptServerToClientMessage( + ServerMetadataHandle (Derived::Call::* fn)(const Message&), typename Derived::Call* call, Derived*, PipeBasedCallSpine* call_spine) { DCHECK(fn == &Derived::Call::OnServerToClientMessage); call_spine->server_to_client_messages().sender.InterceptAndMap( @@ -939,7 +1030,7 @@ inline void InterceptServerToClientMessage( template inline void InterceptServerToClientMessage( - ServerMetadataHandle (Derived::Call::*fn)(const Message&, Derived*), + ServerMetadataHandle (Derived::Call::* fn)(const Message&, Derived*), typename Derived::Call* call, Derived* channel, PipeBasedCallSpine* call_spine) { DCHECK(fn == &Derived::Call::OnServerToClientMessage); @@ -955,7 +1046,7 @@ inline void InterceptServerToClientMessage( template inline void InterceptServerToClientMessage( - MessageHandle (Derived::Call::*fn)(MessageHandle, Derived*), + MessageHandle (Derived::Call::* fn)(MessageHandle, Derived*), typename Derived::Call* call, Derived* channel, PipeBasedCallSpine* call_spine) { DCHECK(fn == &Derived::Call::OnServerToClientMessage); @@ -967,7 +1058,8 @@ inline void InterceptServerToClientMessage( template inline void InterceptServerToClientMessage( - absl::StatusOr (Derived::Call::*fn)(MessageHandle, Derived*), + absl::StatusOr (Derived::Call::* fn)(MessageHandle, + Derived*), typename Derived::Call* call, Derived* channel, PipeBasedCallSpine* call_spine) { DCHECK(fn == &Derived::Call::OnServerToClientMessage); @@ -1011,7 +1103,7 @@ inline void InterceptServerTrailingMetadata( inline void InterceptFinalize(const NoInterceptor*, void*, void*) {} template -inline void InterceptFinalize(void (Call::*fn)(const grpc_call_final_info*), +inline void InterceptFinalize(void (Call::* fn)(const grpc_call_final_info*), void*, Call* call) { DCHECK(fn == &Call::OnFinalize); GetContext()->Add( @@ -1022,7 +1114,7 @@ inline void InterceptFinalize(void (Call::*fn)(const grpc_call_final_info*), template inline void InterceptFinalize( - void (Derived::Call::*fn)(const grpc_call_final_info*, Derived*), + void (Derived::Call::* fn)(const grpc_call_final_info*, Derived*), Derived* channel, typename Derived::Call* call) { DCHECK(fn == &Derived::Call::OnFinalize); GetContext()->Add( @@ -1120,7 +1212,8 @@ class ImplementChannelFilter : public ChannelFilter, promise_filter_detail::InterceptClientInitialMetadata( &Derived::Call::OnClientInitialMetadata, call, d, c); promise_filter_detail::InterceptClientToServerMessage( - &Derived::Call::OnClientToServerMessage, call, d, c); + &Derived::Call::OnClientToServerMessage, + &Derived::Call::OnClientToServerHalfClose, call, d, c); promise_filter_detail::InterceptServerInitialMetadata( &Derived::Call::OnServerInitialMetadata, call, d, c); promise_filter_detail::InterceptServerToClientMessage( @@ -1139,7 +1232,8 @@ class ImplementChannelFilter : public ChannelFilter, auto* call = promise_filter_detail::MakeFilterCall( static_cast(this)); promise_filter_detail::InterceptClientToServerMessage( - &Derived::Call::OnClientToServerMessage, call, call_args); + &Derived::Call::OnClientToServerMessage, + &Derived::Call::OnClientToServerHalfClose, call, call_args); promise_filter_detail::InterceptServerInitialMetadata( &Derived::Call::OnServerInitialMetadata, call, call_args); promise_filter_detail::InterceptServerToClientMessage( diff --git a/src/core/lib/security/authorization/grpc_server_authz_filter.cc b/src/core/lib/security/authorization/grpc_server_authz_filter.cc index 5474847701a..7207b3897fa 100644 --- a/src/core/lib/security/authorization/grpc_server_authz_filter.cc +++ b/src/core/lib/security/authorization/grpc_server_authz_filter.cc @@ -41,6 +41,7 @@ TraceFlag grpc_authz_trace(false, "grpc_authz_api"); const NoInterceptor GrpcServerAuthzFilter::Call::OnServerInitialMetadata; const NoInterceptor GrpcServerAuthzFilter::Call::OnServerTrailingMetadata; const NoInterceptor GrpcServerAuthzFilter::Call::OnClientToServerMessage; +const NoInterceptor GrpcServerAuthzFilter::Call::OnClientToServerHalfClose; const NoInterceptor GrpcServerAuthzFilter::Call::OnServerToClientMessage; const NoInterceptor GrpcServerAuthzFilter::Call::OnFinalize; diff --git a/src/core/lib/security/authorization/grpc_server_authz_filter.h b/src/core/lib/security/authorization/grpc_server_authz_filter.h index b4b0a7463cd..742b3979d88 100644 --- a/src/core/lib/security/authorization/grpc_server_authz_filter.h +++ b/src/core/lib/security/authorization/grpc_server_authz_filter.h @@ -51,6 +51,7 @@ class GrpcServerAuthzFilter final static const NoInterceptor OnServerInitialMetadata; static const NoInterceptor OnServerTrailingMetadata; static const NoInterceptor OnClientToServerMessage; + static const NoInterceptor OnClientToServerHalfClose; static const NoInterceptor OnServerToClientMessage; static const NoInterceptor OnFinalize; }; diff --git a/src/core/lib/security/transport/auth_filters.h b/src/core/lib/security/transport/auth_filters.h index 06b8b6e6fae..3970ae1e4f3 100644 --- a/src/core/lib/security/transport/auth_filters.h +++ b/src/core/lib/security/transport/auth_filters.h @@ -115,6 +115,7 @@ class ServerAuthFilter final : public ImplementChannelFilter { } static const NoInterceptor OnServerInitialMetadata; static const NoInterceptor OnClientToServerMessage; + static const NoInterceptor OnClientToServerHalfClose; static const NoInterceptor OnServerToClientMessage; static const NoInterceptor OnServerTrailingMetadata; static const NoInterceptor OnFinalize; diff --git a/src/core/lib/security/transport/server_auth_filter.cc b/src/core/lib/security/transport/server_auth_filter.cc index 4b6ef0a1f10..bfbfeb2d8ae 100644 --- a/src/core/lib/security/transport/server_auth_filter.cc +++ b/src/core/lib/security/transport/server_auth_filter.cc @@ -68,6 +68,7 @@ const grpc_channel_filter ServerAuthFilter::kFilter = "server-auth"); const NoInterceptor ServerAuthFilter::Call::OnClientToServerMessage; +const NoInterceptor ServerAuthFilter::Call::OnClientToServerHalfClose; const NoInterceptor ServerAuthFilter::Call::OnServerToClientMessage; const NoInterceptor ServerAuthFilter::Call::OnServerInitialMetadata; const NoInterceptor ServerAuthFilter::Call::OnServerTrailingMetadata; diff --git a/src/core/lib/transport/call_filters.cc b/src/core/lib/transport/call_filters.cc index d4792c02bfe..2b89babdde4 100644 --- a/src/core/lib/transport/call_filters.cc +++ b/src/core/lib/transport/call_filters.cc @@ -29,6 +29,12 @@ void* Offset(void* base, size_t amt) { return static_cast(base) + amt; } namespace filters_detail { +void RunHalfClose(absl::Span ops, void* call_data) { + for (const auto& op : ops) { + op.half_close(Offset(call_data, op.call_offset), op.channel_data); + } +} + template OperationExecutor::~OperationExecutor() { if (promise_data_ != nullptr) { diff --git a/src/core/lib/transport/call_filters.h b/src/core/lib/transport/call_filters.h index d637b6fb0de..1a8c971ee73 100644 --- a/src/core/lib/transport/call_filters.h +++ b/src/core/lib/transport/call_filters.h @@ -43,6 +43,7 @@ // - OnServerInitialMetadata - $VALUE_TYPE = ServerMetadata // - OnServerToClientMessage - $VALUE_TYPE = Message // - OnClientToServerMessage - $VALUE_TYPE = Message +// - OnClientToServerHalfClose - no value // - OnServerTrailingMetadata - $VALUE_TYPE = ServerMetadata // - OnFinalize - special, see below // These members define an interception point for a particular event in @@ -192,6 +193,16 @@ struct Operator { void (*early_destroy)(void* promise_data); }; +struct HalfCloseOperator { + // Pointer to corresponding channel data for this filter + void* channel_data; + // Offset of the call data for this filter within the call data memory + size_t call_offset; + void (*half_close)(void* call_data, void* channel_data); +}; + +void RunHalfClose(absl::Span ops, void* call_data); + // We divide operations into fallible and infallible. // Fallible operations can fail, and that failure terminates the call. // Infallible operations cannot fail. @@ -265,6 +276,32 @@ void AddOp(FilterType* channel_data, size_t call_offset, to); } +template +void AddHalfClose(FilterType* channel_data, size_t call_offset, + void (FilterType::Call::*)(), + std::vector& to) { + to.push_back( + HalfCloseOperator{channel_data, call_offset, [](void* call_data, void*) { + static_cast(call_data) + ->OnClientToServerHalfClose(); + }}); +} + +template +void AddHalfClose(FilterType* channel_data, size_t call_offset, + void (FilterType::Call::*)(FilterType*), + std::vector& to) { + to.push_back(HalfCloseOperator{ + channel_data, call_offset, [](void* call_data, void* channel_data) { + static_cast(call_data) + ->OnClientToServerHalfClose(static_cast(channel_data)); + }}); +} + +template +void AddHalfClose(FilterType*, size_t, const NoInterceptor*, + std::vector&) {} + // const NoInterceptor $EVENT // These do nothing, and specifically DO NOT add an operation to the layout. // Supported for fallible & infallible operations. @@ -276,7 +313,7 @@ struct AddOpImpl { // void $INTERCEPTOR_NAME($VALUE_TYPE&) template + void (FilterType::Call::* impl)(typename T::element_type&)> struct AddOpImpl { static void Add(FilterType* channel_data, size_t call_offset, @@ -313,8 +350,8 @@ struct AddOpImpl + void (FilterType::Call::* impl)(typename T::element_type&, + FilterType*)> struct AddOpImpl< FilterType, T, void (FilterType::Call::*)(typename T::element_type&, FilterType*), impl> { @@ -354,7 +391,7 @@ struct AddOpImpl< // $VALUE_HANDLE $INTERCEPTOR_NAME($VALUE_HANDLE, FilterType*) template + T (FilterType::Call::* impl)(T, FilterType*)> struct AddOpImpl { static void Add(FilterType* channel_data, size_t call_offset, Layout>& to) { @@ -396,7 +433,7 @@ struct AddOpImpl { // absl::Status $INTERCEPTOR_NAME($VALUE_TYPE&) template + absl::Status (FilterType::Call::* impl)(typename T::element_type&)> struct AddOpImpl { @@ -441,7 +478,7 @@ struct AddOpImpl struct AddOpImpl< FilterType, T, @@ -469,8 +506,8 @@ struct AddOpImpl< // absl::Status $INTERCEPTOR_NAME($VALUE_TYPE&, FilterType*) template + absl::Status (FilterType::Call::* impl)(typename T::element_type&, + FilterType*)> struct AddOpImpl struct AddOpImpl $INTERCEPTOR_NAME($VALUE_HANDLE, FilterType*) template (FilterType::Call::*impl)(T, FilterType*)> + absl::StatusOr (FilterType::Call::* impl)(T, FilterType*)> struct AddOpImpl (FilterType::Call::*)(T, FilterType*), impl> { @@ -557,7 +594,7 @@ struct AddOpImpl struct AddOpImpl struct AddOpImpl struct AddOpImpl struct AddOpImpl + R (FilterType::Call::* impl)(typename T::element_type&)> struct AddOpImpl< FilterType, T, R (FilterType::Call::*)(typename T::element_type&), impl, absl::enable_if_t>::value>> { @@ -726,7 +763,7 @@ struct AddOpImpl< // PROMISE_RETURNING(absl::Status) $INTERCEPTOR_NAME($VALUE_TYPE&, FilterType*) template + R (FilterType::Call::* impl)(typename T::element_type&, FilterType*)> struct AddOpImpl< FilterType, T, R (FilterType::Call::*)(typename T::element_type&, FilterType*), impl, @@ -782,7 +819,7 @@ struct AddOpImpl< // PROMISE_RETURNING(absl::StatusOr<$VALUE_HANDLE>) // $INTERCEPTOR_NAME($VALUE_HANDLE, FilterType*) template + R (FilterType::Call::* impl)(T, FilterType*)> struct AddOpImpl, PromiseResult>::value>> { @@ -852,6 +889,7 @@ struct StackData { Layout> client_initial_metadata; Layout> server_initial_metadata; Layout> client_to_server_messages; + std::vector client_to_server_half_close; Layout> server_to_client_messages; Layout> server_trailing_metadata; // A list of finalizers for this call. @@ -972,6 +1010,14 @@ struct StackData { channel_data, call_offset, client_to_server_messages); } + template + void AddClientToServerHalfClose(FilterType* channel_data, + size_t call_offset) { + AddHalfClose(channel_data, call_offset, + &FilterType::Call::OnClientToServerHalfClose, + client_to_server_half_close); + } + template void AddServerToClientMessageOp(FilterType* channel_data, size_t call_offset) { @@ -997,7 +1043,7 @@ struct StackData { template void AddFinalizer(FilterType* channel_data, size_t call_offset, - void (FilterType::Call::*p)(const grpc_call_final_info*)) { + void (FilterType::Call::* p)(const grpc_call_final_info*)) { DCHECK(p == &FilterType::Call::OnFinalize); finalizers.push_back(Finalizer{ channel_data, @@ -1011,8 +1057,8 @@ struct StackData { template void AddFinalizer(FilterType* channel_data, size_t call_offset, - void (FilterType::Call::*p)(const grpc_call_final_info*, - FilterType*)) { + void (FilterType::Call::* p)(const grpc_call_final_info*, + FilterType*)) { DCHECK(p == &FilterType::Call::OnFinalize); finalizers.push_back(Finalizer{ channel_data, @@ -1217,6 +1263,7 @@ class ServerTrailingMetadataInterceptor { static const NoInterceptor OnClientInitialMetadata; static const NoInterceptor OnServerInitialMetadata; static const NoInterceptor OnClientToServerMessage; + static const NoInterceptor OnClientToServerHalfClose; static const NoInterceptor OnServerToClientMessage; static const NoInterceptor OnFinalize; void OnServerTrailingMetadata(ServerMetadata& md, @@ -1240,6 +1287,9 @@ template const NoInterceptor ServerTrailingMetadataInterceptor::Call::OnClientToServerMessage; template +const NoInterceptor + ServerTrailingMetadataInterceptor::Call::OnClientToServerHalfClose; +template const NoInterceptor ServerTrailingMetadataInterceptor::Call::OnServerToClientMessage; template @@ -1256,6 +1306,7 @@ class ClientInitialMetadataInterceptor { } static const NoInterceptor OnServerInitialMetadata; static const NoInterceptor OnClientToServerMessage; + static const NoInterceptor OnClientToServerHalfClose; static const NoInterceptor OnServerToClientMessage; static const NoInterceptor OnServerTrailingMetadata; static const NoInterceptor OnFinalize; @@ -1273,6 +1324,9 @@ template const NoInterceptor ClientInitialMetadataInterceptor::Call::OnClientToServerMessage; template +const NoInterceptor + ClientInitialMetadataInterceptor::Call::OnClientToServerHalfClose; +template const NoInterceptor ClientInitialMetadataInterceptor::Call::OnServerToClientMessage; template @@ -1319,6 +1373,7 @@ class CallFilters { data_.AddClientInitialMetadataOp(filter, call_offset); data_.AddServerInitialMetadataOp(filter, call_offset); data_.AddClientToServerMessageOp(filter, call_offset); + data_.AddClientToServerHalfClose(filter, call_offset); data_.AddServerToClientMessageOp(filter, call_offset); data_.AddServerTrailingMetadataOp(filter, call_offset); data_.AddFinalizer(filter, call_offset, &FilterType::Call::OnFinalize); @@ -1424,10 +1479,10 @@ class CallFilters { std::string DebugString() const; private: - template >( - filters_detail::StackData::*layout_ptr)> + filters_detail::StackData::* layout_ptr)> class PipePromise { public: class Push { @@ -1593,6 +1648,8 @@ class CallFilters { filters_detail::OperationExecutor executor_; }; + template ( + filters_detail::StackData::* half_close_layout_ptr)> class PullMessage { public: explicit PullMessage(CallFilters* filters) : filters_(filters) {} @@ -1637,7 +1694,14 @@ class CallFilters { filters_->CancelDueToFailedPipeOperation(); return Failure{}; } - if (!**r) return absl::nullopt; + if (!**r) { + if (half_close_layout_ptr != nullptr) { + filters_detail::RunHalfClose( + filters_->stack_->data_.*half_close_layout_ptr, + filters_->call_data_); + } + return absl::nullopt; + } CHECK(filters_ != nullptr); return FinishOperationExecutor(executor_.Start( layout(), push()->TakeValue(), filters_->call_data_)); @@ -1865,7 +1929,8 @@ inline auto CallFilters::PushClientToServerMessage(MessageHandle message) { } inline auto CallFilters::PullClientToServerMessage() { - return ClientToServerMessagePromises::PullMessage{this}; + return ClientToServerMessagePromises::PullMessage< + &filters_detail::StackData::client_to_server_half_close>{this}; } inline auto CallFilters::PushServerToClientMessage(MessageHandle message) { @@ -1875,7 +1940,7 @@ inline auto CallFilters::PushServerToClientMessage(MessageHandle message) { } inline auto CallFilters::PullServerToClientMessage() { - return ServerToClientMessagePromises::PullMessage{this}; + return ServerToClientMessagePromises::PullMessage{this}; } inline auto CallFilters::PullServerTrailingMetadata() { diff --git a/src/core/load_balancing/grpclb/client_load_reporting_filter.cc b/src/core/load_balancing/grpclb/client_load_reporting_filter.cc index c3de43e4b94..e03fe8d70ce 100644 --- a/src/core/load_balancing/grpclb/client_load_reporting_filter.cc +++ b/src/core/load_balancing/grpclb/client_load_reporting_filter.cc @@ -41,6 +41,7 @@ namespace grpc_core { const NoInterceptor ClientLoadReportingFilter::Call::OnServerToClientMessage; const NoInterceptor ClientLoadReportingFilter::Call::OnClientToServerMessage; +const NoInterceptor ClientLoadReportingFilter::Call::OnClientToServerHalfClose; const NoInterceptor ClientLoadReportingFilter::Call::OnFinalize; const grpc_channel_filter ClientLoadReportingFilter::kFilter = diff --git a/src/core/load_balancing/grpclb/client_load_reporting_filter.h b/src/core/load_balancing/grpclb/client_load_reporting_filter.h index 941b97abf99..f8c5b7fbcc4 100644 --- a/src/core/load_balancing/grpclb/client_load_reporting_filter.h +++ b/src/core/load_balancing/grpclb/client_load_reporting_filter.h @@ -43,6 +43,7 @@ class ClientLoadReportingFilter final void OnServerTrailingMetadata(ServerMetadata& server_trailing_metadata); static const NoInterceptor OnServerToClientMessage; static const NoInterceptor OnClientToServerMessage; + static const NoInterceptor OnClientToServerHalfClose; static const NoInterceptor OnFinalize; private: diff --git a/src/core/resolver/xds/xds_resolver.cc b/src/core/resolver/xds/xds_resolver.cc index d64da628eed..17808406c4c 100644 --- a/src/core/resolver/xds/xds_resolver.cc +++ b/src/core/resolver/xds/xds_resolver.cc @@ -330,6 +330,7 @@ class XdsResolver final : public Resolver { static const NoInterceptor OnServerInitialMetadata; static const NoInterceptor OnServerTrailingMetadata; static const NoInterceptor OnClientToServerMessage; + static const NoInterceptor OnClientToServerHalfClose; static const NoInterceptor OnServerToClientMessage; static const NoInterceptor OnFinalize; }; @@ -383,6 +384,8 @@ const NoInterceptor XdsResolver::ClusterSelectionFilter::Call::OnServerTrailingMetadata; const NoInterceptor XdsResolver::ClusterSelectionFilter::Call::OnClientToServerMessage; +const NoInterceptor + XdsResolver::ClusterSelectionFilter::Call::OnClientToServerHalfClose; const NoInterceptor XdsResolver::ClusterSelectionFilter::Call::OnServerToClientMessage; const NoInterceptor XdsResolver::ClusterSelectionFilter::Call::OnFinalize; diff --git a/src/core/server/server_call_tracer_filter.cc b/src/core/server/server_call_tracer_filter.cc index 6966b4ccc05..6ceb637d481 100644 --- a/src/core/server/server_call_tracer_filter.cc +++ b/src/core/server/server_call_tracer_filter.cc @@ -80,6 +80,7 @@ class ServerCallTracerFilter } static const NoInterceptor OnClientToServerMessage; + static const NoInterceptor OnClientToServerHalfClose; static const NoInterceptor OnServerToClientMessage; private: @@ -92,6 +93,7 @@ class ServerCallTracerFilter }; const NoInterceptor ServerCallTracerFilter::Call::OnClientToServerMessage; +const NoInterceptor ServerCallTracerFilter::Call::OnClientToServerHalfClose; const NoInterceptor ServerCallTracerFilter::Call::OnServerToClientMessage; const grpc_channel_filter ServerCallTracerFilter::kFilter = diff --git a/src/core/server/server_config_selector_filter.cc b/src/core/server/server_config_selector_filter.cc index 4da3b75398c..b4acec5329e 100644 --- a/src/core/server/server_config_selector_filter.cc +++ b/src/core/server/server_config_selector_filter.cc @@ -72,6 +72,7 @@ class ServerConfigSelectorFilter final static const NoInterceptor OnServerInitialMetadata; static const NoInterceptor OnServerTrailingMetadata; static const NoInterceptor OnClientToServerMessage; + static const NoInterceptor OnClientToServerHalfClose; static const NoInterceptor OnServerToClientMessage; static const NoInterceptor OnFinalize; }; @@ -158,6 +159,7 @@ absl::Status ServerConfigSelectorFilter::Call::OnClientInitialMetadata( const NoInterceptor ServerConfigSelectorFilter::Call::OnServerInitialMetadata; const NoInterceptor ServerConfigSelectorFilter::Call::OnServerTrailingMetadata; const NoInterceptor ServerConfigSelectorFilter::Call::OnClientToServerMessage; +const NoInterceptor ServerConfigSelectorFilter::Call::OnClientToServerHalfClose; const NoInterceptor ServerConfigSelectorFilter::Call::OnServerToClientMessage; const NoInterceptor ServerConfigSelectorFilter::Call::OnFinalize; diff --git a/src/core/service_config/service_config_channel_arg_filter.cc b/src/core/service_config/service_config_channel_arg_filter.cc index a2c9a973129..c73e9dd450e 100644 --- a/src/core/service_config/service_config_channel_arg_filter.cc +++ b/src/core/service_config/service_config_channel_arg_filter.cc @@ -83,6 +83,7 @@ class ServiceConfigChannelArgFilter final static const NoInterceptor OnServerInitialMetadata; static const NoInterceptor OnServerTrailingMetadata; static const NoInterceptor OnClientToServerMessage; + static const NoInterceptor OnClientToServerHalfClose; static const NoInterceptor OnServerToClientMessage; static const NoInterceptor OnFinalize; }; @@ -97,6 +98,8 @@ const NoInterceptor ServiceConfigChannelArgFilter::Call::OnServerTrailingMetadata; const NoInterceptor ServiceConfigChannelArgFilter::Call::OnClientToServerMessage; +const NoInterceptor + ServiceConfigChannelArgFilter::Call::OnClientToServerHalfClose; const NoInterceptor ServiceConfigChannelArgFilter::Call::OnServerToClientMessage; const NoInterceptor ServiceConfigChannelArgFilter::Call::OnFinalize; diff --git a/test/core/surface/channel_init_test.cc b/test/core/surface/channel_init_test.cc index a90777973c3..a548b221eac 100644 --- a/test/core/surface/channel_init_test.cc +++ b/test/core/surface/channel_init_test.cc @@ -229,6 +229,7 @@ class TestFilter1 { static const NoInterceptor OnServerInitialMetadata; static const NoInterceptor OnServerTrailingMetadata; static const NoInterceptor OnClientToServerMessage; + static const NoInterceptor OnClientToServerHalfClose; static const NoInterceptor OnServerToClientMessage; static const NoInterceptor OnFinalize; }; @@ -245,6 +246,7 @@ const NoInterceptor TestFilter1::Call::OnClientInitialMetadata; const NoInterceptor TestFilter1::Call::OnServerInitialMetadata; const NoInterceptor TestFilter1::Call::OnServerTrailingMetadata; const NoInterceptor TestFilter1::Call::OnClientToServerMessage; +const NoInterceptor TestFilter1::Call::OnClientToServerHalfClose; const NoInterceptor TestFilter1::Call::OnServerToClientMessage; const NoInterceptor TestFilter1::Call::OnFinalize; diff --git a/test/core/transport/call_filters_test.cc b/test/core/transport/call_filters_test.cc index 8044aa85610..750f16b2ec1 100644 --- a/test/core/transport/call_filters_test.cc +++ b/test/core/transport/call_filters_test.cc @@ -1331,6 +1331,7 @@ TEST(CallFiltersTest, CanBuildStack) { void OnClientInitialMetadata(ClientMetadata&) {} void OnServerInitialMetadata(ServerMetadata&) {} void OnClientToServerMessage(Message&) {} + void OnClientToServerHalfClose() {} void OnServerToClientMessage(Message&) {} void OnServerTrailingMetadata(ServerMetadata&) {} void OnFinalize(const grpc_call_final_info*) {} @@ -1355,6 +1356,10 @@ TEST(CallFiltersTest, UnaryCall) { void OnClientToServerMessage(Message&, Filter* f) { f->steps.push_back(absl::StrCat(f->label, ":OnClientToServerMessage")); } + void OnClientToServerHalfClose(Filter* f) { + f->steps.push_back( + absl::StrCat(f->label, ":OnClientToServerHalfClose")); + } void OnServerToClientMessage(Message&, Filter* f) { f->steps.push_back(absl::StrCat(f->label, ":OnServerToClientMessage")); } diff --git a/test/core/transport/interception_chain_test.cc b/test/core/transport/interception_chain_test.cc index f72fd59e019..bf08abb61e3 100644 --- a/test/core/transport/interception_chain_test.cc +++ b/test/core/transport/interception_chain_test.cc @@ -82,6 +82,7 @@ class TestFilter { } static const NoInterceptor OnServerInitialMetadata; static const NoInterceptor OnClientToServerMessage; + static const NoInterceptor OnClientToServerHalfClose; static const NoInterceptor OnServerToClientMessage; static const NoInterceptor OnServerTrailingMetadata; static const NoInterceptor OnFinalize; @@ -102,6 +103,8 @@ const NoInterceptor TestFilter::Call::OnServerInitialMetadata; template const NoInterceptor TestFilter::Call::OnClientToServerMessage; template +const NoInterceptor TestFilter::Call::OnClientToServerHalfClose; +template const NoInterceptor TestFilter::Call::OnServerToClientMessage; template const NoInterceptor TestFilter::Call::OnServerTrailingMetadata; @@ -119,6 +122,7 @@ class FailsToInstantiateFilter { static const NoInterceptor OnClientInitialMetadata; static const NoInterceptor OnServerInitialMetadata; static const NoInterceptor OnClientToServerMessage; + static const NoInterceptor OnClientToServerHalfClose; static const NoInterceptor OnServerToClientMessage; static const NoInterceptor OnServerTrailingMetadata; static const NoInterceptor OnFinalize; @@ -138,6 +142,9 @@ const NoInterceptor FailsToInstantiateFilter::Call::OnServerInitialMetadata; template const NoInterceptor FailsToInstantiateFilter::Call::OnClientToServerMessage; template +const NoInterceptor + FailsToInstantiateFilter::Call::OnClientToServerHalfClose; +template const NoInterceptor FailsToInstantiateFilter::Call::OnServerToClientMessage; template const NoInterceptor FailsToInstantiateFilter::Call::OnServerTrailingMetadata;