Merge pull request #12359 from justinburke/tsi_factory_locking

Add reference counting to tsi_ssl_handshaker_factories
pull/12631/head
Justin Burke 7 years ago committed by GitHub
commit 06576d5bc4
  1. 3
      src/core/lib/http/httpcli_security_connector.c
  2. 22
      src/core/lib/security/transport/security_connector.c
  3. 113
      src/core/tsi/ssl_transport_security.c
  4. 37
      src/core/tsi/ssl_transport_security.h
  5. 121
      test/core/tsi/ssl_transport_security_test.c

@ -43,7 +43,8 @@ static void httpcli_ssl_destroy(grpc_exec_ctx *exec_ctx,
grpc_httpcli_ssl_channel_security_connector *c = grpc_httpcli_ssl_channel_security_connector *c =
(grpc_httpcli_ssl_channel_security_connector *)sc; (grpc_httpcli_ssl_channel_security_connector *)sc;
if (c->handshaker_factory != NULL) { if (c->handshaker_factory != NULL) {
tsi_ssl_client_handshaker_factory_destroy(c->handshaker_factory); tsi_ssl_client_handshaker_factory_unref(c->handshaker_factory);
c->handshaker_factory = NULL;
} }
if (c->secure_peer_name != NULL) gpr_free(c->secure_peer_name); if (c->secure_peer_name != NULL) gpr_free(c->secure_peer_name);
gpr_free(sc); gpr_free(sc);

@ -455,14 +455,14 @@ grpc_server_security_connector *grpc_fake_server_security_connector_create(
typedef struct { typedef struct {
grpc_channel_security_connector base; grpc_channel_security_connector base;
tsi_ssl_client_handshaker_factory *handshaker_factory; tsi_ssl_client_handshaker_factory *client_handshaker_factory;
char *target_name; char *target_name;
char *overridden_target_name; char *overridden_target_name;
} grpc_ssl_channel_security_connector; } grpc_ssl_channel_security_connector;
typedef struct { typedef struct {
grpc_server_security_connector base; grpc_server_security_connector base;
tsi_ssl_server_handshaker_factory *handshaker_factory; tsi_ssl_server_handshaker_factory *server_handshaker_factory;
} grpc_ssl_server_security_connector; } grpc_ssl_server_security_connector;
static void ssl_channel_destroy(grpc_exec_ctx *exec_ctx, static void ssl_channel_destroy(grpc_exec_ctx *exec_ctx,
@ -470,9 +470,8 @@ static void ssl_channel_destroy(grpc_exec_ctx *exec_ctx,
grpc_ssl_channel_security_connector *c = grpc_ssl_channel_security_connector *c =
(grpc_ssl_channel_security_connector *)sc; (grpc_ssl_channel_security_connector *)sc;
grpc_call_credentials_unref(exec_ctx, c->base.request_metadata_creds); grpc_call_credentials_unref(exec_ctx, c->base.request_metadata_creds);
if (c->handshaker_factory != NULL) { tsi_ssl_client_handshaker_factory_unref(c->client_handshaker_factory);
tsi_ssl_client_handshaker_factory_destroy(c->handshaker_factory); c->client_handshaker_factory = NULL;
}
if (c->target_name != NULL) gpr_free(c->target_name); if (c->target_name != NULL) gpr_free(c->target_name);
if (c->overridden_target_name != NULL) gpr_free(c->overridden_target_name); if (c->overridden_target_name != NULL) gpr_free(c->overridden_target_name);
gpr_free(sc); gpr_free(sc);
@ -482,9 +481,8 @@ static void ssl_server_destroy(grpc_exec_ctx *exec_ctx,
grpc_security_connector *sc) { grpc_security_connector *sc) {
grpc_ssl_server_security_connector *c = grpc_ssl_server_security_connector *c =
(grpc_ssl_server_security_connector *)sc; (grpc_ssl_server_security_connector *)sc;
if (c->handshaker_factory != NULL) { tsi_ssl_server_handshaker_factory_unref(c->server_handshaker_factory);
tsi_ssl_server_handshaker_factory_destroy(c->handshaker_factory); c->server_handshaker_factory = NULL;
}
gpr_free(sc); gpr_free(sc);
} }
@ -496,7 +494,7 @@ static void ssl_channel_add_handshakers(grpc_exec_ctx *exec_ctx,
// Instantiate TSI handshaker. // Instantiate TSI handshaker.
tsi_handshaker *tsi_hs = NULL; tsi_handshaker *tsi_hs = NULL;
tsi_result result = tsi_ssl_client_handshaker_factory_create_handshaker( tsi_result result = tsi_ssl_client_handshaker_factory_create_handshaker(
c->handshaker_factory, c->client_handshaker_factory,
c->overridden_target_name != NULL ? c->overridden_target_name c->overridden_target_name != NULL ? c->overridden_target_name
: c->target_name, : c->target_name,
&tsi_hs); &tsi_hs);
@ -521,7 +519,7 @@ static void ssl_server_add_handshakers(grpc_exec_ctx *exec_ctx,
// Instantiate TSI handshaker. // Instantiate TSI handshaker.
tsi_handshaker *tsi_hs = NULL; tsi_handshaker *tsi_hs = NULL;
tsi_result result = tsi_ssl_server_handshaker_factory_create_handshaker( tsi_result result = tsi_ssl_server_handshaker_factory_create_handshaker(
c->handshaker_factory, &tsi_hs); c->server_handshaker_factory, &tsi_hs);
if (result != TSI_OK) { if (result != TSI_OK) {
gpr_log(GPR_ERROR, "Handshaker creation failed with error %s.", gpr_log(GPR_ERROR, "Handshaker creation failed with error %s.",
tsi_result_to_string(result)); tsi_result_to_string(result));
@ -852,7 +850,7 @@ grpc_security_status grpc_ssl_channel_security_connector_create(
result = tsi_create_ssl_client_handshaker_factory( result = tsi_create_ssl_client_handshaker_factory(
has_key_cert_pair ? &config->pem_key_cert_pair : NULL, pem_root_certs, has_key_cert_pair ? &config->pem_key_cert_pair : NULL, pem_root_certs,
ssl_cipher_suites(), alpn_protocol_strings, (uint16_t)num_alpn_protocols, ssl_cipher_suites(), alpn_protocol_strings, (uint16_t)num_alpn_protocols,
&c->handshaker_factory); &c->client_handshaker_factory);
if (result != TSI_OK) { if (result != TSI_OK) {
gpr_log(GPR_ERROR, "Handshaker factory creation failed with %s.", gpr_log(GPR_ERROR, "Handshaker factory creation failed with %s.",
tsi_result_to_string(result)); tsi_result_to_string(result));
@ -897,7 +895,7 @@ grpc_security_status grpc_ssl_server_security_connector_create(
config->pem_root_certs, get_tsi_client_certificate_request_type( config->pem_root_certs, get_tsi_client_certificate_request_type(
config->client_certificate_request), config->client_certificate_request),
ssl_cipher_suites(), alpn_protocol_strings, (uint16_t)num_alpn_protocols, ssl_cipher_suites(), alpn_protocol_strings, (uint16_t)num_alpn_protocols,
&c->handshaker_factory); &c->server_handshaker_factory);
if (result != TSI_OK) { if (result != TSI_OK) {
gpr_log(GPR_ERROR, "Handshaker factory creation failed with %s.", gpr_log(GPR_ERROR, "Handshaker factory creation failed with %s.",
tsi_result_to_string(result)); tsi_result_to_string(result));

@ -67,7 +67,13 @@
/* --- Structure definitions. ---*/ /* --- Structure definitions. ---*/
struct tsi_ssl_handshaker_factory {
const tsi_ssl_handshaker_factory_vtable *vtable;
gpr_refcount refcount;
};
struct tsi_ssl_client_handshaker_factory { struct tsi_ssl_client_handshaker_factory {
tsi_ssl_handshaker_factory base;
SSL_CTX *ssl_context; SSL_CTX *ssl_context;
unsigned char *alpn_protocol_list; unsigned char *alpn_protocol_list;
size_t alpn_protocol_list_length; size_t alpn_protocol_list_length;
@ -77,6 +83,7 @@ struct tsi_ssl_server_handshaker_factory {
/* Several contexts to support SNI. /* Several contexts to support SNI.
The tsi_peer array contains the subject names of the server certificates The tsi_peer array contains the subject names of the server certificates
associated with the contexts at the same index. */ associated with the contexts at the same index. */
tsi_ssl_handshaker_factory base;
SSL_CTX **ssl_contexts; SSL_CTX **ssl_contexts;
tsi_peer *ssl_context_x509_subject_names; tsi_peer *ssl_context_x509_subject_names;
size_t ssl_context_count; size_t ssl_context_count;
@ -90,6 +97,7 @@ typedef struct {
BIO *into_ssl; BIO *into_ssl;
BIO *from_ssl; BIO *from_ssl;
tsi_result result; tsi_result result;
tsi_ssl_handshaker_factory *factory_ref;
} tsi_ssl_handshaker; } tsi_ssl_handshaker;
typedef struct { typedef struct {
@ -846,6 +854,47 @@ static const tsi_frame_protector_vtable frame_protector_vtable = {
ssl_protector_destroy, ssl_protector_destroy,
}; };
/* --- tsi_server_handshaker_factory methods implementation. --- */
static void tsi_ssl_handshaker_factory_destroy(
tsi_ssl_handshaker_factory *self) {
if (self == NULL) return;
if (self->vtable != NULL && self->vtable->destroy != NULL) {
self->vtable->destroy(self);
}
/* Note, we don't free(self) here because this object is always directly
* embedded in another object. If tsi_ssl_handshaker_factory_init allocates
* any memory, it should be free'd here. */
}
static tsi_ssl_handshaker_factory *tsi_ssl_handshaker_factory_ref(
tsi_ssl_handshaker_factory *self) {
if (self == NULL) return NULL;
gpr_refn(&self->refcount, 1);
return self;
}
static void tsi_ssl_handshaker_factory_unref(tsi_ssl_handshaker_factory *self) {
if (self == NULL) return;
if (gpr_unref(&self->refcount)) {
tsi_ssl_handshaker_factory_destroy(self);
}
}
static tsi_ssl_handshaker_factory_vtable handshaker_factory_vtable = {NULL};
/* Initializes a tsi_ssl_handshaker_factory object. Caller is responsible for
* allocating memory for the factory. */
static void tsi_ssl_handshaker_factory_init(
tsi_ssl_handshaker_factory *factory) {
GPR_ASSERT(factory != NULL);
factory->vtable = &handshaker_factory_vtable;
gpr_ref_init(&factory->refcount, 1);
}
/* --- tsi_handshaker methods implementation. ---*/ /* --- tsi_handshaker methods implementation. ---*/
static tsi_result ssl_handshaker_get_bytes_to_send_to_peer(tsi_handshaker *self, static tsi_result ssl_handshaker_get_bytes_to_send_to_peer(tsi_handshaker *self,
@ -1013,6 +1062,7 @@ static tsi_result ssl_handshaker_create_frame_protector(
static void ssl_handshaker_destroy(tsi_handshaker *self) { static void ssl_handshaker_destroy(tsi_handshaker *self) {
tsi_ssl_handshaker *impl = (tsi_ssl_handshaker *)self; tsi_ssl_handshaker *impl = (tsi_ssl_handshaker *)self;
SSL_free(impl->ssl); /* The BIO objects are owned by ssl */ SSL_free(impl->ssl); /* The BIO objects are owned by ssl */
tsi_ssl_handshaker_factory_unref(impl->factory_ref);
gpr_free(impl); gpr_free(impl);
} }
@ -1030,6 +1080,7 @@ static const tsi_handshaker_vtable handshaker_vtable = {
static tsi_result create_tsi_ssl_handshaker(SSL_CTX *ctx, int is_client, static tsi_result create_tsi_ssl_handshaker(SSL_CTX *ctx, int is_client,
const char *server_name_indication, const char *server_name_indication,
tsi_ssl_handshaker_factory *factory,
tsi_handshaker **handshaker) { tsi_handshaker **handshaker) {
SSL *ssl = SSL_new(ctx); SSL *ssl = SSL_new(ctx);
BIO *into_ssl = NULL; BIO *into_ssl = NULL;
@ -1085,6 +1136,8 @@ static tsi_result create_tsi_ssl_handshaker(SSL_CTX *ctx, int is_client,
impl->from_ssl = from_ssl; impl->from_ssl = from_ssl;
impl->result = TSI_HANDSHAKE_IN_PROGRESS; impl->result = TSI_HANDSHAKE_IN_PROGRESS;
impl->base.vtable = &handshaker_vtable; impl->base.vtable = &handshaker_vtable;
impl->factory_ref = tsi_ssl_handshaker_factory_ref(factory);
*handshaker = &impl->base; *handshaker = &impl->base;
return TSI_OK; return TSI_OK;
} }
@ -1121,11 +1174,20 @@ tsi_result tsi_ssl_client_handshaker_factory_create_handshaker(
tsi_ssl_client_handshaker_factory *self, const char *server_name_indication, tsi_ssl_client_handshaker_factory *self, const char *server_name_indication,
tsi_handshaker **handshaker) { tsi_handshaker **handshaker) {
return create_tsi_ssl_handshaker(self->ssl_context, 1, server_name_indication, return create_tsi_ssl_handshaker(self->ssl_context, 1, server_name_indication,
handshaker); &self->base, handshaker);
} }
void tsi_ssl_client_handshaker_factory_destroy( void tsi_ssl_client_handshaker_factory_unref(
tsi_ssl_client_handshaker_factory *self) { tsi_ssl_client_handshaker_factory *self) {
if (self == NULL) return;
tsi_ssl_handshaker_factory_unref(&self->base);
}
static void tsi_ssl_client_handshaker_factory_destroy(
tsi_ssl_handshaker_factory *factory) {
if (factory == NULL) return;
tsi_ssl_client_handshaker_factory *self =
(tsi_ssl_client_handshaker_factory *)factory;
if (self->ssl_context != NULL) SSL_CTX_free(self->ssl_context); if (self->ssl_context != NULL) SSL_CTX_free(self->ssl_context);
if (self->alpn_protocol_list != NULL) gpr_free(self->alpn_protocol_list); if (self->alpn_protocol_list != NULL) gpr_free(self->alpn_protocol_list);
gpr_free(self); gpr_free(self);
@ -1150,11 +1212,21 @@ tsi_result tsi_ssl_server_handshaker_factory_create_handshaker(
if (self->ssl_context_count == 0) return TSI_INVALID_ARGUMENT; if (self->ssl_context_count == 0) return TSI_INVALID_ARGUMENT;
/* Create the handshaker with the first context. We will switch if needed /* Create the handshaker with the first context. We will switch if needed
because of SNI in ssl_server_handshaker_factory_servername_callback. */ because of SNI in ssl_server_handshaker_factory_servername_callback. */
return create_tsi_ssl_handshaker(self->ssl_contexts[0], 0, NULL, handshaker); return create_tsi_ssl_handshaker(self->ssl_contexts[0], 0, NULL, &self->base,
handshaker);
} }
void tsi_ssl_server_handshaker_factory_destroy( void tsi_ssl_server_handshaker_factory_unref(
tsi_ssl_server_handshaker_factory *self) { tsi_ssl_server_handshaker_factory *self) {
if (self == NULL) return;
tsi_ssl_handshaker_factory_unref(&self->base);
}
static void tsi_ssl_server_handshaker_factory_destroy(
tsi_ssl_handshaker_factory *factory) {
if (factory == NULL) return;
tsi_ssl_server_handshaker_factory *self =
(tsi_ssl_server_handshaker_factory *)factory;
size_t i; size_t i;
for (i = 0; i < self->ssl_context_count; i++) { for (i = 0; i < self->ssl_context_count; i++) {
if (self->ssl_contexts[i] != NULL) { if (self->ssl_contexts[i] != NULL) {
@ -1263,6 +1335,9 @@ static int server_handshaker_factory_npn_advertised_callback(
/* --- tsi_ssl_handshaker_factory constructors. --- */ /* --- tsi_ssl_handshaker_factory constructors. --- */
static tsi_ssl_handshaker_factory_vtable client_handshaker_factory_vtable = {
tsi_ssl_client_handshaker_factory_destroy};
tsi_result tsi_create_ssl_client_handshaker_factory( tsi_result tsi_create_ssl_client_handshaker_factory(
const tsi_ssl_pem_key_cert_pair *pem_key_cert_pair, const tsi_ssl_pem_key_cert_pair *pem_key_cert_pair,
const char *pem_root_certs, const char *cipher_suites, const char *pem_root_certs, const char *cipher_suites,
@ -1285,6 +1360,9 @@ tsi_result tsi_create_ssl_client_handshaker_factory(
} }
impl = gpr_zalloc(sizeof(*impl)); impl = gpr_zalloc(sizeof(*impl));
tsi_ssl_handshaker_factory_init(&impl->base);
impl->base.vtable = &client_handshaker_factory_vtable;
impl->ssl_context = ssl_context; impl->ssl_context = ssl_context;
do { do {
@ -1322,7 +1400,7 @@ tsi_result tsi_create_ssl_client_handshaker_factory(
} }
} while (0); } while (0);
if (result != TSI_OK) { if (result != TSI_OK) {
tsi_ssl_client_handshaker_factory_destroy(impl); tsi_ssl_handshaker_factory_unref(&impl->base);
return result; return result;
} }
SSL_CTX_set_verify(ssl_context, SSL_VERIFY_PEER, NULL); SSL_CTX_set_verify(ssl_context, SSL_VERIFY_PEER, NULL);
@ -1332,6 +1410,9 @@ tsi_result tsi_create_ssl_client_handshaker_factory(
return TSI_OK; return TSI_OK;
} }
static tsi_ssl_handshaker_factory_vtable server_handshaker_factory_vtable = {
tsi_ssl_server_handshaker_factory_destroy};
tsi_result tsi_create_ssl_server_handshaker_factory( tsi_result tsi_create_ssl_server_handshaker_factory(
const tsi_ssl_pem_key_cert_pair *pem_key_cert_pairs, const tsi_ssl_pem_key_cert_pair *pem_key_cert_pairs,
size_t num_key_cert_pairs, const char *pem_client_root_certs, size_t num_key_cert_pairs, const char *pem_client_root_certs,
@ -1364,12 +1445,15 @@ tsi_result tsi_create_ssl_server_handshaker_factory_ex(
} }
impl = gpr_zalloc(sizeof(*impl)); impl = gpr_zalloc(sizeof(*impl));
tsi_ssl_handshaker_factory_init(&impl->base);
impl->base.vtable = &server_handshaker_factory_vtable;
impl->ssl_contexts = gpr_zalloc(num_key_cert_pairs * sizeof(SSL_CTX *)); impl->ssl_contexts = gpr_zalloc(num_key_cert_pairs * sizeof(SSL_CTX *));
impl->ssl_context_x509_subject_names = impl->ssl_context_x509_subject_names =
gpr_zalloc(num_key_cert_pairs * sizeof(tsi_peer)); gpr_zalloc(num_key_cert_pairs * sizeof(tsi_peer));
if (impl->ssl_contexts == NULL || if (impl->ssl_contexts == NULL ||
impl->ssl_context_x509_subject_names == NULL) { impl->ssl_context_x509_subject_names == NULL) {
tsi_ssl_server_handshaker_factory_destroy(impl); tsi_ssl_handshaker_factory_unref(&impl->base);
return TSI_OUT_OF_RESOURCES; return TSI_OUT_OF_RESOURCES;
} }
impl->ssl_context_count = num_key_cert_pairs; impl->ssl_context_count = num_key_cert_pairs;
@ -1379,7 +1463,7 @@ tsi_result tsi_create_ssl_server_handshaker_factory_ex(
&impl->alpn_protocol_list, &impl->alpn_protocol_list,
&impl->alpn_protocol_list_length); &impl->alpn_protocol_list_length);
if (result != TSI_OK) { if (result != TSI_OK) {
tsi_ssl_server_handshaker_factory_destroy(impl); tsi_ssl_handshaker_factory_unref(&impl->base);
return result; return result;
} }
} }
@ -1451,10 +1535,11 @@ tsi_result tsi_create_ssl_server_handshaker_factory_ex(
} while (0); } while (0);
if (result != TSI_OK) { if (result != TSI_OK) {
tsi_ssl_server_handshaker_factory_destroy(impl); tsi_ssl_handshaker_factory_unref(&impl->base);
return result; return result;
} }
} }
*factory = impl; *factory = impl;
return TSI_OK; return TSI_OK;
} }
@ -1501,3 +1586,15 @@ int tsi_ssl_peer_matches_name(const tsi_peer *peer, const char *name) {
return 0; /* Not found. */ return 0; /* Not found. */
} }
/* --- Testing support. --- */
const tsi_ssl_handshaker_factory_vtable *tsi_ssl_handshaker_factory_swap_vtable(
tsi_ssl_handshaker_factory *factory,
tsi_ssl_handshaker_factory_vtable *new_vtable) {
GPR_ASSERT(factory != NULL);
GPR_ASSERT(factory->vtable != NULL);
const tsi_ssl_handshaker_factory_vtable *orig_vtable = factory->vtable;
factory->vtable = new_vtable;
return orig_vtable;
}

@ -96,10 +96,10 @@ tsi_result tsi_ssl_client_handshaker_factory_create_handshaker(
tsi_ssl_client_handshaker_factory *self, const char *server_name_indication, tsi_ssl_client_handshaker_factory *self, const char *server_name_indication,
tsi_handshaker **handshaker); tsi_handshaker **handshaker);
/* Destroys the handshaker factory. WARNING: it is unsafe to destroy a factory /* Decrements reference count of the handshaker factory. Handshaker factory will
while handshakers created with this factory are still in use. */ * be destroyed once no references exist. */
void tsi_ssl_client_handshaker_factory_destroy( void tsi_ssl_client_handshaker_factory_unref(
tsi_ssl_client_handshaker_factory *self); tsi_ssl_client_handshaker_factory *factory);
/* --- tsi_ssl_server_handshaker_factory object --- /* --- tsi_ssl_server_handshaker_factory object ---
@ -158,9 +158,9 @@ tsi_result tsi_create_ssl_server_handshaker_factory_ex(
tsi_result tsi_ssl_server_handshaker_factory_create_handshaker( tsi_result tsi_ssl_server_handshaker_factory_create_handshaker(
tsi_ssl_server_handshaker_factory *self, tsi_handshaker **handshaker); tsi_ssl_server_handshaker_factory *self, tsi_handshaker **handshaker);
/* Destroys the handshaker factory. WARNING: it is unsafe to destroy a factory /* Decrements reference count of the handshaker factory. Handshaker factory will
while handshakers created with this factory are still in use. */ * be destroyed once no references exist. */
void tsi_ssl_server_handshaker_factory_destroy( void tsi_ssl_server_handshaker_factory_unref(
tsi_ssl_server_handshaker_factory *self); tsi_ssl_server_handshaker_factory *self);
/* Util that checks that an ssl peer matches a specific name. /* Util that checks that an ssl peer matches a specific name.
@ -170,6 +170,29 @@ void tsi_ssl_server_handshaker_factory_destroy(
- handle public suffix wildchar more strictly (e.g. *.co.uk) */ - handle public suffix wildchar more strictly (e.g. *.co.uk) */
int tsi_ssl_peer_matches_name(const tsi_peer *peer, const char *name); int tsi_ssl_peer_matches_name(const tsi_peer *peer, const char *name);
/* --- Testing support. ---
These functions and typedefs are not intended to be used outside of testing.
*/
/* Base type of client and server handshaker factories. */
typedef struct tsi_ssl_handshaker_factory tsi_ssl_handshaker_factory;
/* Function pointer to handshaker_factory destructor. */
typedef void (*tsi_ssl_handshaker_factory_destructor)(
tsi_ssl_handshaker_factory *factory);
/* Virtual table for tsi_ssl_handshaker_factory. */
typedef struct {
tsi_ssl_handshaker_factory_destructor destroy;
} tsi_ssl_handshaker_factory_vtable;
/* Set destructor of handshaker_factory to new_destructor, returns previous
destructor. */
const tsi_ssl_handshaker_factory_vtable *tsi_ssl_handshaker_factory_swap_vtable(
tsi_ssl_handshaker_factory *factory,
tsi_ssl_handshaker_factory_vtable *new_vtable);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif

@ -23,7 +23,9 @@
#include "src/core/lib/iomgr/load_file.h" #include "src/core/lib/iomgr/load_file.h"
#include "src/core/lib/security/transport/security_connector.h" #include "src/core/lib/security/transport/security_connector.h"
#include "src/core/tsi/ssl_transport_security.h" #include "src/core/tsi/ssl_transport_security.h"
#include "src/core/tsi/transport_security.h"
#include "src/core/tsi/transport_security_adapter.h" #include "src/core/tsi/transport_security_adapter.h"
#include "src/core/tsi/transport_security_interface.h"
#include "test/core/tsi/transport_security_test_lib.h" #include "test/core/tsi/transport_security_test_lib.h"
#include "test/core/util/test_config.h" #include "test/core/util/test_config.h"
@ -312,10 +314,10 @@ static void ssl_test_destruct(tsi_test_fixture *fixture) {
key_cert_lib->bad_client_pem_key_cert_pair); key_cert_lib->bad_client_pem_key_cert_pair);
gpr_free(key_cert_lib->root_cert); gpr_free(key_cert_lib->root_cert);
gpr_free(key_cert_lib); gpr_free(key_cert_lib);
/* Destroy others. */ /* Unreference others. */
tsi_ssl_server_handshaker_factory_destroy( tsi_ssl_server_handshaker_factory_unref(
ssl_fixture->server_handshaker_factory); ssl_fixture->server_handshaker_factory);
tsi_ssl_client_handshaker_factory_destroy( tsi_ssl_client_handshaker_factory_unref(
ssl_fixture->client_handshaker_factory); ssl_fixture->client_handshaker_factory);
} }
@ -536,6 +538,118 @@ void ssl_tsi_test_do_round_trip_odd_buffer_size() {
} }
} }
static const tsi_ssl_handshaker_factory_vtable *original_vtable;
static bool handshaker_factory_destructor_called;
static void ssl_tsi_test_handshaker_factory_destructor(
tsi_ssl_handshaker_factory *factory) {
GPR_ASSERT(factory != NULL);
handshaker_factory_destructor_called = true;
if (original_vtable != NULL && original_vtable->destroy != NULL) {
original_vtable->destroy(factory);
}
}
static tsi_ssl_handshaker_factory_vtable test_handshaker_factory_vtable = {
ssl_tsi_test_handshaker_factory_destructor};
void test_tsi_ssl_client_handshaker_factory_refcounting() {
int i;
const char *cert_chain =
load_file(SSL_TSI_TEST_CREDENTIALS_DIR, "client.pem");
tsi_ssl_client_handshaker_factory *client_handshaker_factory;
GPR_ASSERT(tsi_create_ssl_client_handshaker_factory(
NULL, cert_chain, NULL, NULL, 0, &client_handshaker_factory) ==
TSI_OK);
handshaker_factory_destructor_called = false;
original_vtable = tsi_ssl_handshaker_factory_swap_vtable(
(tsi_ssl_handshaker_factory *)client_handshaker_factory,
&test_handshaker_factory_vtable);
tsi_handshaker *handshaker[3];
for (i = 0; i < 3; ++i) {
GPR_ASSERT(tsi_ssl_client_handshaker_factory_create_handshaker(
client_handshaker_factory, "google.com", &handshaker[i]) ==
TSI_OK);
}
tsi_handshaker_destroy(handshaker[1]);
GPR_ASSERT(!handshaker_factory_destructor_called);
tsi_handshaker_destroy(handshaker[0]);
GPR_ASSERT(!handshaker_factory_destructor_called);
tsi_ssl_client_handshaker_factory_unref(client_handshaker_factory);
GPR_ASSERT(!handshaker_factory_destructor_called);
tsi_handshaker_destroy(handshaker[2]);
GPR_ASSERT(handshaker_factory_destructor_called);
gpr_free((void *)cert_chain);
}
void test_tsi_ssl_server_handshaker_factory_refcounting() {
int i;
tsi_ssl_server_handshaker_factory *server_handshaker_factory;
tsi_handshaker *handshaker[3];
const char *cert_chain =
load_file(SSL_TSI_TEST_CREDENTIALS_DIR, "server0.pem");
tsi_ssl_pem_key_cert_pair cert_pair;
cert_pair.cert_chain = cert_chain;
cert_pair.private_key =
load_file(SSL_TSI_TEST_CREDENTIALS_DIR, "server0.key");
GPR_ASSERT(tsi_create_ssl_server_handshaker_factory(
&cert_pair, 1, cert_chain, 0, NULL, NULL, 0,
&server_handshaker_factory) == TSI_OK);
handshaker_factory_destructor_called = false;
original_vtable = tsi_ssl_handshaker_factory_swap_vtable(
(tsi_ssl_handshaker_factory *)server_handshaker_factory,
&test_handshaker_factory_vtable);
for (i = 0; i < 3; ++i) {
GPR_ASSERT(tsi_ssl_server_handshaker_factory_create_handshaker(
server_handshaker_factory, &handshaker[i]) == TSI_OK);
}
tsi_handshaker_destroy(handshaker[1]);
GPR_ASSERT(!handshaker_factory_destructor_called);
tsi_handshaker_destroy(handshaker[0]);
GPR_ASSERT(!handshaker_factory_destructor_called);
tsi_ssl_server_handshaker_factory_unref(server_handshaker_factory);
GPR_ASSERT(!handshaker_factory_destructor_called);
tsi_handshaker_destroy(handshaker[2]);
GPR_ASSERT(handshaker_factory_destructor_called);
ssl_test_pem_key_cert_pair_destroy(cert_pair);
}
/* Attempting to create a handshaker factory with invalid parameters should fail
* but not crash. */
void test_tsi_ssl_client_handshaker_factory_bad_params() {
const char *cert_chain = "This is not a valid PEM file.";
tsi_ssl_client_handshaker_factory *client_handshaker_factory;
GPR_ASSERT(tsi_create_ssl_client_handshaker_factory(
NULL, cert_chain, NULL, NULL, 0, &client_handshaker_factory) ==
TSI_INVALID_ARGUMENT);
tsi_ssl_client_handshaker_factory_unref(client_handshaker_factory);
}
void ssl_tsi_test_handshaker_factory_internals() {
test_tsi_ssl_client_handshaker_factory_refcounting();
test_tsi_ssl_server_handshaker_factory_refcounting();
test_tsi_ssl_client_handshaker_factory_bad_params();
}
int main(int argc, char **argv) { int main(int argc, char **argv) {
grpc_test_init(argc, argv); grpc_test_init(argc, argv);
grpc_init(); grpc_init();
@ -553,6 +667,7 @@ int main(int argc, char **argv) {
ssl_tsi_test_do_handshake_alpn_client_server_ok(); ssl_tsi_test_do_handshake_alpn_client_server_ok();
ssl_tsi_test_do_round_trip_for_all_configs(); ssl_tsi_test_do_round_trip_for_all_configs();
ssl_tsi_test_do_round_trip_odd_buffer_size(); ssl_tsi_test_do_round_trip_odd_buffer_size();
ssl_tsi_test_handshaker_factory_internals();
grpc_shutdown(); grpc_shutdown();
return 0; return 0;
} }

Loading…
Cancel
Save