[Security] Move ownership of tsi_ssl_client_handshaker_factory to grpc_ssl_credentials. (#34180)

Move the SSL_CTX to the level of the credentials rather than the
subchannel.
The SSL_CTX should only get created once per credential rather than once
per subchannel.

We should observe no behavior change with this PR, only efficiency
gains.
pull/34355/head
Gregory Cooke 1 year ago committed by GitHub
parent 58f1c74383
commit 36dc5e7391
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 125
      src/core/lib/security/credentials/ssl/ssl_credentials.cc
  2. 14
      src/core/lib/security/credentials/ssl/ssl_credentials.h
  3. 64
      src/core/lib/security/security_connector/ssl/ssl_security_connector.cc
  4. 2
      src/core/lib/security/security_connector/ssl/ssl_security_connector.h
  5. 7
      src/core/tsi/ssl_transport_security.cc
  6. 4
      src/core/tsi/ssl_transport_security.h
  7. 7
      test/core/tsi/ssl_transport_security_test.cc
  8. 3
      test/cpp/end2end/ssl_credentials_test.cc

@ -39,6 +39,7 @@
#include "src/core/lib/surface/api_trace.h"
#include "src/core/tsi/ssl/session_cache/ssl_session_cache.h"
#include "src/core/tsi/ssl_transport_security.h"
#include "src/core/tsi/transport_security_interface.h"
//
// SSL Channel Credentials.
@ -48,6 +49,27 @@ grpc_ssl_credentials::grpc_ssl_credentials(
const char* pem_root_certs, grpc_ssl_pem_key_cert_pair* pem_key_cert_pair,
const grpc_ssl_verify_peer_options* verify_options) {
build_config(pem_root_certs, pem_key_cert_pair, verify_options);
// Use default (e.g. OS) root certificates if the user did not pass any root
// certificates.
if (config_.pem_root_certs == nullptr) {
const char* pem_root_certs =
grpc_core::DefaultSslRootStore::GetPemRootCerts();
if (pem_root_certs == nullptr) {
gpr_log(GPR_ERROR, "Could not get default pem root certs.");
} else {
size_t root_len = strlen(pem_root_certs);
char* default_roots = strcpy(new char[root_len + 1], pem_root_certs);
config_.pem_root_certs = default_roots;
root_store_ = grpc_core::DefaultSslRootStore::GetRootStore();
}
} else {
config_.pem_root_certs = config_.pem_root_certs;
root_store_ = nullptr;
}
client_handshaker_initialization_status_ = InitializeClientHandshakerFactory(
&config_, config_.pem_root_certs, root_store_, nullptr,
&client_handshaker_factory_);
}
grpc_ssl_credentials::~grpc_ssl_credentials() {
@ -57,26 +79,67 @@ grpc_ssl_credentials::~grpc_ssl_credentials() {
config_.verify_options.verify_peer_destruct(
config_.verify_options.verify_peer_callback_userdata);
}
tsi_ssl_client_handshaker_factory_unref(client_handshaker_factory_);
}
grpc_core::RefCountedPtr<grpc_channel_security_connector>
grpc_ssl_credentials::create_security_connector(
grpc_core::RefCountedPtr<grpc_call_credentials> call_creds,
const char* target, grpc_core::ChannelArgs* args) {
if (config_.pem_root_certs == nullptr) {
gpr_log(GPR_ERROR,
"No root certs in config. Client-side security connector must have "
"root certs.");
return nullptr;
}
absl::optional<std::string> overridden_target_name =
args->GetOwnedString(GRPC_SSL_TARGET_NAME_OVERRIDE_ARG);
auto* ssl_session_cache = args->GetObject<tsi::SslSessionLRUCache>();
grpc_core::RefCountedPtr<grpc_channel_security_connector> sc =
grpc_ssl_channel_security_connector_create(
this->Ref(), std::move(call_creds), &config_, target,
overridden_target_name.has_value() ? overridden_target_name->c_str()
: nullptr,
ssl_session_cache == nullptr ? nullptr : ssl_session_cache->c_ptr());
if (sc == nullptr) {
return sc;
tsi_ssl_session_cache* session_cache =
ssl_session_cache == nullptr ? nullptr : ssl_session_cache->c_ptr();
grpc_core::RefCountedPtr<grpc_channel_security_connector> security_connector =
nullptr;
if (session_cache != nullptr) {
// We need a separate factory and SSL_CTX if there's a cache in the channel
// args. SSL_CTX should live with the factory and that should live on the
// credentials. However, there is a way to configure a session cache in the
// channel args, so that prevents us from also keeping the session cache at
// the credentials level. In the case of a session cache, we still need to
// keep a separate factory and SSL_CTX at the subchannel/security_connector
// level.
tsi_ssl_client_handshaker_factory* factory_with_cache = nullptr;
grpc_security_status status = InitializeClientHandshakerFactory(
&config_, config_.pem_root_certs, root_store_, session_cache,
&factory_with_cache);
if (status != GRPC_SECURITY_OK) {
gpr_log(GPR_ERROR,
"InitializeClientHandshakerFactory returned bad "
"status.");
return nullptr;
}
security_connector = grpc_ssl_channel_security_connector_create(
this->Ref(), std::move(call_creds), &config_, target,
overridden_target_name.has_value() ? overridden_target_name->c_str()
: nullptr,
factory_with_cache);
tsi_ssl_client_handshaker_factory_unref(factory_with_cache);
} else {
if (client_handshaker_initialization_status_ != GRPC_SECURITY_OK) {
return nullptr;
}
security_connector = grpc_ssl_channel_security_connector_create(
this->Ref(), std::move(call_creds), &config_, target,
overridden_target_name.has_value() ? overridden_target_name->c_str()
: nullptr,
client_handshaker_factory_);
}
if (security_connector == nullptr) {
return security_connector;
}
*args = args->Set(GRPC_ARG_HTTP2_SCHEME, "https");
return sc;
return security_connector;
}
grpc_core::UniqueTypeName grpc_ssl_credentials::Type() {
@ -119,6 +182,50 @@ void grpc_ssl_credentials::set_max_tls_version(
config_.max_tls_version = max_tls_version;
}
grpc_security_status grpc_ssl_credentials::InitializeClientHandshakerFactory(
const grpc_ssl_config* config, const char* pem_root_certs,
const tsi_ssl_root_certs_store* root_store,
tsi_ssl_session_cache* ssl_session_cache,
tsi_ssl_client_handshaker_factory** handshaker_factory) {
// This class level factory can't have a session cache by design. If we want
// to init one with a cache we need to make a new one
if (client_handshaker_factory_ != nullptr && ssl_session_cache == nullptr) {
return GRPC_SECURITY_OK;
}
bool has_key_cert_pair = config->pem_key_cert_pair != nullptr &&
config->pem_key_cert_pair->private_key != nullptr &&
config->pem_key_cert_pair->cert_chain != nullptr;
tsi_ssl_client_handshaker_options options;
if (pem_root_certs == nullptr) {
gpr_log(
GPR_ERROR,
"Handshaker factory creation failed. pem_root_certs cannot be nullptr");
return GRPC_SECURITY_ERROR;
}
options.pem_root_certs = pem_root_certs;
options.root_store = root_store;
options.alpn_protocols =
grpc_fill_alpn_protocol_strings(&options.num_alpn_protocols);
if (has_key_cert_pair) {
options.pem_key_cert_pair = config->pem_key_cert_pair;
}
options.cipher_suites = grpc_get_ssl_cipher_suites();
options.session_cache = ssl_session_cache;
options.min_tls_version = grpc_get_tsi_tls_version(config->min_tls_version);
options.max_tls_version = grpc_get_tsi_tls_version(config->max_tls_version);
const tsi_result result =
tsi_create_ssl_client_handshaker_factory_with_options(&options,
handshaker_factory);
gpr_free(options.alpn_protocols);
if (result != TSI_OK) {
gpr_log(GPR_ERROR, "Handshaker factory creation failed with %s.",
tsi_result_to_string(result));
return GRPC_SECURITY_ERROR;
}
return GRPC_SECURITY_OK;
}
// Deprecated in favor of grpc_ssl_credentials_create_ex. Will be removed
// once all of its call sites are migrated to grpc_ssl_credentials_create_ex.
grpc_channel_credentials* grpc_ssl_credentials_create(

@ -69,7 +69,21 @@ class grpc_ssl_credentials : public grpc_channel_credentials {
grpc_ssl_pem_key_cert_pair* pem_key_cert_pair,
const grpc_ssl_verify_peer_options* verify_options);
// InitializeClientHandshakerFactory constructs a client handshaker factory
// that is stored on this credentials object. This handshaker factory will be
// used when creating handshakers using these credentials except in the case
// that there is a session cache. If a session cache is used, a new handshaker
// factory will be created and used that contains that session cache.
grpc_security_status InitializeClientHandshakerFactory(
const grpc_ssl_config* config, const char* pem_root_certs,
const tsi_ssl_root_certs_store* root_store,
tsi_ssl_session_cache* ssl_session_cache,
tsi_ssl_client_handshaker_factory** handshaker_factory);
grpc_ssl_config config_;
tsi_ssl_client_handshaker_factory* client_handshaker_factory_ = nullptr;
const tsi_ssl_root_certs_store* root_store_ = nullptr;
grpc_security_status client_handshaker_initialization_status_;
};
struct grpc_ssl_server_certificate_config {

@ -82,10 +82,12 @@ class grpc_ssl_channel_security_connector final
grpc_core::RefCountedPtr<grpc_channel_credentials> channel_creds,
grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds,
const grpc_ssl_config* config, const char* target_name,
const char* overridden_target_name)
const char* overridden_target_name,
tsi_ssl_client_handshaker_factory* client_handshaker_factory)
: grpc_channel_security_connector(GRPC_SSL_URL_SCHEME,
std::move(channel_creds),
std::move(request_metadata_creds)),
client_handshaker_factory_(client_handshaker_factory),
overridden_target_name_(
overridden_target_name == nullptr ? "" : overridden_target_name),
verify_options_(&config->verify_options) {
@ -99,39 +101,6 @@ class grpc_ssl_channel_security_connector final
tsi_ssl_client_handshaker_factory_unref(client_handshaker_factory_);
}
grpc_security_status InitializeHandshakerFactory(
const grpc_ssl_config* config, const char* pem_root_certs,
const tsi_ssl_root_certs_store* root_store,
tsi_ssl_session_cache* ssl_session_cache) {
bool has_key_cert_pair =
config->pem_key_cert_pair != nullptr &&
config->pem_key_cert_pair->private_key != nullptr &&
config->pem_key_cert_pair->cert_chain != nullptr;
tsi_ssl_client_handshaker_options options;
GPR_DEBUG_ASSERT(pem_root_certs != nullptr);
options.pem_root_certs = pem_root_certs;
options.root_store = root_store;
options.alpn_protocols =
grpc_fill_alpn_protocol_strings(&options.num_alpn_protocols);
if (has_key_cert_pair) {
options.pem_key_cert_pair = config->pem_key_cert_pair;
}
options.cipher_suites = grpc_get_ssl_cipher_suites();
options.session_cache = ssl_session_cache;
options.min_tls_version = grpc_get_tsi_tls_version(config->min_tls_version);
options.max_tls_version = grpc_get_tsi_tls_version(config->max_tls_version);
const tsi_result result =
tsi_create_ssl_client_handshaker_factory_with_options(
&options, &client_handshaker_factory_);
gpr_free(options.alpn_protocols);
if (result != TSI_OK) {
gpr_log(GPR_ERROR, "Handshaker factory creation failed with %s.",
tsi_result_to_string(result));
return GRPC_SECURITY_ERROR;
}
return GRPC_SECURITY_OK;
}
void add_handshakers(const grpc_core::ChannelArgs& args,
grpc_pollset_set* /*interested_parties*/,
grpc_core::HandshakeManager* handshake_mgr) override {
@ -205,7 +174,7 @@ class grpc_ssl_channel_security_connector final
}
private:
tsi_ssl_client_handshaker_factory* client_handshaker_factory_;
tsi_ssl_client_handshaker_factory* client_handshaker_factory_ = nullptr;
std::string target_name_;
std::string overridden_target_name_;
const verify_peer_options* verify_options_;
@ -411,36 +380,17 @@ grpc_ssl_channel_security_connector_create(
grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds,
const grpc_ssl_config* config, const char* target_name,
const char* overridden_target_name,
tsi_ssl_session_cache* ssl_session_cache) {
tsi_ssl_client_handshaker_factory* client_factory) {
if (config == nullptr || target_name == nullptr) {
gpr_log(GPR_ERROR, "An ssl channel needs a config and a target name.");
return nullptr;
}
const char* pem_root_certs;
const tsi_ssl_root_certs_store* root_store;
if (config->pem_root_certs == nullptr) {
// Use default root certificates.
pem_root_certs = grpc_core::DefaultSslRootStore::GetPemRootCerts();
if (pem_root_certs == nullptr) {
gpr_log(GPR_ERROR, "Could not get default pem root certs.");
return nullptr;
}
root_store = grpc_core::DefaultSslRootStore::GetRootStore();
} else {
pem_root_certs = config->pem_root_certs;
root_store = nullptr;
}
grpc_core::RefCountedPtr<grpc_ssl_channel_security_connector> c =
grpc_core::MakeRefCounted<grpc_ssl_channel_security_connector>(
std::move(channel_creds), std::move(request_metadata_creds), config,
target_name, overridden_target_name);
const grpc_security_status result = c->InitializeHandshakerFactory(
config, pem_root_certs, root_store, ssl_session_cache);
if (result != GRPC_SECURITY_OK) {
return nullptr;
}
target_name, overridden_target_name,
tsi_ssl_client_handshaker_factory_ref(client_factory));
return c;
}

@ -57,7 +57,7 @@ grpc_ssl_channel_security_connector_create(
grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds,
const grpc_ssl_config* config, const char* target_name,
const char* overridden_target_name,
tsi_ssl_session_cache* ssl_session_cache);
tsi_ssl_client_handshaker_factory* factory);
// Config for ssl servers.
struct grpc_ssl_server_config {

@ -1762,6 +1762,13 @@ void tsi_ssl_client_handshaker_factory_unref(
tsi_ssl_handshaker_factory_unref(&factory->base);
}
tsi_ssl_client_handshaker_factory* tsi_ssl_client_handshaker_factory_ref(
tsi_ssl_client_handshaker_factory* client_factory) {
if (client_factory == nullptr) return nullptr;
return reinterpret_cast<tsi_ssl_client_handshaker_factory*>(
tsi_ssl_handshaker_factory_ref(&client_factory->base));
}
static void tsi_ssl_client_handshaker_factory_destroy(
tsi_ssl_handshaker_factory* factory) {
if (factory == nullptr) return;

@ -223,6 +223,10 @@ tsi_result tsi_ssl_client_handshaker_factory_create_handshaker(
const char* server_name_indication, size_t network_bio_buf_size,
size_t ssl_bio_buf_size, tsi_handshaker** handshaker);
// Increments reference count of the client handshaker factory.
tsi_ssl_client_handshaker_factory* tsi_ssl_client_handshaker_factory_ref(
tsi_ssl_client_handshaker_factory* client_factory);
// Decrements reference count of the handshaker factory. Handshaker factory will
// be destroyed once no references exist.
void tsi_ssl_client_handshaker_factory_unref(

@ -960,6 +960,9 @@ void test_tsi_ssl_client_handshaker_factory_refcounting() {
TSI_OK);
}
client_handshaker_factory =
tsi_ssl_client_handshaker_factory_ref(client_handshaker_factory);
tsi_handshaker_destroy(handshaker[1]);
ASSERT_FALSE(handshaker_factory_destructor_called);
@ -970,8 +973,10 @@ void test_tsi_ssl_client_handshaker_factory_refcounting() {
ASSERT_FALSE(handshaker_factory_destructor_called);
tsi_handshaker_destroy(handshaker[2]);
ASSERT_TRUE(handshaker_factory_destructor_called);
ASSERT_FALSE(handshaker_factory_destructor_called);
tsi_ssl_client_handshaker_factory_unref(client_handshaker_factory);
ASSERT_TRUE(handshaker_factory_destructor_called);
gpr_free(cert_chain);
}

@ -106,8 +106,7 @@ void DoRpc(const std::string& server_addr,
grpc::testing::EchoResponse response;
request.set_message(kMessage);
ClientContext context;
context.set_deadline(grpc_timeout_milliseconds_to_deadline(
/*time_ms=*/5000 * grpc_test_slowdown_factor()));
context.set_deadline(grpc_timeout_seconds_to_deadline(/*time_s=*/10));
grpc::Status result = stub->Echo(&context, request, &response);
EXPECT_TRUE(result.ok());
if (!result.ok()) {

Loading…
Cancel
Save