diff --git a/examples/pubsub/main.cc b/examples/pubsub/main.cc index 68620e64c5c..b1898f18d9a 100644 --- a/examples/pubsub/main.cc +++ b/examples/pubsub/main.cc @@ -71,7 +71,7 @@ int main(int argc, char** argv) { ss << FLAGS_server_host << ":" << FLAGS_server_port; - std::unique_ptr creds = grpc::GoogleDefaultCredentials(); + std::shared_ptr creds = grpc::GoogleDefaultCredentials(); std::shared_ptr channel = grpc::CreateChannel(ss.str(), creds, grpc::ChannelArguments()); diff --git a/include/grpc++/client_context.h b/include/grpc++/client_context.h index a58e9872e60..6d9015f278c 100644 --- a/include/grpc++/client_context.h +++ b/include/grpc++/client_context.h @@ -51,6 +51,7 @@ namespace grpc { class CallOpBuffer; class ChannelInterface; class CompletionQueue; +class Credentials; class RpcMethod; class Status; template @@ -102,6 +103,11 @@ class ClientContext { void set_authority(const grpc::string& authority) { authority_ = authority; } + // Set credentials for the rpc. + void set_credentials(const std::shared_ptr& creds) { + creds_ = creds; + } + void TryCancel(); private: @@ -127,11 +133,8 @@ class ClientContext { friend class ::grpc::ClientAsyncResponseReader; grpc_call* call() { return call_; } - void set_call(grpc_call* call, const std::shared_ptr& channel) { - GPR_ASSERT(call_ == nullptr); - call_ = call; - channel_ = channel; - } + void set_call(grpc_call* call, + const std::shared_ptr& channel); grpc_completion_queue* cq() { return cq_; } void set_cq(grpc_completion_queue* cq) { cq_ = cq; } @@ -144,6 +147,7 @@ class ClientContext { grpc_completion_queue* cq_; gpr_timespec deadline_; grpc::string authority_; + std::shared_ptr creds_; std::multimap send_initial_metadata_; std::multimap recv_initial_metadata_; std::multimap trailing_metadata_; diff --git a/include/grpc++/create_channel.h b/include/grpc++/create_channel.h index da375b97db4..424a93a64c5 100644 --- a/include/grpc++/create_channel.h +++ b/include/grpc++/create_channel.h @@ -45,7 +45,7 @@ class ChannelInterface; // If creds does not hold an object or is invalid, a lame channel is returned. std::shared_ptr CreateChannel( - const grpc::string& target, const std::unique_ptr& creds, + const grpc::string& target, const std::shared_ptr& creds, const ChannelArguments& args); } // namespace grpc diff --git a/include/grpc++/credentials.h b/include/grpc++/credentials.h index 61c40946910..7a40cd199d8 100644 --- a/include/grpc++/credentials.h +++ b/include/grpc++/credentials.h @@ -47,17 +47,18 @@ class SecureCredentials; class Credentials : public GrpcLibrary { public: ~Credentials() GRPC_OVERRIDE; + virtual bool ApplyToCall(grpc_call* call) = 0; protected: - friend std::unique_ptr CompositeCredentials( - const std::unique_ptr& creds1, - const std::unique_ptr& creds2); + friend std::shared_ptr CompositeCredentials( + const std::shared_ptr& creds1, + const std::shared_ptr& creds2); virtual SecureCredentials* AsSecureCredentials() = 0; private: friend std::shared_ptr CreateChannel( - const grpc::string& target, const std::unique_ptr& creds, + const grpc::string& target, const std::shared_ptr& creds, const ChannelArguments& args); virtual std::shared_ptr CreateChannel( @@ -80,20 +81,20 @@ struct SslCredentialsOptions { }; // Factories for building different types of Credentials -// The functions may return empty unique_ptr when credentials cannot be created. +// The functions may return empty shared_ptr when credentials cannot be created. // If a Credentials pointer is returned, it can still be invalid when used to // create a channel. A lame channel will be created then and all rpcs will // fail on it. // Builds credentials with reasonable defaults. -std::unique_ptr GoogleDefaultCredentials(); +std::shared_ptr GoogleDefaultCredentials(); // Builds SSL Credentials given SSL specific options -std::unique_ptr SslCredentials( +std::shared_ptr SslCredentials( const SslCredentialsOptions& options); // Builds credentials for use when running in GCE -std::unique_ptr ComputeEngineCredentials(); +std::shared_ptr ComputeEngineCredentials(); // Builds service account credentials. // json_key is the JSON key string containing the client's private key. @@ -101,7 +102,7 @@ std::unique_ptr ComputeEngineCredentials(); // token_lifetime_seconds is the lifetime in seconds of each token acquired // through this service account credentials. It should be positive and should // not exceed grpc_max_auth_token_lifetime or will be cropped to this value. -std::unique_ptr ServiceAccountCredentials( +std::shared_ptr ServiceAccountCredentials( const grpc::string& json_key, const grpc::string& scope, long token_lifetime_seconds); @@ -110,27 +111,27 @@ std::unique_ptr ServiceAccountCredentials( // token_lifetime_seconds is the lifetime in seconds of each Json Web Token // (JWT) created with this credentials. It should not exceed // grpc_max_auth_token_lifetime or will be cropped to this value. -std::unique_ptr JWTCredentials( - const grpc::string& json_key, long token_lifetime_seconds); +std::shared_ptr JWTCredentials(const grpc::string& json_key, + long token_lifetime_seconds); // Builds refresh token credentials. // json_refresh_token is the JSON string containing the refresh token along // with a client_id and client_secret. -std::unique_ptr RefreshTokenCredentials( +std::shared_ptr RefreshTokenCredentials( const grpc::string& json_refresh_token); // Builds IAM credentials. -std::unique_ptr IAMCredentials( +std::shared_ptr IAMCredentials( const grpc::string& authorization_token, const grpc::string& authority_selector); // Combines two credentials objects into a composite credentials -std::unique_ptr CompositeCredentials( - const std::unique_ptr& creds1, - const std::unique_ptr& creds2); +std::shared_ptr CompositeCredentials( + const std::shared_ptr& creds1, + const std::shared_ptr& creds2); // Credentials for an unencrypted, unauthenticated channel -std::unique_ptr InsecureCredentials(); +std::shared_ptr InsecureCredentials(); } // namespace grpc diff --git a/src/cpp/client/client_context.cc b/src/cpp/client/client_context.cc index f38a694734a..72cdd49d195 100644 --- a/src/cpp/client/client_context.cc +++ b/src/cpp/client/client_context.cc @@ -34,6 +34,7 @@ #include #include +#include #include namespace grpc { @@ -63,6 +64,17 @@ void ClientContext::AddMetadata(const grpc::string& meta_key, send_initial_metadata_.insert(std::make_pair(meta_key, meta_value)); } +void ClientContext::set_call(grpc_call* call, + const std::shared_ptr& channel) { + GPR_ASSERT(call_ == nullptr); + call_ = call; + channel_ = channel; + if (creds_ && !creds_->ApplyToCall(call_)) { + grpc_call_cancel_with_status(call, GRPC_STATUS_CANCELLED, + "Failed to set credentials to rpc."); + } +} + void ClientContext::TryCancel() { if (call_) { grpc_call_cancel(call_); diff --git a/src/cpp/client/create_channel.cc b/src/cpp/client/create_channel.cc index 301430572a2..510af2bb00a 100644 --- a/src/cpp/client/create_channel.cc +++ b/src/cpp/client/create_channel.cc @@ -41,7 +41,7 @@ namespace grpc { class ChannelArguments; std::shared_ptr CreateChannel( - const grpc::string& target, const std::unique_ptr& creds, + const grpc::string& target, const std::shared_ptr& creds, const ChannelArguments& args) { return creds ? creds->CreateChannel(target, args) : std::shared_ptr( diff --git a/src/cpp/client/insecure_credentials.cc b/src/cpp/client/insecure_credentials.cc index 8945b038dec..668ea2e8733 100644 --- a/src/cpp/client/insecure_credentials.cc +++ b/src/cpp/client/insecure_credentials.cc @@ -52,12 +52,14 @@ class InsecureCredentialsImpl GRPC_FINAL : public Credentials { target, grpc_channel_create(target.c_str(), &channel_args))); } + bool ApplyToCall(grpc_call* call) GRPC_OVERRIDE { return true; } + SecureCredentials* AsSecureCredentials() GRPC_OVERRIDE { return nullptr; } }; } // namespace -std::unique_ptr InsecureCredentials() { - return std::unique_ptr(new InsecureCredentialsImpl()); +std::shared_ptr InsecureCredentials() { + return std::shared_ptr(new InsecureCredentialsImpl()); } } // namespace grpc diff --git a/src/cpp/client/secure_credentials.cc b/src/cpp/client/secure_credentials.cc index 48bf7430b27..b5134b31403 100644 --- a/src/cpp/client/secure_credentials.cc +++ b/src/cpp/client/secure_credentials.cc @@ -49,20 +49,24 @@ std::shared_ptr SecureCredentials::CreateChannel( grpc_secure_channel_create(c_creds_, target.c_str(), &channel_args))); } +bool SecureCredentials::ApplyToCall(grpc_call* call) { + return grpc_call_set_credentials(call, c_creds_) == GRPC_CALL_OK; +} + namespace { -std::unique_ptr WrapCredentials(grpc_credentials* creds) { +std::shared_ptr WrapCredentials(grpc_credentials* creds) { return creds == nullptr ? nullptr - : std::unique_ptr(new SecureCredentials(creds)); + : std::shared_ptr(new SecureCredentials(creds)); } } // namespace -std::unique_ptr GoogleDefaultCredentials() { +std::shared_ptr GoogleDefaultCredentials() { return WrapCredentials(grpc_google_default_credentials_create()); } // Builds SSL Credentials given SSL specific options -std::unique_ptr SslCredentials( +std::shared_ptr SslCredentials( const SslCredentialsOptions& options) { grpc_ssl_pem_key_cert_pair pem_key_cert_pair = { options.pem_private_key.c_str(), options.pem_cert_chain.c_str()}; @@ -74,12 +78,12 @@ std::unique_ptr SslCredentials( } // Builds credentials for use when running in GCE -std::unique_ptr ComputeEngineCredentials() { +std::shared_ptr ComputeEngineCredentials() { return WrapCredentials(grpc_compute_engine_credentials_create()); } // Builds service account credentials. -std::unique_ptr ServiceAccountCredentials( +std::shared_ptr ServiceAccountCredentials( const grpc::string& json_key, const grpc::string& scope, long token_lifetime_seconds) { if (token_lifetime_seconds <= 0) { @@ -94,8 +98,8 @@ std::unique_ptr ServiceAccountCredentials( } // Builds JWT credentials. -std::unique_ptr JWTCredentials( - const grpc::string& json_key, long token_lifetime_seconds) { +std::shared_ptr JWTCredentials(const grpc::string& json_key, + long token_lifetime_seconds) { if (token_lifetime_seconds <= 0) { gpr_log(GPR_ERROR, "Trying to create JWTCredentials with non-positive lifetime"); @@ -107,14 +111,14 @@ std::unique_ptr JWTCredentials( } // Builds refresh token credentials. -std::unique_ptr RefreshTokenCredentials( +std::shared_ptr RefreshTokenCredentials( const grpc::string& json_refresh_token) { return WrapCredentials( grpc_refresh_token_credentials_create(json_refresh_token.c_str())); } // Builds IAM credentials. -std::unique_ptr IAMCredentials( +std::shared_ptr IAMCredentials( const grpc::string& authorization_token, const grpc::string& authority_selector) { return WrapCredentials(grpc_iam_credentials_create( @@ -122,10 +126,10 @@ std::unique_ptr IAMCredentials( } // Combines two credentials objects into a composite credentials. -std::unique_ptr CompositeCredentials( - const std::unique_ptr& creds1, - const std::unique_ptr& creds2) { - // Note that we are not saving unique_ptrs to the two credentials +std::shared_ptr CompositeCredentials( + const std::shared_ptr& creds1, + const std::shared_ptr& creds2) { + // Note that we are not saving shared_ptrs to the two credentials // passed in here. This is OK because the underlying C objects (i.e., // creds1 and creds2) into grpc_composite_credentials_create will see their // refcounts incremented. diff --git a/src/cpp/client/secure_credentials.h b/src/cpp/client/secure_credentials.h index 77d575813ea..ddf69911b5f 100644 --- a/src/cpp/client/secure_credentials.h +++ b/src/cpp/client/secure_credentials.h @@ -46,6 +46,7 @@ class SecureCredentials GRPC_FINAL : public Credentials { explicit SecureCredentials(grpc_credentials* c_creds) : c_creds_(c_creds) {} ~SecureCredentials() GRPC_OVERRIDE { grpc_credentials_release(c_creds_); } grpc_credentials* GetRawCreds() { return c_creds_; } + bool ApplyToCall(grpc_call* call) GRPC_OVERRIDE; std::shared_ptr CreateChannel( const string& target, const grpc::ChannelArguments& args) GRPC_OVERRIDE; diff --git a/test/cpp/client/credentials_test.cc b/test/cpp/client/credentials_test.cc index 6840418989c..ee94f455a43 100644 --- a/test/cpp/client/credentials_test.cc +++ b/test/cpp/client/credentials_test.cc @@ -46,8 +46,7 @@ class CredentialsTest : public ::testing::Test { }; TEST_F(CredentialsTest, InvalidServiceAccountCreds) { - std::unique_ptr bad1 = - ServiceAccountCredentials("", "", 1); + std::shared_ptr bad1 = ServiceAccountCredentials("", "", 1); EXPECT_EQ(nullptr, bad1.get()); } diff --git a/test/cpp/end2end/end2end_test.cc b/test/cpp/end2end/end2end_test.cc index f35b16fe555..a98d7c23b97 100644 --- a/test/cpp/end2end/end2end_test.cc +++ b/test/cpp/end2end/end2end_test.cc @@ -429,7 +429,7 @@ TEST_F(End2endTest, DiffPackageServices) { // rpc and stream should fail on bad credentials. TEST_F(End2endTest, BadCredentials) { - std::unique_ptr bad_creds = ServiceAccountCredentials("", "", 1); + std::shared_ptr bad_creds = ServiceAccountCredentials("", "", 1); EXPECT_EQ(nullptr, bad_creds.get()); std::shared_ptr channel = CreateChannel(server_address_.str(), bad_creds, ChannelArguments()); @@ -588,6 +588,54 @@ TEST_F(End2endTest, RpcMaxMessageSize) { EXPECT_FALSE(s.IsOk()); } +TEST_F(End2endTest, SetPerCallCredentials) { + ResetStub(); + EchoRequest request; + EchoResponse response; + ClientContext context; + std::shared_ptr creds = + IAMCredentials("fake_token", "fake_selector"); + context.set_credentials(creds); + grpc::string msg("Hello"); + + Status s = stub_->Echo(&context, request, &response); + // TODO(yangg) verify creds at the server side. + EXPECT_EQ(request.message(), response.message()); + EXPECT_TRUE(s.IsOk()); +} + +TEST_F(End2endTest, InsecurePerCallCredentials) { + ResetStub(); + EchoRequest request; + EchoResponse response; + ClientContext context; + std::shared_ptr creds = InsecureCredentials(); + context.set_credentials(creds); + grpc::string msg("Hello"); + + Status s = stub_->Echo(&context, request, &response); + EXPECT_EQ(request.message(), response.message()); + EXPECT_TRUE(s.IsOk()); +} + +TEST_F(End2endTest, OverridePerCallCredentials) { + ResetStub(); + EchoRequest request; + EchoResponse response; + ClientContext context; + std::shared_ptr creds1 = InsecureCredentials(); + context.set_credentials(creds1); + std::shared_ptr creds2 = + IAMCredentials("fake_token", "fake_selector"); + context.set_credentials(creds2); + grpc::string msg("Hello"); + + Status s = stub_->Echo(&context, request, &response); + // TODO(yangg) verify creds at the server side. + EXPECT_EQ(request.message(), response.message()); + EXPECT_TRUE(s.IsOk()); +} + } // namespace testing } // namespace grpc diff --git a/test/cpp/interop/client_helper.cc b/test/cpp/interop/client_helper.cc index a1dea383e6a..09fd1c8913f 100644 --- a/test/cpp/interop/client_helper.cc +++ b/test/cpp/interop/client_helper.cc @@ -82,7 +82,7 @@ std::shared_ptr CreateChannelForTestCase( FLAGS_server_port); if (test_case == "service_account_creds") { - std::unique_ptr creds; + std::shared_ptr creds; GPR_ASSERT(FLAGS_enable_ssl); grpc::string json_key = GetServiceAccountJsonKey(); std::chrono::seconds token_lifetime = std::chrono::hours(1); @@ -91,13 +91,13 @@ std::shared_ptr CreateChannelForTestCase( return CreateTestChannel(host_port, FLAGS_server_host_override, FLAGS_enable_ssl, FLAGS_use_prod_roots, creds); } else if (test_case == "compute_engine_creds") { - std::unique_ptr creds; + std::shared_ptr creds; GPR_ASSERT(FLAGS_enable_ssl); creds = ComputeEngineCredentials(); return CreateTestChannel(host_port, FLAGS_server_host_override, FLAGS_enable_ssl, FLAGS_use_prod_roots, creds); } else if (test_case == "jwt_token_creds") { - std::unique_ptr creds; + std::shared_ptr creds; GPR_ASSERT(FLAGS_enable_ssl); grpc::string json_key = GetServiceAccountJsonKey(); std::chrono::seconds token_lifetime = std::chrono::hours(1); diff --git a/test/cpp/util/create_test_channel.cc b/test/cpp/util/create_test_channel.cc index f040acc4b1f..dc48fa2d87f 100644 --- a/test/cpp/util/create_test_channel.cc +++ b/test/cpp/util/create_test_channel.cc @@ -58,13 +58,13 @@ namespace grpc { std::shared_ptr CreateTestChannel( const grpc::string& server, const grpc::string& override_hostname, bool enable_ssl, bool use_prod_roots, - const std::unique_ptr& creds) { + const std::shared_ptr& creds) { ChannelArguments channel_args; if (enable_ssl) { const char* roots_certs = use_prod_roots ? "" : test_root_cert; SslCredentialsOptions ssl_opts = {roots_certs, "", ""}; - std::unique_ptr channel_creds = SslCredentials(ssl_opts); + std::shared_ptr channel_creds = SslCredentials(ssl_opts); if (!server.empty() && !override_hostname.empty()) { channel_args.SetSslTargetNameOverride(override_hostname); @@ -84,7 +84,7 @@ std::shared_ptr CreateTestChannel( const grpc::string& server, const grpc::string& override_hostname, bool enable_ssl, bool use_prod_roots) { return CreateTestChannel(server, override_hostname, enable_ssl, - use_prod_roots, std::unique_ptr()); + use_prod_roots, std::shared_ptr()); } // Shortcut for end2end and interop tests. diff --git a/test/cpp/util/create_test_channel.h b/test/cpp/util/create_test_channel.h index 5c298ce8504..5f2609ddd81 100644 --- a/test/cpp/util/create_test_channel.h +++ b/test/cpp/util/create_test_channel.h @@ -52,7 +52,7 @@ std::shared_ptr CreateTestChannel( std::shared_ptr CreateTestChannel( const grpc::string& server, const grpc::string& override_hostname, bool enable_ssl, bool use_prod_roots, - const std::unique_ptr& creds); + const std::shared_ptr& creds); } // namespace grpc diff --git a/test/cpp/util/grpc_cli.cc b/test/cpp/util/grpc_cli.cc index d71a7a0b778..ad3c0af8775 100644 --- a/test/cpp/util/grpc_cli.cc +++ b/test/cpp/util/grpc_cli.cc @@ -104,7 +104,7 @@ int main(int argc, char** argv) { std::stringstream input_stream; input_stream << input_file.rdbuf(); - std::unique_ptr creds; + std::shared_ptr creds; if (!FLAGS_enable_ssl) { creds = grpc::InsecureCredentials(); } else {