diff --git a/test/cpp/util/BUILD b/test/cpp/util/BUILD index 9b42bb28b1d..c3bfeb76156 100644 --- a/test/cpp/util/BUILD +++ b/test/cpp/util/BUILD @@ -177,6 +177,7 @@ grpc_cc_test( "//:grpc++_reflection", "//src/proto/grpc/testing:echo_messages_proto", "//src/proto/grpc/testing:echo_proto", + "//test/core/end2end:ssl_test_data", "//test/core/util:grpc_test_util", ], ) diff --git a/test/cpp/util/cli_credentials.cc b/test/cpp/util/cli_credentials.cc index aa4eafb7569..d14dc18f168 100644 --- a/test/cpp/util/cli_credentials.cc +++ b/test/cpp/util/cli_credentials.cc @@ -25,6 +25,10 @@ DEFINE_bool(use_auth, false, "Whether to create default google credentials."); DEFINE_string( access_token, "", "The access token that will be sent to the server to authenticate RPCs."); +DEFINE_string( + ssl_target, "", + "If not empty, treat the server host name as this for ssl/tls certificate " + "validation."); namespace grpc { namespace testing { @@ -58,7 +62,15 @@ const grpc::string CliCredentials::GetCredentialUsage() const { " --use_auth ; Set whether to create default google" " credentials\n" " --access_token ; Set the access token in metadata," - " overrides --use_auth\n"; + " overrides --use_auth\n" + " --ssl_target ; Set server host for tls validation\n"; +} + +const grpc::string CliCredentials::GetSslTargetNameOverride() const { + bool use_tls = + FLAGS_enable_ssl || (FLAGS_access_token.empty() && FLAGS_use_auth); + return use_tls ? FLAGS_ssl_target : ""; } + } // namespace testing } // namespace grpc diff --git a/test/cpp/util/cli_credentials.h b/test/cpp/util/cli_credentials.h index b1358e77d8b..8d662356de8 100644 --- a/test/cpp/util/cli_credentials.h +++ b/test/cpp/util/cli_credentials.h @@ -30,6 +30,7 @@ class CliCredentials { virtual ~CliCredentials() {} virtual std::shared_ptr GetCredentials() const; virtual const grpc::string GetCredentialUsage() const; + virtual const grpc::string GetSslTargetNameOverride() const; }; } // namespace testing diff --git a/test/cpp/util/grpc_tool.cc b/test/cpp/util/grpc_tool.cc index 840ca07d2bf..ccc60cca27a 100644 --- a/test/cpp/util/grpc_tool.cc +++ b/test/cpp/util/grpc_tool.cc @@ -206,6 +206,15 @@ void ReadResponse(CliCall* call, const grpc::string& method_name, } } +std::shared_ptr CreateCliChannel( + const grpc::string& server_address, const CliCredentials& cred) { + grpc::ChannelArguments args; + if (!cred.GetSslTargetNameOverride().empty()) { + args.SetSslTargetNameOverride(cred.GetSslTargetNameOverride()); + } + return grpc::CreateCustomChannel(server_address, cred.GetCredentials(), args); +} + struct Command { const char* command; std::function channel = - grpc::CreateChannel(server_address, cred.GetCredentials()); + CreateCliChannel(server_address, cred); grpc::ProtoReflectionDescriptorDatabase desc_db(channel); grpc::protobuf::DescriptorPool desc_pool(&desc_db); @@ -422,7 +431,7 @@ bool GrpcTool::PrintType(int argc, const char** argv, grpc::string server_address(argv[0]); std::shared_ptr channel = - grpc::CreateChannel(server_address, cred.GetCredentials()); + CreateCliChannel(server_address, cred); grpc::ProtoReflectionDescriptorDatabase desc_db(channel); grpc::protobuf::DescriptorPool desc_pool(&desc_db); @@ -469,7 +478,7 @@ bool GrpcTool::CallMethod(int argc, const char** argv, bool print_mode = false; std::shared_ptr channel = - grpc::CreateChannel(server_address, cred.GetCredentials()); + CreateCliChannel(server_address, cred); if (!FLAGS_binary_input || !FLAGS_binary_output) { parser.reset( @@ -820,7 +829,7 @@ bool GrpcTool::ParseMessage(int argc, const char** argv, if (!FLAGS_binary_input || !FLAGS_binary_output) { std::shared_ptr channel = - grpc::CreateChannel(server_address, cred.GetCredentials()); + CreateCliChannel(server_address, cred); parser.reset( new grpc::testing::ProtoFileParser(FLAGS_remotedb ? channel : nullptr, FLAGS_proto_path, FLAGS_protofiles)); diff --git a/test/cpp/util/grpc_tool_test.cc b/test/cpp/util/grpc_tool_test.cc index 6574d1bb441..7e7f44551ef 100644 --- a/test/cpp/util/grpc_tool_test.cc +++ b/test/cpp/util/grpc_tool_test.cc @@ -35,6 +35,7 @@ #include "src/core/lib/gpr/env.h" #include "src/proto/grpc/testing/echo.grpc.pb.h" #include "src/proto/grpc/testing/echo.pb.h" +#include "test/core/end2end/data/ssl_test_data.h" #include "test/core/util/port.h" #include "test/core/util/test_config.h" #include "test/cpp/util/cli_credentials.h" @@ -80,6 +81,9 @@ using grpc::testing::EchoResponse; " peer: \"peer\"\n" \ "}\n\n" +DECLARE_bool(enable_ssl); +DECLARE_string(ssl_target); + namespace grpc { namespace testing { @@ -97,10 +101,18 @@ const int kServerDefaultResponseStreamsToSend = 3; class TestCliCredentials final : public grpc::testing::CliCredentials { public: + TestCliCredentials(bool secure = false) : secure_(secure) {} std::shared_ptr GetCredentials() const override { - return InsecureChannelCredentials(); + if (!secure_) { + return InsecureChannelCredentials(); + } + SslCredentialsOptions ssl_opts = {test_root_cert, "", ""}; + return SslCredentials(grpc::SslCredentialsOptions(ssl_opts)); } const grpc::string GetCredentialUsage() const override { return ""; } + + private: + const bool secure_; }; bool PrintStream(std::stringstream* ss, const grpc::string& output) { @@ -206,13 +218,24 @@ class GrpcToolTest : public ::testing::Test { // SetUpServer cannot be used with EXPECT_EXIT. grpc_pick_unused_port_or_die() // uses atexit() to free chosen ports, and it will spawn a new thread in // resolve_address_posix.c:192 at exit time. - const grpc::string SetUpServer() { + const grpc::string SetUpServer(bool secure = false) { std::ostringstream server_address; int port = grpc_pick_unused_port_or_die(); server_address << "localhost:" << port; // Setup server ServerBuilder builder; - builder.AddListeningPort(server_address.str(), InsecureServerCredentials()); + std::shared_ptr creds; + if (secure) { + SslServerCredentialsOptions::PemKeyCertPair pkcp = {test_server1_key, + test_server1_cert}; + SslServerCredentialsOptions ssl_opts; + ssl_opts.pem_root_certs = ""; + ssl_opts.pem_key_cert_pairs.push_back(pkcp); + creds = SslServerCredentials(ssl_opts); + } else { + creds = InsecureServerCredentials(); + } + builder.AddListeningPort(server_address.str(), creds); builder.RegisterService(&service_); server_ = builder.BuildAndStart(); return server_address.str(); @@ -743,6 +766,29 @@ TEST_F(GrpcToolTest, CallCommandWithBadMetadata) { gpr_free(test_srcdir); } +TEST_F(GrpcToolTest, ListCommand_OverrideSslHostName) { + const grpc::string server_address = SetUpServer(true); + + // Test input "grpc_cli ls localhost: --enable_ssl + // --ssl_target=z.test.google.fr" + std::stringstream output_stream; + const char* argv[] = {"grpc_cli", "ls", server_address.c_str()}; + FLAGS_l = false; + FLAGS_enable_ssl = true; + FLAGS_ssl_target = "z.test.google.fr"; + EXPECT_TRUE( + 0 == GrpcToolMainLib( + ArraySize(argv), argv, TestCliCredentials(true), + std::bind(PrintStream, &output_stream, std::placeholders::_1))); + EXPECT_TRUE(0 == strcmp(output_stream.str().c_str(), + "grpc.testing.EchoTestService\n" + "grpc.reflection.v1alpha.ServerReflection\n")); + + FLAGS_enable_ssl = false; + FLAGS_ssl_target = ""; + ShutdownServer(); +} + } // namespace testing } // namespace grpc