diff --git a/src/core/lib/security/credentials/ssl/ssl_credentials.cc b/src/core/lib/security/credentials/ssl/ssl_credentials.cc index e31225e82ff..d2bad595f78 100644 --- a/src/core/lib/security/credentials/ssl/ssl_credentials.cc +++ b/src/core/lib/security/credentials/ssl/ssl_credentials.cc @@ -39,6 +39,7 @@ #include "src/core/lib/surface/api_trace.h" #include "src/core/tsi/ssl/session_cache/ssl_session_cache.h" #include "src/core/tsi/ssl_transport_security.h" +#include "src/core/tsi/transport_security_interface.h" // // SSL Channel Credentials. @@ -48,6 +49,27 @@ grpc_ssl_credentials::grpc_ssl_credentials( const char* pem_root_certs, grpc_ssl_pem_key_cert_pair* pem_key_cert_pair, const grpc_ssl_verify_peer_options* verify_options) { build_config(pem_root_certs, pem_key_cert_pair, verify_options); + // Use default (e.g. OS) root certificates if the user did not pass any root + // certificates. + if (config_.pem_root_certs == nullptr) { + const char* pem_root_certs = + grpc_core::DefaultSslRootStore::GetPemRootCerts(); + if (pem_root_certs == nullptr) { + gpr_log(GPR_ERROR, "Could not get default pem root certs."); + } else { + size_t root_len = strlen(pem_root_certs); + char* default_roots = strcpy(new char[root_len + 1], pem_root_certs); + config_.pem_root_certs = default_roots; + root_store_ = grpc_core::DefaultSslRootStore::GetRootStore(); + } + } else { + config_.pem_root_certs = config_.pem_root_certs; + root_store_ = nullptr; + } + + client_handshaker_initialization_status_ = InitializeClientHandshakerFactory( + &config_, config_.pem_root_certs, root_store_, nullptr, + &client_handshaker_factory_); } grpc_ssl_credentials::~grpc_ssl_credentials() { @@ -57,26 +79,67 @@ grpc_ssl_credentials::~grpc_ssl_credentials() { config_.verify_options.verify_peer_destruct( config_.verify_options.verify_peer_callback_userdata); } + tsi_ssl_client_handshaker_factory_unref(client_handshaker_factory_); } grpc_core::RefCountedPtr grpc_ssl_credentials::create_security_connector( grpc_core::RefCountedPtr call_creds, const char* target, grpc_core::ChannelArgs* args) { + if (config_.pem_root_certs == nullptr) { + gpr_log(GPR_ERROR, + "No root certs in config. Client-side security connector must have " + "root certs."); + return nullptr; + } absl::optional overridden_target_name = args->GetOwnedString(GRPC_SSL_TARGET_NAME_OVERRIDE_ARG); auto* ssl_session_cache = args->GetObject(); - grpc_core::RefCountedPtr sc = - grpc_ssl_channel_security_connector_create( - this->Ref(), std::move(call_creds), &config_, target, - overridden_target_name.has_value() ? overridden_target_name->c_str() - : nullptr, - ssl_session_cache == nullptr ? nullptr : ssl_session_cache->c_ptr()); - if (sc == nullptr) { - return sc; + tsi_ssl_session_cache* session_cache = + ssl_session_cache == nullptr ? nullptr : ssl_session_cache->c_ptr(); + + grpc_core::RefCountedPtr security_connector = + nullptr; + if (session_cache != nullptr) { + // We need a separate factory and SSL_CTX if there's a cache in the channel + // args. SSL_CTX should live with the factory and that should live on the + // credentials. However, there is a way to configure a session cache in the + // channel args, so that prevents us from also keeping the session cache at + // the credentials level. In the case of a session cache, we still need to + // keep a separate factory and SSL_CTX at the subchannel/security_connector + // level. + tsi_ssl_client_handshaker_factory* factory_with_cache = nullptr; + grpc_security_status status = InitializeClientHandshakerFactory( + &config_, config_.pem_root_certs, root_store_, session_cache, + &factory_with_cache); + if (status != GRPC_SECURITY_OK) { + gpr_log(GPR_ERROR, + "InitializeClientHandshakerFactory returned bad " + "status."); + return nullptr; + } + security_connector = grpc_ssl_channel_security_connector_create( + this->Ref(), std::move(call_creds), &config_, target, + overridden_target_name.has_value() ? overridden_target_name->c_str() + : nullptr, + factory_with_cache); + tsi_ssl_client_handshaker_factory_unref(factory_with_cache); + } else { + if (client_handshaker_initialization_status_ != GRPC_SECURITY_OK) { + return nullptr; + } + security_connector = grpc_ssl_channel_security_connector_create( + this->Ref(), std::move(call_creds), &config_, target, + overridden_target_name.has_value() ? overridden_target_name->c_str() + : nullptr, + client_handshaker_factory_); + } + + if (security_connector == nullptr) { + return security_connector; } *args = args->Set(GRPC_ARG_HTTP2_SCHEME, "https"); - return sc; + return security_connector; } grpc_core::UniqueTypeName grpc_ssl_credentials::Type() { @@ -119,6 +182,50 @@ void grpc_ssl_credentials::set_max_tls_version( config_.max_tls_version = max_tls_version; } +grpc_security_status grpc_ssl_credentials::InitializeClientHandshakerFactory( + const grpc_ssl_config* config, const char* pem_root_certs, + const tsi_ssl_root_certs_store* root_store, + tsi_ssl_session_cache* ssl_session_cache, + tsi_ssl_client_handshaker_factory** handshaker_factory) { + // This class level factory can't have a session cache by design. If we want + // to init one with a cache we need to make a new one + if (client_handshaker_factory_ != nullptr && ssl_session_cache == nullptr) { + return GRPC_SECURITY_OK; + } + + bool has_key_cert_pair = config->pem_key_cert_pair != nullptr && + config->pem_key_cert_pair->private_key != nullptr && + config->pem_key_cert_pair->cert_chain != nullptr; + tsi_ssl_client_handshaker_options options; + if (pem_root_certs == nullptr) { + gpr_log( + GPR_ERROR, + "Handshaker factory creation failed. pem_root_certs cannot be nullptr"); + return GRPC_SECURITY_ERROR; + } + options.pem_root_certs = pem_root_certs; + options.root_store = root_store; + options.alpn_protocols = + grpc_fill_alpn_protocol_strings(&options.num_alpn_protocols); + if (has_key_cert_pair) { + options.pem_key_cert_pair = config->pem_key_cert_pair; + } + options.cipher_suites = grpc_get_ssl_cipher_suites(); + options.session_cache = ssl_session_cache; + options.min_tls_version = grpc_get_tsi_tls_version(config->min_tls_version); + options.max_tls_version = grpc_get_tsi_tls_version(config->max_tls_version); + const tsi_result result = + tsi_create_ssl_client_handshaker_factory_with_options(&options, + handshaker_factory); + gpr_free(options.alpn_protocols); + if (result != TSI_OK) { + gpr_log(GPR_ERROR, "Handshaker factory creation failed with %s.", + tsi_result_to_string(result)); + return GRPC_SECURITY_ERROR; + } + return GRPC_SECURITY_OK; +} + // Deprecated in favor of grpc_ssl_credentials_create_ex. Will be removed // once all of its call sites are migrated to grpc_ssl_credentials_create_ex. grpc_channel_credentials* grpc_ssl_credentials_create( diff --git a/src/core/lib/security/credentials/ssl/ssl_credentials.h b/src/core/lib/security/credentials/ssl/ssl_credentials.h index 2b2380d3946..01f6ece6581 100644 --- a/src/core/lib/security/credentials/ssl/ssl_credentials.h +++ b/src/core/lib/security/credentials/ssl/ssl_credentials.h @@ -69,7 +69,21 @@ class grpc_ssl_credentials : public grpc_channel_credentials { grpc_ssl_pem_key_cert_pair* pem_key_cert_pair, const grpc_ssl_verify_peer_options* verify_options); + // InitializeClientHandshakerFactory constructs a client handshaker factory + // that is stored on this credentials object. This handshaker factory will be + // used when creating handshakers using these credentials except in the case + // that there is a session cache. If a session cache is used, a new handshaker + // factory will be created and used that contains that session cache. + grpc_security_status InitializeClientHandshakerFactory( + const grpc_ssl_config* config, const char* pem_root_certs, + const tsi_ssl_root_certs_store* root_store, + tsi_ssl_session_cache* ssl_session_cache, + tsi_ssl_client_handshaker_factory** handshaker_factory); + grpc_ssl_config config_; + tsi_ssl_client_handshaker_factory* client_handshaker_factory_ = nullptr; + const tsi_ssl_root_certs_store* root_store_ = nullptr; + grpc_security_status client_handshaker_initialization_status_; }; struct grpc_ssl_server_certificate_config { diff --git a/src/core/lib/security/security_connector/ssl/ssl_security_connector.cc b/src/core/lib/security/security_connector/ssl/ssl_security_connector.cc index 4fc2aeb3c76..34fbd571b7d 100644 --- a/src/core/lib/security/security_connector/ssl/ssl_security_connector.cc +++ b/src/core/lib/security/security_connector/ssl/ssl_security_connector.cc @@ -82,10 +82,12 @@ class grpc_ssl_channel_security_connector final grpc_core::RefCountedPtr channel_creds, grpc_core::RefCountedPtr request_metadata_creds, const grpc_ssl_config* config, const char* target_name, - const char* overridden_target_name) + const char* overridden_target_name, + tsi_ssl_client_handshaker_factory* client_handshaker_factory) : grpc_channel_security_connector(GRPC_SSL_URL_SCHEME, std::move(channel_creds), std::move(request_metadata_creds)), + client_handshaker_factory_(client_handshaker_factory), overridden_target_name_( overridden_target_name == nullptr ? "" : overridden_target_name), verify_options_(&config->verify_options) { @@ -99,39 +101,6 @@ class grpc_ssl_channel_security_connector final tsi_ssl_client_handshaker_factory_unref(client_handshaker_factory_); } - grpc_security_status InitializeHandshakerFactory( - const grpc_ssl_config* config, const char* pem_root_certs, - const tsi_ssl_root_certs_store* root_store, - tsi_ssl_session_cache* ssl_session_cache) { - bool has_key_cert_pair = - config->pem_key_cert_pair != nullptr && - config->pem_key_cert_pair->private_key != nullptr && - config->pem_key_cert_pair->cert_chain != nullptr; - tsi_ssl_client_handshaker_options options; - GPR_DEBUG_ASSERT(pem_root_certs != nullptr); - options.pem_root_certs = pem_root_certs; - options.root_store = root_store; - options.alpn_protocols = - grpc_fill_alpn_protocol_strings(&options.num_alpn_protocols); - if (has_key_cert_pair) { - options.pem_key_cert_pair = config->pem_key_cert_pair; - } - options.cipher_suites = grpc_get_ssl_cipher_suites(); - options.session_cache = ssl_session_cache; - options.min_tls_version = grpc_get_tsi_tls_version(config->min_tls_version); - options.max_tls_version = grpc_get_tsi_tls_version(config->max_tls_version); - const tsi_result result = - tsi_create_ssl_client_handshaker_factory_with_options( - &options, &client_handshaker_factory_); - gpr_free(options.alpn_protocols); - if (result != TSI_OK) { - gpr_log(GPR_ERROR, "Handshaker factory creation failed with %s.", - tsi_result_to_string(result)); - return GRPC_SECURITY_ERROR; - } - return GRPC_SECURITY_OK; - } - void add_handshakers(const grpc_core::ChannelArgs& args, grpc_pollset_set* /*interested_parties*/, grpc_core::HandshakeManager* handshake_mgr) override { @@ -205,7 +174,7 @@ class grpc_ssl_channel_security_connector final } private: - tsi_ssl_client_handshaker_factory* client_handshaker_factory_; + tsi_ssl_client_handshaker_factory* client_handshaker_factory_ = nullptr; std::string target_name_; std::string overridden_target_name_; const verify_peer_options* verify_options_; @@ -411,36 +380,17 @@ grpc_ssl_channel_security_connector_create( grpc_core::RefCountedPtr request_metadata_creds, const grpc_ssl_config* config, const char* target_name, const char* overridden_target_name, - tsi_ssl_session_cache* ssl_session_cache) { + tsi_ssl_client_handshaker_factory* client_factory) { if (config == nullptr || target_name == nullptr) { gpr_log(GPR_ERROR, "An ssl channel needs a config and a target name."); return nullptr; } - const char* pem_root_certs; - const tsi_ssl_root_certs_store* root_store; - if (config->pem_root_certs == nullptr) { - // Use default root certificates. - pem_root_certs = grpc_core::DefaultSslRootStore::GetPemRootCerts(); - if (pem_root_certs == nullptr) { - gpr_log(GPR_ERROR, "Could not get default pem root certs."); - return nullptr; - } - root_store = grpc_core::DefaultSslRootStore::GetRootStore(); - } else { - pem_root_certs = config->pem_root_certs; - root_store = nullptr; - } - grpc_core::RefCountedPtr c = grpc_core::MakeRefCounted( std::move(channel_creds), std::move(request_metadata_creds), config, - target_name, overridden_target_name); - const grpc_security_status result = c->InitializeHandshakerFactory( - config, pem_root_certs, root_store, ssl_session_cache); - if (result != GRPC_SECURITY_OK) { - return nullptr; - } + target_name, overridden_target_name, + tsi_ssl_client_handshaker_factory_ref(client_factory)); return c; } diff --git a/src/core/lib/security/security_connector/ssl/ssl_security_connector.h b/src/core/lib/security/security_connector/ssl/ssl_security_connector.h index 1e93592f718..c5781d8d958 100644 --- a/src/core/lib/security/security_connector/ssl/ssl_security_connector.h +++ b/src/core/lib/security/security_connector/ssl/ssl_security_connector.h @@ -57,7 +57,7 @@ grpc_ssl_channel_security_connector_create( grpc_core::RefCountedPtr request_metadata_creds, const grpc_ssl_config* config, const char* target_name, const char* overridden_target_name, - tsi_ssl_session_cache* ssl_session_cache); + tsi_ssl_client_handshaker_factory* factory); // Config for ssl servers. struct grpc_ssl_server_config { diff --git a/src/core/tsi/ssl_transport_security.cc b/src/core/tsi/ssl_transport_security.cc index 91519650b6d..f11e4a1418d 100644 --- a/src/core/tsi/ssl_transport_security.cc +++ b/src/core/tsi/ssl_transport_security.cc @@ -1762,6 +1762,13 @@ void tsi_ssl_client_handshaker_factory_unref( tsi_ssl_handshaker_factory_unref(&factory->base); } +tsi_ssl_client_handshaker_factory* tsi_ssl_client_handshaker_factory_ref( + tsi_ssl_client_handshaker_factory* client_factory) { + if (client_factory == nullptr) return nullptr; + return reinterpret_cast( + tsi_ssl_handshaker_factory_ref(&client_factory->base)); +} + static void tsi_ssl_client_handshaker_factory_destroy( tsi_ssl_handshaker_factory* factory) { if (factory == nullptr) return; diff --git a/src/core/tsi/ssl_transport_security.h b/src/core/tsi/ssl_transport_security.h index 75e549fc3cc..7248a1fe84a 100644 --- a/src/core/tsi/ssl_transport_security.h +++ b/src/core/tsi/ssl_transport_security.h @@ -223,6 +223,10 @@ tsi_result tsi_ssl_client_handshaker_factory_create_handshaker( const char* server_name_indication, size_t network_bio_buf_size, size_t ssl_bio_buf_size, tsi_handshaker** handshaker); +// Increments reference count of the client handshaker factory. +tsi_ssl_client_handshaker_factory* tsi_ssl_client_handshaker_factory_ref( + tsi_ssl_client_handshaker_factory* client_factory); + // Decrements reference count of the handshaker factory. Handshaker factory will // be destroyed once no references exist. void tsi_ssl_client_handshaker_factory_unref( diff --git a/test/core/tsi/ssl_transport_security_test.cc b/test/core/tsi/ssl_transport_security_test.cc index 140d3729959..51703a2bb74 100644 --- a/test/core/tsi/ssl_transport_security_test.cc +++ b/test/core/tsi/ssl_transport_security_test.cc @@ -960,6 +960,9 @@ void test_tsi_ssl_client_handshaker_factory_refcounting() { TSI_OK); } + client_handshaker_factory = + tsi_ssl_client_handshaker_factory_ref(client_handshaker_factory); + tsi_handshaker_destroy(handshaker[1]); ASSERT_FALSE(handshaker_factory_destructor_called); @@ -970,8 +973,10 @@ void test_tsi_ssl_client_handshaker_factory_refcounting() { ASSERT_FALSE(handshaker_factory_destructor_called); tsi_handshaker_destroy(handshaker[2]); - ASSERT_TRUE(handshaker_factory_destructor_called); + ASSERT_FALSE(handshaker_factory_destructor_called); + tsi_ssl_client_handshaker_factory_unref(client_handshaker_factory); + ASSERT_TRUE(handshaker_factory_destructor_called); gpr_free(cert_chain); } diff --git a/test/cpp/end2end/ssl_credentials_test.cc b/test/cpp/end2end/ssl_credentials_test.cc index 4ecd542571e..91022228986 100644 --- a/test/cpp/end2end/ssl_credentials_test.cc +++ b/test/cpp/end2end/ssl_credentials_test.cc @@ -106,8 +106,7 @@ void DoRpc(const std::string& server_addr, grpc::testing::EchoResponse response; request.set_message(kMessage); ClientContext context; - context.set_deadline(grpc_timeout_milliseconds_to_deadline( - /*time_ms=*/5000 * grpc_test_slowdown_factor())); + context.set_deadline(grpc_timeout_seconds_to_deadline(/*time_s=*/10)); grpc::Status result = stub->Echo(&context, request, &response); EXPECT_TRUE(result.ok()); if (!result.ok()) {