Add reference counting to tsi_ssl_handshaker_factories

pull/12359/head
Justin Burke 7 years ago
parent 95cd9a46bd
commit 4984135a78
  1. 3
      src/core/lib/http/httpcli_security_connector.c
  2. 22
      src/core/lib/security/transport/security_connector.c
  3. 113
      src/core/tsi/ssl_transport_security.c
  4. 37
      src/core/tsi/ssl_transport_security.h
  5. 121
      test/core/tsi/ssl_transport_security_test.c

@ -43,7 +43,8 @@ static void httpcli_ssl_destroy(grpc_exec_ctx *exec_ctx,
grpc_httpcli_ssl_channel_security_connector *c =
(grpc_httpcli_ssl_channel_security_connector *)sc;
if (c->handshaker_factory != NULL) {
tsi_ssl_client_handshaker_factory_destroy(c->handshaker_factory);
tsi_ssl_client_handshaker_factory_unref(c->handshaker_factory);
c->handshaker_factory = NULL;
}
if (c->secure_peer_name != NULL) gpr_free(c->secure_peer_name);
gpr_free(sc);

@ -455,14 +455,14 @@ grpc_server_security_connector *grpc_fake_server_security_connector_create(
typedef struct {
grpc_channel_security_connector base;
tsi_ssl_client_handshaker_factory *handshaker_factory;
tsi_ssl_client_handshaker_factory *client_handshaker_factory;
char *target_name;
char *overridden_target_name;
} grpc_ssl_channel_security_connector;
typedef struct {
grpc_server_security_connector base;
tsi_ssl_server_handshaker_factory *handshaker_factory;
tsi_ssl_server_handshaker_factory *server_handshaker_factory;
} grpc_ssl_server_security_connector;
static void ssl_channel_destroy(grpc_exec_ctx *exec_ctx,
@ -470,9 +470,8 @@ static void ssl_channel_destroy(grpc_exec_ctx *exec_ctx,
grpc_ssl_channel_security_connector *c =
(grpc_ssl_channel_security_connector *)sc;
grpc_call_credentials_unref(exec_ctx, c->base.request_metadata_creds);
if (c->handshaker_factory != NULL) {
tsi_ssl_client_handshaker_factory_destroy(c->handshaker_factory);
}
tsi_ssl_client_handshaker_factory_unref(c->client_handshaker_factory);
c->client_handshaker_factory = NULL;
if (c->target_name != NULL) gpr_free(c->target_name);
if (c->overridden_target_name != NULL) gpr_free(c->overridden_target_name);
gpr_free(sc);
@ -482,9 +481,8 @@ static void ssl_server_destroy(grpc_exec_ctx *exec_ctx,
grpc_security_connector *sc) {
grpc_ssl_server_security_connector *c =
(grpc_ssl_server_security_connector *)sc;
if (c->handshaker_factory != NULL) {
tsi_ssl_server_handshaker_factory_destroy(c->handshaker_factory);
}
tsi_ssl_server_handshaker_factory_unref(c->server_handshaker_factory);
c->server_handshaker_factory = NULL;
gpr_free(sc);
}
@ -496,7 +494,7 @@ static void ssl_channel_add_handshakers(grpc_exec_ctx *exec_ctx,
// Instantiate TSI handshaker.
tsi_handshaker *tsi_hs = NULL;
tsi_result result = tsi_ssl_client_handshaker_factory_create_handshaker(
c->handshaker_factory,
c->client_handshaker_factory,
c->overridden_target_name != NULL ? c->overridden_target_name
: c->target_name,
&tsi_hs);
@ -521,7 +519,7 @@ static void ssl_server_add_handshakers(grpc_exec_ctx *exec_ctx,
// Instantiate TSI handshaker.
tsi_handshaker *tsi_hs = NULL;
tsi_result result = tsi_ssl_server_handshaker_factory_create_handshaker(
c->handshaker_factory, &tsi_hs);
c->server_handshaker_factory, &tsi_hs);
if (result != TSI_OK) {
gpr_log(GPR_ERROR, "Handshaker creation failed with error %s.",
tsi_result_to_string(result));
@ -852,7 +850,7 @@ grpc_security_status grpc_ssl_channel_security_connector_create(
result = tsi_create_ssl_client_handshaker_factory(
has_key_cert_pair ? &config->pem_key_cert_pair : NULL, pem_root_certs,
ssl_cipher_suites(), alpn_protocol_strings, (uint16_t)num_alpn_protocols,
&c->handshaker_factory);
&c->client_handshaker_factory);
if (result != TSI_OK) {
gpr_log(GPR_ERROR, "Handshaker factory creation failed with %s.",
tsi_result_to_string(result));
@ -897,7 +895,7 @@ grpc_security_status grpc_ssl_server_security_connector_create(
config->pem_root_certs, get_tsi_client_certificate_request_type(
config->client_certificate_request),
ssl_cipher_suites(), alpn_protocol_strings, (uint16_t)num_alpn_protocols,
&c->handshaker_factory);
&c->server_handshaker_factory);
if (result != TSI_OK) {
gpr_log(GPR_ERROR, "Handshaker factory creation failed with %s.",
tsi_result_to_string(result));

@ -67,7 +67,13 @@
/* --- Structure definitions. ---*/
struct tsi_ssl_handshaker_factory {
const tsi_ssl_handshaker_factory_vtable *vtable;
gpr_refcount refcount;
};
struct tsi_ssl_client_handshaker_factory {
tsi_ssl_handshaker_factory base;
SSL_CTX *ssl_context;
unsigned char *alpn_protocol_list;
size_t alpn_protocol_list_length;
@ -77,6 +83,7 @@ struct tsi_ssl_server_handshaker_factory {
/* Several contexts to support SNI.
The tsi_peer array contains the subject names of the server certificates
associated with the contexts at the same index. */
tsi_ssl_handshaker_factory base;
SSL_CTX **ssl_contexts;
tsi_peer *ssl_context_x509_subject_names;
size_t ssl_context_count;
@ -90,6 +97,7 @@ typedef struct {
BIO *into_ssl;
BIO *from_ssl;
tsi_result result;
tsi_ssl_handshaker_factory *factory_ref;
} tsi_ssl_handshaker;
typedef struct {
@ -846,6 +854,47 @@ static const tsi_frame_protector_vtable frame_protector_vtable = {
ssl_protector_destroy,
};
/* --- tsi_server_handshaker_factory methods implementation. --- */
static void tsi_ssl_handshaker_factory_destroy(
tsi_ssl_handshaker_factory *self) {
if (self == NULL) return;
if (self->vtable != NULL && self->vtable->destroy != NULL) {
self->vtable->destroy(self);
}
/* Note, we don't free(self) here because this object is always directly
* embedded in another object. If tsi_ssl_handshaker_factory_init allocates
* any memory, it should be free'd here. */
}
static tsi_ssl_handshaker_factory *tsi_ssl_handshaker_factory_ref(
tsi_ssl_handshaker_factory *self) {
if (self == NULL) return NULL;
gpr_refn(&self->refcount, 1);
return self;
}
static void tsi_ssl_handshaker_factory_unref(tsi_ssl_handshaker_factory *self) {
if (self == NULL) return;
if (gpr_unref(&self->refcount)) {
tsi_ssl_handshaker_factory_destroy(self);
}
}
static tsi_ssl_handshaker_factory_vtable handshaker_factory_vtable = {NULL};
/* Initializes a tsi_ssl_handshaker_factory object. Caller is responsible for
* allocating memory for the factory. */
static void tsi_ssl_handshaker_factory_init(
tsi_ssl_handshaker_factory *factory) {
GPR_ASSERT(factory != NULL);
factory->vtable = &handshaker_factory_vtable;
gpr_ref_init(&factory->refcount, 1);
}
/* --- tsi_handshaker methods implementation. ---*/
static tsi_result ssl_handshaker_get_bytes_to_send_to_peer(tsi_handshaker *self,
@ -1013,6 +1062,7 @@ static tsi_result ssl_handshaker_create_frame_protector(
static void ssl_handshaker_destroy(tsi_handshaker *self) {
tsi_ssl_handshaker *impl = (tsi_ssl_handshaker *)self;
SSL_free(impl->ssl); /* The BIO objects are owned by ssl */
tsi_ssl_handshaker_factory_unref(impl->factory_ref);
gpr_free(impl);
}
@ -1030,6 +1080,7 @@ static const tsi_handshaker_vtable handshaker_vtable = {
static tsi_result create_tsi_ssl_handshaker(SSL_CTX *ctx, int is_client,
const char *server_name_indication,
tsi_ssl_handshaker_factory *factory,
tsi_handshaker **handshaker) {
SSL *ssl = SSL_new(ctx);
BIO *into_ssl = NULL;
@ -1085,6 +1136,8 @@ static tsi_result create_tsi_ssl_handshaker(SSL_CTX *ctx, int is_client,
impl->from_ssl = from_ssl;
impl->result = TSI_HANDSHAKE_IN_PROGRESS;
impl->base.vtable = &handshaker_vtable;
impl->factory_ref = tsi_ssl_handshaker_factory_ref(factory);
*handshaker = &impl->base;
return TSI_OK;
}
@ -1121,11 +1174,20 @@ tsi_result tsi_ssl_client_handshaker_factory_create_handshaker(
tsi_ssl_client_handshaker_factory *self, const char *server_name_indication,
tsi_handshaker **handshaker) {
return create_tsi_ssl_handshaker(self->ssl_context, 1, server_name_indication,
handshaker);
&self->base, handshaker);
}
void tsi_ssl_client_handshaker_factory_destroy(
void tsi_ssl_client_handshaker_factory_unref(
tsi_ssl_client_handshaker_factory *self) {
if (self == NULL) return;
tsi_ssl_handshaker_factory_unref(&self->base);
}
static void tsi_ssl_client_handshaker_factory_destroy(
tsi_ssl_handshaker_factory *factory) {
if (factory == NULL) return;
tsi_ssl_client_handshaker_factory *self =
(tsi_ssl_client_handshaker_factory *)factory;
if (self->ssl_context != NULL) SSL_CTX_free(self->ssl_context);
if (self->alpn_protocol_list != NULL) gpr_free(self->alpn_protocol_list);
gpr_free(self);
@ -1150,11 +1212,21 @@ tsi_result tsi_ssl_server_handshaker_factory_create_handshaker(
if (self->ssl_context_count == 0) return TSI_INVALID_ARGUMENT;
/* Create the handshaker with the first context. We will switch if needed
because of SNI in ssl_server_handshaker_factory_servername_callback. */
return create_tsi_ssl_handshaker(self->ssl_contexts[0], 0, NULL, handshaker);
return create_tsi_ssl_handshaker(self->ssl_contexts[0], 0, NULL, &self->base,
handshaker);
}
void tsi_ssl_server_handshaker_factory_destroy(
void tsi_ssl_server_handshaker_factory_unref(
tsi_ssl_server_handshaker_factory *self) {
if (self == NULL) return;
tsi_ssl_handshaker_factory_unref(&self->base);
}
static void tsi_ssl_server_handshaker_factory_destroy(
tsi_ssl_handshaker_factory *factory) {
if (factory == NULL) return;
tsi_ssl_server_handshaker_factory *self =
(tsi_ssl_server_handshaker_factory *)factory;
size_t i;
for (i = 0; i < self->ssl_context_count; i++) {
if (self->ssl_contexts[i] != NULL) {
@ -1263,6 +1335,9 @@ static int server_handshaker_factory_npn_advertised_callback(
/* --- tsi_ssl_handshaker_factory constructors. --- */
static tsi_ssl_handshaker_factory_vtable client_handshaker_factory_vtable = {
tsi_ssl_client_handshaker_factory_destroy};
tsi_result tsi_create_ssl_client_handshaker_factory(
const tsi_ssl_pem_key_cert_pair *pem_key_cert_pair,
const char *pem_root_certs, const char *cipher_suites,
@ -1285,6 +1360,9 @@ tsi_result tsi_create_ssl_client_handshaker_factory(
}
impl = gpr_zalloc(sizeof(*impl));
tsi_ssl_handshaker_factory_init(&impl->base);
impl->base.vtable = &client_handshaker_factory_vtable;
impl->ssl_context = ssl_context;
do {
@ -1322,7 +1400,7 @@ tsi_result tsi_create_ssl_client_handshaker_factory(
}
} while (0);
if (result != TSI_OK) {
tsi_ssl_client_handshaker_factory_destroy(impl);
tsi_ssl_handshaker_factory_unref(&impl->base);
return result;
}
SSL_CTX_set_verify(ssl_context, SSL_VERIFY_PEER, NULL);
@ -1332,6 +1410,9 @@ tsi_result tsi_create_ssl_client_handshaker_factory(
return TSI_OK;
}
static tsi_ssl_handshaker_factory_vtable server_handshaker_factory_vtable = {
tsi_ssl_server_handshaker_factory_destroy};
tsi_result tsi_create_ssl_server_handshaker_factory(
const tsi_ssl_pem_key_cert_pair *pem_key_cert_pairs,
size_t num_key_cert_pairs, const char *pem_client_root_certs,
@ -1364,12 +1445,15 @@ tsi_result tsi_create_ssl_server_handshaker_factory_ex(
}
impl = gpr_zalloc(sizeof(*impl));
tsi_ssl_handshaker_factory_init(&impl->base);
impl->base.vtable = &server_handshaker_factory_vtable;
impl->ssl_contexts = gpr_zalloc(num_key_cert_pairs * sizeof(SSL_CTX *));
impl->ssl_context_x509_subject_names =
gpr_zalloc(num_key_cert_pairs * sizeof(tsi_peer));
if (impl->ssl_contexts == NULL ||
impl->ssl_context_x509_subject_names == NULL) {
tsi_ssl_server_handshaker_factory_destroy(impl);
tsi_ssl_handshaker_factory_unref(&impl->base);
return TSI_OUT_OF_RESOURCES;
}
impl->ssl_context_count = num_key_cert_pairs;
@ -1379,7 +1463,7 @@ tsi_result tsi_create_ssl_server_handshaker_factory_ex(
&impl->alpn_protocol_list,
&impl->alpn_protocol_list_length);
if (result != TSI_OK) {
tsi_ssl_server_handshaker_factory_destroy(impl);
tsi_ssl_handshaker_factory_unref(&impl->base);
return result;
}
}
@ -1451,10 +1535,11 @@ tsi_result tsi_create_ssl_server_handshaker_factory_ex(
} while (0);
if (result != TSI_OK) {
tsi_ssl_server_handshaker_factory_destroy(impl);
tsi_ssl_handshaker_factory_unref(&impl->base);
return result;
}
}
*factory = impl;
return TSI_OK;
}
@ -1501,3 +1586,15 @@ int tsi_ssl_peer_matches_name(const tsi_peer *peer, const char *name) {
return 0; /* Not found. */
}
/* --- Testing support. --- */
const tsi_ssl_handshaker_factory_vtable *tsi_ssl_handshaker_factory_swap_vtable(
tsi_ssl_handshaker_factory *factory,
tsi_ssl_handshaker_factory_vtable *new_vtable) {
GPR_ASSERT(factory != NULL);
GPR_ASSERT(factory->vtable != NULL);
const tsi_ssl_handshaker_factory_vtable *orig_vtable = factory->vtable;
factory->vtable = new_vtable;
return orig_vtable;
}

@ -96,10 +96,10 @@ tsi_result tsi_ssl_client_handshaker_factory_create_handshaker(
tsi_ssl_client_handshaker_factory *self, const char *server_name_indication,
tsi_handshaker **handshaker);
/* Destroys the handshaker factory. WARNING: it is unsafe to destroy a factory
while handshakers created with this factory are still in use. */
void tsi_ssl_client_handshaker_factory_destroy(
tsi_ssl_client_handshaker_factory *self);
/* Decrements reference count of the handshaker factory. Handshaker factory will
* be destroyed once no references exist. */
void tsi_ssl_client_handshaker_factory_unref(
tsi_ssl_client_handshaker_factory *factory);
/* --- tsi_ssl_server_handshaker_factory object ---
@ -158,9 +158,9 @@ tsi_result tsi_create_ssl_server_handshaker_factory_ex(
tsi_result tsi_ssl_server_handshaker_factory_create_handshaker(
tsi_ssl_server_handshaker_factory *self, tsi_handshaker **handshaker);
/* Destroys the handshaker factory. WARNING: it is unsafe to destroy a factory
while handshakers created with this factory are still in use. */
void tsi_ssl_server_handshaker_factory_destroy(
/* Decrements reference count of the handshaker factory. Handshaker factory will
* be destroyed once no references exist. */
void tsi_ssl_server_handshaker_factory_unref(
tsi_ssl_server_handshaker_factory *self);
/* Util that checks that an ssl peer matches a specific name.
@ -170,6 +170,29 @@ void tsi_ssl_server_handshaker_factory_destroy(
- handle public suffix wildchar more strictly (e.g. *.co.uk) */
int tsi_ssl_peer_matches_name(const tsi_peer *peer, const char *name);
/* --- Testing support. ---
These functions and typedefs are not intended to be used outside of testing.
*/
/* Base type of client and server handshaker factories. */
typedef struct tsi_ssl_handshaker_factory tsi_ssl_handshaker_factory;
/* Function pointer to handshaker_factory destructor. */
typedef void (*tsi_ssl_handshaker_factory_destructor)(
tsi_ssl_handshaker_factory *factory);
/* Virtual table for tsi_ssl_handshaker_factory. */
typedef struct {
tsi_ssl_handshaker_factory_destructor destroy;
} tsi_ssl_handshaker_factory_vtable;
/* Set destructor of handshaker_factory to new_destructor, returns previous
destructor. */
const tsi_ssl_handshaker_factory_vtable *tsi_ssl_handshaker_factory_swap_vtable(
tsi_ssl_handshaker_factory *factory,
tsi_ssl_handshaker_factory_vtable *new_vtable);
#ifdef __cplusplus
}
#endif

@ -23,7 +23,9 @@
#include "src/core/lib/iomgr/load_file.h"
#include "src/core/lib/security/transport/security_connector.h"
#include "src/core/tsi/ssl_transport_security.h"
#include "src/core/tsi/transport_security.h"
#include "src/core/tsi/transport_security_adapter.h"
#include "src/core/tsi/transport_security_interface.h"
#include "test/core/tsi/transport_security_test_lib.h"
#include "test/core/util/test_config.h"
@ -312,10 +314,10 @@ static void ssl_test_destruct(tsi_test_fixture *fixture) {
key_cert_lib->bad_client_pem_key_cert_pair);
gpr_free(key_cert_lib->root_cert);
gpr_free(key_cert_lib);
/* Destroy others. */
tsi_ssl_server_handshaker_factory_destroy(
/* Unreference others. */
tsi_ssl_server_handshaker_factory_unref(
ssl_fixture->server_handshaker_factory);
tsi_ssl_client_handshaker_factory_destroy(
tsi_ssl_client_handshaker_factory_unref(
ssl_fixture->client_handshaker_factory);
}
@ -536,6 +538,118 @@ void ssl_tsi_test_do_round_trip_odd_buffer_size() {
}
}
static const tsi_ssl_handshaker_factory_vtable *original_vtable;
static bool handshaker_factory_destructor_called;
static void ssl_tsi_test_handshaker_factory_destructor(
tsi_ssl_handshaker_factory *factory) {
GPR_ASSERT(factory != NULL);
handshaker_factory_destructor_called = true;
if (original_vtable != NULL && original_vtable->destroy != NULL) {
original_vtable->destroy(factory);
}
}
static tsi_ssl_handshaker_factory_vtable test_handshaker_factory_vtable = {
ssl_tsi_test_handshaker_factory_destructor};
void test_tsi_ssl_client_handshaker_factory_refcounting() {
int i;
const char *cert_chain =
load_file(SSL_TSI_TEST_CREDENTIALS_DIR, "client.pem");
tsi_ssl_client_handshaker_factory *client_handshaker_factory;
GPR_ASSERT(tsi_create_ssl_client_handshaker_factory(
NULL, cert_chain, NULL, NULL, 0, &client_handshaker_factory) ==
TSI_OK);
handshaker_factory_destructor_called = false;
original_vtable = tsi_ssl_handshaker_factory_swap_vtable(
(tsi_ssl_handshaker_factory *)client_handshaker_factory,
&test_handshaker_factory_vtable);
tsi_handshaker *handshaker[3];
for (i = 0; i < 3; ++i) {
GPR_ASSERT(tsi_ssl_client_handshaker_factory_create_handshaker(
client_handshaker_factory, "google.com", &handshaker[i]) ==
TSI_OK);
}
tsi_handshaker_destroy(handshaker[1]);
GPR_ASSERT(!handshaker_factory_destructor_called);
tsi_handshaker_destroy(handshaker[0]);
GPR_ASSERT(!handshaker_factory_destructor_called);
tsi_ssl_client_handshaker_factory_unref(client_handshaker_factory);
GPR_ASSERT(!handshaker_factory_destructor_called);
tsi_handshaker_destroy(handshaker[2]);
GPR_ASSERT(handshaker_factory_destructor_called);
gpr_free((void *)cert_chain);
}
void test_tsi_ssl_server_handshaker_factory_refcounting() {
int i;
tsi_ssl_server_handshaker_factory *server_handshaker_factory;
tsi_handshaker *handshaker[3];
const char *cert_chain =
load_file(SSL_TSI_TEST_CREDENTIALS_DIR, "server0.pem");
tsi_ssl_pem_key_cert_pair cert_pair;
cert_pair.cert_chain = cert_chain;
cert_pair.private_key =
load_file(SSL_TSI_TEST_CREDENTIALS_DIR, "server0.key");
GPR_ASSERT(tsi_create_ssl_server_handshaker_factory(
&cert_pair, 1, cert_chain, 0, NULL, NULL, 0,
&server_handshaker_factory) == TSI_OK);
handshaker_factory_destructor_called = false;
original_vtable = tsi_ssl_handshaker_factory_swap_vtable(
(tsi_ssl_handshaker_factory *)server_handshaker_factory,
&test_handshaker_factory_vtable);
for (i = 0; i < 3; ++i) {
GPR_ASSERT(tsi_ssl_server_handshaker_factory_create_handshaker(
server_handshaker_factory, &handshaker[i]) == TSI_OK);
}
tsi_handshaker_destroy(handshaker[1]);
GPR_ASSERT(!handshaker_factory_destructor_called);
tsi_handshaker_destroy(handshaker[0]);
GPR_ASSERT(!handshaker_factory_destructor_called);
tsi_ssl_server_handshaker_factory_unref(server_handshaker_factory);
GPR_ASSERT(!handshaker_factory_destructor_called);
tsi_handshaker_destroy(handshaker[2]);
GPR_ASSERT(handshaker_factory_destructor_called);
ssl_test_pem_key_cert_pair_destroy(cert_pair);
}
/* Attempting to create a handshaker factory with invalid parameters should fail
* but not crash. */
void test_tsi_ssl_client_handshaker_factory_bad_params() {
const char *cert_chain = "This is not a valid PEM file.";
tsi_ssl_client_handshaker_factory *client_handshaker_factory;
GPR_ASSERT(tsi_create_ssl_client_handshaker_factory(
NULL, cert_chain, NULL, NULL, 0, &client_handshaker_factory) ==
TSI_INVALID_ARGUMENT);
tsi_ssl_client_handshaker_factory_unref(client_handshaker_factory);
}
void ssl_tsi_test_handshaker_factory_internals() {
test_tsi_ssl_client_handshaker_factory_refcounting();
test_tsi_ssl_server_handshaker_factory_refcounting();
test_tsi_ssl_client_handshaker_factory_bad_params();
}
int main(int argc, char **argv) {
grpc_test_init(argc, argv);
grpc_init();
@ -553,6 +667,7 @@ int main(int argc, char **argv) {
ssl_tsi_test_do_handshake_alpn_client_server_ok();
ssl_tsi_test_do_round_trip_for_all_configs();
ssl_tsi_test_do_round_trip_odd_buffer_size();
ssl_tsi_test_handshaker_factory_internals();
grpc_shutdown();
return 0;
}

Loading…
Cancel
Save