Change main argument of call promise to be a struct (#29019)

* introduce call args

* bs

* x

* Automated change: Fix sanity tests

* fix

* Simplify naming

* tweak

Co-authored-by: ctiller <ctiller@users.noreply.github.com>
pull/29091/head
Craig Tiller 3 years ago committed by GitHub
parent 4682e041e1
commit 138c4667c9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      BUILD
  2. 4
      build_autogenerated.yaml
  3. 4
      gRPC-C++.podspec
  4. 4
      gRPC-Core.podspec
  5. 2
      grpc.gemspec
  6. 2
      package.xml
  7. 18
      src/core/ext/filters/client_idle/client_idle_filter.cc
  8. 13
      src/core/ext/filters/http/client_authority_filter.cc
  9. 5
      src/core/ext/filters/http/client_authority_filter.h
  10. 5
      src/core/lib/channel/channel_stack.h
  11. 53
      src/core/lib/channel/promise_based_filter.cc
  12. 25
      src/core/lib/channel/promise_based_filter.h
  13. 13
      src/core/lib/security/authorization/grpc_server_authz_filter.cc
  14. 7
      src/core/lib/security/authorization/grpc_server_authz_filter.h
  15. 6
      src/core/lib/security/credentials/call_creds_util.cc
  16. 4
      src/core/lib/security/credentials/call_creds_util.h
  17. 6
      src/core/lib/security/credentials/composite/composite_credentials.cc
  18. 4
      src/core/lib/security/credentials/composite/composite_credentials.h
  19. 4
      src/core/lib/security/credentials/credentials.h
  20. 4
      src/core/lib/security/credentials/fake/fake_credentials.cc
  21. 4
      src/core/lib/security/credentials/fake/fake_credentials.h
  22. 4
      src/core/lib/security/credentials/iam/iam_credentials.cc
  23. 4
      src/core/lib/security/credentials/iam/iam_credentials.h
  24. 4
      src/core/lib/security/credentials/jwt/jwt_credentials.cc
  25. 4
      src/core/lib/security/credentials/jwt/jwt_credentials.h
  26. 10
      src/core/lib/security/credentials/oauth2/oauth2_credentials.cc
  27. 12
      src/core/lib/security/credentials/oauth2/oauth2_credentials.h
  28. 10
      src/core/lib/security/credentials/plugin/plugin_credentials.cc
  29. 12
      src/core/lib/security/credentials/plugin/plugin_credentials.h
  30. 9
      src/core/lib/security/transport/auth_filters.h
  31. 31
      src/core/lib/security/transport/client_auth_filter.cc
  32. 21
      src/core/lib/transport/transport.h
  33. 4
      src/core/lib/transport/transport_impl.h
  34. 40
      test/core/filters/client_authority_filter_test.cc
  35. 8
      test/core/security/credentials_test.cc
  36. 4
      test/core/security/oauth2_utils.cc
  37. 4
      test/cpp/cocoapods/test/server_context_test_spouse_test.mm
  38. 4
      test/cpp/test/server_context_test_spouse_test.cc
  39. 2
      tools/doxygen/Doxyfile.c++.internal
  40. 2
      tools/doxygen/Doxyfile.core.internal

@ -2182,6 +2182,7 @@ grpc_cc_library(
"grpc_trace",
"iomgr_port",
"json",
"latch",
"memory_quota",
"orphanable",
"promise",
@ -4172,6 +4173,7 @@ grpc_cc_library(
deps = [
"arena",
"arena_promise",
"capture",
"config",
"gpr_base",
"grpc_base",

@ -820,6 +820,8 @@ libs:
- src/core/lib/promise/detail/status.h
- src/core/lib/promise/detail/switch.h
- src/core/lib/promise/exec_ctx_wakeup_scheduler.h
- src/core/lib/promise/intra_activity_waiter.h
- src/core/lib/promise/latch.h
- src/core/lib/promise/loop.h
- src/core/lib/promise/map.h
- src/core/lib/promise/poll.h
@ -1991,6 +1993,8 @@ libs:
- src/core/lib/promise/detail/status.h
- src/core/lib/promise/detail/switch.h
- src/core/lib/promise/exec_ctx_wakeup_scheduler.h
- src/core/lib/promise/intra_activity_waiter.h
- src/core/lib/promise/latch.h
- src/core/lib/promise/loop.h
- src/core/lib/promise/map.h
- src/core/lib/promise/poll.h

4
gRPC-C++.podspec generated

@ -789,6 +789,8 @@ Pod::Spec.new do |s|
'src/core/lib/promise/detail/status.h',
'src/core/lib/promise/detail/switch.h',
'src/core/lib/promise/exec_ctx_wakeup_scheduler.h',
'src/core/lib/promise/intra_activity_waiter.h',
'src/core/lib/promise/latch.h',
'src/core/lib/promise/loop.h',
'src/core/lib/promise/map.h',
'src/core/lib/promise/poll.h',
@ -1590,6 +1592,8 @@ Pod::Spec.new do |s|
'src/core/lib/promise/detail/status.h',
'src/core/lib/promise/detail/switch.h',
'src/core/lib/promise/exec_ctx_wakeup_scheduler.h',
'src/core/lib/promise/intra_activity_waiter.h',
'src/core/lib/promise/latch.h',
'src/core/lib/promise/loop.h',
'src/core/lib/promise/map.h',
'src/core/lib/promise/poll.h',

4
gRPC-Core.podspec generated

@ -1292,6 +1292,8 @@ Pod::Spec.new do |s|
'src/core/lib/promise/detail/status.h',
'src/core/lib/promise/detail/switch.h',
'src/core/lib/promise/exec_ctx_wakeup_scheduler.h',
'src/core/lib/promise/intra_activity_waiter.h',
'src/core/lib/promise/latch.h',
'src/core/lib/promise/loop.h',
'src/core/lib/promise/map.h',
'src/core/lib/promise/poll.h',
@ -2188,6 +2190,8 @@ Pod::Spec.new do |s|
'src/core/lib/promise/detail/status.h',
'src/core/lib/promise/detail/switch.h',
'src/core/lib/promise/exec_ctx_wakeup_scheduler.h',
'src/core/lib/promise/intra_activity_waiter.h',
'src/core/lib/promise/latch.h',
'src/core/lib/promise/loop.h',
'src/core/lib/promise/map.h',
'src/core/lib/promise/poll.h',

2
grpc.gemspec generated

@ -1211,6 +1211,8 @@ Gem::Specification.new do |s|
s.files += %w( src/core/lib/promise/detail/status.h )
s.files += %w( src/core/lib/promise/detail/switch.h )
s.files += %w( src/core/lib/promise/exec_ctx_wakeup_scheduler.h )
s.files += %w( src/core/lib/promise/intra_activity_waiter.h )
s.files += %w( src/core/lib/promise/latch.h )
s.files += %w( src/core/lib/promise/loop.h )
s.files += %w( src/core/lib/promise/map.h )
s.files += %w( src/core/lib/promise/poll.h )

2
package.xml generated

@ -1191,6 +1191,8 @@
<file baseinstalldir="/" name="src/core/lib/promise/detail/status.h" role="src" />
<file baseinstalldir="/" name="src/core/lib/promise/detail/switch.h" role="src" />
<file baseinstalldir="/" name="src/core/lib/promise/exec_ctx_wakeup_scheduler.h" role="src" />
<file baseinstalldir="/" name="src/core/lib/promise/intra_activity_waiter.h" role="src" />
<file baseinstalldir="/" name="src/core/lib/promise/latch.h" role="src" />
<file baseinstalldir="/" name="src/core/lib/promise/loop.h" role="src" />
<file baseinstalldir="/" name="src/core/lib/promise/map.h" role="src" />
<file baseinstalldir="/" name="src/core/lib/promise/poll.h" role="src" />

@ -76,9 +76,8 @@ class ClientIdleFilter : public ChannelFilter {
ClientIdleFilter& operator=(ClientIdleFilter&&) = default;
// Construct a promise for one call.
ArenaPromise<TrailingMetadata> MakeCallPromise(
ClientInitialMetadata initial_metadata,
NextPromiseFactory next_promise_factory) override;
ArenaPromise<ServerMetadataHandle> MakeCallPromise(
CallArgs call_args, NextPromiseFactory next_promise_factory) override;
bool StartTransportOp(grpc_transport_op* op) override;
@ -116,15 +115,14 @@ absl::StatusOr<ClientIdleFilter> ClientIdleFilter::Create(
}
// Construct a promise for one call.
ArenaPromise<TrailingMetadata> ClientIdleFilter::MakeCallPromise(
ClientInitialMetadata initial_metadata,
NextPromiseFactory next_promise_factory) {
ArenaPromise<ServerMetadataHandle> ClientIdleFilter::MakeCallPromise(
CallArgs call_args, NextPromiseFactory next_promise_factory) {
using Decrementer = std::unique_ptr<ClientIdleFilter, CallCountDecreaser>;
IncreaseCallCount();
return ArenaPromise<TrailingMetadata>(Capture(
[](Decrementer*, ArenaPromise<TrailingMetadata>* next)
-> Poll<TrailingMetadata> { return (*next)(); },
Decrementer(this), next_promise_factory(std::move(initial_metadata))));
return ArenaPromise<ServerMetadataHandle>(
Capture([](Decrementer*, ArenaPromise<ServerMetadataHandle>* next)
-> Poll<ServerMetadataHandle> { return (*next)(); },
Decrementer(this), next_promise_factory(std::move(call_args))));
}
bool ClientIdleFilter::StartTransportOp(grpc_transport_op* op) {

@ -57,16 +57,17 @@ absl::StatusOr<ClientAuthorityFilter> ClientAuthorityFilter::Create(
return ClientAuthorityFilter(Slice::FromCopiedString(default_authority_str));
}
ArenaPromise<TrailingMetadata> ClientAuthorityFilter::MakeCallPromise(
ClientInitialMetadata initial_metadata,
NextPromiseFactory next_promise_factory) {
ArenaPromise<ServerMetadataHandle> ClientAuthorityFilter::MakeCallPromise(
CallArgs call_args, NextPromiseFactory next_promise_factory) {
// If no authority is set, set the default authority.
if (initial_metadata->get_pointer(HttpAuthorityMetadata()) == nullptr) {
initial_metadata->Set(HttpAuthorityMetadata(), default_authority_.Ref());
if (call_args.client_initial_metadata->get_pointer(HttpAuthorityMetadata()) ==
nullptr) {
call_args.client_initial_metadata->Set(HttpAuthorityMetadata(),
default_authority_.Ref());
}
// We have no asynchronous work, so we can just ask the next promise to run,
// passing down initial_metadata.
return next_promise_factory(std::move(initial_metadata));
return next_promise_factory(std::move(call_args));
}
namespace {

@ -37,9 +37,8 @@ class ClientAuthorityFilter final : public ChannelFilter {
const grpc_channel_args* args, ChannelFilter::Args);
// Construct a promise for one call.
ArenaPromise<TrailingMetadata> MakeCallPromise(
ClientInitialMetadata initial_metadata,
NextPromiseFactory next_promise_factory) override;
ArenaPromise<ServerMetadataHandle> MakeCallPromise(
CallArgs call_args, NextPromiseFactory next_promise_factory) override;
private:
explicit ClientAuthorityFilter(Slice default_authority)

@ -122,9 +122,8 @@ struct grpc_channel_filter {
- allocation of memory for call data
There is an on-going migration to move all filters to providing this, and
then to drop start_transport_stream_op_batch. */
grpc_core::ArenaPromise<grpc_core::TrailingMetadata> (*make_call_promise)(
grpc_channel_element* elem,
grpc_core::ClientInitialMetadata initial_metadata,
grpc_core::ArenaPromise<grpc_core::ServerMetadataHandle> (*make_call_promise)(
grpc_channel_element* elem, grpc_core::CallArgs call_args,
grpc_core::NextPromiseFactory next_promise_factory);
/* Called to handle channel level operations - e.g. new calls, or transport
closure.

@ -136,7 +136,7 @@ void ClientCallData::Cancel(grpc_error_handle error) {
GRPC_ERROR_UNREF(cancelled_error_);
cancelled_error_ = GRPC_ERROR_REF(error);
// Stop running the promise.
promise_ = ArenaPromise<TrailingMetadata>();
promise_ = ArenaPromise<ServerMetadataHandle>();
// If we have an op queued, fail that op.
// Record what we've done.
if (send_initial_state_ == SendInitialState::kQueued) {
@ -176,10 +176,12 @@ void ClientCallData::StartPromise() {
{
ScopedActivity activity(this);
promise_ = filter->MakeCallPromise(
WrapMetadata(send_initial_metadata_batch_->payload
->send_initial_metadata.send_initial_metadata),
[this](ClientInitialMetadata initial_metadata) {
return MakeNextPromise(std::move(initial_metadata));
CallArgs{
WrapMetadata(send_initial_metadata_batch_->payload
->send_initial_metadata.send_initial_metadata),
nullptr},
[this](CallArgs call_args) {
return MakeNextPromise(std::move(call_args));
});
}
// Poll once.
@ -203,12 +205,13 @@ void ClientCallData::HookRecvTrailingMetadata(
// Effectively:
// - put the modified initial metadata into the batch to be sent down.
// - return a wrapper around PollTrailingMetadata as the promise.
ArenaPromise<TrailingMetadata> ClientCallData::MakeNextPromise(
ClientInitialMetadata initial_metadata) {
ArenaPromise<ServerMetadataHandle> ClientCallData::MakeNextPromise(
CallArgs call_args) {
GPR_ASSERT(send_initial_state_ == SendInitialState::kQueued);
send_initial_metadata_batch_->payload->send_initial_metadata
.send_initial_metadata = UnwrapMetadata(std::move(initial_metadata));
return ArenaPromise<TrailingMetadata>(
.send_initial_metadata =
UnwrapMetadata(std::move(call_args.client_initial_metadata));
return ArenaPromise<ServerMetadataHandle>(
[this]() { return PollTrailingMetadata(); });
}
@ -216,7 +219,7 @@ ArenaPromise<TrailingMetadata> ClientCallData::MakeNextPromise(
// First poll: send the send_initial_metadata op down the stack.
// All polls: await receiving the trailing metadata, then return it to the
// application.
Poll<TrailingMetadata> ClientCallData::PollTrailingMetadata() {
Poll<ServerMetadataHandle> ClientCallData::PollTrailingMetadata() {
if (send_initial_state_ == SendInitialState::kQueued) {
// First poll: pass the send_initial_metadata op down the stack.
GPR_ASSERT(send_initial_metadata_batch_ != nullptr);
@ -274,7 +277,7 @@ void ClientCallData::RecvTrailingMetadataReady(grpc_error_handle error) {
WakeInsideCombiner();
}
// Given an error, fill in TrailingMetadata to represent that error.
// Given an error, fill in ServerMetadataHandle to represent that error.
void ClientCallData::SetStatusFromError(grpc_metadata_batch* metadata,
grpc_error_handle error) {
grpc_status_code status_code = GRPC_STATUS_UNKNOWN;
@ -298,13 +301,13 @@ void ClientCallData::WakeInsideCombiner() {
case SendInitialState::kQueued:
case SendInitialState::kForwarded: {
// Poll the promise once since we're waiting for it.
Poll<TrailingMetadata> poll;
Poll<ServerMetadataHandle> poll;
{
ScopedActivity activity(this);
poll = promise_();
}
if (auto* r = absl::get_if<TrailingMetadata>(&poll)) {
promise_ = ArenaPromise<TrailingMetadata>();
if (auto* r = absl::get_if<ServerMetadataHandle>(&poll)) {
promise_ = ArenaPromise<ServerMetadataHandle>();
auto* md = UnwrapMetadata(std::move(*r));
bool destroy_md = true;
if (recv_trailing_state_ == RecvTrailingState::kComplete) {
@ -505,7 +508,7 @@ void ServerCallData::Cancel(grpc_error_handle error) {
GRPC_ERROR_UNREF(cancelled_error_);
cancelled_error_ = GRPC_ERROR_REF(error);
// Stop running the promise.
promise_ = ArenaPromise<TrailingMetadata>();
promise_ = ArenaPromise<ServerMetadataHandle>();
if (send_trailing_state_ == SendTrailingState::kQueued) {
send_trailing_state_ = SendTrailingState::kCancelled;
struct FailBatch : public grpc_closure {
@ -534,20 +537,20 @@ void ServerCallData::Cancel(grpc_error_handle error) {
// Effectively:
// - put the modified initial metadata into the batch being sent up.
// - return a wrapper around PollTrailingMetadata as the promise.
ArenaPromise<TrailingMetadata> ServerCallData::MakeNextPromise(
ClientInitialMetadata initial_metadata) {
ArenaPromise<ServerMetadataHandle> ServerCallData::MakeNextPromise(
CallArgs call_args) {
GPR_ASSERT(recv_initial_state_ == RecvInitialState::kComplete);
GPR_ASSERT(UnwrapMetadata(std::move(initial_metadata)) ==
GPR_ASSERT(UnwrapMetadata(std::move(call_args.client_initial_metadata)) ==
recv_initial_metadata_);
forward_recv_initial_metadata_callback_ = true;
return ArenaPromise<TrailingMetadata>(
return ArenaPromise<ServerMetadataHandle>(
[this]() { return PollTrailingMetadata(); });
}
// Wrapper to make it look like we're calling the next filter as a promise.
// All polls: await sending the trailing metadata, then foward it down the
// stack.
Poll<TrailingMetadata> ServerCallData::PollTrailingMetadata() {
Poll<ServerMetadataHandle> ServerCallData::PollTrailingMetadata() {
switch (send_trailing_state_) {
case SendTrailingState::kInitial:
return Pending{};
@ -587,9 +590,9 @@ void ServerCallData::RecvInitialMetadataReady(grpc_error_handle error) {
// Construct the promise.
ChannelFilter* filter = static_cast<ChannelFilter*>(elem()->channel_data);
promise_ = filter->MakeCallPromise(
WrapMetadata(recv_initial_metadata_),
[this](ClientInitialMetadata initial_metadata) {
return MakeNextPromise(std::move(initial_metadata));
CallArgs{WrapMetadata(recv_initial_metadata_), nullptr},
[this](CallArgs call_args) {
return MakeNextPromise(std::move(call_args));
});
// Poll once.
bool own_error = false;
@ -610,12 +613,12 @@ void ServerCallData::WakeInsideCombiner(
bool forward_send_trailing_metadata = false;
is_polling_ = true;
if (recv_initial_state_ == RecvInitialState::kComplete) {
Poll<TrailingMetadata> poll;
Poll<ServerMetadataHandle> poll;
{
ScopedActivity activity(this);
poll = promise_();
}
if (auto* r = absl::get_if<TrailingMetadata>(&poll)) {
if (auto* r = absl::get_if<ServerMetadataHandle>(&poll)) {
auto* md = UnwrapMetadata(std::move(*r));
bool destroy_md = true;
switch (send_trailing_state_) {

@ -54,9 +54,8 @@ class ChannelFilter {
};
// Construct a promise for one call.
virtual ArenaPromise<TrailingMetadata> MakeCallPromise(
ClientInitialMetadata initial_metadata,
NextPromiseFactory next_promise_factory) = 0;
virtual ArenaPromise<ServerMetadataHandle> MakeCallPromise(
CallArgs call_args, NextPromiseFactory next_promise_factory) = 0;
// Start a legacy transport op
// Return true if the op was handled, false if it should be passed to the
@ -205,17 +204,16 @@ class ClientCallData : public BaseCallData {
// Effectively:
// - put the modified initial metadata into the batch to be sent down.
// - return a wrapper around PollTrailingMetadata as the promise.
ArenaPromise<TrailingMetadata> MakeNextPromise(
ClientInitialMetadata initial_metadata);
ArenaPromise<ServerMetadataHandle> MakeNextPromise(CallArgs call_args);
// Wrapper to make it look like we're calling the next filter as a promise.
// First poll: send the send_initial_metadata op down the stack.
// All polls: await receiving the trailing metadata, then return it to the
// application.
Poll<TrailingMetadata> PollTrailingMetadata();
Poll<ServerMetadataHandle> PollTrailingMetadata();
static void RecvTrailingMetadataReadyCallback(void* arg,
grpc_error_handle error);
void RecvTrailingMetadataReady(grpc_error_handle error);
// Given an error, fill in TrailingMetadata to represent that error.
// Given an error, fill in ServerMetadataHandle to represent that error.
void SetStatusFromError(grpc_metadata_batch* metadata,
grpc_error_handle error);
// Wakeup and poll the promise if appropriate.
@ -223,7 +221,7 @@ class ClientCallData : public BaseCallData {
void OnWakeup() override;
// Contained promise
ArenaPromise<TrailingMetadata> promise_;
ArenaPromise<ServerMetadataHandle> promise_;
// Queued batch containing at least a send_initial_metadata op.
grpc_transport_stream_op_batch* send_initial_metadata_batch_ = nullptr;
// Pointer to where trailing metadata will be stored.
@ -289,12 +287,11 @@ class ServerCallData : public BaseCallData {
// Effectively:
// - put the modified initial metadata into the batch being sent up.
// - return a wrapper around PollTrailingMetadata as the promise.
ArenaPromise<TrailingMetadata> MakeNextPromise(
ClientInitialMetadata initial_metadata);
ArenaPromise<ServerMetadataHandle> MakeNextPromise(CallArgs call_args);
// Wrapper to make it look like we're calling the next filter as a promise.
// All polls: await sending the trailing metadata, then foward it down the
// stack.
Poll<TrailingMetadata> PollTrailingMetadata();
Poll<ServerMetadataHandle> PollTrailingMetadata();
static void RecvInitialMetadataReadyCallback(void* arg,
grpc_error_handle error);
void RecvInitialMetadataReady(grpc_error_handle error);
@ -303,7 +300,7 @@ class ServerCallData : public BaseCallData {
void OnWakeup() override;
// Contained promise
ArenaPromise<TrailingMetadata> promise_;
ArenaPromise<ServerMetadataHandle> promise_;
// Pointer to where initial metadata will be stored.
grpc_metadata_batch* recv_initial_metadata_ = nullptr;
// Closure to call when we're done with the trailing metadata.
@ -366,10 +363,10 @@ MakePromiseBasedFilter(const char* name) {
static_cast<CallData*>(elem->call_data)->StartBatch(batch);
},
// make_call_promise
[](grpc_channel_element* elem, ClientInitialMetadata initial_metadata,
[](grpc_channel_element* elem, CallArgs call_args,
NextPromiseFactory next_promise_factory) {
return static_cast<F*>(elem->channel_data)
->MakeCallPromise(std::move(initial_metadata),
->MakeCallPromise(std::move(call_args),
std::move(next_promise_factory));
},
// start_transport_op

@ -49,7 +49,7 @@ absl::StatusOr<GrpcServerAuthzFilter> GrpcServerAuthzFilter::Create(
}
bool GrpcServerAuthzFilter::IsAuthorized(
const ClientInitialMetadata& initial_metadata) {
const ClientMetadataHandle& initial_metadata) {
EvaluateArgs args(initial_metadata.get(), &per_channel_evaluate_args_);
if (GRPC_TRACE_FLAG_ENABLED(grpc_authz_trace)) {
gpr_log(GPR_DEBUG,
@ -92,14 +92,13 @@ bool GrpcServerAuthzFilter::IsAuthorized(
return false;
}
ArenaPromise<TrailingMetadata> GrpcServerAuthzFilter::MakeCallPromise(
ClientInitialMetadata initial_metadata,
NextPromiseFactory next_promise_factory) {
if (!IsAuthorized(initial_metadata)) {
return ArenaPromise<TrailingMetadata>(Immediate(TrailingMetadata(
ArenaPromise<ServerMetadataHandle> GrpcServerAuthzFilter::MakeCallPromise(
CallArgs call_args, NextPromiseFactory next_promise_factory) {
if (!IsAuthorized(call_args.client_initial_metadata)) {
return ArenaPromise<ServerMetadataHandle>(Immediate(ServerMetadataHandle(
absl::PermissionDeniedError("Unauthorized RPC request rejected."))));
}
return next_promise_factory(std::move(initial_metadata));
return next_promise_factory(std::move(call_args));
}
const grpc_channel_filter GrpcServerAuthzFilter::kFilterVtable =

@ -30,16 +30,15 @@ class GrpcServerAuthzFilter final : public ChannelFilter {
static absl::StatusOr<GrpcServerAuthzFilter> Create(
const grpc_channel_args* args, ChannelFilter::Args);
ArenaPromise<TrailingMetadata> MakeCallPromise(
ClientInitialMetadata initial_metadata,
NextPromiseFactory next_promise_factory) override;
ArenaPromise<ServerMetadataHandle> MakeCallPromise(
CallArgs call_args, NextPromiseFactory next_promise_factory) override;
private:
GrpcServerAuthzFilter(
RefCountedPtr<grpc_auth_context> auth_context, grpc_endpoint* endpoint,
RefCountedPtr<grpc_authorization_policy_provider> provider);
bool IsAuthorized(const ClientInitialMetadata& initial_metadata);
bool IsAuthorized(const ClientMetadataHandle& initial_metadata);
RefCountedPtr<grpc_auth_context> auth_context_;
EvaluateArgs::PerChannelArgs per_channel_evaluate_args_;

@ -31,7 +31,7 @@ struct ServiceUrlAndMethod {
};
ServiceUrlAndMethod MakeServiceUrlAndMethod(
const ClientInitialMetadata& initial_metadata,
const ClientMetadataHandle& initial_metadata,
const grpc_call_credentials::GetRequestMetadataArgs* args) {
auto service =
initial_metadata->get_pointer(HttpPathMetadata())->as_string_view();
@ -65,13 +65,13 @@ ServiceUrlAndMethod MakeServiceUrlAndMethod(
} // namespace
std::string MakeJwtServiceUrl(
const ClientInitialMetadata& initial_metadata,
const ClientMetadataHandle& initial_metadata,
const grpc_call_credentials::GetRequestMetadataArgs* args) {
return MakeServiceUrlAndMethod(initial_metadata, args).service_url;
}
grpc_auth_metadata_context MakePluginAuthMetadataContext(
const ClientInitialMetadata& initial_metadata,
const ClientMetadataHandle& initial_metadata,
const grpc_call_credentials::GetRequestMetadataArgs* args) {
auto fields = MakeServiceUrlAndMethod(initial_metadata, args);
grpc_auth_metadata_context ctx;

@ -29,12 +29,12 @@ namespace grpc_core {
// Helper function to construct service URL for jwt call creds.
std::string MakeJwtServiceUrl(
const ClientInitialMetadata& initial_metadata,
const ClientMetadataHandle& initial_metadata,
const grpc_call_credentials::GetRequestMetadataArgs* args);
// Helper function to construct context for plugin call creds.
grpc_auth_metadata_context MakePluginAuthMetadataContext(
const ClientInitialMetadata& initial_metadata,
const ClientMetadataHandle& initial_metadata,
const grpc_call_credentials::GetRequestMetadataArgs* args);
} // namespace grpc_core

@ -43,15 +43,15 @@ const char kCredentialsTypeComposite[] = "composite";
/* -- Composite call credentials. -- */
grpc_core::ArenaPromise<absl::StatusOr<grpc_core::ClientInitialMetadata>>
grpc_core::ArenaPromise<absl::StatusOr<grpc_core::ClientMetadataHandle>>
grpc_composite_call_credentials::GetRequestMetadata(
grpc_core::ClientInitialMetadata initial_metadata,
grpc_core::ClientMetadataHandle initial_metadata,
const grpc_call_credentials::GetRequestMetadataArgs* args) {
auto self = Ref();
return TrySeqIter(
inner_.begin(), inner_.end(), std::move(initial_metadata),
[self, args](const grpc_core::RefCountedPtr<grpc_call_credentials>& creds,
grpc_core::ClientInitialMetadata initial_metadata) {
grpc_core::ClientMetadataHandle initial_metadata) {
return creds->GetRequestMetadata(std::move(initial_metadata), args);
});
}

@ -90,8 +90,8 @@ class grpc_composite_call_credentials : public grpc_call_credentials {
grpc_core::RefCountedPtr<grpc_call_credentials> creds2);
~grpc_composite_call_credentials() override = default;
grpc_core::ArenaPromise<absl::StatusOr<grpc_core::ClientInitialMetadata>>
GetRequestMetadata(grpc_core::ClientInitialMetadata initial_metadata,
grpc_core::ArenaPromise<absl::StatusOr<grpc_core::ClientMetadataHandle>>
GetRequestMetadata(grpc_core::ClientMetadataHandle initial_metadata,
const GetRequestMetadataArgs* args) override;
grpc_security_level min_security_level() const override {

@ -216,8 +216,8 @@ struct grpc_call_credentials
~grpc_call_credentials() override = default;
virtual grpc_core::ArenaPromise<
absl::StatusOr<grpc_core::ClientInitialMetadata>>
GetRequestMetadata(grpc_core::ClientInitialMetadata initial_metadata,
absl::StatusOr<grpc_core::ClientMetadataHandle>>
GetRequestMetadata(grpc_core::ClientMetadataHandle initial_metadata,
const GetRequestMetadataArgs* args) = 0;
virtual grpc_security_level min_security_level() const {

@ -97,9 +97,9 @@ const char* grpc_fake_transport_get_expected_targets(
/* -- Metadata-only test credentials. -- */
grpc_core::ArenaPromise<absl::StatusOr<grpc_core::ClientInitialMetadata>>
grpc_core::ArenaPromise<absl::StatusOr<grpc_core::ClientMetadataHandle>>
grpc_md_only_test_credentials::GetRequestMetadata(
grpc_core::ClientInitialMetadata initial_metadata,
grpc_core::ClientMetadataHandle initial_metadata,
const grpc_call_credentials::GetRequestMetadataArgs*) {
initial_metadata->Append(
key_.as_string_view(), value_.Ref(),

@ -65,8 +65,8 @@ class grpc_md_only_test_credentials : public grpc_call_credentials {
key_(grpc_core::Slice::FromCopiedString(md_key)),
value_(grpc_core::Slice::FromCopiedString(md_value)) {}
grpc_core::ArenaPromise<absl::StatusOr<grpc_core::ClientInitialMetadata>>
GetRequestMetadata(grpc_core::ClientInitialMetadata initial_metadata,
grpc_core::ArenaPromise<absl::StatusOr<grpc_core::ClientMetadataHandle>>
GetRequestMetadata(grpc_core::ClientMetadataHandle initial_metadata,
const GetRequestMetadataArgs* args) override;
std::string debug_string() override { return "MD only Test Credentials"; };

@ -31,9 +31,9 @@
#include "src/core/lib/promise/promise.h"
#include "src/core/lib/surface/api_trace.h"
grpc_core::ArenaPromise<absl::StatusOr<grpc_core::ClientInitialMetadata>>
grpc_core::ArenaPromise<absl::StatusOr<grpc_core::ClientMetadataHandle>>
grpc_google_iam_credentials::GetRequestMetadata(
grpc_core::ClientInitialMetadata initial_metadata,
grpc_core::ClientMetadataHandle initial_metadata,
const grpc_call_credentials::GetRequestMetadataArgs*) {
if (token_.has_value()) {
initial_metadata->Append(

@ -30,8 +30,8 @@ class grpc_google_iam_credentials : public grpc_call_credentials {
grpc_google_iam_credentials(const char* token,
const char* authority_selector);
grpc_core::ArenaPromise<absl::StatusOr<grpc_core::ClientInitialMetadata>>
GetRequestMetadata(grpc_core::ClientInitialMetadata initial_metadata,
grpc_core::ArenaPromise<absl::StatusOr<grpc_core::ClientMetadataHandle>>
GetRequestMetadata(grpc_core::ClientMetadataHandle initial_metadata,
const GetRequestMetadataArgs* args) override;
std::string debug_string() override { return debug_string_; }

@ -49,9 +49,9 @@ grpc_service_account_jwt_access_credentials::
gpr_mu_destroy(&cache_mu_);
}
grpc_core::ArenaPromise<absl::StatusOr<grpc_core::ClientInitialMetadata>>
grpc_core::ArenaPromise<absl::StatusOr<grpc_core::ClientMetadataHandle>>
grpc_service_account_jwt_access_credentials::GetRequestMetadata(
grpc_core::ClientInitialMetadata initial_metadata,
grpc_core::ClientMetadataHandle initial_metadata,
const grpc_call_credentials::GetRequestMetadataArgs* args) {
gpr_timespec refresh_threshold = gpr_time_from_seconds(
GRPC_SECURE_TOKEN_REFRESH_THRESHOLD_SECS, GPR_TIMESPAN);

@ -38,8 +38,8 @@ class grpc_service_account_jwt_access_credentials
gpr_timespec token_lifetime);
~grpc_service_account_jwt_access_credentials() override;
grpc_core::ArenaPromise<absl::StatusOr<grpc_core::ClientInitialMetadata>>
GetRequestMetadata(grpc_core::ClientInitialMetadata initial_metadata,
grpc_core::ArenaPromise<absl::StatusOr<grpc_core::ClientMetadataHandle>>
GetRequestMetadata(grpc_core::ClientMetadataHandle initial_metadata,
const GetRequestMetadataArgs* args) override;
const gpr_timespec& jwt_lifetime() const { return jwt_lifetime_; }

@ -281,9 +281,9 @@ void grpc_oauth2_token_fetcher_credentials::on_http_response(
delete r;
}
grpc_core::ArenaPromise<absl::StatusOr<grpc_core::ClientInitialMetadata>>
grpc_core::ArenaPromise<absl::StatusOr<grpc_core::ClientMetadataHandle>>
grpc_oauth2_token_fetcher_credentials::GetRequestMetadata(
grpc_core::ClientInitialMetadata initial_metadata,
grpc_core::ClientMetadataHandle initial_metadata,
const grpc_call_credentials::GetRequestMetadataArgs*) {
// Check if we can use the cached token.
absl::optional<grpc_core::Slice> cached_access_token_value;
@ -328,7 +328,7 @@ grpc_oauth2_token_fetcher_credentials::GetRequestMetadata(
}
return
[pending_request]()
-> grpc_core::Poll<absl::StatusOr<grpc_core::ClientInitialMetadata>> {
-> grpc_core::Poll<absl::StatusOr<grpc_core::ClientMetadataHandle>> {
if (!pending_request->done.load(std::memory_order_acquire)) {
return grpc_core::Pending{};
}
@ -696,9 +696,9 @@ grpc_call_credentials* grpc_sts_credentials_create(
// Oauth2 Access Token credentials.
//
grpc_core::ArenaPromise<absl::StatusOr<grpc_core::ClientInitialMetadata>>
grpc_core::ArenaPromise<absl::StatusOr<grpc_core::ClientMetadataHandle>>
grpc_access_token_credentials::GetRequestMetadata(
grpc_core::ClientInitialMetadata initial_metadata,
grpc_core::ClientMetadataHandle initial_metadata,
const grpc_call_credentials::GetRequestMetadataArgs*) {
initial_metadata->Append(
GRPC_AUTHORIZATION_METADATA_KEY, access_token_value_.Ref(),

@ -79,9 +79,9 @@ struct grpc_oauth2_pending_get_request_metadata
std::atomic<bool> done{false};
grpc_core::Waker waker;
grpc_polling_entity* pollent;
grpc_core::ClientInitialMetadata md;
grpc_core::ClientMetadataHandle md;
struct grpc_oauth2_pending_get_request_metadata* next;
absl::StatusOr<grpc_core::ClientInitialMetadata> result;
absl::StatusOr<grpc_core::ClientMetadataHandle> result;
};
// -- Oauth2 Token Fetcher credentials --
@ -94,8 +94,8 @@ class grpc_oauth2_token_fetcher_credentials : public grpc_call_credentials {
grpc_oauth2_token_fetcher_credentials();
~grpc_oauth2_token_fetcher_credentials() override;
grpc_core::ArenaPromise<absl::StatusOr<grpc_core::ClientInitialMetadata>>
GetRequestMetadata(grpc_core::ClientInitialMetadata initial_metadata,
grpc_core::ArenaPromise<absl::StatusOr<grpc_core::ClientMetadataHandle>>
GetRequestMetadata(grpc_core::ClientMetadataHandle initial_metadata,
const GetRequestMetadataArgs* args) override;
void on_http_response(grpc_credentials_metadata_request* r,
@ -152,8 +152,8 @@ class grpc_access_token_credentials final : public grpc_call_credentials {
public:
explicit grpc_access_token_credentials(const char* access_token);
grpc_core::ArenaPromise<absl::StatusOr<grpc_core::ClientInitialMetadata>>
GetRequestMetadata(grpc_core::ClientInitialMetadata initial_metadata,
grpc_core::ArenaPromise<absl::StatusOr<grpc_core::ClientMetadataHandle>>
GetRequestMetadata(grpc_core::ClientMetadataHandle initial_metadata,
const GetRequestMetadataArgs* args) override;
std::string debug_string() override;

@ -60,7 +60,7 @@ std::string grpc_plugin_credentials::debug_string() {
return debug_str;
}
absl::StatusOr<grpc_core::ClientInitialMetadata>
absl::StatusOr<grpc_core::ClientMetadataHandle>
grpc_plugin_credentials::PendingRequest::ProcessPluginResult(
const grpc_metadata* md, size_t num_md, grpc_status_code status,
const char* error_details) {
@ -96,12 +96,12 @@ grpc_plugin_credentials::PendingRequest::ProcessPluginResult(
});
}
if (!error.ok()) return std::move(error);
return grpc_core::ClientInitialMetadata(std::move(md_));
return grpc_core::ClientMetadataHandle(std::move(md_));
}
}
}
grpc_core::Poll<absl::StatusOr<grpc_core::ClientInitialMetadata>>
grpc_core::Poll<absl::StatusOr<grpc_core::ClientMetadataHandle>>
grpc_plugin_credentials::PendingRequest::PollAsyncResult() {
if (!ready_.load(std::memory_order_acquire)) {
return grpc_core::Pending{};
@ -137,9 +137,9 @@ void grpc_plugin_credentials::PendingRequest::RequestMetadataReady(
r->waker_.Wakeup();
}
grpc_core::ArenaPromise<absl::StatusOr<grpc_core::ClientInitialMetadata>>
grpc_core::ArenaPromise<absl::StatusOr<grpc_core::ClientMetadataHandle>>
grpc_plugin_credentials::GetRequestMetadata(
grpc_core::ClientInitialMetadata initial_metadata,
grpc_core::ClientMetadataHandle initial_metadata,
const grpc_call_credentials::GetRequestMetadataArgs* args) {
if (plugin_.get_metadata == nullptr) {
return grpc_core::Immediate(std::move(initial_metadata));

@ -35,8 +35,8 @@ struct grpc_plugin_credentials final : public grpc_call_credentials {
grpc_security_level min_security_level);
~grpc_plugin_credentials() override;
grpc_core::ArenaPromise<absl::StatusOr<grpc_core::ClientInitialMetadata>>
GetRequestMetadata(grpc_core::ClientInitialMetadata initial_metadata,
grpc_core::ArenaPromise<absl::StatusOr<grpc_core::ClientMetadataHandle>>
GetRequestMetadata(grpc_core::ClientMetadataHandle initial_metadata,
const GetRequestMetadataArgs* args) override;
std::string debug_string() override;
@ -45,7 +45,7 @@ struct grpc_plugin_credentials final : public grpc_call_credentials {
class PendingRequest : public grpc_core::RefCounted<PendingRequest> {
public:
PendingRequest(grpc_core::RefCountedPtr<grpc_plugin_credentials> creds,
grpc_core::ClientInitialMetadata initial_metadata,
grpc_core::ClientMetadataHandle initial_metadata,
const grpc_call_credentials::GetRequestMetadataArgs* args)
: call_creds_(std::move(creds)),
context_(
@ -60,11 +60,11 @@ struct grpc_plugin_credentials final : public grpc_call_credentials {
}
}
absl::StatusOr<grpc_core::ClientInitialMetadata> ProcessPluginResult(
absl::StatusOr<grpc_core::ClientMetadataHandle> ProcessPluginResult(
const grpc_metadata* md, size_t num_md, grpc_status_code status,
const char* error_details);
grpc_core::Poll<absl::StatusOr<grpc_core::ClientInitialMetadata>>
grpc_core::Poll<absl::StatusOr<grpc_core::ClientMetadataHandle>>
PollAsyncResult();
static void RequestMetadataReady(void* request, const grpc_metadata* md,
@ -80,7 +80,7 @@ struct grpc_plugin_credentials final : public grpc_call_credentials {
grpc_core::Activity::current()->MakeNonOwningWaker()};
grpc_core::RefCountedPtr<grpc_plugin_credentials> call_creds_;
grpc_auth_metadata_context context_;
grpc_core::ClientInitialMetadata md_;
grpc_core::ClientMetadataHandle md_;
// final status
absl::InlinedVector<grpc_metadata, 2> metadata_;
std::string error_details_;

@ -41,17 +41,16 @@ class ClientAuthFilter final : public ChannelFilter {
ChannelFilter::Args);
// Construct a promise for one call.
ArenaPromise<TrailingMetadata> MakeCallPromise(
ClientInitialMetadata initial_metadata,
NextPromiseFactory next_promise_factory) override;
ArenaPromise<ServerMetadataHandle> MakeCallPromise(
CallArgs call_args, NextPromiseFactory next_promise_factory) override;
private:
ClientAuthFilter(
RefCountedPtr<grpc_channel_security_connector> security_connector,
RefCountedPtr<grpc_auth_context> auth_context);
ArenaPromise<absl::StatusOr<ClientInitialMetadata>> GetCallCredsMetadata(
ClientInitialMetadata initial_metadata);
ArenaPromise<absl::StatusOr<CallArgs>> GetCallCredsMetadata(
CallArgs call_args);
// Contains refs to security connector and auth context.
grpc_call_credentials::GetRequestMetadataArgs args_;

@ -31,6 +31,7 @@
#include "src/core/lib/channel/channel_stack.h"
#include "src/core/lib/channel/promise_based_filter.h"
#include "src/core/lib/gpr/string.h"
#include "src/core/lib/gprpp/capture.h"
#include "src/core/lib/iomgr/error.h"
#include "src/core/lib/profiling/timers.h"
#include "src/core/lib/promise/promise.h"
@ -98,8 +99,8 @@ ClientAuthFilter::ClientAuthFilter(
RefCountedPtr<grpc_auth_context> auth_context)
: args_{std::move(security_connector), std::move(auth_context)} {}
ArenaPromise<absl::StatusOr<ClientInitialMetadata>>
ClientAuthFilter::GetCallCredsMetadata(ClientInitialMetadata initial_metadata) {
ArenaPromise<absl::StatusOr<CallArgs>> ClientAuthFilter::GetCallCredsMetadata(
CallArgs call_args) {
auto* ctx = static_cast<grpc_client_security_context*>(
GetContext<grpc_call_context_element>()[GRPC_CONTEXT_SECURITY].value);
grpc_call_credentials* channel_call_creds =
@ -108,8 +109,7 @@ ClientAuthFilter::GetCallCredsMetadata(ClientInitialMetadata initial_metadata) {
if (channel_call_creds == nullptr && !call_creds_has_md) {
/* Skip sending metadata altogether. */
return Immediate(
absl::StatusOr<ClientInitialMetadata>(std::move(initial_metadata)));
return Immediate(absl::StatusOr<CallArgs>(std::move(call_args)));
}
RefCountedPtr<grpc_call_credentials> creds;
@ -148,12 +148,20 @@ ClientAuthFilter::GetCallCredsMetadata(ClientInitialMetadata initial_metadata) {
"transfer call credential."));
}
return creds->GetRequestMetadata(std::move(initial_metadata), &args_);
auto client_initial_metadata = std::move(call_args.client_initial_metadata);
return TrySeq(
creds->GetRequestMetadata(std::move(client_initial_metadata), &args_),
Capture(
[](CallArgs* rest_of_args, ClientMetadataHandle new_metadata) {
rest_of_args->client_initial_metadata = std::move(new_metadata);
return Immediate<absl::StatusOr<CallArgs>>(
absl::StatusOr<CallArgs>(std::move(*rest_of_args)));
},
std::move(call_args)));
}
ArenaPromise<TrailingMetadata> ClientAuthFilter::MakeCallPromise(
ClientInitialMetadata initial_metadata,
NextPromiseFactory next_promise_factory) {
ArenaPromise<ServerMetadataHandle> ClientAuthFilter::MakeCallPromise(
CallArgs call_args, NextPromiseFactory next_promise_factory) {
auto* legacy_ctx = GetContext<grpc_call_context_element>();
if (legacy_ctx[GRPC_CONTEXT_SECURITY].value == nullptr) {
legacy_ctx[GRPC_CONTEXT_SECURITY].value =
@ -166,13 +174,14 @@ ArenaPromise<TrailingMetadata> ClientAuthFilter::MakeCallPromise(
legacy_ctx[GRPC_CONTEXT_SECURITY].value)
->auth_context = args_.auth_context;
auto* host = initial_metadata->get_pointer(HttpAuthorityMetadata());
auto* host =
call_args.client_initial_metadata->get_pointer(HttpAuthorityMetadata());
if (host == nullptr) {
return next_promise_factory(std::move(initial_metadata));
return next_promise_factory(std::move(call_args));
}
return TrySeq(args_.security_connector->CheckCallHost(
host->as_string_view(), args_.auth_context.get()),
GetCallCredsMetadata(std::move(initial_metadata)),
GetCallCredsMetadata(std::move(call_args)),
next_promise_factory);
}

@ -31,6 +31,7 @@
#include "src/core/lib/iomgr/pollset.h"
#include "src/core/lib/iomgr/pollset_set.h"
#include "src/core/lib/promise/arena_promise.h"
#include "src/core/lib/promise/latch.h"
#include "src/core/lib/resource_quota/arena.h"
#include "src/core/lib/slice/slice_internal.h"
#include "src/core/lib/transport/byte_stream.h"
@ -99,21 +100,33 @@ class MetadataHandle {
// Trailing metadata type
// TODO(ctiller): This should be a bespoke instance of MetadataMap<>
using TrailingMetadata = MetadataHandle<grpc_metadata_batch>;
using ServerMetadata = grpc_metadata_batch;
using ServerMetadataHandle = MetadataHandle<ServerMetadata>;
// Ok/not-ok check for trailing metadata, so that it can be used as result types
// for TrySeq.
inline bool IsStatusOk(const TrailingMetadata& m) {
inline bool IsStatusOk(const ServerMetadataHandle& m) {
return m->get(GrpcStatusMetadata()).value_or(GRPC_STATUS_UNKNOWN) ==
GRPC_STATUS_OK;
}
// Client initial metadata type
// TODO(ctiller): This should be a bespoke instance of MetadataMap<>
using ClientInitialMetadata = MetadataHandle<grpc_metadata_batch>;
using ClientMetadata = grpc_metadata_batch;
using ClientMetadataHandle = MetadataHandle<ClientMetadata>;
// Server initial metadata type
// TODO(ctiller): This should be a bespoke instance of MetadataMap<>
using ServerMetadataHandle = MetadataHandle<grpc_metadata_batch>;
struct CallArgs {
ClientMetadataHandle client_initial_metadata;
Latch<ServerMetadata*>* server_initial_metadata;
};
using NextPromiseFactory =
std::function<ArenaPromise<TrailingMetadata>(ClientInitialMetadata)>;
std::function<ArenaPromise<ServerMetadataHandle>(CallArgs)>;
} // namespace grpc_core
/* forward declarations */

@ -45,8 +45,8 @@ typedef struct grpc_transport_vtable {
- allocation of memory for call data (sizeof_stream may be ignored)
There is an on-going migration to move all filters to providing this, and
then to drop perform_stream_op. */
grpc_core::ArenaPromise<grpc_core::TrailingMetadata> (*make_call_promise)(
grpc_transport* self, grpc_core::ClientInitialMetadata initial_metadata,
grpc_core::ArenaPromise<grpc_core::ServerMetadataHandle> (*make_call_promise)(
grpc_transport* self, grpc_core::ClientMetadataHandle initial_metadata,
grpc_core::NextPromiseFactory next_promise_factory);
/* implementation of grpc_transport_set_pollset */

@ -71,18 +71,24 @@ TEST(ClientAuthorityFilterTest, PromiseCompletesImmediatelyAndSetsAuthority) {
// TODO(ctiller): use Activity here, once it's ready.
TestContext<Arena> context(arena.get());
auto promise = filter.MakeCallPromise(
ClientInitialMetadata::TestOnlyWrap(&initial_metadata_batch),
[&](ClientInitialMetadata initial_metadata) {
EXPECT_EQ(initial_metadata->get_pointer(HttpAuthorityMetadata())
CallArgs{
ClientMetadataHandle::TestOnlyWrap(&initial_metadata_batch),
nullptr,
},
[&](CallArgs call_args) {
EXPECT_EQ(call_args.client_initial_metadata
->get_pointer(HttpAuthorityMetadata())
->as_string_view(),
"foo.test.google.au");
seen = true;
return ArenaPromise<TrailingMetadata>([&]() -> Poll<TrailingMetadata> {
return TrailingMetadata::TestOnlyWrap(&trailing_metadata_batch);
});
return ArenaPromise<ServerMetadataHandle>(
[&]() -> Poll<ServerMetadataHandle> {
return ServerMetadataHandle::TestOnlyWrap(
&trailing_metadata_batch);
});
});
auto result = promise();
EXPECT_TRUE(absl::get_if<TrailingMetadata>(&result) != nullptr);
EXPECT_TRUE(absl::get_if<ServerMetadataHandle>(&result) != nullptr);
EXPECT_TRUE(seen);
}
@ -99,18 +105,24 @@ TEST(ClientAuthorityFilterTest,
// TODO(ctiller): use Activity here, once it's ready.
TestContext<Arena> context(arena.get());
auto promise = filter.MakeCallPromise(
ClientInitialMetadata::TestOnlyWrap(&initial_metadata_batch),
[&](ClientInitialMetadata initial_metadata) {
EXPECT_EQ(initial_metadata->get_pointer(HttpAuthorityMetadata())
CallArgs{
ClientMetadataHandle::TestOnlyWrap(&initial_metadata_batch),
nullptr,
},
[&](CallArgs call_args) {
EXPECT_EQ(call_args.client_initial_metadata
->get_pointer(HttpAuthorityMetadata())
->as_string_view(),
"bar.test.google.au");
seen = true;
return ArenaPromise<TrailingMetadata>([&]() -> Poll<TrailingMetadata> {
return TrailingMetadata::TestOnlyWrap(&trailing_metadata_batch);
});
return ArenaPromise<ServerMetadataHandle>(
[&]() -> Poll<ServerMetadataHandle> {
return ServerMetadataHandle::TestOnlyWrap(
&trailing_metadata_batch);
});
});
auto result = promise();
EXPECT_TRUE(absl::get_if<TrailingMetadata>(&result) != nullptr);
EXPECT_TRUE(absl::get_if<ServerMetadataHandle>(&result) != nullptr);
EXPECT_TRUE(seen);
}

@ -450,9 +450,9 @@ class RequestMetadataState : public RefCounted<RequestMetadataState> {
activity_ = MakeActivity(
[this, creds] {
return Seq(creds->GetRequestMetadata(
ClientInitialMetadata::TestOnlyWrap(&md_),
ClientMetadataHandle::TestOnlyWrap(&md_),
&get_request_metadata_args_),
[this](absl::StatusOr<ClientInitialMetadata> metadata) {
[this](absl::StatusOr<ClientMetadataHandle> metadata) {
if (metadata.ok()) {
GPR_ASSERT(metadata->get() == &md_);
}
@ -1771,8 +1771,8 @@ struct fake_call_creds : public grpc_call_credentials {
public:
fake_call_creds() : grpc_call_credentials("fake") {}
ArenaPromise<absl::StatusOr<ClientInitialMetadata>> GetRequestMetadata(
ClientInitialMetadata initial_metadata,
ArenaPromise<absl::StatusOr<ClientMetadataHandle>> GetRequestMetadata(
ClientMetadataHandle initial_metadata,
const grpc_call_credentials::GetRequestMetadataArgs*) override {
initial_metadata->Append("foo", Slice::FromStaticString("oof"),
[](absl::string_view, const Slice&) { abort(); });

@ -57,10 +57,10 @@ char* grpc_test_fetch_oauth2_token_with_credentials(
[creds, &initial_metadata, &get_request_metadata_args]() {
return grpc_core::Map(
creds->GetRequestMetadata(
grpc_core::ClientInitialMetadata::TestOnlyWrap(
grpc_core::ClientMetadataHandle::TestOnlyWrap(
&initial_metadata),
&get_request_metadata_args),
[](const absl::StatusOr<grpc_core::ClientInitialMetadata>& s) {
[](const absl::StatusOr<grpc_core::ClientMetadataHandle>& s) {
return s.status();
});
},

@ -55,7 +55,7 @@ bool ClientMetadataContains(const grpc::ServerContext& context, const grpc::stri
@implementation ServerContextTestSpouseTest
TEST(ServerContextTestSpouseTest, ClientMetadata) {
TEST(ServerContextTestSpouseTest, ClientMetadataHandle) {
grpc::ServerContext context;
grpc::testing::ServerContextTestSpouse spouse(&context);
@ -81,7 +81,7 @@ TEST(ServerContextTestSpouseTest, InitialMetadata) {
ASSERT_EQ(metadata, spouse.GetInitialMetadata());
}
TEST(ServerContextTestSpouseTest, TrailingMetadata) {
TEST(ServerContextTestSpouseTest, ServerMetadataHandle) {
grpc::ServerContext context;
grpc::testing::ServerContextTestSpouse spouse(&context);
std::multimap<std::string, std::string> metadata;

@ -47,7 +47,7 @@ bool ClientMetadataContains(const ServerContext& context,
return false;
}
TEST(ServerContextTestSpouseTest, ClientMetadata) {
TEST(ServerContextTestSpouseTest, ClientMetadataHandle) {
ServerContext context;
ServerContextTestSpouse spouse(&context);
@ -73,7 +73,7 @@ TEST(ServerContextTestSpouseTest, InitialMetadata) {
ASSERT_EQ(metadata, spouse.GetInitialMetadata());
}
TEST(ServerContextTestSpouseTest, TrailingMetadata) {
TEST(ServerContextTestSpouseTest, ServerMetadataHandle) {
ServerContext context;
ServerContextTestSpouse spouse(&context);
std::multimap<std::string, std::string> metadata;

@ -2190,6 +2190,8 @@ src/core/lib/promise/detail/promise_like.h \
src/core/lib/promise/detail/status.h \
src/core/lib/promise/detail/switch.h \
src/core/lib/promise/exec_ctx_wakeup_scheduler.h \
src/core/lib/promise/intra_activity_waiter.h \
src/core/lib/promise/latch.h \
src/core/lib/promise/loop.h \
src/core/lib/promise/map.h \
src/core/lib/promise/poll.h \

@ -1985,6 +1985,8 @@ src/core/lib/promise/detail/promise_like.h \
src/core/lib/promise/detail/status.h \
src/core/lib/promise/detail/switch.h \
src/core/lib/promise/exec_ctx_wakeup_scheduler.h \
src/core/lib/promise/intra_activity_waiter.h \
src/core/lib/promise/latch.h \
src/core/lib/promise/loop.h \
src/core/lib/promise/map.h \
src/core/lib/promise/poll.h \

Loading…
Cancel
Save