From 60f060e0785ede3a292bf740de9594e3b4edb30f Mon Sep 17 00:00:00 2001 From: Michael Behr Date: Thu, 14 Feb 2019 16:51:04 -0500 Subject: [PATCH] Let interop_client send additional metadata, controlled by a flag. --- test/cpp/interop/client.cc | 27 +++++- test/cpp/interop/client_helper.cc | 50 ++++++++++- test/cpp/interop/client_helper.h | 38 +++++++- test/cpp/util/create_test_channel.cc | 125 +++++++++++++++++++++------ test/cpp/util/create_test_channel.h | 33 +++++++ 5 files changed, 241 insertions(+), 32 deletions(-) diff --git a/test/cpp/interop/client.cc b/test/cpp/interop/client.cc index c9458ff40ce..8a934845ab4 100644 --- a/test/cpp/interop/client.cc +++ b/test/cpp/interop/client.cc @@ -92,6 +92,9 @@ DEFINE_int32(soak_iterations, 1000, DEFINE_int32(iteration_interval, 10, "The interval in seconds between rpcs. This is used by " "long_connection test"); +DEFINE_string(additional_metadata, "", + "Additional metadata to send in each request, as a " + "semicolon-separated list of key:value pairs."); using grpc::testing::CreateChannelForTestCase; using grpc::testing::GetServiceAccountJsonKey; @@ -101,8 +104,28 @@ int main(int argc, char** argv) { grpc::testing::InitTest(&argc, &argv, true); gpr_log(GPR_INFO, "Testing these cases: %s", FLAGS_test_case.c_str()); int ret = 0; - grpc::testing::ChannelCreationFunc channel_creation_func = - std::bind(&CreateChannelForTestCase, FLAGS_test_case); + + grpc::testing::ChannelCreationFunc channel_creation_func; + grpc::string test_case = FLAGS_test_case; + if (FLAGS_additional_metadata == "") { + channel_creation_func = [test_case]() { + return CreateChannelForTestCase(test_case); + }; + } else { + std::multimap additional_metadata = + grpc::testing::ParseAdditionalMetadataFlag(FLAGS_additional_metadata); + + channel_creation_func = [test_case, additional_metadata]() { + std::vector> + factories; + factories.emplace_back( + new grpc::testing::AdditionalMetadataInterceptorFactory( + additional_metadata)); + return CreateChannelForTestCase(test_case, std::move(factories)); + }; + } + grpc::testing::InteropClient client(channel_creation_func, true, FLAGS_do_not_abort_on_transient_failures); diff --git a/test/cpp/interop/client_helper.cc b/test/cpp/interop/client_helper.cc index fb7b7bb7d03..ff4fab9cb01 100644 --- a/test/cpp/interop/client_helper.cc +++ b/test/cpp/interop/client_helper.cc @@ -20,6 +20,7 @@ #include #include +#include #include #include @@ -79,7 +80,10 @@ void UpdateActions( std::unordered_map>* actions) {} std::shared_ptr CreateChannelForTestCase( - const grpc::string& test_case) { + const grpc::string& test_case, + std::vector< + std::unique_ptr> + interceptor_creators) { GPR_ASSERT(FLAGS_server_port); const int host_port_buf_size = 1024; char host_port[host_port_buf_size]; @@ -107,11 +111,51 @@ std::shared_ptr CreateChannelForTestCase( transport_security security_type = FLAGS_use_alts ? ALTS : (FLAGS_use_tls ? TLS : INSECURE); return CreateTestChannel(host_port, FLAGS_server_host_override, - security_type, !FLAGS_use_test_ca, creds); + security_type, !FLAGS_use_test_ca, creds, + std::move(interceptor_creators)); } else { - return CreateTestChannel(host_port, FLAGS_custom_credentials_type, creds); + if (interceptor_creators.empty()) { + return CreateTestChannel(host_port, FLAGS_custom_credentials_type, creds); + } else { + return CreateTestChannel(host_port, FLAGS_custom_credentials_type, creds, + std::move(interceptor_creators)); + } } } +std::multimap ParseAdditionalMetadataFlag( + const grpc::string& flag) { + std::multimap additional_metadata; + + // Key in group 1; value in group 2. + std::regex re("([-a-zA-Z0-9]+):([^;]*);?"); + auto metadata_entries_begin = std::sregex_iterator( + flag.begin(), flag.end(), re, std::regex_constants::match_continuous); + auto metadata_entries_end = std::sregex_iterator(); + + for (std::sregex_iterator i = metadata_entries_begin; + i != metadata_entries_end; ++i) { + std::smatch match = *i; + gpr_log(GPR_INFO, "Adding additional metadata with key %s and value %s", + match[1].str().c_str(), match[2].str().c_str()); + additional_metadata.insert({match[1].str(), match[2].str()}); + } + + return additional_metadata; +} + +void AdditionalMetadataInterceptor::Intercept( + experimental::InterceptorBatchMethods* methods) { + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) { + std::multimap* metadata = + methods->GetSendInitialMetadata(); + for (const auto& entry : additional_metadata_) { + metadata->insert(entry); + } + } + methods->Proceed(); +} + } // namespace testing } // namespace grpc diff --git a/test/cpp/interop/client_helper.h b/test/cpp/interop/client_helper.h index 7dee85cc980..895f7625baa 100644 --- a/test/cpp/interop/client_helper.h +++ b/test/cpp/interop/client_helper.h @@ -39,7 +39,16 @@ void UpdateActions( std::unordered_map>* actions); std::shared_ptr CreateChannelForTestCase( - const grpc::string& test_case); + const grpc::string& test_case, + std::vector< + std::unique_ptr> + interceptor_creators = {}); + +// Parse the contents of FLAGS_additional_metadata into a map. Allow +// alphanumeric characters and dashes in keys, and any character but semicolons +// in values. +std::multimap ParseAdditionalMetadataFlag( + const grpc::string& flag); class InteropClientContextInspector { public: @@ -59,6 +68,33 @@ class InteropClientContextInspector { const ::grpc::ClientContext& context_; }; +class AdditionalMetadataInterceptor : public experimental::Interceptor { + public: + AdditionalMetadataInterceptor( + std::multimap additional_metadata) + : additional_metadata_(std::move(additional_metadata)) {} + + void Intercept(experimental::InterceptorBatchMethods* methods) override; + + private: + const std::multimap additional_metadata_; +}; + +class AdditionalMetadataInterceptorFactory + : public experimental::ClientInterceptorFactoryInterface { + public: + AdditionalMetadataInterceptorFactory( + std::multimap additional_metadata) + : additional_metadata_(std::move(additional_metadata)) {} + + experimental::Interceptor* CreateClientInterceptor( + experimental::ClientRpcInfo* info) override { + return new AdditionalMetadataInterceptor(additional_metadata_); + } + + const std::multimap additional_metadata_; +}; + } // namespace testing } // namespace grpc diff --git a/test/cpp/util/create_test_channel.cc b/test/cpp/util/create_test_channel.cc index 0bcd4dbc844..e0c0bd064fc 100644 --- a/test/cpp/util/create_test_channel.cc +++ b/test/cpp/util/create_test_channel.cc @@ -71,10 +71,74 @@ std::shared_ptr CreateTestChannel( const grpc::string& override_hostname, bool use_prod_roots, const std::shared_ptr& creds, const ChannelArguments& args) { + return CreateTestChannel(server, cred_type, override_hostname, + use_prod_roots, creds, args, + /*interceptor_creators=*/{}); +} + +std::shared_ptr CreateTestChannel( + const grpc::string& server, const grpc::string& override_hostname, + testing::transport_security security_type, bool use_prod_roots, + const std::shared_ptr& creds, + const ChannelArguments& args) { + return CreateTestChannel(server, override_hostname, security_type, + use_prod_roots, creds, args, + /*interceptor_creators=*/{}); +} + +std::shared_ptr CreateTestChannel( + const grpc::string& server, const grpc::string& override_hostname, + testing::transport_security security_type, bool use_prod_roots, + const std::shared_ptr& creds) { + return CreateTestChannel(server, override_hostname, security_type, + use_prod_roots, creds, ChannelArguments()); +} + +std::shared_ptr CreateTestChannel( + const grpc::string& server, const grpc::string& override_hostname, + testing::transport_security security_type, bool use_prod_roots) { + return CreateTestChannel(server, override_hostname, security_type, + use_prod_roots, std::shared_ptr()); +} + +// Shortcut for end2end and interop tests. +std::shared_ptr CreateTestChannel( + const grpc::string& server, testing::transport_security security_type) { + return CreateTestChannel(server, "foo.test.google.fr", security_type, false); +} + +std::shared_ptr CreateTestChannel( + const grpc::string& server, const grpc::string& credential_type, + const std::shared_ptr& creds) { + ChannelArguments channel_args; + std::shared_ptr channel_creds = + testing::GetCredentialsProvider()->GetChannelCredentials(credential_type, + &channel_args); + GPR_ASSERT(channel_creds != nullptr); + if (creds.get()) { + channel_creds = CompositeChannelCredentials(channel_creds, creds); + } + return CreateCustomChannel(server, channel_creds, channel_args); +} + +std::shared_ptr CreateTestChannel( + const grpc::string& server, const grpc::string& cred_type, + const grpc::string& override_hostname, bool use_prod_roots, + const std::shared_ptr& creds, + const ChannelArguments& args, + std::vector< + std::unique_ptr> + interceptor_creators) { ChannelArguments channel_args(args); std::shared_ptr channel_creds; if (cred_type.empty()) { - return CreateCustomChannel(server, InsecureChannelCredentials(), args); + if (interceptor_creators.empty()) { + return CreateCustomChannel(server, InsecureChannelCredentials(), args); + } else { + return experimental::CreateCustomChannelWithInterceptors( + server, InsecureChannelCredentials(), args, + std::move(interceptor_creators)); + } } else if (cred_type == testing::kTlsCredentialsType) { // cred_type == "ssl" if (use_prod_roots) { gpr_once_init(&g_once_init_add_prod_ssl_provider, &AddProdSslType); @@ -95,54 +159,62 @@ std::shared_ptr CreateTestChannel( if (creds.get()) { channel_creds = CompositeChannelCredentials(channel_creds, creds); } - return CreateCustomChannel(connect_to, channel_creds, channel_args); + if (interceptor_creators.empty()) { + return CreateCustomChannel(connect_to, channel_creds, channel_args); + } else { + return experimental::CreateCustomChannelWithInterceptors( + connect_to, channel_creds, channel_args, + std::move(interceptor_creators)); + } } else { channel_creds = testing::GetCredentialsProvider()->GetChannelCredentials( cred_type, &channel_args); GPR_ASSERT(channel_creds != nullptr); - return CreateCustomChannel(server, channel_creds, args); + if (interceptor_creators.empty()) { + return CreateCustomChannel(server, channel_creds, args); + } else { + return experimental::CreateCustomChannelWithInterceptors( + server, channel_creds, args, std::move(interceptor_creators)); + } } } std::shared_ptr CreateTestChannel( const grpc::string& server, const grpc::string& override_hostname, testing::transport_security security_type, bool use_prod_roots, - const std::shared_ptr& creds, - const ChannelArguments& args) { - grpc::string type = + const std::shared_ptr& creds, const ChannelArguments& args, + std::vector< + std::unique_ptr> + interceptor_creators) { + grpc::string credential_type = security_type == testing::ALTS ? testing::kAltsCredentialsType : (security_type == testing::TLS ? testing::kTlsCredentialsType : testing::kInsecureCredentialsType); - return CreateTestChannel(server, type, override_hostname, use_prod_roots, - creds, args); + return CreateTestChannel( + server, credential_type, override_hostname, use_prod_roots, creds, args, + std::move(interceptor_creators)); } std::shared_ptr CreateTestChannel( const grpc::string& server, const grpc::string& override_hostname, testing::transport_security security_type, bool use_prod_roots, - const std::shared_ptr& creds) { - return CreateTestChannel(server, override_hostname, security_type, - use_prod_roots, creds, ChannelArguments()); -} - -std::shared_ptr CreateTestChannel( - const grpc::string& server, const grpc::string& override_hostname, - testing::transport_security security_type, bool use_prod_roots) { - return CreateTestChannel(server, override_hostname, security_type, - use_prod_roots, std::shared_ptr()); -} - -// Shortcut for end2end and interop tests. -std::shared_ptr CreateTestChannel( - const grpc::string& server, testing::transport_security security_type) { - return CreateTestChannel(server, "foo.test.google.fr", security_type, false); + const std::shared_ptr& creds, + std::vector< + std::unique_ptr> + interceptor_creators) { + return CreateTestChannel( + server, override_hostname, security_type, use_prod_roots, creds, + ChannelArguments(), std::move(interceptor_creators)); } std::shared_ptr CreateTestChannel( const grpc::string& server, const grpc::string& credential_type, - const std::shared_ptr& creds) { + const std::shared_ptr& creds, + std::vector< + std::unique_ptr> + interceptor_creators) { ChannelArguments channel_args; std::shared_ptr channel_creds = testing::GetCredentialsProvider()->GetChannelCredentials(credential_type, @@ -151,7 +223,8 @@ std::shared_ptr CreateTestChannel( if (creds.get()) { channel_creds = CompositeChannelCredentials(channel_creds, creds); } - return CreateCustomChannel(server, channel_creds, channel_args); + return experimental::CreateCustomChannelWithInterceptors( + server, channel_creds, channel_args, std::move(interceptor_creators)); } } // namespace grpc diff --git a/test/cpp/util/create_test_channel.h b/test/cpp/util/create_test_channel.h index c615fb76536..e706acc6072 100644 --- a/test/cpp/util/create_test_channel.h +++ b/test/cpp/util/create_test_channel.h @@ -21,6 +21,7 @@ #include +#include #include namespace grpc { @@ -60,6 +61,38 @@ std::shared_ptr CreateTestChannel( const grpc::string& server, const grpc::string& credential_type, const std::shared_ptr& creds); +std::shared_ptr CreateTestChannel( + const grpc::string& server, const grpc::string& override_hostname, + testing::transport_security security_type, bool use_prod_roots, + const std::shared_ptr& creds, + std::vector< + std::unique_ptr> + interceptor_creators); + +std::shared_ptr CreateTestChannel( + const grpc::string& server, const grpc::string& override_hostname, + testing::transport_security security_type, bool use_prod_roots, + const std::shared_ptr& creds, const ChannelArguments& args, + std::vector< + std::unique_ptr> + interceptor_creators); + +std::shared_ptr CreateTestChannel( + const grpc::string& server, const grpc::string& cred_type, + const grpc::string& override_hostname, bool use_prod_roots, + const std::shared_ptr& creds, + const ChannelArguments& args, + std::vector< + std::unique_ptr> + interceptor_creators); + +std::shared_ptr CreateTestChannel( + const grpc::string& server, const grpc::string& credential_type, + const std::shared_ptr& creds, + std::vector< + std::unique_ptr> + interceptor_creators); + } // namespace grpc #endif // GRPC_TEST_CPP_UTIL_CREATE_TEST_CHANNEL_H