From 195a30bb8bc05f7fb1c2873f639621d6fea2948d Mon Sep 17 00:00:00 2001 From: Arjun Roy Date: Wed, 16 Jan 2019 16:30:39 -0800 Subject: [PATCH] Grpc: Change grpc_handshake and grpc_handshake_mgr to use CPP implementations. grpc_handshake is renamed to GrpcHandshake, using C++ class definitions instead of C-style vtable classes. Update callers to use new interfaces. We use RefCountedPtr to simplify reference tracking. --- BUILD | 1 - CMakeLists.txt | 6 - Makefile | 6 - build.yaml | 1 - config.m4 | 1 - config.w32 | 1 - gRPC-Core.podspec | 1 - grpc.gemspec | 1 - grpc.gyp | 4 - package.xml | 1 - .../client_channel/http_connect_handshaker.cc | 302 +++++----- .../chttp2/client/chttp2_connector.cc | 22 +- .../transport/chttp2/server/chttp2_server.cc | 37 +- src/core/lib/channel/handshaker.cc | 355 +++++------ src/core/lib/channel/handshaker.h | 211 +++---- src/core/lib/channel/handshaker_factory.cc | 42 -- src/core/lib/channel/handshaker_factory.h | 30 +- src/core/lib/channel/handshaker_registry.cc | 116 ++-- src/core/lib/channel/handshaker_registry.h | 37 +- .../lib/http/httpcli_security_connector.cc | 27 +- .../alts/alts_security_connector.cc | 18 +- .../fake/fake_security_connector.cc | 16 +- .../local/local_security_connector.cc | 18 +- .../security_connector/security_connector.h | 4 +- .../ssl/ssl_security_connector.cc | 10 +- .../security/transport/security_handshaker.cc | 567 +++++++++--------- .../security/transport/security_handshaker.h | 13 +- src/core/lib/surface/init.cc | 4 +- src/core/lib/surface/init_secure.cc | 2 +- src/python/grpcio/grpc_core_dependencies.py | 1 - .../readahead_handshaker_server_ssl.cc | 65 +- test/core/security/ssl_server_fuzzer.cc | 15 +- tools/doxygen/Doxyfile.core.internal | 1 - .../generated/sources_and_headers.json | 1 - 34 files changed, 882 insertions(+), 1055 deletions(-) delete mode 100644 src/core/lib/channel/handshaker_factory.cc diff --git a/BUILD b/BUILD index 3f1e735466d..ebb03580bb4 100644 --- a/BUILD +++ b/BUILD @@ -701,7 +701,6 @@ grpc_cc_library( "src/core/lib/channel/channelz_registry.cc", "src/core/lib/channel/connected_channel.cc", "src/core/lib/channel/handshaker.cc", - "src/core/lib/channel/handshaker_factory.cc", "src/core/lib/channel/handshaker_registry.cc", "src/core/lib/channel/status_util.cc", "src/core/lib/compression/compression.cc", diff --git a/CMakeLists.txt b/CMakeLists.txt index 6f1a0f6af9b..b2de3f6fde5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -971,7 +971,6 @@ add_library(grpc src/core/lib/channel/channelz_registry.cc src/core/lib/channel/connected_channel.cc src/core/lib/channel/handshaker.cc - src/core/lib/channel/handshaker_factory.cc src/core/lib/channel/handshaker_registry.cc src/core/lib/channel/status_util.cc src/core/lib/compression/compression.cc @@ -1397,7 +1396,6 @@ add_library(grpc_cronet src/core/lib/channel/channelz_registry.cc src/core/lib/channel/connected_channel.cc src/core/lib/channel/handshaker.cc - src/core/lib/channel/handshaker_factory.cc src/core/lib/channel/handshaker_registry.cc src/core/lib/channel/status_util.cc src/core/lib/compression/compression.cc @@ -1808,7 +1806,6 @@ add_library(grpc_test_util src/core/lib/channel/channelz_registry.cc src/core/lib/channel/connected_channel.cc src/core/lib/channel/handshaker.cc - src/core/lib/channel/handshaker_factory.cc src/core/lib/channel/handshaker_registry.cc src/core/lib/channel/status_util.cc src/core/lib/compression/compression.cc @@ -2134,7 +2131,6 @@ add_library(grpc_test_util_unsecure src/core/lib/channel/channelz_registry.cc src/core/lib/channel/connected_channel.cc src/core/lib/channel/handshaker.cc - src/core/lib/channel/handshaker_factory.cc src/core/lib/channel/handshaker_registry.cc src/core/lib/channel/status_util.cc src/core/lib/compression/compression.cc @@ -2436,7 +2432,6 @@ add_library(grpc_unsecure src/core/lib/channel/channelz_registry.cc src/core/lib/channel/connected_channel.cc src/core/lib/channel/handshaker.cc - src/core/lib/channel/handshaker_factory.cc src/core/lib/channel/handshaker_registry.cc src/core/lib/channel/status_util.cc src/core/lib/compression/compression.cc @@ -3324,7 +3319,6 @@ add_library(grpc++_cronet src/core/lib/channel/channelz_registry.cc src/core/lib/channel/connected_channel.cc src/core/lib/channel/handshaker.cc - src/core/lib/channel/handshaker_factory.cc src/core/lib/channel/handshaker_registry.cc src/core/lib/channel/status_util.cc src/core/lib/compression/compression.cc diff --git a/Makefile b/Makefile index e41c0584c7d..069d001d3be 100644 --- a/Makefile +++ b/Makefile @@ -3497,7 +3497,6 @@ LIBGRPC_SRC = \ src/core/lib/channel/channelz_registry.cc \ src/core/lib/channel/connected_channel.cc \ src/core/lib/channel/handshaker.cc \ - src/core/lib/channel/handshaker_factory.cc \ src/core/lib/channel/handshaker_registry.cc \ src/core/lib/channel/status_util.cc \ src/core/lib/compression/compression.cc \ @@ -3917,7 +3916,6 @@ LIBGRPC_CRONET_SRC = \ src/core/lib/channel/channelz_registry.cc \ src/core/lib/channel/connected_channel.cc \ src/core/lib/channel/handshaker.cc \ - src/core/lib/channel/handshaker_factory.cc \ src/core/lib/channel/handshaker_registry.cc \ src/core/lib/channel/status_util.cc \ src/core/lib/compression/compression.cc \ @@ -4321,7 +4319,6 @@ LIBGRPC_TEST_UTIL_SRC = \ src/core/lib/channel/channelz_registry.cc \ src/core/lib/channel/connected_channel.cc \ src/core/lib/channel/handshaker.cc \ - src/core/lib/channel/handshaker_factory.cc \ src/core/lib/channel/handshaker_registry.cc \ src/core/lib/channel/status_util.cc \ src/core/lib/compression/compression.cc \ @@ -4634,7 +4631,6 @@ LIBGRPC_TEST_UTIL_UNSECURE_SRC = \ src/core/lib/channel/channelz_registry.cc \ src/core/lib/channel/connected_channel.cc \ src/core/lib/channel/handshaker.cc \ - src/core/lib/channel/handshaker_factory.cc \ src/core/lib/channel/handshaker_registry.cc \ src/core/lib/channel/status_util.cc \ src/core/lib/compression/compression.cc \ @@ -4910,7 +4906,6 @@ LIBGRPC_UNSECURE_SRC = \ src/core/lib/channel/channelz_registry.cc \ src/core/lib/channel/connected_channel.cc \ src/core/lib/channel/handshaker.cc \ - src/core/lib/channel/handshaker_factory.cc \ src/core/lib/channel/handshaker_registry.cc \ src/core/lib/channel/status_util.cc \ src/core/lib/compression/compression.cc \ @@ -5775,7 +5770,6 @@ LIBGRPC++_CRONET_SRC = \ src/core/lib/channel/channelz_registry.cc \ src/core/lib/channel/connected_channel.cc \ src/core/lib/channel/handshaker.cc \ - src/core/lib/channel/handshaker_factory.cc \ src/core/lib/channel/handshaker_registry.cc \ src/core/lib/channel/status_util.cc \ src/core/lib/compression/compression.cc \ diff --git a/build.yaml b/build.yaml index ec00450f28a..f96b0cbcf22 100644 --- a/build.yaml +++ b/build.yaml @@ -242,7 +242,6 @@ filegroups: - src/core/lib/channel/channelz_registry.cc - src/core/lib/channel/connected_channel.cc - src/core/lib/channel/handshaker.cc - - src/core/lib/channel/handshaker_factory.cc - src/core/lib/channel/handshaker_registry.cc - src/core/lib/channel/status_util.cc - src/core/lib/compression/compression.cc diff --git a/config.m4 b/config.m4 index 1874f3ba1b0..5746caf694a 100644 --- a/config.m4 +++ b/config.m4 @@ -94,7 +94,6 @@ if test "$PHP_GRPC" != "no"; then src/core/lib/channel/channelz_registry.cc \ src/core/lib/channel/connected_channel.cc \ src/core/lib/channel/handshaker.cc \ - src/core/lib/channel/handshaker_factory.cc \ src/core/lib/channel/handshaker_registry.cc \ src/core/lib/channel/status_util.cc \ src/core/lib/compression/compression.cc \ diff --git a/config.w32 b/config.w32 index 452e8fd18b1..5659d8b8408 100644 --- a/config.w32 +++ b/config.w32 @@ -69,7 +69,6 @@ if (PHP_GRPC != "no") { "src\\core\\lib\\channel\\channelz_registry.cc " + "src\\core\\lib\\channel\\connected_channel.cc " + "src\\core\\lib\\channel\\handshaker.cc " + - "src\\core\\lib\\channel\\handshaker_factory.cc " + "src\\core\\lib\\channel\\handshaker_registry.cc " + "src\\core\\lib\\channel\\status_util.cc " + "src\\core\\lib\\compression\\compression.cc " + diff --git a/gRPC-Core.podspec b/gRPC-Core.podspec index da48fe7e953..625d1a9a50c 100644 --- a/gRPC-Core.podspec +++ b/gRPC-Core.podspec @@ -543,7 +543,6 @@ Pod::Spec.new do |s| 'src/core/lib/channel/channelz_registry.cc', 'src/core/lib/channel/connected_channel.cc', 'src/core/lib/channel/handshaker.cc', - 'src/core/lib/channel/handshaker_factory.cc', 'src/core/lib/channel/handshaker_registry.cc', 'src/core/lib/channel/status_util.cc', 'src/core/lib/compression/compression.cc', diff --git a/grpc.gemspec b/grpc.gemspec index 9a3c657cc85..a4e25d7bb22 100644 --- a/grpc.gemspec +++ b/grpc.gemspec @@ -477,7 +477,6 @@ Gem::Specification.new do |s| s.files += %w( src/core/lib/channel/channelz_registry.cc ) s.files += %w( src/core/lib/channel/connected_channel.cc ) s.files += %w( src/core/lib/channel/handshaker.cc ) - s.files += %w( src/core/lib/channel/handshaker_factory.cc ) s.files += %w( src/core/lib/channel/handshaker_registry.cc ) s.files += %w( src/core/lib/channel/status_util.cc ) s.files += %w( src/core/lib/compression/compression.cc ) diff --git a/grpc.gyp b/grpc.gyp index 6a0a2718c8e..113c17f0d09 100644 --- a/grpc.gyp +++ b/grpc.gyp @@ -276,7 +276,6 @@ 'src/core/lib/channel/channelz_registry.cc', 'src/core/lib/channel/connected_channel.cc', 'src/core/lib/channel/handshaker.cc', - 'src/core/lib/channel/handshaker_factory.cc', 'src/core/lib/channel/handshaker_registry.cc', 'src/core/lib/channel/status_util.cc', 'src/core/lib/compression/compression.cc', @@ -643,7 +642,6 @@ 'src/core/lib/channel/channelz_registry.cc', 'src/core/lib/channel/connected_channel.cc', 'src/core/lib/channel/handshaker.cc', - 'src/core/lib/channel/handshaker_factory.cc', 'src/core/lib/channel/handshaker_registry.cc', 'src/core/lib/channel/status_util.cc', 'src/core/lib/compression/compression.cc', @@ -889,7 +887,6 @@ 'src/core/lib/channel/channelz_registry.cc', 'src/core/lib/channel/connected_channel.cc', 'src/core/lib/channel/handshaker.cc', - 'src/core/lib/channel/handshaker_factory.cc', 'src/core/lib/channel/handshaker_registry.cc', 'src/core/lib/channel/status_util.cc', 'src/core/lib/compression/compression.cc', @@ -1111,7 +1108,6 @@ 'src/core/lib/channel/channelz_registry.cc', 'src/core/lib/channel/connected_channel.cc', 'src/core/lib/channel/handshaker.cc', - 'src/core/lib/channel/handshaker_factory.cc', 'src/core/lib/channel/handshaker_registry.cc', 'src/core/lib/channel/status_util.cc', 'src/core/lib/compression/compression.cc', diff --git a/package.xml b/package.xml index 69b6fdfa671..aa2bf62411c 100644 --- a/package.xml +++ b/package.xml @@ -482,7 +482,6 @@ - diff --git a/src/core/ext/filters/client_channel/http_connect_handshaker.cc b/src/core/ext/filters/client_channel/http_connect_handshaker.cc index 0716e468181..fa5aaa9e7ce 100644 --- a/src/core/ext/filters/client_channel/http_connect_handshaker.cc +++ b/src/core/ext/filters/client_channel/http_connect_handshaker.cc @@ -33,151 +33,160 @@ #include "src/core/lib/channel/handshaker_registry.h" #include "src/core/lib/gpr/env.h" #include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/mutex_lock.h" #include "src/core/lib/http/format_request.h" #include "src/core/lib/http/parser.h" #include "src/core/lib/slice/slice_internal.h" #include "src/core/lib/uri/uri_parser.h" -typedef struct http_connect_handshaker { - // Base class. Must be first. - grpc_handshaker base; +namespace grpc_core { - gpr_refcount refcount; - gpr_mu mu; +namespace { - bool shutdown; +class HttpConnectHandshaker : public Handshaker { + public: + HttpConnectHandshaker(); + void Shutdown(grpc_error* why) override; + void DoHandshake(grpc_tcp_server_acceptor* acceptor, + grpc_closure* on_handshake_done, + HandshakerArgs* args) override; + const char* name() const override { return "http_connect"; } + + private: + virtual ~HttpConnectHandshaker(); + void CleanupArgsForFailureLocked(); + void HandshakeFailedLocked(grpc_error* error); + static void OnWriteDone(void* arg, grpc_error* error); + static void OnReadDone(void* arg, grpc_error* error); + + gpr_mu mu_; + + bool is_shutdown_ = false; // Endpoint and read buffer to destroy after a shutdown. - grpc_endpoint* endpoint_to_destroy; - grpc_slice_buffer* read_buffer_to_destroy; + grpc_endpoint* endpoint_to_destroy_ = nullptr; + grpc_slice_buffer* read_buffer_to_destroy_ = nullptr; // State saved while performing the handshake. - grpc_handshaker_args* args; - grpc_closure* on_handshake_done; + HandshakerArgs* args_ = nullptr; + grpc_closure* on_handshake_done_ = nullptr; // Objects for processing the HTTP CONNECT request and response. - grpc_slice_buffer write_buffer; - grpc_closure request_done_closure; - grpc_closure response_read_closure; - grpc_http_parser http_parser; - grpc_http_response http_response; -} http_connect_handshaker; + grpc_slice_buffer write_buffer_; + grpc_closure request_done_closure_; + grpc_closure response_read_closure_; + grpc_http_parser http_parser_; + grpc_http_response http_response_; +}; -// Unref and clean up handshaker. -static void http_connect_handshaker_unref(http_connect_handshaker* handshaker) { - if (gpr_unref(&handshaker->refcount)) { - gpr_mu_destroy(&handshaker->mu); - if (handshaker->endpoint_to_destroy != nullptr) { - grpc_endpoint_destroy(handshaker->endpoint_to_destroy); - } - if (handshaker->read_buffer_to_destroy != nullptr) { - grpc_slice_buffer_destroy_internal(handshaker->read_buffer_to_destroy); - gpr_free(handshaker->read_buffer_to_destroy); - } - grpc_slice_buffer_destroy_internal(&handshaker->write_buffer); - grpc_http_parser_destroy(&handshaker->http_parser); - grpc_http_response_destroy(&handshaker->http_response); - gpr_free(handshaker); +HttpConnectHandshaker::~HttpConnectHandshaker() { + gpr_mu_destroy(&mu_); + if (endpoint_to_destroy_ != nullptr) { + grpc_endpoint_destroy(endpoint_to_destroy_); + } + if (read_buffer_to_destroy_ != nullptr) { + grpc_slice_buffer_destroy_internal(read_buffer_to_destroy_); + gpr_free(read_buffer_to_destroy_); } + grpc_slice_buffer_destroy_internal(&write_buffer_); + grpc_http_parser_destroy(&http_parser_); + grpc_http_response_destroy(&http_response_); } // Set args fields to nullptr, saving the endpoint and read buffer for // later destruction. -static void cleanup_args_for_failure_locked( - http_connect_handshaker* handshaker) { - handshaker->endpoint_to_destroy = handshaker->args->endpoint; - handshaker->args->endpoint = nullptr; - handshaker->read_buffer_to_destroy = handshaker->args->read_buffer; - handshaker->args->read_buffer = nullptr; - grpc_channel_args_destroy(handshaker->args->args); - handshaker->args->args = nullptr; +void HttpConnectHandshaker::CleanupArgsForFailureLocked() { + endpoint_to_destroy_ = args_->endpoint; + args_->endpoint = nullptr; + read_buffer_to_destroy_ = args_->read_buffer; + args_->read_buffer = nullptr; + grpc_channel_args_destroy(args_->args); + args_->args = nullptr; } // If the handshake failed or we're shutting down, clean up and invoke the // callback with the error. -static void handshake_failed_locked(http_connect_handshaker* handshaker, - grpc_error* error) { +void HttpConnectHandshaker::HandshakeFailedLocked(grpc_error* error) { if (error == GRPC_ERROR_NONE) { // If we were shut down after an endpoint operation succeeded but // before the endpoint callback was invoked, we need to generate our // own error. error = GRPC_ERROR_CREATE_FROM_STATIC_STRING("Handshaker shutdown"); } - if (!handshaker->shutdown) { + if (!is_shutdown_) { // TODO(ctiller): It is currently necessary to shutdown endpoints // before destroying them, even if we know that there are no // pending read/write callbacks. This should be fixed, at which // point this can be removed. - grpc_endpoint_shutdown(handshaker->args->endpoint, GRPC_ERROR_REF(error)); + grpc_endpoint_shutdown(args_->endpoint, GRPC_ERROR_REF(error)); // Not shutting down, so the handshake failed. Clean up before // invoking the callback. - cleanup_args_for_failure_locked(handshaker); + CleanupArgsForFailureLocked(); // Set shutdown to true so that subsequent calls to // http_connect_handshaker_shutdown() do nothing. - handshaker->shutdown = true; + is_shutdown_ = true; } // Invoke callback. - GRPC_CLOSURE_SCHED(handshaker->on_handshake_done, error); + GRPC_CLOSURE_SCHED(on_handshake_done_, error); } // Callback invoked when finished writing HTTP CONNECT request. -static void on_write_done(void* arg, grpc_error* error) { - http_connect_handshaker* handshaker = - static_cast(arg); - gpr_mu_lock(&handshaker->mu); - if (error != GRPC_ERROR_NONE || handshaker->shutdown) { +void HttpConnectHandshaker::OnWriteDone(void* arg, grpc_error* error) { + auto* handshaker = static_cast(arg); + gpr_mu_lock(&handshaker->mu_); + if (error != GRPC_ERROR_NONE || handshaker->is_shutdown_) { // If the write failed or we're shutting down, clean up and invoke the // callback with the error. - handshake_failed_locked(handshaker, GRPC_ERROR_REF(error)); - gpr_mu_unlock(&handshaker->mu); - http_connect_handshaker_unref(handshaker); + handshaker->HandshakeFailedLocked(GRPC_ERROR_REF(error)); + gpr_mu_unlock(&handshaker->mu_); + handshaker->Unref(); } else { // Otherwise, read the response. // The read callback inherits our ref to the handshaker. - grpc_endpoint_read(handshaker->args->endpoint, - handshaker->args->read_buffer, - &handshaker->response_read_closure); - gpr_mu_unlock(&handshaker->mu); + grpc_endpoint_read(handshaker->args_->endpoint, + handshaker->args_->read_buffer, + &handshaker->response_read_closure_); + gpr_mu_unlock(&handshaker->mu_); } } // Callback invoked for reading HTTP CONNECT response. -static void on_read_done(void* arg, grpc_error* error) { - http_connect_handshaker* handshaker = - static_cast(arg); - gpr_mu_lock(&handshaker->mu); - if (error != GRPC_ERROR_NONE || handshaker->shutdown) { +void HttpConnectHandshaker::OnReadDone(void* arg, grpc_error* error) { + auto* handshaker = static_cast(arg); + + gpr_mu_lock(&handshaker->mu_); + if (error != GRPC_ERROR_NONE || handshaker->is_shutdown_) { // If the read failed or we're shutting down, clean up and invoke the // callback with the error. - handshake_failed_locked(handshaker, GRPC_ERROR_REF(error)); + handshaker->HandshakeFailedLocked(GRPC_ERROR_REF(error)); goto done; } // Add buffer to parser. - for (size_t i = 0; i < handshaker->args->read_buffer->count; ++i) { - if (GRPC_SLICE_LENGTH(handshaker->args->read_buffer->slices[i]) > 0) { + for (size_t i = 0; i < handshaker->args_->read_buffer->count; ++i) { + if (GRPC_SLICE_LENGTH(handshaker->args_->read_buffer->slices[i]) > 0) { size_t body_start_offset = 0; - error = grpc_http_parser_parse(&handshaker->http_parser, - handshaker->args->read_buffer->slices[i], + error = grpc_http_parser_parse(&handshaker->http_parser_, + handshaker->args_->read_buffer->slices[i], &body_start_offset); if (error != GRPC_ERROR_NONE) { - handshake_failed_locked(handshaker, error); + handshaker->HandshakeFailedLocked(error); goto done; } - if (handshaker->http_parser.state == GRPC_HTTP_BODY) { + if (handshaker->http_parser_.state == GRPC_HTTP_BODY) { // Remove the data we've already read from the read buffer, // leaving only the leftover bytes (if any). grpc_slice_buffer tmp_buffer; grpc_slice_buffer_init(&tmp_buffer); if (body_start_offset < - GRPC_SLICE_LENGTH(handshaker->args->read_buffer->slices[i])) { + GRPC_SLICE_LENGTH(handshaker->args_->read_buffer->slices[i])) { grpc_slice_buffer_add( &tmp_buffer, - grpc_slice_split_tail(&handshaker->args->read_buffer->slices[i], + grpc_slice_split_tail(&handshaker->args_->read_buffer->slices[i], body_start_offset)); } grpc_slice_buffer_addn(&tmp_buffer, - &handshaker->args->read_buffer->slices[i + 1], - handshaker->args->read_buffer->count - i - 1); - grpc_slice_buffer_swap(handshaker->args->read_buffer, &tmp_buffer); + &handshaker->args_->read_buffer->slices[i + 1], + handshaker->args_->read_buffer->count - i - 1); + grpc_slice_buffer_swap(handshaker->args_->read_buffer, &tmp_buffer); grpc_slice_buffer_destroy_internal(&tmp_buffer); break; } @@ -194,64 +203,53 @@ static void on_read_done(void* arg, grpc_error* error) { // need to fix the HTTP parser to understand when the body is // complete (e.g., handling chunked transfer encoding or looking // at the Content-Length: header). - if (handshaker->http_parser.state != GRPC_HTTP_BODY) { - grpc_slice_buffer_reset_and_unref_internal(handshaker->args->read_buffer); - grpc_endpoint_read(handshaker->args->endpoint, - handshaker->args->read_buffer, - &handshaker->response_read_closure); - gpr_mu_unlock(&handshaker->mu); + if (handshaker->http_parser_.state != GRPC_HTTP_BODY) { + grpc_slice_buffer_reset_and_unref_internal(handshaker->args_->read_buffer); + grpc_endpoint_read(handshaker->args_->endpoint, + handshaker->args_->read_buffer, + &handshaker->response_read_closure_); + gpr_mu_unlock(&handshaker->mu_); return; } // Make sure we got a 2xx response. - if (handshaker->http_response.status < 200 || - handshaker->http_response.status >= 300) { + if (handshaker->http_response_.status < 200 || + handshaker->http_response_.status >= 300) { char* msg; gpr_asprintf(&msg, "HTTP proxy returned response code %d", - handshaker->http_response.status); + handshaker->http_response_.status); error = GRPC_ERROR_CREATE_FROM_COPIED_STRING(msg); gpr_free(msg); - handshake_failed_locked(handshaker, error); + handshaker->HandshakeFailedLocked(error); goto done; } // Success. Invoke handshake-done callback. - GRPC_CLOSURE_SCHED(handshaker->on_handshake_done, error); + GRPC_CLOSURE_SCHED(handshaker->on_handshake_done_, error); done: // Set shutdown to true so that subsequent calls to // http_connect_handshaker_shutdown() do nothing. - handshaker->shutdown = true; - gpr_mu_unlock(&handshaker->mu); - http_connect_handshaker_unref(handshaker); + handshaker->is_shutdown_ = true; + gpr_mu_unlock(&handshaker->mu_); + handshaker->Unref(); } // // Public handshaker methods // -static void http_connect_handshaker_destroy(grpc_handshaker* handshaker_in) { - http_connect_handshaker* handshaker = - reinterpret_cast(handshaker_in); - http_connect_handshaker_unref(handshaker); -} - -static void http_connect_handshaker_shutdown(grpc_handshaker* handshaker_in, - grpc_error* why) { - http_connect_handshaker* handshaker = - reinterpret_cast(handshaker_in); - gpr_mu_lock(&handshaker->mu); - if (!handshaker->shutdown) { - handshaker->shutdown = true; - grpc_endpoint_shutdown(handshaker->args->endpoint, GRPC_ERROR_REF(why)); - cleanup_args_for_failure_locked(handshaker); +void HttpConnectHandshaker::Shutdown(grpc_error* why) { + gpr_mu_lock(&mu_); + if (!is_shutdown_) { + is_shutdown_ = true; + grpc_endpoint_shutdown(args_->endpoint, GRPC_ERROR_REF(why)); + CleanupArgsForFailureLocked(); } - gpr_mu_unlock(&handshaker->mu); + gpr_mu_unlock(&mu_); GRPC_ERROR_UNREF(why); } -static void http_connect_handshaker_do_handshake( - grpc_handshaker* handshaker_in, grpc_tcp_server_acceptor* acceptor, - grpc_closure* on_handshake_done, grpc_handshaker_args* args) { - http_connect_handshaker* handshaker = - reinterpret_cast(handshaker_in); +void HttpConnectHandshaker::DoHandshake(grpc_tcp_server_acceptor* acceptor, + grpc_closure* on_handshake_done, + HandshakerArgs* args) { // Check for HTTP CONNECT channel arg. // If not found, invoke on_handshake_done without doing anything. const grpc_arg* arg = @@ -260,9 +258,9 @@ static void http_connect_handshaker_do_handshake( if (server_name == nullptr) { // Set shutdown to true so that subsequent calls to // http_connect_handshaker_shutdown() do nothing. - gpr_mu_lock(&handshaker->mu); - handshaker->shutdown = true; - gpr_mu_unlock(&handshaker->mu); + gpr_mu_lock(&mu_); + is_shutdown_ = true; + gpr_mu_unlock(&mu_); GRPC_CLOSURE_SCHED(on_handshake_done, GRPC_ERROR_NONE); return; } @@ -280,6 +278,7 @@ static void http_connect_handshaker_do_handshake( gpr_malloc(sizeof(grpc_http_header) * num_header_strings)); for (size_t i = 0; i < num_header_strings; ++i) { char* sep = strchr(header_strings[i], ':'); + if (sep == nullptr) { gpr_log(GPR_ERROR, "skipping unparseable HTTP CONNECT header: %s", header_strings[i]); @@ -292,9 +291,9 @@ static void http_connect_handshaker_do_handshake( } } // Save state in the handshaker object. - gpr_mu_lock(&handshaker->mu); - handshaker->args = args; - handshaker->on_handshake_done = on_handshake_done; + MutexLock lock(&mu_); + args_ = args; + on_handshake_done_ = on_handshake_done; // Log connection via proxy. char* proxy_name = grpc_endpoint_get_peer(args->endpoint); gpr_log(GPR_INFO, "Connecting to server %s via HTTP proxy %s", server_name, @@ -302,15 +301,18 @@ static void http_connect_handshaker_do_handshake( gpr_free(proxy_name); // Construct HTTP CONNECT request. grpc_httpcli_request request; - memset(&request, 0, sizeof(request)); request.host = server_name; + request.ssl_host_override = nullptr; request.http.method = (char*)"CONNECT"; request.http.path = server_name; + request.http.version = GRPC_HTTP_HTTP10; // Set by OnReadDone request.http.hdrs = headers; request.http.hdr_count = num_headers; + request.http.body_length = 0; + request.http.body = nullptr; request.handshaker = &grpc_httpcli_plaintext; grpc_slice request_slice = grpc_httpcli_format_connect_request(&request); - grpc_slice_buffer_add(&handshaker->write_buffer, request_slice); + grpc_slice_buffer_add(&write_buffer_, request_slice); // Clean up. gpr_free(headers); for (size_t i = 0; i < num_header_strings; ++i) { @@ -318,54 +320,42 @@ static void http_connect_handshaker_do_handshake( } gpr_free(header_strings); // Take a new ref to be held by the write callback. - gpr_ref(&handshaker->refcount); - grpc_endpoint_write(args->endpoint, &handshaker->write_buffer, - &handshaker->request_done_closure, nullptr); - gpr_mu_unlock(&handshaker->mu); + Ref().release(); + grpc_endpoint_write(args->endpoint, &write_buffer_, &request_done_closure_, + nullptr); } -static const grpc_handshaker_vtable http_connect_handshaker_vtable = { - http_connect_handshaker_destroy, http_connect_handshaker_shutdown, - http_connect_handshaker_do_handshake, "http_connect"}; - -static grpc_handshaker* grpc_http_connect_handshaker_create() { - http_connect_handshaker* handshaker = - static_cast(gpr_malloc(sizeof(*handshaker))); - memset(handshaker, 0, sizeof(*handshaker)); - grpc_handshaker_init(&http_connect_handshaker_vtable, &handshaker->base); - gpr_mu_init(&handshaker->mu); - gpr_ref_init(&handshaker->refcount, 1); - grpc_slice_buffer_init(&handshaker->write_buffer); - GRPC_CLOSURE_INIT(&handshaker->request_done_closure, on_write_done, - handshaker, grpc_schedule_on_exec_ctx); - GRPC_CLOSURE_INIT(&handshaker->response_read_closure, on_read_done, - handshaker, grpc_schedule_on_exec_ctx); - grpc_http_parser_init(&handshaker->http_parser, GRPC_HTTP_RESPONSE, - &handshaker->http_response); - return &handshaker->base; +HttpConnectHandshaker::HttpConnectHandshaker() { + gpr_mu_init(&mu_); + grpc_slice_buffer_init(&write_buffer_); + GRPC_CLOSURE_INIT(&request_done_closure_, &HttpConnectHandshaker::OnWriteDone, + this, grpc_schedule_on_exec_ctx); + GRPC_CLOSURE_INIT(&response_read_closure_, &HttpConnectHandshaker::OnReadDone, + this, grpc_schedule_on_exec_ctx); + grpc_http_parser_init(&http_parser_, GRPC_HTTP_RESPONSE, &http_response_); } // // handshaker factory // -static void handshaker_factory_add_handshakers( - grpc_handshaker_factory* factory, const grpc_channel_args* args, - grpc_pollset_set* interested_parties, - grpc_handshake_manager* handshake_mgr) { - grpc_handshake_manager_add(handshake_mgr, - grpc_http_connect_handshaker_create()); -} - -static void handshaker_factory_destroy(grpc_handshaker_factory* factory) {} +class HttpConnectHandshakerFactory : public HandshakerFactory { + public: + void AddHandshakers(const grpc_channel_args* args, + grpc_pollset_set* interested_parties, + HandshakeManager* handshake_mgr) override { + handshake_mgr->Add(MakeRefCounted()); + } + ~HttpConnectHandshakerFactory() override = default; +}; -static const grpc_handshaker_factory_vtable handshaker_factory_vtable = { - handshaker_factory_add_handshakers, handshaker_factory_destroy}; +} // namespace -static grpc_handshaker_factory handshaker_factory = { - &handshaker_factory_vtable}; +} // namespace grpc_core void grpc_http_connect_register_handshaker_factory() { - grpc_handshaker_factory_register(true /* at_start */, HANDSHAKER_CLIENT, - &handshaker_factory); + using namespace grpc_core; + HandshakerRegistry::RegisterHandshakerFactory( + true /* at_start */, HANDSHAKER_CLIENT, + UniquePtr(New())); } diff --git a/src/core/ext/transport/chttp2/client/chttp2_connector.cc b/src/core/ext/transport/chttp2/client/chttp2_connector.cc index 1e9a75d0630..c324c2c9243 100644 --- a/src/core/ext/transport/chttp2/client/chttp2_connector.cc +++ b/src/core/ext/transport/chttp2/client/chttp2_connector.cc @@ -55,7 +55,7 @@ typedef struct { grpc_closure connected; - grpc_handshake_manager* handshake_mgr; + grpc_core::RefCountedPtr handshake_mgr; } chttp2_connector; static void chttp2_connector_ref(grpc_connector* con) { @@ -79,7 +79,7 @@ static void chttp2_connector_shutdown(grpc_connector* con, grpc_error* why) { gpr_mu_lock(&c->mu); c->shutdown = true; if (c->handshake_mgr != nullptr) { - grpc_handshake_manager_shutdown(c->handshake_mgr, GRPC_ERROR_REF(why)); + c->handshake_mgr->Shutdown(GRPC_ERROR_REF(why)); } // If handshaking is not yet in progress, shutdown the endpoint. // Otherwise, the handshaker will do this for us. @@ -91,7 +91,7 @@ static void chttp2_connector_shutdown(grpc_connector* con, grpc_error* why) { } static void on_handshake_done(void* arg, grpc_error* error) { - grpc_handshaker_args* args = static_cast(arg); + auto* args = static_cast(arg); chttp2_connector* c = static_cast(args->user_data); gpr_mu_lock(&c->mu); if (error != GRPC_ERROR_NONE || c->shutdown) { @@ -152,20 +152,20 @@ static void on_handshake_done(void* arg, grpc_error* error) { grpc_closure* notify = c->notify; c->notify = nullptr; GRPC_CLOSURE_SCHED(notify, error); - grpc_handshake_manager_destroy(c->handshake_mgr); - c->handshake_mgr = nullptr; + c->handshake_mgr.reset(); gpr_mu_unlock(&c->mu); chttp2_connector_unref(reinterpret_cast(c)); } static void start_handshake_locked(chttp2_connector* c) { - c->handshake_mgr = grpc_handshake_manager_create(); - grpc_handshakers_add(HANDSHAKER_CLIENT, c->args.channel_args, - c->args.interested_parties, c->handshake_mgr); + c->handshake_mgr = grpc_core::MakeRefCounted(); + grpc_core::HandshakerRegistry::AddHandshakers( + grpc_core::HANDSHAKER_CLIENT, c->args.channel_args, + c->args.interested_parties, c->handshake_mgr.get()); grpc_endpoint_add_to_pollset_set(c->endpoint, c->args.interested_parties); - grpc_handshake_manager_do_handshake( - c->handshake_mgr, c->endpoint, c->args.channel_args, c->args.deadline, - nullptr /* acceptor */, on_handshake_done, c); + c->handshake_mgr->DoHandshake(c->endpoint, c->args.channel_args, + c->args.deadline, nullptr /* acceptor */, + on_handshake_done, c); c->endpoint = nullptr; // Endpoint handed off to handshake manager. } diff --git a/src/core/ext/transport/chttp2/server/chttp2_server.cc b/src/core/ext/transport/chttp2/server/chttp2_server.cc index 3d09187b9ba..040ea2044b1 100644 --- a/src/core/ext/transport/chttp2/server/chttp2_server.cc +++ b/src/core/ext/transport/chttp2/server/chttp2_server.cc @@ -54,7 +54,7 @@ typedef struct { bool shutdown; grpc_closure tcp_server_shutdown_complete; grpc_closure* server_destroy_listener_done; - grpc_handshake_manager* pending_handshake_mgrs; + grpc_core::HandshakeManager* pending_handshake_mgrs; grpc_core::RefCountedPtr channelz_listen_socket; } server_state; @@ -64,7 +64,7 @@ typedef struct { server_state* svr_state; grpc_pollset* accepting_pollset; grpc_tcp_server_acceptor* acceptor; - grpc_handshake_manager* handshake_mgr; + grpc_core::RefCountedPtr handshake_mgr; // State for enforcing handshake timeout on receiving HTTP/2 settings. grpc_chttp2_transport* transport; grpc_millis deadline; @@ -112,7 +112,7 @@ static void on_receive_settings(void* arg, grpc_error* error) { } static void on_handshake_done(void* arg, grpc_error* error) { - grpc_handshaker_args* args = static_cast(arg); + auto* args = static_cast(arg); server_connection_state* connection_state = static_cast(args->user_data); gpr_mu_lock(&connection_state->svr_state->mu); @@ -175,11 +175,10 @@ static void on_handshake_done(void* arg, grpc_error* error) { } } } - grpc_handshake_manager_pending_list_remove( - &connection_state->svr_state->pending_handshake_mgrs, - connection_state->handshake_mgr); + connection_state->handshake_mgr->RemoveFromPendingMgrList( + &connection_state->svr_state->pending_handshake_mgrs); gpr_mu_unlock(&connection_state->svr_state->mu); - grpc_handshake_manager_destroy(connection_state->handshake_mgr); + connection_state->handshake_mgr.reset(); gpr_free(connection_state->acceptor); grpc_tcp_server_unref(connection_state->svr_state->tcp_server); server_connection_state_unref(connection_state); @@ -211,9 +210,8 @@ static void on_accept(void* arg, grpc_endpoint* tcp, gpr_free(acceptor); return; } - grpc_handshake_manager* handshake_mgr = grpc_handshake_manager_create(); - grpc_handshake_manager_pending_list_add(&state->pending_handshake_mgrs, - handshake_mgr); + auto handshake_mgr = grpc_core::MakeRefCounted(); + handshake_mgr->AddToPendingMgrList(&state->pending_handshake_mgrs); grpc_tcp_server_ref(state->tcp_server); gpr_mu_unlock(&state->mu); server_connection_state* connection_state = @@ -227,19 +225,19 @@ static void on_accept(void* arg, grpc_endpoint* tcp, connection_state->interested_parties = grpc_pollset_set_create(); grpc_pollset_set_add_pollset(connection_state->interested_parties, connection_state->accepting_pollset); - grpc_handshakers_add(HANDSHAKER_SERVER, state->args, - connection_state->interested_parties, - connection_state->handshake_mgr); + grpc_core::HandshakerRegistry::AddHandshakers( + grpc_core::HANDSHAKER_SERVER, state->args, + connection_state->interested_parties, + connection_state->handshake_mgr.get()); const grpc_arg* timeout_arg = grpc_channel_args_find(state->args, GRPC_ARG_SERVER_HANDSHAKE_TIMEOUT_MS); connection_state->deadline = grpc_core::ExecCtx::Get()->Now() + grpc_channel_arg_get_integer(timeout_arg, {120 * GPR_MS_PER_SEC, 1, INT_MAX}); - grpc_handshake_manager_do_handshake(connection_state->handshake_mgr, tcp, - state->args, connection_state->deadline, - acceptor, on_handshake_done, - connection_state); + connection_state->handshake_mgr->DoHandshake( + tcp, state->args, connection_state->deadline, acceptor, on_handshake_done, + connection_state); } /* Server callback: start listening on our ports */ @@ -260,8 +258,9 @@ static void tcp_server_shutdown_complete(void* arg, grpc_error* error) { gpr_mu_lock(&state->mu); grpc_closure* destroy_done = state->server_destroy_listener_done; GPR_ASSERT(state->shutdown); - grpc_handshake_manager_pending_list_shutdown_all( - state->pending_handshake_mgrs, GRPC_ERROR_REF(error)); + if (state->pending_handshake_mgrs != nullptr) { + state->pending_handshake_mgrs->ShutdownAllPending(GRPC_ERROR_REF(error)); + } state->channelz_listen_socket.reset(); gpr_mu_unlock(&state->mu); // Flush queued work before destroying handshaker factory, since that diff --git a/src/core/lib/channel/handshaker.cc b/src/core/lib/channel/handshaker.cc index e516b56b743..6bb05cee24e 100644 --- a/src/core/lib/channel/handshaker.cc +++ b/src/core/lib/channel/handshaker.cc @@ -30,302 +30,229 @@ #include "src/core/lib/iomgr/timer.h" #include "src/core/lib/slice/slice_internal.h" -grpc_core::TraceFlag grpc_handshaker_trace(false, "handshaker"); +namespace grpc_core { -// -// grpc_handshaker -// +TraceFlag grpc_handshaker_trace(false, "handshaker"); -void grpc_handshaker_init(const grpc_handshaker_vtable* vtable, - grpc_handshaker* handshaker) { - handshaker->vtable = vtable; -} - -void grpc_handshaker_destroy(grpc_handshaker* handshaker) { - handshaker->vtable->destroy(handshaker); -} - -void grpc_handshaker_shutdown(grpc_handshaker* handshaker, grpc_error* why) { - handshaker->vtable->shutdown(handshaker, why); -} - -void grpc_handshaker_do_handshake(grpc_handshaker* handshaker, - grpc_tcp_server_acceptor* acceptor, - grpc_closure* on_handshake_done, - grpc_handshaker_args* args) { - handshaker->vtable->do_handshake(handshaker, acceptor, on_handshake_done, - args); -} +namespace { -const char* grpc_handshaker_name(grpc_handshaker* handshaker) { - return handshaker->vtable->name; +char* HandshakerArgsString(HandshakerArgs* args) { + char* args_str = grpc_channel_args_string(args->args); + size_t num_args = args->args != nullptr ? args->args->num_args : 0; + size_t read_buffer_length = + args->read_buffer != nullptr ? args->read_buffer->length : 0; + char* str; + gpr_asprintf(&str, + "{endpoint=%p, args=%p {size=%" PRIuPTR + ": %s}, read_buffer=%p (length=%" PRIuPTR "), exit_early=%d}", + args->endpoint, args->args, num_args, args_str, + args->read_buffer, read_buffer_length, args->exit_early); + gpr_free(args_str); + return str; } -// -// grpc_handshake_manager -// +} // namespace -struct grpc_handshake_manager { - gpr_mu mu; - gpr_refcount refs; - bool shutdown; - // An array of handshakers added via grpc_handshake_manager_add(). - size_t count; - grpc_handshaker** handshakers; - // The index of the handshaker to invoke next and closure to invoke it. - size_t index; - grpc_closure call_next_handshaker; - // The acceptor to call the handshakers with. - grpc_tcp_server_acceptor* acceptor; - // Deadline timer across all handshakers. - grpc_timer deadline_timer; - grpc_closure on_timeout; - // The final callback and user_data to invoke after the last handshaker. - grpc_closure on_handshake_done; - void* user_data; - // Handshaker args. - grpc_handshaker_args args; - // Links to the previous and next managers in a list of all pending handshakes - // Used at server side only. - grpc_handshake_manager* prev; - grpc_handshake_manager* next; -}; +HandshakeManager::HandshakeManager() { gpr_mu_init(&mu_); } -grpc_handshake_manager* grpc_handshake_manager_create() { - grpc_handshake_manager* mgr = static_cast( - gpr_zalloc(sizeof(grpc_handshake_manager))); - gpr_mu_init(&mgr->mu); - gpr_ref_init(&mgr->refs, 1); - return mgr; -} - -void grpc_handshake_manager_pending_list_add(grpc_handshake_manager** head, - grpc_handshake_manager* mgr) { - GPR_ASSERT(mgr->prev == nullptr); - GPR_ASSERT(mgr->next == nullptr); - mgr->next = *head; +/// Add \a mgr to the server side list of all pending handshake managers, the +/// list starts with \a *head. +// Not thread-safe. Caller needs to synchronize. +void HandshakeManager::AddToPendingMgrList(HandshakeManager** head) { + GPR_ASSERT(prev_ == nullptr); + GPR_ASSERT(next_ == nullptr); + next_ = *head; if (*head) { - (*head)->prev = mgr; + (*head)->prev_ = this; } - *head = mgr; + *head = this; } -void grpc_handshake_manager_pending_list_remove(grpc_handshake_manager** head, - grpc_handshake_manager* mgr) { - if (mgr->next != nullptr) { - mgr->next->prev = mgr->prev; +/// Remove \a mgr from the server side list of all pending handshake managers. +// Not thread-safe. Caller needs to synchronize. +void HandshakeManager::RemoveFromPendingMgrList(HandshakeManager** head) { + if (next_ != nullptr) { + next_->prev_ = prev_; } - if (mgr->prev != nullptr) { - mgr->prev->next = mgr->next; + if (prev_ != nullptr) { + prev_->next_ = next_; } else { - GPR_ASSERT(*head == mgr); - *head = mgr->next; + GPR_ASSERT(*head == this); + *head = next_; } } -void grpc_handshake_manager_pending_list_shutdown_all( - grpc_handshake_manager* head, grpc_error* why) { +/// Shutdown all pending handshake managers starting at head on the server +/// side. Not thread-safe. Caller needs to synchronize. +void HandshakeManager::ShutdownAllPending(grpc_error* why) { + auto* head = this; while (head != nullptr) { - grpc_handshake_manager_shutdown(head, GRPC_ERROR_REF(why)); - head = head->next; + head->Shutdown(GRPC_ERROR_REF(why)); + head = head->next_; } GRPC_ERROR_UNREF(why); } -static bool is_power_of_2(size_t n) { return (n & (n - 1)) == 0; } - -void grpc_handshake_manager_add(grpc_handshake_manager* mgr, - grpc_handshaker* handshaker) { +void HandshakeManager::Add(RefCountedPtr handshaker) { if (grpc_handshaker_trace.enabled()) { gpr_log( GPR_INFO, "handshake_manager %p: adding handshaker %s [%p] at index %" PRIuPTR, - mgr, grpc_handshaker_name(handshaker), handshaker, mgr->count); - } - gpr_mu_lock(&mgr->mu); - // To avoid allocating memory for each handshaker we add, we double - // the number of elements every time we need more. - size_t realloc_count = 0; - if (mgr->count == 0) { - realloc_count = 2; - } else if (mgr->count >= 2 && is_power_of_2(mgr->count)) { - realloc_count = mgr->count * 2; - } - if (realloc_count > 0) { - mgr->handshakers = static_cast(gpr_realloc( - mgr->handshakers, realloc_count * sizeof(grpc_handshaker*))); + this, handshaker->name(), handshaker.get(), handshakers_.size()); } - mgr->handshakers[mgr->count++] = handshaker; - gpr_mu_unlock(&mgr->mu); + MutexLock lock(&mu_); + handshakers_.push_back(std::move(handshaker)); } -static void grpc_handshake_manager_unref(grpc_handshake_manager* mgr) { - if (gpr_unref(&mgr->refs)) { - for (size_t i = 0; i < mgr->count; ++i) { - grpc_handshaker_destroy(mgr->handshakers[i]); - } - gpr_free(mgr->handshakers); - gpr_mu_destroy(&mgr->mu); - gpr_free(mgr); - } -} - -void grpc_handshake_manager_destroy(grpc_handshake_manager* mgr) { - grpc_handshake_manager_unref(mgr); +HandshakeManager::~HandshakeManager() { + handshakers_.clear(); + gpr_mu_destroy(&mu_); } -void grpc_handshake_manager_shutdown(grpc_handshake_manager* mgr, - grpc_error* why) { - gpr_mu_lock(&mgr->mu); - // Shutdown the handshaker that's currently in progress, if any. - if (!mgr->shutdown && mgr->index > 0) { - mgr->shutdown = true; - grpc_handshaker_shutdown(mgr->handshakers[mgr->index - 1], - GRPC_ERROR_REF(why)); +void HandshakeManager::Shutdown(grpc_error* why) { + { + MutexLock lock(&mu_); + // Shutdown the handshaker that's currently in progress, if any. + if (!is_shutdown_ && index_ > 0) { + is_shutdown_ = true; + handshakers_[index_ - 1]->Shutdown(GRPC_ERROR_REF(why)); + } } - gpr_mu_unlock(&mgr->mu); GRPC_ERROR_UNREF(why); } -static char* handshaker_args_string(grpc_handshaker_args* args) { - char* args_str = grpc_channel_args_string(args->args); - size_t num_args = args->args != nullptr ? args->args->num_args : 0; - size_t read_buffer_length = - args->read_buffer != nullptr ? args->read_buffer->length : 0; - char* str; - gpr_asprintf(&str, - "{endpoint=%p, args=%p {size=%" PRIuPTR - ": %s}, read_buffer=%p (length=%" PRIuPTR "), exit_early=%d}", - args->endpoint, args->args, num_args, args_str, - args->read_buffer, read_buffer_length, args->exit_early); - gpr_free(args_str); - return str; -} - // Helper function to call either the next handshaker or the // on_handshake_done callback. // Returns true if we've scheduled the on_handshake_done callback. -static bool call_next_handshaker_locked(grpc_handshake_manager* mgr, - grpc_error* error) { +bool HandshakeManager::CallNextHandshakerLocked(grpc_error* error) { if (grpc_handshaker_trace.enabled()) { - char* args_str = handshaker_args_string(&mgr->args); + char* args_str = HandshakerArgsString(&args_); gpr_log(GPR_INFO, "handshake_manager %p: error=%s shutdown=%d index=%" PRIuPTR ", args=%s", - mgr, grpc_error_string(error), mgr->shutdown, mgr->index, args_str); + this, grpc_error_string(error), is_shutdown_, index_, args_str); gpr_free(args_str); } - GPR_ASSERT(mgr->index <= mgr->count); + GPR_ASSERT(index_ <= handshakers_.size()); // If we got an error or we've been shut down or we're exiting early or // we've finished the last handshaker, invoke the on_handshake_done // callback. Otherwise, call the next handshaker. - if (error != GRPC_ERROR_NONE || mgr->shutdown || mgr->args.exit_early || - mgr->index == mgr->count) { - if (error == GRPC_ERROR_NONE && mgr->shutdown) { + if (error != GRPC_ERROR_NONE || is_shutdown_ || args_.exit_early || + index_ == handshakers_.size()) { + if (error == GRPC_ERROR_NONE && is_shutdown_) { error = GRPC_ERROR_CREATE_FROM_STATIC_STRING("handshaker shutdown"); // It is possible that the endpoint has already been destroyed by // a shutdown call while this callback was sitting on the ExecCtx // with no error. - if (mgr->args.endpoint != nullptr) { + if (args_.endpoint != nullptr) { // TODO(roth): It is currently necessary to shutdown endpoints // before destroying then, even when we know that there are no // pending read/write callbacks. This should be fixed, at which // point this can be removed. - grpc_endpoint_shutdown(mgr->args.endpoint, GRPC_ERROR_REF(error)); - grpc_endpoint_destroy(mgr->args.endpoint); - mgr->args.endpoint = nullptr; - grpc_channel_args_destroy(mgr->args.args); - mgr->args.args = nullptr; - grpc_slice_buffer_destroy_internal(mgr->args.read_buffer); - gpr_free(mgr->args.read_buffer); - mgr->args.read_buffer = nullptr; + grpc_endpoint_shutdown(args_.endpoint, GRPC_ERROR_REF(error)); + grpc_endpoint_destroy(args_.endpoint); + args_.endpoint = nullptr; + grpc_channel_args_destroy(args_.args); + args_.args = nullptr; + grpc_slice_buffer_destroy_internal(args_.read_buffer); + gpr_free(args_.read_buffer); + args_.read_buffer = nullptr; } } if (grpc_handshaker_trace.enabled()) { gpr_log(GPR_INFO, "handshake_manager %p: handshaking complete -- scheduling " "on_handshake_done with error=%s", - mgr, grpc_error_string(error)); + this, grpc_error_string(error)); } // Cancel deadline timer, since we're invoking the on_handshake_done // callback now. - grpc_timer_cancel(&mgr->deadline_timer); - GRPC_CLOSURE_SCHED(&mgr->on_handshake_done, error); - mgr->shutdown = true; + grpc_timer_cancel(&deadline_timer_); + GRPC_CLOSURE_SCHED(&on_handshake_done_, error); + is_shutdown_ = true; } else { + auto handshaker = handshakers_[index_]; if (grpc_handshaker_trace.enabled()) { gpr_log( GPR_INFO, "handshake_manager %p: calling handshaker %s [%p] at index %" PRIuPTR, - mgr, grpc_handshaker_name(mgr->handshakers[mgr->index]), - mgr->handshakers[mgr->index], mgr->index); + this, handshaker->name(), handshaker.get(), index_); } - grpc_handshaker_do_handshake(mgr->handshakers[mgr->index], mgr->acceptor, - &mgr->call_next_handshaker, &mgr->args); + handshaker->DoHandshake(acceptor_, &call_next_handshaker_, &args_); } - ++mgr->index; - return mgr->shutdown; + ++index_; + return is_shutdown_; } -// A function used as the handshaker-done callback when chaining -// handshakers together. -static void call_next_handshaker(void* arg, grpc_error* error) { - grpc_handshake_manager* mgr = static_cast(arg); - gpr_mu_lock(&mgr->mu); - bool done = call_next_handshaker_locked(mgr, GRPC_ERROR_REF(error)); - gpr_mu_unlock(&mgr->mu); +void HandshakeManager::CallNextHandshakerFn(void* arg, grpc_error* error) { + auto* mgr = static_cast(arg); + bool done; + { + MutexLock lock(&mgr->mu_); + done = mgr->CallNextHandshakerLocked(GRPC_ERROR_REF(error)); + } // If we're invoked the final callback, we won't be coming back // to this function, so we can release our reference to the // handshake manager. if (done) { - grpc_handshake_manager_unref(mgr); + mgr->Unref(); } } -// Callback invoked when deadline is exceeded. -static void on_timeout(void* arg, grpc_error* error) { - grpc_handshake_manager* mgr = static_cast(arg); - if (error == GRPC_ERROR_NONE) { // Timer fired, rather than being cancelled. - grpc_handshake_manager_shutdown( - mgr, GRPC_ERROR_CREATE_FROM_STATIC_STRING("Handshake timed out")); +void HandshakeManager::OnTimeoutFn(void* arg, grpc_error* error) { + auto* mgr = static_cast(arg); + if (error == GRPC_ERROR_NONE) { // Timer fired, rather than being cancelled + mgr->Shutdown(GRPC_ERROR_CREATE_FROM_STATIC_STRING("Handshake timed out")); } - grpc_handshake_manager_unref(mgr); + mgr->Unref(); } -void grpc_handshake_manager_do_handshake(grpc_handshake_manager* mgr, - grpc_endpoint* endpoint, - const grpc_channel_args* channel_args, - grpc_millis deadline, - grpc_tcp_server_acceptor* acceptor, - grpc_iomgr_cb_func on_handshake_done, - void* user_data) { - gpr_mu_lock(&mgr->mu); - GPR_ASSERT(mgr->index == 0); - GPR_ASSERT(!mgr->shutdown); - // Construct handshaker args. These will be passed through all - // handshakers and eventually be freed by the on_handshake_done callback. - mgr->args.endpoint = endpoint; - mgr->args.args = grpc_channel_args_copy(channel_args); - mgr->args.user_data = user_data; - mgr->args.read_buffer = static_cast( - gpr_malloc(sizeof(*mgr->args.read_buffer))); - grpc_slice_buffer_init(mgr->args.read_buffer); - // Initialize state needed for calling handshakers. - mgr->acceptor = acceptor; - GRPC_CLOSURE_INIT(&mgr->call_next_handshaker, call_next_handshaker, mgr, - grpc_schedule_on_exec_ctx); - GRPC_CLOSURE_INIT(&mgr->on_handshake_done, on_handshake_done, &mgr->args, - grpc_schedule_on_exec_ctx); - // Start deadline timer, which owns a ref. - gpr_ref(&mgr->refs); - GRPC_CLOSURE_INIT(&mgr->on_timeout, on_timeout, mgr, - grpc_schedule_on_exec_ctx); - grpc_timer_init(&mgr->deadline_timer, deadline, &mgr->on_timeout); - // Start first handshaker, which also owns a ref. - gpr_ref(&mgr->refs); - bool done = call_next_handshaker_locked(mgr, GRPC_ERROR_NONE); - gpr_mu_unlock(&mgr->mu); +void HandshakeManager::DoHandshake(grpc_endpoint* endpoint, + const grpc_channel_args* channel_args, + grpc_millis deadline, + grpc_tcp_server_acceptor* acceptor, + grpc_iomgr_cb_func on_handshake_done, + void* user_data) { + bool done; + { + MutexLock lock(&mu_); + GPR_ASSERT(index_ == 0); + GPR_ASSERT(!is_shutdown_); + // Construct handshaker args. These will be passed through all + // handshakers and eventually be freed by the on_handshake_done callback. + args_.endpoint = endpoint; + args_.args = grpc_channel_args_copy(channel_args); + args_.user_data = user_data; + args_.read_buffer = + static_cast(gpr_malloc(sizeof(*args_.read_buffer))); + grpc_slice_buffer_init(args_.read_buffer); + // Initialize state needed for calling handshakers. + acceptor_ = acceptor; + GRPC_CLOSURE_INIT(&call_next_handshaker_, + &HandshakeManager::CallNextHandshakerFn, this, + grpc_schedule_on_exec_ctx); + GRPC_CLOSURE_INIT(&on_handshake_done_, on_handshake_done, &args_, + grpc_schedule_on_exec_ctx); + // Start deadline timer, which owns a ref. + Ref().release(); + GRPC_CLOSURE_INIT(&on_timeout_, &HandshakeManager::OnTimeoutFn, this, + grpc_schedule_on_exec_ctx); + grpc_timer_init(&deadline_timer_, deadline, &on_timeout_); + // Start first handshaker, which also owns a ref. + Ref().release(); + done = CallNextHandshakerLocked(GRPC_ERROR_NONE); + } if (done) { - grpc_handshake_manager_unref(mgr); + Unref(); } } + +} // namespace grpc_core + +void grpc_handshake_manager_add(grpc_handshake_manager* mgr, + grpc_handshaker* handshaker) { + // This is a transition method to aid the API change for handshakers. + using namespace grpc_core; + RefCountedPtr refd_hs(static_cast(handshaker)); + mgr->Add(refd_hs); +} diff --git a/src/core/lib/channel/handshaker.h b/src/core/lib/channel/handshaker.h index a65990fceb4..912d524c8db 100644 --- a/src/core/lib/channel/handshaker.h +++ b/src/core/lib/channel/handshaker.h @@ -21,12 +21,21 @@ #include +#include + #include +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gprpp/inlined_vector.h" +#include "src/core/lib/gprpp/mutex_lock.h" +#include "src/core/lib/gprpp/ref_counted.h" #include "src/core/lib/iomgr/closure.h" #include "src/core/lib/iomgr/endpoint.h" #include "src/core/lib/iomgr/exec_ctx.h" #include "src/core/lib/iomgr/tcp_server.h" +#include "src/core/lib/iomgr/timer.h" + +namespace grpc_core { /// Handshakers are used to perform initial handshakes on a connection /// before the client sends the initial request. Some examples of what @@ -35,12 +44,6 @@ /// /// In general, handshakers should be used via a handshake manager. -/// -/// grpc_handshaker -/// - -typedef struct grpc_handshaker grpc_handshaker; - /// Arguments passed through handshakers and to the on_handshake_done callback. /// /// For handshakers, all members are input/output parameters; for @@ -55,115 +58,121 @@ typedef struct grpc_handshaker grpc_handshaker; /// /// For the on_handshake_done callback, all members are input arguments, /// which the callback takes ownership of. -typedef struct { - grpc_endpoint* endpoint; - grpc_channel_args* args; - grpc_slice_buffer* read_buffer; +struct HandshakerArgs { + grpc_endpoint* endpoint = nullptr; + grpc_channel_args* args = nullptr; + grpc_slice_buffer* read_buffer = nullptr; // A handshaker may set this to true before invoking on_handshake_done // to indicate that subsequent handshakers should be skipped. - bool exit_early; + bool exit_early = false; // User data passed through the handshake manager. Not used by // individual handshakers. - void* user_data; -} grpc_handshaker_args; + void* user_data = nullptr; +}; -typedef struct { - /// Destroys the handshaker. - void (*destroy)(grpc_handshaker* handshaker); +/// +/// Handshaker +/// - /// Shuts down the handshaker (e.g., to clean up when the operation is - /// aborted in the middle). - void (*shutdown)(grpc_handshaker* handshaker, grpc_error* why); - - /// Performs handshaking, modifying \a args as needed (e.g., to - /// replace \a endpoint with a wrapped endpoint). - /// When finished, invokes \a on_handshake_done. - /// \a acceptor will be NULL for client-side handshakers. - void (*do_handshake)(grpc_handshaker* handshaker, - grpc_tcp_server_acceptor* acceptor, - grpc_closure* on_handshake_done, - grpc_handshaker_args* args); - - /// The name of the handshaker, for debugging purposes. - const char* name; -} grpc_handshaker_vtable; - -/// Base struct. To subclass, make this the first member of the -/// implementation struct. -struct grpc_handshaker { - const grpc_handshaker_vtable* vtable; +class Handshaker : public RefCounted { + public: + virtual ~Handshaker() = default; + virtual void Shutdown(grpc_error* why) GRPC_ABSTRACT; + virtual void DoHandshake(grpc_tcp_server_acceptor* acceptor, + grpc_closure* on_handshake_done, + HandshakerArgs* args) GRPC_ABSTRACT; + virtual const char* name() const GRPC_ABSTRACT; + GRPC_ABSTRACT_BASE_CLASS }; -/// Called by concrete implementations to initialize the base struct. -void grpc_handshaker_init(const grpc_handshaker_vtable* vtable, - grpc_handshaker* handshaker); +// +// HandshakeManager +// -void grpc_handshaker_destroy(grpc_handshaker* handshaker); -void grpc_handshaker_shutdown(grpc_handshaker* handshaker, grpc_error* why); -void grpc_handshaker_do_handshake(grpc_handshaker* handshaker, - grpc_tcp_server_acceptor* acceptor, - grpc_closure* on_handshake_done, - grpc_handshaker_args* args); -const char* grpc_handshaker_name(grpc_handshaker* handshaker); +class HandshakeManager : public RefCounted { + public: + HandshakeManager(); + ~HandshakeManager(); -/// -/// grpc_handshake_manager -/// + /// Add \a mgr to the server side list of all pending handshake managers, the + /// list starts with \a *head. + // Not thread-safe. Caller needs to synchronize. + void AddToPendingMgrList(HandshakeManager** head); + + /// Remove \a mgr from the server side list of all pending handshake managers. + // Not thread-safe. Caller needs to synchronize. + void RemoveFromPendingMgrList(HandshakeManager** head); -typedef struct grpc_handshake_manager grpc_handshake_manager; + /// Shutdown all pending handshake managers starting at head on the server + /// side. Not thread-safe. Caller needs to synchronize. + void ShutdownAllPending(grpc_error* why); -/// Creates a new handshake manager. Caller takes ownership. -grpc_handshake_manager* grpc_handshake_manager_create(); + /// Adds a handshaker to the handshake manager. + /// Takes ownership of \a handshaker. + void Add(RefCountedPtr handshaker); -/// Adds a handshaker to the handshake manager. -/// Takes ownership of \a handshaker. + /// Shuts down the handshake manager (e.g., to clean up when the operation is + /// aborted in the middle). + void Shutdown(grpc_error* why); + + /// Invokes handshakers in the order they were added. + /// Takes ownership of \a endpoint, and then passes that ownership to + /// the \a on_handshake_done callback. + /// Does NOT take ownership of \a channel_args. Instead, makes a copy before + /// invoking the first handshaker. + /// \a acceptor will be nullptr for client-side handshakers. + /// + /// When done, invokes \a on_handshake_done with a HandshakerArgs + /// object as its argument. If the callback is invoked with error != + /// GRPC_ERROR_NONE, then handshaking failed and the handshaker has done + /// the necessary clean-up. Otherwise, the callback takes ownership of + /// the arguments. + void DoHandshake(grpc_endpoint* endpoint, + const grpc_channel_args* channel_args, grpc_millis deadline, + grpc_tcp_server_acceptor* acceptor, + grpc_iomgr_cb_func on_handshake_done, void* user_data); + + private: + bool CallNextHandshakerLocked(grpc_error* error); + + // A function used as the handshaker-done callback when chaining + // handshakers together. + static void CallNextHandshakerFn(void* arg, grpc_error* error); + + // Callback invoked when deadline is exceeded. + static void OnTimeoutFn(void* arg, grpc_error* error); + + static const size_t HANDSHAKERS_INIT_SIZE = 2; + + gpr_mu mu_; + bool is_shutdown_ = false; + // An array of handshakers added via grpc_handshake_manager_add(). + InlinedVector, HANDSHAKERS_INIT_SIZE> handshakers_; + // The index of the handshaker to invoke next and closure to invoke it. + size_t index_ = 0; + grpc_closure call_next_handshaker_; + // The acceptor to call the handshakers with. + grpc_tcp_server_acceptor* acceptor_; + // Deadline timer across all handshakers. + grpc_timer deadline_timer_; + grpc_closure on_timeout_; + // The final callback and user_data to invoke after the last handshaker. + grpc_closure on_handshake_done_; + // Handshaker args. + HandshakerArgs args_; + // Links to the previous and next managers in a list of all pending handshakes + // Used at server side only. + HandshakeManager* prev_ = nullptr; + HandshakeManager* next_ = nullptr; +}; + +} // namespace grpc_core + +// TODO(arjunroy): These are transitional to account for the new handshaker API +// and will eventually be removed entirely. +typedef grpc_core::HandshakeManager grpc_handshake_manager; +typedef grpc_core::Handshaker grpc_handshaker; void grpc_handshake_manager_add(grpc_handshake_manager* mgr, grpc_handshaker* handshaker); -/// Destroys the handshake manager. -void grpc_handshake_manager_destroy(grpc_handshake_manager* mgr); - -/// Shuts down the handshake manager (e.g., to clean up when the operation is -/// aborted in the middle). -/// The caller must still call grpc_handshake_manager_destroy() after -/// calling this function. -void grpc_handshake_manager_shutdown(grpc_handshake_manager* mgr, - grpc_error* why); - -/// Invokes handshakers in the order they were added. -/// Takes ownership of \a endpoint, and then passes that ownership to -/// the \a on_handshake_done callback. -/// Does NOT take ownership of \a channel_args. Instead, makes a copy before -/// invoking the first handshaker. -/// \a acceptor will be nullptr for client-side handshakers. -/// -/// When done, invokes \a on_handshake_done with a grpc_handshaker_args -/// object as its argument. If the callback is invoked with error != -/// GRPC_ERROR_NONE, then handshaking failed and the handshaker has done -/// the necessary clean-up. Otherwise, the callback takes ownership of -/// the arguments. -void grpc_handshake_manager_do_handshake(grpc_handshake_manager* mgr, - grpc_endpoint* endpoint, - const grpc_channel_args* channel_args, - grpc_millis deadline, - grpc_tcp_server_acceptor* acceptor, - grpc_iomgr_cb_func on_handshake_done, - void* user_data); - -/// Add \a mgr to the server side list of all pending handshake managers, the -/// list starts with \a *head. -// Not thread-safe. Caller needs to synchronize. -void grpc_handshake_manager_pending_list_add(grpc_handshake_manager** head, - grpc_handshake_manager* mgr); - -/// Remove \a mgr from the server side list of all pending handshake managers. -// Not thread-safe. Caller needs to synchronize. -void grpc_handshake_manager_pending_list_remove(grpc_handshake_manager** head, - grpc_handshake_manager* mgr); - -/// Shutdown all pending handshake managers on the server side. -// Not thread-safe. Caller needs to synchronize. -void grpc_handshake_manager_pending_list_shutdown_all( - grpc_handshake_manager* head, grpc_error* why); - #endif /* GRPC_CORE_LIB_CHANNEL_HANDSHAKER_H */ diff --git a/src/core/lib/channel/handshaker_factory.cc b/src/core/lib/channel/handshaker_factory.cc deleted file mode 100644 index 8ade8fe4e23..00000000000 --- a/src/core/lib/channel/handshaker_factory.cc +++ /dev/null @@ -1,42 +0,0 @@ -/* - * - * Copyright 2016 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -#include - -#include "src/core/lib/channel/handshaker_factory.h" - -#include - -void grpc_handshaker_factory_add_handshakers( - grpc_handshaker_factory* handshaker_factory, const grpc_channel_args* args, - grpc_pollset_set* interested_parties, - grpc_handshake_manager* handshake_mgr) { - if (handshaker_factory != nullptr) { - GPR_ASSERT(handshaker_factory->vtable != nullptr); - handshaker_factory->vtable->add_handshakers( - handshaker_factory, args, interested_parties, handshake_mgr); - } -} - -void grpc_handshaker_factory_destroy( - grpc_handshaker_factory* handshaker_factory) { - if (handshaker_factory != nullptr) { - GPR_ASSERT(handshaker_factory->vtable != nullptr); - handshaker_factory->vtable->destroy(handshaker_factory); - } -} diff --git a/src/core/lib/channel/handshaker_factory.h b/src/core/lib/channel/handshaker_factory.h index e17a6781798..3972af1f439 100644 --- a/src/core/lib/channel/handshaker_factory.h +++ b/src/core/lib/channel/handshaker_factory.h @@ -27,26 +27,18 @@ // A handshaker factory is used to create handshakers. -typedef struct grpc_handshaker_factory grpc_handshaker_factory; - -typedef struct { - void (*add_handshakers)(grpc_handshaker_factory* handshaker_factory, - const grpc_channel_args* args, - grpc_pollset_set* interested_parties, - grpc_handshake_manager* handshake_mgr); - void (*destroy)(grpc_handshaker_factory* handshaker_factory); -} grpc_handshaker_factory_vtable; - -struct grpc_handshaker_factory { - const grpc_handshaker_factory_vtable* vtable; -}; +namespace grpc_core { + +class HandshakerFactory { + public: + virtual void AddHandshakers(const grpc_channel_args* args, + grpc_pollset_set* interested_parties, + HandshakeManager* handshake_mgr) GRPC_ABSTRACT; + virtual ~HandshakerFactory() = default; -void grpc_handshaker_factory_add_handshakers( - grpc_handshaker_factory* handshaker_factory, const grpc_channel_args* args, - grpc_pollset_set* interested_parties, - grpc_handshake_manager* handshake_mgr); + GRPC_ABSTRACT_BASE_CLASS +}; -void grpc_handshaker_factory_destroy( - grpc_handshaker_factory* handshaker_factory); +} // namespace grpc_core #endif /* GRPC_CORE_LIB_CHANNEL_HANDSHAKER_FACTORY_H */ diff --git a/src/core/lib/channel/handshaker_registry.cc b/src/core/lib/channel/handshaker_registry.cc index fbafc43e795..b65129a6ed6 100644 --- a/src/core/lib/channel/handshaker_registry.cc +++ b/src/core/lib/channel/handshaker_registry.cc @@ -19,8 +19,11 @@ #include #include "src/core/lib/channel/handshaker_registry.h" +#include "src/core/lib/gprpp/inlined_vector.h" +#include "src/core/lib/gprpp/memory.h" #include +#include #include @@ -28,74 +31,83 @@ // grpc_handshaker_factory_list // -typedef struct { - grpc_handshaker_factory** list; - size_t num_factories; -} grpc_handshaker_factory_list; - -static void grpc_handshaker_factory_list_register( - grpc_handshaker_factory_list* list, bool at_start, - grpc_handshaker_factory* factory) { - list->list = static_cast(gpr_realloc( - list->list, - (list->num_factories + 1) * sizeof(grpc_handshaker_factory*))); - if (at_start) { - memmove(list->list + 1, list->list, - sizeof(grpc_handshaker_factory*) * list->num_factories); - list->list[0] = factory; - } else { - list->list[list->num_factories] = factory; - } - ++list->num_factories; -} +namespace grpc_core { + +namespace { + +class HandshakerFactoryList { + public: + void Register(bool at_start, UniquePtr factory); + void AddHandshakers(const grpc_channel_args* args, + grpc_pollset_set* interested_parties, + HandshakeManager* handshake_mgr); + + private: + InlinedVector, 2> factories_; +}; -static void grpc_handshaker_factory_list_add_handshakers( - grpc_handshaker_factory_list* list, const grpc_channel_args* args, - grpc_pollset_set* interested_parties, - grpc_handshake_manager* handshake_mgr) { - for (size_t i = 0; i < list->num_factories; ++i) { - grpc_handshaker_factory_add_handshakers(list->list[i], args, - interested_parties, handshake_mgr); +HandshakerFactoryList* g_handshaker_factory_lists = nullptr; + +} // namespace + +void HandshakerFactoryList::Register(bool at_start, + UniquePtr factory) { + factories_.push_back(std::move(factory)); + if (at_start) { + auto* end = &factories_[factories_.size() - 1]; + std::rotate(&factories_[0], end, end + 1); } } -static void grpc_handshaker_factory_list_destroy( - grpc_handshaker_factory_list* list) { - for (size_t i = 0; i < list->num_factories; ++i) { - grpc_handshaker_factory_destroy(list->list[i]); +void HandshakerFactoryList::AddHandshakers(const grpc_channel_args* args, + grpc_pollset_set* interested_parties, + HandshakeManager* handshake_mgr) { + for (size_t idx = 0; idx < factories_.size(); ++idx) { + auto& handshaker_factory = factories_[idx]; + handshaker_factory->AddHandshakers(args, interested_parties, handshake_mgr); } - gpr_free(list->list); } // // plugin // -static grpc_handshaker_factory_list - g_handshaker_factory_lists[NUM_HANDSHAKER_TYPES]; - -void grpc_handshaker_factory_registry_init() { - memset(g_handshaker_factory_lists, 0, sizeof(g_handshaker_factory_lists)); +void HandshakerRegistry::Init() { + GPR_ASSERT(g_handshaker_factory_lists == nullptr); + g_handshaker_factory_lists = static_cast( + gpr_malloc(sizeof(*g_handshaker_factory_lists) * NUM_HANDSHAKER_TYPES)); + GPR_ASSERT(g_handshaker_factory_lists != nullptr); + for (auto idx = 0; idx < NUM_HANDSHAKER_TYPES; ++idx) { + auto factory_list = g_handshaker_factory_lists + idx; + new (factory_list) HandshakerFactoryList(); + } } -void grpc_handshaker_factory_registry_shutdown() { - for (size_t i = 0; i < NUM_HANDSHAKER_TYPES; ++i) { - grpc_handshaker_factory_list_destroy(&g_handshaker_factory_lists[i]); +void HandshakerRegistry::Shutdown() { + GPR_ASSERT(g_handshaker_factory_lists != nullptr); + for (auto idx = 0; idx < NUM_HANDSHAKER_TYPES; ++idx) { + auto factory_list = g_handshaker_factory_lists + idx; + factory_list->~HandshakerFactoryList(); } + gpr_free(g_handshaker_factory_lists); + g_handshaker_factory_lists = nullptr; } -void grpc_handshaker_factory_register(bool at_start, - grpc_handshaker_type handshaker_type, - grpc_handshaker_factory* factory) { - grpc_handshaker_factory_list_register( - &g_handshaker_factory_lists[handshaker_type], at_start, factory); +void HandshakerRegistry::RegisterHandshakerFactory( + bool at_start, HandshakerType handshaker_type, + UniquePtr factory) { + GPR_ASSERT(g_handshaker_factory_lists != nullptr); + auto& factory_list = g_handshaker_factory_lists[handshaker_type]; + factory_list.Register(at_start, std::move(factory)); } -void grpc_handshakers_add(grpc_handshaker_type handshaker_type, - const grpc_channel_args* args, - grpc_pollset_set* interested_parties, - grpc_handshake_manager* handshake_mgr) { - grpc_handshaker_factory_list_add_handshakers( - &g_handshaker_factory_lists[handshaker_type], args, interested_parties, - handshake_mgr); +void HandshakerRegistry::AddHandshakers(HandshakerType handshaker_type, + const grpc_channel_args* args, + grpc_pollset_set* interested_parties, + HandshakeManager* handshake_mgr) { + GPR_ASSERT(g_handshaker_factory_lists != nullptr); + auto& factory_list = g_handshaker_factory_lists[handshaker_type]; + factory_list.AddHandshakers(args, interested_parties, handshake_mgr); } + +} // namespace grpc_core diff --git a/src/core/lib/channel/handshaker_registry.h b/src/core/lib/channel/handshaker_registry.h index 3dd4316de67..1b93a8dd47e 100644 --- a/src/core/lib/channel/handshaker_registry.h +++ b/src/core/lib/channel/handshaker_registry.h @@ -25,25 +25,30 @@ #include "src/core/lib/channel/handshaker_factory.h" +namespace grpc_core { + typedef enum { HANDSHAKER_CLIENT = 0, HANDSHAKER_SERVER, NUM_HANDSHAKER_TYPES, // Must be last. -} grpc_handshaker_type; - -void grpc_handshaker_factory_registry_init(); -void grpc_handshaker_factory_registry_shutdown(); - -/// Registers a new handshaker factory. Takes ownership. -/// If \a at_start is true, the new handshaker will be at the beginning of -/// the list. Otherwise, it will be added to the end. -void grpc_handshaker_factory_register(bool at_start, - grpc_handshaker_type handshaker_type, - grpc_handshaker_factory* factory); - -void grpc_handshakers_add(grpc_handshaker_type handshaker_type, - const grpc_channel_args* args, - grpc_pollset_set* interested_parties, - grpc_handshake_manager* handshake_mgr); +} HandshakerType; + +class HandshakerRegistry { + public: + /// Registers a new handshaker factory. Takes ownership. + /// If \a at_start is true, the new handshaker will be at the beginning of + /// the list. Otherwise, it will be added to the end. + static void RegisterHandshakerFactory(bool at_start, + HandshakerType handshaker_type, + UniquePtr factory); + static void AddHandshakers(HandshakerType handshaker_type, + const grpc_channel_args* args, + grpc_pollset_set* interested_parties, + HandshakeManager* handshake_mgr); + static void Init(); + static void Shutdown(); +}; + +} // namespace grpc_core #endif /* GRPC_CORE_LIB_CHANNEL_HANDSHAKER_REGISTRY_H */ diff --git a/src/core/lib/http/httpcli_security_connector.cc b/src/core/lib/http/httpcli_security_connector.cc index fdea7511cca..3f288e045a6 100644 --- a/src/core/lib/http/httpcli_security_connector.cc +++ b/src/core/lib/http/httpcli_security_connector.cc @@ -67,7 +67,7 @@ class grpc_httpcli_ssl_channel_security_connector final } void add_handshakers(grpc_pollset_set* interested_parties, - grpc_handshake_manager* handshake_mgr) override { + grpc_core::HandshakeManager* handshake_mgr) override { tsi_handshaker* handshaker = nullptr; if (handshaker_factory_ != nullptr) { tsi_result result = tsi_ssl_client_handshaker_factory_create_handshaker( @@ -77,8 +77,7 @@ class grpc_httpcli_ssl_channel_security_connector final tsi_result_to_string(result)); } } - grpc_handshake_manager_add( - handshake_mgr, grpc_security_handshaker_create(handshaker, this)); + handshake_mgr->Add(grpc_core::SecurityHandshakerCreate(handshaker, this)); } tsi_ssl_client_handshaker_factory* handshaker_factory() const { @@ -155,11 +154,11 @@ httpcli_ssl_channel_security_connector_create( typedef struct { void (*func)(void* arg, grpc_endpoint* endpoint); void* arg; - grpc_handshake_manager* handshake_mgr; + grpc_core::RefCountedPtr handshake_mgr; } on_done_closure; static void on_handshake_done(void* arg, grpc_error* error) { - grpc_handshaker_args* args = static_cast(arg); + auto* args = static_cast(arg); on_done_closure* c = static_cast(args->user_data); if (error != GRPC_ERROR_NONE) { const char* msg = grpc_error_string(error); @@ -172,14 +171,13 @@ static void on_handshake_done(void* arg, grpc_error* error) { gpr_free(args->read_buffer); c->func(c->arg, args->endpoint); } - grpc_handshake_manager_destroy(c->handshake_mgr); - gpr_free(c); + grpc_core::Delete(c); } static void ssl_handshake(void* arg, grpc_endpoint* tcp, const char* host, grpc_millis deadline, void (*on_done)(void* arg, grpc_endpoint* endpoint)) { - on_done_closure* c = static_cast(gpr_malloc(sizeof(*c))); + auto* c = grpc_core::New(); const char* pem_root_certs = grpc_core::DefaultSslRootStore::GetPemRootCerts(); const tsi_ssl_root_certs_store* root_store = @@ -198,12 +196,13 @@ static void ssl_handshake(void* arg, grpc_endpoint* tcp, const char* host, GPR_ASSERT(sc != nullptr); grpc_arg channel_arg = grpc_security_connector_to_arg(sc.get()); grpc_channel_args args = {1, &channel_arg}; - c->handshake_mgr = grpc_handshake_manager_create(); - grpc_handshakers_add(HANDSHAKER_CLIENT, &args, - nullptr /* interested_parties */, c->handshake_mgr); - grpc_handshake_manager_do_handshake( - c->handshake_mgr, tcp, nullptr /* channel_args */, deadline, - nullptr /* acceptor */, on_handshake_done, c /* user_data */); + c->handshake_mgr = grpc_core::MakeRefCounted(); + grpc_core::HandshakerRegistry::AddHandshakers( + grpc_core::HANDSHAKER_CLIENT, &args, /*interested_parties=*/nullptr, + c->handshake_mgr.get()); + c->handshake_mgr->DoHandshake(tcp, /*channel_args=*/nullptr, deadline, + /*acceptor=*/nullptr, on_handshake_done, + /*user_data=*/c); sc.reset(DEBUG_LOCATION, "httpcli"); } diff --git a/src/core/lib/security/security_connector/alts/alts_security_connector.cc b/src/core/lib/security/security_connector/alts/alts_security_connector.cc index 3ad0cc353cb..38b1f856d52 100644 --- a/src/core/lib/security/security_connector/alts/alts_security_connector.cc +++ b/src/core/lib/security/security_connector/alts/alts_security_connector.cc @@ -80,8 +80,9 @@ class grpc_alts_channel_security_connector final ~grpc_alts_channel_security_connector() override { gpr_free(target_name_); } - void add_handshakers(grpc_pollset_set* interested_parties, - grpc_handshake_manager* handshake_manager) override { + void add_handshakers( + grpc_pollset_set* interested_parties, + grpc_core::HandshakeManager* handshake_manager) override { tsi_handshaker* handshaker = nullptr; const grpc_alts_credentials* creds = static_cast(channel_creds()); @@ -89,8 +90,8 @@ class grpc_alts_channel_security_connector final creds->handshaker_service_url(), true, interested_parties, &handshaker) == TSI_OK); - grpc_handshake_manager_add( - handshake_manager, grpc_security_handshaker_create(handshaker, this)); + handshake_manager->Add( + grpc_core::SecurityHandshakerCreate(handshaker, this)); } void check_peer(tsi_peer peer, grpc_endpoint* ep, @@ -139,16 +140,17 @@ class grpc_alts_server_security_connector final } ~grpc_alts_server_security_connector() override = default; - void add_handshakers(grpc_pollset_set* interested_parties, - grpc_handshake_manager* handshake_manager) override { + void add_handshakers( + grpc_pollset_set* interested_parties, + grpc_core::HandshakeManager* handshake_manager) override { tsi_handshaker* handshaker = nullptr; const grpc_alts_server_credentials* creds = static_cast(server_creds()); GPR_ASSERT(alts_tsi_handshaker_create( creds->options(), nullptr, creds->handshaker_service_url(), false, interested_parties, &handshaker) == TSI_OK); - grpc_handshake_manager_add( - handshake_manager, grpc_security_handshaker_create(handshaker, this)); + handshake_manager->Add( + grpc_core::SecurityHandshakerCreate(handshaker, this)); } void check_peer(tsi_peer peer, grpc_endpoint* ep, diff --git a/src/core/lib/security/security_connector/fake/fake_security_connector.cc b/src/core/lib/security/security_connector/fake/fake_security_connector.cc index e3b8affb360..a0e2e6f030b 100644 --- a/src/core/lib/security/security_connector/fake/fake_security_connector.cc +++ b/src/core/lib/security/security_connector/fake/fake_security_connector.cc @@ -92,11 +92,9 @@ class grpc_fake_channel_security_connector final } void add_handshakers(grpc_pollset_set* interested_parties, - grpc_handshake_manager* handshake_mgr) override { - grpc_handshake_manager_add( - handshake_mgr, - grpc_security_handshaker_create( - tsi_create_fake_handshaker(/*is_client=*/true), this)); + grpc_core::HandshakeManager* handshake_mgr) override { + handshake_mgr->Add(grpc_core::SecurityHandshakerCreate( + tsi_create_fake_handshaker(/*is_client=*/true), this)); } bool check_call_host(const char* host, grpc_auth_context* auth_context, @@ -273,11 +271,9 @@ class grpc_fake_server_security_connector } void add_handshakers(grpc_pollset_set* interested_parties, - grpc_handshake_manager* handshake_mgr) override { - grpc_handshake_manager_add( - handshake_mgr, - grpc_security_handshaker_create( - tsi_create_fake_handshaker(/*=is_client*/ false), this)); + grpc_core::HandshakeManager* handshake_mgr) override { + handshake_mgr->Add(grpc_core::SecurityHandshakerCreate( + tsi_create_fake_handshaker(/*=is_client*/ false), this)); } int cmp(const grpc_security_connector* other) const override { diff --git a/src/core/lib/security/security_connector/local/local_security_connector.cc b/src/core/lib/security/security_connector/local/local_security_connector.cc index 7cc482c16c5..c1a101d4ab8 100644 --- a/src/core/lib/security/security_connector/local/local_security_connector.cc +++ b/src/core/lib/security/security_connector/local/local_security_connector.cc @@ -128,13 +128,14 @@ class grpc_local_channel_security_connector final ~grpc_local_channel_security_connector() override { gpr_free(target_name_); } - void add_handshakers(grpc_pollset_set* interested_parties, - grpc_handshake_manager* handshake_manager) override { + void add_handshakers( + grpc_pollset_set* interested_parties, + grpc_core::HandshakeManager* handshake_manager) override { tsi_handshaker* handshaker = nullptr; GPR_ASSERT(local_tsi_handshaker_create(true /* is_client */, &handshaker) == TSI_OK); - grpc_handshake_manager_add( - handshake_manager, grpc_security_handshaker_create(handshaker, this)); + handshake_manager->Add( + grpc_core::SecurityHandshakerCreate(handshaker, this)); } int cmp(const grpc_security_connector* other_sc) const override { @@ -184,13 +185,14 @@ class grpc_local_server_security_connector final : grpc_server_security_connector(nullptr, std::move(server_creds)) {} ~grpc_local_server_security_connector() override = default; - void add_handshakers(grpc_pollset_set* interested_parties, - grpc_handshake_manager* handshake_manager) override { + void add_handshakers( + grpc_pollset_set* interested_parties, + grpc_core::HandshakeManager* handshake_manager) override { tsi_handshaker* handshaker = nullptr; GPR_ASSERT(local_tsi_handshaker_create(false /* is_client */, &handshaker) == TSI_OK); - grpc_handshake_manager_add( - handshake_manager, grpc_security_handshaker_create(handshaker, this)); + handshake_manager->Add( + grpc_core::SecurityHandshakerCreate(handshaker, this)); } void check_peer(tsi_peer peer, grpc_endpoint* ep, diff --git a/src/core/lib/security/security_connector/security_connector.h b/src/core/lib/security/security_connector/security_connector.h index 74b0ef21a62..4c74c5cfea0 100644 --- a/src/core/lib/security/security_connector/security_connector.h +++ b/src/core/lib/security/security_connector/security_connector.h @@ -109,7 +109,7 @@ class grpc_channel_security_connector : public grpc_security_connector { grpc_error* error) GRPC_ABSTRACT; /// Registers handshakers with \a handshake_mgr. virtual void add_handshakers(grpc_pollset_set* interested_parties, - grpc_handshake_manager* handshake_mgr) + grpc_core::HandshakeManager* handshake_mgr) GRPC_ABSTRACT; const grpc_channel_credentials* channel_creds() const { @@ -150,7 +150,7 @@ class grpc_server_security_connector : public grpc_security_connector { ~grpc_server_security_connector() override = default; virtual void add_handshakers(grpc_pollset_set* interested_parties, - grpc_handshake_manager* handshake_mgr) + grpc_core::HandshakeManager* handshake_mgr) GRPC_ABSTRACT; const grpc_server_credentials* server_creds() const { diff --git a/src/core/lib/security/security_connector/ssl/ssl_security_connector.cc b/src/core/lib/security/security_connector/ssl/ssl_security_connector.cc index 7414ab1a37f..37cb41b9637 100644 --- a/src/core/lib/security/security_connector/ssl/ssl_security_connector.cc +++ b/src/core/lib/security/security_connector/ssl/ssl_security_connector.cc @@ -128,7 +128,7 @@ class grpc_ssl_channel_security_connector final } void add_handshakers(grpc_pollset_set* interested_parties, - grpc_handshake_manager* handshake_mgr) override { + grpc_core::HandshakeManager* handshake_mgr) override { // Instantiate TSI handshaker. tsi_handshaker* tsi_hs = nullptr; tsi_result result = tsi_ssl_client_handshaker_factory_create_handshaker( @@ -142,8 +142,7 @@ class grpc_ssl_channel_security_connector final return; } // Create handshakers. - grpc_handshake_manager_add(handshake_mgr, - grpc_security_handshaker_create(tsi_hs, this)); + handshake_mgr->Add(grpc_core::SecurityHandshakerCreate(tsi_hs, this)); } void check_peer(tsi_peer peer, grpc_endpoint* ep, @@ -283,7 +282,7 @@ class grpc_ssl_server_security_connector } void add_handshakers(grpc_pollset_set* interested_parties, - grpc_handshake_manager* handshake_mgr) override { + grpc_core::HandshakeManager* handshake_mgr) override { // Instantiate TSI handshaker. try_fetch_ssl_server_credentials(); tsi_handshaker* tsi_hs = nullptr; @@ -295,8 +294,7 @@ class grpc_ssl_server_security_connector return; } // Create handshakers. - grpc_handshake_manager_add(handshake_mgr, - grpc_security_handshaker_create(tsi_hs, this)); + handshake_mgr->Add(grpc_core::SecurityHandshakerCreate(tsi_hs, this)); } void check_peer(tsi_peer peer, grpc_endpoint* ep, diff --git a/src/core/lib/security/transport/security_handshaker.cc b/src/core/lib/security/transport/security_handshaker.cc index 01831dab10f..a6fd2481a4a 100644 --- a/src/core/lib/security/transport/security_handshaker.cc +++ b/src/core/lib/security/transport/security_handshaker.cc @@ -39,74 +39,113 @@ #define GRPC_INITIAL_HANDSHAKE_BUFFER_SIZE 256 -namespace { -struct security_handshaker { - security_handshaker(tsi_handshaker* handshaker, - grpc_security_connector* connector); - ~security_handshaker() { - gpr_mu_destroy(&mu); - tsi_handshaker_destroy(handshaker); - tsi_handshaker_result_destroy(handshaker_result); - if (endpoint_to_destroy != nullptr) { - grpc_endpoint_destroy(endpoint_to_destroy); - } - if (read_buffer_to_destroy != nullptr) { - grpc_slice_buffer_destroy_internal(read_buffer_to_destroy); - gpr_free(read_buffer_to_destroy); - } - gpr_free(handshake_buffer); - grpc_slice_buffer_destroy_internal(&outgoing); - auth_context.reset(DEBUG_LOCATION, "handshake"); - connector.reset(DEBUG_LOCATION, "handshake"); - } +namespace grpc_core { - void Ref() { refs.Ref(); } - void Unref() { - if (refs.Unref()) { - grpc_core::Delete(this); - } - } +namespace { - grpc_handshaker base; +class SecurityHandshaker : public Handshaker { + public: + SecurityHandshaker(tsi_handshaker* handshaker, + grpc_security_connector* connector); + ~SecurityHandshaker() override; + void Shutdown(grpc_error* why) override; + void DoHandshake(grpc_tcp_server_acceptor* acceptor, + grpc_closure* on_handshake_done, + HandshakerArgs* args) override; + const char* name() const override { return "security"; } + + private: + grpc_error* DoHandshakerNextLocked(const unsigned char* bytes_received, + size_t bytes_received_size); + + grpc_error* OnHandshakeNextDoneLocked( + tsi_result result, const unsigned char* bytes_to_send, + size_t bytes_to_send_size, tsi_handshaker_result* handshaker_result); + void HandshakeFailedLocked(grpc_error* error); + void CleanupArgsForFailureLocked(); + + static void OnHandshakeDataReceivedFromPeerFn(void* arg, grpc_error* error); + static void OnHandshakeDataSentToPeerFn(void* arg, grpc_error* error); + static void OnHandshakeNextDoneGrpcWrapper( + tsi_result result, void* user_data, const unsigned char* bytes_to_send, + size_t bytes_to_send_size, tsi_handshaker_result* handshaker_result); + static void OnPeerCheckedFn(void* arg, grpc_error* error); + void OnPeerCheckedInner(grpc_error* error); + size_t MoveReadBufferIntoHandshakeBuffer(); + grpc_error* CheckPeerLocked(); // State set at creation time. - tsi_handshaker* handshaker; - grpc_core::RefCountedPtr connector; + tsi_handshaker* handshaker_; + RefCountedPtr connector_; - gpr_mu mu; - grpc_core::RefCount refs; + gpr_mu mu_; - bool shutdown = false; + bool is_shutdown_ = false; // Endpoint and read buffer to destroy after a shutdown. - grpc_endpoint* endpoint_to_destroy = nullptr; - grpc_slice_buffer* read_buffer_to_destroy = nullptr; + grpc_endpoint* endpoint_to_destroy_ = nullptr; + grpc_slice_buffer* read_buffer_to_destroy_ = nullptr; // State saved while performing the handshake. - grpc_handshaker_args* args = nullptr; - grpc_closure* on_handshake_done = nullptr; - - size_t handshake_buffer_size; - unsigned char* handshake_buffer; - grpc_slice_buffer outgoing; - grpc_closure on_handshake_data_sent_to_peer; - grpc_closure on_handshake_data_received_from_peer; - grpc_closure on_peer_checked; - grpc_core::RefCountedPtr auth_context; - tsi_handshaker_result* handshaker_result = nullptr; + HandshakerArgs* args_ = nullptr; + grpc_closure* on_handshake_done_ = nullptr; + + size_t handshake_buffer_size_; + unsigned char* handshake_buffer_; + grpc_slice_buffer outgoing_; + grpc_closure on_handshake_data_sent_to_peer_; + grpc_closure on_handshake_data_received_from_peer_; + grpc_closure on_peer_checked_; + RefCountedPtr auth_context_; + tsi_handshaker_result* handshaker_result_ = nullptr; }; -} // namespace -static size_t move_read_buffer_into_handshake_buffer(security_handshaker* h) { - size_t bytes_in_read_buffer = h->args->read_buffer->length; - if (h->handshake_buffer_size < bytes_in_read_buffer) { - h->handshake_buffer = static_cast( - gpr_realloc(h->handshake_buffer, bytes_in_read_buffer)); - h->handshake_buffer_size = bytes_in_read_buffer; +SecurityHandshaker::SecurityHandshaker(tsi_handshaker* handshaker, + grpc_security_connector* connector) + : handshaker_(handshaker), + connector_(connector->Ref(DEBUG_LOCATION, "handshake")), + handshake_buffer_size_(GRPC_INITIAL_HANDSHAKE_BUFFER_SIZE), + handshake_buffer_( + static_cast(gpr_malloc(handshake_buffer_size_))) { + gpr_mu_init(&mu_); + grpc_slice_buffer_init(&outgoing_); + GRPC_CLOSURE_INIT(&on_handshake_data_sent_to_peer_, + &SecurityHandshaker::OnHandshakeDataSentToPeerFn, this, + grpc_schedule_on_exec_ctx); + GRPC_CLOSURE_INIT(&on_handshake_data_received_from_peer_, + &SecurityHandshaker::OnHandshakeDataReceivedFromPeerFn, + this, grpc_schedule_on_exec_ctx); + GRPC_CLOSURE_INIT(&on_peer_checked_, &SecurityHandshaker::OnPeerCheckedFn, + this, grpc_schedule_on_exec_ctx); +} + +SecurityHandshaker::~SecurityHandshaker() { + gpr_mu_destroy(&mu_); + tsi_handshaker_destroy(handshaker_); + tsi_handshaker_result_destroy(handshaker_result_); + if (endpoint_to_destroy_ != nullptr) { + grpc_endpoint_destroy(endpoint_to_destroy_); + } + if (read_buffer_to_destroy_ != nullptr) { + grpc_slice_buffer_destroy_internal(read_buffer_to_destroy_); + gpr_free(read_buffer_to_destroy_); + } + gpr_free(handshake_buffer_); + grpc_slice_buffer_destroy_internal(&outgoing_); + auth_context_.reset(DEBUG_LOCATION, "handshake"); + connector_.reset(DEBUG_LOCATION, "handshake"); +} + +size_t SecurityHandshaker::MoveReadBufferIntoHandshakeBuffer() { + size_t bytes_in_read_buffer = args_->read_buffer->length; + if (handshake_buffer_size_ < bytes_in_read_buffer) { + handshake_buffer_ = static_cast( + gpr_realloc(handshake_buffer_, bytes_in_read_buffer)); + handshake_buffer_size_ = bytes_in_read_buffer; } size_t offset = 0; - while (h->args->read_buffer->count > 0) { - grpc_slice next_slice = grpc_slice_buffer_take_first(h->args->read_buffer); - memcpy(h->handshake_buffer + offset, GRPC_SLICE_START_PTR(next_slice), + while (args_->read_buffer->count > 0) { + grpc_slice next_slice = grpc_slice_buffer_take_first(args_->read_buffer); + memcpy(handshake_buffer_ + offset, GRPC_SLICE_START_PTR(next_slice), GRPC_SLICE_LENGTH(next_slice)); offset += GRPC_SLICE_LENGTH(next_slice); grpc_slice_unref_internal(next_slice); @@ -114,21 +153,20 @@ static size_t move_read_buffer_into_handshake_buffer(security_handshaker* h) { return bytes_in_read_buffer; } -// Set args fields to NULL, saving the endpoint and read buffer for +// Set args_ fields to NULL, saving the endpoint and read buffer for // later destruction. -static void cleanup_args_for_failure_locked(security_handshaker* h) { - h->endpoint_to_destroy = h->args->endpoint; - h->args->endpoint = nullptr; - h->read_buffer_to_destroy = h->args->read_buffer; - h->args->read_buffer = nullptr; - grpc_channel_args_destroy(h->args->args); - h->args->args = nullptr; +void SecurityHandshaker::CleanupArgsForFailureLocked() { + endpoint_to_destroy_ = args_->endpoint; + args_->endpoint = nullptr; + read_buffer_to_destroy_ = args_->read_buffer; + args_->read_buffer = nullptr; + grpc_channel_args_destroy(args_->args); + args_->args = nullptr; } // If the handshake failed or we're shutting down, clean up and invoke the // callback with the error. -static void security_handshake_failed_locked(security_handshaker* h, - grpc_error* error) { +void SecurityHandshaker::HandshakeFailedLocked(grpc_error* error) { if (error == GRPC_ERROR_NONE) { // If we were shut down after the handshake succeeded but before an // endpoint callback was invoked, we need to generate our own error. @@ -137,50 +175,51 @@ static void security_handshake_failed_locked(security_handshaker* h, const char* msg = grpc_error_string(error); gpr_log(GPR_DEBUG, "Security handshake failed: %s", msg); - if (!h->shutdown) { + if (!is_shutdown_) { // TODO(ctiller): It is currently necessary to shutdown endpoints // before destroying them, even if we know that there are no // pending read/write callbacks. This should be fixed, at which // point this can be removed. - grpc_endpoint_shutdown(h->args->endpoint, GRPC_ERROR_REF(error)); + grpc_endpoint_shutdown(args_->endpoint, GRPC_ERROR_REF(error)); // Not shutting down, so the write failed. Clean up before // invoking the callback. - cleanup_args_for_failure_locked(h); + CleanupArgsForFailureLocked(); // Set shutdown to true so that subsequent calls to // security_handshaker_shutdown() do nothing. - h->shutdown = true; + is_shutdown_ = true; } // Invoke callback. - GRPC_CLOSURE_SCHED(h->on_handshake_done, error); + GRPC_CLOSURE_SCHED(on_handshake_done_, error); } -static void on_peer_checked_inner(security_handshaker* h, grpc_error* error) { - if (error != GRPC_ERROR_NONE || h->shutdown) { - security_handshake_failed_locked(h, GRPC_ERROR_REF(error)); +void SecurityHandshaker::OnPeerCheckedInner(grpc_error* error) { + MutexLock lock(&mu_); + if (error != GRPC_ERROR_NONE || is_shutdown_) { + HandshakeFailedLocked(GRPC_ERROR_REF(error)); return; } // Create zero-copy frame protector, if implemented. tsi_zero_copy_grpc_protector* zero_copy_protector = nullptr; tsi_result result = tsi_handshaker_result_create_zero_copy_grpc_protector( - h->handshaker_result, nullptr, &zero_copy_protector); + handshaker_result_, nullptr, &zero_copy_protector); if (result != TSI_OK && result != TSI_UNIMPLEMENTED) { error = grpc_set_tsi_error_result( GRPC_ERROR_CREATE_FROM_STATIC_STRING( "Zero-copy frame protector creation failed"), result); - security_handshake_failed_locked(h, error); + HandshakeFailedLocked(error); return; } // Create frame protector if zero-copy frame protector is NULL. tsi_frame_protector* protector = nullptr; if (zero_copy_protector == nullptr) { - result = tsi_handshaker_result_create_frame_protector(h->handshaker_result, + result = tsi_handshaker_result_create_frame_protector(handshaker_result_, nullptr, &protector); if (result != TSI_OK) { error = grpc_set_tsi_error_result(GRPC_ERROR_CREATE_FROM_STATIC_STRING( "Frame protector creation failed"), result); - security_handshake_failed_locked(h, error); + HandshakeFailedLocked(error); return; } } @@ -188,68 +227,63 @@ static void on_peer_checked_inner(security_handshaker* h, grpc_error* error) { const unsigned char* unused_bytes = nullptr; size_t unused_bytes_size = 0; result = tsi_handshaker_result_get_unused_bytes( - h->handshaker_result, &unused_bytes, &unused_bytes_size); + handshaker_result_, &unused_bytes, &unused_bytes_size); // Create secure endpoint. if (unused_bytes_size > 0) { grpc_slice slice = grpc_slice_from_copied_buffer((char*)unused_bytes, unused_bytes_size); - h->args->endpoint = grpc_secure_endpoint_create( - protector, zero_copy_protector, h->args->endpoint, &slice, 1); + args_->endpoint = grpc_secure_endpoint_create( + protector, zero_copy_protector, args_->endpoint, &slice, 1); grpc_slice_unref_internal(slice); } else { - h->args->endpoint = grpc_secure_endpoint_create( - protector, zero_copy_protector, h->args->endpoint, nullptr, 0); + args_->endpoint = grpc_secure_endpoint_create( + protector, zero_copy_protector, args_->endpoint, nullptr, 0); } - tsi_handshaker_result_destroy(h->handshaker_result); - h->handshaker_result = nullptr; + tsi_handshaker_result_destroy(handshaker_result_); + handshaker_result_ = nullptr; // Add auth context to channel args. - grpc_arg auth_context_arg = grpc_auth_context_to_arg(h->auth_context.get()); - grpc_channel_args* tmp_args = h->args->args; - h->args->args = - grpc_channel_args_copy_and_add(tmp_args, &auth_context_arg, 1); + grpc_arg auth_context_arg = grpc_auth_context_to_arg(auth_context_.get()); + grpc_channel_args* tmp_args = args_->args; + args_->args = grpc_channel_args_copy_and_add(tmp_args, &auth_context_arg, 1); grpc_channel_args_destroy(tmp_args); // Invoke callback. - GRPC_CLOSURE_SCHED(h->on_handshake_done, GRPC_ERROR_NONE); + GRPC_CLOSURE_SCHED(on_handshake_done_, GRPC_ERROR_NONE); // Set shutdown to true so that subsequent calls to // security_handshaker_shutdown() do nothing. - h->shutdown = true; + is_shutdown_ = true; } -static void on_peer_checked(void* arg, grpc_error* error) { - security_handshaker* h = static_cast(arg); - gpr_mu_lock(&h->mu); - on_peer_checked_inner(h, error); - gpr_mu_unlock(&h->mu); - h->Unref(); +void SecurityHandshaker::OnPeerCheckedFn(void* arg, grpc_error* error) { + RefCountedPtr(static_cast(arg)) + ->OnPeerCheckedInner(error); } -static grpc_error* check_peer_locked(security_handshaker* h) { +grpc_error* SecurityHandshaker::CheckPeerLocked() { tsi_peer peer; tsi_result result = - tsi_handshaker_result_extract_peer(h->handshaker_result, &peer); + tsi_handshaker_result_extract_peer(handshaker_result_, &peer); if (result != TSI_OK) { return grpc_set_tsi_error_result( GRPC_ERROR_CREATE_FROM_STATIC_STRING("Peer extraction failed"), result); } - h->connector->check_peer(peer, h->args->endpoint, &h->auth_context, - &h->on_peer_checked); + connector_->check_peer(peer, args_->endpoint, &auth_context_, + &on_peer_checked_); return GRPC_ERROR_NONE; } -static grpc_error* on_handshake_next_done_locked( - security_handshaker* h, tsi_result result, - const unsigned char* bytes_to_send, size_t bytes_to_send_size, - tsi_handshaker_result* handshaker_result) { +grpc_error* SecurityHandshaker::OnHandshakeNextDoneLocked( + tsi_result result, const unsigned char* bytes_to_send, + size_t bytes_to_send_size, tsi_handshaker_result* handshaker_result) { grpc_error* error = GRPC_ERROR_NONE; // Handshaker was shutdown. - if (h->shutdown) { + if (is_shutdown_) { return GRPC_ERROR_CREATE_FROM_STATIC_STRING("Handshaker shutdown"); } // Read more if we need to. if (result == TSI_INCOMPLETE_DATA) { GPR_ASSERT(bytes_to_send_size == 0); - grpc_endpoint_read(h->args->endpoint, h->args->read_buffer, - &h->on_handshake_data_received_from_peer); + grpc_endpoint_read(args_->endpoint, args_->read_buffer, + &on_handshake_data_received_from_peer_); return error; } if (result != TSI_OK) { @@ -258,55 +292,52 @@ static grpc_error* on_handshake_next_done_locked( } // Update handshaker result. if (handshaker_result != nullptr) { - GPR_ASSERT(h->handshaker_result == nullptr); - h->handshaker_result = handshaker_result; + GPR_ASSERT(handshaker_result_ == nullptr); + handshaker_result_ = handshaker_result; } if (bytes_to_send_size > 0) { // Send data to peer, if needed. grpc_slice to_send = grpc_slice_from_copied_buffer( reinterpret_cast(bytes_to_send), bytes_to_send_size); - grpc_slice_buffer_reset_and_unref_internal(&h->outgoing); - grpc_slice_buffer_add(&h->outgoing, to_send); - grpc_endpoint_write(h->args->endpoint, &h->outgoing, - &h->on_handshake_data_sent_to_peer, nullptr); + grpc_slice_buffer_reset_and_unref_internal(&outgoing_); + grpc_slice_buffer_add(&outgoing_, to_send); + grpc_endpoint_write(args_->endpoint, &outgoing_, + &on_handshake_data_sent_to_peer_, nullptr); } else if (handshaker_result == nullptr) { // There is nothing to send, but need to read from peer. - grpc_endpoint_read(h->args->endpoint, h->args->read_buffer, - &h->on_handshake_data_received_from_peer); + grpc_endpoint_read(args_->endpoint, args_->read_buffer, + &on_handshake_data_received_from_peer_); } else { // Handshake has finished, check peer and so on. - error = check_peer_locked(h); + error = CheckPeerLocked(); } return error; } -static void on_handshake_next_done_grpc_wrapper( +void SecurityHandshaker::OnHandshakeNextDoneGrpcWrapper( tsi_result result, void* user_data, const unsigned char* bytes_to_send, size_t bytes_to_send_size, tsi_handshaker_result* handshaker_result) { - security_handshaker* h = static_cast(user_data); - gpr_mu_lock(&h->mu); - grpc_error* error = on_handshake_next_done_locked( - h, result, bytes_to_send, bytes_to_send_size, handshaker_result); + RefCountedPtr h( + static_cast(user_data)); + MutexLock lock(&h->mu_); + grpc_error* error = h->OnHandshakeNextDoneLocked( + result, bytes_to_send, bytes_to_send_size, handshaker_result); if (error != GRPC_ERROR_NONE) { - security_handshake_failed_locked(h, error); - gpr_mu_unlock(&h->mu); - h->Unref(); + h->HandshakeFailedLocked(error); } else { - gpr_mu_unlock(&h->mu); + h.release(); // Avoid unref } } -static grpc_error* do_handshaker_next_locked( - security_handshaker* h, const unsigned char* bytes_received, - size_t bytes_received_size) { +grpc_error* SecurityHandshaker::DoHandshakerNextLocked( + const unsigned char* bytes_received, size_t bytes_received_size) { // Invoke TSI handshaker. const unsigned char* bytes_to_send = nullptr; size_t bytes_to_send_size = 0; - tsi_handshaker_result* handshaker_result = nullptr; + tsi_handshaker_result* hs_result = nullptr; tsi_result result = tsi_handshaker_next( - h->handshaker, bytes_received, bytes_received_size, &bytes_to_send, - &bytes_to_send_size, &handshaker_result, - &on_handshake_next_done_grpc_wrapper, h); + handshaker_, bytes_received, bytes_received_size, &bytes_to_send, + &bytes_to_send_size, &hs_result, &OnHandshakeNextDoneGrpcWrapper, this); if (result == TSI_ASYNC) { // Handshaker operating asynchronously. Nothing else to do here; // callback will be invoked in a TSI thread. @@ -314,233 +345,169 @@ static grpc_error* do_handshaker_next_locked( } // Handshaker returned synchronously. Invoke callback directly in // this thread with our existing exec_ctx. - return on_handshake_next_done_locked(h, result, bytes_to_send, - bytes_to_send_size, handshaker_result); + return OnHandshakeNextDoneLocked(result, bytes_to_send, bytes_to_send_size, + hs_result); } -static void on_handshake_data_received_from_peer(void* arg, grpc_error* error) { - security_handshaker* h = static_cast(arg); - gpr_mu_lock(&h->mu); - if (error != GRPC_ERROR_NONE || h->shutdown) { - security_handshake_failed_locked( - h, GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( - "Handshake read failed", &error, 1)); - gpr_mu_unlock(&h->mu); - h->Unref(); +void SecurityHandshaker::OnHandshakeDataReceivedFromPeerFn(void* arg, + grpc_error* error) { + RefCountedPtr h(static_cast(arg)); + MutexLock lock(&h->mu_); + if (error != GRPC_ERROR_NONE || h->is_shutdown_) { + h->HandshakeFailedLocked(GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( + "Handshake read failed", &error, 1)); return; } // Copy all slices received. - size_t bytes_received_size = move_read_buffer_into_handshake_buffer(h); + size_t bytes_received_size = h->MoveReadBufferIntoHandshakeBuffer(); // Call TSI handshaker. - error = - do_handshaker_next_locked(h, h->handshake_buffer, bytes_received_size); + error = h->DoHandshakerNextLocked(h->handshake_buffer_, bytes_received_size); if (error != GRPC_ERROR_NONE) { - security_handshake_failed_locked(h, error); - gpr_mu_unlock(&h->mu); - h->Unref(); + h->HandshakeFailedLocked(error); } else { - gpr_mu_unlock(&h->mu); + h.release(); // Avoid unref } } -static void on_handshake_data_sent_to_peer(void* arg, grpc_error* error) { - security_handshaker* h = static_cast(arg); - gpr_mu_lock(&h->mu); - if (error != GRPC_ERROR_NONE || h->shutdown) { - security_handshake_failed_locked( - h, GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( - "Handshake write failed", &error, 1)); - gpr_mu_unlock(&h->mu); - h->Unref(); +void SecurityHandshaker::OnHandshakeDataSentToPeerFn(void* arg, + grpc_error* error) { + RefCountedPtr h(static_cast(arg)); + MutexLock lock(&h->mu_); + if (error != GRPC_ERROR_NONE || h->is_shutdown_) { + h->HandshakeFailedLocked(GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( + "Handshake write failed", &error, 1)); return; } // We may be done. - if (h->handshaker_result == nullptr) { - grpc_endpoint_read(h->args->endpoint, h->args->read_buffer, - &h->on_handshake_data_received_from_peer); + if (h->handshaker_result_ == nullptr) { + grpc_endpoint_read(h->args_->endpoint, h->args_->read_buffer, + &h->on_handshake_data_received_from_peer_); } else { - error = check_peer_locked(h); + error = h->CheckPeerLocked(); if (error != GRPC_ERROR_NONE) { - security_handshake_failed_locked(h, error); - gpr_mu_unlock(&h->mu); - h->Unref(); + h->HandshakeFailedLocked(error); return; } } - gpr_mu_unlock(&h->mu); + h.release(); // Avoid unref } // // public handshaker API // -static void security_handshaker_destroy(grpc_handshaker* handshaker) { - security_handshaker* h = reinterpret_cast(handshaker); - h->Unref(); -} - -static void security_handshaker_shutdown(grpc_handshaker* handshaker, - grpc_error* why) { - security_handshaker* h = reinterpret_cast(handshaker); - gpr_mu_lock(&h->mu); - if (!h->shutdown) { - h->shutdown = true; - tsi_handshaker_shutdown(h->handshaker); - grpc_endpoint_shutdown(h->args->endpoint, GRPC_ERROR_REF(why)); - cleanup_args_for_failure_locked(h); +void SecurityHandshaker::Shutdown(grpc_error* why) { + MutexLock lock(&mu_); + if (!is_shutdown_) { + is_shutdown_ = true; + tsi_handshaker_shutdown(handshaker_); + grpc_endpoint_shutdown(args_->endpoint, GRPC_ERROR_REF(why)); + CleanupArgsForFailureLocked(); } - gpr_mu_unlock(&h->mu); GRPC_ERROR_UNREF(why); } -static void security_handshaker_do_handshake(grpc_handshaker* handshaker, - grpc_tcp_server_acceptor* acceptor, - grpc_closure* on_handshake_done, - grpc_handshaker_args* args) { - security_handshaker* h = reinterpret_cast(handshaker); - gpr_mu_lock(&h->mu); - h->args = args; - h->on_handshake_done = on_handshake_done; - h->Ref(); - size_t bytes_received_size = move_read_buffer_into_handshake_buffer(h); +void SecurityHandshaker::DoHandshake(grpc_tcp_server_acceptor* acceptor, + grpc_closure* on_handshake_done, + HandshakerArgs* args) { + auto ref = Ref(); + MutexLock lock(&mu_); + args_ = args; + on_handshake_done_ = on_handshake_done; + size_t bytes_received_size = MoveReadBufferIntoHandshakeBuffer(); grpc_error* error = - do_handshaker_next_locked(h, h->handshake_buffer, bytes_received_size); + DoHandshakerNextLocked(handshake_buffer_, bytes_received_size); if (error != GRPC_ERROR_NONE) { - security_handshake_failed_locked(h, error); - gpr_mu_unlock(&h->mu); - h->Unref(); - return; + HandshakeFailedLocked(error); + } else { + ref.release(); // Avoid unref } - gpr_mu_unlock(&h->mu); -} - -static const grpc_handshaker_vtable security_handshaker_vtable = { - security_handshaker_destroy, security_handshaker_shutdown, - security_handshaker_do_handshake, "security"}; - -namespace { -security_handshaker::security_handshaker(tsi_handshaker* handshaker, - grpc_security_connector* connector) - : handshaker(handshaker), - connector(connector->Ref(DEBUG_LOCATION, "handshake")), - handshake_buffer_size(GRPC_INITIAL_HANDSHAKE_BUFFER_SIZE), - handshake_buffer( - static_cast(gpr_malloc(handshake_buffer_size))) { - grpc_handshaker_init(&security_handshaker_vtable, &base); - gpr_mu_init(&mu); - grpc_slice_buffer_init(&outgoing); - GRPC_CLOSURE_INIT(&on_handshake_data_sent_to_peer, - ::on_handshake_data_sent_to_peer, this, - grpc_schedule_on_exec_ctx); - GRPC_CLOSURE_INIT(&on_handshake_data_received_from_peer, - ::on_handshake_data_received_from_peer, this, - grpc_schedule_on_exec_ctx); - GRPC_CLOSURE_INIT(&on_peer_checked, ::on_peer_checked, this, - grpc_schedule_on_exec_ctx); -} -} // namespace - -static grpc_handshaker* security_handshaker_create( - tsi_handshaker* handshaker, grpc_security_connector* connector) { - security_handshaker* h = - grpc_core::New(handshaker, connector); - return &h->base; } // -// fail_handshaker +// FailHandshaker // -static void fail_handshaker_destroy(grpc_handshaker* handshaker) { - gpr_free(handshaker); -} - -static void fail_handshaker_shutdown(grpc_handshaker* handshaker, - grpc_error* why) { - GRPC_ERROR_UNREF(why); -} - -static void fail_handshaker_do_handshake(grpc_handshaker* handshaker, - grpc_tcp_server_acceptor* acceptor, - grpc_closure* on_handshake_done, - grpc_handshaker_args* args) { - GRPC_CLOSURE_SCHED(on_handshake_done, - GRPC_ERROR_CREATE_FROM_STATIC_STRING( - "Failed to create security handshaker")); -} - -static const grpc_handshaker_vtable fail_handshaker_vtable = { - fail_handshaker_destroy, fail_handshaker_shutdown, - fail_handshaker_do_handshake, "security_fail"}; +class FailHandshaker : public Handshaker { + public: + const char* name() const override { return "security_fail"; } + void Shutdown(grpc_error* why) override { GRPC_ERROR_UNREF(why); } + void DoHandshake(grpc_tcp_server_acceptor* acceptor, + grpc_closure* on_handshake_done, + HandshakerArgs* args) override { + GRPC_CLOSURE_SCHED(on_handshake_done, + GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Failed to create security handshaker")); + } -static grpc_handshaker* fail_handshaker_create() { - grpc_handshaker* h = static_cast(gpr_malloc(sizeof(*h))); - grpc_handshaker_init(&fail_handshaker_vtable, h); - return h; -} + private: + virtual ~FailHandshaker() = default; +}; // // handshaker factories // -static void client_handshaker_factory_add_handshakers( - grpc_handshaker_factory* handshaker_factory, const grpc_channel_args* args, - grpc_pollset_set* interested_parties, - grpc_handshake_manager* handshake_mgr) { - grpc_channel_security_connector* security_connector = - reinterpret_cast( - grpc_security_connector_find_in_args(args)); - if (security_connector) { - security_connector->add_handshakers(interested_parties, handshake_mgr); +class ClientSecurityHandshakerFactory : public HandshakerFactory { + public: + void AddHandshakers(const grpc_channel_args* args, + grpc_pollset_set* interested_parties, + HandshakeManager* handshake_mgr) override { + auto* security_connector = + reinterpret_cast( + grpc_security_connector_find_in_args(args)); + if (security_connector) { + security_connector->add_handshakers(interested_parties, handshake_mgr); + } } -} + ~ClientSecurityHandshakerFactory() override = default; +}; -static void server_handshaker_factory_add_handshakers( - grpc_handshaker_factory* hf, const grpc_channel_args* args, - grpc_pollset_set* interested_parties, - grpc_handshake_manager* handshake_mgr) { - grpc_server_security_connector* security_connector = - reinterpret_cast( - grpc_security_connector_find_in_args(args)); - if (security_connector) { - security_connector->add_handshakers(interested_parties, handshake_mgr); +class ServerSecurityHandshakerFactory : public HandshakerFactory { + public: + void AddHandshakers(const grpc_channel_args* args, + grpc_pollset_set* interested_parties, + HandshakeManager* handshake_mgr) override { + auto* security_connector = + reinterpret_cast( + grpc_security_connector_find_in_args(args)); + if (security_connector) { + security_connector->add_handshakers(interested_parties, handshake_mgr); + } } -} - -static void handshaker_factory_destroy( - grpc_handshaker_factory* handshaker_factory) {} - -static const grpc_handshaker_factory_vtable client_handshaker_factory_vtable = { - client_handshaker_factory_add_handshakers, handshaker_factory_destroy}; - -static grpc_handshaker_factory client_handshaker_factory = { - &client_handshaker_factory_vtable}; - -static const grpc_handshaker_factory_vtable server_handshaker_factory_vtable = { - server_handshaker_factory_add_handshakers, handshaker_factory_destroy}; + ~ServerSecurityHandshakerFactory() override = default; +}; -static grpc_handshaker_factory server_handshaker_factory = { - &server_handshaker_factory_vtable}; +} // namespace // // exported functions // -grpc_handshaker* grpc_security_handshaker_create( +RefCountedPtr SecurityHandshakerCreate( tsi_handshaker* handshaker, grpc_security_connector* connector) { // If no TSI handshaker was created, return a handshaker that always fails. // Otherwise, return a real security handshaker. if (handshaker == nullptr) { - return fail_handshaker_create(); + return MakeRefCounted(); } else { - return security_handshaker_create(handshaker, connector); + return MakeRefCounted(handshaker, connector); } } -void grpc_security_register_handshaker_factories() { - grpc_handshaker_factory_register(false /* at_start */, HANDSHAKER_CLIENT, - &client_handshaker_factory); - grpc_handshaker_factory_register(false /* at_start */, HANDSHAKER_SERVER, - &server_handshaker_factory); +grpc_handshaker* grpc_security_handshaker_create( + tsi_handshaker* handshaker, grpc_security_connector* connector) { + return SecurityHandshakerCreate(handshaker, connector).release(); } + +void SecurityRegisterHandshakerFactories() { + HandshakerRegistry::RegisterHandshakerFactory( + false /* at_start */, HANDSHAKER_CLIENT, + UniquePtr(New())); + HandshakerRegistry::RegisterHandshakerFactory( + false /* at_start */, HANDSHAKER_SERVER, + UniquePtr(New())); +} + +} // namespace grpc_core diff --git a/src/core/lib/security/transport/security_handshaker.h b/src/core/lib/security/transport/security_handshaker.h index 88483b02e74..263fe555967 100644 --- a/src/core/lib/security/transport/security_handshaker.h +++ b/src/core/lib/security/transport/security_handshaker.h @@ -24,11 +24,20 @@ #include "src/core/lib/channel/handshaker.h" #include "src/core/lib/security/security_connector/security_connector.h" +namespace grpc_core { + /// Creates a security handshaker using \a handshaker. -grpc_handshaker* grpc_security_handshaker_create( +RefCountedPtr SecurityHandshakerCreate( tsi_handshaker* handshaker, grpc_security_connector* connector); /// Registers security handshaker factories. -void grpc_security_register_handshaker_factories(); +void SecurityRegisterHandshakerFactories(); + +} // namespace grpc_core + +// TODO(arjunroy): This is transitional to account for the new handshaker API +// and will eventually be removed entirely. +grpc_handshaker* grpc_security_handshaker_create( + tsi_handshaker* handshaker, grpc_security_connector* connector); #endif /* GRPC_CORE_LIB_SECURITY_TRANSPORT_SECURITY_HANDSHAKER_H */ diff --git a/src/core/lib/surface/init.cc b/src/core/lib/surface/init.cc index f704a64b1c9..e507de87c2a 100644 --- a/src/core/lib/surface/init.cc +++ b/src/core/lib/surface/init.cc @@ -134,7 +134,7 @@ void grpc_init(void) { grpc_core::ExecCtx::GlobalInit(); grpc_iomgr_init(); gpr_timers_global_init(); - grpc_handshaker_factory_registry_init(); + grpc_core::HandshakerRegistry::Init(); grpc_security_init(); for (i = 0; i < g_number_of_plugins; i++) { if (g_all_of_the_plugins[i].init != nullptr) { @@ -177,7 +177,7 @@ void grpc_shutdown(void) { gpr_timers_global_destroy(); grpc_tracer_shutdown(); grpc_mdctx_global_shutdown(); - grpc_handshaker_factory_registry_shutdown(); + grpc_core::HandshakerRegistry::Shutdown(); grpc_slice_intern_shutdown(); grpc_core::channelz::ChannelzRegistry::Shutdown(); grpc_stats_shutdown(); diff --git a/src/core/lib/surface/init_secure.cc b/src/core/lib/surface/init_secure.cc index 765350cced0..0e83a11a5f0 100644 --- a/src/core/lib/surface/init_secure.cc +++ b/src/core/lib/surface/init_secure.cc @@ -78,4 +78,4 @@ void grpc_register_security_filters(void) { maybe_prepend_server_auth_filter, nullptr); } -void grpc_security_init() { grpc_security_register_handshaker_factories(); } +void grpc_security_init() { grpc_core::SecurityRegisterHandshakerFactories(); } diff --git a/src/python/grpcio/grpc_core_dependencies.py b/src/python/grpcio/grpc_core_dependencies.py index 19d27412205..71de0c4abe0 100644 --- a/src/python/grpcio/grpc_core_dependencies.py +++ b/src/python/grpcio/grpc_core_dependencies.py @@ -68,7 +68,6 @@ CORE_SOURCE_FILES = [ 'src/core/lib/channel/channelz_registry.cc', 'src/core/lib/channel/connected_channel.cc', 'src/core/lib/channel/handshaker.cc', - 'src/core/lib/channel/handshaker_factory.cc', 'src/core/lib/channel/handshaker_registry.cc', 'src/core/lib/channel/status_util.cc', 'src/core/lib/compression/compression.cc', diff --git a/test/core/handshake/readahead_handshaker_server_ssl.cc b/test/core/handshake/readahead_handshaker_server_ssl.cc index 14d96b5d89c..e4584105e65 100644 --- a/test/core/handshake/readahead_handshaker_server_ssl.cc +++ b/test/core/handshake/readahead_handshaker_server_ssl.cc @@ -49,51 +49,38 @@ * to the security_handshaker). This test is meant to protect code relying on * this functionality that lives outside of this repo. */ -static void readahead_handshaker_destroy(grpc_handshaker* handshaker) { - gpr_free(handshaker); -} - -static void readahead_handshaker_shutdown(grpc_handshaker* handshaker, - grpc_error* error) {} - -static void readahead_handshaker_do_handshake( - grpc_handshaker* handshaker, grpc_tcp_server_acceptor* acceptor, - grpc_closure* on_handshake_done, grpc_handshaker_args* args) { - grpc_endpoint_read(args->endpoint, args->read_buffer, on_handshake_done); -} +namespace grpc_core { -const grpc_handshaker_vtable readahead_handshaker_vtable = { - readahead_handshaker_destroy, readahead_handshaker_shutdown, - readahead_handshaker_do_handshake, "read_ahead"}; - -static grpc_handshaker* readahead_handshaker_create() { - grpc_handshaker* h = - static_cast(gpr_zalloc(sizeof(grpc_handshaker))); - grpc_handshaker_init(&readahead_handshaker_vtable, h); - return h; -} - -static void readahead_handshaker_factory_add_handshakers( - grpc_handshaker_factory* hf, const grpc_channel_args* args, - grpc_pollset_set* interested_parties, - grpc_handshake_manager* handshake_mgr) { - grpc_handshake_manager_add(handshake_mgr, readahead_handshaker_create()); -} +class ReadAheadHandshaker : public Handshaker { + public: + virtual ~ReadAheadHandshaker() {} + const char* name() const override { return "read_ahead"; } + void Shutdown(grpc_error* why) override {} + void DoHandshake(grpc_tcp_server_acceptor* acceptor, + grpc_closure* on_handshake_done, + HandshakerArgs* args) override { + grpc_endpoint_read(args->endpoint, args->read_buffer, on_handshake_done); + } +}; -static void readahead_handshaker_factory_destroy( - grpc_handshaker_factory* handshaker_factory) {} +class ReadAheadHandshakerFactory : public HandshakerFactory { + public: + void AddHandshakers(const grpc_channel_args* args, + grpc_pollset_set* interested_parties, + HandshakeManager* handshake_mgr) override { + handshake_mgr->Add(MakeRefCounted()); + } + ~ReadAheadHandshakerFactory() override = default; +}; -static const grpc_handshaker_factory_vtable - readahead_handshaker_factory_vtable = { - readahead_handshaker_factory_add_handshakers, - readahead_handshaker_factory_destroy}; +} // namespace grpc_core int main(int argc, char* argv[]) { - grpc_handshaker_factory readahead_handshaker_factory = { - &readahead_handshaker_factory_vtable}; + using namespace grpc_core; grpc_init(); - grpc_handshaker_factory_register(true /* at_start */, HANDSHAKER_SERVER, - &readahead_handshaker_factory); + HandshakerRegistry::RegisterHandshakerFactory( + true /* at_start */, HANDSHAKER_SERVER, + UniquePtr(New())); const char* full_alpn_list[] = {"grpc-exp", "h2"}; GPR_ASSERT(server_ssl_test(full_alpn_list, 2, "grpc-exp")); grpc_shutdown(); diff --git a/test/core/security/ssl_server_fuzzer.cc b/test/core/security/ssl_server_fuzzer.cc index c9380126dd0..8533644aceb 100644 --- a/test/core/security/ssl_server_fuzzer.cc +++ b/test/core/security/ssl_server_fuzzer.cc @@ -41,7 +41,8 @@ struct handshake_state { }; static void on_handshake_done(void* arg, grpc_error* error) { - grpc_handshaker_args* args = static_cast(arg); + grpc_core::HandshakerArgs* args = + static_cast(arg); struct handshake_state* state = static_cast(args->user_data); GPR_ASSERT(state->done_callback_called == false); @@ -89,11 +90,12 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { struct handshake_state state; state.done_callback_called = false; - grpc_handshake_manager* handshake_mgr = grpc_handshake_manager_create(); - sc->add_handshakers(nullptr, handshake_mgr); - grpc_handshake_manager_do_handshake( - handshake_mgr, mock_endpoint, nullptr /* channel_args */, deadline, - nullptr /* acceptor */, on_handshake_done, &state); + auto handshake_mgr = + grpc_core::MakeRefCounted(); + sc->add_handshakers(nullptr, handshake_mgr.get()); + handshake_mgr->DoHandshake(mock_endpoint, nullptr /* channel_args */, + deadline, nullptr /* acceptor */, + on_handshake_done, &state); grpc_core::ExecCtx::Get()->Flush(); // If the given string happens to be part of the correct client hello, the @@ -108,7 +110,6 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { GPR_ASSERT(state.done_callback_called); - grpc_handshake_manager_destroy(handshake_mgr); sc.reset(DEBUG_LOCATION, "test"); grpc_server_credentials_release(creds); grpc_slice_unref(cert_slice); diff --git a/tools/doxygen/Doxyfile.core.internal b/tools/doxygen/Doxyfile.core.internal index 2aced414218..86b57b23d9a 100644 --- a/tools/doxygen/Doxyfile.core.internal +++ b/tools/doxygen/Doxyfile.core.internal @@ -1080,7 +1080,6 @@ src/core/lib/channel/connected_channel.h \ src/core/lib/channel/context.h \ src/core/lib/channel/handshaker.cc \ src/core/lib/channel/handshaker.h \ -src/core/lib/channel/handshaker_factory.cc \ src/core/lib/channel/handshaker_factory.h \ src/core/lib/channel/handshaker_registry.cc \ src/core/lib/channel/handshaker_registry.h \ diff --git a/tools/run_tests/generated/sources_and_headers.json b/tools/run_tests/generated/sources_and_headers.json index ab01b8fca6a..84d5c45095f 100644 --- a/tools/run_tests/generated/sources_and_headers.json +++ b/tools/run_tests/generated/sources_and_headers.json @@ -9440,7 +9440,6 @@ "src/core/lib/channel/channelz_registry.cc", "src/core/lib/channel/connected_channel.cc", "src/core/lib/channel/handshaker.cc", - "src/core/lib/channel/handshaker_factory.cc", "src/core/lib/channel/handshaker_registry.cc", "src/core/lib/channel/status_util.cc", "src/core/lib/compression/compression.cc",