diff --git a/src/core/ext/filters/backend_metrics/backend_metric_filter.cc b/src/core/ext/filters/backend_metrics/backend_metric_filter.cc index 33566871cdf..e2caef67939 100644 --- a/src/core/ext/filters/backend_metrics/backend_metric_filter.cc +++ b/src/core/ext/filters/backend_metrics/backend_metric_filter.cc @@ -25,6 +25,7 @@ #include #include "absl/strings/string_view.h" +#include "third_party/grpc/src/core/lib/channel/promise_based_filter.h" #include "upb/base/string_view.h" #include "upb/mem/arena.hpp" #include "xds/data/orca/v3/orca_load_report.upb.h" @@ -121,9 +122,9 @@ const grpc_channel_filter BackendMetricFilter::kFilter = MakePromiseBasedFilter( "backend_metric"); -absl::StatusOr BackendMetricFilter::Create( - const ChannelArgs&, ChannelFilter::Args) { - return BackendMetricFilter(); +absl::StatusOr> +BackendMetricFilter::Create(const ChannelArgs&, ChannelFilter::Args) { + return std::make_unique(); } void BackendMetricFilter::Call::OnServerTrailingMetadata(ServerMetadata& md) { diff --git a/src/core/ext/filters/backend_metrics/backend_metric_filter.h b/src/core/ext/filters/backend_metrics/backend_metric_filter.h index 2c71a9d19cb..d97b0c8cb65 100644 --- a/src/core/ext/filters/backend_metrics/backend_metric_filter.h +++ b/src/core/ext/filters/backend_metrics/backend_metric_filter.h @@ -35,8 +35,8 @@ class BackendMetricFilter : public ImplementChannelFilter { public: static const grpc_channel_filter kFilter; - static absl::StatusOr Create(const ChannelArgs& args, - ChannelFilter::Args); + static absl::StatusOr> Create( + const ChannelArgs& args, ChannelFilter::Args); class Call { public: diff --git a/src/core/ext/filters/channel_idle/legacy_channel_idle_filter.cc b/src/core/ext/filters/channel_idle/legacy_channel_idle_filter.cc index 7d7757ac081..d189d4080a7 100644 --- a/src/core/ext/filters/channel_idle/legacy_channel_idle_filter.cc +++ b/src/core/ext/filters/channel_idle/legacy_channel_idle_filter.cc @@ -25,6 +25,7 @@ #include "absl/base/thread_annotations.h" #include "absl/meta/type_traits.h" #include "absl/random/random.h" +#include "absl/status/statusor.h" #include "absl/types/optional.h" #include @@ -133,18 +134,17 @@ struct LegacyMaxAgeFilter::Config { // will be removed at that time also, so just disable the deprecation warning // for now. ABSL_INTERNAL_DISABLE_DEPRECATED_DECLARATION_WARNING -absl::StatusOr LegacyClientIdleFilter::Create( - const ChannelArgs& args, ChannelFilter::Args filter_args) { - LegacyClientIdleFilter filter(filter_args.channel_stack(), - GetClientIdleTimeout(args)); - return absl::StatusOr(std::move(filter)); +absl::StatusOr> +LegacyClientIdleFilter::Create(const ChannelArgs& args, + ChannelFilter::Args filter_args) { + return std::make_unique(filter_args.channel_stack(), + GetClientIdleTimeout(args)); } -absl::StatusOr LegacyMaxAgeFilter::Create( +absl::StatusOr> LegacyMaxAgeFilter::Create( const ChannelArgs& args, ChannelFilter::Args filter_args) { - LegacyMaxAgeFilter filter(filter_args.channel_stack(), - Config::FromChannelArgs(args)); - return absl::StatusOr(std::move(filter)); + return std::make_unique(filter_args.channel_stack(), + Config::FromChannelArgs(args)); } ABSL_INTERNAL_RESTORE_DEPRECATED_DECLARATION_WARNING diff --git a/src/core/ext/filters/channel_idle/legacy_channel_idle_filter.h b/src/core/ext/filters/channel_idle/legacy_channel_idle_filter.h index 8e06505d732..8e6215cebba 100644 --- a/src/core/ext/filters/channel_idle/legacy_channel_idle_filter.h +++ b/src/core/ext/filters/channel_idle/legacy_channel_idle_filter.h @@ -42,6 +42,11 @@ namespace grpc_core { class LegacyChannelIdleFilter : public ChannelFilter { public: + LegacyChannelIdleFilter(grpc_channel_stack* channel_stack, + Duration client_idle_timeout) + : channel_stack_(channel_stack), + client_idle_timeout_(client_idle_timeout) {} + ~LegacyChannelIdleFilter() override = default; LegacyChannelIdleFilter(const LegacyChannelIdleFilter&) = delete; @@ -59,11 +64,6 @@ class LegacyChannelIdleFilter : public ChannelFilter { using SingleSetActivityPtr = SingleSetPtr; - LegacyChannelIdleFilter(grpc_channel_stack* channel_stack, - Duration client_idle_timeout) - : channel_stack_(channel_stack), - client_idle_timeout_(client_idle_timeout) {} - grpc_channel_stack* channel_stack() { return channel_stack_; }; virtual void Shutdown(); @@ -94,10 +94,9 @@ class LegacyClientIdleFilter final : public LegacyChannelIdleFilter { public: static const grpc_channel_filter kFilter; - static absl::StatusOr Create( + static absl::StatusOr> Create( const ChannelArgs& args, ChannelFilter::Args filter_args); - private: using LegacyChannelIdleFilter::LegacyChannelIdleFilter; }; @@ -106,9 +105,12 @@ class LegacyMaxAgeFilter final : public LegacyChannelIdleFilter { static const grpc_channel_filter kFilter; struct Config; - static absl::StatusOr Create( + static absl::StatusOr> Create( const ChannelArgs& args, ChannelFilter::Args filter_args); + LegacyMaxAgeFilter(grpc_channel_stack* channel_stack, + const Config& max_age_config); + void PostInit() override; private: @@ -128,9 +130,6 @@ class LegacyMaxAgeFilter final : public LegacyChannelIdleFilter { LegacyMaxAgeFilter* filter_; }; - LegacyMaxAgeFilter(grpc_channel_stack* channel_stack, - const Config& max_age_config); - void Shutdown() override; SingleSetActivityPtr max_age_activity_; diff --git a/src/core/ext/filters/fault_injection/fault_injection_filter.cc b/src/core/ext/filters/fault_injection/fault_injection_filter.cc index 2464d01d7db..87d5a4d2f1f 100644 --- a/src/core/ext/filters/fault_injection/fault_injection_filter.cc +++ b/src/core/ext/filters/fault_injection/fault_injection_filter.cc @@ -29,6 +29,7 @@ #include "absl/meta/type_traits.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" @@ -135,16 +136,16 @@ class FaultInjectionFilter::InjectionDecision { FaultHandle active_fault_{false}; }; -absl::StatusOr FaultInjectionFilter::Create( - const ChannelArgs&, ChannelFilter::Args filter_args) { - return FaultInjectionFilter(filter_args); +absl::StatusOr> +FaultInjectionFilter::Create(const ChannelArgs&, + ChannelFilter::Args filter_args) { + return std::make_unique(filter_args); } FaultInjectionFilter::FaultInjectionFilter(ChannelFilter::Args filter_args) : index_(filter_args.instance_id()), service_config_parser_index_( - FaultInjectionServiceConfigParser::ParserIndex()), - mu_(new Mutex) {} + FaultInjectionServiceConfigParser::ParserIndex()) {} // Construct a promise for one call. ArenaPromise FaultInjectionFilter::Call::OnClientInitialMetadata( @@ -226,7 +227,7 @@ FaultInjectionFilter::MakeInjectionDecision( bool delay_request = delay != Duration::Zero(); bool abort_request = abort_code != GRPC_STATUS_OK; if (delay_request || abort_request) { - MutexLock lock(mu_.get()); + MutexLock lock(&mu_); if (delay_request) { delay_request = UnderFraction(&delay_rand_generator_, delay_percentage_numerator, diff --git a/src/core/ext/filters/fault_injection/fault_injection_filter.h b/src/core/ext/filters/fault_injection/fault_injection_filter.h index e2918e854f5..b6b1b811cde 100644 --- a/src/core/ext/filters/fault_injection/fault_injection_filter.h +++ b/src/core/ext/filters/fault_injection/fault_injection_filter.h @@ -45,9 +45,11 @@ class FaultInjectionFilter public: static const grpc_channel_filter kFilter; - static absl::StatusOr Create( + static absl::StatusOr> Create( const ChannelArgs& args, ChannelFilter::Args filter_args); + explicit FaultInjectionFilter(ChannelFilter::Args filter_args); + // Construct a promise for one call. class Call { public: @@ -61,8 +63,6 @@ class FaultInjectionFilter }; private: - explicit FaultInjectionFilter(ChannelFilter::Args filter_args); - class InjectionDecision; InjectionDecision MakeInjectionDecision( const ClientMetadata& initial_metadata); @@ -70,7 +70,7 @@ class FaultInjectionFilter // The relative index of instances of the same filter. size_t index_; const size_t service_config_parser_index_; - std::unique_ptr mu_; + Mutex mu_; absl::InsecureBitGen abort_rand_generator_ ABSL_GUARDED_BY(mu_); absl::InsecureBitGen delay_rand_generator_ ABSL_GUARDED_BY(mu_); }; diff --git a/src/core/ext/filters/http/client/http_client_filter.cc b/src/core/ext/filters/http/client/http_client_filter.cc index ac8004cdd3e..390df545efe 100644 --- a/src/core/ext/filters/http/client/http_client_filter.cc +++ b/src/core/ext/filters/http/client/http_client_filter.cc @@ -27,6 +27,7 @@ #include #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" @@ -136,16 +137,16 @@ HttpClientFilter::HttpClientFilter(HttpSchemeMetadata::ValueType scheme, Slice user_agent, bool test_only_use_put_requests) : scheme_(scheme), - user_agent_(std::move(user_agent)), - test_only_use_put_requests_(test_only_use_put_requests) {} + test_only_use_put_requests_(test_only_use_put_requests), + user_agent_(std::move(user_agent)) {} -absl::StatusOr HttpClientFilter::Create( +absl::StatusOr> HttpClientFilter::Create( const ChannelArgs& args, ChannelFilter::Args) { auto* transport = args.GetObject(); if (transport == nullptr) { return absl::InvalidArgumentError("HttpClientFilter needs a transport"); } - return HttpClientFilter( + return std::make_unique( SchemeFromArgs(args), UserAgentFromArgs(args, transport->GetTransportName()), args.GetInt(GRPC_ARG_TEST_ONLY_USE_PUT_REQUESTS).value_or(false)); diff --git a/src/core/ext/filters/http/client/http_client_filter.h b/src/core/ext/filters/http/client/http_client_filter.h index dbfa9f7f2d0..f5d7875da5e 100644 --- a/src/core/ext/filters/http/client/http_client_filter.h +++ b/src/core/ext/filters/http/client/http_client_filter.h @@ -35,9 +35,12 @@ class HttpClientFilter : public ImplementChannelFilter { public: static const grpc_channel_filter kFilter; - static absl::StatusOr Create( + static absl::StatusOr> Create( const ChannelArgs& args, ChannelFilter::Args filter_args); + HttpClientFilter(HttpSchemeMetadata::ValueType scheme, Slice user_agent, + bool test_only_use_put_requests); + class Call { public: void OnClientInitialMetadata(ClientMetadata& md, HttpClientFilter* filter); @@ -49,12 +52,9 @@ class HttpClientFilter : public ImplementChannelFilter { }; private: - HttpClientFilter(HttpSchemeMetadata::ValueType scheme, Slice user_agent, - bool test_only_use_put_requests); - HttpSchemeMetadata::ValueType scheme_; - Slice user_agent_; bool test_only_use_put_requests_; + Slice user_agent_; }; // A test-only channel arg to allow testing gRPC Core server behavior on PUT diff --git a/src/core/ext/filters/http/client_authority_filter.cc b/src/core/ext/filters/http/client_authority_filter.cc index 3df6d275f6a..1d5258493e4 100644 --- a/src/core/ext/filters/http/client_authority_filter.cc +++ b/src/core/ext/filters/http/client_authority_filter.cc @@ -43,8 +43,8 @@ const NoInterceptor ClientAuthorityFilter::Call::OnClientToServerMessage; const NoInterceptor ClientAuthorityFilter::Call::OnServerToClientMessage; const NoInterceptor ClientAuthorityFilter::Call::OnFinalize; -absl::StatusOr ClientAuthorityFilter::Create( - const ChannelArgs& args, ChannelFilter::Args) { +absl::StatusOr> +ClientAuthorityFilter::Create(const ChannelArgs& args, ChannelFilter::Args) { absl::optional default_authority = args.GetString(GRPC_ARG_DEFAULT_AUTHORITY); if (!default_authority.has_value()) { @@ -52,7 +52,8 @@ absl::StatusOr ClientAuthorityFilter::Create( "GRPC_ARG_DEFAULT_AUTHORITY string channel arg. not found. Note that " "direct channels must explicitly specify a value for this argument."); } - return ClientAuthorityFilter(Slice::FromCopiedString(*default_authority)); + return std::make_unique( + Slice::FromCopiedString(*default_authority)); } void ClientAuthorityFilter::Call::OnClientInitialMetadata( diff --git a/src/core/ext/filters/http/client_authority_filter.h b/src/core/ext/filters/http/client_authority_filter.h index 064e0a0a450..44229c6cdde 100644 --- a/src/core/ext/filters/http/client_authority_filter.h +++ b/src/core/ext/filters/http/client_authority_filter.h @@ -39,8 +39,11 @@ class ClientAuthorityFilter final public: static const grpc_channel_filter kFilter; - static absl::StatusOr Create(const ChannelArgs& args, - ChannelFilter::Args); + static absl::StatusOr> Create( + const ChannelArgs& args, ChannelFilter::Args); + + explicit ClientAuthorityFilter(Slice default_authority) + : default_authority_(std::move(default_authority)) {} class Call { public: @@ -54,8 +57,6 @@ class ClientAuthorityFilter final }; private: - explicit ClientAuthorityFilter(Slice default_authority) - : default_authority_(std::move(default_authority)) {} Slice default_authority_; }; diff --git a/src/core/ext/filters/http/message_compress/compression_filter.cc b/src/core/ext/filters/http/message_compress/compression_filter.cc index c9b3acaf1f7..1ad6081ffac 100644 --- a/src/core/ext/filters/http/message_compress/compression_filter.cc +++ b/src/core/ext/filters/http/message_compress/compression_filter.cc @@ -72,14 +72,14 @@ const grpc_channel_filter ServerCompressionFilter::kFilter = kFilterExaminesInboundMessages | kFilterExaminesOutboundMessages>("compression"); -absl::StatusOr ClientCompressionFilter::Create( - const ChannelArgs& args, ChannelFilter::Args) { - return ClientCompressionFilter(args); +absl::StatusOr> +ClientCompressionFilter::Create(const ChannelArgs& args, ChannelFilter::Args) { + return std::make_unique(args); } -absl::StatusOr ServerCompressionFilter::Create( - const ChannelArgs& args, ChannelFilter::Args) { - return ServerCompressionFilter(args); +absl::StatusOr> +ServerCompressionFilter::Create(const ChannelArgs& args, ChannelFilter::Args) { + return std::make_unique(args); } ChannelCompression::ChannelCompression(const ChannelArgs& args) diff --git a/src/core/ext/filters/http/message_compress/compression_filter.h b/src/core/ext/filters/http/message_compress/compression_filter.h index adf9fb49355..99e57a0ac1d 100644 --- a/src/core/ext/filters/http/message_compress/compression_filter.h +++ b/src/core/ext/filters/http/message_compress/compression_filter.h @@ -110,9 +110,12 @@ class ClientCompressionFilter final public: static const grpc_channel_filter kFilter; - static absl::StatusOr Create( + static absl::StatusOr> Create( const ChannelArgs& args, ChannelFilter::Args filter_args); + explicit ClientCompressionFilter(const ChannelArgs& args) + : compression_engine_(args) {} + // Construct a promise for one call. class Call { public: @@ -135,9 +138,6 @@ class ClientCompressionFilter final }; private: - explicit ClientCompressionFilter(const ChannelArgs& args) - : compression_engine_(args) {} - ChannelCompression compression_engine_; }; @@ -146,9 +146,12 @@ class ServerCompressionFilter final public: static const grpc_channel_filter kFilter; - static absl::StatusOr Create( + static absl::StatusOr> Create( const ChannelArgs& args, ChannelFilter::Args filter_args); + explicit ServerCompressionFilter(const ChannelArgs& args) + : compression_engine_(args) {} + // Construct a promise for one call. class Call { public: @@ -171,9 +174,6 @@ class ServerCompressionFilter final }; private: - explicit ServerCompressionFilter(const ChannelArgs& args) - : compression_engine_(args) {} - ChannelCompression compression_engine_; }; diff --git a/src/core/ext/filters/http/server/http_server_filter.cc b/src/core/ext/filters/http/server/http_server_filter.cc index 775eebcb1d3..925cc73c23e 100644 --- a/src/core/ext/filters/http/server/http_server_filter.cc +++ b/src/core/ext/filters/http/server/http_server_filter.cc @@ -152,9 +152,9 @@ void HttpServerFilter::Call::OnServerTrailingMetadata(ServerMetadata& md) { FilterOutgoingMetadata(&md); } -absl::StatusOr HttpServerFilter::Create( +absl::StatusOr> HttpServerFilter::Create( const ChannelArgs& args, ChannelFilter::Args) { - return HttpServerFilter( + return std::make_unique( args.GetBool(GRPC_ARG_SURFACE_USER_AGENT).value_or(true), args.GetBool( GRPC_ARG_DO_NOT_USE_UNLESS_YOU_HAVE_PERMISSION_FROM_GRPC_TEAM_ALLOW_BROKEN_PUT_REQUESTS) diff --git a/src/core/ext/filters/http/server/http_server_filter.h b/src/core/ext/filters/http/server/http_server_filter.h index a87c83518ee..282973ddecd 100644 --- a/src/core/ext/filters/http/server/http_server_filter.h +++ b/src/core/ext/filters/http/server/http_server_filter.h @@ -36,9 +36,13 @@ class HttpServerFilter : public ImplementChannelFilter { public: static const grpc_channel_filter kFilter; - static absl::StatusOr Create( + static absl::StatusOr> Create( const ChannelArgs& args, ChannelFilter::Args filter_args); + HttpServerFilter(bool surface_user_agent, bool allow_put_requests) + : surface_user_agent_(surface_user_agent), + allow_put_requests_(allow_put_requests) {} + class Call { public: ServerMetadataHandle OnClientInitialMetadata(ClientMetadata& md, @@ -51,10 +55,6 @@ class HttpServerFilter : public ImplementChannelFilter { }; private: - HttpServerFilter(bool surface_user_agent, bool allow_put_requests) - : surface_user_agent_(surface_user_agent), - allow_put_requests_(allow_put_requests) {} - bool surface_user_agent_; bool allow_put_requests_; }; diff --git a/src/core/ext/filters/load_reporting/server_load_reporting_filter.cc b/src/core/ext/filters/load_reporting/server_load_reporting_filter.cc index 624ca70e79e..7937ab6fe74 100644 --- a/src/core/ext/filters/load_reporting/server_load_reporting_filter.cc +++ b/src/core/ext/filters/load_reporting/server_load_reporting_filter.cc @@ -76,10 +76,11 @@ const NoInterceptor ServerLoadReportingFilter::Call::OnServerInitialMetadata; const NoInterceptor ServerLoadReportingFilter::Call::OnClientToServerMessage; const NoInterceptor ServerLoadReportingFilter::Call::OnServerToClientMessage; -absl::StatusOr ServerLoadReportingFilter::Create( - const ChannelArgs& channel_args, ChannelFilter::Args) { +absl::StatusOr> +ServerLoadReportingFilter::Create(const ChannelArgs& channel_args, + ChannelFilter::Args) { // Find and record the peer_identity. - ServerLoadReportingFilter filter; + auto filter = std::make_unique(); const auto* auth_context = channel_args.GetObject(); if (auth_context != nullptr && grpc_auth_context_peer_is_authenticated(auth_context)) { @@ -88,7 +89,7 @@ absl::StatusOr ServerLoadReportingFilter::Create( const grpc_auth_property* auth_property = grpc_auth_property_iterator_next(&auth_it); if (auth_property != nullptr) { - filter.peer_identity_ = + filter->peer_identity_ = std::string(auth_property->value, auth_property->value_length); } } diff --git a/src/core/ext/filters/load_reporting/server_load_reporting_filter.h b/src/core/ext/filters/load_reporting/server_load_reporting_filter.h index ddb8d0f2482..f11c8c38bcf 100644 --- a/src/core/ext/filters/load_reporting/server_load_reporting_filter.h +++ b/src/core/ext/filters/load_reporting/server_load_reporting_filter.h @@ -39,7 +39,7 @@ class ServerLoadReportingFilter public: static const grpc_channel_filter kFilter; - static absl::StatusOr Create( + static absl::StatusOr> Create( const ChannelArgs& args, ChannelFilter::Args); // Getters. diff --git a/src/core/ext/filters/logging/logging_filter.cc b/src/core/ext/filters/logging/logging_filter.cc index 89fe46bcb66..1c76e64b06a 100644 --- a/src/core/ext/filters/logging/logging_filter.cc +++ b/src/core/ext/filters/logging/logging_filter.cc @@ -342,21 +342,23 @@ class CallData { } // namespace -absl::StatusOr ClientLoggingFilter::Create( - const ChannelArgs& args, ChannelFilter::Args /*filter_args*/) { +absl::StatusOr> +ClientLoggingFilter::Create(const ChannelArgs& args, + ChannelFilter::Args /*filter_args*/) { absl::optional default_authority = args.GetString(GRPC_ARG_DEFAULT_AUTHORITY); if (default_authority.has_value()) { - return ClientLoggingFilter(std::string(default_authority.value())); + return std::make_unique( + std::string(default_authority.value())); } absl::optional server_uri = args.GetOwnedString(GRPC_ARG_SERVER_URI); if (server_uri.has_value()) { - return ClientLoggingFilter( + return std::make_unique( CoreConfiguration::Get().resolver_registry().GetDefaultAuthority( *server_uri)); } - return ClientLoggingFilter(""); + return std::make_unique(""); } // Construct a promise for one call. @@ -445,9 +447,10 @@ const grpc_channel_filter ClientLoggingFilter::kFilter = kFilterExaminesInboundMessages | kFilterExaminesOutboundMessages>("logging"); -absl::StatusOr ServerLoggingFilter::Create( - const ChannelArgs& /*args*/, ChannelFilter::Args /*filter_args*/) { - return ServerLoggingFilter(); +absl::StatusOr> +ServerLoggingFilter::Create(const ChannelArgs& /*args*/, + ChannelFilter::Args /*filter_args*/) { + return std::make_unique(); } // Construct a promise for one call. diff --git a/src/core/ext/filters/logging/logging_filter.h b/src/core/ext/filters/logging/logging_filter.h index 6a27a653860..7d42abbc337 100644 --- a/src/core/ext/filters/logging/logging_filter.h +++ b/src/core/ext/filters/logging/logging_filter.h @@ -39,24 +39,25 @@ class ClientLoggingFilter final : public ChannelFilter { public: static const grpc_channel_filter kFilter; - static absl::StatusOr Create( + static absl::StatusOr> Create( const ChannelArgs& args, ChannelFilter::Args /*filter_args*/); + explicit ClientLoggingFilter(std::string default_authority) + : default_authority_(std::move(default_authority)) {} + // Construct a promise for one call. ArenaPromise MakeCallPromise( CallArgs call_args, NextPromiseFactory next_promise_factory) override; private: - explicit ClientLoggingFilter(std::string default_authority) - : default_authority_(std::move(default_authority)) {} - std::string default_authority_; + const std::string default_authority_; }; class ServerLoggingFilter final : public ChannelFilter { public: static const grpc_channel_filter kFilter; - static absl::StatusOr Create( + static absl::StatusOr> Create( const ChannelArgs& args, ChannelFilter::Args /*filter_args*/); // Construct a promise for one call. diff --git a/src/core/ext/filters/message_size/message_size_filter.cc b/src/core/ext/filters/message_size/message_size_filter.cc index d975282432a..379d4944788 100644 --- a/src/core/ext/filters/message_size/message_size_filter.cc +++ b/src/core/ext/filters/message_size/message_size_filter.cc @@ -25,7 +25,6 @@ #include "absl/strings/str_format.h" -#include #include #include #include @@ -142,19 +141,20 @@ const grpc_channel_filter ClientMessageSizeFilter::kFilter = MakePromiseBasedFilter("message_size"); + const grpc_channel_filter ServerMessageSizeFilter::kFilter = MakePromiseBasedFilter("message_size"); -absl::StatusOr ClientMessageSizeFilter::Create( - const ChannelArgs& args, ChannelFilter::Args) { - return ClientMessageSizeFilter(args); +absl::StatusOr> +ClientMessageSizeFilter::Create(const ChannelArgs& args, ChannelFilter::Args) { + return std::make_unique(args); } -absl::StatusOr ServerMessageSizeFilter::Create( - const ChannelArgs& args, ChannelFilter::Args) { - return ServerMessageSizeFilter(args); +absl::StatusOr> +ServerMessageSizeFilter::Create(const ChannelArgs& args, ChannelFilter::Args) { + return std::make_unique(args); } namespace { diff --git a/src/core/ext/filters/message_size/message_size_filter.h b/src/core/ext/filters/message_size/message_size_filter.h index 3a9ea1f2064..89d21201a5c 100644 --- a/src/core/ext/filters/message_size/message_size_filter.h +++ b/src/core/ext/filters/message_size/message_size_filter.h @@ -91,9 +91,12 @@ class ServerMessageSizeFilter final public: static const grpc_channel_filter kFilter; - static absl::StatusOr Create( + static absl::StatusOr> Create( const ChannelArgs& args, ChannelFilter::Args filter_args); + explicit ServerMessageSizeFilter(const ChannelArgs& args) + : parsed_config_(MessageSizeParsedConfig::GetFromChannelArgs(args)) {} + class Call { public: static const NoInterceptor OnClientInitialMetadata; @@ -107,8 +110,6 @@ class ServerMessageSizeFilter final }; private: - explicit ServerMessageSizeFilter(const ChannelArgs& args) - : parsed_config_(MessageSizeParsedConfig::GetFromChannelArgs(args)) {} const MessageSizeParsedConfig parsed_config_; }; @@ -117,9 +118,12 @@ class ClientMessageSizeFilter final public: static const grpc_channel_filter kFilter; - static absl::StatusOr Create( + static absl::StatusOr> Create( const ChannelArgs& args, ChannelFilter::Args filter_args); + explicit ClientMessageSizeFilter(const ChannelArgs& args) + : parsed_config_(MessageSizeParsedConfig::GetFromChannelArgs(args)) {} + class Call { public: explicit Call(ClientMessageSizeFilter* filter); @@ -136,8 +140,6 @@ class ClientMessageSizeFilter final }; private: - explicit ClientMessageSizeFilter(const ChannelArgs& args) - : parsed_config_(MessageSizeParsedConfig::GetFromChannelArgs(args)) {} const size_t service_config_parser_index_{MessageSizeParser::ParserIndex()}; const MessageSizeParsedConfig parsed_config_; }; diff --git a/src/core/ext/filters/rbac/rbac_filter.cc b/src/core/ext/filters/rbac/rbac_filter.cc index 68c5e3d7699..7c75f46ae7d 100644 --- a/src/core/ext/filters/rbac/rbac_filter.cc +++ b/src/core/ext/filters/rbac/rbac_filter.cc @@ -82,14 +82,21 @@ RbacFilter::RbacFilter(size_t index, service_config_parser_index_(RbacServiceConfigParser::ParserIndex()), per_channel_evaluate_args_(std::move(per_channel_evaluate_args)) {} -absl::StatusOr RbacFilter::Create(const ChannelArgs& args, - ChannelFilter::Args filter_args) { +absl::StatusOr> RbacFilter::Create( + const ChannelArgs& args, ChannelFilter::Args filter_args) { auto* auth_context = args.GetObject(); if (auth_context == nullptr) { return GRPC_ERROR_CREATE("No auth context found"); } - return RbacFilter(filter_args.instance_id(), - EvaluateArgs::PerChannelArgs(auth_context, args)); + auto* transport = args.GetObject(); + if (transport == nullptr) { + // This should never happen since the transport is always set on the server + // side. + return GRPC_ERROR_CREATE("No transport configured"); + } + return std::make_unique( + filter_args.instance_id(), + EvaluateArgs::PerChannelArgs(auth_context, args)); } void RbacFilterRegister(CoreConfiguration::Builder* builder) { diff --git a/src/core/ext/filters/rbac/rbac_filter.h b/src/core/ext/filters/rbac/rbac_filter.h index f2208753b6b..a4c41cbdd0b 100644 --- a/src/core/ext/filters/rbac/rbac_filter.h +++ b/src/core/ext/filters/rbac/rbac_filter.h @@ -42,8 +42,11 @@ class RbacFilter : public ImplementChannelFilter { // and enforces the RBAC policy. static const grpc_channel_filter kFilterVtable; - static absl::StatusOr Create(const ChannelArgs& args, - ChannelFilter::Args filter_args); + static absl::StatusOr> Create( + const ChannelArgs& args, ChannelFilter::Args filter_args); + + RbacFilter(size_t index, + EvaluateArgs::PerChannelArgs per_channel_evaluate_args); class Call { public: @@ -57,9 +60,6 @@ class RbacFilter : public ImplementChannelFilter { }; private: - RbacFilter(size_t index, - EvaluateArgs::PerChannelArgs per_channel_evaluate_args); - // The index of this filter instance among instances of the same filter. size_t index_; // Assigned index for service config data from the parser. diff --git a/src/core/ext/filters/server_config_selector/server_config_selector_filter.cc b/src/core/ext/filters/server_config_selector/server_config_selector_filter.cc index 0be80a238a0..0d37c5c646c 100644 --- a/src/core/ext/filters/server_config_selector/server_config_selector_filter.cc +++ b/src/core/ext/filters/server_config_selector/server_config_selector_filter.cc @@ -49,19 +49,22 @@ namespace grpc_core { namespace { class ServerConfigSelectorFilter final - : public ImplementChannelFilter { + : public ImplementChannelFilter, + public InternallyRefCounted { public: - ~ServerConfigSelectorFilter() override; + explicit ServerConfigSelectorFilter( + RefCountedPtr + server_config_selector_provider); ServerConfigSelectorFilter(const ServerConfigSelectorFilter&) = delete; ServerConfigSelectorFilter& operator=(const ServerConfigSelectorFilter&) = delete; - ServerConfigSelectorFilter(ServerConfigSelectorFilter&&) = default; - ServerConfigSelectorFilter& operator=(ServerConfigSelectorFilter&&) = default; - static absl::StatusOr Create( + static absl::StatusOr> Create( const ChannelArgs& args, ChannelFilter::Args); + void Orphan() override; + class Call { public: absl::Status OnClientInitialMetadata(ClientMetadata& md, @@ -74,70 +77,66 @@ class ServerConfigSelectorFilter final }; absl::StatusOr> config_selector() { - MutexLock lock(&state_->mu); - return state_->config_selector.value(); + MutexLock lock(&mu_); + return config_selector_.value(); } private: - struct State { - Mutex mu; - absl::optional>> - config_selector ABSL_GUARDED_BY(mu); - }; class ServerConfigSelectorWatcher : public ServerConfigSelectorProvider::ServerConfigSelectorWatcher { public: - explicit ServerConfigSelectorWatcher(std::shared_ptr state) - : state_(state) {} + explicit ServerConfigSelectorWatcher( + RefCountedPtr filter) + : filter_(filter) {} void OnServerConfigSelectorUpdate( absl::StatusOr> update) override { - MutexLock lock(&state_->mu); - state_->config_selector = std::move(update); + MutexLock lock(&filter_->mu_); + filter_->config_selector_ = std::move(update); } private: - std::shared_ptr state_; + RefCountedPtr filter_; }; - explicit ServerConfigSelectorFilter( - RefCountedPtr - server_config_selector_provider); - RefCountedPtr server_config_selector_provider_; - std::shared_ptr state_; + Mutex mu_; + absl::optional>> + config_selector_ ABSL_GUARDED_BY(mu_); }; -absl::StatusOr ServerConfigSelectorFilter::Create( - const ChannelArgs& args, ChannelFilter::Args) { +absl::StatusOr> +ServerConfigSelectorFilter::Create(const ChannelArgs& args, + ChannelFilter::Args) { ServerConfigSelectorProvider* server_config_selector_provider = args.GetObject(); if (server_config_selector_provider == nullptr) { return absl::UnknownError("No ServerConfigSelectorProvider object found"); } - return ServerConfigSelectorFilter(server_config_selector_provider->Ref()); + return MakeOrphanable( + server_config_selector_provider->Ref()); } ServerConfigSelectorFilter::ServerConfigSelectorFilter( RefCountedPtr server_config_selector_provider) : server_config_selector_provider_( - std::move(server_config_selector_provider)), - state_(std::make_shared()) { + std::move(server_config_selector_provider)) { GPR_ASSERT(server_config_selector_provider_ != nullptr); auto server_config_selector_watcher = - std::make_unique(state_); + std::make_unique(Ref()); auto config_selector = server_config_selector_provider_->Watch( std::move(server_config_selector_watcher)); - MutexLock lock(&state_->mu); + MutexLock lock(&mu_); // It's possible for the watcher to have already updated config_selector_ - if (!state_->config_selector.has_value()) { - state_->config_selector = std::move(config_selector); + if (!config_selector_.has_value()) { + config_selector_ = std::move(config_selector); } } -ServerConfigSelectorFilter::~ServerConfigSelectorFilter() { +void ServerConfigSelectorFilter::Orphan() { if (server_config_selector_provider_ != nullptr) { server_config_selector_provider_->CancelWatch(); } + Unref(); } absl::Status ServerConfigSelectorFilter::Call::OnClientInitialMetadata( diff --git a/src/core/ext/filters/stateful_session/stateful_session_filter.cc b/src/core/ext/filters/stateful_session/stateful_session_filter.cc index 9345edeed7d..b64affbb61f 100644 --- a/src/core/ext/filters/stateful_session/stateful_session_filter.cc +++ b/src/core/ext/filters/stateful_session/stateful_session_filter.cc @@ -72,9 +72,10 @@ const grpc_channel_filter StatefulSessionFilter::kFilter = kFilterExaminesServerInitialMetadata>( "stateful_session_filter"); -absl::StatusOr StatefulSessionFilter::Create( - const ChannelArgs&, ChannelFilter::Args filter_args) { - return StatefulSessionFilter(filter_args); +absl::StatusOr> +StatefulSessionFilter::Create(const ChannelArgs&, + ChannelFilter::Args filter_args) { + return std::make_unique(filter_args); } StatefulSessionFilter::StatefulSessionFilter(ChannelFilter::Args filter_args) diff --git a/src/core/ext/filters/stateful_session/stateful_session_filter.h b/src/core/ext/filters/stateful_session/stateful_session_filter.h index 1fc27e1508b..5cd534843aa 100644 --- a/src/core/ext/filters/stateful_session/stateful_session_filter.h +++ b/src/core/ext/filters/stateful_session/stateful_session_filter.h @@ -74,9 +74,11 @@ class StatefulSessionFilter public: static const grpc_channel_filter kFilter; - static absl::StatusOr Create( + static absl::StatusOr> Create( const ChannelArgs& args, ChannelFilter::Args filter_args); + explicit StatefulSessionFilter(ChannelFilter::Args filter_args); + class Call { public: void OnClientInitialMetadata(ClientMetadata& md, @@ -97,7 +99,6 @@ class StatefulSessionFilter }; private: - explicit StatefulSessionFilter(ChannelFilter::Args filter_args); // The relative index of instances of the same filter. const size_t index_; // Index of the service config parser. diff --git a/src/core/lib/channel/promise_based_filter.cc b/src/core/lib/channel/promise_based_filter.cc index 746e171334b..9e47aacae55 100644 --- a/src/core/lib/channel/promise_based_filter.cc +++ b/src/core/lib/channel/promise_based_filter.cc @@ -106,7 +106,7 @@ BaseCallData::BaseCallData( ? arena_->New(this, make_recv_interceptor()) : nullptr), event_engine_( - static_cast(elem->channel_data) + ChannelFilterFromElem(elem) ->hack_until_per_channel_stack_event_engines_land_get_event_engine()) { } @@ -1572,7 +1572,7 @@ void ClientCallData::Cancel(grpc_error_handle error, Flusher* flusher) { // metadata and return some trailing metadata. void ClientCallData::StartPromise(Flusher* flusher) { GPR_ASSERT(send_initial_state_ == SendInitialState::kQueued); - ChannelFilter* filter = static_cast(elem()->channel_data); + ChannelFilter* filter = promise_filter_detail::ChannelFilterFromElem(elem()); // Construct the promise. PollContext ctx(this, flusher); @@ -2369,7 +2369,7 @@ void ServerCallData::RecvInitialMetadataReady(grpc_error_handle error) { // Start the promise. ScopedContext context(this); // Construct the promise. - ChannelFilter* filter = static_cast(elem()->channel_data); + ChannelFilter* filter = promise_filter_detail::ChannelFilterFromElem(elem()); FakeActivity(this).Run([this, filter] { promise_ = filter->MakeCallPromise( CallArgs{WrapMetadata(recv_initial_metadata_), diff --git a/src/core/lib/channel/promise_based_filter.h b/src/core/lib/channel/promise_based_filter.h index 347f2004b8c..491cbe9c144 100644 --- a/src/core/lib/channel/promise_based_filter.h +++ b/src/core/lib/channel/promise_based_filter.h @@ -1841,6 +1841,15 @@ struct BaseCallDataMethods { } }; +// The type of object returned by a filter's Create method. +template +using CreatedType = typename decltype(T::Create(ChannelArgs(), {}))::value_type; + +template +inline ChannelFilter* ChannelFilterFromElem(GrpcChannelOrCallElement* elem) { + return *static_cast(elem->channel_data); +} + template struct CallDataFilterWithFlagsMethods { static absl::Status InitCallElem(grpc_call_element* elem, @@ -1865,32 +1874,25 @@ struct ChannelFilterMethods { static ArenaPromise MakeCallPromise( grpc_channel_element* elem, CallArgs call_args, NextPromiseFactory next_promise_factory) { - return static_cast(elem->channel_data) - ->MakeCallPromise(std::move(call_args), - std::move(next_promise_factory)); + return ChannelFilterFromElem(elem)->MakeCallPromise( + std::move(call_args), std::move(next_promise_factory)); } static void StartTransportOp(grpc_channel_element* elem, grpc_transport_op* op) { - if (!static_cast(elem->channel_data) - ->StartTransportOp(op)) { + if (!ChannelFilterFromElem(elem)->StartTransportOp(op)) { grpc_channel_next_op(elem, op); } } static void PostInitChannelElem(grpc_channel_stack*, grpc_channel_element* elem) { - static_cast(elem->channel_data)->PostInit(); - } - - static void DestroyChannelElem(grpc_channel_element* elem) { - static_cast(elem->channel_data)->~ChannelFilter(); + ChannelFilterFromElem(elem)->PostInit(); } static void GetChannelInfo(grpc_channel_element* elem, const grpc_channel_info* info) { - if (!static_cast(elem->channel_data) - ->GetChannelInfo(info)) { + if (!ChannelFilterFromElem(elem)->GetChannelInfo(info)) { grpc_channel_next_get_info(elem, info); } } @@ -1904,15 +1906,16 @@ struct ChannelFilterWithFlagsMethods { auto status = F::Create(args->channel_args, ChannelFilter::Args(args->channel_stack, elem)); if (!status.ok()) { - static_assert( - sizeof(promise_filter_detail::InvalidChannelFilter) <= sizeof(F), - "InvalidChannelFilter must fit in F"); - new (elem->channel_data) promise_filter_detail::InvalidChannelFilter(); + new (elem->channel_data) F*(nullptr); return absl_status_to_grpc_error(status.status()); } - new (elem->channel_data) F(std::move(*status)); + new (elem->channel_data) F*(status->release()); return absl::OkStatus(); } + + static void DestroyChannelElem(grpc_channel_element* elem) { + CreatedType channel_elem(DownCast(ChannelFilterFromElem(elem))); + } }; } // namespace promise_filter_detail @@ -1958,7 +1961,8 @@ MakePromiseBasedFilter(const char* name) { // post_init_channel_elem promise_filter_detail::ChannelFilterMethods::PostInitChannelElem, // destroy_channel_elem - promise_filter_detail::ChannelFilterMethods::DestroyChannelElem, + promise_filter_detail::ChannelFilterWithFlagsMethods< + F, kFlags>::DestroyChannelElem, // get_channel_info promise_filter_detail::ChannelFilterMethods::GetChannelInfo, // name @@ -2004,7 +2008,8 @@ MakePromiseBasedFilter(const char* name) { // post_init_channel_elem promise_filter_detail::ChannelFilterMethods::PostInitChannelElem, // destroy_channel_elem - promise_filter_detail::ChannelFilterMethods::DestroyChannelElem, + promise_filter_detail::ChannelFilterWithFlagsMethods< + F, kFlags>::DestroyChannelElem, // get_channel_info promise_filter_detail::ChannelFilterMethods::GetChannelInfo, // name @@ -2046,7 +2051,8 @@ MakePromiseBasedFilter(const char* name) { // post_init_channel_elem promise_filter_detail::ChannelFilterMethods::PostInitChannelElem, // destroy_channel_elem - promise_filter_detail::ChannelFilterMethods::DestroyChannelElem, + promise_filter_detail::ChannelFilterWithFlagsMethods< + F, kFlags>::DestroyChannelElem, // get_channel_info promise_filter_detail::ChannelFilterMethods::GetChannelInfo, // name diff --git a/src/core/lib/channel/server_call_tracer_filter.cc b/src/core/lib/channel/server_call_tracer_filter.cc index 982f160468e..5effe8f8538 100644 --- a/src/core/lib/channel/server_call_tracer_filter.cc +++ b/src/core/lib/channel/server_call_tracer_filter.cc @@ -17,6 +17,7 @@ #include "src/core/lib/channel/server_call_tracer_filter.h" #include +#include #include #include "absl/status/status.h" @@ -49,7 +50,7 @@ class ServerCallTracerFilter public: static const grpc_channel_filter kFilter; - static absl::StatusOr Create( + static absl::StatusOr> Create( const ChannelArgs& /*args*/, ChannelFilter::Args /*filter_args*/); class Call { @@ -98,9 +99,10 @@ const grpc_channel_filter ServerCallTracerFilter::kFilter = kFilterExaminesServerInitialMetadata>( "server_call_tracer"); -absl::StatusOr ServerCallTracerFilter::Create( - const ChannelArgs& /*args*/, ChannelFilter::Args /*filter_args*/) { - return ServerCallTracerFilter(); +absl::StatusOr> +ServerCallTracerFilter::Create(const ChannelArgs& /*args*/, + ChannelFilter::Args /*filter_args*/) { + return std::make_unique(); } } // namespace diff --git a/src/core/lib/security/authorization/grpc_server_authz_filter.cc b/src/core/lib/security/authorization/grpc_server_authz_filter.cc index 199c0a8fa5b..5474847701a 100644 --- a/src/core/lib/security/authorization/grpc_server_authz_filter.cc +++ b/src/core/lib/security/authorization/grpc_server_authz_filter.cc @@ -51,14 +51,14 @@ GrpcServerAuthzFilter::GrpcServerAuthzFilter( per_channel_evaluate_args_(auth_context_.get(), args), provider_(std::move(provider)) {} -absl::StatusOr GrpcServerAuthzFilter::Create( - const ChannelArgs& args, ChannelFilter::Args) { +absl::StatusOr> +GrpcServerAuthzFilter::Create(const ChannelArgs& args, ChannelFilter::Args) { auto* auth_context = args.GetObject(); auto* provider = args.GetObject(); if (provider == nullptr) { return absl::InvalidArgumentError("Failed to get authorization provider."); } - return GrpcServerAuthzFilter( + return std::make_unique( auth_context != nullptr ? auth_context->Ref() : nullptr, args, provider->Ref()); } diff --git a/src/core/lib/security/authorization/grpc_server_authz_filter.h b/src/core/lib/security/authorization/grpc_server_authz_filter.h index fd29197a3bf..b4b0a7463cd 100644 --- a/src/core/lib/security/authorization/grpc_server_authz_filter.h +++ b/src/core/lib/security/authorization/grpc_server_authz_filter.h @@ -37,8 +37,12 @@ class GrpcServerAuthzFilter final public: static const grpc_channel_filter kFilter; - static absl::StatusOr Create(const ChannelArgs& args, - ChannelFilter::Args); + static absl::StatusOr> Create( + const ChannelArgs& args, ChannelFilter::Args); + + GrpcServerAuthzFilter( + RefCountedPtr auth_context, const ChannelArgs& args, + RefCountedPtr provider); class Call { public: @@ -52,10 +56,6 @@ class GrpcServerAuthzFilter final }; private: - GrpcServerAuthzFilter( - RefCountedPtr auth_context, const ChannelArgs& args, - RefCountedPtr provider); - bool IsAuthorized(ClientMetadata& initial_metadata); RefCountedPtr auth_context_; diff --git a/src/core/lib/security/transport/auth_filters.h b/src/core/lib/security/transport/auth_filters.h index e4e80909526..ced5319bb55 100644 --- a/src/core/lib/security/transport/auth_filters.h +++ b/src/core/lib/security/transport/auth_filters.h @@ -42,18 +42,18 @@ class ClientAuthFilter final : public ChannelFilter { public: static const grpc_channel_filter kFilter; - static absl::StatusOr Create(const ChannelArgs& args, - ChannelFilter::Args); + ClientAuthFilter( + RefCountedPtr security_connector, + RefCountedPtr auth_context); + + static absl::StatusOr> Create( + const ChannelArgs& args, ChannelFilter::Args); // Construct a promise for one call. ArenaPromise MakeCallPromise( CallArgs call_args, NextPromiseFactory next_promise_factory) override; private: - ClientAuthFilter( - RefCountedPtr security_connector, - RefCountedPtr auth_context); - ArenaPromise> GetCallCredsMetadata( CallArgs call_args); @@ -63,9 +63,6 @@ class ClientAuthFilter final : public ChannelFilter { class ServerAuthFilter final : public ImplementChannelFilter { private: - ServerAuthFilter(RefCountedPtr server_credentials, - RefCountedPtr auth_context); - class RunApplicationCode { public: RunApplicationCode(ServerAuthFilter* filter, ClientMetadata& metadata); @@ -98,8 +95,11 @@ class ServerAuthFilter final : public ImplementChannelFilter { public: static const grpc_channel_filter kFilter; - static absl::StatusOr Create(const ChannelArgs& args, - ChannelFilter::Args); + ServerAuthFilter(RefCountedPtr server_credentials, + RefCountedPtr auth_context); + + static absl::StatusOr> Create( + const ChannelArgs& args, ChannelFilter::Args); class Call { public: diff --git a/src/core/lib/security/transport/client_auth_filter.cc b/src/core/lib/security/transport/client_auth_filter.cc index 1840522e6c7..400087dcedd 100644 --- a/src/core/lib/security/transport/client_auth_filter.cc +++ b/src/core/lib/security/transport/client_auth_filter.cc @@ -203,7 +203,7 @@ ArenaPromise ClientAuthFilter::MakeCallPromise( next_promise_factory); } -absl::StatusOr ClientAuthFilter::Create( +absl::StatusOr> ClientAuthFilter::Create( const ChannelArgs& args, ChannelFilter::Args) { auto* sc = args.GetObject(); if (sc == nullptr) { @@ -215,8 +215,9 @@ absl::StatusOr ClientAuthFilter::Create( return absl::InvalidArgumentError( "Auth context missing from client auth filter args"); } - return ClientAuthFilter(sc->RefAsSubclass(), - auth_context->Ref()); + return std::make_unique( + sc->RefAsSubclass(), + auth_context->Ref()); } const grpc_channel_filter ClientAuthFilter::kFilter = diff --git a/src/core/lib/security/transport/server_auth_filter.cc b/src/core/lib/security/transport/server_auth_filter.cc index 0354398684a..f2e403b329f 100644 --- a/src/core/lib/security/transport/server_auth_filter.cc +++ b/src/core/lib/security/transport/server_auth_filter.cc @@ -212,12 +212,13 @@ ServerAuthFilter::ServerAuthFilter( RefCountedPtr auth_context) : server_credentials_(server_credentials), auth_context_(auth_context) {} -absl::StatusOr ServerAuthFilter::Create( +absl::StatusOr> ServerAuthFilter::Create( const ChannelArgs& args, ChannelFilter::Args) { auto auth_context = args.GetObjectRef(); GPR_ASSERT(auth_context != nullptr); auto creds = args.GetObjectRef(); - return ServerAuthFilter(std::move(creds), std::move(auth_context)); + return std::make_unique(std::move(creds), + std::move(auth_context)); } } // namespace grpc_core diff --git a/src/core/lib/surface/channel_init.h b/src/core/lib/surface/channel_init.h index 48de857f76d..5f1bacdeff1 100644 --- a/src/core/lib/surface/channel_init.h +++ b/src/core/lib/surface/channel_init.h @@ -285,6 +285,11 @@ class ChannelInit { grpc_channel_stack_type type, const ChannelArgs& args) const; private: + // The type of object returned by a filter's Create method. + template + using CreatedType = + typename decltype(T::Create(ChannelArgs(), {}))::value_type; + struct Filter { Filter(const grpc_channel_filter* filter, const ChannelFilterVtable* vtable, std::vector predicates, bool skip_v3, @@ -328,17 +333,17 @@ class ChannelInit { template const ChannelInit::ChannelFilterVtable ChannelInit::VtableForType>::kVtable = { - sizeof(T), alignof(T), + sizeof(CreatedType), alignof(CreatedType), [](void* data, const ChannelArgs& args) -> absl::Status { // TODO(ctiller): fill in ChannelFilter::Args (2nd arg) - absl::StatusOr r = T::Create(args, {}); + absl::StatusOr> r = T::Create(args, {}); if (!r.ok()) return r.status(); - new (data) T(std::move(*r)); + new (data) CreatedType(std::move(*r)); return absl::OkStatus(); }, - [](void* data) { static_cast(data)->~T(); }, + [](void* data) { Destruct(static_cast*>(data)); }, [](void* data, CallFilters::StackBuilder& builder) { - builder.Add(static_cast(data)); + builder.Add(static_cast*>(data)->get()); }}; } // namespace grpc_core diff --git a/src/core/lib/surface/lame_client.cc b/src/core/lib/surface/lame_client.cc index ce8a4ea1880..8966f34ebf7 100644 --- a/src/core/lib/surface/lame_client.cc +++ b/src/core/lib/surface/lame_client.cc @@ -59,17 +59,15 @@ const grpc_channel_filter LameClientFilter::kFilter = MakePromiseBasedFilter("lame-client"); -absl::StatusOr LameClientFilter::Create( +absl::StatusOr> LameClientFilter::Create( const ChannelArgs& args, ChannelFilter::Args) { - return LameClientFilter( + return std::make_unique( *args.GetPointer(GRPC_ARG_LAME_FILTER_ERROR)); } LameClientFilter::LameClientFilter(absl::Status error) - : error_(std::move(error)), state_(std::make_unique()) {} - -LameClientFilter::State::State() - : state_tracker("lame_client", GRPC_CHANNEL_SHUTDOWN) {} + : error_(std::move(error)), + state_tracker_("lame_client", GRPC_CHANNEL_SHUTDOWN) {} ArenaPromise LameClientFilter::MakeCallPromise( CallArgs args, NextPromiseFactory) { @@ -92,13 +90,13 @@ bool LameClientFilter::GetChannelInfo(const grpc_channel_info*) { return true; } bool LameClientFilter::StartTransportOp(grpc_transport_op* op) { { - MutexLock lock(&state_->mu); + MutexLock lock(&mu_); if (op->start_connectivity_watch != nullptr) { - state_->state_tracker.AddWatcher(op->start_connectivity_watch_state, - std::move(op->start_connectivity_watch)); + state_tracker_.AddWatcher(op->start_connectivity_watch_state, + std::move(op->start_connectivity_watch)); } if (op->stop_connectivity_watch != nullptr) { - state_->state_tracker.RemoveWatcher(op->stop_connectivity_watch); + state_tracker_.RemoveWatcher(op->stop_connectivity_watch); } } if (op->send_ping.on_initiate != nullptr) { diff --git a/src/core/lib/surface/lame_client.h b/src/core/lib/surface/lame_client.h index a8464fc9b96..1f8d323d352 100644 --- a/src/core/lib/surface/lame_client.h +++ b/src/core/lib/surface/lame_client.h @@ -47,7 +47,9 @@ class LameClientFilter : public ChannelFilter { public: static const grpc_channel_filter kFilter; - static absl::StatusOr Create( + explicit LameClientFilter(absl::Status error); + + static absl::StatusOr> Create( const ChannelArgs& args, ChannelFilter::Args filter_args); ArenaPromise MakeCallPromise( CallArgs call_args, NextPromiseFactory next_promise_factory) override; @@ -55,15 +57,9 @@ class LameClientFilter : public ChannelFilter { bool GetChannelInfo(const grpc_channel_info*) override; private: - explicit LameClientFilter(absl::Status error); - absl::Status error_; - struct State { - State(); - Mutex mu; - ConnectivityStateTracker state_tracker ABSL_GUARDED_BY(mu); - }; - std::unique_ptr state_; + Mutex mu_; + ConnectivityStateTracker state_tracker_ ABSL_GUARDED_BY(mu_); }; extern const grpc_arg_pointer_vtable kLameFilterErrorArgVtable; diff --git a/src/core/load_balancing/grpclb/client_load_reporting_filter.cc b/src/core/load_balancing/grpclb/client_load_reporting_filter.cc index 3b139fa8a4c..c3de43e4b94 100644 --- a/src/core/load_balancing/grpclb/client_load_reporting_filter.cc +++ b/src/core/load_balancing/grpclb/client_load_reporting_filter.cc @@ -16,8 +16,6 @@ // // -#include - #include "src/core/load_balancing/grpclb/client_load_reporting_filter.h" #include @@ -27,7 +25,8 @@ #include "absl/types/optional.h" -#include "src/core/load_balancing/grpclb/grpclb_client_stats.h" +#include + #include "src/core/lib/channel/channel_stack.h" #include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/promise/context.h" @@ -36,48 +35,47 @@ #include "src/core/lib/resource_quota/arena.h" #include "src/core/lib/transport/metadata_batch.h" #include "src/core/lib/transport/transport.h" +#include "src/core/load_balancing/grpclb/grpclb_client_stats.h" namespace grpc_core { + +const NoInterceptor ClientLoadReportingFilter::Call::OnServerToClientMessage; +const NoInterceptor ClientLoadReportingFilter::Call::OnClientToServerMessage; +const NoInterceptor ClientLoadReportingFilter::Call::OnFinalize; + const grpc_channel_filter ClientLoadReportingFilter::kFilter = MakePromiseBasedFilter( "client_load_reporting"); -absl::StatusOr ClientLoadReportingFilter::Create( - const ChannelArgs&, ChannelFilter::Args) { - return ClientLoadReportingFilter(); +absl::StatusOr> +ClientLoadReportingFilter::Create(const ChannelArgs&, ChannelFilter::Args) { + return std::make_unique(); } -ArenaPromise ClientLoadReportingFilter::MakeCallPromise( - CallArgs call_args, NextPromiseFactory next_promise_factory) { - // Stats object to update. - RefCountedPtr client_stats; - +void ClientLoadReportingFilter::Call::OnClientInitialMetadata( + ClientMetadata& client_initial_metadata) { // Handle client initial metadata. // Grab client stats object from metadata. auto client_stats_md = - call_args.client_initial_metadata->Take(GrpcLbClientStatsMetadata()); + client_initial_metadata.Take(GrpcLbClientStatsMetadata()); if (client_stats_md.has_value()) { - client_stats.reset(*client_stats_md); + client_stats_.reset(*client_stats_md); } +} - auto* saw_initial_metadata = GetContext()->New(false); - call_args.server_initial_metadata->InterceptAndMap( - [saw_initial_metadata](ServerMetadataHandle md) { - *saw_initial_metadata = true; - return md; - }); +void ClientLoadReportingFilter::Call::OnServerInitialMetadata(ServerMetadata&) { + saw_initial_metadata_ = true; +} - return Map(next_promise_factory(std::move(call_args)), - [saw_initial_metadata, client_stats = std::move(client_stats)]( - ServerMetadataHandle trailing_metadata) { - if (client_stats != nullptr) { - client_stats->AddCallFinished( - trailing_metadata->get(GrpcStreamNetworkState()) == - GrpcStreamNetworkState::kNotSentOnWire, - *saw_initial_metadata); - } - return trailing_metadata; - }); +void ClientLoadReportingFilter::Call::OnServerTrailingMetadata( + ServerMetadata& server_trailing_metadata) { + if (client_stats_ != nullptr) { + client_stats_->AddCallFinished( + server_trailing_metadata.get(GrpcStreamNetworkState()) == + GrpcStreamNetworkState::kNotSentOnWire, + saw_initial_metadata_); + } } + } // namespace grpc_core diff --git a/src/core/load_balancing/grpclb/client_load_reporting_filter.h b/src/core/load_balancing/grpclb/client_load_reporting_filter.h index 57e89251d0b..941b97abf99 100644 --- a/src/core/load_balancing/grpclb/client_load_reporting_filter.h +++ b/src/core/load_balancing/grpclb/client_load_reporting_filter.h @@ -19,10 +19,10 @@ #ifndef GRPC_SRC_CORE_LOAD_BALANCING_GRPCLB_CLIENT_LOAD_REPORTING_FILTER_H #define GRPC_SRC_CORE_LOAD_BALANCING_GRPCLB_CLIENT_LOAD_REPORTING_FILTER_H -#include - #include "absl/status/statusor.h" +#include + #include "src/core/lib/channel/channel_args.h" #include "src/core/lib/channel/channel_fwd.h" #include "src/core/lib/channel/promise_based_filter.h" @@ -31,16 +31,27 @@ namespace grpc_core { -class ClientLoadReportingFilter final : public ChannelFilter { +class ClientLoadReportingFilter final + : public ImplementChannelFilter { public: static const grpc_channel_filter kFilter; - static absl::StatusOr Create( + class Call { + public: + void OnClientInitialMetadata(ClientMetadata& client_initial_metadata); + void OnServerInitialMetadata(ServerMetadata& server_initial_metadata); + void OnServerTrailingMetadata(ServerMetadata& server_trailing_metadata); + static const NoInterceptor OnServerToClientMessage; + static const NoInterceptor OnClientToServerMessage; + static const NoInterceptor OnFinalize; + + private: + RefCountedPtr client_stats_; + bool saw_initial_metadata_ = false; + }; + + static absl::StatusOr> Create( const ChannelArgs& args, ChannelFilter::Args filter_args); - - // Construct a promise for one call. - ArenaPromise MakeCallPromise( - CallArgs call_args, NextPromiseFactory next_promise_factory) override; }; } // namespace grpc_core diff --git a/src/core/resolver/xds/xds_resolver.cc b/src/core/resolver/xds/xds_resolver.cc index 6fc59c59113..13efe8a5099 100644 --- a/src/core/resolver/xds/xds_resolver.cc +++ b/src/core/resolver/xds/xds_resolver.cc @@ -317,9 +317,10 @@ class XdsResolver final : public Resolver { public: const static grpc_channel_filter kFilter; - static absl::StatusOr Create( - const ChannelArgs& /* unused */, ChannelFilter::Args filter_args) { - return ClusterSelectionFilter(filter_args); + static absl::StatusOr> Create( + const ChannelArgs& /* unused */, + ChannelFilter::Args /* filter_args */) { + return std::make_unique(); } // Construct a promise for one call. @@ -332,12 +333,6 @@ class XdsResolver final : public Resolver { static const NoInterceptor OnServerToClientMessage; static const NoInterceptor OnFinalize; }; - - private: - explicit ClusterSelectionFilter(ChannelFilter::Args filter_args) - : filter_args_(filter_args) {} - - ChannelFilter::Args filter_args_; }; RefCountedPtr GetOrCreateClusterRef( diff --git a/src/core/service_config/service_config_channel_arg_filter.cc b/src/core/service_config/service_config_channel_arg_filter.cc index 43069611022..a2c9a973129 100644 --- a/src/core/service_config/service_config_channel_arg_filter.cc +++ b/src/core/service_config/service_config_channel_arg_filter.cc @@ -17,8 +17,6 @@ // This filter reads GRPC_ARG_SERVICE_CONFIG and populates ServiceConfigCallData // in the call context per call for direct channels. -#include - #include #include #include @@ -30,6 +28,7 @@ #include #include +#include #include "src/core/ext/filters/message_size/message_size_filter.h" #include "src/core/lib/channel/channel_args.h" @@ -42,13 +41,13 @@ #include "src/core/lib/promise/arena_promise.h" #include "src/core/lib/promise/context.h" #include "src/core/lib/resource_quota/arena.h" +#include "src/core/lib/surface/channel_stack_type.h" +#include "src/core/lib/transport/metadata_batch.h" +#include "src/core/lib/transport/transport.h" #include "src/core/service_config/service_config.h" #include "src/core/service_config/service_config_call_data.h" #include "src/core/service_config/service_config_impl.h" #include "src/core/service_config/service_config_parser.h" -#include "src/core/lib/surface/channel_stack_type.h" -#include "src/core/lib/transport/metadata_batch.h" -#include "src/core/lib/transport/transport.h" namespace grpc_core { @@ -59,9 +58,9 @@ class ServiceConfigChannelArgFilter final public: static const grpc_channel_filter kFilter; - static absl::StatusOr Create( + static absl::StatusOr> Create( const ChannelArgs& args, ChannelFilter::Args) { - return ServiceConfigChannelArgFilter(args); + return std::make_unique(args); } explicit ServiceConfigChannelArgFilter(const ChannelArgs& args) { diff --git a/src/cpp/ext/filters/census/client_filter.cc b/src/cpp/ext/filters/census/client_filter.cc index 8dbc0959bd7..16131ce8af4 100644 --- a/src/cpp/ext/filters/census/client_filter.cc +++ b/src/cpp/ext/filters/census/client_filter.cc @@ -84,11 +84,13 @@ const grpc_channel_filter OpenCensusClientFilter::kFilter = grpc_core::FilterEndpoint::kClient, 0>( "opencensus_client"); -absl::StatusOr OpenCensusClientFilter::Create( - const grpc_core::ChannelArgs& args, ChannelFilter::Args /*filter_args*/) { +absl::StatusOr> +OpenCensusClientFilter::Create(const grpc_core::ChannelArgs& args, + ChannelFilter::Args /*filter_args*/) { bool observability_enabled = args.GetInt(GRPC_ARG_ENABLE_OBSERVABILITY).value_or(true); - return OpenCensusClientFilter(/*tracing_enabled=*/observability_enabled); + return std::make_unique( + /*tracing_enabled=*/observability_enabled); } grpc_core::ArenaPromise diff --git a/src/cpp/ext/filters/census/client_filter.h b/src/cpp/ext/filters/census/client_filter.h index 011a22e58a1..a549510bbe2 100644 --- a/src/cpp/ext/filters/census/client_filter.h +++ b/src/cpp/ext/filters/census/client_filter.h @@ -38,16 +38,17 @@ class OpenCensusClientFilter : public grpc_core::ChannelFilter { public: static const grpc_channel_filter kFilter; - static absl::StatusOr Create( + static absl::StatusOr> Create( const grpc_core::ChannelArgs& args, ChannelFilter::Args /*filter_args*/); + explicit OpenCensusClientFilter(bool tracing_enabled) + : tracing_enabled_(tracing_enabled) {} + grpc_core::ArenaPromise MakeCallPromise( grpc_core::CallArgs call_args, grpc_core::NextPromiseFactory next_promise_factory) override; private: - explicit OpenCensusClientFilter(bool tracing_enabled) - : tracing_enabled_(tracing_enabled) {} bool tracing_enabled_ = true; }; diff --git a/test/core/filters/filter_test.h b/test/core/filters/filter_test.h index 796b0a22213..c557e6aee16 100644 --- a/test/core/filters/filter_test.h +++ b/test/core/filters/filter_test.h @@ -231,7 +231,7 @@ class FilterTest : public FilterTestBase { absl::StatusOr MakeChannel(const ChannelArgs& args) { auto filter = Filter::Create(args, ChannelFilter::Args()); if (!filter.ok()) return filter.status(); - return Channel(std::make_unique(std::move(*filter)), this); + return Channel(std::move(*filter), this); } }; diff --git a/test/core/filters/filter_test_test.cc b/test/core/filters/filter_test_test.cc index e0dad7ad671..ec836316f86 100644 --- a/test/core/filters/filter_test_test.cc +++ b/test/core/filters/filter_test_test.cc @@ -49,9 +49,9 @@ class NoOpFilter final : public ChannelFilter { return next(std::move(args)); } - static absl::StatusOr Create(const ChannelArgs&, - ChannelFilter::Args) { - return NoOpFilter(); + static absl::StatusOr> Create( + const ChannelArgs&, ChannelFilter::Args) { + return std::make_unique(); } }; using NoOpFilterTest = FilterTest; @@ -70,9 +70,9 @@ class DelayStartFilter final : public ChannelFilter { next); } - static absl::StatusOr Create(const ChannelArgs&, - ChannelFilter::Args) { - return DelayStartFilter(); + static absl::StatusOr> Create( + const ChannelArgs&, ChannelFilter::Args) { + return std::make_unique(); } }; using DelayStartFilterTest = FilterTest; @@ -86,9 +86,9 @@ class AddClientInitialMetadataFilter final : public ChannelFilter { return next(std::move(args)); } - static absl::StatusOr Create( + static absl::StatusOr> Create( const ChannelArgs&, ChannelFilter::Args) { - return AddClientInitialMetadataFilter(); + return absl::make_unique(); } }; using AddClientInitialMetadataFilterTest = @@ -104,9 +104,9 @@ class AddServerTrailingMetadataFilter final : public ChannelFilter { }); } - static absl::StatusOr Create( - const ChannelArgs&, ChannelFilter::Args) { - return AddServerTrailingMetadataFilter(); + static absl::StatusOr> + Create(const ChannelArgs&, ChannelFilter::Args) { + return absl::make_unique(); } }; using AddServerTrailingMetadataFilterTest = @@ -122,10 +122,9 @@ class AddServerInitialMetadataFilter final : public ChannelFilter { }); return next(std::move(args)); } - - static absl::StatusOr Create( + static absl::StatusOr> Create( const ChannelArgs&, ChannelFilter::Args) { - return AddServerInitialMetadataFilter(); + return absl::make_unique(); } }; using AddServerInitialMetadataFilterTest = diff --git a/test/core/surface/channel_init_test.cc b/test/core/surface/channel_init_test.cc index cd7bfc7f347..4986cc3a839 100644 --- a/test/core/surface/channel_init_test.cc +++ b/test/core/surface/channel_init_test.cc @@ -15,6 +15,7 @@ #include "src/core/lib/surface/channel_init.h" #include +#include #include #include "absl/strings/string_view.h" @@ -206,9 +207,10 @@ class TestFilter1 { public: explicit TestFilter1(int* p) : p_(p) {} - static absl::StatusOr Create(const ChannelArgs& args, Empty) { + static absl::StatusOr> Create( + const ChannelArgs& args, Empty) { EXPECT_EQ(args.GetInt("foo"), 1); - return TestFilter1(args.GetPointer("p")); + return std::make_unique(args.GetPointer("p")); } static const grpc_channel_filter kFilter; diff --git a/test/core/util/fake_stats_plugin.cc b/test/core/util/fake_stats_plugin.cc index 2ef159d5b75..00055d06d94 100644 --- a/test/core/util/fake_stats_plugin.cc +++ b/test/core/util/fake_stats_plugin.cc @@ -22,15 +22,16 @@ class FakeStatsClientFilter : public ChannelFilter { public: static const grpc_channel_filter kFilter; - static absl::StatusOr Create( + explicit FakeStatsClientFilter( + FakeClientCallTracerFactory* fake_client_call_tracer_factory); + + static absl::StatusOr> Create( const ChannelArgs& /*args*/, ChannelFilter::Args /*filter_args*/); ArenaPromise MakeCallPromise( CallArgs call_args, NextPromiseFactory next_promise_factory) override; private: - explicit FakeStatsClientFilter( - FakeClientCallTracerFactory* fake_client_call_tracer_factory); FakeClientCallTracerFactory* const fake_client_call_tracer_factory_; }; @@ -38,13 +39,15 @@ const grpc_channel_filter FakeStatsClientFilter::kFilter = MakePromiseBasedFilter( "fake_stats_client"); -absl::StatusOr FakeStatsClientFilter::Create( - const ChannelArgs& args, ChannelFilter::Args /*filter_args*/) { +absl::StatusOr> +FakeStatsClientFilter::Create(const ChannelArgs& args, + ChannelFilter::Args /*filter_args*/) { auto* fake_client_call_tracer_factory = args.GetPointer( GRPC_ARG_INJECT_FAKE_CLIENT_CALL_TRACER_FACTORY); GPR_ASSERT(fake_client_call_tracer_factory != nullptr); - return FakeStatsClientFilter(fake_client_call_tracer_factory); + return std::make_unique( + fake_client_call_tracer_factory); } ArenaPromise FakeStatsClientFilter::MakeCallPromise( diff --git a/test/cpp/ext/otel/otel_test_library.cc b/test/cpp/ext/otel/otel_test_library.cc index 0aac08a4da3..81ab0842128 100644 --- a/test/cpp/ext/otel/otel_test_library.cc +++ b/test/cpp/ext/otel/otel_test_library.cc @@ -50,9 +50,15 @@ class AddLabelsFilter : public grpc_core::ChannelFilter { public: static const grpc_channel_filter kFilter; - static absl::StatusOr Create( + explicit AddLabelsFilter( + std::map + labels_to_inject) + : labels_to_inject_(std::move(labels_to_inject)) {} + + static absl::StatusOr> Create( const grpc_core::ChannelArgs& args, ChannelFilter::Args /*filter_args*/) { - return AddLabelsFilter( + return absl::make_unique( *args.GetPointer>(GRPC_ARG_LABELS_TO_INJECT)); @@ -73,12 +79,6 @@ class AddLabelsFilter : public grpc_core::ChannelFilter { } private: - explicit AddLabelsFilter( - std::map - labels_to_inject) - : labels_to_inject_(std::move(labels_to_inject)) {} - const std::map< grpc_core::ClientCallTracer::CallAttemptTracer::OptionalLabelKey, grpc_core::RefCountedStringValue>