diff --git a/src/core/lib/security/credentials/ssl/ssl_credentials.cc b/src/core/lib/security/credentials/ssl/ssl_credentials.cc index d2bad595f78..e31225e82ff 100644 --- a/src/core/lib/security/credentials/ssl/ssl_credentials.cc +++ b/src/core/lib/security/credentials/ssl/ssl_credentials.cc @@ -39,7 +39,6 @@ #include "src/core/lib/surface/api_trace.h" #include "src/core/tsi/ssl/session_cache/ssl_session_cache.h" #include "src/core/tsi/ssl_transport_security.h" -#include "src/core/tsi/transport_security_interface.h" // // SSL Channel Credentials. @@ -49,27 +48,6 @@ grpc_ssl_credentials::grpc_ssl_credentials( const char* pem_root_certs, grpc_ssl_pem_key_cert_pair* pem_key_cert_pair, const grpc_ssl_verify_peer_options* verify_options) { build_config(pem_root_certs, pem_key_cert_pair, verify_options); - // Use default (e.g. OS) root certificates if the user did not pass any root - // certificates. - if (config_.pem_root_certs == nullptr) { - const char* pem_root_certs = - grpc_core::DefaultSslRootStore::GetPemRootCerts(); - if (pem_root_certs == nullptr) { - gpr_log(GPR_ERROR, "Could not get default pem root certs."); - } else { - size_t root_len = strlen(pem_root_certs); - char* default_roots = strcpy(new char[root_len + 1], pem_root_certs); - config_.pem_root_certs = default_roots; - root_store_ = grpc_core::DefaultSslRootStore::GetRootStore(); - } - } else { - config_.pem_root_certs = config_.pem_root_certs; - root_store_ = nullptr; - } - - client_handshaker_initialization_status_ = InitializeClientHandshakerFactory( - &config_, config_.pem_root_certs, root_store_, nullptr, - &client_handshaker_factory_); } grpc_ssl_credentials::~grpc_ssl_credentials() { @@ -79,67 +57,26 @@ grpc_ssl_credentials::~grpc_ssl_credentials() { config_.verify_options.verify_peer_destruct( config_.verify_options.verify_peer_callback_userdata); } - tsi_ssl_client_handshaker_factory_unref(client_handshaker_factory_); } grpc_core::RefCountedPtr grpc_ssl_credentials::create_security_connector( grpc_core::RefCountedPtr call_creds, const char* target, grpc_core::ChannelArgs* args) { - if (config_.pem_root_certs == nullptr) { - gpr_log(GPR_ERROR, - "No root certs in config. Client-side security connector must have " - "root certs."); - return nullptr; - } absl::optional overridden_target_name = args->GetOwnedString(GRPC_SSL_TARGET_NAME_OVERRIDE_ARG); auto* ssl_session_cache = args->GetObject(); - tsi_ssl_session_cache* session_cache = - ssl_session_cache == nullptr ? nullptr : ssl_session_cache->c_ptr(); - - grpc_core::RefCountedPtr security_connector = - nullptr; - if (session_cache != nullptr) { - // We need a separate factory and SSL_CTX if there's a cache in the channel - // args. SSL_CTX should live with the factory and that should live on the - // credentials. However, there is a way to configure a session cache in the - // channel args, so that prevents us from also keeping the session cache at - // the credentials level. In the case of a session cache, we still need to - // keep a separate factory and SSL_CTX at the subchannel/security_connector - // level. - tsi_ssl_client_handshaker_factory* factory_with_cache = nullptr; - grpc_security_status status = InitializeClientHandshakerFactory( - &config_, config_.pem_root_certs, root_store_, session_cache, - &factory_with_cache); - if (status != GRPC_SECURITY_OK) { - gpr_log(GPR_ERROR, - "InitializeClientHandshakerFactory returned bad " - "status."); - return nullptr; - } - security_connector = grpc_ssl_channel_security_connector_create( - this->Ref(), std::move(call_creds), &config_, target, - overridden_target_name.has_value() ? overridden_target_name->c_str() - : nullptr, - factory_with_cache); - tsi_ssl_client_handshaker_factory_unref(factory_with_cache); - } else { - if (client_handshaker_initialization_status_ != GRPC_SECURITY_OK) { - return nullptr; - } - security_connector = grpc_ssl_channel_security_connector_create( - this->Ref(), std::move(call_creds), &config_, target, - overridden_target_name.has_value() ? overridden_target_name->c_str() - : nullptr, - client_handshaker_factory_); - } - - if (security_connector == nullptr) { - return security_connector; + grpc_core::RefCountedPtr sc = + grpc_ssl_channel_security_connector_create( + this->Ref(), std::move(call_creds), &config_, target, + overridden_target_name.has_value() ? overridden_target_name->c_str() + : nullptr, + ssl_session_cache == nullptr ? nullptr : ssl_session_cache->c_ptr()); + if (sc == nullptr) { + return sc; } *args = args->Set(GRPC_ARG_HTTP2_SCHEME, "https"); - return security_connector; + return sc; } grpc_core::UniqueTypeName grpc_ssl_credentials::Type() { @@ -182,50 +119,6 @@ void grpc_ssl_credentials::set_max_tls_version( config_.max_tls_version = max_tls_version; } -grpc_security_status grpc_ssl_credentials::InitializeClientHandshakerFactory( - const grpc_ssl_config* config, const char* pem_root_certs, - const tsi_ssl_root_certs_store* root_store, - tsi_ssl_session_cache* ssl_session_cache, - tsi_ssl_client_handshaker_factory** handshaker_factory) { - // This class level factory can't have a session cache by design. If we want - // to init one with a cache we need to make a new one - if (client_handshaker_factory_ != nullptr && ssl_session_cache == nullptr) { - return GRPC_SECURITY_OK; - } - - bool has_key_cert_pair = config->pem_key_cert_pair != nullptr && - config->pem_key_cert_pair->private_key != nullptr && - config->pem_key_cert_pair->cert_chain != nullptr; - tsi_ssl_client_handshaker_options options; - if (pem_root_certs == nullptr) { - gpr_log( - GPR_ERROR, - "Handshaker factory creation failed. pem_root_certs cannot be nullptr"); - return GRPC_SECURITY_ERROR; - } - options.pem_root_certs = pem_root_certs; - options.root_store = root_store; - options.alpn_protocols = - grpc_fill_alpn_protocol_strings(&options.num_alpn_protocols); - if (has_key_cert_pair) { - options.pem_key_cert_pair = config->pem_key_cert_pair; - } - options.cipher_suites = grpc_get_ssl_cipher_suites(); - options.session_cache = ssl_session_cache; - options.min_tls_version = grpc_get_tsi_tls_version(config->min_tls_version); - options.max_tls_version = grpc_get_tsi_tls_version(config->max_tls_version); - const tsi_result result = - tsi_create_ssl_client_handshaker_factory_with_options(&options, - handshaker_factory); - gpr_free(options.alpn_protocols); - if (result != TSI_OK) { - gpr_log(GPR_ERROR, "Handshaker factory creation failed with %s.", - tsi_result_to_string(result)); - return GRPC_SECURITY_ERROR; - } - return GRPC_SECURITY_OK; -} - // Deprecated in favor of grpc_ssl_credentials_create_ex. Will be removed // once all of its call sites are migrated to grpc_ssl_credentials_create_ex. grpc_channel_credentials* grpc_ssl_credentials_create( diff --git a/src/core/lib/security/credentials/ssl/ssl_credentials.h b/src/core/lib/security/credentials/ssl/ssl_credentials.h index 01f6ece6581..2b2380d3946 100644 --- a/src/core/lib/security/credentials/ssl/ssl_credentials.h +++ b/src/core/lib/security/credentials/ssl/ssl_credentials.h @@ -69,21 +69,7 @@ class grpc_ssl_credentials : public grpc_channel_credentials { grpc_ssl_pem_key_cert_pair* pem_key_cert_pair, const grpc_ssl_verify_peer_options* verify_options); - // InitializeClientHandshakerFactory constructs a client handshaker factory - // that is stored on this credentials object. This handshaker factory will be - // used when creating handshakers using these credentials except in the case - // that there is a session cache. If a session cache is used, a new handshaker - // factory will be created and used that contains that session cache. - grpc_security_status InitializeClientHandshakerFactory( - const grpc_ssl_config* config, const char* pem_root_certs, - const tsi_ssl_root_certs_store* root_store, - tsi_ssl_session_cache* ssl_session_cache, - tsi_ssl_client_handshaker_factory** handshaker_factory); - grpc_ssl_config config_; - tsi_ssl_client_handshaker_factory* client_handshaker_factory_ = nullptr; - const tsi_ssl_root_certs_store* root_store_ = nullptr; - grpc_security_status client_handshaker_initialization_status_; }; struct grpc_ssl_server_certificate_config { diff --git a/src/core/lib/security/security_connector/ssl/ssl_security_connector.cc b/src/core/lib/security/security_connector/ssl/ssl_security_connector.cc index 34fbd571b7d..4fc2aeb3c76 100644 --- a/src/core/lib/security/security_connector/ssl/ssl_security_connector.cc +++ b/src/core/lib/security/security_connector/ssl/ssl_security_connector.cc @@ -82,12 +82,10 @@ class grpc_ssl_channel_security_connector final grpc_core::RefCountedPtr channel_creds, grpc_core::RefCountedPtr request_metadata_creds, const grpc_ssl_config* config, const char* target_name, - const char* overridden_target_name, - tsi_ssl_client_handshaker_factory* client_handshaker_factory) + const char* overridden_target_name) : grpc_channel_security_connector(GRPC_SSL_URL_SCHEME, std::move(channel_creds), std::move(request_metadata_creds)), - client_handshaker_factory_(client_handshaker_factory), overridden_target_name_( overridden_target_name == nullptr ? "" : overridden_target_name), verify_options_(&config->verify_options) { @@ -101,6 +99,39 @@ class grpc_ssl_channel_security_connector final tsi_ssl_client_handshaker_factory_unref(client_handshaker_factory_); } + grpc_security_status InitializeHandshakerFactory( + const grpc_ssl_config* config, const char* pem_root_certs, + const tsi_ssl_root_certs_store* root_store, + tsi_ssl_session_cache* ssl_session_cache) { + bool has_key_cert_pair = + config->pem_key_cert_pair != nullptr && + config->pem_key_cert_pair->private_key != nullptr && + config->pem_key_cert_pair->cert_chain != nullptr; + tsi_ssl_client_handshaker_options options; + GPR_DEBUG_ASSERT(pem_root_certs != nullptr); + options.pem_root_certs = pem_root_certs; + options.root_store = root_store; + options.alpn_protocols = + grpc_fill_alpn_protocol_strings(&options.num_alpn_protocols); + if (has_key_cert_pair) { + options.pem_key_cert_pair = config->pem_key_cert_pair; + } + options.cipher_suites = grpc_get_ssl_cipher_suites(); + options.session_cache = ssl_session_cache; + options.min_tls_version = grpc_get_tsi_tls_version(config->min_tls_version); + options.max_tls_version = grpc_get_tsi_tls_version(config->max_tls_version); + const tsi_result result = + tsi_create_ssl_client_handshaker_factory_with_options( + &options, &client_handshaker_factory_); + gpr_free(options.alpn_protocols); + if (result != TSI_OK) { + gpr_log(GPR_ERROR, "Handshaker factory creation failed with %s.", + tsi_result_to_string(result)); + return GRPC_SECURITY_ERROR; + } + return GRPC_SECURITY_OK; + } + void add_handshakers(const grpc_core::ChannelArgs& args, grpc_pollset_set* /*interested_parties*/, grpc_core::HandshakeManager* handshake_mgr) override { @@ -174,7 +205,7 @@ class grpc_ssl_channel_security_connector final } private: - tsi_ssl_client_handshaker_factory* client_handshaker_factory_ = nullptr; + tsi_ssl_client_handshaker_factory* client_handshaker_factory_; std::string target_name_; std::string overridden_target_name_; const verify_peer_options* verify_options_; @@ -380,17 +411,36 @@ grpc_ssl_channel_security_connector_create( grpc_core::RefCountedPtr request_metadata_creds, const grpc_ssl_config* config, const char* target_name, const char* overridden_target_name, - tsi_ssl_client_handshaker_factory* client_factory) { + tsi_ssl_session_cache* ssl_session_cache) { if (config == nullptr || target_name == nullptr) { gpr_log(GPR_ERROR, "An ssl channel needs a config and a target name."); return nullptr; } + const char* pem_root_certs; + const tsi_ssl_root_certs_store* root_store; + if (config->pem_root_certs == nullptr) { + // Use default root certificates. + pem_root_certs = grpc_core::DefaultSslRootStore::GetPemRootCerts(); + if (pem_root_certs == nullptr) { + gpr_log(GPR_ERROR, "Could not get default pem root certs."); + return nullptr; + } + root_store = grpc_core::DefaultSslRootStore::GetRootStore(); + } else { + pem_root_certs = config->pem_root_certs; + root_store = nullptr; + } + grpc_core::RefCountedPtr c = grpc_core::MakeRefCounted( std::move(channel_creds), std::move(request_metadata_creds), config, - target_name, overridden_target_name, - tsi_ssl_client_handshaker_factory_ref(client_factory)); + target_name, overridden_target_name); + const grpc_security_status result = c->InitializeHandshakerFactory( + config, pem_root_certs, root_store, ssl_session_cache); + if (result != GRPC_SECURITY_OK) { + return nullptr; + } return c; } diff --git a/src/core/lib/security/security_connector/ssl/ssl_security_connector.h b/src/core/lib/security/security_connector/ssl/ssl_security_connector.h index c5781d8d958..1e93592f718 100644 --- a/src/core/lib/security/security_connector/ssl/ssl_security_connector.h +++ b/src/core/lib/security/security_connector/ssl/ssl_security_connector.h @@ -57,7 +57,7 @@ grpc_ssl_channel_security_connector_create( grpc_core::RefCountedPtr request_metadata_creds, const grpc_ssl_config* config, const char* target_name, const char* overridden_target_name, - tsi_ssl_client_handshaker_factory* factory); + tsi_ssl_session_cache* ssl_session_cache); // Config for ssl servers. struct grpc_ssl_server_config { diff --git a/src/core/tsi/ssl_transport_security.cc b/src/core/tsi/ssl_transport_security.cc index f11e4a1418d..91519650b6d 100644 --- a/src/core/tsi/ssl_transport_security.cc +++ b/src/core/tsi/ssl_transport_security.cc @@ -1762,13 +1762,6 @@ void tsi_ssl_client_handshaker_factory_unref( tsi_ssl_handshaker_factory_unref(&factory->base); } -tsi_ssl_client_handshaker_factory* tsi_ssl_client_handshaker_factory_ref( - tsi_ssl_client_handshaker_factory* client_factory) { - if (client_factory == nullptr) return nullptr; - return reinterpret_cast( - tsi_ssl_handshaker_factory_ref(&client_factory->base)); -} - static void tsi_ssl_client_handshaker_factory_destroy( tsi_ssl_handshaker_factory* factory) { if (factory == nullptr) return; diff --git a/src/core/tsi/ssl_transport_security.h b/src/core/tsi/ssl_transport_security.h index 7248a1fe84a..75e549fc3cc 100644 --- a/src/core/tsi/ssl_transport_security.h +++ b/src/core/tsi/ssl_transport_security.h @@ -223,10 +223,6 @@ tsi_result tsi_ssl_client_handshaker_factory_create_handshaker( const char* server_name_indication, size_t network_bio_buf_size, size_t ssl_bio_buf_size, tsi_handshaker** handshaker); -// Increments reference count of the client handshaker factory. -tsi_ssl_client_handshaker_factory* tsi_ssl_client_handshaker_factory_ref( - tsi_ssl_client_handshaker_factory* client_factory); - // Decrements reference count of the handshaker factory. Handshaker factory will // be destroyed once no references exist. void tsi_ssl_client_handshaker_factory_unref( diff --git a/test/core/tsi/ssl_transport_security_test.cc b/test/core/tsi/ssl_transport_security_test.cc index 51703a2bb74..140d3729959 100644 --- a/test/core/tsi/ssl_transport_security_test.cc +++ b/test/core/tsi/ssl_transport_security_test.cc @@ -960,9 +960,6 @@ void test_tsi_ssl_client_handshaker_factory_refcounting() { TSI_OK); } - client_handshaker_factory = - tsi_ssl_client_handshaker_factory_ref(client_handshaker_factory); - tsi_handshaker_destroy(handshaker[1]); ASSERT_FALSE(handshaker_factory_destructor_called); @@ -973,10 +970,8 @@ void test_tsi_ssl_client_handshaker_factory_refcounting() { ASSERT_FALSE(handshaker_factory_destructor_called); tsi_handshaker_destroy(handshaker[2]); - ASSERT_FALSE(handshaker_factory_destructor_called); - - tsi_ssl_client_handshaker_factory_unref(client_handshaker_factory); ASSERT_TRUE(handshaker_factory_destructor_called); + gpr_free(cert_chain); } diff --git a/test/cpp/end2end/ssl_credentials_test.cc b/test/cpp/end2end/ssl_credentials_test.cc index 91022228986..4ecd542571e 100644 --- a/test/cpp/end2end/ssl_credentials_test.cc +++ b/test/cpp/end2end/ssl_credentials_test.cc @@ -106,7 +106,8 @@ void DoRpc(const std::string& server_addr, grpc::testing::EchoResponse response; request.set_message(kMessage); ClientContext context; - context.set_deadline(grpc_timeout_seconds_to_deadline(/*time_s=*/10)); + context.set_deadline(grpc_timeout_milliseconds_to_deadline( + /*time_ms=*/5000 * grpc_test_slowdown_factor())); grpc::Status result = stub->Echo(&context, request, &response); EXPECT_TRUE(result.ok()); if (!result.ok()) {