diff --git a/test/cpp/util/test_credentials_provider.cc b/test/cpp/util/test_credentials_provider.cc index cfd3ebbb111..65d32057673 100644 --- a/test/cpp/util/test_credentials_provider.cc +++ b/test/cpp/util/test_credentials_provider.cc @@ -34,6 +34,8 @@ #include "test/cpp/util/test_credentials_provider.h" +#include + #include #include @@ -48,12 +50,36 @@ using grpc::InsecureServerCredentials; using grpc::ServerCredentials; using grpc::SslCredentialsOptions; using grpc::SslServerCredentialsOptions; -using grpc::testing::CredentialsProvider; +using grpc::testing::CredentialTypeProvider; + +// Provide test credentials. Thread-safe. +class CredentialsProvider { + public: + virtual ~CredentialsProvider() {} + + virtual void AddSecureType( + const grpc::string& type, + std::unique_ptr type_provider) = 0; + virtual std::shared_ptr GetChannelCredentials( + const grpc::string& type, ChannelArguments* args) = 0; + virtual std::shared_ptr GetServerCredentials( + const grpc::string& type) = 0; + virtual std::vector GetSecureCredentialsTypeList() = 0; +}; class DefaultCredentialsProvider : public CredentialsProvider { public: ~DefaultCredentialsProvider() override {} + void AddSecureType( + const grpc::string& type, + std::unique_ptr type_provider) override { + // This clobbers any existing entry for type, except the defaults, which + // can't be clobbered. + grpc::unique_lock lock(mu_); + added_secure_types_[type] = std::move(type_provider); + } + std::shared_ptr GetChannelCredentials( const grpc::string& type, ChannelArguments* args) override { if (type == grpc::testing::kInsecureCredentialsType) { @@ -63,9 +89,14 @@ class DefaultCredentialsProvider : public CredentialsProvider { args->SetSslTargetNameOverride("foo.test.google.fr"); return SslCredentials(ssl_opts); } else { - gpr_log(GPR_ERROR, "Unsupported credentials type %s.", type.c_str()); + grpc::unique_lock lock(mu_); + auto it(added_secure_types_.find(type)); + if (it == added_secure_types_.end()) { + gpr_log(GPR_ERROR, "Unsupported credentials type %s.", type.c_str()); + return nullptr; + } + return it->second->GetChannelCredentials(args); } - return nullptr; } std::shared_ptr GetServerCredentials( @@ -80,35 +111,40 @@ class DefaultCredentialsProvider : public CredentialsProvider { ssl_opts.pem_key_cert_pairs.push_back(pkcp); return SslServerCredentials(ssl_opts); } else { - gpr_log(GPR_ERROR, "Unsupported credentials type %s.", type.c_str()); + grpc::unique_lock lock(mu_); + auto it(added_secure_types_.find(type)); + if (it == added_secure_types_.end()) { + gpr_log(GPR_ERROR, "Unsupported credentials type %s.", type.c_str()); + return nullptr; + } + return it->second->GetServerCredentials(); } - return nullptr; } std::vector GetSecureCredentialsTypeList() override { std::vector types; types.push_back(grpc::testing::kTlsCredentialsType); + grpc::unique_lock lock(mu_); + for (const auto& type_pair : added_secure_types_) { + types.push_back(type_pair.first); + } return types; } + + private: + grpc::mutex mu_; + std::unordered_map > + added_secure_types_; }; -gpr_once g_once_init_provider_mu = GPR_ONCE_INIT; -grpc::mutex* g_provider_mu = nullptr; +gpr_once g_once_init_provider = GPR_ONCE_INIT; CredentialsProvider* g_provider = nullptr; -void InitProviderMu() { - g_provider_mu = new grpc::mutex; -} - -grpc::mutex& GetMu() { - gpr_once_init(&g_once_init_provider_mu, &InitProviderMu); - return *g_provider_mu; +void CreateDefaultProvider() { + g_provider = new DefaultCredentialsProvider; } CredentialsProvider* GetProvider() { - grpc::unique_lock lock(GetMu()); - if (g_provider == nullptr) { - g_provider = new DefaultCredentialsProvider; - } + gpr_once_init(&g_once_init_provider, &CreateDefaultProvider); return g_provider; } @@ -117,15 +153,9 @@ CredentialsProvider* GetProvider() { namespace grpc { namespace testing { -// Note that it is not thread-safe to set a provider while concurrently using -// the previously set provider, as this deletes and replaces it. nullptr may be -// given to reset to the default. -void SetTestCredentialsProvider(std::unique_ptr provider) { - grpc::unique_lock lock(GetMu()); - if (g_provider != nullptr) { - delete g_provider; - } - g_provider = provider.release(); +void AddSecureType(const grpc::string& type, + std::unique_ptr type_provider) { + GetProvider()->AddSecureType(type, std::move(type_provider)); } std::shared_ptr GetChannelCredentials( diff --git a/test/cpp/util/test_credentials_provider.h b/test/cpp/util/test_credentials_provider.h index a6b547cb070..50fadb53a24 100644 --- a/test/cpp/util/test_credentials_provider.h +++ b/test/cpp/util/test_credentials_provider.h @@ -46,20 +46,21 @@ namespace testing { const char kInsecureCredentialsType[] = "INSECURE_CREDENTIALS"; const char kTlsCredentialsType[] = "TLS_CREDENTIALS"; -class CredentialsProvider { +// Provide test credentials of a particular type. +class CredentialTypeProvider { public: - virtual ~CredentialsProvider() {} + virtual ~CredentialTypeProvider() {} virtual std::shared_ptr GetChannelCredentials( - const grpc::string& type, ChannelArguments* args) = 0; - virtual std::shared_ptr GetServerCredentials( - const grpc::string& type) = 0; - virtual std::vector GetSecureCredentialsTypeList() = 0; + ChannelArguments* args) = 0; + virtual std::shared_ptr GetServerCredentials() = 0; }; -// Set the CredentialsProvider used by the other functions in this file. If this -// is not set, a default provider will be used. -void SetTestCredentialsProvider(std::unique_ptr provider); +// Add a secure type in addition to the defaults above +// (kInsecureCredentialsType, kTlsCredentialsType) that can be returned from the +// functions below. +void AddSecureType(const grpc::string& type, + std::unique_ptr type_provider); // Provide channel credentials according to the given type. Alter the channel // arguments if needed.