diff --git a/src/core/security/server_secure_chttp2.c b/src/core/security/server_secure_chttp2.c index 4dcd4b55244..b15c553b82c 100644 --- a/src/core/security/server_secure_chttp2.c +++ b/src/core/security/server_secure_chttp2.c @@ -37,6 +37,7 @@ #include "src/core/channel/http_filter.h" #include "src/core/channel/http_server_filter.h" +#include "src/core/iomgr/endpoint.h" #include "src/core/iomgr/resolve_address.h" #include "src/core/iomgr/tcp_server.h" #include "src/core/security/security_context.h" @@ -45,8 +46,29 @@ #include "src/core/transport/chttp2_transport.h" #include #include +#include #include +typedef struct grpc_server_secure_state { + grpc_server *server; + grpc_tcp_server *tcp; + grpc_security_context *ctx; + int is_shutdown; + gpr_mu mu; + gpr_refcount refcount; +} grpc_server_secure_state; + +static void state_ref(grpc_server_secure_state *state) { + gpr_ref(&state->refcount); +} + +static void state_unref(grpc_server_secure_state *state) { + if (gpr_unref(&state->refcount)) { + grpc_security_context_unref(state->ctx); + gpr_free(state); + } +} + static grpc_transport_setup_result setup_transport(void *server, grpc_transport *transport, grpc_mdctx *mdctx) { @@ -56,55 +78,63 @@ static grpc_transport_setup_result setup_transport(void *server, GPR_ARRAY_SIZE(extra_filters), mdctx); } -static void on_secure_transport_setup_done(void *server, +static void on_secure_transport_setup_done(void *statep, grpc_security_status status, grpc_endpoint *secure_endpoint) { + grpc_server_secure_state *state = statep; if (status == GRPC_SECURITY_OK) { - grpc_create_chttp2_transport( - setup_transport, server, grpc_server_get_channel_args(server), - secure_endpoint, NULL, 0, grpc_mdctx_create(), 0); + gpr_mu_lock(&state->mu); + if (!state->is_shutdown) { + grpc_create_chttp2_transport( + setup_transport, state->server, + grpc_server_get_channel_args(state->server), + secure_endpoint, NULL, 0, grpc_mdctx_create(), 0); + } else { + /* We need to consume this here, because the server may already have gone + * away. */ + grpc_endpoint_destroy(secure_endpoint); + } + gpr_mu_unlock(&state->mu); } else { gpr_log(GPR_ERROR, "Secure transport failed with error %d", status); } + state_unref(state); } -typedef struct { - grpc_tcp_server *tcp; - grpc_security_context *ctx; - grpc_server *server; -} secured_port; - -static void on_accept(void *spp, grpc_endpoint *tcp) { - secured_port *sp = spp; - grpc_setup_secure_transport(sp->ctx, tcp, on_secure_transport_setup_done, sp->server); +static void on_accept(void *statep, grpc_endpoint *tcp) { + grpc_server_secure_state *state = statep; + state_ref(state); + grpc_setup_secure_transport(state->ctx, tcp, on_secure_transport_setup_done, state); } /* Server callback: start listening on our ports */ -static void start(grpc_server *server, void *spp, grpc_pollset **pollsets, +static void start(grpc_server *server, void *statep, grpc_pollset **pollsets, size_t pollset_count) { - secured_port *sp = spp; - grpc_tcp_server_start(sp->tcp, pollsets, pollset_count, on_accept, sp); + grpc_server_secure_state *state = statep; + grpc_tcp_server_start(state->tcp, pollsets, pollset_count, on_accept, state); } /* Server callback: destroy the tcp listener (so we don't generate further callbacks) */ -static void destroy(grpc_server *server, void *spp) { - secured_port *sp = spp; - grpc_tcp_server_destroy(sp->tcp); - grpc_security_context_unref(sp->ctx); - gpr_free(sp); +static void destroy(grpc_server *server, void *statep) { + grpc_server_secure_state *state = statep; + gpr_mu_lock(&state->mu); + state->is_shutdown = 1; + grpc_tcp_server_destroy(state->tcp); + gpr_mu_unlock(&state->mu); + state_unref(state); } int grpc_server_add_secure_http2_port(grpc_server *server, const char *addr, grpc_server_credentials *creds) { grpc_resolved_addresses *resolved = NULL; grpc_tcp_server *tcp = NULL; + grpc_server_secure_state *state = NULL; size_t i; unsigned count = 0; int port_num = -1; int port_temp; grpc_security_status status = GRPC_SECURITY_ERROR; grpc_security_context *ctx = NULL; - secured_port *sp = NULL; /* create security context */ if (creds == NULL) goto error; @@ -161,13 +191,16 @@ int grpc_server_add_secure_http2_port(grpc_server *server, const char *addr, grp } grpc_resolved_addresses_destroy(resolved); - sp = gpr_malloc(sizeof(secured_port)); - sp->tcp = tcp; - sp->ctx = ctx; - sp->server = server; + state = gpr_malloc(sizeof(*state)); + state->server = server; + state->tcp = tcp; + state->ctx = ctx; + state->is_shutdown = 0; + gpr_mu_init(&state->mu); + gpr_ref_init(&state->refcount, 1); /* Register with the server only upon success */ - grpc_server_add_listener(server, sp, start, destroy); + grpc_server_add_listener(server, state, start, destroy); return port_num; @@ -182,8 +215,8 @@ error: if (tcp) { grpc_tcp_server_destroy(tcp); } - if (sp) { - gpr_free(sp); + if (state) { + gpr_free(state); } return 0; } diff --git a/src/core/surface/server_chttp2.c b/src/core/surface/server_chttp2.c index fd702593b89..27434b39e2d 100644 --- a/src/core/surface/server_chttp2.c +++ b/src/core/surface/server_chttp2.c @@ -53,6 +53,13 @@ static grpc_transport_setup_result setup_transport(void *server, } static void new_transport(void *server, grpc_endpoint *tcp) { + /* + * Beware that the call to grpc_create_chttp2_transport() has to happen before + * grpc_tcp_server_destroy(). This is fine here, but similar code + * asynchronously doing a handshake instead of calling grpc_tcp_server_start() + * (as in server_secure_chttp2.c) needs to add synchronization to avoid this + * case. + */ grpc_create_chttp2_transport(setup_transport, server, grpc_server_get_channel_args(server), tcp, NULL, 0, grpc_mdctx_create(), 0);