diff --git a/src/core/lib/channel/promise_based_filter.cc b/src/core/lib/channel/promise_based_filter.cc index fd334201a2e..60eaf3e4422 100644 --- a/src/core/lib/channel/promise_based_filter.cc +++ b/src/core/lib/channel/promise_based_filter.cc @@ -20,6 +20,7 @@ #include #include #include +#include #include #include "absl/base/attributes.h" @@ -49,14 +50,21 @@ namespace promise_filter_detail { namespace { class FakeActivity final : public Activity { public: + explicit FakeActivity(Activity* wake_activity) + : wake_activity_(wake_activity) {} void Orphan() override {} void ForceImmediateRepoll() override {} - Waker MakeOwningWaker() override { abort(); } - Waker MakeNonOwningWaker() override { abort(); } + Waker MakeOwningWaker() override { return wake_activity_->MakeOwningWaker(); } + Waker MakeNonOwningWaker() override { + return wake_activity_->MakeNonOwningWaker(); + } void Run(absl::FunctionRef f) { ScopedActivity activity(this); f(); } + + private: + Activity* const wake_activity_; }; absl::Status StatusFromMetadata(const ServerMetadata& md) { @@ -103,7 +111,7 @@ BaseCallData::BaseCallData( } BaseCallData::~BaseCallData() { - FakeActivity().Run([this] { + FakeActivity(this).Run([this] { if (send_message_ != nullptr) { send_message_->~SendMessage(); } @@ -2279,7 +2287,7 @@ void ServerCallData::RecvInitialMetadataReady(grpc_error_handle error) { ScopedContext context(this); // Construct the promise. ChannelFilter* filter = static_cast(elem()->channel_data); - FakeActivity().Run([this, filter] { + FakeActivity(this).Run([this, filter] { promise_ = filter->MakeCallPromise( CallArgs{WrapMetadata(recv_initial_metadata_), server_initial_metadata_pipe() == nullptr @@ -2297,11 +2305,6 @@ void ServerCallData::RecvInitialMetadataReady(grpc_error_handle error) { }); // Poll once. WakeInsideCombiner(&flusher); - if (auto* closure = - std::exchange(original_recv_initial_metadata_ready_, nullptr)) { - flusher.AddClosure(closure, absl::OkStatus(), - "original_recv_initial_metadata"); - } } std::string ServerCallData::DebugString() const { @@ -2481,9 +2484,20 @@ void ServerCallData::WakeInsideCombiner(Flusher* flusher) { } } } + if (std::exchange(forward_recv_initial_metadata_callback_, false)) { + if (auto* closure = + std::exchange(original_recv_initial_metadata_ready_, nullptr)) { + flusher->AddClosure(closure, absl::OkStatus(), + "original_recv_initial_metadata"); + } + } } -void ServerCallData::OnWakeup() { abort(); } // not implemented +void ServerCallData::OnWakeup() { + Flusher flusher(this); + ScopedContext context(this); + WakeInsideCombiner(&flusher); +} } // namespace promise_filter_detail } // namespace grpc_core diff --git a/src/core/lib/security/transport/auth_filters.h b/src/core/lib/security/transport/auth_filters.h index cdf1c0a8b4b..7c69b8f1b57 100644 --- a/src/core/lib/security/transport/auth_filters.h +++ b/src/core/lib/security/transport/auth_filters.h @@ -36,8 +36,6 @@ #include "src/core/lib/security/security_connector/security_connector.h" #include "src/core/lib/transport/transport.h" -extern const grpc_channel_filter grpc_server_auth_filter; - namespace grpc_core { // Handles calling out to credentials to fill in metadata per call. @@ -64,6 +62,30 @@ class ClientAuthFilter final : public ChannelFilter { grpc_call_credentials::GetRequestMetadataArgs args_; }; +class ServerAuthFilter final : public ChannelFilter { + public: + static const grpc_channel_filter kFilter; + + static absl::StatusOr Create(const ChannelArgs& args, + ChannelFilter::Args); + + // Construct a promise for one call. + ArenaPromise MakeCallPromise( + CallArgs call_args, NextPromiseFactory next_promise_factory) override; + + private: + ServerAuthFilter(RefCountedPtr server_credentials, + RefCountedPtr auth_context); + + class RunApplicationCode; + + ArenaPromise> GetCallCredsMetadata( + CallArgs call_args); + + RefCountedPtr server_credentials_; + RefCountedPtr auth_context_; +}; + } // namespace grpc_core // Exposed for testing purposes only. diff --git a/src/core/lib/security/transport/server_auth_filter.cc b/src/core/lib/security/transport/server_auth_filter.cc index a933524ce0d..f5938e99e6c 100644 --- a/src/core/lib/security/transport/server_auth_filter.cc +++ b/src/core/lib/security/transport/server_auth_filter.cc @@ -21,28 +21,37 @@ #include #include -#include +#include +#include +#include +#include #include "absl/status/status.h" +#include "absl/status/statusor.h" #include #include #include #include #include -#include #include +#include "src/core/lib/channel/channel_args.h" #include "src/core/lib/channel/channel_fwd.h" #include "src/core/lib/channel/channel_stack.h" #include "src/core/lib/channel/context.h" +#include "src/core/lib/channel/promise_based_filter.h" #include "src/core/lib/gprpp/debug_location.h" #include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/gprpp/status_helper.h" -#include "src/core/lib/iomgr/call_combiner.h" -#include "src/core/lib/iomgr/closure.h" #include "src/core/lib/iomgr/error.h" #include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/promise/activity.h" +#include "src/core/lib/promise/arena_promise.h" +#include "src/core/lib/promise/context.h" +#include "src/core/lib/promise/poll.h" +#include "src/core/lib/promise/try_seq.h" +#include "src/core/lib/resource_quota/arena.h" #include "src/core/lib/security/context/security_context.h" #include "src/core/lib/security/credentials/credentials.h" #include "src/core/lib/security/transport/auth_filters.h" // IWYU pragma: keep @@ -51,88 +60,33 @@ #include "src/core/lib/transport/metadata_batch.h" #include "src/core/lib/transport/transport.h" -static void recv_initial_metadata_ready(void* arg, grpc_error_handle error); -static void recv_trailing_metadata_ready(void* user_data, - grpc_error_handle error); +namespace grpc_core { -namespace { -enum async_state { - STATE_INIT = 0, - STATE_DONE, - STATE_CANCELLED, -}; - -struct channel_data { - channel_data(grpc_auth_context* auth_context, grpc_server_credentials* creds) - : auth_context(auth_context->Ref()), creds(creds->Ref()) {} - ~channel_data() { auth_context.reset(DEBUG_LOCATION, "server_auth_filter"); } - - grpc_core::RefCountedPtr auth_context; - grpc_core::RefCountedPtr creds; -}; +const grpc_channel_filter ServerAuthFilter::kFilter = + MakePromiseBasedFilter( + "server-auth"); -struct call_data { - call_data(grpc_call_element* elem, const grpc_call_element_args& args) - : call_combiner(args.call_combiner), owning_call(args.call_stack) { - GRPC_CLOSURE_INIT(&recv_initial_metadata_ready, - ::recv_initial_metadata_ready, elem, - grpc_schedule_on_exec_ctx); - GRPC_CLOSURE_INIT(&recv_trailing_metadata_ready, - ::recv_trailing_metadata_ready, elem, - grpc_schedule_on_exec_ctx); - // Create server security context. Set its auth context from channel - // data and save it in the call context. - grpc_server_security_context* server_ctx = - grpc_server_security_context_create(args.arena); - channel_data* chand = static_cast(elem->channel_data); - server_ctx->auth_context = - chand->auth_context->Ref(DEBUG_LOCATION, "server_auth_filter"); - if (args.context[GRPC_CONTEXT_SECURITY].value != nullptr) { - args.context[GRPC_CONTEXT_SECURITY].destroy( - args.context[GRPC_CONTEXT_SECURITY].value); - } - args.context[GRPC_CONTEXT_SECURITY].value = server_ctx; - args.context[GRPC_CONTEXT_SECURITY].destroy = - grpc_server_security_context_destroy; - } - - ~call_data() {} - - grpc_core::CallCombiner* call_combiner; - grpc_call_stack* owning_call; - grpc_transport_stream_op_batch* recv_initial_metadata_batch; - grpc_closure* original_recv_initial_metadata_ready; - grpc_closure recv_initial_metadata_ready; - grpc_error_handle recv_initial_metadata_error; - grpc_closure recv_trailing_metadata_ready; - grpc_closure* original_recv_trailing_metadata_ready; - grpc_error_handle recv_trailing_metadata_error; - bool seen_recv_trailing_metadata_ready = false; - grpc_metadata_array md; - grpc_closure cancel_closure; - gpr_atm state = STATE_INIT; // async_state -}; +namespace { class ArrayEncoder { public: explicit ArrayEncoder(grpc_metadata_array* result) : result_(result) {} - void Encode(const grpc_core::Slice& key, const grpc_core::Slice& value) { + void Encode(const Slice& key, const Slice& value) { Append(key.Ref(), value.Ref()); } template void Encode(Which, const typename Which::ValueType& value) { - Append(grpc_core::Slice( - grpc_core::StaticSlice::FromStaticString(Which::key())), - grpc_core::Slice(Which::Encode(value))); + Append(Slice(StaticSlice::FromStaticString(Which::key())), + Slice(Which::Encode(value))); } - void Encode(grpc_core::HttpMethodMetadata, - const typename grpc_core::HttpMethodMetadata::ValueType&) {} + void Encode(HttpMethodMetadata, + const typename HttpMethodMetadata::ValueType&) {} private: - void Append(grpc_core::Slice key, grpc_core::Slice value) { + void Append(Slice key, Slice value) { if (result_->count == result_->capacity) { result_->capacity = std::max(result_->capacity + 8, result_->capacity * 2); @@ -147,9 +101,9 @@ class ArrayEncoder { grpc_metadata_array* result_; }; -} // namespace - -static grpc_metadata_array metadata_batch_to_md_array( +// TODO(ctiller): seek out all users of this functionality and change API so +// that this unilateral format conversion IS NOT REQUIRED. +grpc_metadata_array MetadataBatchToMetadataArray( const grpc_metadata_batch* batch) { grpc_metadata_array result; grpc_metadata_array_init(&result); @@ -158,202 +112,117 @@ static grpc_metadata_array metadata_batch_to_md_array( return result; } -static void on_md_processing_done_inner(grpc_call_element* elem, - const grpc_metadata* consumed_md, - size_t num_consumed_md, - const grpc_metadata* response_md, - size_t num_response_md, - grpc_error_handle error) { - call_data* calld = static_cast(elem->call_data); - grpc_transport_stream_op_batch* batch = calld->recv_initial_metadata_batch; - // TODO(ZhenLian): Implement support for response_md. - if (response_md != nullptr && num_response_md > 0) { - gpr_log(GPR_ERROR, - "response_md in auth metadata processing not supported for now. " - "Ignoring..."); +} // namespace + +class ServerAuthFilter::RunApplicationCode { + public: + // TODO(ctiller): Allocate state_ into a pool on the arena to reuse this + // memory later + RunApplicationCode(ServerAuthFilter* filter, CallArgs call_args) + : state_(GetContext()->ManagedNew(std::move(call_args))) { + filter->server_credentials_->auth_metadata_processor().process( + filter->server_credentials_->auth_metadata_processor().state, + filter->auth_context_.get(), state_->md.metadata, state_->md.count, + OnMdProcessingDone, state_); } - if (error.ok()) { - for (size_t i = 0; i < num_consumed_md; i++) { - batch->payload->recv_initial_metadata.recv_initial_metadata->Remove( - grpc_core::StringViewFromSlice(consumed_md[i].key)); + + Poll> operator()() { + if (state_->done.load(std::memory_order_acquire)) { + return Poll>(std::move(state_->call_args)); } + return Pending{}; } - calld->recv_initial_metadata_error = error; - grpc_closure* closure = calld->original_recv_initial_metadata_ready; - calld->original_recv_initial_metadata_ready = nullptr; - if (calld->seen_recv_trailing_metadata_ready) { - GRPC_CALL_COMBINER_START(calld->call_combiner, - &calld->recv_trailing_metadata_ready, - calld->recv_trailing_metadata_error, - "continue recv_trailing_metadata_ready"); - } - grpc_core::ExecCtx::Run(DEBUG_LOCATION, closure, error); -} -// Called from application code. -static void on_md_processing_done( - void* user_data, const grpc_metadata* consumed_md, size_t num_consumed_md, - const grpc_metadata* response_md, size_t num_response_md, - grpc_status_code status, const char* error_details) { - grpc_call_element* elem = static_cast(user_data); - call_data* calld = static_cast(elem->call_data); - grpc_core::ApplicationCallbackExecCtx callback_exec_ctx; - grpc_core::ExecCtx exec_ctx; - // If the call was not cancelled while we were in flight, process the result. - if (gpr_atm_full_cas(&calld->state, static_cast(STATE_INIT), - static_cast(STATE_DONE))) { - grpc_error_handle error; - if (status != GRPC_STATUS_OK) { + private: + struct State { + explicit State(CallArgs call_args) : call_args(std::move(call_args)) {} + Waker waker{Activity::current()->MakeOwningWaker()}; + absl::StatusOr call_args; + grpc_metadata_array md = + MetadataBatchToMetadataArray(call_args->client_initial_metadata.get()); + std::atomic done{false}; + }; + + // Called from application code. + static void OnMdProcessingDone( + void* user_data, const grpc_metadata* consumed_md, size_t num_consumed_md, + const grpc_metadata* response_md, size_t num_response_md, + grpc_status_code status, const char* error_details) { + ApplicationCallbackExecCtx callback_exec_ctx; + ExecCtx exec_ctx; + + auto* state = static_cast(user_data); + + // TODO(ZhenLian): Implement support for response_md. + if (response_md != nullptr && num_response_md > 0) { + gpr_log(GPR_ERROR, + "response_md in auth metadata processing not supported for now. " + "Ignoring..."); + } + + if (status == GRPC_STATUS_OK) { + ClientMetadataHandle& md = state->call_args->client_initial_metadata; + for (size_t i = 0; i < num_consumed_md; i++) { + md->Remove(StringViewFromSlice(consumed_md[i].key)); + } + } else { if (error_details == nullptr) { error_details = "Authentication metadata processing failed."; } - error = - grpc_error_set_int(GRPC_ERROR_CREATE(error_details), - grpc_core::StatusIntProperty::kRpcStatus, status); + state->call_args = grpc_error_set_int( + absl::Status(static_cast(status), error_details), + StatusIntProperty::kRpcStatus, status); } - on_md_processing_done_inner(elem, consumed_md, num_consumed_md, response_md, - num_response_md, error); - } - // Clean up. - for (size_t i = 0; i < calld->md.count; i++) { - grpc_core::CSliceUnref(calld->md.metadata[i].key); - grpc_core::CSliceUnref(calld->md.metadata[i].value); - } - grpc_metadata_array_destroy(&calld->md); - GRPC_CALL_STACK_UNREF(calld->owning_call, "server_auth_metadata"); -} -static void cancel_call(void* arg, grpc_error_handle error) { - grpc_call_element* elem = static_cast(arg); - call_data* calld = static_cast(elem->call_data); - // If the result was not already processed, invoke the callback now. - if (!error.ok() && - gpr_atm_full_cas(&calld->state, static_cast(STATE_INIT), - static_cast(STATE_CANCELLED))) { - on_md_processing_done_inner(elem, nullptr, 0, nullptr, 0, error); - } - GRPC_CALL_STACK_UNREF(calld->owning_call, "cancel_call"); -} - -static void recv_initial_metadata_ready(void* arg, grpc_error_handle error) { - grpc_call_element* elem = static_cast(arg); - channel_data* chand = static_cast(elem->channel_data); - call_data* calld = static_cast(elem->call_data); - grpc_transport_stream_op_batch* batch = calld->recv_initial_metadata_batch; - if (error.ok()) { - if (chand->creds != nullptr && - chand->creds->auth_metadata_processor().process != nullptr) { - // We're calling out to the application, so we need to make sure - // to drop the call combiner early if we get cancelled. - // TODO(yashykt): We would not need this ref if call combiners used - // Closure::Run() instead of ExecCtx::Run() - GRPC_CALL_STACK_REF(calld->owning_call, "cancel_call"); - GRPC_CLOSURE_INIT(&calld->cancel_closure, cancel_call, elem, - grpc_schedule_on_exec_ctx); - calld->call_combiner->SetNotifyOnCancel(&calld->cancel_closure); - GRPC_CALL_STACK_REF(calld->owning_call, "server_auth_metadata"); - calld->md = metadata_batch_to_md_array( - batch->payload->recv_initial_metadata.recv_initial_metadata); - chand->creds->auth_metadata_processor().process( - chand->creds->auth_metadata_processor().state, - chand->auth_context.get(), calld->md.metadata, calld->md.count, - on_md_processing_done, elem); - return; + // Clean up. + for (size_t i = 0; i < state->md.count; i++) { + CSliceUnref(state->md.metadata[i].key); + CSliceUnref(state->md.metadata[i].value); } - } - grpc_closure* closure = calld->original_recv_initial_metadata_ready; - calld->original_recv_initial_metadata_ready = nullptr; - if (calld->seen_recv_trailing_metadata_ready) { - GRPC_CALL_COMBINER_START(calld->call_combiner, - &calld->recv_trailing_metadata_ready, - calld->recv_trailing_metadata_error, - "continue recv_trailing_metadata_ready"); - } - grpc_core::Closure::Run(DEBUG_LOCATION, closure, error); -} + grpc_metadata_array_destroy(&state->md); -static void recv_trailing_metadata_ready(void* user_data, - grpc_error_handle err) { - grpc_call_element* elem = static_cast(user_data); - call_data* calld = static_cast(elem->call_data); - if (calld->original_recv_initial_metadata_ready != nullptr) { - calld->recv_trailing_metadata_error = err; - calld->seen_recv_trailing_metadata_ready = true; - GRPC_CALL_COMBINER_STOP(calld->call_combiner, - "deferring recv_trailing_metadata_ready until " - "after recv_initial_metadata_ready"); - return; + auto waker = std::move(state->waker); + state->done.store(true, std::memory_order_release); + waker.Wakeup(); } - err = grpc_error_add_child(err, calld->recv_initial_metadata_error); - grpc_core::Closure::Run(DEBUG_LOCATION, - calld->original_recv_trailing_metadata_ready, err); -} -static void server_auth_start_transport_stream_op_batch( - grpc_call_element* elem, grpc_transport_stream_op_batch* batch) { - call_data* calld = static_cast(elem->call_data); - if (batch->recv_initial_metadata) { - // Inject our callback. - calld->recv_initial_metadata_batch = batch; - calld->original_recv_initial_metadata_ready = - batch->payload->recv_initial_metadata.recv_initial_metadata_ready; - batch->payload->recv_initial_metadata.recv_initial_metadata_ready = - &calld->recv_initial_metadata_ready; - } - if (batch->recv_trailing_metadata) { - calld->original_recv_trailing_metadata_ready = - batch->payload->recv_trailing_metadata.recv_trailing_metadata_ready; - batch->payload->recv_trailing_metadata.recv_trailing_metadata_ready = - &calld->recv_trailing_metadata_ready; + State* state_; +}; + +ArenaPromise ServerAuthFilter::MakeCallPromise( + CallArgs call_args, NextPromiseFactory next_promise_factory) { + // Create server security context. Set its auth context from channel + // data and save it in the call context. + grpc_server_security_context* server_ctx = + grpc_server_security_context_create(GetContext()); + server_ctx->auth_context = + auth_context_->Ref(DEBUG_LOCATION, "server_auth_filter"); + grpc_call_context_element& context = + GetContext()[GRPC_CONTEXT_SECURITY]; + if (context.value != nullptr) context.destroy(context.value); + context.value = server_ctx; + context.destroy = grpc_server_security_context_destroy; + + if (server_credentials_ == nullptr || + server_credentials_->auth_metadata_processor().process == nullptr) { + return next_promise_factory(std::move(call_args)); } - grpc_call_next_op(elem, batch); -} -// Constructor for call_data -static grpc_error_handle server_auth_init_call_elem( - grpc_call_element* elem, const grpc_call_element_args* args) { - new (elem->call_data) call_data(elem, *args); - return absl::OkStatus(); + return TrySeq(RunApplicationCode(this, std::move(call_args)), + std::move(next_promise_factory)); } -// Destructor for call_data -static void server_auth_destroy_call_elem( - grpc_call_element* elem, const grpc_call_final_info* /*final_info*/, - grpc_closure* /*ignored*/) { - call_data* calld = static_cast(elem->call_data); - calld->~call_data(); -} +ServerAuthFilter::ServerAuthFilter( + RefCountedPtr server_credentials, + RefCountedPtr auth_context) + : server_credentials_(server_credentials), auth_context_(auth_context) {} -// Constructor for channel_data -static grpc_error_handle server_auth_init_channel_elem( - grpc_channel_element* elem, grpc_channel_element_args* args) { - GPR_ASSERT(!args->is_last); - grpc_auth_context* auth_context = - grpc_find_auth_context_in_args(args->channel_args); +absl::StatusOr ServerAuthFilter::Create( + const ChannelArgs& args, ChannelFilter::Args) { + auto auth_context = args.GetObjectRef(); GPR_ASSERT(auth_context != nullptr); - grpc_server_credentials* creds = - grpc_find_server_credentials_in_args(args->channel_args); - new (elem->channel_data) channel_data(auth_context, creds); - return absl::OkStatus(); -} - -// Destructor for channel data -static void server_auth_destroy_channel_elem(grpc_channel_element* elem) { - channel_data* chand = static_cast(elem->channel_data); - chand->~channel_data(); + auto creds = args.GetObjectRef(); + return ServerAuthFilter(std::move(creds), std::move(auth_context)); } -const grpc_channel_filter grpc_server_auth_filter = { - server_auth_start_transport_stream_op_batch, - nullptr, - grpc_channel_next_op, - sizeof(call_data), - server_auth_init_call_elem, - grpc_call_stack_ignore_set_pollset_or_pollset_set, - server_auth_destroy_call_elem, - sizeof(channel_data), - server_auth_init_channel_elem, - grpc_channel_stack_no_post_init, - server_auth_destroy_channel_elem, - grpc_channel_next_get_info, - "server-auth"}; +} // namespace grpc_core diff --git a/src/core/lib/surface/init.cc b/src/core/lib/surface/init.cc index 9af5938255b..32d72fb2ae2 100644 --- a/src/core/lib/surface/init.cc +++ b/src/core/lib/surface/init.cc @@ -80,7 +80,7 @@ static bool maybe_prepend_client_auth_filter( static bool maybe_prepend_server_auth_filter( grpc_core::ChannelStackBuilder* builder) { if (builder->channel_args().Contains(GRPC_SERVER_CREDENTIALS_ARG)) { - builder->PrependFilter(&grpc_server_auth_filter); + builder->PrependFilter(&grpc_core::ServerAuthFilter::kFilter); } return true; }