diff --git a/src/core/lib/http/httpcli_security_connector.c b/src/core/lib/http/httpcli_security_connector.c index 97c28865258..c553fa3981e 100644 --- a/src/core/lib/http/httpcli_security_connector.c +++ b/src/core/lib/http/httpcli_security_connector.c @@ -43,7 +43,8 @@ static void httpcli_ssl_destroy(grpc_exec_ctx *exec_ctx, grpc_httpcli_ssl_channel_security_connector *c = (grpc_httpcli_ssl_channel_security_connector *)sc; if (c->handshaker_factory != NULL) { - tsi_ssl_client_handshaker_factory_destroy(c->handshaker_factory); + tsi_ssl_client_handshaker_factory_unref(c->handshaker_factory); + c->handshaker_factory = NULL; } if (c->secure_peer_name != NULL) gpr_free(c->secure_peer_name); gpr_free(sc); diff --git a/src/core/lib/security/transport/security_connector.c b/src/core/lib/security/transport/security_connector.c index a7568b995f5..2a9e939d40a 100644 --- a/src/core/lib/security/transport/security_connector.c +++ b/src/core/lib/security/transport/security_connector.c @@ -455,14 +455,14 @@ grpc_server_security_connector *grpc_fake_server_security_connector_create( typedef struct { grpc_channel_security_connector base; - tsi_ssl_client_handshaker_factory *handshaker_factory; + tsi_ssl_client_handshaker_factory *client_handshaker_factory; char *target_name; char *overridden_target_name; } grpc_ssl_channel_security_connector; typedef struct { grpc_server_security_connector base; - tsi_ssl_server_handshaker_factory *handshaker_factory; + tsi_ssl_server_handshaker_factory *server_handshaker_factory; } grpc_ssl_server_security_connector; static void ssl_channel_destroy(grpc_exec_ctx *exec_ctx, @@ -470,9 +470,8 @@ static void ssl_channel_destroy(grpc_exec_ctx *exec_ctx, grpc_ssl_channel_security_connector *c = (grpc_ssl_channel_security_connector *)sc; grpc_call_credentials_unref(exec_ctx, c->base.request_metadata_creds); - if (c->handshaker_factory != NULL) { - tsi_ssl_client_handshaker_factory_destroy(c->handshaker_factory); - } + tsi_ssl_client_handshaker_factory_unref(c->client_handshaker_factory); + c->client_handshaker_factory = NULL; if (c->target_name != NULL) gpr_free(c->target_name); if (c->overridden_target_name != NULL) gpr_free(c->overridden_target_name); gpr_free(sc); @@ -482,9 +481,8 @@ static void ssl_server_destroy(grpc_exec_ctx *exec_ctx, grpc_security_connector *sc) { grpc_ssl_server_security_connector *c = (grpc_ssl_server_security_connector *)sc; - if (c->handshaker_factory != NULL) { - tsi_ssl_server_handshaker_factory_destroy(c->handshaker_factory); - } + tsi_ssl_server_handshaker_factory_unref(c->server_handshaker_factory); + c->server_handshaker_factory = NULL; gpr_free(sc); } @@ -496,7 +494,7 @@ static void ssl_channel_add_handshakers(grpc_exec_ctx *exec_ctx, // Instantiate TSI handshaker. tsi_handshaker *tsi_hs = NULL; tsi_result result = tsi_ssl_client_handshaker_factory_create_handshaker( - c->handshaker_factory, + c->client_handshaker_factory, c->overridden_target_name != NULL ? c->overridden_target_name : c->target_name, &tsi_hs); @@ -521,7 +519,7 @@ static void ssl_server_add_handshakers(grpc_exec_ctx *exec_ctx, // Instantiate TSI handshaker. tsi_handshaker *tsi_hs = NULL; tsi_result result = tsi_ssl_server_handshaker_factory_create_handshaker( - c->handshaker_factory, &tsi_hs); + c->server_handshaker_factory, &tsi_hs); if (result != TSI_OK) { gpr_log(GPR_ERROR, "Handshaker creation failed with error %s.", tsi_result_to_string(result)); @@ -852,7 +850,7 @@ grpc_security_status grpc_ssl_channel_security_connector_create( result = tsi_create_ssl_client_handshaker_factory( has_key_cert_pair ? &config->pem_key_cert_pair : NULL, pem_root_certs, ssl_cipher_suites(), alpn_protocol_strings, (uint16_t)num_alpn_protocols, - &c->handshaker_factory); + &c->client_handshaker_factory); if (result != TSI_OK) { gpr_log(GPR_ERROR, "Handshaker factory creation failed with %s.", tsi_result_to_string(result)); @@ -897,7 +895,7 @@ grpc_security_status grpc_ssl_server_security_connector_create( config->pem_root_certs, get_tsi_client_certificate_request_type( config->client_certificate_request), ssl_cipher_suites(), alpn_protocol_strings, (uint16_t)num_alpn_protocols, - &c->handshaker_factory); + &c->server_handshaker_factory); if (result != TSI_OK) { gpr_log(GPR_ERROR, "Handshaker factory creation failed with %s.", tsi_result_to_string(result)); diff --git a/src/core/tsi/ssl_transport_security.c b/src/core/tsi/ssl_transport_security.c index 1fd65928f9c..7ebf9dd96f0 100644 --- a/src/core/tsi/ssl_transport_security.c +++ b/src/core/tsi/ssl_transport_security.c @@ -67,7 +67,13 @@ /* --- Structure definitions. ---*/ +struct tsi_ssl_handshaker_factory { + const tsi_ssl_handshaker_factory_vtable *vtable; + gpr_refcount refcount; +}; + struct tsi_ssl_client_handshaker_factory { + tsi_ssl_handshaker_factory base; SSL_CTX *ssl_context; unsigned char *alpn_protocol_list; size_t alpn_protocol_list_length; @@ -77,6 +83,7 @@ struct tsi_ssl_server_handshaker_factory { /* Several contexts to support SNI. The tsi_peer array contains the subject names of the server certificates associated with the contexts at the same index. */ + tsi_ssl_handshaker_factory base; SSL_CTX **ssl_contexts; tsi_peer *ssl_context_x509_subject_names; size_t ssl_context_count; @@ -90,6 +97,7 @@ typedef struct { BIO *into_ssl; BIO *from_ssl; tsi_result result; + tsi_ssl_handshaker_factory *factory_ref; } tsi_ssl_handshaker; typedef struct { @@ -846,6 +854,47 @@ static const tsi_frame_protector_vtable frame_protector_vtable = { ssl_protector_destroy, }; +/* --- tsi_server_handshaker_factory methods implementation. --- */ + +static void tsi_ssl_handshaker_factory_destroy( + tsi_ssl_handshaker_factory *self) { + if (self == NULL) return; + + if (self->vtable != NULL && self->vtable->destroy != NULL) { + self->vtable->destroy(self); + } + /* Note, we don't free(self) here because this object is always directly + * embedded in another object. If tsi_ssl_handshaker_factory_init allocates + * any memory, it should be free'd here. */ +} + +static tsi_ssl_handshaker_factory *tsi_ssl_handshaker_factory_ref( + tsi_ssl_handshaker_factory *self) { + if (self == NULL) return NULL; + gpr_refn(&self->refcount, 1); + return self; +} + +static void tsi_ssl_handshaker_factory_unref(tsi_ssl_handshaker_factory *self) { + if (self == NULL) return; + + if (gpr_unref(&self->refcount)) { + tsi_ssl_handshaker_factory_destroy(self); + } +} + +static tsi_ssl_handshaker_factory_vtable handshaker_factory_vtable = {NULL}; + +/* Initializes a tsi_ssl_handshaker_factory object. Caller is responsible for + * allocating memory for the factory. */ +static void tsi_ssl_handshaker_factory_init( + tsi_ssl_handshaker_factory *factory) { + GPR_ASSERT(factory != NULL); + + factory->vtable = &handshaker_factory_vtable; + gpr_ref_init(&factory->refcount, 1); +} + /* --- tsi_handshaker methods implementation. ---*/ static tsi_result ssl_handshaker_get_bytes_to_send_to_peer(tsi_handshaker *self, @@ -1013,6 +1062,7 @@ static tsi_result ssl_handshaker_create_frame_protector( static void ssl_handshaker_destroy(tsi_handshaker *self) { tsi_ssl_handshaker *impl = (tsi_ssl_handshaker *)self; SSL_free(impl->ssl); /* The BIO objects are owned by ssl */ + tsi_ssl_handshaker_factory_unref(impl->factory_ref); gpr_free(impl); } @@ -1030,6 +1080,7 @@ static const tsi_handshaker_vtable handshaker_vtable = { static tsi_result create_tsi_ssl_handshaker(SSL_CTX *ctx, int is_client, const char *server_name_indication, + tsi_ssl_handshaker_factory *factory, tsi_handshaker **handshaker) { SSL *ssl = SSL_new(ctx); BIO *into_ssl = NULL; @@ -1085,6 +1136,8 @@ static tsi_result create_tsi_ssl_handshaker(SSL_CTX *ctx, int is_client, impl->from_ssl = from_ssl; impl->result = TSI_HANDSHAKE_IN_PROGRESS; impl->base.vtable = &handshaker_vtable; + impl->factory_ref = tsi_ssl_handshaker_factory_ref(factory); + *handshaker = &impl->base; return TSI_OK; } @@ -1121,11 +1174,20 @@ tsi_result tsi_ssl_client_handshaker_factory_create_handshaker( tsi_ssl_client_handshaker_factory *self, const char *server_name_indication, tsi_handshaker **handshaker) { return create_tsi_ssl_handshaker(self->ssl_context, 1, server_name_indication, - handshaker); + &self->base, handshaker); } -void tsi_ssl_client_handshaker_factory_destroy( +void tsi_ssl_client_handshaker_factory_unref( tsi_ssl_client_handshaker_factory *self) { + if (self == NULL) return; + tsi_ssl_handshaker_factory_unref(&self->base); +} + +static void tsi_ssl_client_handshaker_factory_destroy( + tsi_ssl_handshaker_factory *factory) { + if (factory == NULL) return; + tsi_ssl_client_handshaker_factory *self = + (tsi_ssl_client_handshaker_factory *)factory; if (self->ssl_context != NULL) SSL_CTX_free(self->ssl_context); if (self->alpn_protocol_list != NULL) gpr_free(self->alpn_protocol_list); gpr_free(self); @@ -1150,11 +1212,21 @@ tsi_result tsi_ssl_server_handshaker_factory_create_handshaker( if (self->ssl_context_count == 0) return TSI_INVALID_ARGUMENT; /* Create the handshaker with the first context. We will switch if needed because of SNI in ssl_server_handshaker_factory_servername_callback. */ - return create_tsi_ssl_handshaker(self->ssl_contexts[0], 0, NULL, handshaker); + return create_tsi_ssl_handshaker(self->ssl_contexts[0], 0, NULL, &self->base, + handshaker); } -void tsi_ssl_server_handshaker_factory_destroy( +void tsi_ssl_server_handshaker_factory_unref( tsi_ssl_server_handshaker_factory *self) { + if (self == NULL) return; + tsi_ssl_handshaker_factory_unref(&self->base); +} + +static void tsi_ssl_server_handshaker_factory_destroy( + tsi_ssl_handshaker_factory *factory) { + if (factory == NULL) return; + tsi_ssl_server_handshaker_factory *self = + (tsi_ssl_server_handshaker_factory *)factory; size_t i; for (i = 0; i < self->ssl_context_count; i++) { if (self->ssl_contexts[i] != NULL) { @@ -1263,6 +1335,9 @@ static int server_handshaker_factory_npn_advertised_callback( /* --- tsi_ssl_handshaker_factory constructors. --- */ +static tsi_ssl_handshaker_factory_vtable client_handshaker_factory_vtable = { + tsi_ssl_client_handshaker_factory_destroy}; + tsi_result tsi_create_ssl_client_handshaker_factory( const tsi_ssl_pem_key_cert_pair *pem_key_cert_pair, const char *pem_root_certs, const char *cipher_suites, @@ -1285,6 +1360,9 @@ tsi_result tsi_create_ssl_client_handshaker_factory( } impl = gpr_zalloc(sizeof(*impl)); + tsi_ssl_handshaker_factory_init(&impl->base); + impl->base.vtable = &client_handshaker_factory_vtable; + impl->ssl_context = ssl_context; do { @@ -1322,7 +1400,7 @@ tsi_result tsi_create_ssl_client_handshaker_factory( } } while (0); if (result != TSI_OK) { - tsi_ssl_client_handshaker_factory_destroy(impl); + tsi_ssl_handshaker_factory_unref(&impl->base); return result; } SSL_CTX_set_verify(ssl_context, SSL_VERIFY_PEER, NULL); @@ -1332,6 +1410,9 @@ tsi_result tsi_create_ssl_client_handshaker_factory( return TSI_OK; } +static tsi_ssl_handshaker_factory_vtable server_handshaker_factory_vtable = { + tsi_ssl_server_handshaker_factory_destroy}; + tsi_result tsi_create_ssl_server_handshaker_factory( const tsi_ssl_pem_key_cert_pair *pem_key_cert_pairs, size_t num_key_cert_pairs, const char *pem_client_root_certs, @@ -1364,12 +1445,15 @@ tsi_result tsi_create_ssl_server_handshaker_factory_ex( } impl = gpr_zalloc(sizeof(*impl)); + tsi_ssl_handshaker_factory_init(&impl->base); + impl->base.vtable = &server_handshaker_factory_vtable; + impl->ssl_contexts = gpr_zalloc(num_key_cert_pairs * sizeof(SSL_CTX *)); impl->ssl_context_x509_subject_names = gpr_zalloc(num_key_cert_pairs * sizeof(tsi_peer)); if (impl->ssl_contexts == NULL || impl->ssl_context_x509_subject_names == NULL) { - tsi_ssl_server_handshaker_factory_destroy(impl); + tsi_ssl_handshaker_factory_unref(&impl->base); return TSI_OUT_OF_RESOURCES; } impl->ssl_context_count = num_key_cert_pairs; @@ -1379,7 +1463,7 @@ tsi_result tsi_create_ssl_server_handshaker_factory_ex( &impl->alpn_protocol_list, &impl->alpn_protocol_list_length); if (result != TSI_OK) { - tsi_ssl_server_handshaker_factory_destroy(impl); + tsi_ssl_handshaker_factory_unref(&impl->base); return result; } } @@ -1451,10 +1535,11 @@ tsi_result tsi_create_ssl_server_handshaker_factory_ex( } while (0); if (result != TSI_OK) { - tsi_ssl_server_handshaker_factory_destroy(impl); + tsi_ssl_handshaker_factory_unref(&impl->base); return result; } } + *factory = impl; return TSI_OK; } @@ -1501,3 +1586,15 @@ int tsi_ssl_peer_matches_name(const tsi_peer *peer, const char *name) { return 0; /* Not found. */ } + +/* --- Testing support. --- */ +const tsi_ssl_handshaker_factory_vtable *tsi_ssl_handshaker_factory_swap_vtable( + tsi_ssl_handshaker_factory *factory, + tsi_ssl_handshaker_factory_vtable *new_vtable) { + GPR_ASSERT(factory != NULL); + GPR_ASSERT(factory->vtable != NULL); + + const tsi_ssl_handshaker_factory_vtable *orig_vtable = factory->vtable; + factory->vtable = new_vtable; + return orig_vtable; +} diff --git a/src/core/tsi/ssl_transport_security.h b/src/core/tsi/ssl_transport_security.h index 177599930bc..3abfdf5ed87 100644 --- a/src/core/tsi/ssl_transport_security.h +++ b/src/core/tsi/ssl_transport_security.h @@ -96,10 +96,10 @@ tsi_result tsi_ssl_client_handshaker_factory_create_handshaker( tsi_ssl_client_handshaker_factory *self, const char *server_name_indication, tsi_handshaker **handshaker); -/* Destroys the handshaker factory. WARNING: it is unsafe to destroy a factory - while handshakers created with this factory are still in use. */ -void tsi_ssl_client_handshaker_factory_destroy( - tsi_ssl_client_handshaker_factory *self); +/* Decrements reference count of the handshaker factory. Handshaker factory will + * be destroyed once no references exist. */ +void tsi_ssl_client_handshaker_factory_unref( + tsi_ssl_client_handshaker_factory *factory); /* --- tsi_ssl_server_handshaker_factory object --- @@ -158,9 +158,9 @@ tsi_result tsi_create_ssl_server_handshaker_factory_ex( tsi_result tsi_ssl_server_handshaker_factory_create_handshaker( tsi_ssl_server_handshaker_factory *self, tsi_handshaker **handshaker); -/* Destroys the handshaker factory. WARNING: it is unsafe to destroy a factory - while handshakers created with this factory are still in use. */ -void tsi_ssl_server_handshaker_factory_destroy( +/* Decrements reference count of the handshaker factory. Handshaker factory will + * be destroyed once no references exist. */ +void tsi_ssl_server_handshaker_factory_unref( tsi_ssl_server_handshaker_factory *self); /* Util that checks that an ssl peer matches a specific name. @@ -170,6 +170,29 @@ void tsi_ssl_server_handshaker_factory_destroy( - handle public suffix wildchar more strictly (e.g. *.co.uk) */ int tsi_ssl_peer_matches_name(const tsi_peer *peer, const char *name); +/* --- Testing support. --- + + These functions and typedefs are not intended to be used outside of testing. + */ + +/* Base type of client and server handshaker factories. */ +typedef struct tsi_ssl_handshaker_factory tsi_ssl_handshaker_factory; + +/* Function pointer to handshaker_factory destructor. */ +typedef void (*tsi_ssl_handshaker_factory_destructor)( + tsi_ssl_handshaker_factory *factory); + +/* Virtual table for tsi_ssl_handshaker_factory. */ +typedef struct { + tsi_ssl_handshaker_factory_destructor destroy; +} tsi_ssl_handshaker_factory_vtable; + +/* Set destructor of handshaker_factory to new_destructor, returns previous + destructor. */ +const tsi_ssl_handshaker_factory_vtable *tsi_ssl_handshaker_factory_swap_vtable( + tsi_ssl_handshaker_factory *factory, + tsi_ssl_handshaker_factory_vtable *new_vtable); + #ifdef __cplusplus } #endif diff --git a/test/core/tsi/ssl_transport_security_test.c b/test/core/tsi/ssl_transport_security_test.c index 364dfa1b73f..2399b054b1b 100644 --- a/test/core/tsi/ssl_transport_security_test.c +++ b/test/core/tsi/ssl_transport_security_test.c @@ -23,7 +23,9 @@ #include "src/core/lib/iomgr/load_file.h" #include "src/core/lib/security/transport/security_connector.h" #include "src/core/tsi/ssl_transport_security.h" +#include "src/core/tsi/transport_security.h" #include "src/core/tsi/transport_security_adapter.h" +#include "src/core/tsi/transport_security_interface.h" #include "test/core/tsi/transport_security_test_lib.h" #include "test/core/util/test_config.h" @@ -312,10 +314,10 @@ static void ssl_test_destruct(tsi_test_fixture *fixture) { key_cert_lib->bad_client_pem_key_cert_pair); gpr_free(key_cert_lib->root_cert); gpr_free(key_cert_lib); - /* Destroy others. */ - tsi_ssl_server_handshaker_factory_destroy( + /* Unreference others. */ + tsi_ssl_server_handshaker_factory_unref( ssl_fixture->server_handshaker_factory); - tsi_ssl_client_handshaker_factory_destroy( + tsi_ssl_client_handshaker_factory_unref( ssl_fixture->client_handshaker_factory); } @@ -536,6 +538,118 @@ void ssl_tsi_test_do_round_trip_odd_buffer_size() { } } +static const tsi_ssl_handshaker_factory_vtable *original_vtable; +static bool handshaker_factory_destructor_called; + +static void ssl_tsi_test_handshaker_factory_destructor( + tsi_ssl_handshaker_factory *factory) { + GPR_ASSERT(factory != NULL); + handshaker_factory_destructor_called = true; + if (original_vtable != NULL && original_vtable->destroy != NULL) { + original_vtable->destroy(factory); + } +} + +static tsi_ssl_handshaker_factory_vtable test_handshaker_factory_vtable = { + ssl_tsi_test_handshaker_factory_destructor}; + +void test_tsi_ssl_client_handshaker_factory_refcounting() { + int i; + const char *cert_chain = + load_file(SSL_TSI_TEST_CREDENTIALS_DIR, "client.pem"); + + tsi_ssl_client_handshaker_factory *client_handshaker_factory; + GPR_ASSERT(tsi_create_ssl_client_handshaker_factory( + NULL, cert_chain, NULL, NULL, 0, &client_handshaker_factory) == + TSI_OK); + + handshaker_factory_destructor_called = false; + original_vtable = tsi_ssl_handshaker_factory_swap_vtable( + (tsi_ssl_handshaker_factory *)client_handshaker_factory, + &test_handshaker_factory_vtable); + + tsi_handshaker *handshaker[3]; + + for (i = 0; i < 3; ++i) { + GPR_ASSERT(tsi_ssl_client_handshaker_factory_create_handshaker( + client_handshaker_factory, "google.com", &handshaker[i]) == + TSI_OK); + } + + tsi_handshaker_destroy(handshaker[1]); + GPR_ASSERT(!handshaker_factory_destructor_called); + + tsi_handshaker_destroy(handshaker[0]); + GPR_ASSERT(!handshaker_factory_destructor_called); + + tsi_ssl_client_handshaker_factory_unref(client_handshaker_factory); + GPR_ASSERT(!handshaker_factory_destructor_called); + + tsi_handshaker_destroy(handshaker[2]); + GPR_ASSERT(handshaker_factory_destructor_called); + + gpr_free((void *)cert_chain); +} + +void test_tsi_ssl_server_handshaker_factory_refcounting() { + int i; + tsi_ssl_server_handshaker_factory *server_handshaker_factory; + tsi_handshaker *handshaker[3]; + const char *cert_chain = + load_file(SSL_TSI_TEST_CREDENTIALS_DIR, "server0.pem"); + tsi_ssl_pem_key_cert_pair cert_pair; + + cert_pair.cert_chain = cert_chain; + cert_pair.private_key = + load_file(SSL_TSI_TEST_CREDENTIALS_DIR, "server0.key"); + + GPR_ASSERT(tsi_create_ssl_server_handshaker_factory( + &cert_pair, 1, cert_chain, 0, NULL, NULL, 0, + &server_handshaker_factory) == TSI_OK); + + handshaker_factory_destructor_called = false; + original_vtable = tsi_ssl_handshaker_factory_swap_vtable( + (tsi_ssl_handshaker_factory *)server_handshaker_factory, + &test_handshaker_factory_vtable); + + for (i = 0; i < 3; ++i) { + GPR_ASSERT(tsi_ssl_server_handshaker_factory_create_handshaker( + server_handshaker_factory, &handshaker[i]) == TSI_OK); + } + + tsi_handshaker_destroy(handshaker[1]); + GPR_ASSERT(!handshaker_factory_destructor_called); + + tsi_handshaker_destroy(handshaker[0]); + GPR_ASSERT(!handshaker_factory_destructor_called); + + tsi_ssl_server_handshaker_factory_unref(server_handshaker_factory); + GPR_ASSERT(!handshaker_factory_destructor_called); + + tsi_handshaker_destroy(handshaker[2]); + GPR_ASSERT(handshaker_factory_destructor_called); + + ssl_test_pem_key_cert_pair_destroy(cert_pair); +} + +/* Attempting to create a handshaker factory with invalid parameters should fail + * but not crash. */ +void test_tsi_ssl_client_handshaker_factory_bad_params() { + const char *cert_chain = "This is not a valid PEM file."; + + tsi_ssl_client_handshaker_factory *client_handshaker_factory; + GPR_ASSERT(tsi_create_ssl_client_handshaker_factory( + NULL, cert_chain, NULL, NULL, 0, &client_handshaker_factory) == + TSI_INVALID_ARGUMENT); + tsi_ssl_client_handshaker_factory_unref(client_handshaker_factory); +} + +void ssl_tsi_test_handshaker_factory_internals() { + test_tsi_ssl_client_handshaker_factory_refcounting(); + test_tsi_ssl_server_handshaker_factory_refcounting(); + test_tsi_ssl_client_handshaker_factory_bad_params(); +} + int main(int argc, char **argv) { grpc_test_init(argc, argv); grpc_init(); @@ -553,6 +667,7 @@ int main(int argc, char **argv) { ssl_tsi_test_do_handshake_alpn_client_server_ok(); ssl_tsi_test_do_round_trip_for_all_configs(); ssl_tsi_test_do_round_trip_odd_buffer_size(); + ssl_tsi_test_handshaker_factory_internals(); grpc_shutdown(); return 0; }