diff --git a/test/cpp/util/cli_credentials.cc b/test/cpp/util/cli_credentials.cc index d14dc18f168..acf4ef8ef10 100644 --- a/test/cpp/util/cli_credentials.cc +++ b/test/cpp/util/cli_credentials.cc @@ -20,8 +20,12 @@ #include -DEFINE_bool(enable_ssl, false, "Whether to use ssl/tls."); -DEFINE_bool(use_auth, false, "Whether to create default google credentials."); +DEFINE_bool( + enable_ssl, false, + "Whether to use ssl/tls. Deprecated. Use --channel_creds_type=ssl."); +DEFINE_bool(use_auth, false, + "Whether to create default google credentials. Deprecated. Use " + "--channel_creds_type=gdc."); DEFINE_string( access_token, "", "The access token that will be sent to the server to authenticate RPCs."); @@ -29,47 +33,109 @@ DEFINE_string( ssl_target, "", "If not empty, treat the server host name as this for ssl/tls certificate " "validation."); +DEFINE_string( + channel_creds_type, "", + "The channel creds type: insecure, ssl, gdc (Google Default Credentials) " + "or alts."); namespace grpc { namespace testing { -std::shared_ptr CliCredentials::GetCredentials() +grpc::string CliCredentials::GetDefaultChannelCredsType() const { + // Compatibility logic for --enable_ssl. + if (FLAGS_enable_ssl) { + fprintf(stderr, + "warning: --enable_ssl is deprecated. Use " + "--channel_creds_type=ssl.\n"); + return "ssl"; + } + // Compatibility logic for --use_auth. + if (FLAGS_access_token.empty() && FLAGS_use_auth) { + fprintf(stderr, + "warning: --use_auth is deprecated. Use " + "--channel_creds_type=gdc.\n"); + return "gdc"; + } + return "insecure"; +} + +std::shared_ptr +CliCredentials::GetChannelCredentials() const { + if (FLAGS_channel_creds_type.compare("insecure") == 0) { + return grpc::InsecureChannelCredentials(); + } else if (FLAGS_channel_creds_type.compare("ssl") == 0) { + return grpc::SslCredentials(grpc::SslCredentialsOptions()); + } else if (FLAGS_channel_creds_type.compare("gdc") == 0) { + return grpc::GoogleDefaultCredentials(); + } else if (FLAGS_channel_creds_type.compare("alts") == 0) { + return grpc::experimental::AltsCredentials( + grpc::experimental::AltsCredentialsOptions()); + } + fprintf(stderr, + "--channel_creds_type=%s invalid; must be insecure, ssl, gdc or " + "alts.\n", + FLAGS_channel_creds_type.c_str()); + return std::shared_ptr(); +} + +std::shared_ptr CliCredentials::GetCallCredentials() const { if (!FLAGS_access_token.empty()) { if (FLAGS_use_auth) { fprintf(stderr, "warning: use_auth is ignored when access_token is provided."); } - - return grpc::CompositeChannelCredentials( - grpc::SslCredentials(grpc::SslCredentialsOptions()), - grpc::AccessTokenCredentials(FLAGS_access_token)); + return grpc::AccessTokenCredentials(FLAGS_access_token); } + return std::shared_ptr(); +} - if (FLAGS_use_auth) { - return grpc::GoogleDefaultCredentials(); +std::shared_ptr CliCredentials::GetCredentials() + const { + if (FLAGS_channel_creds_type.empty()) { + FLAGS_channel_creds_type = GetDefaultChannelCredsType(); + } else if (FLAGS_enable_ssl && FLAGS_channel_creds_type.compare("ssl") != 0) { + fprintf(stderr, + "warning: ignoring --enable_ssl because " + "--channel_creds_type already set to %s.\n", + FLAGS_channel_creds_type.c_str()); + } else if (FLAGS_use_auth && FLAGS_channel_creds_type.compare("gdc") != 0) { + fprintf(stderr, + "warning: ignoring --use_auth because " + "--channel_creds_type already set to %s.\n", + FLAGS_channel_creds_type.c_str()); } - - if (FLAGS_enable_ssl) { - return grpc::SslCredentials(grpc::SslCredentialsOptions()); + // Legacy transport upgrade logic for insecure requests. + if (!FLAGS_access_token.empty() && + FLAGS_channel_creds_type.compare("insecure") == 0) { + fprintf(stderr, + "warning: --channel_creds_type=insecure upgraded to ssl because " + "an access token was provided.\n"); + FLAGS_channel_creds_type = "ssl"; } - - return grpc::InsecureChannelCredentials(); + std::shared_ptr channel_creds = + GetChannelCredentials(); + // Composite any call-type credentials on top of the base channel. + std::shared_ptr call_creds = GetCallCredentials(); + return (channel_creds == nullptr || call_creds == nullptr) + ? channel_creds + : grpc::CompositeChannelCredentials(channel_creds, call_creds); } const grpc::string CliCredentials::GetCredentialUsage() const { - return " --enable_ssl ; Set whether to use tls\n" + return " --enable_ssl ; Set whether to use ssl (deprecated)\n" " --use_auth ; Set whether to create default google" " credentials\n" " --access_token ; Set the access token in metadata," " overrides --use_auth\n" - " --ssl_target ; Set server host for tls validation\n"; + " --ssl_target ; Set server host for ssl validation\n" + " --channel_creds_type ; Set to insecure, ssl, gdc, or alts\n"; } const grpc::string CliCredentials::GetSslTargetNameOverride() const { - bool use_tls = - FLAGS_enable_ssl || (FLAGS_access_token.empty() && FLAGS_use_auth); - return use_tls ? FLAGS_ssl_target : ""; + bool use_ssl = FLAGS_channel_creds_type.compare("ssl") == 0 || + FLAGS_channel_creds_type.compare("gdc") == 0; + return use_ssl ? FLAGS_ssl_target : ""; } } // namespace testing diff --git a/test/cpp/util/cli_credentials.h b/test/cpp/util/cli_credentials.h index 8d662356de8..4636d3ca149 100644 --- a/test/cpp/util/cli_credentials.h +++ b/test/cpp/util/cli_credentials.h @@ -28,9 +28,22 @@ namespace testing { class CliCredentials { public: virtual ~CliCredentials() {} - virtual std::shared_ptr GetCredentials() const; + std::shared_ptr GetCredentials() const; virtual const grpc::string GetCredentialUsage() const; virtual const grpc::string GetSslTargetNameOverride() const; + + protected: + // Returns the appropriate channel_creds_type value for the set of legacy + // flag arguments. + virtual grpc::string GetDefaultChannelCredsType() const; + // Returns the base transport channel credentials. Child classes can override + // to support additional channel_creds_types unknown to this base class. + virtual std::shared_ptr GetChannelCredentials() + const; + // Returns call credentials to composite onto the base transport channel + // credentials. Child classes can override to support additional + // authentication flags unknown to this base class. + virtual std::shared_ptr GetCallCredentials() const; }; } // namespace testing diff --git a/test/cpp/util/grpc_tool_test.cc b/test/cpp/util/grpc_tool_test.cc index 7e7f44551ef..3aae090e818 100644 --- a/test/cpp/util/grpc_tool_test.cc +++ b/test/cpp/util/grpc_tool_test.cc @@ -81,7 +81,7 @@ using grpc::testing::EchoResponse; " peer: \"peer\"\n" \ "}\n\n" -DECLARE_bool(enable_ssl); +DECLARE_string(channel_creds_type); DECLARE_string(ssl_target); namespace grpc { @@ -102,7 +102,8 @@ const int kServerDefaultResponseStreamsToSend = 3; class TestCliCredentials final : public grpc::testing::CliCredentials { public: TestCliCredentials(bool secure = false) : secure_(secure) {} - std::shared_ptr GetCredentials() const override { + std::shared_ptr GetChannelCredentials() + const override { if (!secure_) { return InsecureChannelCredentials(); } @@ -769,12 +770,12 @@ TEST_F(GrpcToolTest, CallCommandWithBadMetadata) { TEST_F(GrpcToolTest, ListCommand_OverrideSslHostName) { const grpc::string server_address = SetUpServer(true); - // Test input "grpc_cli ls localhost: --enable_ssl + // Test input "grpc_cli ls localhost: --channel_creds_type=ssl // --ssl_target=z.test.google.fr" std::stringstream output_stream; const char* argv[] = {"grpc_cli", "ls", server_address.c_str()}; FLAGS_l = false; - FLAGS_enable_ssl = true; + FLAGS_channel_creds_type = "ssl"; FLAGS_ssl_target = "z.test.google.fr"; EXPECT_TRUE( 0 == GrpcToolMainLib( @@ -784,7 +785,7 @@ TEST_F(GrpcToolTest, ListCommand_OverrideSslHostName) { "grpc.testing.EchoTestService\n" "grpc.reflection.v1alpha.ServerReflection\n")); - FLAGS_enable_ssl = false; + FLAGS_channel_creds_type = ""; FLAGS_ssl_target = ""; ShutdownServer(); }