diff --git a/src/core/lib/security/transport/auth_filters.h b/src/core/lib/security/transport/auth_filters.h index 7c69b8f1b57..765246646be 100644 --- a/src/core/lib/security/transport/auth_filters.h +++ b/src/core/lib/security/transport/auth_filters.h @@ -62,23 +62,56 @@ class ClientAuthFilter final : public ChannelFilter { grpc_call_credentials::GetRequestMetadataArgs args_; }; -class ServerAuthFilter final : public ChannelFilter { +class ServerAuthFilter final : public ImplementChannelFilter { + private: + ServerAuthFilter(RefCountedPtr server_credentials, + RefCountedPtr auth_context); + + class RunApplicationCode { + public: + RunApplicationCode(ServerAuthFilter* filter, ClientMetadata& metadata); + + RunApplicationCode(const RunApplicationCode&) = delete; + RunApplicationCode& operator=(const RunApplicationCode&) = delete; + RunApplicationCode(RunApplicationCode&& other) noexcept + : state_(std::exchange(other.state_, nullptr)) {} + RunApplicationCode& operator=(RunApplicationCode&& other) noexcept { + state_ = std::exchange(other.state_, nullptr); + return *this; + } + + Poll operator()(); + + private: + struct State; + State* state_; + }; + public: static const grpc_channel_filter kFilter; static absl::StatusOr Create(const ChannelArgs& args, ChannelFilter::Args); - // Construct a promise for one call. - ArenaPromise MakeCallPromise( - CallArgs call_args, NextPromiseFactory next_promise_factory) override; + class Call { + public: + explicit Call(ServerAuthFilter* filter); + auto OnClientInitialMetadata(ClientMetadata& md, ServerAuthFilter* filter) { + return If( + filter->server_credentials_ == nullptr || + filter->server_credentials_->auth_metadata_processor().process == + nullptr, + ImmediateOkStatus(), + [filter, md = &md]() { return RunApplicationCode(filter, *md); }); + } + static const NoInterceptor OnServerInitialMetadata; + static const NoInterceptor OnClientToServerMessage; + static const NoInterceptor OnServerToClientMessage; + static const NoInterceptor OnServerTrailingMetadata; + static const NoInterceptor OnFinalize; + }; private: - ServerAuthFilter(RefCountedPtr server_credentials, - RefCountedPtr auth_context); - - class RunApplicationCode; - ArenaPromise> GetCallCredsMetadata( CallArgs call_args); diff --git a/src/core/lib/security/transport/server_auth_filter.cc b/src/core/lib/security/transport/server_auth_filter.cc index 765ddcdf42b..6be54bef11b 100644 --- a/src/core/lib/security/transport/server_auth_filter.cc +++ b/src/core/lib/security/transport/server_auth_filter.cc @@ -114,6 +114,7 @@ grpc_metadata_array MetadataBatchToMetadataArray( } // namespace +#if 0 class ServerAuthFilter::RunApplicationCode { public: // TODO(ctiller): Allocate state_ into a pool on the arena to reuse this @@ -204,28 +205,20 @@ class ServerAuthFilter::RunApplicationCode { State* state_; }; +#endif -ArenaPromise ServerAuthFilter::MakeCallPromise( - CallArgs call_args, NextPromiseFactory next_promise_factory) { +ServerAuthFilter::Call::Call(ServerAuthFilter* filter) { // Create server security context. Set its auth context from channel // data and save it in the call context. grpc_server_security_context* server_ctx = grpc_server_security_context_create(GetContext()); server_ctx->auth_context = - auth_context_->Ref(DEBUG_LOCATION, "server_auth_filter"); + filter->auth_context_->Ref(DEBUG_LOCATION, "server_auth_filter"); grpc_call_context_element& context = GetContext()[GRPC_CONTEXT_SECURITY]; if (context.value != nullptr) context.destroy(context.value); context.value = server_ctx; context.destroy = grpc_server_security_context_destroy; - - if (server_credentials_ == nullptr || - server_credentials_->auth_metadata_processor().process == nullptr) { - return next_promise_factory(std::move(call_args)); - } - - return TrySeq(RunApplicationCode(this, std::move(call_args)), - std::move(next_promise_factory)); } ServerAuthFilter::ServerAuthFilter(