Merge pull request #9317 from yang-g/test_credentials

manual revert of #8901
pull/9149/head^2
Yang Gao 8 years ago committed by GitHub
commit 1d8db03492
  1. 1
      build.yaml
  2. 19
      test/cpp/end2end/async_end2end_test.cc
  3. 20
      test/cpp/end2end/end2end_test.cc
  4. 1
      test/cpp/interop/client.cc
  5. 10
      test/cpp/interop/client_helper.cc
  6. 1
      test/cpp/interop/interop_server.cc
  7. 18
      test/cpp/interop/server_helper.cc
  8. 1
      test/cpp/interop/stress_test.cc
  9. 64
      test/cpp/util/create_test_channel.cc
  10. 4
      test/cpp/util/create_test_channel.h
  11. 52
      test/cpp/util/test_credentials_provider.cc
  12. 50
      test/cpp/util/test_credentials_provider.h
  13. 1
      tools/run_tests/generated/sources_and_headers.json
  14. 3
      vsprojects/vcxproj/interop_server_helper/interop_server_helper.vcxproj

@ -1259,6 +1259,7 @@ libs:
src: src:
- test/cpp/interop/server_helper.cc - test/cpp/interop/server_helper.cc
deps: deps:
- grpc++_test_util
- grpc_test_util - grpc_test_util
- grpc++ - grpc++
- grpc - grpc

