From 138c4667c9caeef50c193ed18019edeacfe6a2cc Mon Sep 17 00:00:00 2001 From: Craig Tiller Date: Mon, 14 Mar 2022 10:05:01 -0700 Subject: [PATCH] Change main argument of call promise to be a struct (#29019) * introduce call args * bs * x * Automated change: Fix sanity tests * fix * Simplify naming * tweak Co-authored-by: ctiller --- BUILD | 2 + build_autogenerated.yaml | 4 ++ gRPC-C++.podspec | 4 ++ gRPC-Core.podspec | 4 ++ grpc.gemspec | 2 + package.xml | 2 + .../filters/client_idle/client_idle_filter.cc | 18 +++---- .../filters/http/client_authority_filter.cc | 13 ++--- .../filters/http/client_authority_filter.h | 5 +- src/core/lib/channel/channel_stack.h | 5 +- src/core/lib/channel/promise_based_filter.cc | 53 ++++++++++--------- src/core/lib/channel/promise_based_filter.h | 25 ++++----- .../authorization/grpc_server_authz_filter.cc | 13 +++-- .../authorization/grpc_server_authz_filter.h | 7 ++- .../security/credentials/call_creds_util.cc | 6 +-- .../security/credentials/call_creds_util.h | 4 +- .../composite/composite_credentials.cc | 6 +-- .../composite/composite_credentials.h | 4 +- .../lib/security/credentials/credentials.h | 4 +- .../credentials/fake/fake_credentials.cc | 4 +- .../credentials/fake/fake_credentials.h | 4 +- .../credentials/iam/iam_credentials.cc | 4 +- .../credentials/iam/iam_credentials.h | 4 +- .../credentials/jwt/jwt_credentials.cc | 4 +- .../credentials/jwt/jwt_credentials.h | 4 +- .../credentials/oauth2/oauth2_credentials.cc | 10 ++-- .../credentials/oauth2/oauth2_credentials.h | 12 ++--- .../credentials/plugin/plugin_credentials.cc | 10 ++-- .../credentials/plugin/plugin_credentials.h | 12 ++--- .../lib/security/transport/auth_filters.h | 9 ++-- .../security/transport/client_auth_filter.cc | 31 +++++++---- src/core/lib/transport/transport.h | 21 ++++++-- src/core/lib/transport/transport_impl.h | 4 +- .../filters/client_authority_filter_test.cc | 40 +++++++++----- test/core/security/credentials_test.cc | 8 +-- test/core/security/oauth2_utils.cc | 4 +- .../test/server_context_test_spouse_test.mm | 4 +- .../test/server_context_test_spouse_test.cc | 4 +- tools/doxygen/Doxyfile.c++.internal | 2 + tools/doxygen/Doxyfile.core.internal | 2 + 40 files changed, 214 insertions(+), 164 deletions(-) diff --git a/BUILD b/BUILD index 981e00c9119..dba48fbbaa1 100644 --- a/BUILD +++ b/BUILD @@ -2182,6 +2182,7 @@ grpc_cc_library( "grpc_trace", "iomgr_port", "json", + "latch", "memory_quota", "orphanable", "promise", @@ -4172,6 +4173,7 @@ grpc_cc_library( deps = [ "arena", "arena_promise", + "capture", "config", "gpr_base", "grpc_base", diff --git a/build_autogenerated.yaml b/build_autogenerated.yaml index 9e943074f09..1dffac01dbc 100644 --- a/build_autogenerated.yaml +++ b/build_autogenerated.yaml @@ -820,6 +820,8 @@ libs: - src/core/lib/promise/detail/status.h - src/core/lib/promise/detail/switch.h - src/core/lib/promise/exec_ctx_wakeup_scheduler.h + - src/core/lib/promise/intra_activity_waiter.h + - src/core/lib/promise/latch.h - src/core/lib/promise/loop.h - src/core/lib/promise/map.h - src/core/lib/promise/poll.h @@ -1991,6 +1993,8 @@ libs: - src/core/lib/promise/detail/status.h - src/core/lib/promise/detail/switch.h - src/core/lib/promise/exec_ctx_wakeup_scheduler.h + - src/core/lib/promise/intra_activity_waiter.h + - src/core/lib/promise/latch.h - src/core/lib/promise/loop.h - src/core/lib/promise/map.h - src/core/lib/promise/poll.h diff --git a/gRPC-C++.podspec b/gRPC-C++.podspec index a99be0ba6b4..85f5127fec8 100644 --- a/gRPC-C++.podspec +++ b/gRPC-C++.podspec @@ -789,6 +789,8 @@ Pod::Spec.new do |s| 'src/core/lib/promise/detail/status.h', 'src/core/lib/promise/detail/switch.h', 'src/core/lib/promise/exec_ctx_wakeup_scheduler.h', + 'src/core/lib/promise/intra_activity_waiter.h', + 'src/core/lib/promise/latch.h', 'src/core/lib/promise/loop.h', 'src/core/lib/promise/map.h', 'src/core/lib/promise/poll.h', @@ -1590,6 +1592,8 @@ Pod::Spec.new do |s| 'src/core/lib/promise/detail/status.h', 'src/core/lib/promise/detail/switch.h', 'src/core/lib/promise/exec_ctx_wakeup_scheduler.h', + 'src/core/lib/promise/intra_activity_waiter.h', + 'src/core/lib/promise/latch.h', 'src/core/lib/promise/loop.h', 'src/core/lib/promise/map.h', 'src/core/lib/promise/poll.h', diff --git a/gRPC-Core.podspec b/gRPC-Core.podspec index 90c796fb176..8b900661035 100644 --- a/gRPC-Core.podspec +++ b/gRPC-Core.podspec @@ -1292,6 +1292,8 @@ Pod::Spec.new do |s| 'src/core/lib/promise/detail/status.h', 'src/core/lib/promise/detail/switch.h', 'src/core/lib/promise/exec_ctx_wakeup_scheduler.h', + 'src/core/lib/promise/intra_activity_waiter.h', + 'src/core/lib/promise/latch.h', 'src/core/lib/promise/loop.h', 'src/core/lib/promise/map.h', 'src/core/lib/promise/poll.h', @@ -2188,6 +2190,8 @@ Pod::Spec.new do |s| 'src/core/lib/promise/detail/status.h', 'src/core/lib/promise/detail/switch.h', 'src/core/lib/promise/exec_ctx_wakeup_scheduler.h', + 'src/core/lib/promise/intra_activity_waiter.h', + 'src/core/lib/promise/latch.h', 'src/core/lib/promise/loop.h', 'src/core/lib/promise/map.h', 'src/core/lib/promise/poll.h', diff --git a/grpc.gemspec b/grpc.gemspec index 42acad70ab8..6ee6c58294c 100644 --- a/grpc.gemspec +++ b/grpc.gemspec @@ -1211,6 +1211,8 @@ Gem::Specification.new do |s| s.files += %w( src/core/lib/promise/detail/status.h ) s.files += %w( src/core/lib/promise/detail/switch.h ) s.files += %w( src/core/lib/promise/exec_ctx_wakeup_scheduler.h ) + s.files += %w( src/core/lib/promise/intra_activity_waiter.h ) + s.files += %w( src/core/lib/promise/latch.h ) s.files += %w( src/core/lib/promise/loop.h ) s.files += %w( src/core/lib/promise/map.h ) s.files += %w( src/core/lib/promise/poll.h ) diff --git a/package.xml b/package.xml index 14d840c20d7..d064ac37ba6 100644 --- a/package.xml +++ b/package.xml @@ -1191,6 +1191,8 @@ + + diff --git a/src/core/ext/filters/client_idle/client_idle_filter.cc b/src/core/ext/filters/client_idle/client_idle_filter.cc index 631499ade95..9c1c7b3ae34 100644 --- a/src/core/ext/filters/client_idle/client_idle_filter.cc +++ b/src/core/ext/filters/client_idle/client_idle_filter.cc @@ -76,9 +76,8 @@ class ClientIdleFilter : public ChannelFilter { ClientIdleFilter& operator=(ClientIdleFilter&&) = default; // Construct a promise for one call. - ArenaPromise MakeCallPromise( - ClientInitialMetadata initial_metadata, - NextPromiseFactory next_promise_factory) override; + ArenaPromise MakeCallPromise( + CallArgs call_args, NextPromiseFactory next_promise_factory) override; bool StartTransportOp(grpc_transport_op* op) override; @@ -116,15 +115,14 @@ absl::StatusOr ClientIdleFilter::Create( } // Construct a promise for one call. -ArenaPromise ClientIdleFilter::MakeCallPromise( - ClientInitialMetadata initial_metadata, - NextPromiseFactory next_promise_factory) { +ArenaPromise ClientIdleFilter::MakeCallPromise( + CallArgs call_args, NextPromiseFactory next_promise_factory) { using Decrementer = std::unique_ptr; IncreaseCallCount(); - return ArenaPromise(Capture( - [](Decrementer*, ArenaPromise* next) - -> Poll { return (*next)(); }, - Decrementer(this), next_promise_factory(std::move(initial_metadata)))); + return ArenaPromise( + Capture([](Decrementer*, ArenaPromise* next) + -> Poll { return (*next)(); }, + Decrementer(this), next_promise_factory(std::move(call_args)))); } bool ClientIdleFilter::StartTransportOp(grpc_transport_op* op) { diff --git a/src/core/ext/filters/http/client_authority_filter.cc b/src/core/ext/filters/http/client_authority_filter.cc index 318f5ebc76d..548d66528d5 100644 --- a/src/core/ext/filters/http/client_authority_filter.cc +++ b/src/core/ext/filters/http/client_authority_filter.cc @@ -57,16 +57,17 @@ absl::StatusOr ClientAuthorityFilter::Create( return ClientAuthorityFilter(Slice::FromCopiedString(default_authority_str)); } -ArenaPromise ClientAuthorityFilter::MakeCallPromise( - ClientInitialMetadata initial_metadata, - NextPromiseFactory next_promise_factory) { +ArenaPromise ClientAuthorityFilter::MakeCallPromise( + CallArgs call_args, NextPromiseFactory next_promise_factory) { // If no authority is set, set the default authority. - if (initial_metadata->get_pointer(HttpAuthorityMetadata()) == nullptr) { - initial_metadata->Set(HttpAuthorityMetadata(), default_authority_.Ref()); + if (call_args.client_initial_metadata->get_pointer(HttpAuthorityMetadata()) == + nullptr) { + call_args.client_initial_metadata->Set(HttpAuthorityMetadata(), + default_authority_.Ref()); } // We have no asynchronous work, so we can just ask the next promise to run, // passing down initial_metadata. - return next_promise_factory(std::move(initial_metadata)); + return next_promise_factory(std::move(call_args)); } namespace { diff --git a/src/core/ext/filters/http/client_authority_filter.h b/src/core/ext/filters/http/client_authority_filter.h index 9d16f00a619..e27e5a93df8 100644 --- a/src/core/ext/filters/http/client_authority_filter.h +++ b/src/core/ext/filters/http/client_authority_filter.h @@ -37,9 +37,8 @@ class ClientAuthorityFilter final : public ChannelFilter { const grpc_channel_args* args, ChannelFilter::Args); // Construct a promise for one call. - ArenaPromise MakeCallPromise( - ClientInitialMetadata initial_metadata, - NextPromiseFactory next_promise_factory) override; + ArenaPromise MakeCallPromise( + CallArgs call_args, NextPromiseFactory next_promise_factory) override; private: explicit ClientAuthorityFilter(Slice default_authority) diff --git a/src/core/lib/channel/channel_stack.h b/src/core/lib/channel/channel_stack.h index 00a09db96fd..68111965aac 100644 --- a/src/core/lib/channel/channel_stack.h +++ b/src/core/lib/channel/channel_stack.h @@ -122,9 +122,8 @@ struct grpc_channel_filter { - allocation of memory for call data There is an on-going migration to move all filters to providing this, and then to drop start_transport_stream_op_batch. */ - grpc_core::ArenaPromise (*make_call_promise)( - grpc_channel_element* elem, - grpc_core::ClientInitialMetadata initial_metadata, + grpc_core::ArenaPromise (*make_call_promise)( + grpc_channel_element* elem, grpc_core::CallArgs call_args, grpc_core::NextPromiseFactory next_promise_factory); /* Called to handle channel level operations - e.g. new calls, or transport closure. diff --git a/src/core/lib/channel/promise_based_filter.cc b/src/core/lib/channel/promise_based_filter.cc index 89c507f7747..77f10ff0566 100644 --- a/src/core/lib/channel/promise_based_filter.cc +++ b/src/core/lib/channel/promise_based_filter.cc @@ -136,7 +136,7 @@ void ClientCallData::Cancel(grpc_error_handle error) { GRPC_ERROR_UNREF(cancelled_error_); cancelled_error_ = GRPC_ERROR_REF(error); // Stop running the promise. - promise_ = ArenaPromise(); + promise_ = ArenaPromise(); // If we have an op queued, fail that op. // Record what we've done. if (send_initial_state_ == SendInitialState::kQueued) { @@ -176,10 +176,12 @@ void ClientCallData::StartPromise() { { ScopedActivity activity(this); promise_ = filter->MakeCallPromise( - WrapMetadata(send_initial_metadata_batch_->payload - ->send_initial_metadata.send_initial_metadata), - [this](ClientInitialMetadata initial_metadata) { - return MakeNextPromise(std::move(initial_metadata)); + CallArgs{ + WrapMetadata(send_initial_metadata_batch_->payload + ->send_initial_metadata.send_initial_metadata), + nullptr}, + [this](CallArgs call_args) { + return MakeNextPromise(std::move(call_args)); }); } // Poll once. @@ -203,12 +205,13 @@ void ClientCallData::HookRecvTrailingMetadata( // Effectively: // - put the modified initial metadata into the batch to be sent down. // - return a wrapper around PollTrailingMetadata as the promise. -ArenaPromise ClientCallData::MakeNextPromise( - ClientInitialMetadata initial_metadata) { +ArenaPromise ClientCallData::MakeNextPromise( + CallArgs call_args) { GPR_ASSERT(send_initial_state_ == SendInitialState::kQueued); send_initial_metadata_batch_->payload->send_initial_metadata - .send_initial_metadata = UnwrapMetadata(std::move(initial_metadata)); - return ArenaPromise( + .send_initial_metadata = + UnwrapMetadata(std::move(call_args.client_initial_metadata)); + return ArenaPromise( [this]() { return PollTrailingMetadata(); }); } @@ -216,7 +219,7 @@ ArenaPromise ClientCallData::MakeNextPromise( // First poll: send the send_initial_metadata op down the stack. // All polls: await receiving the trailing metadata, then return it to the // application. -Poll ClientCallData::PollTrailingMetadata() { +Poll ClientCallData::PollTrailingMetadata() { if (send_initial_state_ == SendInitialState::kQueued) { // First poll: pass the send_initial_metadata op down the stack. GPR_ASSERT(send_initial_metadata_batch_ != nullptr); @@ -274,7 +277,7 @@ void ClientCallData::RecvTrailingMetadataReady(grpc_error_handle error) { WakeInsideCombiner(); } -// Given an error, fill in TrailingMetadata to represent that error. +// Given an error, fill in ServerMetadataHandle to represent that error. void ClientCallData::SetStatusFromError(grpc_metadata_batch* metadata, grpc_error_handle error) { grpc_status_code status_code = GRPC_STATUS_UNKNOWN; @@ -298,13 +301,13 @@ void ClientCallData::WakeInsideCombiner() { case SendInitialState::kQueued: case SendInitialState::kForwarded: { // Poll the promise once since we're waiting for it. - Poll poll; + Poll poll; { ScopedActivity activity(this); poll = promise_(); } - if (auto* r = absl::get_if(&poll)) { - promise_ = ArenaPromise(); + if (auto* r = absl::get_if(&poll)) { + promise_ = ArenaPromise(); auto* md = UnwrapMetadata(std::move(*r)); bool destroy_md = true; if (recv_trailing_state_ == RecvTrailingState::kComplete) { @@ -505,7 +508,7 @@ void ServerCallData::Cancel(grpc_error_handle error) { GRPC_ERROR_UNREF(cancelled_error_); cancelled_error_ = GRPC_ERROR_REF(error); // Stop running the promise. - promise_ = ArenaPromise(); + promise_ = ArenaPromise(); if (send_trailing_state_ == SendTrailingState::kQueued) { send_trailing_state_ = SendTrailingState::kCancelled; struct FailBatch : public grpc_closure { @@ -534,20 +537,20 @@ void ServerCallData::Cancel(grpc_error_handle error) { // Effectively: // - put the modified initial metadata into the batch being sent up. // - return a wrapper around PollTrailingMetadata as the promise. -ArenaPromise ServerCallData::MakeNextPromise( - ClientInitialMetadata initial_metadata) { +ArenaPromise ServerCallData::MakeNextPromise( + CallArgs call_args) { GPR_ASSERT(recv_initial_state_ == RecvInitialState::kComplete); - GPR_ASSERT(UnwrapMetadata(std::move(initial_metadata)) == + GPR_ASSERT(UnwrapMetadata(std::move(call_args.client_initial_metadata)) == recv_initial_metadata_); forward_recv_initial_metadata_callback_ = true; - return ArenaPromise( + return ArenaPromise( [this]() { return PollTrailingMetadata(); }); } // Wrapper to make it look like we're calling the next filter as a promise. // All polls: await sending the trailing metadata, then foward it down the // stack. -Poll ServerCallData::PollTrailingMetadata() { +Poll ServerCallData::PollTrailingMetadata() { switch (send_trailing_state_) { case SendTrailingState::kInitial: return Pending{}; @@ -587,9 +590,9 @@ void ServerCallData::RecvInitialMetadataReady(grpc_error_handle error) { // Construct the promise. ChannelFilter* filter = static_cast(elem()->channel_data); promise_ = filter->MakeCallPromise( - WrapMetadata(recv_initial_metadata_), - [this](ClientInitialMetadata initial_metadata) { - return MakeNextPromise(std::move(initial_metadata)); + CallArgs{WrapMetadata(recv_initial_metadata_), nullptr}, + [this](CallArgs call_args) { + return MakeNextPromise(std::move(call_args)); }); // Poll once. bool own_error = false; @@ -610,12 +613,12 @@ void ServerCallData::WakeInsideCombiner( bool forward_send_trailing_metadata = false; is_polling_ = true; if (recv_initial_state_ == RecvInitialState::kComplete) { - Poll poll; + Poll poll; { ScopedActivity activity(this); poll = promise_(); } - if (auto* r = absl::get_if(&poll)) { + if (auto* r = absl::get_if(&poll)) { auto* md = UnwrapMetadata(std::move(*r)); bool destroy_md = true; switch (send_trailing_state_) { diff --git a/src/core/lib/channel/promise_based_filter.h b/src/core/lib/channel/promise_based_filter.h index 1e18bf70217..7042885bfd9 100644 --- a/src/core/lib/channel/promise_based_filter.h +++ b/src/core/lib/channel/promise_based_filter.h @@ -54,9 +54,8 @@ class ChannelFilter { }; // Construct a promise for one call. - virtual ArenaPromise MakeCallPromise( - ClientInitialMetadata initial_metadata, - NextPromiseFactory next_promise_factory) = 0; + virtual ArenaPromise MakeCallPromise( + CallArgs call_args, NextPromiseFactory next_promise_factory) = 0; // Start a legacy transport op // Return true if the op was handled, false if it should be passed to the @@ -205,17 +204,16 @@ class ClientCallData : public BaseCallData { // Effectively: // - put the modified initial metadata into the batch to be sent down. // - return a wrapper around PollTrailingMetadata as the promise. - ArenaPromise MakeNextPromise( - ClientInitialMetadata initial_metadata); + ArenaPromise MakeNextPromise(CallArgs call_args); // Wrapper to make it look like we're calling the next filter as a promise. // First poll: send the send_initial_metadata op down the stack. // All polls: await receiving the trailing metadata, then return it to the // application. - Poll PollTrailingMetadata(); + Poll PollTrailingMetadata(); static void RecvTrailingMetadataReadyCallback(void* arg, grpc_error_handle error); void RecvTrailingMetadataReady(grpc_error_handle error); - // Given an error, fill in TrailingMetadata to represent that error. + // Given an error, fill in ServerMetadataHandle to represent that error. void SetStatusFromError(grpc_metadata_batch* metadata, grpc_error_handle error); // Wakeup and poll the promise if appropriate. @@ -223,7 +221,7 @@ class ClientCallData : public BaseCallData { void OnWakeup() override; // Contained promise - ArenaPromise promise_; + ArenaPromise promise_; // Queued batch containing at least a send_initial_metadata op. grpc_transport_stream_op_batch* send_initial_metadata_batch_ = nullptr; // Pointer to where trailing metadata will be stored. @@ -289,12 +287,11 @@ class ServerCallData : public BaseCallData { // Effectively: // - put the modified initial metadata into the batch being sent up. // - return a wrapper around PollTrailingMetadata as the promise. - ArenaPromise MakeNextPromise( - ClientInitialMetadata initial_metadata); + ArenaPromise MakeNextPromise(CallArgs call_args); // Wrapper to make it look like we're calling the next filter as a promise. // All polls: await sending the trailing metadata, then foward it down the // stack. - Poll PollTrailingMetadata(); + Poll PollTrailingMetadata(); static void RecvInitialMetadataReadyCallback(void* arg, grpc_error_handle error); void RecvInitialMetadataReady(grpc_error_handle error); @@ -303,7 +300,7 @@ class ServerCallData : public BaseCallData { void OnWakeup() override; // Contained promise - ArenaPromise promise_; + ArenaPromise promise_; // Pointer to where initial metadata will be stored. grpc_metadata_batch* recv_initial_metadata_ = nullptr; // Closure to call when we're done with the trailing metadata. @@ -366,10 +363,10 @@ MakePromiseBasedFilter(const char* name) { static_cast(elem->call_data)->StartBatch(batch); }, // make_call_promise - [](grpc_channel_element* elem, ClientInitialMetadata initial_metadata, + [](grpc_channel_element* elem, CallArgs call_args, NextPromiseFactory next_promise_factory) { return static_cast(elem->channel_data) - ->MakeCallPromise(std::move(initial_metadata), + ->MakeCallPromise(std::move(call_args), std::move(next_promise_factory)); }, // start_transport_op diff --git a/src/core/lib/security/authorization/grpc_server_authz_filter.cc b/src/core/lib/security/authorization/grpc_server_authz_filter.cc index 12a500b9309..3000f4cc799 100644 --- a/src/core/lib/security/authorization/grpc_server_authz_filter.cc +++ b/src/core/lib/security/authorization/grpc_server_authz_filter.cc @@ -49,7 +49,7 @@ absl::StatusOr GrpcServerAuthzFilter::Create( } bool GrpcServerAuthzFilter::IsAuthorized( - const ClientInitialMetadata& initial_metadata) { + const ClientMetadataHandle& initial_metadata) { EvaluateArgs args(initial_metadata.get(), &per_channel_evaluate_args_); if (GRPC_TRACE_FLAG_ENABLED(grpc_authz_trace)) { gpr_log(GPR_DEBUG, @@ -92,14 +92,13 @@ bool GrpcServerAuthzFilter::IsAuthorized( return false; } -ArenaPromise GrpcServerAuthzFilter::MakeCallPromise( - ClientInitialMetadata initial_metadata, - NextPromiseFactory next_promise_factory) { - if (!IsAuthorized(initial_metadata)) { - return ArenaPromise(Immediate(TrailingMetadata( +ArenaPromise GrpcServerAuthzFilter::MakeCallPromise( + CallArgs call_args, NextPromiseFactory next_promise_factory) { + if (!IsAuthorized(call_args.client_initial_metadata)) { + return ArenaPromise(Immediate(ServerMetadataHandle( absl::PermissionDeniedError("Unauthorized RPC request rejected.")))); } - return next_promise_factory(std::move(initial_metadata)); + return next_promise_factory(std::move(call_args)); } const grpc_channel_filter GrpcServerAuthzFilter::kFilterVtable = diff --git a/src/core/lib/security/authorization/grpc_server_authz_filter.h b/src/core/lib/security/authorization/grpc_server_authz_filter.h index 866c1ad5420..e4bd0e46be7 100644 --- a/src/core/lib/security/authorization/grpc_server_authz_filter.h +++ b/src/core/lib/security/authorization/grpc_server_authz_filter.h @@ -30,16 +30,15 @@ class GrpcServerAuthzFilter final : public ChannelFilter { static absl::StatusOr Create( const grpc_channel_args* args, ChannelFilter::Args); - ArenaPromise MakeCallPromise( - ClientInitialMetadata initial_metadata, - NextPromiseFactory next_promise_factory) override; + ArenaPromise MakeCallPromise( + CallArgs call_args, NextPromiseFactory next_promise_factory) override; private: GrpcServerAuthzFilter( RefCountedPtr auth_context, grpc_endpoint* endpoint, RefCountedPtr provider); - bool IsAuthorized(const ClientInitialMetadata& initial_metadata); + bool IsAuthorized(const ClientMetadataHandle& initial_metadata); RefCountedPtr auth_context_; EvaluateArgs::PerChannelArgs per_channel_evaluate_args_; diff --git a/src/core/lib/security/credentials/call_creds_util.cc b/src/core/lib/security/credentials/call_creds_util.cc index ecb79a2b839..c04a52f072a 100644 --- a/src/core/lib/security/credentials/call_creds_util.cc +++ b/src/core/lib/security/credentials/call_creds_util.cc @@ -31,7 +31,7 @@ struct ServiceUrlAndMethod { }; ServiceUrlAndMethod MakeServiceUrlAndMethod( - const ClientInitialMetadata& initial_metadata, + const ClientMetadataHandle& initial_metadata, const grpc_call_credentials::GetRequestMetadataArgs* args) { auto service = initial_metadata->get_pointer(HttpPathMetadata())->as_string_view(); @@ -65,13 +65,13 @@ ServiceUrlAndMethod MakeServiceUrlAndMethod( } // namespace std::string MakeJwtServiceUrl( - const ClientInitialMetadata& initial_metadata, + const ClientMetadataHandle& initial_metadata, const grpc_call_credentials::GetRequestMetadataArgs* args) { return MakeServiceUrlAndMethod(initial_metadata, args).service_url; } grpc_auth_metadata_context MakePluginAuthMetadataContext( - const ClientInitialMetadata& initial_metadata, + const ClientMetadataHandle& initial_metadata, const grpc_call_credentials::GetRequestMetadataArgs* args) { auto fields = MakeServiceUrlAndMethod(initial_metadata, args); grpc_auth_metadata_context ctx; diff --git a/src/core/lib/security/credentials/call_creds_util.h b/src/core/lib/security/credentials/call_creds_util.h index 84072ba4bd9..30b70dad97c 100644 --- a/src/core/lib/security/credentials/call_creds_util.h +++ b/src/core/lib/security/credentials/call_creds_util.h @@ -29,12 +29,12 @@ namespace grpc_core { // Helper function to construct service URL for jwt call creds. std::string MakeJwtServiceUrl( - const ClientInitialMetadata& initial_metadata, + const ClientMetadataHandle& initial_metadata, const grpc_call_credentials::GetRequestMetadataArgs* args); // Helper function to construct context for plugin call creds. grpc_auth_metadata_context MakePluginAuthMetadataContext( - const ClientInitialMetadata& initial_metadata, + const ClientMetadataHandle& initial_metadata, const grpc_call_credentials::GetRequestMetadataArgs* args); } // namespace grpc_core diff --git a/src/core/lib/security/credentials/composite/composite_credentials.cc b/src/core/lib/security/credentials/composite/composite_credentials.cc index cd63139bcbd..3fafa805c42 100644 --- a/src/core/lib/security/credentials/composite/composite_credentials.cc +++ b/src/core/lib/security/credentials/composite/composite_credentials.cc @@ -43,15 +43,15 @@ const char kCredentialsTypeComposite[] = "composite"; /* -- Composite call credentials. -- */ -grpc_core::ArenaPromise> +grpc_core::ArenaPromise> grpc_composite_call_credentials::GetRequestMetadata( - grpc_core::ClientInitialMetadata initial_metadata, + grpc_core::ClientMetadataHandle initial_metadata, const grpc_call_credentials::GetRequestMetadataArgs* args) { auto self = Ref(); return TrySeqIter( inner_.begin(), inner_.end(), std::move(initial_metadata), [self, args](const grpc_core::RefCountedPtr& creds, - grpc_core::ClientInitialMetadata initial_metadata) { + grpc_core::ClientMetadataHandle initial_metadata) { return creds->GetRequestMetadata(std::move(initial_metadata), args); }); } diff --git a/src/core/lib/security/credentials/composite/composite_credentials.h b/src/core/lib/security/credentials/composite/composite_credentials.h index e200d9a3a9a..24a036fa3e4 100644 --- a/src/core/lib/security/credentials/composite/composite_credentials.h +++ b/src/core/lib/security/credentials/composite/composite_credentials.h @@ -90,8 +90,8 @@ class grpc_composite_call_credentials : public grpc_call_credentials { grpc_core::RefCountedPtr creds2); ~grpc_composite_call_credentials() override = default; - grpc_core::ArenaPromise> - GetRequestMetadata(grpc_core::ClientInitialMetadata initial_metadata, + grpc_core::ArenaPromise> + GetRequestMetadata(grpc_core::ClientMetadataHandle initial_metadata, const GetRequestMetadataArgs* args) override; grpc_security_level min_security_level() const override { diff --git a/src/core/lib/security/credentials/credentials.h b/src/core/lib/security/credentials/credentials.h index de8a42891d9..2d75cadbedf 100644 --- a/src/core/lib/security/credentials/credentials.h +++ b/src/core/lib/security/credentials/credentials.h @@ -216,8 +216,8 @@ struct grpc_call_credentials ~grpc_call_credentials() override = default; virtual grpc_core::ArenaPromise< - absl::StatusOr> - GetRequestMetadata(grpc_core::ClientInitialMetadata initial_metadata, + absl::StatusOr> + GetRequestMetadata(grpc_core::ClientMetadataHandle initial_metadata, const GetRequestMetadataArgs* args) = 0; virtual grpc_security_level min_security_level() const { diff --git a/src/core/lib/security/credentials/fake/fake_credentials.cc b/src/core/lib/security/credentials/fake/fake_credentials.cc index 52d00a45523..b8371017841 100644 --- a/src/core/lib/security/credentials/fake/fake_credentials.cc +++ b/src/core/lib/security/credentials/fake/fake_credentials.cc @@ -97,9 +97,9 @@ const char* grpc_fake_transport_get_expected_targets( /* -- Metadata-only test credentials. -- */ -grpc_core::ArenaPromise> +grpc_core::ArenaPromise> grpc_md_only_test_credentials::GetRequestMetadata( - grpc_core::ClientInitialMetadata initial_metadata, + grpc_core::ClientMetadataHandle initial_metadata, const grpc_call_credentials::GetRequestMetadataArgs*) { initial_metadata->Append( key_.as_string_view(), value_.Ref(), diff --git a/src/core/lib/security/credentials/fake/fake_credentials.h b/src/core/lib/security/credentials/fake/fake_credentials.h index f55e9863aa9..df5ae25c82b 100644 --- a/src/core/lib/security/credentials/fake/fake_credentials.h +++ b/src/core/lib/security/credentials/fake/fake_credentials.h @@ -65,8 +65,8 @@ class grpc_md_only_test_credentials : public grpc_call_credentials { key_(grpc_core::Slice::FromCopiedString(md_key)), value_(grpc_core::Slice::FromCopiedString(md_value)) {} - grpc_core::ArenaPromise> - GetRequestMetadata(grpc_core::ClientInitialMetadata initial_metadata, + grpc_core::ArenaPromise> + GetRequestMetadata(grpc_core::ClientMetadataHandle initial_metadata, const GetRequestMetadataArgs* args) override; std::string debug_string() override { return "MD only Test Credentials"; }; diff --git a/src/core/lib/security/credentials/iam/iam_credentials.cc b/src/core/lib/security/credentials/iam/iam_credentials.cc index f010e51210e..f986897df04 100644 --- a/src/core/lib/security/credentials/iam/iam_credentials.cc +++ b/src/core/lib/security/credentials/iam/iam_credentials.cc @@ -31,9 +31,9 @@ #include "src/core/lib/promise/promise.h" #include "src/core/lib/surface/api_trace.h" -grpc_core::ArenaPromise> +grpc_core::ArenaPromise> grpc_google_iam_credentials::GetRequestMetadata( - grpc_core::ClientInitialMetadata initial_metadata, + grpc_core::ClientMetadataHandle initial_metadata, const grpc_call_credentials::GetRequestMetadataArgs*) { if (token_.has_value()) { initial_metadata->Append( diff --git a/src/core/lib/security/credentials/iam/iam_credentials.h b/src/core/lib/security/credentials/iam/iam_credentials.h index 2724ad80c7a..7aa8889fa7e 100644 --- a/src/core/lib/security/credentials/iam/iam_credentials.h +++ b/src/core/lib/security/credentials/iam/iam_credentials.h @@ -30,8 +30,8 @@ class grpc_google_iam_credentials : public grpc_call_credentials { grpc_google_iam_credentials(const char* token, const char* authority_selector); - grpc_core::ArenaPromise> - GetRequestMetadata(grpc_core::ClientInitialMetadata initial_metadata, + grpc_core::ArenaPromise> + GetRequestMetadata(grpc_core::ClientMetadataHandle initial_metadata, const GetRequestMetadataArgs* args) override; std::string debug_string() override { return debug_string_; } diff --git a/src/core/lib/security/credentials/jwt/jwt_credentials.cc b/src/core/lib/security/credentials/jwt/jwt_credentials.cc index 8b3e76eeaa2..dda77ef2641 100644 --- a/src/core/lib/security/credentials/jwt/jwt_credentials.cc +++ b/src/core/lib/security/credentials/jwt/jwt_credentials.cc @@ -49,9 +49,9 @@ grpc_service_account_jwt_access_credentials:: gpr_mu_destroy(&cache_mu_); } -grpc_core::ArenaPromise> +grpc_core::ArenaPromise> grpc_service_account_jwt_access_credentials::GetRequestMetadata( - grpc_core::ClientInitialMetadata initial_metadata, + grpc_core::ClientMetadataHandle initial_metadata, const grpc_call_credentials::GetRequestMetadataArgs* args) { gpr_timespec refresh_threshold = gpr_time_from_seconds( GRPC_SECURE_TOKEN_REFRESH_THRESHOLD_SECS, GPR_TIMESPAN); diff --git a/src/core/lib/security/credentials/jwt/jwt_credentials.h b/src/core/lib/security/credentials/jwt/jwt_credentials.h index e2df0040f29..374254d83a2 100644 --- a/src/core/lib/security/credentials/jwt/jwt_credentials.h +++ b/src/core/lib/security/credentials/jwt/jwt_credentials.h @@ -38,8 +38,8 @@ class grpc_service_account_jwt_access_credentials gpr_timespec token_lifetime); ~grpc_service_account_jwt_access_credentials() override; - grpc_core::ArenaPromise> - GetRequestMetadata(grpc_core::ClientInitialMetadata initial_metadata, + grpc_core::ArenaPromise> + GetRequestMetadata(grpc_core::ClientMetadataHandle initial_metadata, const GetRequestMetadataArgs* args) override; const gpr_timespec& jwt_lifetime() const { return jwt_lifetime_; } diff --git a/src/core/lib/security/credentials/oauth2/oauth2_credentials.cc b/src/core/lib/security/credentials/oauth2/oauth2_credentials.cc index 77a76130e94..ed304c36ca7 100644 --- a/src/core/lib/security/credentials/oauth2/oauth2_credentials.cc +++ b/src/core/lib/security/credentials/oauth2/oauth2_credentials.cc @@ -281,9 +281,9 @@ void grpc_oauth2_token_fetcher_credentials::on_http_response( delete r; } -grpc_core::ArenaPromise> +grpc_core::ArenaPromise> grpc_oauth2_token_fetcher_credentials::GetRequestMetadata( - grpc_core::ClientInitialMetadata initial_metadata, + grpc_core::ClientMetadataHandle initial_metadata, const grpc_call_credentials::GetRequestMetadataArgs*) { // Check if we can use the cached token. absl::optional cached_access_token_value; @@ -328,7 +328,7 @@ grpc_oauth2_token_fetcher_credentials::GetRequestMetadata( } return [pending_request]() - -> grpc_core::Poll> { + -> grpc_core::Poll> { if (!pending_request->done.load(std::memory_order_acquire)) { return grpc_core::Pending{}; } @@ -696,9 +696,9 @@ grpc_call_credentials* grpc_sts_credentials_create( // Oauth2 Access Token credentials. // -grpc_core::ArenaPromise> +grpc_core::ArenaPromise> grpc_access_token_credentials::GetRequestMetadata( - grpc_core::ClientInitialMetadata initial_metadata, + grpc_core::ClientMetadataHandle initial_metadata, const grpc_call_credentials::GetRequestMetadataArgs*) { initial_metadata->Append( GRPC_AUTHORIZATION_METADATA_KEY, access_token_value_.Ref(), diff --git a/src/core/lib/security/credentials/oauth2/oauth2_credentials.h b/src/core/lib/security/credentials/oauth2/oauth2_credentials.h index 30fc30899bb..d3d9e0a3c6e 100644 --- a/src/core/lib/security/credentials/oauth2/oauth2_credentials.h +++ b/src/core/lib/security/credentials/oauth2/oauth2_credentials.h @@ -79,9 +79,9 @@ struct grpc_oauth2_pending_get_request_metadata std::atomic done{false}; grpc_core::Waker waker; grpc_polling_entity* pollent; - grpc_core::ClientInitialMetadata md; + grpc_core::ClientMetadataHandle md; struct grpc_oauth2_pending_get_request_metadata* next; - absl::StatusOr result; + absl::StatusOr result; }; // -- Oauth2 Token Fetcher credentials -- @@ -94,8 +94,8 @@ class grpc_oauth2_token_fetcher_credentials : public grpc_call_credentials { grpc_oauth2_token_fetcher_credentials(); ~grpc_oauth2_token_fetcher_credentials() override; - grpc_core::ArenaPromise> - GetRequestMetadata(grpc_core::ClientInitialMetadata initial_metadata, + grpc_core::ArenaPromise> + GetRequestMetadata(grpc_core::ClientMetadataHandle initial_metadata, const GetRequestMetadataArgs* args) override; void on_http_response(grpc_credentials_metadata_request* r, @@ -152,8 +152,8 @@ class grpc_access_token_credentials final : public grpc_call_credentials { public: explicit grpc_access_token_credentials(const char* access_token); - grpc_core::ArenaPromise> - GetRequestMetadata(grpc_core::ClientInitialMetadata initial_metadata, + grpc_core::ArenaPromise> + GetRequestMetadata(grpc_core::ClientMetadataHandle initial_metadata, const GetRequestMetadataArgs* args) override; std::string debug_string() override; diff --git a/src/core/lib/security/credentials/plugin/plugin_credentials.cc b/src/core/lib/security/credentials/plugin/plugin_credentials.cc index 2414d322951..c9691f97845 100644 --- a/src/core/lib/security/credentials/plugin/plugin_credentials.cc +++ b/src/core/lib/security/credentials/plugin/plugin_credentials.cc @@ -60,7 +60,7 @@ std::string grpc_plugin_credentials::debug_string() { return debug_str; } -absl::StatusOr +absl::StatusOr grpc_plugin_credentials::PendingRequest::ProcessPluginResult( const grpc_metadata* md, size_t num_md, grpc_status_code status, const char* error_details) { @@ -96,12 +96,12 @@ grpc_plugin_credentials::PendingRequest::ProcessPluginResult( }); } if (!error.ok()) return std::move(error); - return grpc_core::ClientInitialMetadata(std::move(md_)); + return grpc_core::ClientMetadataHandle(std::move(md_)); } } } -grpc_core::Poll> +grpc_core::Poll> grpc_plugin_credentials::PendingRequest::PollAsyncResult() { if (!ready_.load(std::memory_order_acquire)) { return grpc_core::Pending{}; @@ -137,9 +137,9 @@ void grpc_plugin_credentials::PendingRequest::RequestMetadataReady( r->waker_.Wakeup(); } -grpc_core::ArenaPromise> +grpc_core::ArenaPromise> grpc_plugin_credentials::GetRequestMetadata( - grpc_core::ClientInitialMetadata initial_metadata, + grpc_core::ClientMetadataHandle initial_metadata, const grpc_call_credentials::GetRequestMetadataArgs* args) { if (plugin_.get_metadata == nullptr) { return grpc_core::Immediate(std::move(initial_metadata)); diff --git a/src/core/lib/security/credentials/plugin/plugin_credentials.h b/src/core/lib/security/credentials/plugin/plugin_credentials.h index 2ed2cd977cd..4c7c2f064ac 100644 --- a/src/core/lib/security/credentials/plugin/plugin_credentials.h +++ b/src/core/lib/security/credentials/plugin/plugin_credentials.h @@ -35,8 +35,8 @@ struct grpc_plugin_credentials final : public grpc_call_credentials { grpc_security_level min_security_level); ~grpc_plugin_credentials() override; - grpc_core::ArenaPromise> - GetRequestMetadata(grpc_core::ClientInitialMetadata initial_metadata, + grpc_core::ArenaPromise> + GetRequestMetadata(grpc_core::ClientMetadataHandle initial_metadata, const GetRequestMetadataArgs* args) override; std::string debug_string() override; @@ -45,7 +45,7 @@ struct grpc_plugin_credentials final : public grpc_call_credentials { class PendingRequest : public grpc_core::RefCounted { public: PendingRequest(grpc_core::RefCountedPtr creds, - grpc_core::ClientInitialMetadata initial_metadata, + grpc_core::ClientMetadataHandle initial_metadata, const grpc_call_credentials::GetRequestMetadataArgs* args) : call_creds_(std::move(creds)), context_( @@ -60,11 +60,11 @@ struct grpc_plugin_credentials final : public grpc_call_credentials { } } - absl::StatusOr ProcessPluginResult( + absl::StatusOr ProcessPluginResult( const grpc_metadata* md, size_t num_md, grpc_status_code status, const char* error_details); - grpc_core::Poll> + grpc_core::Poll> PollAsyncResult(); static void RequestMetadataReady(void* request, const grpc_metadata* md, @@ -80,7 +80,7 @@ struct grpc_plugin_credentials final : public grpc_call_credentials { grpc_core::Activity::current()->MakeNonOwningWaker()}; grpc_core::RefCountedPtr call_creds_; grpc_auth_metadata_context context_; - grpc_core::ClientInitialMetadata md_; + grpc_core::ClientMetadataHandle md_; // final status absl::InlinedVector metadata_; std::string error_details_; diff --git a/src/core/lib/security/transport/auth_filters.h b/src/core/lib/security/transport/auth_filters.h index 084e7125e70..2a338a062d2 100644 --- a/src/core/lib/security/transport/auth_filters.h +++ b/src/core/lib/security/transport/auth_filters.h @@ -41,17 +41,16 @@ class ClientAuthFilter final : public ChannelFilter { ChannelFilter::Args); // Construct a promise for one call. - ArenaPromise MakeCallPromise( - ClientInitialMetadata initial_metadata, - NextPromiseFactory next_promise_factory) override; + ArenaPromise MakeCallPromise( + CallArgs call_args, NextPromiseFactory next_promise_factory) override; private: ClientAuthFilter( RefCountedPtr security_connector, RefCountedPtr auth_context); - ArenaPromise> GetCallCredsMetadata( - ClientInitialMetadata initial_metadata); + ArenaPromise> GetCallCredsMetadata( + CallArgs call_args); // Contains refs to security connector and auth context. grpc_call_credentials::GetRequestMetadataArgs args_; diff --git a/src/core/lib/security/transport/client_auth_filter.cc b/src/core/lib/security/transport/client_auth_filter.cc index f4da0d63a8d..79178972fbb 100644 --- a/src/core/lib/security/transport/client_auth_filter.cc +++ b/src/core/lib/security/transport/client_auth_filter.cc @@ -31,6 +31,7 @@ #include "src/core/lib/channel/channel_stack.h" #include "src/core/lib/channel/promise_based_filter.h" #include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/capture.h" #include "src/core/lib/iomgr/error.h" #include "src/core/lib/profiling/timers.h" #include "src/core/lib/promise/promise.h" @@ -98,8 +99,8 @@ ClientAuthFilter::ClientAuthFilter( RefCountedPtr auth_context) : args_{std::move(security_connector), std::move(auth_context)} {} -ArenaPromise> -ClientAuthFilter::GetCallCredsMetadata(ClientInitialMetadata initial_metadata) { +ArenaPromise> ClientAuthFilter::GetCallCredsMetadata( + CallArgs call_args) { auto* ctx = static_cast( GetContext()[GRPC_CONTEXT_SECURITY].value); grpc_call_credentials* channel_call_creds = @@ -108,8 +109,7 @@ ClientAuthFilter::GetCallCredsMetadata(ClientInitialMetadata initial_metadata) { if (channel_call_creds == nullptr && !call_creds_has_md) { /* Skip sending metadata altogether. */ - return Immediate( - absl::StatusOr(std::move(initial_metadata))); + return Immediate(absl::StatusOr(std::move(call_args))); } RefCountedPtr creds; @@ -148,12 +148,20 @@ ClientAuthFilter::GetCallCredsMetadata(ClientInitialMetadata initial_metadata) { "transfer call credential.")); } - return creds->GetRequestMetadata(std::move(initial_metadata), &args_); + auto client_initial_metadata = std::move(call_args.client_initial_metadata); + return TrySeq( + creds->GetRequestMetadata(std::move(client_initial_metadata), &args_), + Capture( + [](CallArgs* rest_of_args, ClientMetadataHandle new_metadata) { + rest_of_args->client_initial_metadata = std::move(new_metadata); + return Immediate>( + absl::StatusOr(std::move(*rest_of_args))); + }, + std::move(call_args))); } -ArenaPromise ClientAuthFilter::MakeCallPromise( - ClientInitialMetadata initial_metadata, - NextPromiseFactory next_promise_factory) { +ArenaPromise ClientAuthFilter::MakeCallPromise( + CallArgs call_args, NextPromiseFactory next_promise_factory) { auto* legacy_ctx = GetContext(); if (legacy_ctx[GRPC_CONTEXT_SECURITY].value == nullptr) { legacy_ctx[GRPC_CONTEXT_SECURITY].value = @@ -166,13 +174,14 @@ ArenaPromise ClientAuthFilter::MakeCallPromise( legacy_ctx[GRPC_CONTEXT_SECURITY].value) ->auth_context = args_.auth_context; - auto* host = initial_metadata->get_pointer(HttpAuthorityMetadata()); + auto* host = + call_args.client_initial_metadata->get_pointer(HttpAuthorityMetadata()); if (host == nullptr) { - return next_promise_factory(std::move(initial_metadata)); + return next_promise_factory(std::move(call_args)); } return TrySeq(args_.security_connector->CheckCallHost( host->as_string_view(), args_.auth_context.get()), - GetCallCredsMetadata(std::move(initial_metadata)), + GetCallCredsMetadata(std::move(call_args)), next_promise_factory); } diff --git a/src/core/lib/transport/transport.h b/src/core/lib/transport/transport.h index 72b7e697f54..b392d1bd6bd 100644 --- a/src/core/lib/transport/transport.h +++ b/src/core/lib/transport/transport.h @@ -31,6 +31,7 @@ #include "src/core/lib/iomgr/pollset.h" #include "src/core/lib/iomgr/pollset_set.h" #include "src/core/lib/promise/arena_promise.h" +#include "src/core/lib/promise/latch.h" #include "src/core/lib/resource_quota/arena.h" #include "src/core/lib/slice/slice_internal.h" #include "src/core/lib/transport/byte_stream.h" @@ -99,21 +100,33 @@ class MetadataHandle { // Trailing metadata type // TODO(ctiller): This should be a bespoke instance of MetadataMap<> -using TrailingMetadata = MetadataHandle; +using ServerMetadata = grpc_metadata_batch; +using ServerMetadataHandle = MetadataHandle; // Ok/not-ok check for trailing metadata, so that it can be used as result types // for TrySeq. -inline bool IsStatusOk(const TrailingMetadata& m) { +inline bool IsStatusOk(const ServerMetadataHandle& m) { return m->get(GrpcStatusMetadata()).value_or(GRPC_STATUS_UNKNOWN) == GRPC_STATUS_OK; } // Client initial metadata type // TODO(ctiller): This should be a bespoke instance of MetadataMap<> -using ClientInitialMetadata = MetadataHandle; +using ClientMetadata = grpc_metadata_batch; +using ClientMetadataHandle = MetadataHandle; + +// Server initial metadata type +// TODO(ctiller): This should be a bespoke instance of MetadataMap<> +using ServerMetadataHandle = MetadataHandle; + +struct CallArgs { + ClientMetadataHandle client_initial_metadata; + Latch* server_initial_metadata; +}; using NextPromiseFactory = - std::function(ClientInitialMetadata)>; + std::function(CallArgs)>; + } // namespace grpc_core /* forward declarations */ diff --git a/src/core/lib/transport/transport_impl.h b/src/core/lib/transport/transport_impl.h index 11e99935c66..9bd7cfc758e 100644 --- a/src/core/lib/transport/transport_impl.h +++ b/src/core/lib/transport/transport_impl.h @@ -45,8 +45,8 @@ typedef struct grpc_transport_vtable { - allocation of memory for call data (sizeof_stream may be ignored) There is an on-going migration to move all filters to providing this, and then to drop perform_stream_op. */ - grpc_core::ArenaPromise (*make_call_promise)( - grpc_transport* self, grpc_core::ClientInitialMetadata initial_metadata, + grpc_core::ArenaPromise (*make_call_promise)( + grpc_transport* self, grpc_core::ClientMetadataHandle initial_metadata, grpc_core::NextPromiseFactory next_promise_factory); /* implementation of grpc_transport_set_pollset */ diff --git a/test/core/filters/client_authority_filter_test.cc b/test/core/filters/client_authority_filter_test.cc index 023c64addf6..5d02aaa2b2f 100644 --- a/test/core/filters/client_authority_filter_test.cc +++ b/test/core/filters/client_authority_filter_test.cc @@ -71,18 +71,24 @@ TEST(ClientAuthorityFilterTest, PromiseCompletesImmediatelyAndSetsAuthority) { // TODO(ctiller): use Activity here, once it's ready. TestContext context(arena.get()); auto promise = filter.MakeCallPromise( - ClientInitialMetadata::TestOnlyWrap(&initial_metadata_batch), - [&](ClientInitialMetadata initial_metadata) { - EXPECT_EQ(initial_metadata->get_pointer(HttpAuthorityMetadata()) + CallArgs{ + ClientMetadataHandle::TestOnlyWrap(&initial_metadata_batch), + nullptr, + }, + [&](CallArgs call_args) { + EXPECT_EQ(call_args.client_initial_metadata + ->get_pointer(HttpAuthorityMetadata()) ->as_string_view(), "foo.test.google.au"); seen = true; - return ArenaPromise([&]() -> Poll { - return TrailingMetadata::TestOnlyWrap(&trailing_metadata_batch); - }); + return ArenaPromise( + [&]() -> Poll { + return ServerMetadataHandle::TestOnlyWrap( + &trailing_metadata_batch); + }); }); auto result = promise(); - EXPECT_TRUE(absl::get_if(&result) != nullptr); + EXPECT_TRUE(absl::get_if(&result) != nullptr); EXPECT_TRUE(seen); } @@ -99,18 +105,24 @@ TEST(ClientAuthorityFilterTest, // TODO(ctiller): use Activity here, once it's ready. TestContext context(arena.get()); auto promise = filter.MakeCallPromise( - ClientInitialMetadata::TestOnlyWrap(&initial_metadata_batch), - [&](ClientInitialMetadata initial_metadata) { - EXPECT_EQ(initial_metadata->get_pointer(HttpAuthorityMetadata()) + CallArgs{ + ClientMetadataHandle::TestOnlyWrap(&initial_metadata_batch), + nullptr, + }, + [&](CallArgs call_args) { + EXPECT_EQ(call_args.client_initial_metadata + ->get_pointer(HttpAuthorityMetadata()) ->as_string_view(), "bar.test.google.au"); seen = true; - return ArenaPromise([&]() -> Poll { - return TrailingMetadata::TestOnlyWrap(&trailing_metadata_batch); - }); + return ArenaPromise( + [&]() -> Poll { + return ServerMetadataHandle::TestOnlyWrap( + &trailing_metadata_batch); + }); }); auto result = promise(); - EXPECT_TRUE(absl::get_if(&result) != nullptr); + EXPECT_TRUE(absl::get_if(&result) != nullptr); EXPECT_TRUE(seen); } diff --git a/test/core/security/credentials_test.cc b/test/core/security/credentials_test.cc index 47777036d20..dc396a8b67e 100644 --- a/test/core/security/credentials_test.cc +++ b/test/core/security/credentials_test.cc @@ -450,9 +450,9 @@ class RequestMetadataState : public RefCounted { activity_ = MakeActivity( [this, creds] { return Seq(creds->GetRequestMetadata( - ClientInitialMetadata::TestOnlyWrap(&md_), + ClientMetadataHandle::TestOnlyWrap(&md_), &get_request_metadata_args_), - [this](absl::StatusOr metadata) { + [this](absl::StatusOr metadata) { if (metadata.ok()) { GPR_ASSERT(metadata->get() == &md_); } @@ -1771,8 +1771,8 @@ struct fake_call_creds : public grpc_call_credentials { public: fake_call_creds() : grpc_call_credentials("fake") {} - ArenaPromise> GetRequestMetadata( - ClientInitialMetadata initial_metadata, + ArenaPromise> GetRequestMetadata( + ClientMetadataHandle initial_metadata, const grpc_call_credentials::GetRequestMetadataArgs*) override { initial_metadata->Append("foo", Slice::FromStaticString("oof"), [](absl::string_view, const Slice&) { abort(); }); diff --git a/test/core/security/oauth2_utils.cc b/test/core/security/oauth2_utils.cc index 82d9696569f..183580156bc 100644 --- a/test/core/security/oauth2_utils.cc +++ b/test/core/security/oauth2_utils.cc @@ -57,10 +57,10 @@ char* grpc_test_fetch_oauth2_token_with_credentials( [creds, &initial_metadata, &get_request_metadata_args]() { return grpc_core::Map( creds->GetRequestMetadata( - grpc_core::ClientInitialMetadata::TestOnlyWrap( + grpc_core::ClientMetadataHandle::TestOnlyWrap( &initial_metadata), &get_request_metadata_args), - [](const absl::StatusOr& s) { + [](const absl::StatusOr& s) { return s.status(); }); }, diff --git a/test/cpp/cocoapods/test/server_context_test_spouse_test.mm b/test/cpp/cocoapods/test/server_context_test_spouse_test.mm index f1ffe8cffa7..e26942147a6 100644 --- a/test/cpp/cocoapods/test/server_context_test_spouse_test.mm +++ b/test/cpp/cocoapods/test/server_context_test_spouse_test.mm @@ -55,7 +55,7 @@ bool ClientMetadataContains(const grpc::ServerContext& context, const grpc::stri @implementation ServerContextTestSpouseTest -TEST(ServerContextTestSpouseTest, ClientMetadata) { +TEST(ServerContextTestSpouseTest, ClientMetadataHandle) { grpc::ServerContext context; grpc::testing::ServerContextTestSpouse spouse(&context); @@ -81,7 +81,7 @@ TEST(ServerContextTestSpouseTest, InitialMetadata) { ASSERT_EQ(metadata, spouse.GetInitialMetadata()); } -TEST(ServerContextTestSpouseTest, TrailingMetadata) { +TEST(ServerContextTestSpouseTest, ServerMetadataHandle) { grpc::ServerContext context; grpc::testing::ServerContextTestSpouse spouse(&context); std::multimap metadata; diff --git a/test/cpp/test/server_context_test_spouse_test.cc b/test/cpp/test/server_context_test_spouse_test.cc index 42e9e89d001..cb11dcbf0e4 100644 --- a/test/cpp/test/server_context_test_spouse_test.cc +++ b/test/cpp/test/server_context_test_spouse_test.cc @@ -47,7 +47,7 @@ bool ClientMetadataContains(const ServerContext& context, return false; } -TEST(ServerContextTestSpouseTest, ClientMetadata) { +TEST(ServerContextTestSpouseTest, ClientMetadataHandle) { ServerContext context; ServerContextTestSpouse spouse(&context); @@ -73,7 +73,7 @@ TEST(ServerContextTestSpouseTest, InitialMetadata) { ASSERT_EQ(metadata, spouse.GetInitialMetadata()); } -TEST(ServerContextTestSpouseTest, TrailingMetadata) { +TEST(ServerContextTestSpouseTest, ServerMetadataHandle) { ServerContext context; ServerContextTestSpouse spouse(&context); std::multimap metadata; diff --git a/tools/doxygen/Doxyfile.c++.internal b/tools/doxygen/Doxyfile.c++.internal index 15cfc35c1ea..71749c3a920 100644 --- a/tools/doxygen/Doxyfile.c++.internal +++ b/tools/doxygen/Doxyfile.c++.internal @@ -2190,6 +2190,8 @@ src/core/lib/promise/detail/promise_like.h \ src/core/lib/promise/detail/status.h \ src/core/lib/promise/detail/switch.h \ src/core/lib/promise/exec_ctx_wakeup_scheduler.h \ +src/core/lib/promise/intra_activity_waiter.h \ +src/core/lib/promise/latch.h \ src/core/lib/promise/loop.h \ src/core/lib/promise/map.h \ src/core/lib/promise/poll.h \ diff --git a/tools/doxygen/Doxyfile.core.internal b/tools/doxygen/Doxyfile.core.internal index c84f9d34609..7d4b7f074a8 100644 --- a/tools/doxygen/Doxyfile.core.internal +++ b/tools/doxygen/Doxyfile.core.internal @@ -1985,6 +1985,8 @@ src/core/lib/promise/detail/promise_like.h \ src/core/lib/promise/detail/status.h \ src/core/lib/promise/detail/switch.h \ src/core/lib/promise/exec_ctx_wakeup_scheduler.h \ +src/core/lib/promise/intra_activity_waiter.h \ +src/core/lib/promise/latch.h \ src/core/lib/promise/loop.h \ src/core/lib/promise/map.h \ src/core/lib/promise/poll.h \