diff --git a/include/grpcpp/impl/codegen/client_context_impl.h b/include/grpcpp/impl/codegen/client_context_impl.h index 7a543055b4b..bf5afa44b10 100644 --- a/include/grpcpp/impl/codegen/client_context_impl.h +++ b/include/grpcpp/impl/codegen/client_context_impl.h @@ -310,11 +310,11 @@ class ClientContext { /// client’s identity, role, or whether it is authorized to make a particular /// call. /// + /// It is legal to call this only before initial metadata is sent. + /// /// \see https://grpc.io/docs/guides/auth.html void set_credentials( - const std::shared_ptr& creds) { - creds_ = creds; - } + const std::shared_ptr& creds); /// Return the compression algorithm the client call will request be used. /// Note that the gRPC runtime may decide to ignore this request, for example, diff --git a/src/cpp/client/client_context.cc b/src/cpp/client/client_context.cc index efb10a44006..fc44b654efe 100644 --- a/src/cpp/client/client_context.cc +++ b/src/cpp/client/client_context.cc @@ -72,6 +72,22 @@ ClientContext::~ClientContext() { g_client_callbacks->Destructor(this); } +void ClientContext::set_credentials( + const std::shared_ptr& creds) { + creds_ = creds; + // If call_ is set, we have already created the call, and set the call + // credentials. This should only be done before we have started the batch + // for sending initial metadata. + if (creds_ != nullptr && call_ != nullptr) { + if (!creds_->ApplyToCall(call_)) { + SendCancelToInterceptors(); + grpc_call_cancel_with_status(call_, GRPC_STATUS_CANCELLED, + "Failed to set credentials to rpc.", + nullptr); + } + } +} + std::unique_ptr ClientContext::FromServerContext( const grpc::ServerContext& context, PropagationOptions options) { std::unique_ptr ctx(new ClientContext); diff --git a/test/cpp/end2end/end2end_test.cc b/test/cpp/end2end/end2end_test.cc index a8592fd97c2..a4655329250 100644 --- a/test/cpp/end2end/end2end_test.cc +++ b/test/cpp/end2end/end2end_test.cc @@ -16,9 +16,6 @@ * */ -#include -#include - #include #include #include @@ -35,6 +32,9 @@ #include #include +#include +#include + #include "src/core/ext/filters/client_channel/backup_poller.h" #include "src/core/lib/iomgr/iomgr.h" #include "src/core/lib/security/credentials/credentials.h" @@ -338,7 +338,11 @@ class End2endTest : public ::testing::TestWithParam { kMaxMessageSize_); // For testing max message size. } - void ResetChannel() { + void ResetChannel( + std::vector< + std::unique_ptr> + interceptor_creators = std::vector>()) { if (!is_server_started_) { StartServer(std::shared_ptr()); } @@ -358,20 +362,27 @@ class End2endTest : public ::testing::TestWithParam { } else { channel_ = CreateCustomChannelWithInterceptors( server_address_.str(), channel_creds, args, - CreateDummyClientInterceptors()); + interceptor_creators.empty() ? CreateDummyClientInterceptors() + : std::move(interceptor_creators)); } } else { if (!GetParam().use_interceptors) { channel_ = server_->InProcessChannel(args); } else { channel_ = server_->experimental().InProcessChannelWithInterceptors( - args, CreateDummyClientInterceptors()); + args, interceptor_creators.empty() + ? CreateDummyClientInterceptors() + : std::move(interceptor_creators)); } } } - void ResetStub() { - ResetChannel(); + void ResetStub( + std::vector< + std::unique_ptr> + interceptor_creators = std::vector>()) { + ResetChannel(std::move(interceptor_creators)); if (GetParam().use_proxy) { proxy_service_.reset(new Proxy(channel_)); int port = grpc_pick_unused_port_or_die(); @@ -1802,6 +1813,60 @@ TEST_P(SecureEnd2endTest, SetPerCallCredentials) { "fake_selector")); } +class CredentialsInterceptor : public experimental::Interceptor { + public: + CredentialsInterceptor(experimental::ClientRpcInfo* info) : info_(info) {} + + void Intercept(experimental::InterceptorBatchMethods* methods) { + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) { + std::shared_ptr creds = + GoogleIAMCredentials("fake_token", "fake_selector"); + info_->client_context()->set_credentials(creds); + } + methods->Proceed(); + } + + private: + experimental::ClientRpcInfo* info_ = nullptr; +}; + +class CredentialsInterceptorFactory + : public experimental::ClientInterceptorFactoryInterface { + CredentialsInterceptor* CreateClientInterceptor( + experimental::ClientRpcInfo* info) { + return new CredentialsInterceptor(info); + } +}; + +TEST_P(SecureEnd2endTest, CallCredentialsInterception) { + MAYBE_SKIP_TEST; + if (!GetParam().use_interceptors) { + return; + } + std::vector> + interceptor_creators; + interceptor_creators.push_back(std::unique_ptr( + new CredentialsInterceptorFactory())); + ResetStub(std::move(interceptor_creators)); + EchoRequest request; + EchoResponse response; + ClientContext context; + + request.set_message("Hello"); + request.mutable_param()->set_echo_metadata(true); + + Status s = stub_->Echo(&context, request, &response); + EXPECT_EQ(request.message(), response.message()); + EXPECT_TRUE(s.ok()); + EXPECT_TRUE(MetadataContains(context.GetServerTrailingMetadata(), + GRPC_IAM_AUTHORIZATION_TOKEN_METADATA_KEY, + "fake_token")); + EXPECT_TRUE(MetadataContains(context.GetServerTrailingMetadata(), + GRPC_IAM_AUTHORITY_SELECTOR_METADATA_KEY, + "fake_selector")); +} + TEST_P(SecureEnd2endTest, OverridePerCallCredentials) { MAYBE_SKIP_TEST; ResetStub();