@ -254,7 +254,8 @@ class AsyncEnd2endTest : public ::testing::TestWithParam<TestScenario> {
// Setup server // Setup server
ServerBuilder builder; ServerBuilder builder;
auto server_creds = GetServerCredentials(GetParam().credentials_type); auto server_creds = GetCredentialsProvider()->GetServerCredentials(
GetParam().credentials_type);
builder.AddListeningPort(server_address_.str(), server_creds); builder.AddListeningPort(server_address_.str(), server_creds);
builder.RegisterService(&service_); builder.RegisterService(&service_);
cq_ = builder.AddCompletionQueue(); cq_ = builder.AddCompletionQueue();
@ -283,8 +284,8 @@ class AsyncEnd2endTest : public ::testing::TestWithParam<TestScenario> {
void ResetStub() { void ResetStub() {
ChannelArguments args; ChannelArguments args;
auto channel_creds = auto channel_creds = GetCredentialsProvider()->GetChannelCredentials(
GetChannelCredentials(GetParam().credentials_type, &args); GetParam().credentials_type, &args);
std::shared_ptr<Channel> channel = std::shared_ptr<Channel> channel =
CreateCustomChannel(server_address_.str(), channel_creds, args); CreateCustomChannel(server_address_.str(), channel_creds, args);
stub_ = grpc::testing::EchoTestService::NewStub(channel); stub_ = grpc::testing::EchoTestService::NewStub(channel);
@ -892,8 +893,8 @@ TEST_P(AsyncEnd2endTest, ServerCheckDone) {
TEST_P(AsyncEnd2endTest, UnimplementedRpc) { TEST_P(AsyncEnd2endTest, UnimplementedRpc) {
ChannelArguments args; ChannelArguments args;
auto channel_creds = auto channel_creds = GetCredentialsProvider()->GetChannelCredentials(
GetChannelCredentials(GetParam().credentials_type, &args); GetParam().credentials_type, &args);
std::shared_ptr<Channel> channel = std::shared_ptr<Channel> channel =
CreateCustomChannel(server_address_.str(), channel_creds, args); CreateCustomChannel(server_address_.str(), channel_creds, args);
std::unique_ptr<grpc::testing::UnimplementedEchoService::Stub> stub; std::unique_ptr<grpc::testing::UnimplementedEchoService::Stub> stub;
@ -1404,11 +1405,15 @@ std::vector<TestScenario> CreateTestScenarios(bool test_disable_blocking,
std::vector<grpc::string> credentials_types; std::vector<grpc::string> credentials_types;
std::vector<grpc::string> messages; std::vector<grpc::string> messages;
credentials_types.push_back(kInsecureCredentialsType); if (GetCredentialsProvider()->GetChannelCredentials(kInsecureCredentialsType,
auto sec_list = GetSecureCredentialsTypeList(); nullptr) != nullptr) {
credentials_types.push_back(kInsecureCredentialsType);
}
auto sec_list = GetCredentialsProvider()->GetSecureCredentialsTypeList();
for (auto sec = sec_list.begin(); sec != sec_list.end(); sec++) { for (auto sec = sec_list.begin(); sec != sec_list.end(); sec++) {
credentials_types.push_back(*sec); credentials_types.push_back(*sec);
} }
GPR_ASSERT(!credentials_types.empty());
messages.push_back("Hello"); messages.push_back("Hello");
for (int sz = 1; sz < test_big_limit; sz *= 2) { for (int sz = 1; sz < test_big_limit; sz *= 2) {

@ -242,7 +242,8 @@ class End2endTest : public ::testing::TestWithParam<TestScenario> {
// Setup server // Setup server
ServerBuilder builder; ServerBuilder builder;
ConfigureServerBuilder(&builder); ConfigureServerBuilder(&builder);
auto server_creds = GetServerCredentials(GetParam().credentials_type); auto server_creds = GetCredentialsProvider()->GetServerCredentials(
GetParam().credentials_type);
if (GetParam().credentials_type != kInsecureCredentialsType) { if (GetParam().credentials_type != kInsecureCredentialsType) {
server_creds->SetAuthMetadataProcessor(processor); server_creds->SetAuthMetadataProcessor(processor);
} }
@ -270,8 +271,8 @@ class End2endTest : public ::testing::TestWithParam<TestScenario> {
} }
EXPECT_TRUE(is_server_started_); EXPECT_TRUE(is_server_started_);
ChannelArguments args; ChannelArguments args;
auto channel_creds = auto channel_creds = GetCredentialsProvider()->GetChannelCredentials(
GetChannelCredentials(GetParam().credentials_type, &args); GetParam().credentials_type, &args);
if (!user_agent_prefix_.empty()) { if (!user_agent_prefix_.empty()) {
args.SetUserAgentPrefix(user_agent_prefix_); args.SetUserAgentPrefix(user_agent_prefix_);
} }
@ -1520,11 +1521,18 @@ std::vector<TestScenario> CreateTestScenarios(bool use_proxy,
std::vector<TestScenario> scenarios; std::vector<TestScenario> scenarios;
std::vector<grpc::string> credentials_types; std::vector<grpc::string> credentials_types;
if (test_secure) { if (test_secure) {
credentials_types = GetSecureCredentialsTypeList(); credentials_types =
GetCredentialsProvider()->GetSecureCredentialsTypeList();
} }
if (test_insecure) { if (test_insecure) {
credentials_types.push_back(kInsecureCredentialsType); // Only add insecure credentials type when it is registered with the
// provider. User may create providers that do not have insecure.
if (GetCredentialsProvider()->GetChannelCredentials(
kInsecureCredentialsType, nullptr) != nullptr) {
credentials_types.push_back(kInsecureCredentialsType);
}
} }
GPR_ASSERT(!credentials_types.empty());
for (auto it = credentials_types.begin(); it != credentials_types.end(); for (auto it = credentials_types.begin(); it != credentials_types.end();
++it) { ++it) {
scenarios.emplace_back(false, *it); scenarios.emplace_back(false, *it);
@ -1541,7 +1549,7 @@ INSTANTIATE_TEST_CASE_P(End2end, End2endTest,
INSTANTIATE_TEST_CASE_P(End2endServerTryCancel, End2endServerTryCancelTest, INSTANTIATE_TEST_CASE_P(End2endServerTryCancel, End2endServerTryCancelTest,
::testing::ValuesIn(CreateTestScenarios(false, true, ::testing::ValuesIn(CreateTestScenarios(false, true,
false))); true)));
INSTANTIATE_TEST_CASE_P(ProxyEnd2end, ProxyEnd2endTest, INSTANTIATE_TEST_CASE_P(ProxyEnd2end, ProxyEnd2endTest,
::testing::ValuesIn(CreateTestScenarios(true, true, ::testing::ValuesIn(CreateTestScenarios(true, true,

@ -49,6 +49,7 @@
#include "test/cpp/util/test_config.h" #include "test/cpp/util/test_config.h"
DEFINE_bool(use_tls, false, "Whether to use tls."); DEFINE_bool(use_tls, false, "Whether to use tls.");
DEFINE_string(custom_credentials_type, "", "User provided credentials type.");
DEFINE_bool(use_test_ca, false, "False to use SSL roots for google"); DEFINE_bool(use_test_ca, false, "False to use SSL roots for google");
DEFINE_int32(server_port, 0, "Server port."); DEFINE_int32(server_port, 0, "Server port.");
DEFINE_string(server_host, "127.0.0.1", "Server host to connect to"); DEFINE_string(server_host, "127.0.0.1", "Server host to connect to");

@ -50,8 +50,10 @@
#include "src/cpp/client/secure_credentials.h" #include "src/cpp/client/secure_credentials.h"
#include "test/core/security/oauth2_utils.h" #include "test/core/security/oauth2_utils.h"
#include "test/cpp/util/create_test_channel.h" #include "test/cpp/util/create_test_channel.h"
#include "test/cpp/util/test_credentials_provider.h"
DECLARE_bool(use_tls); DECLARE_bool(use_tls);
DECLARE_string(custom_credentials_type);
DECLARE_bool(use_test_ca); DECLARE_bool(use_test_ca);
DECLARE_int32(server_port); DECLARE_int32(server_port);
DECLARE_string(server_host); DECLARE_string(server_host);
@ -114,8 +116,12 @@ std::shared_ptr<Channel> CreateChannelForTestCase(
creds = AccessTokenCredentials(raw_token); creds = AccessTokenCredentials(raw_token);
GPR_ASSERT(creds); GPR_ASSERT(creds);
} }
return CreateTestChannel(host_port, FLAGS_server_host_override, FLAGS_use_tls, if (FLAGS_custom_credentials_type.empty()) {
!FLAGS_use_test_ca, creds); return CreateTestChannel(host_port, FLAGS_server_host_override,
FLAGS_use_tls, !FLAGS_use_test_ca, creds);
} else {
return CreateTestChannel(host_port, FLAGS_custom_credentials_type, creds);
}
} }
} // namespace testing } // namespace testing

@ -56,6 +56,7 @@
#include "test/cpp/util/test_config.h" #include "test/cpp/util/test_config.h"
DEFINE_bool(use_tls, false, "Whether to use tls."); DEFINE_bool(use_tls, false, "Whether to use tls.");
DEFINE_string(custom_credentials_type, "", "User provided credentials type.");
DEFINE_int32(port, 0, "Server port."); DEFINE_int32(port, 0, "Server port.");
DEFINE_int32(max_send_message_size, -1, "The maximum send message size."); DEFINE_int32(max_send_message_size, -1, "The maximum send message size.");

@ -39,23 +39,23 @@
#include <grpc++/security/server_credentials.h> #include <grpc++/security/server_credentials.h>
#include "src/core/lib/surface/call_test_only.h" #include "src/core/lib/surface/call_test_only.h"
#include "test/core/end2end/data/ssl_test_data.h" #include "test/cpp/util/test_credentials_provider.h"
DECLARE_bool(use_tls); DECLARE_bool(use_tls);
DECLARE_string(custom_credentials_type);
namespace grpc { namespace grpc {
namespace testing { namespace testing {
std::shared_ptr<ServerCredentials> CreateInteropServerCredentials() { std::shared_ptr<ServerCredentials> CreateInteropServerCredentials() {
if (FLAGS_use_tls) { if (!FLAGS_custom_credentials_type.empty()) {
SslServerCredentialsOptions::PemKeyCertPair pkcp = {test_server1_key, return GetCredentialsProvider()->GetServerCredentials(
test_server1_cert}; FLAGS_custom_credentials_type);
SslServerCredentialsOptions ssl_opts; } else if (FLAGS_use_tls) {
ssl_opts.pem_root_certs = ""; return GetCredentialsProvider()->GetServerCredentials(kTlsCredentialsType);
ssl_opts.pem_key_cert_pairs.push_back(pkcp);
return SslServerCredentials(ssl_opts);
} else { } else {
return InsecureServerCredentials(); return GetCredentialsProvider()->GetServerCredentials(
kInsecureCredentialsType);
} }
} }

@ -147,6 +147,7 @@ DEFINE_bool(do_not_abort_on_transient_failures, true,
// Options from client.cc (for compatibility with interop test). // Options from client.cc (for compatibility with interop test).
// TODO(sreek): Consolidate overlapping options // TODO(sreek): Consolidate overlapping options
DEFINE_bool(use_tls, false, "Whether to use tls."); DEFINE_bool(use_tls, false, "Whether to use tls.");
DEFINE_string(custom_credentials_type, "", "User provided credentials type.");
DEFINE_bool(use_test_ca, false, "False to use SSL roots for google"); DEFINE_bool(use_test_ca, false, "False to use SSL roots for google");
DEFINE_int32(server_port, 0, "Server port."); DEFINE_int32(server_port, 0, "Server port.");
DEFINE_string(server_host, "127.0.0.1", "Server host to connect to"); DEFINE_string(server_host, "127.0.0.1", "Server host to connect to");

@ -35,11 +35,37 @@
#include <grpc++/create_channel.h> #include <grpc++/create_channel.h>
#include <grpc++/security/credentials.h> #include <grpc++/security/credentials.h>
#include <grpc/support/log.h>
#include "test/core/end2end/data/ssl_test_data.h" #include "test/cpp/util/test_credentials_provider.h"
namespace grpc { namespace grpc {
namespace {
const char kProdTlsCredentialsType[] = "prod_ssl";
class SslCredentialProvider : public testing::CredentialTypeProvider {
public:
std::shared_ptr<ChannelCredentials> GetChannelCredentials(
grpc::ChannelArguments* args) override {
return SslCredentials(SslCredentialsOptions());
}
std::shared_ptr<ServerCredentials> GetServerCredentials() override {
return nullptr;
}
};
gpr_once g_once_init_add_prod_ssl_provider = GPR_ONCE_INIT;
// Register ssl with non-test roots type to the credentials provider.
void AddProdSslType() {
testing::GetCredentialsProvider()->AddSecureType(
kProdTlsCredentialsType, std::unique_ptr<testing::CredentialTypeProvider>(
new SslCredentialProvider));
}
} // namespace
// When ssl is enabled, if server is empty, override_hostname is used to // When ssl is enabled, if server is empty, override_hostname is used to
// create channel. Otherwise, connect to server and override hostname if // create channel. Otherwise, connect to server and override hostname if
// override_hostname is provided. // override_hostname is provided.
@ -61,16 +87,22 @@ std::shared_ptr<Channel> CreateTestChannel(
const std::shared_ptr<CallCredentials>& creds, const std::shared_ptr<CallCredentials>& creds,
const ChannelArguments& args) { const ChannelArguments& args) {
ChannelArguments channel_args(args); ChannelArguments channel_args(args);
std::shared_ptr<ChannelCredentials> channel_creds;
if (enable_ssl) { if (enable_ssl) {
const char* roots_certs = use_prod_roots ? "" : test_root_cert; if (use_prod_roots) {
SslCredentialsOptions ssl_opts = {roots_certs, "", ""}; gpr_once_init(&g_once_init_add_prod_ssl_provider, &AddProdSslType);
channel_creds = testing::GetCredentialsProvider()->GetChannelCredentials(
std::shared_ptr<ChannelCredentials> channel_creds = kProdTlsCredentialsType, &channel_args);
SslCredentials(ssl_opts); if (!server.empty() && !override_hostname.empty()) {
channel_args.SetSslTargetNameOverride(override_hostname);
if (!server.empty() && !override_hostname.empty()) { }
channel_args.SetSslTargetNameOverride(override_hostname); } else {
// override_hostname is discarded as the provider handles it.
channel_creds = testing::GetCredentialsProvider()->GetChannelCredentials(
testing::kTlsCredentialsType, &channel_args);
} }
GPR_ASSERT(channel_creds != nullptr);
const grpc::string& connect_to = const grpc::string& connect_to =
server.empty() ? override_hostname : server; server.empty() ? override_hostname : server;
if (creds.get()) { if (creds.get()) {
@ -103,4 +135,18 @@ std::shared_ptr<Channel> CreateTestChannel(const grpc::string& server,
return CreateTestChannel(server, "foo.test.google.fr", enable_ssl, false); return CreateTestChannel(server, "foo.test.google.fr", enable_ssl, false);
} }
std::shared_ptr<Channel> CreateTestChannel(
const grpc::string& server, const grpc::string& credential_type,
const std::shared_ptr<CallCredentials>& creds) {
ChannelArguments channel_args;
std::shared_ptr<ChannelCredentials> channel_creds =
testing::GetCredentialsProvider()->GetChannelCredentials(credential_type,
&channel_args);
GPR_ASSERT(channel_creds != nullptr);
if (creds.get()) {
channel_creds = CompositeChannelCredentials(channel_creds, creds);
}
return CreateCustomChannel(server, channel_creds, channel_args);
}
} // namespace grpc } // namespace grpc

@ -59,6 +59,10 @@ std::shared_ptr<Channel> CreateTestChannel(
const std::shared_ptr<CallCredentials>& creds, const std::shared_ptr<CallCredentials>& creds,
const ChannelArguments& args); const ChannelArguments& args);
std::shared_ptr<Channel> CreateTestChannel(
const grpc::string& server, const grpc::string& credential_type,
const std::shared_ptr<CallCredentials>& creds);
} // namespace grpc } // namespace grpc
#endif // GRPC_TEST_CPP_UTIL_CREATE_TEST_CHANNEL_H #endif // GRPC_TEST_CPP_UTIL_CREATE_TEST_CHANNEL_H

@ -43,25 +43,9 @@
#include "test/core/end2end/data/ssl_test_data.h" #include "test/core/end2end/data/ssl_test_data.h"
namespace grpc { namespace grpc {
namespace testing {
namespace { namespace {
using grpc::testing::CredentialTypeProvider;
// Provide test credentials. Thread-safe.
class CredentialsProvider {
public:
virtual ~CredentialsProvider() {}
virtual void AddSecureType(
const grpc::string& type,
std::unique_ptr<CredentialTypeProvider> type_provider) = 0;
virtual std::shared_ptr<ChannelCredentials> GetChannelCredentials(
const grpc::string& type, ChannelArguments* args) = 0;
virtual std::shared_ptr<ServerCredentials> GetServerCredentials(
const grpc::string& type) = 0;
virtual std::vector<grpc::string> GetSecureCredentialsTypeList() = 0;
};
class DefaultCredentialsProvider : public CredentialsProvider { class DefaultCredentialsProvider : public CredentialsProvider {
public: public:
~DefaultCredentialsProvider() override {} ~DefaultCredentialsProvider() override {}
@ -145,37 +129,21 @@ class DefaultCredentialsProvider : public CredentialsProvider {
added_secure_type_providers_; added_secure_type_providers_;
}; };
gpr_once g_once_init_provider = GPR_ONCE_INIT;
CredentialsProvider* g_provider = nullptr; CredentialsProvider* g_provider = nullptr;
void CreateDefaultProvider() { g_provider = new DefaultCredentialsProvider; }
CredentialsProvider* GetProvider() {
gpr_once_init(&g_once_init_provider, &CreateDefaultProvider);
return g_provider;
}
} // namespace } // namespace
namespace testing { CredentialsProvider* GetCredentialsProvider() {
if (g_provider == nullptr) {
void AddSecureType(const grpc::string& type, g_provider = new DefaultCredentialsProvider;
std::unique_ptr<CredentialTypeProvider> type_provider) { }
GetProvider()->AddSecureType(type, std::move(type_provider)); return g_provider;
}
std::shared_ptr<ChannelCredentials> GetChannelCredentials(
const grpc::string& type, ChannelArguments* args) {
return GetProvider()->GetChannelCredentials(type, args);
}
std::shared_ptr<ServerCredentials> GetServerCredentials(
const grpc::string& type) {
return GetProvider()->GetServerCredentials(type);
} }
std::vector<grpc::string> GetSecureCredentialsTypeList() { void SetCredentialsProvider(CredentialsProvider* provider) {
return GetProvider()->GetSecureCredentialsTypeList(); // For now, forbids overriding provider.
GPR_ASSERT(g_provider == nullptr);
g_provider = provider;
} }
} // namespace testing } // namespace testing

@ -59,23 +59,39 @@ class CredentialTypeProvider {
virtual std::shared_ptr<ServerCredentials> GetServerCredentials() = 0; virtual std::shared_ptr<ServerCredentials> GetServerCredentials() = 0;
}; };
// Add a secure type in addition to the defaults above // Provide test credentials. Thread-safe.
// (kInsecureCredentialsType, kTlsCredentialsType) that can be returned from the class CredentialsProvider {
// functions below. public:
void AddSecureType(const grpc::string& type, virtual ~CredentialsProvider() {}
std::unique_ptr<CredentialTypeProvider> type_provider);
// Add a secure type in addition to the defaults. The default provider has
// Provide channel credentials according to the given type. Alter the channel // (kInsecureCredentialsType, kTlsCredentialsType).
// arguments if needed. virtual void AddSecureType(
std::shared_ptr<ChannelCredentials> GetChannelCredentials( const grpc::string& type,
const grpc::string& type, ChannelArguments* args); std::unique_ptr<CredentialTypeProvider> type_provider) = 0;
// Provide server credentials according to the given type. // Provide channel credentials according to the given type. Alter the channel
std::shared_ptr<ServerCredentials> GetServerCredentials( // arguments if needed. Return nullptr if type is not registered.
const grpc::string& type); virtual std::shared_ptr<ChannelCredentials> GetChannelCredentials(
const grpc::string& type, ChannelArguments* args) = 0;
// Provide a list of secure credentials type.
std::vector<grpc::string> GetSecureCredentialsTypeList(); // Provide server credentials according to the given type.
// Return nullptr if type is not registered.
virtual std::shared_ptr<ServerCredentials> GetServerCredentials(
const grpc::string& type) = 0;
// Provide a list of secure credentials type.
virtual std::vector<grpc::string> GetSecureCredentialsTypeList() = 0;
};
// Get the current provider. Create a default one if not set.
// Not thread-safe.
CredentialsProvider* GetCredentialsProvider();
// Set the global provider. Takes ownership. The previous set provider will be
// destroyed.
// Not thread-safe.
void SetCredentialsProvider(CredentialsProvider* provider);
} // namespace testing } // namespace testing
} // namespace grpc } // namespace grpc

@ -5447,6 +5447,7 @@
"gpr", "gpr",
"grpc", "grpc",
"grpc++", "grpc++",
"grpc++_test_util",
"grpc_test_util" "grpc_test_util"
], ],
"headers": [ "headers": [

@ -154,6 +154,9 @@
</ClCompile> </ClCompile>
</ItemGroup> </ItemGroup>
<ItemGroup> <ItemGroup>
<ProjectReference Include="$(SolutionDir)\..\vsprojects\vcxproj\.\grpc++_test_util\grpc++_test_util.vcxproj">
<Project>{0BE77741-552A-929B-A497-4EF7ECE17A64}</Project>
</ProjectReference>
<ProjectReference Include="$(SolutionDir)\..\vsprojects\vcxproj\.\grpc_test_util\grpc_test_util.vcxproj"> <ProjectReference Include="$(SolutionDir)\..\vsprojects\vcxproj\.\grpc_test_util\grpc_test_util.vcxproj">
<Project>{17BCAFC0-5FDC-4C94-AEB9-95F3E220614B}</Project> <Project>{17BCAFC0-5FDC-4C94-AEB9-95F3E220614B}</Project>
</ProjectReference> </ProjectReference>

Loading…
Cancel
Save