[call-v3] Channel filter construction returns pointers (#36355)

Currently channel filter construction returns a `StatusOr<T>`, this change makes it return a `StatusOr<P<T>>` where P is `unique_ptr`, `OrphanablePtr`, `RefCountedPtr`, `DualRefCountedPtr`, etc (most of the code really doesn't need to know, so I'm choosing to leave the flexibility).

That smart pointer is then stored in the channel stack instance, and dereferenced when needed.

This means that channel filters no longer need to be movable (which is a nice simplification), and puts these level-1 filters on a similar memory management track as the level-2 filters we have planned.

(this change also converts client load reporting to v3 apis -- it's a bit accidentally picked up, but seems ok to pull through too)

Closes #36355

COPYBARA_INTEGRATE_REVIEW=https://github.com/grpc/grpc/pull/36355 from ctiller:objectify-me 0eb054b748
PiperOrigin-RevId: 625390977
pull/36345/head
Craig Tiller 8 months ago committed by Yash Tibrewal
parent 7561250649
commit 58b254dacf
  1. 7
      src/core/ext/filters/backend_metrics/backend_metric_filter.cc
  2. 4
      src/core/ext/filters/backend_metrics/backend_metric_filter.h
  3. 18
      src/core/ext/filters/channel_idle/legacy_channel_idle_filter.cc
  4. 21
      src/core/ext/filters/channel_idle/legacy_channel_idle_filter.h
  5. 13
      src/core/ext/filters/fault_injection/fault_injection_filter.cc
  6. 8
      src/core/ext/filters/fault_injection/fault_injection_filter.h
  7. 9
      src/core/ext/filters/http/client/http_client_filter.cc
  8. 10
      src/core/ext/filters/http/client/http_client_filter.h
  9. 7
      src/core/ext/filters/http/client_authority_filter.cc
  10. 9
      src/core/ext/filters/http/client_authority_filter.h
  11. 12
      src/core/ext/filters/http/message_compress/compression_filter.cc
  12. 16
      src/core/ext/filters/http/message_compress/compression_filter.h
  13. 4
      src/core/ext/filters/http/server/http_server_filter.cc
  14. 10
      src/core/ext/filters/http/server/http_server_filter.h
  15. 9
      src/core/ext/filters/load_reporting/server_load_reporting_filter.cc
  16. 2
      src/core/ext/filters/load_reporting/server_load_reporting_filter.h
  17. 19
      src/core/ext/filters/logging/logging_filter.cc
  18. 11
      src/core/ext/filters/logging/logging_filter.h
  19. 14
      src/core/ext/filters/message_size/message_size_filter.cc
  20. 14
      src/core/ext/filters/message_size/message_size_filter.h
  21. 15
      src/core/ext/filters/rbac/rbac_filter.cc
  22. 10
      src/core/ext/filters/rbac/rbac_filter.h
  23. 63
      src/core/ext/filters/server_config_selector/server_config_selector_filter.cc
  24. 7
      src/core/ext/filters/stateful_session/stateful_session_filter.cc
  25. 5
      src/core/ext/filters/stateful_session/stateful_session_filter.h
  26. 6
      src/core/lib/channel/promise_based_filter.cc
  27. 46
      src/core/lib/channel/promise_based_filter.h
  28. 10
      src/core/lib/channel/server_call_tracer_filter.cc
  29. 6
      src/core/lib/security/authorization/grpc_server_authz_filter.cc
  30. 12
      src/core/lib/security/authorization/grpc_server_authz_filter.h
  31. 22
      src/core/lib/security/transport/auth_filters.h
  32. 7
      src/core/lib/security/transport/client_auth_filter.cc
  33. 5
      src/core/lib/security/transport/server_auth_filter.cc
  34. 15
      src/core/lib/surface/channel_init.h
  35. 18
      src/core/lib/surface/lame_client.cc
  36. 14
      src/core/lib/surface/lame_client.h
  37. 58
      src/core/load_balancing/grpclb/client_load_reporting_filter.cc
  38. 27
      src/core/load_balancing/grpclb/client_load_reporting_filter.h
  39. 13
      src/core/resolver/xds/xds_resolver.cc
  40. 13
      src/core/service_config/service_config_channel_arg_filter.cc
  41. 8
      src/cpp/ext/filters/census/client_filter.cc
  42. 7
      src/cpp/ext/filters/census/client_filter.h
  43. 2
      test/core/filters/filter_test.h
  44. 27
      test/core/filters/filter_test_test.cc
  45. 6
      test/core/surface/channel_init_test.cc
  46. 15
      test/core/util/fake_stats_plugin.cc
  47. 16
      test/cpp/ext/otel/otel_test_library.cc

@ -25,6 +25,7 @@
#include <utility> #include <utility>
#include "absl/strings/string_view.h" #include "absl/strings/string_view.h"
#include "third_party/grpc/src/core/lib/channel/promise_based_filter.h"
#include "upb/base/string_view.h" #include "upb/base/string_view.h"
#include "upb/mem/arena.hpp" #include "upb/mem/arena.hpp"
#include "xds/data/orca/v3/orca_load_report.upb.h" #include "xds/data/orca/v3/orca_load_report.upb.h"
@ -121,9 +122,9 @@ const grpc_channel_filter BackendMetricFilter::kFilter =
MakePromiseBasedFilter<BackendMetricFilter, FilterEndpoint::kServer>( MakePromiseBasedFilter<BackendMetricFilter, FilterEndpoint::kServer>(
"backend_metric"); "backend_metric");
absl::StatusOr<BackendMetricFilter> BackendMetricFilter::Create( absl::StatusOr<std::unique_ptr<BackendMetricFilter>>
const ChannelArgs&, ChannelFilter::Args) { BackendMetricFilter::Create(const ChannelArgs&, ChannelFilter::Args) {
return BackendMetricFilter(); return std::make_unique<BackendMetricFilter>();
} }
void BackendMetricFilter::Call::OnServerTrailingMetadata(ServerMetadata& md) { void BackendMetricFilter::Call::OnServerTrailingMetadata(ServerMetadata& md) {

@ -35,8 +35,8 @@ class BackendMetricFilter : public ImplementChannelFilter<BackendMetricFilter> {
public: public:
static const grpc_channel_filter kFilter; static const grpc_channel_filter kFilter;
static absl::StatusOr<BackendMetricFilter> Create(const ChannelArgs& args, static absl::StatusOr<std::unique_ptr<BackendMetricFilter>> Create(
ChannelFilter::Args); const ChannelArgs& args, ChannelFilter::Args);
class Call { class Call {
public: public:

@ -25,6 +25,7 @@
#include "absl/base/thread_annotations.h" #include "absl/base/thread_annotations.h"
#include "absl/meta/type_traits.h" #include "absl/meta/type_traits.h"
#include "absl/random/random.h" #include "absl/random/random.h"
#include "absl/status/statusor.h"
#include "absl/types/optional.h" #include "absl/types/optional.h"
#include <grpc/impl/channel_arg_names.h> #include <grpc/impl/channel_arg_names.h>
@ -133,18 +134,17 @@ struct LegacyMaxAgeFilter::Config {
// will be removed at that time also, so just disable the deprecation warning // will be removed at that time also, so just disable the deprecation warning
// for now. // for now.
ABSL_INTERNAL_DISABLE_DEPRECATED_DECLARATION_WARNING ABSL_INTERNAL_DISABLE_DEPRECATED_DECLARATION_WARNING
absl::StatusOr<LegacyClientIdleFilter> LegacyClientIdleFilter::Create( absl::StatusOr<std::unique_ptr<LegacyClientIdleFilter>>
const ChannelArgs& args, ChannelFilter::Args filter_args) { LegacyClientIdleFilter::Create(const ChannelArgs& args,
LegacyClientIdleFilter filter(filter_args.channel_stack(), ChannelFilter::Args filter_args) {
GetClientIdleTimeout(args)); return std::make_unique<LegacyClientIdleFilter>(filter_args.channel_stack(),
return absl::StatusOr<LegacyClientIdleFilter>(std::move(filter)); GetClientIdleTimeout(args));
} }
absl::StatusOr<LegacyMaxAgeFilter> LegacyMaxAgeFilter::Create( absl::StatusOr<std::unique_ptr<LegacyMaxAgeFilter>> LegacyMaxAgeFilter::Create(
const ChannelArgs& args, ChannelFilter::Args filter_args) { const ChannelArgs& args, ChannelFilter::Args filter_args) {
LegacyMaxAgeFilter filter(filter_args.channel_stack(), return std::make_unique<LegacyMaxAgeFilter>(filter_args.channel_stack(),
Config::FromChannelArgs(args)); Config::FromChannelArgs(args));
return absl::StatusOr<LegacyMaxAgeFilter>(std::move(filter));
} }
ABSL_INTERNAL_RESTORE_DEPRECATED_DECLARATION_WARNING ABSL_INTERNAL_RESTORE_DEPRECATED_DECLARATION_WARNING

@ -42,6 +42,11 @@ namespace grpc_core {
class LegacyChannelIdleFilter : public ChannelFilter { class LegacyChannelIdleFilter : public ChannelFilter {
public: public:
LegacyChannelIdleFilter(grpc_channel_stack* channel_stack,
Duration client_idle_timeout)
: channel_stack_(channel_stack),
client_idle_timeout_(client_idle_timeout) {}
~LegacyChannelIdleFilter() override = default; ~LegacyChannelIdleFilter() override = default;
LegacyChannelIdleFilter(const LegacyChannelIdleFilter&) = delete; LegacyChannelIdleFilter(const LegacyChannelIdleFilter&) = delete;
@ -59,11 +64,6 @@ class LegacyChannelIdleFilter : public ChannelFilter {
using SingleSetActivityPtr = using SingleSetActivityPtr =
SingleSetPtr<Activity, typename ActivityPtr::deleter_type>; SingleSetPtr<Activity, typename ActivityPtr::deleter_type>;
LegacyChannelIdleFilter(grpc_channel_stack* channel_stack,
Duration client_idle_timeout)
: channel_stack_(channel_stack),
client_idle_timeout_(client_idle_timeout) {}
grpc_channel_stack* channel_stack() { return channel_stack_; }; grpc_channel_stack* channel_stack() { return channel_stack_; };
virtual void Shutdown(); virtual void Shutdown();
@ -94,10 +94,9 @@ class LegacyClientIdleFilter final : public LegacyChannelIdleFilter {
public: public:
static const grpc_channel_filter kFilter; static const grpc_channel_filter kFilter;
static absl::StatusOr<LegacyClientIdleFilter> Create( static absl::StatusOr<std::unique_ptr<LegacyClientIdleFilter>> Create(
const ChannelArgs& args, ChannelFilter::Args filter_args); const ChannelArgs& args, ChannelFilter::Args filter_args);
private:
using LegacyChannelIdleFilter::LegacyChannelIdleFilter; using LegacyChannelIdleFilter::LegacyChannelIdleFilter;
}; };
@ -106,9 +105,12 @@ class LegacyMaxAgeFilter final : public LegacyChannelIdleFilter {
static const grpc_channel_filter kFilter; static const grpc_channel_filter kFilter;
struct Config; struct Config;
static absl::StatusOr<LegacyMaxAgeFilter> Create( static absl::StatusOr<std::unique_ptr<LegacyMaxAgeFilter>> Create(
const ChannelArgs& args, ChannelFilter::Args filter_args); const ChannelArgs& args, ChannelFilter::Args filter_args);
LegacyMaxAgeFilter(grpc_channel_stack* channel_stack,
const Config& max_age_config);
void PostInit() override; void PostInit() override;
private: private:
@ -128,9 +130,6 @@ class LegacyMaxAgeFilter final : public LegacyChannelIdleFilter {
LegacyMaxAgeFilter* filter_; LegacyMaxAgeFilter* filter_;
}; };
LegacyMaxAgeFilter(grpc_channel_stack* channel_stack,
const Config& max_age_config);
void Shutdown() override; void Shutdown() override;
SingleSetActivityPtr max_age_activity_; SingleSetActivityPtr max_age_activity_;

@ -29,6 +29,7 @@
#include "absl/meta/type_traits.h" #include "absl/meta/type_traits.h"
#include "absl/status/status.h" #include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/numbers.h" #include "absl/strings/numbers.h"
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h" #include "absl/strings/string_view.h"
@ -135,16 +136,16 @@ class FaultInjectionFilter::InjectionDecision {
FaultHandle active_fault_{false}; FaultHandle active_fault_{false};
}; };
absl::StatusOr<FaultInjectionFilter> FaultInjectionFilter::Create( absl::StatusOr<std::unique_ptr<FaultInjectionFilter>>
const ChannelArgs&, ChannelFilter::Args filter_args) { FaultInjectionFilter::Create(const ChannelArgs&,
return FaultInjectionFilter(filter_args); ChannelFilter::Args filter_args) {
return std::make_unique<FaultInjectionFilter>(filter_args);
} }
FaultInjectionFilter::FaultInjectionFilter(ChannelFilter::Args filter_args) FaultInjectionFilter::FaultInjectionFilter(ChannelFilter::Args filter_args)
: index_(filter_args.instance_id()), : index_(filter_args.instance_id()),
service_config_parser_index_( service_config_parser_index_(
FaultInjectionServiceConfigParser::ParserIndex()), FaultInjectionServiceConfigParser::ParserIndex()) {}
mu_(new Mutex) {}
// Construct a promise for one call. // Construct a promise for one call.
ArenaPromise<absl::Status> FaultInjectionFilter::Call::OnClientInitialMetadata( ArenaPromise<absl::Status> FaultInjectionFilter::Call::OnClientInitialMetadata(
@ -226,7 +227,7 @@ FaultInjectionFilter::MakeInjectionDecision(
bool delay_request = delay != Duration::Zero(); bool delay_request = delay != Duration::Zero();
bool abort_request = abort_code != GRPC_STATUS_OK; bool abort_request = abort_code != GRPC_STATUS_OK;
if (delay_request || abort_request) { if (delay_request || abort_request) {
MutexLock lock(mu_.get()); MutexLock lock(&mu_);
if (delay_request) { if (delay_request) {
delay_request = delay_request =
UnderFraction(&delay_rand_generator_, delay_percentage_numerator, UnderFraction(&delay_rand_generator_, delay_percentage_numerator,

@ -45,9 +45,11 @@ class FaultInjectionFilter
public: public:
static const grpc_channel_filter kFilter; static const grpc_channel_filter kFilter;
static absl::StatusOr<FaultInjectionFilter> Create( static absl::StatusOr<std::unique_ptr<FaultInjectionFilter>> Create(
const ChannelArgs& args, ChannelFilter::Args filter_args); const ChannelArgs& args, ChannelFilter::Args filter_args);
explicit FaultInjectionFilter(ChannelFilter::Args filter_args);
// Construct a promise for one call. // Construct a promise for one call.
class Call { class Call {
public: public:
@ -61,8 +63,6 @@ class FaultInjectionFilter
}; };
private: private:
explicit FaultInjectionFilter(ChannelFilter::Args filter_args);
class InjectionDecision; class InjectionDecision;
InjectionDecision MakeInjectionDecision( InjectionDecision MakeInjectionDecision(
const ClientMetadata& initial_metadata); const ClientMetadata& initial_metadata);
@ -70,7 +70,7 @@ class FaultInjectionFilter
// The relative index of instances of the same filter. // The relative index of instances of the same filter.
size_t index_; size_t index_;
const size_t service_config_parser_index_; const size_t service_config_parser_index_;
std::unique_ptr<Mutex> mu_; Mutex mu_;
absl::InsecureBitGen abort_rand_generator_ ABSL_GUARDED_BY(mu_); absl::InsecureBitGen abort_rand_generator_ ABSL_GUARDED_BY(mu_);
absl::InsecureBitGen delay_rand_generator_ ABSL_GUARDED_BY(mu_); absl::InsecureBitGen delay_rand_generator_ ABSL_GUARDED_BY(mu_);
}; };

@ -27,6 +27,7 @@
#include <vector> #include <vector>
#include "absl/status/status.h" #include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h" #include "absl/strings/str_format.h"
#include "absl/strings/str_join.h" #include "absl/strings/str_join.h"
@ -136,16 +137,16 @@ HttpClientFilter::HttpClientFilter(HttpSchemeMetadata::ValueType scheme,
Slice user_agent, Slice user_agent,
bool test_only_use_put_requests) bool test_only_use_put_requests)
: scheme_(scheme), : scheme_(scheme),
user_agent_(std::move(user_agent)), test_only_use_put_requests_(test_only_use_put_requests),
test_only_use_put_requests_(test_only_use_put_requests) {} user_agent_(std::move(user_agent)) {}
absl::StatusOr<HttpClientFilter> HttpClientFilter::Create( absl::StatusOr<std::unique_ptr<HttpClientFilter>> HttpClientFilter::Create(
const ChannelArgs& args, ChannelFilter::Args) { const ChannelArgs& args, ChannelFilter::Args) {
auto* transport = args.GetObject<Transport>(); auto* transport = args.GetObject<Transport>();
if (transport == nullptr) { if (transport == nullptr) {
return absl::InvalidArgumentError("HttpClientFilter needs a transport"); return absl::InvalidArgumentError("HttpClientFilter needs a transport");
} }
return HttpClientFilter( return std::make_unique<HttpClientFilter>(
SchemeFromArgs(args), SchemeFromArgs(args),
UserAgentFromArgs(args, transport->GetTransportName()), UserAgentFromArgs(args, transport->GetTransportName()),
args.GetInt(GRPC_ARG_TEST_ONLY_USE_PUT_REQUESTS).value_or(false)); args.GetInt(GRPC_ARG_TEST_ONLY_USE_PUT_REQUESTS).value_or(false));

@ -35,9 +35,12 @@ class HttpClientFilter : public ImplementChannelFilter<HttpClientFilter> {
public: public:
static const grpc_channel_filter kFilter; static const grpc_channel_filter kFilter;
static absl::StatusOr<HttpClientFilter> Create( static absl::StatusOr<std::unique_ptr<HttpClientFilter>> Create(
const ChannelArgs& args, ChannelFilter::Args filter_args); const ChannelArgs& args, ChannelFilter::Args filter_args);
HttpClientFilter(HttpSchemeMetadata::ValueType scheme, Slice user_agent,
bool test_only_use_put_requests);
class Call { class Call {
public: public:
void OnClientInitialMetadata(ClientMetadata& md, HttpClientFilter* filter); void OnClientInitialMetadata(ClientMetadata& md, HttpClientFilter* filter);
@ -49,12 +52,9 @@ class HttpClientFilter : public ImplementChannelFilter<HttpClientFilter> {
}; };
private: private:
HttpClientFilter(HttpSchemeMetadata::ValueType scheme, Slice user_agent,
bool test_only_use_put_requests);
HttpSchemeMetadata::ValueType scheme_; HttpSchemeMetadata::ValueType scheme_;
Slice user_agent_;
bool test_only_use_put_requests_; bool test_only_use_put_requests_;
Slice user_agent_;
}; };
// A test-only channel arg to allow testing gRPC Core server behavior on PUT // A test-only channel arg to allow testing gRPC Core server behavior on PUT

@ -43,8 +43,8 @@ const NoInterceptor ClientAuthorityFilter::Call::OnClientToServerMessage;
const NoInterceptor ClientAuthorityFilter::Call::OnServerToClientMessage; const NoInterceptor ClientAuthorityFilter::Call::OnServerToClientMessage;
const NoInterceptor ClientAuthorityFilter::Call::OnFinalize; const NoInterceptor ClientAuthorityFilter::Call::OnFinalize;
absl::StatusOr<ClientAuthorityFilter> ClientAuthorityFilter::Create( absl::StatusOr<std::unique_ptr<ClientAuthorityFilter>>
const ChannelArgs& args, ChannelFilter::Args) { ClientAuthorityFilter::Create(const ChannelArgs& args, ChannelFilter::Args) {
absl::optional<absl::string_view> default_authority = absl::optional<absl::string_view> default_authority =
args.GetString(GRPC_ARG_DEFAULT_AUTHORITY); args.GetString(GRPC_ARG_DEFAULT_AUTHORITY);
if (!default_authority.has_value()) { if (!default_authority.has_value()) {
@ -52,7 +52,8 @@ absl::StatusOr<ClientAuthorityFilter> ClientAuthorityFilter::Create(
"GRPC_ARG_DEFAULT_AUTHORITY string channel arg. not found. Note that " "GRPC_ARG_DEFAULT_AUTHORITY string channel arg. not found. Note that "
"direct channels must explicitly specify a value for this argument."); "direct channels must explicitly specify a value for this argument.");
} }
return ClientAuthorityFilter(Slice::FromCopiedString(*default_authority)); return std::make_unique<ClientAuthorityFilter>(
Slice::FromCopiedString(*default_authority));
} }
void ClientAuthorityFilter::Call::OnClientInitialMetadata( void ClientAuthorityFilter::Call::OnClientInitialMetadata(

@ -39,8 +39,11 @@ class ClientAuthorityFilter final
public: public:
static const grpc_channel_filter kFilter; static const grpc_channel_filter kFilter;
static absl::StatusOr<ClientAuthorityFilter> Create(const ChannelArgs& args, static absl::StatusOr<std::unique_ptr<ClientAuthorityFilter>> Create(
ChannelFilter::Args); const ChannelArgs& args, ChannelFilter::Args);
explicit ClientAuthorityFilter(Slice default_authority)
: default_authority_(std::move(default_authority)) {}
class Call { class Call {
public: public:
@ -54,8 +57,6 @@ class ClientAuthorityFilter final
}; };
private: private:
explicit ClientAuthorityFilter(Slice default_authority)
: default_authority_(std::move(default_authority)) {}
Slice default_authority_; Slice default_authority_;
}; };

@ -72,14 +72,14 @@ const grpc_channel_filter ServerCompressionFilter::kFilter =
kFilterExaminesInboundMessages | kFilterExaminesInboundMessages |
kFilterExaminesOutboundMessages>("compression"); kFilterExaminesOutboundMessages>("compression");
absl::StatusOr<ClientCompressionFilter> ClientCompressionFilter::Create( absl::StatusOr<std::unique_ptr<ClientCompressionFilter>>
const ChannelArgs& args, ChannelFilter::Args) { ClientCompressionFilter::Create(const ChannelArgs& args, ChannelFilter::Args) {
return ClientCompressionFilter(args); return std::make_unique<ClientCompressionFilter>(args);
} }
absl::StatusOr<ServerCompressionFilter> ServerCompressionFilter::Create( absl::StatusOr<std::unique_ptr<ServerCompressionFilter>>
const ChannelArgs& args, ChannelFilter::Args) { ServerCompressionFilter::Create(const ChannelArgs& args, ChannelFilter::Args) {
return ServerCompressionFilter(args); return std::make_unique<ServerCompressionFilter>(args);
} }
ChannelCompression::ChannelCompression(const ChannelArgs& args) ChannelCompression::ChannelCompression(const ChannelArgs& args)

@ -110,9 +110,12 @@ class ClientCompressionFilter final
public: public:
static const grpc_channel_filter kFilter; static const grpc_channel_filter kFilter;
static absl::StatusOr<ClientCompressionFilter> Create( static absl::StatusOr<std::unique_ptr<ClientCompressionFilter>> Create(
const ChannelArgs& args, ChannelFilter::Args filter_args); const ChannelArgs& args, ChannelFilter::Args filter_args);
explicit ClientCompressionFilter(const ChannelArgs& args)
: compression_engine_(args) {}
// Construct a promise for one call. // Construct a promise for one call.
class Call { class Call {
public: public:
@ -135,9 +138,6 @@ class ClientCompressionFilter final
}; };
private: private:
explicit ClientCompressionFilter(const ChannelArgs& args)
: compression_engine_(args) {}
ChannelCompression compression_engine_; ChannelCompression compression_engine_;
}; };
@ -146,9 +146,12 @@ class ServerCompressionFilter final
public: public:
static const grpc_channel_filter kFilter; static const grpc_channel_filter kFilter;
static absl::StatusOr<ServerCompressionFilter> Create( static absl::StatusOr<std::unique_ptr<ServerCompressionFilter>> Create(
const ChannelArgs& args, ChannelFilter::Args filter_args); const ChannelArgs& args, ChannelFilter::Args filter_args);
explicit ServerCompressionFilter(const ChannelArgs& args)
: compression_engine_(args) {}
// Construct a promise for one call. // Construct a promise for one call.
class Call { class Call {
public: public:
@ -171,9 +174,6 @@ class ServerCompressionFilter final
}; };
private: private:
explicit ServerCompressionFilter(const ChannelArgs& args)
: compression_engine_(args) {}
ChannelCompression compression_engine_; ChannelCompression compression_engine_;
}; };

@ -152,9 +152,9 @@ void HttpServerFilter::Call::OnServerTrailingMetadata(ServerMetadata& md) {
FilterOutgoingMetadata(&md); FilterOutgoingMetadata(&md);
} }
absl::StatusOr<HttpServerFilter> HttpServerFilter::Create( absl::StatusOr<std::unique_ptr<HttpServerFilter>> HttpServerFilter::Create(
const ChannelArgs& args, ChannelFilter::Args) { const ChannelArgs& args, ChannelFilter::Args) {
return HttpServerFilter( return std::make_unique<HttpServerFilter>(
args.GetBool(GRPC_ARG_SURFACE_USER_AGENT).value_or(true), args.GetBool(GRPC_ARG_SURFACE_USER_AGENT).value_or(true),
args.GetBool( args.GetBool(
GRPC_ARG_DO_NOT_USE_UNLESS_YOU_HAVE_PERMISSION_FROM_GRPC_TEAM_ALLOW_BROKEN_PUT_REQUESTS) GRPC_ARG_DO_NOT_USE_UNLESS_YOU_HAVE_PERMISSION_FROM_GRPC_TEAM_ALLOW_BROKEN_PUT_REQUESTS)

@ -36,9 +36,13 @@ class HttpServerFilter : public ImplementChannelFilter<HttpServerFilter> {
public: public:
static const grpc_channel_filter kFilter; static const grpc_channel_filter kFilter;
static absl::StatusOr<HttpServerFilter> Create( static absl::StatusOr<std::unique_ptr<HttpServerFilter>> Create(
const ChannelArgs& args, ChannelFilter::Args filter_args); const ChannelArgs& args, ChannelFilter::Args filter_args);
HttpServerFilter(bool surface_user_agent, bool allow_put_requests)
: surface_user_agent_(surface_user_agent),
allow_put_requests_(allow_put_requests) {}
class Call { class Call {
public: public:
ServerMetadataHandle OnClientInitialMetadata(ClientMetadata& md, ServerMetadataHandle OnClientInitialMetadata(ClientMetadata& md,
@ -51,10 +55,6 @@ class HttpServerFilter : public ImplementChannelFilter<HttpServerFilter> {
}; };
private: private:
HttpServerFilter(bool surface_user_agent, bool allow_put_requests)
: surface_user_agent_(surface_user_agent),
allow_put_requests_(allow_put_requests) {}
bool surface_user_agent_; bool surface_user_agent_;
bool allow_put_requests_; bool allow_put_requests_;
}; };

@ -76,10 +76,11 @@ const NoInterceptor ServerLoadReportingFilter::Call::OnServerInitialMetadata;
const NoInterceptor ServerLoadReportingFilter::Call::OnClientToServerMessage; const NoInterceptor ServerLoadReportingFilter::Call::OnClientToServerMessage;
const NoInterceptor ServerLoadReportingFilter::Call::OnServerToClientMessage; const NoInterceptor ServerLoadReportingFilter::Call::OnServerToClientMessage;
absl::StatusOr<ServerLoadReportingFilter> ServerLoadReportingFilter::Create( absl::StatusOr<std::unique_ptr<ServerLoadReportingFilter>>
const ChannelArgs& channel_args, ChannelFilter::Args) { ServerLoadReportingFilter::Create(const ChannelArgs& channel_args,
ChannelFilter::Args) {
// Find and record the peer_identity. // Find and record the peer_identity.
ServerLoadReportingFilter filter; auto filter = std::make_unique<ServerLoadReportingFilter>();
const auto* auth_context = channel_args.GetObject<grpc_auth_context>(); const auto* auth_context = channel_args.GetObject<grpc_auth_context>();
if (auth_context != nullptr && if (auth_context != nullptr &&
grpc_auth_context_peer_is_authenticated(auth_context)) { grpc_auth_context_peer_is_authenticated(auth_context)) {
@ -88,7 +89,7 @@ absl::StatusOr<ServerLoadReportingFilter> ServerLoadReportingFilter::Create(
const grpc_auth_property* auth_property = const grpc_auth_property* auth_property =
grpc_auth_property_iterator_next(&auth_it); grpc_auth_property_iterator_next(&auth_it);
if (auth_property != nullptr) { if (auth_property != nullptr) {
filter.peer_identity_ = filter->peer_identity_ =
std::string(auth_property->value, auth_property->value_length); std::string(auth_property->value, auth_property->value_length);
} }
} }

@ -39,7 +39,7 @@ class ServerLoadReportingFilter
public: public:
static const grpc_channel_filter kFilter; static const grpc_channel_filter kFilter;
static absl::StatusOr<ServerLoadReportingFilter> Create( static absl::StatusOr<std::unique_ptr<ServerLoadReportingFilter>> Create(
const ChannelArgs& args, ChannelFilter::Args); const ChannelArgs& args, ChannelFilter::Args);
// Getters. // Getters.

@ -342,21 +342,23 @@ class CallData {
} // namespace } // namespace
absl::StatusOr<ClientLoggingFilter> ClientLoggingFilter::Create( absl::StatusOr<std::unique_ptr<ClientLoggingFilter>>
const ChannelArgs& args, ChannelFilter::Args /*filter_args*/) { ClientLoggingFilter::Create(const ChannelArgs& args,
ChannelFilter::Args /*filter_args*/) {
absl::optional<absl::string_view> default_authority = absl::optional<absl::string_view> default_authority =
args.GetString(GRPC_ARG_DEFAULT_AUTHORITY); args.GetString(GRPC_ARG_DEFAULT_AUTHORITY);
if (default_authority.has_value()) { if (default_authority.has_value()) {
return ClientLoggingFilter(std::string(default_authority.value())); return std::make_unique<ClientLoggingFilter>(
std::string(default_authority.value()));
} }
absl::optional<std::string> server_uri = absl::optional<std::string> server_uri =
args.GetOwnedString(GRPC_ARG_SERVER_URI); args.GetOwnedString(GRPC_ARG_SERVER_URI);
if (server_uri.has_value()) { if (server_uri.has_value()) {
return ClientLoggingFilter( return std::make_unique<ClientLoggingFilter>(
CoreConfiguration::Get().resolver_registry().GetDefaultAuthority( CoreConfiguration::Get().resolver_registry().GetDefaultAuthority(
*server_uri)); *server_uri));
} }
return ClientLoggingFilter(""); return std::make_unique<ClientLoggingFilter>("");
} }
// Construct a promise for one call. // Construct a promise for one call.
@ -445,9 +447,10 @@ const grpc_channel_filter ClientLoggingFilter::kFilter =
kFilterExaminesInboundMessages | kFilterExaminesInboundMessages |
kFilterExaminesOutboundMessages>("logging"); kFilterExaminesOutboundMessages>("logging");
absl::StatusOr<ServerLoggingFilter> ServerLoggingFilter::Create( absl::StatusOr<std::unique_ptr<ServerLoggingFilter>>
const ChannelArgs& /*args*/, ChannelFilter::Args /*filter_args*/) { ServerLoggingFilter::Create(const ChannelArgs& /*args*/,
return ServerLoggingFilter(); ChannelFilter::Args /*filter_args*/) {
return std::make_unique<ServerLoggingFilter>();
} }
// Construct a promise for one call. // Construct a promise for one call.

@ -39,24 +39,25 @@ class ClientLoggingFilter final : public ChannelFilter {
public: public:
static const grpc_channel_filter kFilter; static const grpc_channel_filter kFilter;
static absl::StatusOr<ClientLoggingFilter> Create( static absl::StatusOr<std::unique_ptr<ClientLoggingFilter>> Create(
const ChannelArgs& args, ChannelFilter::Args /*filter_args*/); const ChannelArgs& args, ChannelFilter::Args /*filter_args*/);
explicit ClientLoggingFilter(std::string default_authority)
: default_authority_(std::move(default_authority)) {}
// Construct a promise for one call. // Construct a promise for one call.
ArenaPromise<ServerMetadataHandle> MakeCallPromise( ArenaPromise<ServerMetadataHandle> MakeCallPromise(
CallArgs call_args, NextPromiseFactory next_promise_factory) override; CallArgs call_args, NextPromiseFactory next_promise_factory) override;
private: private:
explicit ClientLoggingFilter(std::string default_authority) const std::string default_authority_;
: default_authority_(std::move(default_authority)) {}
std::string default_authority_;
}; };
class ServerLoggingFilter final : public ChannelFilter { class ServerLoggingFilter final : public ChannelFilter {
public: public:
static const grpc_channel_filter kFilter; static const grpc_channel_filter kFilter;
static absl::StatusOr<ServerLoggingFilter> Create( static absl::StatusOr<std::unique_ptr<ServerLoggingFilter>> Create(
const ChannelArgs& args, ChannelFilter::Args /*filter_args*/); const ChannelArgs& args, ChannelFilter::Args /*filter_args*/);
// Construct a promise for one call. // Construct a promise for one call.

@ -25,7 +25,6 @@
#include "absl/strings/str_format.h" #include "absl/strings/str_format.h"
#include <grpc/grpc.h>
#include <grpc/impl/channel_arg_names.h> #include <grpc/impl/channel_arg_names.h>
#include <grpc/status.h> #include <grpc/status.h>
#include <grpc/support/log.h> #include <grpc/support/log.h>
@ -142,19 +141,20 @@ const grpc_channel_filter ClientMessageSizeFilter::kFilter =
MakePromiseBasedFilter<ClientMessageSizeFilter, FilterEndpoint::kClient, MakePromiseBasedFilter<ClientMessageSizeFilter, FilterEndpoint::kClient,
kFilterExaminesOutboundMessages | kFilterExaminesOutboundMessages |
kFilterExaminesInboundMessages>("message_size"); kFilterExaminesInboundMessages>("message_size");
const grpc_channel_filter ServerMessageSizeFilter::kFilter = const grpc_channel_filter ServerMessageSizeFilter::kFilter =
MakePromiseBasedFilter<ServerMessageSizeFilter, FilterEndpoint::kServer, MakePromiseBasedFilter<ServerMessageSizeFilter, FilterEndpoint::kServer,
kFilterExaminesOutboundMessages | kFilterExaminesOutboundMessages |
kFilterExaminesInboundMessages>("message_size"); kFilterExaminesInboundMessages>("message_size");
absl::StatusOr<ClientMessageSizeFilter> ClientMessageSizeFilter::Create( absl::StatusOr<std::unique_ptr<ClientMessageSizeFilter>>
const ChannelArgs& args, ChannelFilter::Args) { ClientMessageSizeFilter::Create(const ChannelArgs& args, ChannelFilter::Args) {
return ClientMessageSizeFilter(args); return std::make_unique<ClientMessageSizeFilter>(args);
} }
absl::StatusOr<ServerMessageSizeFilter> ServerMessageSizeFilter::Create( absl::StatusOr<std::unique_ptr<ServerMessageSizeFilter>>
const ChannelArgs& args, ChannelFilter::Args) { ServerMessageSizeFilter::Create(const ChannelArgs& args, ChannelFilter::Args) {
return ServerMessageSizeFilter(args); return std::make_unique<ServerMessageSizeFilter>(args);
} }
namespace { namespace {

@ -91,9 +91,12 @@ class ServerMessageSizeFilter final
public: public:
static const grpc_channel_filter kFilter; static const grpc_channel_filter kFilter;
static absl::StatusOr<ServerMessageSizeFilter> Create( static absl::StatusOr<std::unique_ptr<ServerMessageSizeFilter>> Create(
const ChannelArgs& args, ChannelFilter::Args filter_args); const ChannelArgs& args, ChannelFilter::Args filter_args);
explicit ServerMessageSizeFilter(const ChannelArgs& args)
: parsed_config_(MessageSizeParsedConfig::GetFromChannelArgs(args)) {}
class Call { class Call {
public: public:
static const NoInterceptor OnClientInitialMetadata; static const NoInterceptor OnClientInitialMetadata;
@ -107,8 +110,6 @@ class ServerMessageSizeFilter final
}; };
private: private:
explicit ServerMessageSizeFilter(const ChannelArgs& args)
: parsed_config_(MessageSizeParsedConfig::GetFromChannelArgs(args)) {}
const MessageSizeParsedConfig parsed_config_; const MessageSizeParsedConfig parsed_config_;
}; };
@ -117,9 +118,12 @@ class ClientMessageSizeFilter final
public: public:
static const grpc_channel_filter kFilter; static const grpc_channel_filter kFilter;
static absl::StatusOr<ClientMessageSizeFilter> Create( static absl::StatusOr<std::unique_ptr<ClientMessageSizeFilter>> Create(
const ChannelArgs& args, ChannelFilter::Args filter_args); const ChannelArgs& args, ChannelFilter::Args filter_args);
explicit ClientMessageSizeFilter(const ChannelArgs& args)
: parsed_config_(MessageSizeParsedConfig::GetFromChannelArgs(args)) {}
class Call { class Call {
public: public:
explicit Call(ClientMessageSizeFilter* filter); explicit Call(ClientMessageSizeFilter* filter);
@ -136,8 +140,6 @@ class ClientMessageSizeFilter final
}; };
private: private:
explicit ClientMessageSizeFilter(const ChannelArgs& args)
: parsed_config_(MessageSizeParsedConfig::GetFromChannelArgs(args)) {}
const size_t service_config_parser_index_{MessageSizeParser::ParserIndex()}; const size_t service_config_parser_index_{MessageSizeParser::ParserIndex()};
const MessageSizeParsedConfig parsed_config_; const MessageSizeParsedConfig parsed_config_;
}; };

@ -82,14 +82,21 @@ RbacFilter::RbacFilter(size_t index,
service_config_parser_index_(RbacServiceConfigParser::ParserIndex()), service_config_parser_index_(RbacServiceConfigParser::ParserIndex()),
per_channel_evaluate_args_(std::move(per_channel_evaluate_args)) {} per_channel_evaluate_args_(std::move(per_channel_evaluate_args)) {}
absl::StatusOr<RbacFilter> RbacFilter::Create(const ChannelArgs& args, absl::StatusOr<std::unique_ptr<RbacFilter>> RbacFilter::Create(
ChannelFilter::Args filter_args) { const ChannelArgs& args, ChannelFilter::Args filter_args) {
auto* auth_context = args.GetObject<grpc_auth_context>(); auto* auth_context = args.GetObject<grpc_auth_context>();
if (auth_context == nullptr) { if (auth_context == nullptr) {
return GRPC_ERROR_CREATE("No auth context found"); return GRPC_ERROR_CREATE("No auth context found");
} }
return RbacFilter(filter_args.instance_id(), auto* transport = args.GetObject<Transport>();
EvaluateArgs::PerChannelArgs(auth_context, args)); if (transport == nullptr) {
// This should never happen since the transport is always set on the server
// side.
return GRPC_ERROR_CREATE("No transport configured");
}
return std::make_unique<RbacFilter>(
filter_args.instance_id(),
EvaluateArgs::PerChannelArgs(auth_context, args));
} }
void RbacFilterRegister(CoreConfiguration::Builder* builder) { void RbacFilterRegister(CoreConfiguration::Builder* builder) {

@ -42,8 +42,11 @@ class RbacFilter : public ImplementChannelFilter<RbacFilter> {
// and enforces the RBAC policy. // and enforces the RBAC policy.
static const grpc_channel_filter kFilterVtable; static const grpc_channel_filter kFilterVtable;
static absl::StatusOr<RbacFilter> Create(const ChannelArgs& args, static absl::StatusOr<std::unique_ptr<RbacFilter>> Create(
ChannelFilter::Args filter_args); const ChannelArgs& args, ChannelFilter::Args filter_args);
RbacFilter(size_t index,
EvaluateArgs::PerChannelArgs per_channel_evaluate_args);
class Call { class Call {
public: public:
@ -57,9 +60,6 @@ class RbacFilter : public ImplementChannelFilter<RbacFilter> {
}; };
private: private:
RbacFilter(size_t index,
EvaluateArgs::PerChannelArgs per_channel_evaluate_args);
// The index of this filter instance among instances of the same filter. // The index of this filter instance among instances of the same filter.
size_t index_; size_t index_;
// Assigned index for service config data from the parser. // Assigned index for service config data from the parser.

@ -49,19 +49,22 @@ namespace grpc_core {
namespace { namespace {
class ServerConfigSelectorFilter final class ServerConfigSelectorFilter final
: public ImplementChannelFilter<ServerConfigSelectorFilter> { : public ImplementChannelFilter<ServerConfigSelectorFilter>,
public InternallyRefCounted<ServerConfigSelectorFilter> {
public: public:
~ServerConfigSelectorFilter() override; explicit ServerConfigSelectorFilter(
RefCountedPtr<ServerConfigSelectorProvider>
server_config_selector_provider);
ServerConfigSelectorFilter(const ServerConfigSelectorFilter&) = delete; ServerConfigSelectorFilter(const ServerConfigSelectorFilter&) = delete;
ServerConfigSelectorFilter& operator=(const ServerConfigSelectorFilter&) = ServerConfigSelectorFilter& operator=(const ServerConfigSelectorFilter&) =
delete; delete;
ServerConfigSelectorFilter(ServerConfigSelectorFilter&&) = default;
ServerConfigSelectorFilter& operator=(ServerConfigSelectorFilter&&) = default;
static absl::StatusOr<ServerConfigSelectorFilter> Create( static absl::StatusOr<OrphanablePtr<ServerConfigSelectorFilter>> Create(
const ChannelArgs& args, ChannelFilter::Args); const ChannelArgs& args, ChannelFilter::Args);
void Orphan() override;
class Call { class Call {
public: public:
absl::Status OnClientInitialMetadata(ClientMetadata& md, absl::Status OnClientInitialMetadata(ClientMetadata& md,
@ -74,70 +77,66 @@ class ServerConfigSelectorFilter final
}; };
absl::StatusOr<RefCountedPtr<ServerConfigSelector>> config_selector() { absl::StatusOr<RefCountedPtr<ServerConfigSelector>> config_selector() {
MutexLock lock(&state_->mu); MutexLock lock(&mu_);
return state_->config_selector.value(); return config_selector_.value();
} }
private: private:
struct State {
Mutex mu;
absl::optional<absl::StatusOr<RefCountedPtr<ServerConfigSelector>>>
config_selector ABSL_GUARDED_BY(mu);
};
class ServerConfigSelectorWatcher class ServerConfigSelectorWatcher
: public ServerConfigSelectorProvider::ServerConfigSelectorWatcher { : public ServerConfigSelectorProvider::ServerConfigSelectorWatcher {
public: public:
explicit ServerConfigSelectorWatcher(std::shared_ptr<State> state) explicit ServerConfigSelectorWatcher(
: state_(state) {} RefCountedPtr<ServerConfigSelectorFilter> filter)
: filter_(filter) {}
void OnServerConfigSelectorUpdate( void OnServerConfigSelectorUpdate(
absl::StatusOr<RefCountedPtr<ServerConfigSelector>> update) override { absl::StatusOr<RefCountedPtr<ServerConfigSelector>> update) override {
MutexLock lock(&state_->mu); MutexLock lock(&filter_->mu_);
state_->config_selector = std::move(update); filter_->config_selector_ = std::move(update);
} }
private: private:
std::shared_ptr<State> state_; RefCountedPtr<ServerConfigSelectorFilter> filter_;
}; };
explicit ServerConfigSelectorFilter(
RefCountedPtr<ServerConfigSelectorProvider>
server_config_selector_provider);
RefCountedPtr<ServerConfigSelectorProvider> server_config_selector_provider_; RefCountedPtr<ServerConfigSelectorProvider> server_config_selector_provider_;
std::shared_ptr<State> state_; Mutex mu_;
absl::optional<absl::StatusOr<RefCountedPtr<ServerConfigSelector>>>
config_selector_ ABSL_GUARDED_BY(mu_);
}; };
absl::StatusOr<ServerConfigSelectorFilter> ServerConfigSelectorFilter::Create( absl::StatusOr<OrphanablePtr<ServerConfigSelectorFilter>>
const ChannelArgs& args, ChannelFilter::Args) { ServerConfigSelectorFilter::Create(const ChannelArgs& args,
ChannelFilter::Args) {
ServerConfigSelectorProvider* server_config_selector_provider = ServerConfigSelectorProvider* server_config_selector_provider =
args.GetObject<ServerConfigSelectorProvider>(); args.GetObject<ServerConfigSelectorProvider>();
if (server_config_selector_provider == nullptr) { if (server_config_selector_provider == nullptr) {
return absl::UnknownError("No ServerConfigSelectorProvider object found"); return absl::UnknownError("No ServerConfigSelectorProvider object found");
} }
return ServerConfigSelectorFilter(server_config_selector_provider->Ref()); return MakeOrphanable<ServerConfigSelectorFilter>(
server_config_selector_provider->Ref());
} }
ServerConfigSelectorFilter::ServerConfigSelectorFilter( ServerConfigSelectorFilter::ServerConfigSelectorFilter(
RefCountedPtr<ServerConfigSelectorProvider> server_config_selector_provider) RefCountedPtr<ServerConfigSelectorProvider> server_config_selector_provider)
: server_config_selector_provider_( : server_config_selector_provider_(
std::move(server_config_selector_provider)), std::move(server_config_selector_provider)) {
state_(std::make_shared<State>()) {
GPR_ASSERT(server_config_selector_provider_ != nullptr); GPR_ASSERT(server_config_selector_provider_ != nullptr);
auto server_config_selector_watcher = auto server_config_selector_watcher =
std::make_unique<ServerConfigSelectorWatcher>(state_); std::make_unique<ServerConfigSelectorWatcher>(Ref());
auto config_selector = server_config_selector_provider_->Watch( auto config_selector = server_config_selector_provider_->Watch(
std::move(server_config_selector_watcher)); std::move(server_config_selector_watcher));
MutexLock lock(&state_->mu); MutexLock lock(&mu_);
// It's possible for the watcher to have already updated config_selector_ // It's possible for the watcher to have already updated config_selector_
if (!state_->config_selector.has_value()) { if (!config_selector_.has_value()) {
state_->config_selector = std::move(config_selector); config_selector_ = std::move(config_selector);
} }
} }
ServerConfigSelectorFilter::~ServerConfigSelectorFilter() { void ServerConfigSelectorFilter::Orphan() {
if (server_config_selector_provider_ != nullptr) { if (server_config_selector_provider_ != nullptr) {
server_config_selector_provider_->CancelWatch(); server_config_selector_provider_->CancelWatch();
} }
Unref();
} }
absl::Status ServerConfigSelectorFilter::Call::OnClientInitialMetadata( absl::Status ServerConfigSelectorFilter::Call::OnClientInitialMetadata(

@ -72,9 +72,10 @@ const grpc_channel_filter StatefulSessionFilter::kFilter =
kFilterExaminesServerInitialMetadata>( kFilterExaminesServerInitialMetadata>(
"stateful_session_filter"); "stateful_session_filter");
absl::StatusOr<StatefulSessionFilter> StatefulSessionFilter::Create( absl::StatusOr<std::unique_ptr<StatefulSessionFilter>>
const ChannelArgs&, ChannelFilter::Args filter_args) { StatefulSessionFilter::Create(const ChannelArgs&,
return StatefulSessionFilter(filter_args); ChannelFilter::Args filter_args) {
return std::make_unique<StatefulSessionFilter>(filter_args);
} }
StatefulSessionFilter::StatefulSessionFilter(ChannelFilter::Args filter_args) StatefulSessionFilter::StatefulSessionFilter(ChannelFilter::Args filter_args)

@ -74,9 +74,11 @@ class StatefulSessionFilter
public: public:
static const grpc_channel_filter kFilter; static const grpc_channel_filter kFilter;
static absl::StatusOr<StatefulSessionFilter> Create( static absl::StatusOr<std::unique_ptr<StatefulSessionFilter>> Create(
const ChannelArgs& args, ChannelFilter::Args filter_args); const ChannelArgs& args, ChannelFilter::Args filter_args);
explicit StatefulSessionFilter(ChannelFilter::Args filter_args);
class Call { class Call {
public: public:
void OnClientInitialMetadata(ClientMetadata& md, void OnClientInitialMetadata(ClientMetadata& md,
@ -97,7 +99,6 @@ class StatefulSessionFilter
}; };
private: private:
explicit StatefulSessionFilter(ChannelFilter::Args filter_args);
// The relative index of instances of the same filter. // The relative index of instances of the same filter.
const size_t index_; const size_t index_;
// Index of the service config parser. // Index of the service config parser.

@ -106,7 +106,7 @@ BaseCallData::BaseCallData(
? arena_->New<ReceiveMessage>(this, make_recv_interceptor()) ? arena_->New<ReceiveMessage>(this, make_recv_interceptor())
: nullptr), : nullptr),
event_engine_( event_engine_(
static_cast<ChannelFilter*>(elem->channel_data) ChannelFilterFromElem(elem)
->hack_until_per_channel_stack_event_engines_land_get_event_engine()) { ->hack_until_per_channel_stack_event_engines_land_get_event_engine()) {
} }
@ -1572,7 +1572,7 @@ void ClientCallData::Cancel(grpc_error_handle error, Flusher* flusher) {
// metadata and return some trailing metadata. // metadata and return some trailing metadata.
void ClientCallData::StartPromise(Flusher* flusher) { void ClientCallData::StartPromise(Flusher* flusher) {
GPR_ASSERT(send_initial_state_ == SendInitialState::kQueued); GPR_ASSERT(send_initial_state_ == SendInitialState::kQueued);
ChannelFilter* filter = static_cast<ChannelFilter*>(elem()->channel_data); ChannelFilter* filter = promise_filter_detail::ChannelFilterFromElem(elem());
// Construct the promise. // Construct the promise.
PollContext ctx(this, flusher); PollContext ctx(this, flusher);
@ -2369,7 +2369,7 @@ void ServerCallData::RecvInitialMetadataReady(grpc_error_handle error) {
// Start the promise. // Start the promise.
ScopedContext context(this); ScopedContext context(this);
// Construct the promise. // Construct the promise.
ChannelFilter* filter = static_cast<ChannelFilter*>(elem()->channel_data); ChannelFilter* filter = promise_filter_detail::ChannelFilterFromElem(elem());
FakeActivity(this).Run([this, filter] { FakeActivity(this).Run([this, filter] {
promise_ = filter->MakeCallPromise( promise_ = filter->MakeCallPromise(
CallArgs{WrapMetadata(recv_initial_metadata_), CallArgs{WrapMetadata(recv_initial_metadata_),

@ -1841,6 +1841,15 @@ struct BaseCallDataMethods {
} }
}; };
// The type of object returned by a filter's Create method.
template <typename T>
using CreatedType = typename decltype(T::Create(ChannelArgs(), {}))::value_type;
template <typename GrpcChannelOrCallElement>
inline ChannelFilter* ChannelFilterFromElem(GrpcChannelOrCallElement* elem) {
return *static_cast<ChannelFilter**>(elem->channel_data);
}
template <typename CallData, uint8_t kFlags> template <typename CallData, uint8_t kFlags>
struct CallDataFilterWithFlagsMethods { struct CallDataFilterWithFlagsMethods {
static absl::Status InitCallElem(grpc_call_element* elem, static absl::Status InitCallElem(grpc_call_element* elem,
@ -1865,32 +1874,25 @@ struct ChannelFilterMethods {
static ArenaPromise<ServerMetadataHandle> MakeCallPromise( static ArenaPromise<ServerMetadataHandle> MakeCallPromise(
grpc_channel_element* elem, CallArgs call_args, grpc_channel_element* elem, CallArgs call_args,
NextPromiseFactory next_promise_factory) { NextPromiseFactory next_promise_factory) {
return static_cast<ChannelFilter*>(elem->channel_data) return ChannelFilterFromElem(elem)->MakeCallPromise(
->MakeCallPromise(std::move(call_args), std::move(call_args), std::move(next_promise_factory));
std::move(next_promise_factory));
} }
static void StartTransportOp(grpc_channel_element* elem, static void StartTransportOp(grpc_channel_element* elem,
grpc_transport_op* op) { grpc_transport_op* op) {
if (!static_cast<ChannelFilter*>(elem->channel_data) if (!ChannelFilterFromElem(elem)->StartTransportOp(op)) {
->StartTransportOp(op)) {
grpc_channel_next_op(elem, op); grpc_channel_next_op(elem, op);
} }
} }
static void PostInitChannelElem(grpc_channel_stack*, static void PostInitChannelElem(grpc_channel_stack*,
grpc_channel_element* elem) { grpc_channel_element* elem) {
static_cast<ChannelFilter*>(elem->channel_data)->PostInit(); ChannelFilterFromElem(elem)->PostInit();
}
static void DestroyChannelElem(grpc_channel_element* elem) {
static_cast<ChannelFilter*>(elem->channel_data)->~ChannelFilter();
} }
static void GetChannelInfo(grpc_channel_element* elem, static void GetChannelInfo(grpc_channel_element* elem,
const grpc_channel_info* info) { const grpc_channel_info* info) {
if (!static_cast<ChannelFilter*>(elem->channel_data) if (!ChannelFilterFromElem(elem)->GetChannelInfo(info)) {
->GetChannelInfo(info)) {
grpc_channel_next_get_info(elem, info); grpc_channel_next_get_info(elem, info);
} }
} }
@ -1904,15 +1906,16 @@ struct ChannelFilterWithFlagsMethods {
auto status = F::Create(args->channel_args, auto status = F::Create(args->channel_args,
ChannelFilter::Args(args->channel_stack, elem)); ChannelFilter::Args(args->channel_stack, elem));
if (!status.ok()) { if (!status.ok()) {
static_assert( new (elem->channel_data) F*(nullptr);
sizeof(promise_filter_detail::InvalidChannelFilter) <= sizeof(F),
"InvalidChannelFilter must fit in F");
new (elem->channel_data) promise_filter_detail::InvalidChannelFilter();
return absl_status_to_grpc_error(status.status()); return absl_status_to_grpc_error(status.status());
} }
new (elem->channel_data) F(std::move(*status)); new (elem->channel_data) F*(status->release());
return absl::OkStatus(); return absl::OkStatus();
} }
static void DestroyChannelElem(grpc_channel_element* elem) {
CreatedType<F> channel_elem(DownCast<F*>(ChannelFilterFromElem(elem)));
}
}; };
} // namespace promise_filter_detail } // namespace promise_filter_detail
@ -1958,7 +1961,8 @@ MakePromiseBasedFilter(const char* name) {
// post_init_channel_elem // post_init_channel_elem
promise_filter_detail::ChannelFilterMethods::PostInitChannelElem, promise_filter_detail::ChannelFilterMethods::PostInitChannelElem,
// destroy_channel_elem // destroy_channel_elem
promise_filter_detail::ChannelFilterMethods::DestroyChannelElem, promise_filter_detail::ChannelFilterWithFlagsMethods<
F, kFlags>::DestroyChannelElem,
// get_channel_info // get_channel_info
promise_filter_detail::ChannelFilterMethods::GetChannelInfo, promise_filter_detail::ChannelFilterMethods::GetChannelInfo,
// name // name
@ -2004,7 +2008,8 @@ MakePromiseBasedFilter(const char* name) {
// post_init_channel_elem // post_init_channel_elem
promise_filter_detail::ChannelFilterMethods::PostInitChannelElem, promise_filter_detail::ChannelFilterMethods::PostInitChannelElem,
// destroy_channel_elem // destroy_channel_elem
promise_filter_detail::ChannelFilterMethods::DestroyChannelElem, promise_filter_detail::ChannelFilterWithFlagsMethods<
F, kFlags>::DestroyChannelElem,
// get_channel_info // get_channel_info
promise_filter_detail::ChannelFilterMethods::GetChannelInfo, promise_filter_detail::ChannelFilterMethods::GetChannelInfo,
// name // name
@ -2046,7 +2051,8 @@ MakePromiseBasedFilter(const char* name) {
// post_init_channel_elem // post_init_channel_elem
promise_filter_detail::ChannelFilterMethods::PostInitChannelElem, promise_filter_detail::ChannelFilterMethods::PostInitChannelElem,
// destroy_channel_elem // destroy_channel_elem
promise_filter_detail::ChannelFilterMethods::DestroyChannelElem, promise_filter_detail::ChannelFilterWithFlagsMethods<
F, kFlags>::DestroyChannelElem,
// get_channel_info // get_channel_info
promise_filter_detail::ChannelFilterMethods::GetChannelInfo, promise_filter_detail::ChannelFilterMethods::GetChannelInfo,
// name // name

@ -17,6 +17,7 @@
#include "src/core/lib/channel/server_call_tracer_filter.h" #include "src/core/lib/channel/server_call_tracer_filter.h"
#include <functional> #include <functional>
#include <memory>
#include <utility> #include <utility>
#include "absl/status/status.h" #include "absl/status/status.h"
@ -49,7 +50,7 @@ class ServerCallTracerFilter
public: public:
static const grpc_channel_filter kFilter; static const grpc_channel_filter kFilter;
static absl::StatusOr<ServerCallTracerFilter> Create( static absl::StatusOr<std::unique_ptr<ServerCallTracerFilter>> Create(
const ChannelArgs& /*args*/, ChannelFilter::Args /*filter_args*/); const ChannelArgs& /*args*/, ChannelFilter::Args /*filter_args*/);
class Call { class Call {
@ -98,9 +99,10 @@ const grpc_channel_filter ServerCallTracerFilter::kFilter =
kFilterExaminesServerInitialMetadata>( kFilterExaminesServerInitialMetadata>(
"server_call_tracer"); "server_call_tracer");
absl::StatusOr<ServerCallTracerFilter> ServerCallTracerFilter::Create( absl::StatusOr<std::unique_ptr<ServerCallTracerFilter>>
const ChannelArgs& /*args*/, ChannelFilter::Args /*filter_args*/) { ServerCallTracerFilter::Create(const ChannelArgs& /*args*/,
return ServerCallTracerFilter(); ChannelFilter::Args /*filter_args*/) {
return std::make_unique<ServerCallTracerFilter>();
} }
} // namespace } // namespace

@ -51,14 +51,14 @@ GrpcServerAuthzFilter::GrpcServerAuthzFilter(
per_channel_evaluate_args_(auth_context_.get(), args), per_channel_evaluate_args_(auth_context_.get(), args),
provider_(std::move(provider)) {} provider_(std::move(provider)) {}
absl::StatusOr<GrpcServerAuthzFilter> GrpcServerAuthzFilter::Create( absl::StatusOr<std::unique_ptr<GrpcServerAuthzFilter>>
const ChannelArgs& args, ChannelFilter::Args) { GrpcServerAuthzFilter::Create(const ChannelArgs& args, ChannelFilter::Args) {
auto* auth_context = args.GetObject<grpc_auth_context>(); auto* auth_context = args.GetObject<grpc_auth_context>();
auto* provider = args.GetObject<grpc_authorization_policy_provider>(); auto* provider = args.GetObject<grpc_authorization_policy_provider>();
if (provider == nullptr) { if (provider == nullptr) {
return absl::InvalidArgumentError("Failed to get authorization provider."); return absl::InvalidArgumentError("Failed to get authorization provider.");
} }
return GrpcServerAuthzFilter( return std::make_unique<GrpcServerAuthzFilter>(
auth_context != nullptr ? auth_context->Ref() : nullptr, args, auth_context != nullptr ? auth_context->Ref() : nullptr, args,
provider->Ref()); provider->Ref());
} }

@ -37,8 +37,12 @@ class GrpcServerAuthzFilter final
public: public:
static const grpc_channel_filter kFilter; static const grpc_channel_filter kFilter;
static absl::StatusOr<GrpcServerAuthzFilter> Create(const ChannelArgs& args, static absl::StatusOr<std::unique_ptr<GrpcServerAuthzFilter>> Create(
ChannelFilter::Args); const ChannelArgs& args, ChannelFilter::Args);
GrpcServerAuthzFilter(
RefCountedPtr<grpc_auth_context> auth_context, const ChannelArgs& args,
RefCountedPtr<grpc_authorization_policy_provider> provider);
class Call { class Call {
public: public:
@ -52,10 +56,6 @@ class GrpcServerAuthzFilter final
}; };
private: private:
GrpcServerAuthzFilter(
RefCountedPtr<grpc_auth_context> auth_context, const ChannelArgs& args,
RefCountedPtr<grpc_authorization_policy_provider> provider);
bool IsAuthorized(ClientMetadata& initial_metadata); bool IsAuthorized(ClientMetadata& initial_metadata);
RefCountedPtr<grpc_auth_context> auth_context_; RefCountedPtr<grpc_auth_context> auth_context_;

@ -42,18 +42,18 @@ class ClientAuthFilter final : public ChannelFilter {
public: public:
static const grpc_channel_filter kFilter; static const grpc_channel_filter kFilter;
static absl::StatusOr<ClientAuthFilter> Create(const ChannelArgs& args, ClientAuthFilter(
ChannelFilter::Args); RefCountedPtr<grpc_channel_security_connector> security_connector,
RefCountedPtr<grpc_auth_context> auth_context);
static absl::StatusOr<std::unique_ptr<ClientAuthFilter>> Create(
const ChannelArgs& args, ChannelFilter::Args);
// Construct a promise for one call. // Construct a promise for one call.
ArenaPromise<ServerMetadataHandle> MakeCallPromise( ArenaPromise<ServerMetadataHandle> MakeCallPromise(
CallArgs call_args, NextPromiseFactory next_promise_factory) override; CallArgs call_args, NextPromiseFactory next_promise_factory) override;
private: private:
ClientAuthFilter(
RefCountedPtr<grpc_channel_security_connector> security_connector,
RefCountedPtr<grpc_auth_context> auth_context);
ArenaPromise<absl::StatusOr<CallArgs>> GetCallCredsMetadata( ArenaPromise<absl::StatusOr<CallArgs>> GetCallCredsMetadata(
CallArgs call_args); CallArgs call_args);
@ -63,9 +63,6 @@ class ClientAuthFilter final : public ChannelFilter {
class ServerAuthFilter final : public ImplementChannelFilter<ServerAuthFilter> { class ServerAuthFilter final : public ImplementChannelFilter<ServerAuthFilter> {
private: private:
ServerAuthFilter(RefCountedPtr<grpc_server_credentials> server_credentials,
RefCountedPtr<grpc_auth_context> auth_context);
class RunApplicationCode { class RunApplicationCode {
public: public:
RunApplicationCode(ServerAuthFilter* filter, ClientMetadata& metadata); RunApplicationCode(ServerAuthFilter* filter, ClientMetadata& metadata);
@ -98,8 +95,11 @@ class ServerAuthFilter final : public ImplementChannelFilter<ServerAuthFilter> {
public: public:
static const grpc_channel_filter kFilter; static const grpc_channel_filter kFilter;
static absl::StatusOr<ServerAuthFilter> Create(const ChannelArgs& args, ServerAuthFilter(RefCountedPtr<grpc_server_credentials> server_credentials,
ChannelFilter::Args); RefCountedPtr<grpc_auth_context> auth_context);
static absl::StatusOr<std::unique_ptr<ServerAuthFilter>> Create(
const ChannelArgs& args, ChannelFilter::Args);
class Call { class Call {
public: public:

@ -203,7 +203,7 @@ ArenaPromise<ServerMetadataHandle> ClientAuthFilter::MakeCallPromise(
next_promise_factory); next_promise_factory);
} }
absl::StatusOr<ClientAuthFilter> ClientAuthFilter::Create( absl::StatusOr<std::unique_ptr<ClientAuthFilter>> ClientAuthFilter::Create(
const ChannelArgs& args, ChannelFilter::Args) { const ChannelArgs& args, ChannelFilter::Args) {
auto* sc = args.GetObject<grpc_security_connector>(); auto* sc = args.GetObject<grpc_security_connector>();
if (sc == nullptr) { if (sc == nullptr) {
@ -215,8 +215,9 @@ absl::StatusOr<ClientAuthFilter> ClientAuthFilter::Create(
return absl::InvalidArgumentError( return absl::InvalidArgumentError(
"Auth context missing from client auth filter args"); "Auth context missing from client auth filter args");
} }
return ClientAuthFilter(sc->RefAsSubclass<grpc_channel_security_connector>(), return std::make_unique<ClientAuthFilter>(
auth_context->Ref()); sc->RefAsSubclass<grpc_channel_security_connector>(),
auth_context->Ref());
} }
const grpc_channel_filter ClientAuthFilter::kFilter = const grpc_channel_filter ClientAuthFilter::kFilter =

@ -212,12 +212,13 @@ ServerAuthFilter::ServerAuthFilter(
RefCountedPtr<grpc_auth_context> auth_context) RefCountedPtr<grpc_auth_context> auth_context)
: server_credentials_(server_credentials), auth_context_(auth_context) {} : server_credentials_(server_credentials), auth_context_(auth_context) {}
absl::StatusOr<ServerAuthFilter> ServerAuthFilter::Create( absl::StatusOr<std::unique_ptr<ServerAuthFilter>> ServerAuthFilter::Create(
const ChannelArgs& args, ChannelFilter::Args) { const ChannelArgs& args, ChannelFilter::Args) {
auto auth_context = args.GetObjectRef<grpc_auth_context>(); auto auth_context = args.GetObjectRef<grpc_auth_context>();
GPR_ASSERT(auth_context != nullptr); GPR_ASSERT(auth_context != nullptr);
auto creds = args.GetObjectRef<grpc_server_credentials>(); auto creds = args.GetObjectRef<grpc_server_credentials>();
return ServerAuthFilter(std::move(creds), std::move(auth_context)); return std::make_unique<ServerAuthFilter>(std::move(creds),
std::move(auth_context));
} }
} // namespace grpc_core } // namespace grpc_core

@ -285,6 +285,11 @@ class ChannelInit {
grpc_channel_stack_type type, const ChannelArgs& args) const; grpc_channel_stack_type type, const ChannelArgs& args) const;
private: private:
// The type of object returned by a filter's Create method.
template <typename T>
using CreatedType =
typename decltype(T::Create(ChannelArgs(), {}))::value_type;
struct Filter { struct Filter {
Filter(const grpc_channel_filter* filter, const ChannelFilterVtable* vtable, Filter(const grpc_channel_filter* filter, const ChannelFilterVtable* vtable,
std::vector<InclusionPredicate> predicates, bool skip_v3, std::vector<InclusionPredicate> predicates, bool skip_v3,
@ -328,17 +333,17 @@ class ChannelInit {
template <typename T> template <typename T>
const ChannelInit::ChannelFilterVtable const ChannelInit::ChannelFilterVtable
ChannelInit::VtableForType<T, absl::void_t<typename T::Call>>::kVtable = { ChannelInit::VtableForType<T, absl::void_t<typename T::Call>>::kVtable = {
sizeof(T), alignof(T), sizeof(CreatedType<T>), alignof(CreatedType<T>),
[](void* data, const ChannelArgs& args) -> absl::Status { [](void* data, const ChannelArgs& args) -> absl::Status {
// TODO(ctiller): fill in ChannelFilter::Args (2nd arg) // TODO(ctiller): fill in ChannelFilter::Args (2nd arg)
absl::StatusOr<T> r = T::Create(args, {}); absl::StatusOr<CreatedType<T>> r = T::Create(args, {});
if (!r.ok()) return r.status(); if (!r.ok()) return r.status();
new (data) T(std::move(*r)); new (data) CreatedType<T>(std::move(*r));
return absl::OkStatus(); return absl::OkStatus();
}, },
[](void* data) { static_cast<T*>(data)->~T(); }, [](void* data) { Destruct(static_cast<CreatedType<T>*>(data)); },
[](void* data, CallFilters::StackBuilder& builder) { [](void* data, CallFilters::StackBuilder& builder) {
builder.Add(static_cast<T*>(data)); builder.Add(static_cast<CreatedType<T>*>(data)->get());
}}; }};
} // namespace grpc_core } // namespace grpc_core

@ -59,17 +59,15 @@ const grpc_channel_filter LameClientFilter::kFilter =
MakePromiseBasedFilter<LameClientFilter, FilterEndpoint::kClient, MakePromiseBasedFilter<LameClientFilter, FilterEndpoint::kClient,
kFilterIsLast>("lame-client"); kFilterIsLast>("lame-client");
absl::StatusOr<LameClientFilter> LameClientFilter::Create( absl::StatusOr<std::unique_ptr<LameClientFilter>> LameClientFilter::Create(
const ChannelArgs& args, ChannelFilter::Args) { const ChannelArgs& args, ChannelFilter::Args) {
return LameClientFilter( return std::make_unique<LameClientFilter>(
*args.GetPointer<absl::Status>(GRPC_ARG_LAME_FILTER_ERROR)); *args.GetPointer<absl::Status>(GRPC_ARG_LAME_FILTER_ERROR));
} }
LameClientFilter::LameClientFilter(absl::Status error) LameClientFilter::LameClientFilter(absl::Status error)
: error_(std::move(error)), state_(std::make_unique<State>()) {} : error_(std::move(error)),
state_tracker_("lame_client", GRPC_CHANNEL_SHUTDOWN) {}
LameClientFilter::State::State()
: state_tracker("lame_client", GRPC_CHANNEL_SHUTDOWN) {}
ArenaPromise<ServerMetadataHandle> LameClientFilter::MakeCallPromise( ArenaPromise<ServerMetadataHandle> LameClientFilter::MakeCallPromise(
CallArgs args, NextPromiseFactory) { CallArgs args, NextPromiseFactory) {
@ -92,13 +90,13 @@ bool LameClientFilter::GetChannelInfo(const grpc_channel_info*) { return true; }
bool LameClientFilter::StartTransportOp(grpc_transport_op* op) { bool LameClientFilter::StartTransportOp(grpc_transport_op* op) {
{ {
MutexLock lock(&state_->mu); MutexLock lock(&mu_);
if (op->start_connectivity_watch != nullptr) { if (op->start_connectivity_watch != nullptr) {
state_->state_tracker.AddWatcher(op->start_connectivity_watch_state, state_tracker_.AddWatcher(op->start_connectivity_watch_state,
std::move(op->start_connectivity_watch)); std::move(op->start_connectivity_watch));
} }
if (op->stop_connectivity_watch != nullptr) { if (op->stop_connectivity_watch != nullptr) {
state_->state_tracker.RemoveWatcher(op->stop_connectivity_watch); state_tracker_.RemoveWatcher(op->stop_connectivity_watch);
} }
} }
if (op->send_ping.on_initiate != nullptr) { if (op->send_ping.on_initiate != nullptr) {

@ -47,7 +47,9 @@ class LameClientFilter : public ChannelFilter {
public: public:
static const grpc_channel_filter kFilter; static const grpc_channel_filter kFilter;
static absl::StatusOr<LameClientFilter> Create( explicit LameClientFilter(absl::Status error);
static absl::StatusOr<std::unique_ptr<LameClientFilter>> Create(
const ChannelArgs& args, ChannelFilter::Args filter_args); const ChannelArgs& args, ChannelFilter::Args filter_args);
ArenaPromise<ServerMetadataHandle> MakeCallPromise( ArenaPromise<ServerMetadataHandle> MakeCallPromise(
CallArgs call_args, NextPromiseFactory next_promise_factory) override; CallArgs call_args, NextPromiseFactory next_promise_factory) override;
@ -55,15 +57,9 @@ class LameClientFilter : public ChannelFilter {
bool GetChannelInfo(const grpc_channel_info*) override; bool GetChannelInfo(const grpc_channel_info*) override;
private: private:
explicit LameClientFilter(absl::Status error);
absl::Status error_; absl::Status error_;
struct State { Mutex mu_;
State(); ConnectivityStateTracker state_tracker_ ABSL_GUARDED_BY(mu_);
Mutex mu;
ConnectivityStateTracker state_tracker ABSL_GUARDED_BY(mu);
};
std::unique_ptr<State> state_;
}; };
extern const grpc_arg_pointer_vtable kLameFilterErrorArgVtable; extern const grpc_arg_pointer_vtable kLameFilterErrorArgVtable;

@ -16,8 +16,6 @@
// //
// //
#include <grpc/support/port_platform.h>
#include "src/core/load_balancing/grpclb/client_load_reporting_filter.h" #include "src/core/load_balancing/grpclb/client_load_reporting_filter.h"
#include <functional> #include <functional>
@ -27,7 +25,8 @@
#include "absl/types/optional.h" #include "absl/types/optional.h"
#include "src/core/load_balancing/grpclb/grpclb_client_stats.h" #include <grpc/support/port_platform.h>
#include "src/core/lib/channel/channel_stack.h" #include "src/core/lib/channel/channel_stack.h"
#include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/gprpp/ref_counted_ptr.h"
#include "src/core/lib/promise/context.h" #include "src/core/lib/promise/context.h"
@ -36,48 +35,47 @@
#include "src/core/lib/resource_quota/arena.h" #include "src/core/lib/resource_quota/arena.h"
#include "src/core/lib/transport/metadata_batch.h" #include "src/core/lib/transport/metadata_batch.h"
#include "src/core/lib/transport/transport.h" #include "src/core/lib/transport/transport.h"
#include "src/core/load_balancing/grpclb/grpclb_client_stats.h"
namespace grpc_core { namespace grpc_core {
const NoInterceptor ClientLoadReportingFilter::Call::OnServerToClientMessage;
const NoInterceptor ClientLoadReportingFilter::Call::OnClientToServerMessage;
const NoInterceptor ClientLoadReportingFilter::Call::OnFinalize;
const grpc_channel_filter ClientLoadReportingFilter::kFilter = const grpc_channel_filter ClientLoadReportingFilter::kFilter =
MakePromiseBasedFilter<ClientLoadReportingFilter, FilterEndpoint::kClient, MakePromiseBasedFilter<ClientLoadReportingFilter, FilterEndpoint::kClient,
kFilterExaminesServerInitialMetadata>( kFilterExaminesServerInitialMetadata>(
"client_load_reporting"); "client_load_reporting");
absl::StatusOr<ClientLoadReportingFilter> ClientLoadReportingFilter::Create( absl::StatusOr<std::unique_ptr<ClientLoadReportingFilter>>
const ChannelArgs&, ChannelFilter::Args) { ClientLoadReportingFilter::Create(const ChannelArgs&, ChannelFilter::Args) {
return ClientLoadReportingFilter(); return std::make_unique<ClientLoadReportingFilter>();
} }
ArenaPromise<ServerMetadataHandle> ClientLoadReportingFilter::MakeCallPromise( void ClientLoadReportingFilter::Call::OnClientInitialMetadata(
CallArgs call_args, NextPromiseFactory next_promise_factory) { ClientMetadata& client_initial_metadata) {
// Stats object to update.
RefCountedPtr<GrpcLbClientStats> client_stats;
// Handle client initial metadata. // Handle client initial metadata.
// Grab client stats object from metadata. // Grab client stats object from metadata.
auto client_stats_md = auto client_stats_md =
call_args.client_initial_metadata->Take(GrpcLbClientStatsMetadata()); client_initial_metadata.Take(GrpcLbClientStatsMetadata());
if (client_stats_md.has_value()) { if (client_stats_md.has_value()) {
client_stats.reset(*client_stats_md); client_stats_.reset(*client_stats_md);
} }
}
auto* saw_initial_metadata = GetContext<Arena>()->New<bool>(false); void ClientLoadReportingFilter::Call::OnServerInitialMetadata(ServerMetadata&) {
call_args.server_initial_metadata->InterceptAndMap( saw_initial_metadata_ = true;
[saw_initial_metadata](ServerMetadataHandle md) { }
*saw_initial_metadata = true;
return md;
});
return Map(next_promise_factory(std::move(call_args)), void ClientLoadReportingFilter::Call::OnServerTrailingMetadata(
[saw_initial_metadata, client_stats = std::move(client_stats)]( ServerMetadata& server_trailing_metadata) {
ServerMetadataHandle trailing_metadata) { if (client_stats_ != nullptr) {
if (client_stats != nullptr) { client_stats_->AddCallFinished(
client_stats->AddCallFinished( server_trailing_metadata.get(GrpcStreamNetworkState()) ==
trailing_metadata->get(GrpcStreamNetworkState()) == GrpcStreamNetworkState::kNotSentOnWire,
GrpcStreamNetworkState::kNotSentOnWire, saw_initial_metadata_);
*saw_initial_metadata); }
}
return trailing_metadata;
});
} }
} // namespace grpc_core } // namespace grpc_core

@ -19,10 +19,10 @@
#ifndef GRPC_SRC_CORE_LOAD_BALANCING_GRPCLB_CLIENT_LOAD_REPORTING_FILTER_H #ifndef GRPC_SRC_CORE_LOAD_BALANCING_GRPCLB_CLIENT_LOAD_REPORTING_FILTER_H
#define GRPC_SRC_CORE_LOAD_BALANCING_GRPCLB_CLIENT_LOAD_REPORTING_FILTER_H #define GRPC_SRC_CORE_LOAD_BALANCING_GRPCLB_CLIENT_LOAD_REPORTING_FILTER_H
#include <grpc/support/port_platform.h>
#include "absl/status/statusor.h" #include "absl/status/statusor.h"
#include <grpc/support/port_platform.h>
#include "src/core/lib/channel/channel_args.h" #include "src/core/lib/channel/channel_args.h"
#include "src/core/lib/channel/channel_fwd.h" #include "src/core/lib/channel/channel_fwd.h"
#include "src/core/lib/channel/promise_based_filter.h" #include "src/core/lib/channel/promise_based_filter.h"
@ -31,16 +31,27 @@
namespace grpc_core { namespace grpc_core {
class ClientLoadReportingFilter final : public ChannelFilter { class ClientLoadReportingFilter final
: public ImplementChannelFilter<ClientLoadReportingFilter> {
public: public:
static const grpc_channel_filter kFilter; static const grpc_channel_filter kFilter;
static absl::StatusOr<ClientLoadReportingFilter> Create( class Call {
public:
void OnClientInitialMetadata(ClientMetadata& client_initial_metadata);
void OnServerInitialMetadata(ServerMetadata& server_initial_metadata);
void OnServerTrailingMetadata(ServerMetadata& server_trailing_metadata);
static const NoInterceptor OnServerToClientMessage;
static const NoInterceptor OnClientToServerMessage;
static const NoInterceptor OnFinalize;
private:
RefCountedPtr<GrpcLbClientStats> client_stats_;
bool saw_initial_metadata_ = false;
};
static absl::StatusOr<std::unique_ptr<ClientLoadReportingFilter>> Create(
const ChannelArgs& args, ChannelFilter::Args filter_args); const ChannelArgs& args, ChannelFilter::Args filter_args);
// Construct a promise for one call.
ArenaPromise<ServerMetadataHandle> MakeCallPromise(
CallArgs call_args, NextPromiseFactory next_promise_factory) override;
}; };
} // namespace grpc_core } // namespace grpc_core

@ -317,9 +317,10 @@ class XdsResolver final : public Resolver {
public: public:
const static grpc_channel_filter kFilter; const static grpc_channel_filter kFilter;
static absl::StatusOr<ClusterSelectionFilter> Create( static absl::StatusOr<std::unique_ptr<ClusterSelectionFilter>> Create(
const ChannelArgs& /* unused */, ChannelFilter::Args filter_args) { const ChannelArgs& /* unused */,
return ClusterSelectionFilter(filter_args); ChannelFilter::Args /* filter_args */) {
return std::make_unique<ClusterSelectionFilter>();
} }
// Construct a promise for one call. // Construct a promise for one call.
@ -332,12 +333,6 @@ class XdsResolver final : public Resolver {
static const NoInterceptor OnServerToClientMessage; static const NoInterceptor OnServerToClientMessage;
static const NoInterceptor OnFinalize; static const NoInterceptor OnFinalize;
}; };
private:
explicit ClusterSelectionFilter(ChannelFilter::Args filter_args)
: filter_args_(filter_args) {}
ChannelFilter::Args filter_args_;
}; };
RefCountedPtr<ClusterRef> GetOrCreateClusterRef( RefCountedPtr<ClusterRef> GetOrCreateClusterRef(

@ -17,8 +17,6 @@
// This filter reads GRPC_ARG_SERVICE_CONFIG and populates ServiceConfigCallData // This filter reads GRPC_ARG_SERVICE_CONFIG and populates ServiceConfigCallData
// in the call context per call for direct channels. // in the call context per call for direct channels.
#include <grpc/support/port_platform.h>
#include <functional> #include <functional>
#include <memory> #include <memory>
#include <string> #include <string>
@ -30,6 +28,7 @@
#include <grpc/impl/channel_arg_names.h> #include <grpc/impl/channel_arg_names.h>
#include <grpc/support/log.h> #include <grpc/support/log.h>
#include <grpc/support/port_platform.h>
#include "src/core/ext/filters/message_size/message_size_filter.h" #include "src/core/ext/filters/message_size/message_size_filter.h"
#include "src/core/lib/channel/channel_args.h" #include "src/core/lib/channel/channel_args.h"
@ -42,13 +41,13 @@
#include "src/core/lib/promise/arena_promise.h" #include "src/core/lib/promise/arena_promise.h"
#include "src/core/lib/promise/context.h" #include "src/core/lib/promise/context.h"
#include "src/core/lib/resource_quota/arena.h" #include "src/core/lib/resource_quota/arena.h"
#include "src/core/lib/surface/channel_stack_type.h"
#include "src/core/lib/transport/metadata_batch.h"
#include "src/core/lib/transport/transport.h"
#include "src/core/service_config/service_config.h" #include "src/core/service_config/service_config.h"
#include "src/core/service_config/service_config_call_data.h" #include "src/core/service_config/service_config_call_data.h"
#include "src/core/service_config/service_config_impl.h" #include "src/core/service_config/service_config_impl.h"
#include "src/core/service_config/service_config_parser.h" #include "src/core/service_config/service_config_parser.h"
#include "src/core/lib/surface/channel_stack_type.h"
#include "src/core/lib/transport/metadata_batch.h"
#include "src/core/lib/transport/transport.h"
namespace grpc_core { namespace grpc_core {
@ -59,9 +58,9 @@ class ServiceConfigChannelArgFilter final
public: public:
static const grpc_channel_filter kFilter; static const grpc_channel_filter kFilter;
static absl::StatusOr<ServiceConfigChannelArgFilter> Create( static absl::StatusOr<std::unique_ptr<ServiceConfigChannelArgFilter>> Create(
const ChannelArgs& args, ChannelFilter::Args) { const ChannelArgs& args, ChannelFilter::Args) {
return ServiceConfigChannelArgFilter(args); return std::make_unique<ServiceConfigChannelArgFilter>(args);
} }
explicit ServiceConfigChannelArgFilter(const ChannelArgs& args) { explicit ServiceConfigChannelArgFilter(const ChannelArgs& args) {

@ -84,11 +84,13 @@ const grpc_channel_filter OpenCensusClientFilter::kFilter =
grpc_core::FilterEndpoint::kClient, 0>( grpc_core::FilterEndpoint::kClient, 0>(
"opencensus_client"); "opencensus_client");
absl::StatusOr<OpenCensusClientFilter> OpenCensusClientFilter::Create( absl::StatusOr<std::unique_ptr<OpenCensusClientFilter>>
const grpc_core::ChannelArgs& args, ChannelFilter::Args /*filter_args*/) { OpenCensusClientFilter::Create(const grpc_core::ChannelArgs& args,
ChannelFilter::Args /*filter_args*/) {
bool observability_enabled = bool observability_enabled =
args.GetInt(GRPC_ARG_ENABLE_OBSERVABILITY).value_or(true); args.GetInt(GRPC_ARG_ENABLE_OBSERVABILITY).value_or(true);
return OpenCensusClientFilter(/*tracing_enabled=*/observability_enabled); return std::make_unique<OpenCensusClientFilter>(
/*tracing_enabled=*/observability_enabled);
} }
grpc_core::ArenaPromise<grpc_core::ServerMetadataHandle> grpc_core::ArenaPromise<grpc_core::ServerMetadataHandle>

@ -38,16 +38,17 @@ class OpenCensusClientFilter : public grpc_core::ChannelFilter {
public: public:
static const grpc_channel_filter kFilter; static const grpc_channel_filter kFilter;
static absl::StatusOr<OpenCensusClientFilter> Create( static absl::StatusOr<std::unique_ptr<OpenCensusClientFilter>> Create(
const grpc_core::ChannelArgs& args, ChannelFilter::Args /*filter_args*/); const grpc_core::ChannelArgs& args, ChannelFilter::Args /*filter_args*/);
explicit OpenCensusClientFilter(bool tracing_enabled)
: tracing_enabled_(tracing_enabled) {}
grpc_core::ArenaPromise<grpc_core::ServerMetadataHandle> MakeCallPromise( grpc_core::ArenaPromise<grpc_core::ServerMetadataHandle> MakeCallPromise(
grpc_core::CallArgs call_args, grpc_core::CallArgs call_args,
grpc_core::NextPromiseFactory next_promise_factory) override; grpc_core::NextPromiseFactory next_promise_factory) override;
private: private:
explicit OpenCensusClientFilter(bool tracing_enabled)
: tracing_enabled_(tracing_enabled) {}
bool tracing_enabled_ = true; bool tracing_enabled_ = true;
}; };

@ -231,7 +231,7 @@ class FilterTest : public FilterTestBase {
absl::StatusOr<Channel> MakeChannel(const ChannelArgs& args) { absl::StatusOr<Channel> MakeChannel(const ChannelArgs& args) {
auto filter = Filter::Create(args, ChannelFilter::Args()); auto filter = Filter::Create(args, ChannelFilter::Args());
if (!filter.ok()) return filter.status(); if (!filter.ok()) return filter.status();
return Channel(std::make_unique<Filter>(std::move(*filter)), this); return Channel(std::move(*filter), this);
} }
}; };

@ -49,9 +49,9 @@ class NoOpFilter final : public ChannelFilter {
return next(std::move(args)); return next(std::move(args));
} }
static absl::StatusOr<NoOpFilter> Create(const ChannelArgs&, static absl::StatusOr<std::unique_ptr<NoOpFilter>> Create(
ChannelFilter::Args) { const ChannelArgs&, ChannelFilter::Args) {
return NoOpFilter(); return std::make_unique<NoOpFilter>();
} }
}; };
using NoOpFilterTest = FilterTest<NoOpFilter>; using NoOpFilterTest = FilterTest<NoOpFilter>;
@ -70,9 +70,9 @@ class DelayStartFilter final : public ChannelFilter {
next); next);
} }
static absl::StatusOr<DelayStartFilter> Create(const ChannelArgs&, static absl::StatusOr<std::unique_ptr<DelayStartFilter>> Create(
ChannelFilter::Args) { const ChannelArgs&, ChannelFilter::Args) {
return DelayStartFilter(); return std::make_unique<DelayStartFilter>();
} }
}; };
using DelayStartFilterTest = FilterTest<DelayStartFilter>; using DelayStartFilterTest = FilterTest<DelayStartFilter>;
@ -86,9 +86,9 @@ class AddClientInitialMetadataFilter final : public ChannelFilter {
return next(std::move(args)); return next(std::move(args));
} }
static absl::StatusOr<AddClientInitialMetadataFilter> Create( static absl::StatusOr<std::unique_ptr<AddClientInitialMetadataFilter>> Create(
const ChannelArgs&, ChannelFilter::Args) { const ChannelArgs&, ChannelFilter::Args) {
return AddClientInitialMetadataFilter(); return absl::make_unique<AddClientInitialMetadataFilter>();
} }
}; };
using AddClientInitialMetadataFilterTest = using AddClientInitialMetadataFilterTest =
@ -104,9 +104,9 @@ class AddServerTrailingMetadataFilter final : public ChannelFilter {
}); });
} }
static absl::StatusOr<AddServerTrailingMetadataFilter> Create( static absl::StatusOr<std::unique_ptr<AddServerTrailingMetadataFilter>>
const ChannelArgs&, ChannelFilter::Args) { Create(const ChannelArgs&, ChannelFilter::Args) {
return AddServerTrailingMetadataFilter(); return absl::make_unique<AddServerTrailingMetadataFilter>();
} }
}; };
using AddServerTrailingMetadataFilterTest = using AddServerTrailingMetadataFilterTest =
@ -122,10 +122,9 @@ class AddServerInitialMetadataFilter final : public ChannelFilter {
}); });
return next(std::move(args)); return next(std::move(args));
} }
static absl::StatusOr<std::unique_ptr<AddServerInitialMetadataFilter>> Create(
static absl::StatusOr<AddServerInitialMetadataFilter> Create(
const ChannelArgs&, ChannelFilter::Args) { const ChannelArgs&, ChannelFilter::Args) {
return AddServerInitialMetadataFilter(); return absl::make_unique<AddServerInitialMetadataFilter>();
} }
}; };
using AddServerInitialMetadataFilterTest = using AddServerInitialMetadataFilterTest =

@ -15,6 +15,7 @@
#include "src/core/lib/surface/channel_init.h" #include "src/core/lib/surface/channel_init.h"
#include <map> #include <map>
#include <memory>
#include <string> #include <string>
#include "absl/strings/string_view.h" #include "absl/strings/string_view.h"
@ -206,9 +207,10 @@ class TestFilter1 {
public: public:
explicit TestFilter1(int* p) : p_(p) {} explicit TestFilter1(int* p) : p_(p) {}
static absl::StatusOr<TestFilter1> Create(const ChannelArgs& args, Empty) { static absl::StatusOr<std::unique_ptr<TestFilter1>> Create(
const ChannelArgs& args, Empty) {
EXPECT_EQ(args.GetInt("foo"), 1); EXPECT_EQ(args.GetInt("foo"), 1);
return TestFilter1(args.GetPointer<int>("p")); return std::make_unique<TestFilter1>(args.GetPointer<int>("p"));
} }
static const grpc_channel_filter kFilter; static const grpc_channel_filter kFilter;

@ -22,15 +22,16 @@ class FakeStatsClientFilter : public ChannelFilter {
public: public:
static const grpc_channel_filter kFilter; static const grpc_channel_filter kFilter;
static absl::StatusOr<FakeStatsClientFilter> Create( explicit FakeStatsClientFilter(
FakeClientCallTracerFactory* fake_client_call_tracer_factory);
static absl::StatusOr<std::unique_ptr<FakeStatsClientFilter>> Create(
const ChannelArgs& /*args*/, ChannelFilter::Args /*filter_args*/); const ChannelArgs& /*args*/, ChannelFilter::Args /*filter_args*/);
ArenaPromise<ServerMetadataHandle> MakeCallPromise( ArenaPromise<ServerMetadataHandle> MakeCallPromise(
CallArgs call_args, NextPromiseFactory next_promise_factory) override; CallArgs call_args, NextPromiseFactory next_promise_factory) override;
private: private:
explicit FakeStatsClientFilter(
FakeClientCallTracerFactory* fake_client_call_tracer_factory);
FakeClientCallTracerFactory* const fake_client_call_tracer_factory_; FakeClientCallTracerFactory* const fake_client_call_tracer_factory_;
}; };
@ -38,13 +39,15 @@ const grpc_channel_filter FakeStatsClientFilter::kFilter =
MakePromiseBasedFilter<FakeStatsClientFilter, FilterEndpoint::kClient>( MakePromiseBasedFilter<FakeStatsClientFilter, FilterEndpoint::kClient>(
"fake_stats_client"); "fake_stats_client");
absl::StatusOr<FakeStatsClientFilter> FakeStatsClientFilter::Create( absl::StatusOr<std::unique_ptr<FakeStatsClientFilter>>
const ChannelArgs& args, ChannelFilter::Args /*filter_args*/) { FakeStatsClientFilter::Create(const ChannelArgs& args,
ChannelFilter::Args /*filter_args*/) {
auto* fake_client_call_tracer_factory = auto* fake_client_call_tracer_factory =
args.GetPointer<FakeClientCallTracerFactory>( args.GetPointer<FakeClientCallTracerFactory>(
GRPC_ARG_INJECT_FAKE_CLIENT_CALL_TRACER_FACTORY); GRPC_ARG_INJECT_FAKE_CLIENT_CALL_TRACER_FACTORY);
GPR_ASSERT(fake_client_call_tracer_factory != nullptr); GPR_ASSERT(fake_client_call_tracer_factory != nullptr);
return FakeStatsClientFilter(fake_client_call_tracer_factory); return std::make_unique<FakeStatsClientFilter>(
fake_client_call_tracer_factory);
} }
ArenaPromise<ServerMetadataHandle> FakeStatsClientFilter::MakeCallPromise( ArenaPromise<ServerMetadataHandle> FakeStatsClientFilter::MakeCallPromise(

@ -50,9 +50,15 @@ class AddLabelsFilter : public grpc_core::ChannelFilter {
public: public:
static const grpc_channel_filter kFilter; static const grpc_channel_filter kFilter;
static absl::StatusOr<AddLabelsFilter> Create( explicit AddLabelsFilter(
std::map<grpc_core::ClientCallTracer::CallAttemptTracer::OptionalLabelKey,
grpc_core::RefCountedStringValue>
labels_to_inject)
: labels_to_inject_(std::move(labels_to_inject)) {}
static absl::StatusOr<std::unique_ptr<AddLabelsFilter>> Create(
const grpc_core::ChannelArgs& args, ChannelFilter::Args /*filter_args*/) { const grpc_core::ChannelArgs& args, ChannelFilter::Args /*filter_args*/) {
return AddLabelsFilter( return absl::make_unique<AddLabelsFilter>(
*args.GetPointer<std::map< *args.GetPointer<std::map<
grpc_core::ClientCallTracer::CallAttemptTracer::OptionalLabelKey, grpc_core::ClientCallTracer::CallAttemptTracer::OptionalLabelKey,
grpc_core::RefCountedStringValue>>(GRPC_ARG_LABELS_TO_INJECT)); grpc_core::RefCountedStringValue>>(GRPC_ARG_LABELS_TO_INJECT));
@ -73,12 +79,6 @@ class AddLabelsFilter : public grpc_core::ChannelFilter {
} }
private: private:
explicit AddLabelsFilter(
std::map<grpc_core::ClientCallTracer::CallAttemptTracer::OptionalLabelKey,
grpc_core::RefCountedStringValue>
labels_to_inject)
: labels_to_inject_(std::move(labels_to_inject)) {}
const std::map< const std::map<
grpc_core::ClientCallTracer::CallAttemptTracer::OptionalLabelKey, grpc_core::ClientCallTracer::CallAttemptTracer::OptionalLabelKey,
grpc_core::RefCountedStringValue> grpc_core::RefCountedStringValue>

Loading…
Cancel
Save