[xDS] support "tls" channel creds in bootstrap file (#33234)

Implements [gRFC
A65](https://github.com/grpc/proposal/blob/master/A65-xds-mtls-creds-in-bootstrap.md).

Fixes #32977.
pull/33468/head
Mark D. Roth 1 year ago committed by GitHub
parent d8db05a068
commit 1b31c6e0ba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 1
      doc/grpc_xds_features.md
  2. 9
      src/core/BUILD
  3. 63
      src/core/ext/xds/xds_bootstrap_grpc.cc
  4. 18
      src/core/ext/xds/xds_bootstrap_grpc.h
  5. 4
      src/core/ext/xds/xds_transport_grpc.cc
  6. 78
      src/core/lib/security/credentials/channel_creds_registry.h
  7. 178
      src/core/lib/security/credentials/channel_creds_registry_init.cc
  8. 2
      src/core/lib/security/credentials/composite/composite_credentials.cc
  9. 4
      src/core/lib/security/credentials/composite/composite_credentials.h
  10. 38
      src/core/lib/security/credentials/fake/fake_credentials.cc
  11. 28
      src/core/lib/security/credentials/fake/fake_credentials.h
  12. 2
      src/core/lib/security/credentials/tls/tls_credentials.cc
  13. 4
      src/core/lib/security/credentials/tls/tls_credentials.h
  14. 151
      test/core/security/channel_creds_registry_test.cc
  15. 26
      test/core/xds/xds_bootstrap_test.cc

@ -72,3 +72,4 @@ Support for [xDS v2 APIs](https://www.envoyproxy.io/docs/envoy/latest/api/api_su
[xDS Federation](https://github.com/cncf/xds/blob/main/proposals/TP1-xds-transport-next.md) | [A47](https://github.com/grpc/proposal/blob/master/A47-xds-federation.md) | v1.55.0 | | | |
[Client-Side Weighted Round Robin LB Policy](https://github.com/envoyproxy/envoy/blob/a6d46b6ac4750720eec9a49abe701f0df9bf8e0a/api/envoy/extensions/load_balancing_policies/client_side_weighted_round_robin/v3/client_side_weighted_round_robin.proto#L36) | [A58](https://github.com/grpc/proposal/blob/master/A58-client-side-weighted-round-robin-lb-policy.md) | v1.55.0 | | | |
[StringMatcher for Header Matching](https://github.com/envoyproxy/envoy/blob/3fe4b8d335fa339ef6f17325c8d31f87ade7bb1a/api/envoy/config/route/v3/route_components.proto#L2280) | [A63](https://github.com/grpc/proposal/blob/master/A63-xds-string-matcher-in-header-matching.md) | v1.56.0 | | | |
mTLS Credentials in xDS Bootstrap File | [A65](https://github.com/grpc/proposal/blob/master/A65-xds-mtls-creds-in-bootstrap.md) | v1.57.0 | | | |

@ -1020,6 +1020,9 @@ grpc_cc_library(
language = "c++",
deps = [
"json",
"json_args",
"ref_counted",
"validation_errors",
"//:gpr_platform",
"//:ref_counted_ptr",
],
@ -4282,8 +4285,14 @@ grpc_cc_library(
"channel_creds_registry",
"grpc_fake_credentials",
"grpc_google_default_credentials",
"grpc_tls_credentials",
"json",
"json_args",
"json_object_loader",
"time",
"validation_errors",
"//:config",
"//:gpr",
"//:gpr_platform",
"//:grpc_security_base",
"//:ref_counted_ptr",

@ -78,20 +78,6 @@ const JsonLoaderInterface* GrpcXdsBootstrap::GrpcNode::JsonLoader(
return loader;
}
//
// GrpcXdsBootstrap::GrpcXdsServer::ChannelCreds
//
const JsonLoaderInterface*
GrpcXdsBootstrap::GrpcXdsServer::ChannelCreds::JsonLoader(const JsonArgs&) {
static const auto* loader =
JsonObjectLoader<ChannelCreds>()
.Field("type", &ChannelCreds::type)
.OptionalField("config", &ChannelCreds::config)
.Finish();
return loader;
}
//
// GrpcXdsBootstrap::GrpcXdsServer
//
@ -111,8 +97,8 @@ bool GrpcXdsBootstrap::GrpcXdsServer::IgnoreResourceDeletion() const {
bool GrpcXdsBootstrap::GrpcXdsServer::Equals(const XdsServer& other) const {
const auto& o = static_cast<const GrpcXdsServer&>(other);
return (server_uri_ == o.server_uri_ &&
channel_creds_.type == o.channel_creds_.type &&
channel_creds_.config == o.channel_creds_.config &&
channel_creds_config_->type() == o.channel_creds_config_->type() &&
channel_creds_config_->Equals(*o.channel_creds_config_) &&
server_features_ == o.server_features_);
}
@ -125,6 +111,24 @@ const JsonLoaderInterface* GrpcXdsBootstrap::GrpcXdsServer::JsonLoader(
return loader;
}
namespace {
struct ChannelCreds {
std::string type;
Json::Object config;
static const JsonLoaderInterface* JsonLoader(const JsonArgs&) {
static const auto* loader =
JsonObjectLoader<ChannelCreds>()
.Field("type", &ChannelCreds::type)
.OptionalField("config", &ChannelCreds::config)
.Finish();
return loader;
}
};
} // namespace
void GrpcXdsBootstrap::GrpcXdsServer::JsonPostLoad(const Json& json,
const JsonArgs& args,
ValidationErrors* errors) {
@ -136,21 +140,20 @@ void GrpcXdsBootstrap::GrpcXdsServer::JsonPostLoad(const Json& json,
for (size_t i = 0; i < channel_creds_list->size(); ++i) {
ValidationErrors::ScopedField field(errors, absl::StrCat("[", i, "]"));
auto& creds = (*channel_creds_list)[i];
// Select the first channel creds type that we support.
if (channel_creds_.type.empty() &&
CoreConfiguration::Get().channel_creds_registry().IsSupported(
// Select the first channel creds type that we support, but
// validate all entries.
if (CoreConfiguration::Get().channel_creds_registry().IsSupported(
creds.type)) {
if (!CoreConfiguration::Get().channel_creds_registry().IsValidConfig(
creds.type, Json::FromObject(creds.config))) {
errors->AddError(absl::StrCat(
"invalid config for channel creds type \"", creds.type, "\""));
continue;
ValidationErrors::ScopedField field(errors, ".config");
auto config =
CoreConfiguration::Get().channel_creds_registry().ParseConfig(
creds.type, Json::FromObject(creds.config), args, errors);
if (channel_creds_config_ == nullptr) {
channel_creds_config_ = std::move(config);
}
channel_creds_.type = std::move(creds.type);
channel_creds_.config = std::move(creds.config);
}
}
if (channel_creds_.type.empty()) {
if (channel_creds_config_ == nullptr) {
errors->AddError("no known creds type found");
}
}
@ -176,10 +179,10 @@ void GrpcXdsBootstrap::GrpcXdsServer::JsonPostLoad(const Json& json,
Json GrpcXdsBootstrap::GrpcXdsServer::ToJson() const {
Json::Object channel_creds_json{
{"type", Json::FromString(channel_creds_.type)},
{"type", Json::FromString(std::string(channel_creds_config_->type()))},
};
if (!channel_creds_.config.empty()) {
channel_creds_json["config"] = Json::FromObject(channel_creds_.config);
if (channel_creds_config_ != nullptr) {
channel_creds_json["config"] = channel_creds_config_->ToJson();
}
Json::Object json{
{"server_uri", Json::FromString(server_uri_)},

@ -35,10 +35,12 @@
#include "src/core/ext/xds/xds_cluster_specifier_plugin.h"
#include "src/core/ext/xds/xds_http_filters.h"
#include "src/core/ext/xds/xds_lb_policy_registry.h"
#include "src/core/lib/gprpp/ref_counted_ptr.h"
#include "src/core/lib/gprpp/validation_errors.h"
#include "src/core/lib/json/json.h"
#include "src/core/lib/json/json_args.h"
#include "src/core/lib/json/json_object_loader.h"
#include "src/core/lib/security/credentials/channel_creds_registry.h"
namespace grpc_core {
@ -82,11 +84,8 @@ class GrpcXdsBootstrap : public XdsBootstrap {
bool Equals(const XdsServer& other) const override;
const std::string& channel_creds_type() const {
return channel_creds_.type;
}
const Json::Object& channel_creds_config() const {
return channel_creds_.config;
RefCountedPtr<ChannelCredsConfig> channel_creds_config() const {
return channel_creds_config_;
}
static const JsonLoaderInterface* JsonLoader(const JsonArgs&);
@ -96,15 +95,8 @@ class GrpcXdsBootstrap : public XdsBootstrap {
Json ToJson() const;
private:
struct ChannelCreds {
std::string type;
Json::Object config;
static const JsonLoaderInterface* JsonLoader(const JsonArgs&);
};
std::string server_uri_;
ChannelCreds channel_creds_;
RefCountedPtr<ChannelCredsConfig> channel_creds_config_;
std::set<std::string> server_features_;
};

@ -47,7 +47,6 @@
#include "src/core/lib/gprpp/time.h"
#include "src/core/lib/iomgr/closure.h"
#include "src/core/lib/iomgr/pollset_set.h"
#include "src/core/lib/json/json.h"
#include "src/core/lib/security/credentials/channel_creds_registry.h"
#include "src/core/lib/security/credentials/credentials.h"
#include "src/core/lib/slice/slice.h"
@ -256,8 +255,7 @@ grpc_channel* CreateXdsChannel(const ChannelArgs& args,
const GrpcXdsBootstrap::GrpcXdsServer& server) {
RefCountedPtr<grpc_channel_credentials> channel_creds =
CoreConfiguration::Get().channel_creds_registry().CreateChannelCreds(
server.channel_creds_type(),
Json::FromObject(server.channel_creds_config()));
server.channel_creds_config());
return grpc_channel_create(server.server_uri().c_str(), channel_creds.get(),
args.ToC().get());
}

@ -21,81 +21,105 @@
#include <map>
#include <memory>
#include <string>
#include <type_traits>
#include <utility>
#include "absl/strings/string_view.h"
#include "src/core/lib/gprpp/ref_counted.h"
#include "src/core/lib/gprpp/ref_counted_ptr.h"
#include "src/core/lib/gprpp/validation_errors.h"
#include "src/core/lib/json/json.h"
#include "src/core/lib/json/json_args.h"
struct grpc_channel_credentials;
namespace grpc_core {
class ChannelCredsConfig : public RefCounted<ChannelCredsConfig> {
public:
virtual absl::string_view type() const = 0;
virtual bool Equals(const ChannelCredsConfig& other) const = 0;
virtual Json ToJson() const = 0;
};
template <typename T = grpc_channel_credentials>
class ChannelCredsFactory final {
public:
virtual ~ChannelCredsFactory() {}
virtual absl::string_view creds_type() const = delete;
virtual bool IsValidConfig(const Json& config) const = delete;
virtual RefCountedPtr<T> CreateChannelCreds(const Json& config) const =
delete;
virtual absl::string_view type() const = delete;
virtual RefCountedPtr<ChannelCredsConfig> ParseConfig(
const Json& config, const JsonArgs& args,
ValidationErrors* errors) const = delete;
virtual RefCountedPtr<T> CreateChannelCreds(
RefCountedPtr<ChannelCredsConfig> config) const = delete;
};
template <>
class ChannelCredsFactory<grpc_channel_credentials> {
public:
virtual ~ChannelCredsFactory() {}
virtual absl::string_view creds_type() const = 0;
virtual bool IsValidConfig(const Json& config) const = 0;
virtual absl::string_view type() const = 0;
virtual RefCountedPtr<ChannelCredsConfig> ParseConfig(
const Json& config, const JsonArgs& args,
ValidationErrors* errors) const = 0;
virtual RefCountedPtr<grpc_channel_credentials> CreateChannelCreds(
const Json& config) const = 0;
RefCountedPtr<ChannelCredsConfig> config) const = 0;
};
template <typename T = grpc_channel_credentials>
class ChannelCredsRegistry {
private:
using FactoryMap =
std::map<absl::string_view, std::unique_ptr<ChannelCredsFactory<T>>>;
public:
static_assert(std::is_base_of<grpc_channel_credentials, T>::value,
"ChannelCredsRegistry must be instantiated with "
"grpc_channel_credentials.");
class Builder {
public:
void RegisterChannelCredsFactory(
std::unique_ptr<ChannelCredsFactory<T>> factory) {
factories_[factory->creds_type()] = std::move(factory);
absl::string_view type = factory->type();
factories_[type] = std::move(factory);
}
ChannelCredsRegistry Build() {
ChannelCredsRegistry<T> registry;
registry.factories_.swap(factories_);
return registry;
return ChannelCredsRegistry<T>(std::move(factories_));
}
private:
std::map<absl::string_view, std::unique_ptr<ChannelCredsFactory<T>>>
factories_;
FactoryMap factories_;
};
bool IsSupported(const std::string& creds_type) const {
return factories_.find(creds_type) != factories_.end();
bool IsSupported(absl::string_view type) const {
return factories_.find(type) != factories_.end();
}
bool IsValidConfig(const std::string& creds_type, const Json& config) const {
const auto iter = factories_.find(creds_type);
return iter != factories_.cend() && iter->second->IsValidConfig(config);
RefCountedPtr<ChannelCredsConfig> ParseConfig(
absl::string_view type, const Json& config, const JsonArgs& args,
ValidationErrors* errors) const {
const auto it = factories_.find(type);
if (it == factories_.cend()) return nullptr;
return it->second->ParseConfig(config, args, errors);
}
RefCountedPtr<T> CreateChannelCreds(const std::string& creds_type,
const Json& config) const {
const auto iter = factories_.find(creds_type);
if (iter == factories_.cend()) return nullptr;
return iter->second->CreateChannelCreds(config);
RefCountedPtr<T> CreateChannelCreds(
RefCountedPtr<ChannelCredsConfig> config) const {
if (config == nullptr) return nullptr;
const auto it = factories_.find(config->type());
if (it == factories_.cend()) return nullptr;
return it->second->CreateChannelCreds(std::move(config));
}
private:
ChannelCredsRegistry() = default;
std::map<absl::string_view, std::unique_ptr<ChannelCredsFactory<T>>>
factories_;
explicit ChannelCredsRegistry(FactoryMap factories)
: factories_(std::move(factories)) {}
FactoryMap factories_;
};
} // namespace grpc_core

@ -18,59 +18,219 @@
#include <grpc/support/port_platform.h>
#include <map>
#include <memory>
#include <string>
#include <utility>
#include "absl/strings/string_view.h"
#include <grpc/grpc.h>
#include <grpc/grpc_security.h>
#include <grpc/support/json.h>
#include <grpc/support/time.h>
#include "src/core/lib/config/core_configuration.h"
#include "src/core/lib/gprpp/ref_counted_ptr.h"
#include "src/core/lib/gprpp/time.h"
#include "src/core/lib/gprpp/validation_errors.h"
#include "src/core/lib/json/json.h"
#include "src/core/lib/json/json_args.h"
#include "src/core/lib/json/json_object_loader.h"
#include "src/core/lib/security/credentials/channel_creds_registry.h"
#include "src/core/lib/security/credentials/credentials.h"
#include "src/core/lib/security/credentials/fake/fake_credentials.h"
#include "src/core/lib/security/credentials/google_default/google_default_credentials.h" // IWYU pragma: keep
#include "src/core/lib/security/credentials/tls/grpc_tls_certificate_provider.h"
#include "src/core/lib/security/credentials/tls/grpc_tls_credentials_options.h"
#include "src/core/lib/security/credentials/tls/tls_credentials.h"
namespace grpc_core {
class GoogleDefaultChannelCredsFactory : public ChannelCredsFactory<> {
public:
absl::string_view creds_type() const override { return "google_default"; }
bool IsValidConfig(const Json& /*config*/) const override { return true; }
absl::string_view type() const override { return Type(); }
RefCountedPtr<ChannelCredsConfig> ParseConfig(
const Json& /*config*/, const JsonArgs& /*args*/,
ValidationErrors* /*errors*/) const override {
return MakeRefCounted<Config>();
}
RefCountedPtr<grpc_channel_credentials> CreateChannelCreds(
const Json& /*config*/) const override {
RefCountedPtr<ChannelCredsConfig> /*config*/) const override {
return RefCountedPtr<grpc_channel_credentials>(
grpc_google_default_credentials_create(nullptr));
}
private:
class Config : public ChannelCredsConfig {
public:
absl::string_view type() const override { return Type(); }
bool Equals(const ChannelCredsConfig&) const override { return true; }
Json ToJson() const override { return Json::FromObject({}); }
};
static absl::string_view Type() { return "google_default"; }
};
class TlsChannelCredsFactory : public ChannelCredsFactory<> {
public:
absl::string_view type() const override { return Type(); }
RefCountedPtr<ChannelCredsConfig> ParseConfig(
const Json& config, const JsonArgs& args,
ValidationErrors* errors) const override {
return LoadFromJson<RefCountedPtr<TlsConfig>>(config, args, errors);
}
RefCountedPtr<grpc_channel_credentials> CreateChannelCreds(
RefCountedPtr<ChannelCredsConfig> base_config) const override {
auto* config = static_cast<const TlsConfig*>(base_config.get());
auto options = MakeRefCounted<grpc_tls_credentials_options>();
if (!config->certificate_file().empty() ||
!config->ca_certificate_file().empty()) {
options->set_certificate_provider(
MakeRefCounted<FileWatcherCertificateProvider>(
config->private_key_file(), config->certificate_file(),
config->ca_certificate_file(),
config->refresh_interval().millis() / GPR_MS_PER_SEC));
}
options->set_watch_root_cert(!config->ca_certificate_file().empty());
options->set_watch_identity_pair(!config->certificate_file().empty());
return MakeRefCounted<TlsCredentials>(std::move(options));
}
private:
// TODO(roth): It would be nice to share most of this config with the
// xDS file watcher cert provider factory, but that would require
// adding a dependency from lib to ext.
class TlsConfig : public ChannelCredsConfig {
public:
absl::string_view type() const override { return Type(); }
bool Equals(const ChannelCredsConfig& other) const override {
auto& o = static_cast<const TlsConfig&>(other);
return certificate_file_ == o.certificate_file_ &&
private_key_file_ == o.private_key_file_ &&
ca_certificate_file_ == o.ca_certificate_file_ &&
refresh_interval_ == o.refresh_interval_;
}
Json ToJson() const override {
Json::Object obj;
if (!certificate_file_.empty()) {
obj["certificate_file"] = Json::FromString(certificate_file_);
}
if (!private_key_file_.empty()) {
obj["private_key_file"] = Json::FromString(private_key_file_);
}
if (!ca_certificate_file_.empty()) {
obj["ca_certificate_file"] = Json::FromString(ca_certificate_file_);
}
if (refresh_interval_ != kDefaultRefreshInterval) {
obj["refresh_interval"] =
Json::FromString(refresh_interval_.ToJsonString());
}
return Json::FromObject(std::move(obj));
}
const std::string& certificate_file() const { return certificate_file_; }
const std::string& private_key_file() const { return private_key_file_; }
const std::string& ca_certificate_file() const {
return ca_certificate_file_;
}
Duration refresh_interval() const { return refresh_interval_; }
static const JsonLoaderInterface* JsonLoader(const JsonArgs&) {
static const auto* loader =
JsonObjectLoader<TlsConfig>()
.OptionalField("certificate_file", &TlsConfig::certificate_file_)
.OptionalField("private_key_file", &TlsConfig::private_key_file_)
.OptionalField("ca_certificate_file",
&TlsConfig::ca_certificate_file_)
.OptionalField("refresh_interval", &TlsConfig::refresh_interval_)
.Finish();
return loader;
}
void JsonPostLoad(const Json& json, const JsonArgs& /*args*/,
ValidationErrors* errors) {
if ((json.object().find("certificate_file") == json.object().end()) !=
(json.object().find("private_key_file") == json.object().end())) {
errors->AddError(
"fields \"certificate_file\" and \"private_key_file\" must be "
"both set or both unset");
}
}
private:
static constexpr Duration kDefaultRefreshInterval = Duration::Minutes(10);
std::string certificate_file_;
std::string private_key_file_;
std::string ca_certificate_file_;
Duration refresh_interval_ = kDefaultRefreshInterval;
};
static absl::string_view Type() { return "tls"; }
};
constexpr Duration TlsChannelCredsFactory::TlsConfig::kDefaultRefreshInterval;
class InsecureChannelCredsFactory : public ChannelCredsFactory<> {
public:
absl::string_view creds_type() const override { return "insecure"; }
bool IsValidConfig(const Json& /*config*/) const override { return true; }
absl::string_view type() const override { return Type(); }
RefCountedPtr<ChannelCredsConfig> ParseConfig(
const Json& /*config*/, const JsonArgs& /*args*/,
ValidationErrors* /*errors*/) const override {
return MakeRefCounted<Config>();
}
RefCountedPtr<grpc_channel_credentials> CreateChannelCreds(
const Json& /*config*/) const override {
RefCountedPtr<ChannelCredsConfig> /*config*/) const override {
return RefCountedPtr<grpc_channel_credentials>(
grpc_insecure_credentials_create());
}
private:
class Config : public ChannelCredsConfig {
public:
absl::string_view type() const override { return Type(); }
bool Equals(const ChannelCredsConfig&) const override { return true; }
Json ToJson() const override { return Json::FromObject({}); }
};
static absl::string_view Type() { return "insecure"; }
};
class FakeChannelCredsFactory : public ChannelCredsFactory<> {
public:
absl::string_view creds_type() const override { return "fake"; }
bool IsValidConfig(const Json& /*config*/) const override { return true; }
absl::string_view type() const override { return Type(); }
RefCountedPtr<ChannelCredsConfig> ParseConfig(
const Json& /*config*/, const JsonArgs& /*args*/,
ValidationErrors* /*errors*/) const override {
return MakeRefCounted<Config>();
}
RefCountedPtr<grpc_channel_credentials> CreateChannelCreds(
const Json& /*config*/) const override {
RefCountedPtr<ChannelCredsConfig> /*config*/) const override {
return RefCountedPtr<grpc_channel_credentials>(
grpc_fake_transport_security_credentials_create());
}
private:
class Config : public ChannelCredsConfig {
public:
absl::string_view type() const override { return Type(); }
bool Equals(const ChannelCredsConfig&) const override { return true; }
Json ToJson() const override { return Json::FromObject({}); }
};
static absl::string_view Type() { return "fake"; }
};
void RegisterChannelDefaultCreds(CoreConfiguration::Builder* builder) {
builder->channel_creds_registry()->RegisterChannelCredsFactory(
std::make_unique<GoogleDefaultChannelCredsFactory>());
builder->channel_creds_registry()->RegisterChannelCredsFactory(
std::make_unique<TlsChannelCredsFactory>());
builder->channel_creds_registry()->RegisterChannelCredsFactory(
std::make_unique<InsecureChannelCredsFactory>());
builder->channel_creds_registry()->RegisterChannelCredsFactory(

@ -39,7 +39,7 @@
// grpc_composite_channel_credentials
//
grpc_core::UniqueTypeName grpc_composite_channel_credentials::type() const {
grpc_core::UniqueTypeName grpc_composite_channel_credentials::Type() {
static grpc_core::UniqueTypeName::Factory kFactory("Composite");
return kFactory.Create();
}

@ -68,7 +68,9 @@ class grpc_composite_channel_credentials : public grpc_channel_credentials {
return inner_creds_->update_arguments(std::move(args));
}
grpc_core::UniqueTypeName type() const override;
static grpc_core::UniqueTypeName Type();
grpc_core::UniqueTypeName type() const override { return Type(); }
const grpc_channel_credentials* inner_creds() const {
return inner_creds_.get();

@ -36,44 +36,36 @@
// -- Fake transport security credentials. --
namespace {
class grpc_fake_channel_credentials final : public grpc_channel_credentials {
public:
grpc_core::RefCountedPtr<grpc_channel_security_connector>
create_security_connector(
grpc_core::RefCountedPtr<grpc_channel_security_connector>
grpc_fake_channel_credentials::create_security_connector(
grpc_core::RefCountedPtr<grpc_call_credentials> call_creds,
const char* target, grpc_core::ChannelArgs* args) override {
const char* target, grpc_core::ChannelArgs* args) {
return grpc_fake_channel_security_connector_create(
this->Ref(), std::move(call_creds), target, *args);
}
}
grpc_core::UniqueTypeName type() const override {
grpc_core::UniqueTypeName grpc_fake_channel_credentials::Type() {
static grpc_core::UniqueTypeName::Factory kFactory("Fake");
return kFactory.Create();
}
}
private:
int cmp_impl(const grpc_channel_credentials* other) const override {
int grpc_fake_channel_credentials::cmp_impl(
const grpc_channel_credentials* other) const {
// TODO(yashykt): Check if we can do something better here
return grpc_core::QsortCompare(
static_cast<const grpc_channel_credentials*>(this), other);
}
};
}
class grpc_fake_server_credentials final : public grpc_server_credentials {
public:
grpc_core::RefCountedPtr<grpc_server_security_connector>
create_security_connector(const grpc_core::ChannelArgs& /*args*/) override {
grpc_core::RefCountedPtr<grpc_server_security_connector>
grpc_fake_server_credentials::create_security_connector(
const grpc_core::ChannelArgs& /*args*/) {
return grpc_fake_server_security_connector_create(this->Ref());
}
}
grpc_core::UniqueTypeName type() const override {
grpc_core::UniqueTypeName grpc_fake_server_credentials::Type() {
static grpc_core::UniqueTypeName::Factory kFactory("Fake");
return kFactory.Create();
}
};
} // namespace
}
grpc_channel_credentials* grpc_fake_transport_security_credentials_create() {
return new grpc_fake_channel_credentials();

@ -29,10 +29,13 @@
#include <grpc/grpc_security.h>
#include <grpc/grpc_security_constants.h>
#include "src/core/lib/channel/channel_args.h"
#include "src/core/lib/gpr/useful.h"
#include "src/core/lib/gprpp/ref_counted_ptr.h"
#include "src/core/lib/gprpp/unique_type_name.h"
#include "src/core/lib/promise/arena_promise.h"
#include "src/core/lib/security/credentials/credentials.h"
#include "src/core/lib/security/security_connector/security_connector.h"
#include "src/core/lib/slice/slice.h"
#include "src/core/lib/transport/transport.h"
@ -41,6 +44,31 @@
// -- Fake transport security credentials. --
class grpc_fake_channel_credentials final : public grpc_channel_credentials {
public:
grpc_core::RefCountedPtr<grpc_channel_security_connector>
create_security_connector(
grpc_core::RefCountedPtr<grpc_call_credentials> call_creds,
const char* target, grpc_core::ChannelArgs* args) override;
static grpc_core::UniqueTypeName Type();
grpc_core::UniqueTypeName type() const override { return Type(); }
private:
int cmp_impl(const grpc_channel_credentials* other) const override;
};
class grpc_fake_server_credentials final : public grpc_server_credentials {
public:
grpc_core::RefCountedPtr<grpc_server_security_connector>
create_security_connector(const grpc_core::ChannelArgs& /*args*/) override;
static grpc_core::UniqueTypeName Type();
grpc_core::UniqueTypeName type() const override { return Type(); }
};
// Creates a fake transport security credentials object for testing.
grpc_channel_credentials* grpc_fake_transport_security_credentials_create(void);

@ -98,7 +98,7 @@ TlsCredentials::create_security_connector(
return sc;
}
grpc_core::UniqueTypeName TlsCredentials::type() const {
grpc_core::UniqueTypeName TlsCredentials::Type() {
static grpc_core::UniqueTypeName::Factory kFactory("Tls");
return kFactory.Create();
}

@ -41,7 +41,9 @@ class TlsCredentials final : public grpc_channel_credentials {
grpc_core::RefCountedPtr<grpc_call_credentials> call_creds,
const char* target_name, grpc_core::ChannelArgs* args) override;
grpc_core::UniqueTypeName type() const override;
static grpc_core::UniqueTypeName Type();
grpc_core::UniqueTypeName type() const override { return Type(); }
grpc_tls_credentials_options* options() const { return options_.get(); }

@ -21,10 +21,15 @@
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/types/optional.h"
#include <grpc/grpc.h>
#include "src/core/lib/config/core_configuration.h"
#include "src/core/lib/security/credentials/composite/composite_credentials.h"
#include "src/core/lib/security/credentials/fake/fake_credentials.h"
#include "src/core/lib/security/credentials/insecure/insecure_credentials.h"
#include "src/core/lib/security/credentials/tls/tls_credentials.h"
#include "test/core/util/test_config.h"
namespace grpc_core {
@ -33,52 +38,128 @@ namespace {
class TestChannelCredsFactory : public ChannelCredsFactory<> {
public:
absl::string_view creds_type() const override { return "test"; }
bool IsValidConfig(const Json& /*config*/) const override { return true; }
absl::string_view type() const override { return Type(); }
RefCountedPtr<ChannelCredsConfig> ParseConfig(
const Json& /*config*/, const JsonArgs& /*args*/,
ValidationErrors* /*errors*/) const override {
return MakeRefCounted<Config>();
}
RefCountedPtr<grpc_channel_credentials> CreateChannelCreds(
const Json& /*config*/) const override {
RefCountedPtr<ChannelCredsConfig> /*config*/) const override {
return RefCountedPtr<grpc_channel_credentials>(
grpc_fake_transport_security_credentials_create());
}
private:
class Config : public ChannelCredsConfig {
public:
absl::string_view type() const override { return Type(); }
bool Equals(const ChannelCredsConfig&) const override { return true; }
Json ToJson() const override { return Json::FromObject({}); }
};
static absl::string_view Type() { return "test"; }
};
class ChannelCredsRegistryTest : public ::testing::Test {
protected:
void SetUp() override {
CoreConfiguration::Reset();
grpc_init();
void SetUp() override { CoreConfiguration::Reset(); }
// Run a basic test for a given credential type.
// type is the string identifying the type in the registry.
// credential_type is the resulting type of the actual channel creds object;
// if nullopt, does not attempt to instantiate the credentials.
void TestCreds(absl::string_view type,
absl::optional<UniqueTypeName> credential_type,
Json json = Json::FromObject({})) {
EXPECT_TRUE(
CoreConfiguration::Get().channel_creds_registry().IsSupported(type));
ValidationErrors errors;
auto config = CoreConfiguration::Get().channel_creds_registry().ParseConfig(
type, json, JsonArgs(), &errors);
EXPECT_TRUE(errors.ok()) << errors.message("unexpected errors");
ASSERT_NE(config, nullptr);
EXPECT_EQ(config->type(), type);
if (credential_type.has_value()) {
auto creds =
CoreConfiguration::Get().channel_creds_registry().CreateChannelCreds(
std::move(config));
ASSERT_NE(creds, nullptr);
UniqueTypeName actual_type = creds->type();
// If we get composite creds, unwrap them.
// (This happens for GoogleDefaultCreds.)
if (creds->type() == grpc_composite_channel_credentials::Type()) {
actual_type =
static_cast<grpc_composite_channel_credentials*>(creds.get())
->inner_creds()
->type();
}
EXPECT_EQ(actual_type, *credential_type)
<< "Actual: " << actual_type.name()
<< "\nExpected: " << credential_type->name();
}
}
};
TEST_F(ChannelCredsRegistryTest, DefaultCreds) {
// Default creds.
EXPECT_TRUE(CoreConfiguration::Get().channel_creds_registry().IsSupported(
"google_default"));
EXPECT_TRUE(CoreConfiguration::Get().channel_creds_registry().IsSupported(
"insecure"));
EXPECT_TRUE(
CoreConfiguration::Get().channel_creds_registry().IsSupported("fake"));
TEST_F(ChannelCredsRegistryTest, GoogleDefaultCreds) {
// Don't actually instantiate the credentials, since that fails in
// some environments.
TestCreds("google_default", absl::nullopt);
}
// Non-default creds.
EXPECT_EQ(
CoreConfiguration::Get().channel_creds_registry().CreateChannelCreds(
"test", Json()),
nullptr);
EXPECT_EQ(
CoreConfiguration::Get().channel_creds_registry().CreateChannelCreds(
"", Json()),
nullptr);
TEST_F(ChannelCredsRegistryTest, InsecureCreds) {
TestCreds("insecure", InsecureCredentials::Type());
}
TEST_F(ChannelCredsRegistryTest, FakeCreds) {
TestCreds("fake", grpc_fake_channel_credentials::Type());
}
TEST_F(ChannelCredsRegistryTest, TlsCredsNoConfig) {
TestCreds("tls", TlsCredentials::Type());
}
TEST_F(ChannelCredsRegistryTest, TlsCredsFullConfig) {
Json json = Json::FromObject({
{"certificate_file", Json::FromString("/path/to/cert_file")},
{"private_key_file", Json::FromString("/path/to/private_key_file")},
{"ca_certificate_file", Json::FromString("/path/to/ca_cert_file")},
{"refresh_interval", Json::FromString("1s")},
});
TestCreds("tls", TlsCredentials::Type(), json);
}
TEST_F(ChannelCredsRegistryTest, TlsCredsConfigInvalid) {
Json json = Json::FromObject({
{"certificate_file", Json::FromObject({})},
{"private_key_file", Json::FromArray({})},
{"ca_certificate_file", Json::FromBool(true)},
{"refresh_interval", Json::FromNumber(1)},
});
ValidationErrors errors;
auto config = CoreConfiguration::Get().channel_creds_registry().ParseConfig(
"tls", json, JsonArgs(), &errors);
EXPECT_EQ(errors.message("errors"),
"errors: ["
"field:ca_certificate_file error:is not a string; "
"field:certificate_file error:is not a string; "
"field:private_key_file error:is not a string; "
"field:refresh_interval error:is not a string]");
}
TEST_F(ChannelCredsRegistryTest, Register) {
// Before registration.
EXPECT_FALSE(
CoreConfiguration::Get().channel_creds_registry().IsSupported("test"));
EXPECT_EQ(
ValidationErrors errors;
auto config = CoreConfiguration::Get().channel_creds_registry().ParseConfig(
"test", Json::FromObject({}), JsonArgs(), &errors);
EXPECT_TRUE(errors.ok()) << errors.message("unexpected errors");
EXPECT_EQ(config, nullptr);
auto creds =
CoreConfiguration::Get().channel_creds_registry().CreateChannelCreds(
"test", Json()),
nullptr);
std::move(config));
EXPECT_EQ(creds, nullptr);
// Registration.
CoreConfiguration::WithSubstituteBuilder builder(
[](CoreConfiguration::Builder* builder) {
@ -86,13 +167,18 @@ TEST_F(ChannelCredsRegistryTest, Register) {
builder->channel_creds_registry()->RegisterChannelCredsFactory(
std::make_unique<TestChannelCredsFactory>());
});
RefCountedPtr<grpc_channel_credentials> test_cred(
CoreConfiguration::Get().channel_creds_registry().CreateChannelCreds(
"test", Json()));
// After registration.
EXPECT_TRUE(
CoreConfiguration::Get().channel_creds_registry().IsSupported("test"));
EXPECT_NE(test_cred.get(), nullptr);
config = CoreConfiguration::Get().channel_creds_registry().ParseConfig(
"test", Json::FromObject({}), JsonArgs(), &errors);
EXPECT_TRUE(errors.ok()) << errors.message("unexpected errors");
EXPECT_NE(config, nullptr);
EXPECT_EQ(config->type(), "test");
creds = CoreConfiguration::Get().channel_creds_registry().CreateChannelCreds(
std::move(config));
ASSERT_NE(creds, nullptr);
EXPECT_EQ(creds->type(), grpc_fake_channel_credentials::Type());
}
} // namespace
@ -104,5 +190,6 @@ int main(int argc, char** argv) {
grpc::testing::TestEnvironment env(&argc, argv);
grpc_init();
auto result = RUN_ALL_TESTS();
grpc_shutdown();
return result;
}

@ -47,8 +47,8 @@
#include "src/core/lib/json/json_args.h"
#include "src/core/lib/json/json_object_loader.h"
#include "src/core/lib/json/json_reader.h"
#include "src/core/lib/json/json_writer.h"
#include "src/core/lib/security/certificate_provider/certificate_provider_factory.h"
#include "src/core/lib/security/credentials/channel_creds_registry.h"
#include "src/core/lib/security/credentials/tls/grpc_tls_certificate_provider.h"
#include "test/core/util/test_config.h"
@ -142,9 +142,8 @@ TEST(XdsBootstrapTest, Basic) {
auto* server =
&static_cast<const GrpcXdsBootstrap::GrpcXdsServer&>(bootstrap->server());
EXPECT_EQ(server->server_uri(), "fake:///lb");
EXPECT_EQ(server->channel_creds_type(), "fake");
EXPECT_TRUE(server->channel_creds_config().empty())
<< JsonDump(Json::FromObject(server->channel_creds_config()));
ASSERT_NE(server->channel_creds_config(), nullptr);
EXPECT_EQ(server->channel_creds_config()->type(), "fake");
EXPECT_EQ(bootstrap->authorities().size(), 2);
auto* authority = static_cast<const GrpcXdsBootstrap::GrpcAuthority*>(
bootstrap->LookupAuthority("xds.example.com"));
@ -156,9 +155,8 @@ TEST(XdsBootstrapTest, Basic) {
static_cast<const GrpcXdsBootstrap::GrpcXdsServer*>(authority->server());
ASSERT_NE(server, nullptr);
EXPECT_EQ(server->server_uri(), "fake:///xds_server");
EXPECT_EQ(server->channel_creds_type(), "fake");
EXPECT_TRUE(server->channel_creds_config().empty())
<< JsonDump(Json::FromObject(server->channel_creds_config()));
ASSERT_NE(server->channel_creds_config(), nullptr);
EXPECT_EQ(server->channel_creds_config()->type(), "fake");
authority = static_cast<const GrpcXdsBootstrap::GrpcAuthority*>(
bootstrap->LookupAuthority("xds.example2.com"));
ASSERT_NE(authority, nullptr);
@ -169,9 +167,8 @@ TEST(XdsBootstrapTest, Basic) {
static_cast<const GrpcXdsBootstrap::GrpcXdsServer*>(authority->server());
ASSERT_NE(server, nullptr);
EXPECT_EQ(server->server_uri(), "fake:///xds_server2");
EXPECT_EQ(server->channel_creds_type(), "fake");
EXPECT_TRUE(server->channel_creds_config().empty())
<< JsonDump(Json::FromObject(server->channel_creds_config()));
ASSERT_NE(server->channel_creds_config(), nullptr);
EXPECT_EQ(server->channel_creds_config()->type(), "fake");
ASSERT_NE(bootstrap->node(), nullptr);
EXPECT_EQ(bootstrap->node()->id(), "foo");
EXPECT_EQ(bootstrap->node()->cluster(), "bar");
@ -210,7 +207,8 @@ TEST(XdsBootstrapTest, ValidWithoutNode) {
auto* server =
&static_cast<const GrpcXdsBootstrap::GrpcXdsServer&>(bootstrap->server());
EXPECT_EQ(server->server_uri(), "fake:///lb");
EXPECT_EQ(server->channel_creds_type(), "fake");
ASSERT_NE(server->channel_creds_config(), nullptr);
EXPECT_EQ(server->channel_creds_config()->type(), "fake");
EXPECT_EQ(bootstrap->node(), nullptr);
}
@ -230,7 +228,8 @@ TEST(XdsBootstrapTest, InsecureCreds) {
auto* server =
&static_cast<const GrpcXdsBootstrap::GrpcXdsServer&>(bootstrap->server());
EXPECT_EQ(server->server_uri(), "fake:///lb");
EXPECT_EQ(server->channel_creds_type(), "insecure");
ASSERT_NE(server->channel_creds_config(), nullptr);
EXPECT_EQ(server->channel_creds_config()->type(), "insecure");
EXPECT_EQ(bootstrap->node(), nullptr);
}
@ -266,7 +265,8 @@ TEST(XdsBootstrapTest, GoogleDefaultCreds) {
auto* server =
&static_cast<const GrpcXdsBootstrap::GrpcXdsServer&>(bootstrap->server());
EXPECT_EQ(server->server_uri(), "fake:///lb");
EXPECT_EQ(server->channel_creds_type(), "google_default");
ASSERT_NE(server->channel_creds_config(), nullptr);
EXPECT_EQ(server->channel_creds_config()->type(), "google_default");
EXPECT_EQ(bootstrap->node(), nullptr);
}

Loading…
Cancel
Save