Grpc: Change grpc_handshake and grpc_handshake_mgr to use CPP implementations.

grpc_handshake is renamed to GrpcHandshake, using C++ class definitions
instead of C-style vtable classes. Update callers to use new interfaces.
We use RefCountedPtr to simplify reference tracking.
pull/17740/head
Arjun Roy 6 years ago
parent 0c39279b78
commit 195a30bb8b
  1. 1
      BUILD
  2. 6
      CMakeLists.txt
  3. 6
      Makefile
  4. 1
      build.yaml
  5. 1
      config.m4
  6. 1
      config.w32
  7. 1
      gRPC-Core.podspec
  8. 1
      grpc.gemspec
  9. 4
      grpc.gyp
  10. 1
      package.xml
  11. 298
      src/core/ext/filters/client_channel/http_connect_handshaker.cc
  12. 22
      src/core/ext/transport/chttp2/client/chttp2_connector.cc
  13. 33
      src/core/ext/transport/chttp2/server/chttp2_server.cc
  14. 327
      src/core/lib/channel/handshaker.cc
  15. 207
      src/core/lib/channel/handshaker.h
  16. 42
      src/core/lib/channel/handshaker_factory.cc
  17. 24
      src/core/lib/channel/handshaker_factory.h
  18. 110
      src/core/lib/channel/handshaker_registry.cc
  19. 33
      src/core/lib/channel/handshaker_registry.h
  20. 27
      src/core/lib/http/httpcli_security_connector.cc
  21. 18
      src/core/lib/security/security_connector/alts/alts_security_connector.cc
  22. 12
      src/core/lib/security/security_connector/fake/fake_security_connector.cc
  23. 18
      src/core/lib/security/security_connector/local/local_security_connector.cc
  24. 4
      src/core/lib/security/security_connector/security_connector.h
  25. 10
      src/core/lib/security/security_connector/ssl/ssl_security_connector.cc
  26. 535
      src/core/lib/security/transport/security_handshaker.cc
  27. 13
      src/core/lib/security/transport/security_handshaker.h
  28. 4
      src/core/lib/surface/init.cc
  29. 2
      src/core/lib/surface/init_secure.cc
  30. 1
      src/python/grpcio/grpc_core_dependencies.py
  31. 61
      test/core/handshake/readahead_handshaker_server_ssl.cc
  32. 15
      test/core/security/ssl_server_fuzzer.cc
  33. 1
      tools/doxygen/Doxyfile.core.internal
  34. 1
      tools/run_tests/generated/sources_and_headers.json

@ -701,7 +701,6 @@ grpc_cc_library(
"src/core/lib/channel/channelz_registry.cc", "src/core/lib/channel/channelz_registry.cc",
"src/core/lib/channel/connected_channel.cc", "src/core/lib/channel/connected_channel.cc",
"src/core/lib/channel/handshaker.cc", "src/core/lib/channel/handshaker.cc",
"src/core/lib/channel/handshaker_factory.cc",
"src/core/lib/channel/handshaker_registry.cc", "src/core/lib/channel/handshaker_registry.cc",
"src/core/lib/channel/status_util.cc", "src/core/lib/channel/status_util.cc",
"src/core/lib/compression/compression.cc", "src/core/lib/compression/compression.cc",

@ -971,7 +971,6 @@ add_library(grpc
src/core/lib/channel/channelz_registry.cc src/core/lib/channel/channelz_registry.cc
src/core/lib/channel/connected_channel.cc src/core/lib/channel/connected_channel.cc
src/core/lib/channel/handshaker.cc src/core/lib/channel/handshaker.cc
src/core/lib/channel/handshaker_factory.cc
src/core/lib/channel/handshaker_registry.cc src/core/lib/channel/handshaker_registry.cc
src/core/lib/channel/status_util.cc src/core/lib/channel/status_util.cc
src/core/lib/compression/compression.cc src/core/lib/compression/compression.cc
@ -1397,7 +1396,6 @@ add_library(grpc_cronet
src/core/lib/channel/channelz_registry.cc src/core/lib/channel/channelz_registry.cc
src/core/lib/channel/connected_channel.cc src/core/lib/channel/connected_channel.cc
src/core/lib/channel/handshaker.cc src/core/lib/channel/handshaker.cc
src/core/lib/channel/handshaker_factory.cc
src/core/lib/channel/handshaker_registry.cc src/core/lib/channel/handshaker_registry.cc
src/core/lib/channel/status_util.cc src/core/lib/channel/status_util.cc
src/core/lib/compression/compression.cc src/core/lib/compression/compression.cc
@ -1808,7 +1806,6 @@ add_library(grpc_test_util
src/core/lib/channel/channelz_registry.cc src/core/lib/channel/channelz_registry.cc
src/core/lib/channel/connected_channel.cc src/core/lib/channel/connected_channel.cc
src/core/lib/channel/handshaker.cc src/core/lib/channel/handshaker.cc
src/core/lib/channel/handshaker_factory.cc
src/core/lib/channel/handshaker_registry.cc src/core/lib/channel/handshaker_registry.cc
src/core/lib/channel/status_util.cc src/core/lib/channel/status_util.cc
src/core/lib/compression/compression.cc src/core/lib/compression/compression.cc
@ -2134,7 +2131,6 @@ add_library(grpc_test_util_unsecure
src/core/lib/channel/channelz_registry.cc src/core/lib/channel/channelz_registry.cc
src/core/lib/channel/connected_channel.cc src/core/lib/channel/connected_channel.cc
src/core/lib/channel/handshaker.cc src/core/lib/channel/handshaker.cc
src/core/lib/channel/handshaker_factory.cc
src/core/lib/channel/handshaker_registry.cc src/core/lib/channel/handshaker_registry.cc
src/core/lib/channel/status_util.cc src/core/lib/channel/status_util.cc
src/core/lib/compression/compression.cc src/core/lib/compression/compression.cc
@ -2436,7 +2432,6 @@ add_library(grpc_unsecure
src/core/lib/channel/channelz_registry.cc src/core/lib/channel/channelz_registry.cc
src/core/lib/channel/connected_channel.cc src/core/lib/channel/connected_channel.cc
src/core/lib/channel/handshaker.cc src/core/lib/channel/handshaker.cc
src/core/lib/channel/handshaker_factory.cc
src/core/lib/channel/handshaker_registry.cc src/core/lib/channel/handshaker_registry.cc
src/core/lib/channel/status_util.cc src/core/lib/channel/status_util.cc
src/core/lib/compression/compression.cc src/core/lib/compression/compression.cc
@ -3324,7 +3319,6 @@ add_library(grpc++_cronet
src/core/lib/channel/channelz_registry.cc src/core/lib/channel/channelz_registry.cc
src/core/lib/channel/connected_channel.cc src/core/lib/channel/connected_channel.cc
src/core/lib/channel/handshaker.cc src/core/lib/channel/handshaker.cc
src/core/lib/channel/handshaker_factory.cc
src/core/lib/channel/handshaker_registry.cc src/core/lib/channel/handshaker_registry.cc
src/core/lib/channel/status_util.cc src/core/lib/channel/status_util.cc
src/core/lib/compression/compression.cc src/core/lib/compression/compression.cc

@ -3497,7 +3497,6 @@ LIBGRPC_SRC = \
src/core/lib/channel/channelz_registry.cc \ src/core/lib/channel/channelz_registry.cc \
src/core/lib/channel/connected_channel.cc \ src/core/lib/channel/connected_channel.cc \
src/core/lib/channel/handshaker.cc \ src/core/lib/channel/handshaker.cc \
src/core/lib/channel/handshaker_factory.cc \
src/core/lib/channel/handshaker_registry.cc \ src/core/lib/channel/handshaker_registry.cc \
src/core/lib/channel/status_util.cc \ src/core/lib/channel/status_util.cc \
src/core/lib/compression/compression.cc \ src/core/lib/compression/compression.cc \
@ -3917,7 +3916,6 @@ LIBGRPC_CRONET_SRC = \
src/core/lib/channel/channelz_registry.cc \ src/core/lib/channel/channelz_registry.cc \
src/core/lib/channel/connected_channel.cc \ src/core/lib/channel/connected_channel.cc \
src/core/lib/channel/handshaker.cc \ src/core/lib/channel/handshaker.cc \
src/core/lib/channel/handshaker_factory.cc \
src/core/lib/channel/handshaker_registry.cc \ src/core/lib/channel/handshaker_registry.cc \
src/core/lib/channel/status_util.cc \ src/core/lib/channel/status_util.cc \
src/core/lib/compression/compression.cc \ src/core/lib/compression/compression.cc \
@ -4321,7 +4319,6 @@ LIBGRPC_TEST_UTIL_SRC = \
src/core/lib/channel/channelz_registry.cc \ src/core/lib/channel/channelz_registry.cc \
src/core/lib/channel/connected_channel.cc \ src/core/lib/channel/connected_channel.cc \
src/core/lib/channel/handshaker.cc \ src/core/lib/channel/handshaker.cc \
src/core/lib/channel/handshaker_factory.cc \
src/core/lib/channel/handshaker_registry.cc \ src/core/lib/channel/handshaker_registry.cc \
src/core/lib/channel/status_util.cc \ src/core/lib/channel/status_util.cc \
src/core/lib/compression/compression.cc \ src/core/lib/compression/compression.cc \
@ -4634,7 +4631,6 @@ LIBGRPC_TEST_UTIL_UNSECURE_SRC = \
src/core/lib/channel/channelz_registry.cc \ src/core/lib/channel/channelz_registry.cc \
src/core/lib/channel/connected_channel.cc \ src/core/lib/channel/connected_channel.cc \
src/core/lib/channel/handshaker.cc \ src/core/lib/channel/handshaker.cc \
src/core/lib/channel/handshaker_factory.cc \
src/core/lib/channel/handshaker_registry.cc \ src/core/lib/channel/handshaker_registry.cc \
src/core/lib/channel/status_util.cc \ src/core/lib/channel/status_util.cc \
src/core/lib/compression/compression.cc \ src/core/lib/compression/compression.cc \
@ -4910,7 +4906,6 @@ LIBGRPC_UNSECURE_SRC = \
src/core/lib/channel/channelz_registry.cc \ src/core/lib/channel/channelz_registry.cc \
src/core/lib/channel/connected_channel.cc \ src/core/lib/channel/connected_channel.cc \
src/core/lib/channel/handshaker.cc \ src/core/lib/channel/handshaker.cc \
src/core/lib/channel/handshaker_factory.cc \
src/core/lib/channel/handshaker_registry.cc \ src/core/lib/channel/handshaker_registry.cc \
src/core/lib/channel/status_util.cc \ src/core/lib/channel/status_util.cc \
src/core/lib/compression/compression.cc \ src/core/lib/compression/compression.cc \
@ -5775,7 +5770,6 @@ LIBGRPC++_CRONET_SRC = \
src/core/lib/channel/channelz_registry.cc \ src/core/lib/channel/channelz_registry.cc \
src/core/lib/channel/connected_channel.cc \ src/core/lib/channel/connected_channel.cc \
src/core/lib/channel/handshaker.cc \ src/core/lib/channel/handshaker.cc \
src/core/lib/channel/handshaker_factory.cc \
src/core/lib/channel/handshaker_registry.cc \ src/core/lib/channel/handshaker_registry.cc \
src/core/lib/channel/status_util.cc \ src/core/lib/channel/status_util.cc \
src/core/lib/compression/compression.cc \ src/core/lib/compression/compression.cc \

@ -242,7 +242,6 @@ filegroups:
- src/core/lib/channel/channelz_registry.cc - src/core/lib/channel/channelz_registry.cc
- src/core/lib/channel/connected_channel.cc - src/core/lib/channel/connected_channel.cc
- src/core/lib/channel/handshaker.cc - src/core/lib/channel/handshaker.cc
- src/core/lib/channel/handshaker_factory.cc
- src/core/lib/channel/handshaker_registry.cc - src/core/lib/channel/handshaker_registry.cc
- src/core/lib/channel/status_util.cc - src/core/lib/channel/status_util.cc
- src/core/lib/compression/compression.cc - src/core/lib/compression/compression.cc

@ -94,7 +94,6 @@ if test "$PHP_GRPC" != "no"; then
src/core/lib/channel/channelz_registry.cc \ src/core/lib/channel/channelz_registry.cc \
src/core/lib/channel/connected_channel.cc \ src/core/lib/channel/connected_channel.cc \
src/core/lib/channel/handshaker.cc \ src/core/lib/channel/handshaker.cc \
src/core/lib/channel/handshaker_factory.cc \
src/core/lib/channel/handshaker_registry.cc \ src/core/lib/channel/handshaker_registry.cc \
src/core/lib/channel/status_util.cc \ src/core/lib/channel/status_util.cc \
src/core/lib/compression/compression.cc \ src/core/lib/compression/compression.cc \

@ -69,7 +69,6 @@ if (PHP_GRPC != "no") {
"src\\core\\lib\\channel\\channelz_registry.cc " + "src\\core\\lib\\channel\\channelz_registry.cc " +
"src\\core\\lib\\channel\\connected_channel.cc " + "src\\core\\lib\\channel\\connected_channel.cc " +
"src\\core\\lib\\channel\\handshaker.cc " + "src\\core\\lib\\channel\\handshaker.cc " +
"src\\core\\lib\\channel\\handshaker_factory.cc " +
"src\\core\\lib\\channel\\handshaker_registry.cc " + "src\\core\\lib\\channel\\handshaker_registry.cc " +
"src\\core\\lib\\channel\\status_util.cc " + "src\\core\\lib\\channel\\status_util.cc " +
"src\\core\\lib\\compression\\compression.cc " + "src\\core\\lib\\compression\\compression.cc " +

@ -543,7 +543,6 @@ Pod::Spec.new do |s|
'src/core/lib/channel/channelz_registry.cc', 'src/core/lib/channel/channelz_registry.cc',
'src/core/lib/channel/connected_channel.cc', 'src/core/lib/channel/connected_channel.cc',
'src/core/lib/channel/handshaker.cc', 'src/core/lib/channel/handshaker.cc',
'src/core/lib/channel/handshaker_factory.cc',
'src/core/lib/channel/handshaker_registry.cc', 'src/core/lib/channel/handshaker_registry.cc',
'src/core/lib/channel/status_util.cc', 'src/core/lib/channel/status_util.cc',
'src/core/lib/compression/compression.cc', 'src/core/lib/compression/compression.cc',

@ -477,7 +477,6 @@ Gem::Specification.new do |s|
s.files += %w( src/core/lib/channel/channelz_registry.cc ) s.files += %w( src/core/lib/channel/channelz_registry.cc )
s.files += %w( src/core/lib/channel/connected_channel.cc ) s.files += %w( src/core/lib/channel/connected_channel.cc )
s.files += %w( src/core/lib/channel/handshaker.cc ) s.files += %w( src/core/lib/channel/handshaker.cc )
s.files += %w( src/core/lib/channel/handshaker_factory.cc )
s.files += %w( src/core/lib/channel/handshaker_registry.cc ) s.files += %w( src/core/lib/channel/handshaker_registry.cc )
s.files += %w( src/core/lib/channel/status_util.cc ) s.files += %w( src/core/lib/channel/status_util.cc )
s.files += %w( src/core/lib/compression/compression.cc ) s.files += %w( src/core/lib/compression/compression.cc )

@ -276,7 +276,6 @@
'src/core/lib/channel/channelz_registry.cc', 'src/core/lib/channel/channelz_registry.cc',
'src/core/lib/channel/connected_channel.cc', 'src/core/lib/channel/connected_channel.cc',
'src/core/lib/channel/handshaker.cc', 'src/core/lib/channel/handshaker.cc',
'src/core/lib/channel/handshaker_factory.cc',
'src/core/lib/channel/handshaker_registry.cc', 'src/core/lib/channel/handshaker_registry.cc',
'src/core/lib/channel/status_util.cc', 'src/core/lib/channel/status_util.cc',
'src/core/lib/compression/compression.cc', 'src/core/lib/compression/compression.cc',
@ -643,7 +642,6 @@
'src/core/lib/channel/channelz_registry.cc', 'src/core/lib/channel/channelz_registry.cc',
'src/core/lib/channel/connected_channel.cc', 'src/core/lib/channel/connected_channel.cc',
'src/core/lib/channel/handshaker.cc', 'src/core/lib/channel/handshaker.cc',
'src/core/lib/channel/handshaker_factory.cc',
'src/core/lib/channel/handshaker_registry.cc', 'src/core/lib/channel/handshaker_registry.cc',
'src/core/lib/channel/status_util.cc', 'src/core/lib/channel/status_util.cc',
'src/core/lib/compression/compression.cc', 'src/core/lib/compression/compression.cc',
@ -889,7 +887,6 @@
'src/core/lib/channel/channelz_registry.cc', 'src/core/lib/channel/channelz_registry.cc',
'src/core/lib/channel/connected_channel.cc', 'src/core/lib/channel/connected_channel.cc',
'src/core/lib/channel/handshaker.cc', 'src/core/lib/channel/handshaker.cc',
'src/core/lib/channel/handshaker_factory.cc',
'src/core/lib/channel/handshaker_registry.cc', 'src/core/lib/channel/handshaker_registry.cc',
'src/core/lib/channel/status_util.cc', 'src/core/lib/channel/status_util.cc',
'src/core/lib/compression/compression.cc', 'src/core/lib/compression/compression.cc',
@ -1111,7 +1108,6 @@
'src/core/lib/channel/channelz_registry.cc', 'src/core/lib/channel/channelz_registry.cc',
'src/core/lib/channel/connected_channel.cc', 'src/core/lib/channel/connected_channel.cc',
'src/core/lib/channel/handshaker.cc', 'src/core/lib/channel/handshaker.cc',
'src/core/lib/channel/handshaker_factory.cc',
'src/core/lib/channel/handshaker_registry.cc', 'src/core/lib/channel/handshaker_registry.cc',
'src/core/lib/channel/status_util.cc', 'src/core/lib/channel/status_util.cc',
'src/core/lib/compression/compression.cc', 'src/core/lib/compression/compression.cc',

@ -482,7 +482,6 @@
<file baseinstalldir="/" name="src/core/lib/channel/channelz_registry.cc" role="src" /> <file baseinstalldir="/" name="src/core/lib/channel/channelz_registry.cc" role="src" />
<file baseinstalldir="/" name="src/core/lib/channel/connected_channel.cc" role="src" /> <file baseinstalldir="/" name="src/core/lib/channel/connected_channel.cc" role="src" />
<file baseinstalldir="/" name="src/core/lib/channel/handshaker.cc" role="src" /> <file baseinstalldir="/" name="src/core/lib/channel/handshaker.cc" role="src" />
<file baseinstalldir="/" name="src/core/lib/channel/handshaker_factory.cc" role="src" />
<file baseinstalldir="/" name="src/core/lib/channel/handshaker_registry.cc" role="src" /> <file baseinstalldir="/" name="src/core/lib/channel/handshaker_registry.cc" role="src" />
<file baseinstalldir="/" name="src/core/lib/channel/status_util.cc" role="src" /> <file baseinstalldir="/" name="src/core/lib/channel/status_util.cc" role="src" />
<file baseinstalldir="/" name="src/core/lib/compression/compression.cc" role="src" /> <file baseinstalldir="/" name="src/core/lib/compression/compression.cc" role="src" />

@ -33,151 +33,160 @@
#include "src/core/lib/channel/handshaker_registry.h" #include "src/core/lib/channel/handshaker_registry.h"
#include "src/core/lib/gpr/env.h" #include "src/core/lib/gpr/env.h"
#include "src/core/lib/gpr/string.h" #include "src/core/lib/gpr/string.h"
#include "src/core/lib/gprpp/mutex_lock.h"
#include "src/core/lib/http/format_request.h" #include "src/core/lib/http/format_request.h"
#include "src/core/lib/http/parser.h" #include "src/core/lib/http/parser.h"
#include "src/core/lib/slice/slice_internal.h" #include "src/core/lib/slice/slice_internal.h"
#include "src/core/lib/uri/uri_parser.h" #include "src/core/lib/uri/uri_parser.h"
typedef struct http_connect_handshaker { namespace grpc_core {
// Base class. Must be first.
grpc_handshaker base;
gpr_refcount refcount; namespace {
gpr_mu mu;
bool shutdown; class HttpConnectHandshaker : public Handshaker {
public:
HttpConnectHandshaker();
void Shutdown(grpc_error* why) override;
void DoHandshake(grpc_tcp_server_acceptor* acceptor,
grpc_closure* on_handshake_done,
HandshakerArgs* args) override;
const char* name() const override { return "http_connect"; }
private:
virtual ~HttpConnectHandshaker();
void CleanupArgsForFailureLocked();
void HandshakeFailedLocked(grpc_error* error);
static void OnWriteDone(void* arg, grpc_error* error);
static void OnReadDone(void* arg, grpc_error* error);
gpr_mu mu_;
bool is_shutdown_ = false;
// Endpoint and read buffer to destroy after a shutdown. // Endpoint and read buffer to destroy after a shutdown.
grpc_endpoint* endpoint_to_destroy; grpc_endpoint* endpoint_to_destroy_ = nullptr;
grpc_slice_buffer* read_buffer_to_destroy; grpc_slice_buffer* read_buffer_to_destroy_ = nullptr;
// State saved while performing the handshake. // State saved while performing the handshake.
grpc_handshaker_args* args; HandshakerArgs* args_ = nullptr;
grpc_closure* on_handshake_done; grpc_closure* on_handshake_done_ = nullptr;
// Objects for processing the HTTP CONNECT request and response. // Objects for processing the HTTP CONNECT request and response.
grpc_slice_buffer write_buffer; grpc_slice_buffer write_buffer_;
grpc_closure request_done_closure; grpc_closure request_done_closure_;
grpc_closure response_read_closure; grpc_closure response_read_closure_;
grpc_http_parser http_parser; grpc_http_parser http_parser_;
grpc_http_response http_response; grpc_http_response http_response_;
} http_connect_handshaker; };
// Unref and clean up handshaker. HttpConnectHandshaker::~HttpConnectHandshaker() {
static void http_connect_handshaker_unref(http_connect_handshaker* handshaker) { gpr_mu_destroy(&mu_);
if (gpr_unref(&handshaker->refcount)) { if (endpoint_to_destroy_ != nullptr) {
gpr_mu_destroy(&handshaker->mu); grpc_endpoint_destroy(endpoint_to_destroy_);
if (handshaker->endpoint_to_destroy != nullptr) {
grpc_endpoint_destroy(handshaker->endpoint_to_destroy);
} }
if (handshaker->read_buffer_to_destroy != nullptr) { if (read_buffer_to_destroy_ != nullptr) {
grpc_slice_buffer_destroy_internal(handshaker->read_buffer_to_destroy); grpc_slice_buffer_destroy_internal(read_buffer_to_destroy_);
gpr_free(handshaker->read_buffer_to_destroy); gpr_free(read_buffer_to_destroy_);
}
grpc_slice_buffer_destroy_internal(&handshaker->write_buffer);
grpc_http_parser_destroy(&handshaker->http_parser);
grpc_http_response_destroy(&handshaker->http_response);
gpr_free(handshaker);
} }
grpc_slice_buffer_destroy_internal(&write_buffer_);
grpc_http_parser_destroy(&http_parser_);
grpc_http_response_destroy(&http_response_);
} }
// Set args fields to nullptr, saving the endpoint and read buffer for // Set args fields to nullptr, saving the endpoint and read buffer for
// later destruction. // later destruction.
static void cleanup_args_for_failure_locked( void HttpConnectHandshaker::CleanupArgsForFailureLocked() {
http_connect_handshaker* handshaker) { endpoint_to_destroy_ = args_->endpoint;
handshaker->endpoint_to_destroy = handshaker->args->endpoint; args_->endpoint = nullptr;
handshaker->args->endpoint = nullptr; read_buffer_to_destroy_ = args_->read_buffer;
handshaker->read_buffer_to_destroy = handshaker->args->read_buffer; args_->read_buffer = nullptr;
handshaker->args->read_buffer = nullptr; grpc_channel_args_destroy(args_->args);
grpc_channel_args_destroy(handshaker->args->args); args_->args = nullptr;
handshaker->args->args = nullptr;
} }
// If the handshake failed or we're shutting down, clean up and invoke the // If the handshake failed or we're shutting down, clean up and invoke the
// callback with the error. // callback with the error.
static void handshake_failed_locked(http_connect_handshaker* handshaker, void HttpConnectHandshaker::HandshakeFailedLocked(grpc_error* error) {
grpc_error* error) {
if (error == GRPC_ERROR_NONE) { if (error == GRPC_ERROR_NONE) {
// If we were shut down after an endpoint operation succeeded but // If we were shut down after an endpoint operation succeeded but
// before the endpoint callback was invoked, we need to generate our // before the endpoint callback was invoked, we need to generate our
// own error. // own error.
error = GRPC_ERROR_CREATE_FROM_STATIC_STRING("Handshaker shutdown"); error = GRPC_ERROR_CREATE_FROM_STATIC_STRING("Handshaker shutdown");
} }
if (!handshaker->shutdown) { if (!is_shutdown_) {
// TODO(ctiller): It is currently necessary to shutdown endpoints // TODO(ctiller): It is currently necessary to shutdown endpoints
// before destroying them, even if we know that there are no // before destroying them, even if we know that there are no
// pending read/write callbacks. This should be fixed, at which // pending read/write callbacks. This should be fixed, at which
// point this can be removed. // point this can be removed.
grpc_endpoint_shutdown(handshaker->args->endpoint, GRPC_ERROR_REF(error)); grpc_endpoint_shutdown(args_->endpoint, GRPC_ERROR_REF(error));
// Not shutting down, so the handshake failed. Clean up before // Not shutting down, so the handshake failed. Clean up before
// invoking the callback. // invoking the callback.
cleanup_args_for_failure_locked(handshaker); CleanupArgsForFailureLocked();
// Set shutdown to true so that subsequent calls to // Set shutdown to true so that subsequent calls to
// http_connect_handshaker_shutdown() do nothing. // http_connect_handshaker_shutdown() do nothing.
handshaker->shutdown = true; is_shutdown_ = true;
} }
// Invoke callback. // Invoke callback.
GRPC_CLOSURE_SCHED(handshaker->on_handshake_done, error); GRPC_CLOSURE_SCHED(on_handshake_done_, error);
} }
// Callback invoked when finished writing HTTP CONNECT request. // Callback invoked when finished writing HTTP CONNECT request.
static void on_write_done(void* arg, grpc_error* error) { void HttpConnectHandshaker::OnWriteDone(void* arg, grpc_error* error) {
http_connect_handshaker* handshaker = auto* handshaker = static_cast<HttpConnectHandshaker*>(arg);
static_cast<http_connect_handshaker*>(arg); gpr_mu_lock(&handshaker->mu_);
gpr_mu_lock(&handshaker->mu); if (error != GRPC_ERROR_NONE || handshaker->is_shutdown_) {
if (error != GRPC_ERROR_NONE || handshaker->shutdown) {
// If the write failed or we're shutting down, clean up and invoke the // If the write failed or we're shutting down, clean up and invoke the
// callback with the error. // callback with the error.
handshake_failed_locked(handshaker, GRPC_ERROR_REF(error)); handshaker->HandshakeFailedLocked(GRPC_ERROR_REF(error));
gpr_mu_unlock(&handshaker->mu); gpr_mu_unlock(&handshaker->mu_);
http_connect_handshaker_unref(handshaker); handshaker->Unref();
} else { } else {
// Otherwise, read the response. // Otherwise, read the response.
// The read callback inherits our ref to the handshaker. // The read callback inherits our ref to the handshaker.
grpc_endpoint_read(handshaker->args->endpoint, grpc_endpoint_read(handshaker->args_->endpoint,
handshaker->args->read_buffer, handshaker->args_->read_buffer,
&handshaker->response_read_closure); &handshaker->response_read_closure_);
gpr_mu_unlock(&handshaker->mu); gpr_mu_unlock(&handshaker->mu_);
} }
} }
// Callback invoked for reading HTTP CONNECT response. // Callback invoked for reading HTTP CONNECT response.
static void on_read_done(void* arg, grpc_error* error) { void HttpConnectHandshaker::OnReadDone(void* arg, grpc_error* error) {
http_connect_handshaker* handshaker = auto* handshaker = static_cast<HttpConnectHandshaker*>(arg);
static_cast<http_connect_handshaker*>(arg);
gpr_mu_lock(&handshaker->mu); gpr_mu_lock(&handshaker->mu_);
if (error != GRPC_ERROR_NONE || handshaker->shutdown) { if (error != GRPC_ERROR_NONE || handshaker->is_shutdown_) {
// If the read failed or we're shutting down, clean up and invoke the // If the read failed or we're shutting down, clean up and invoke the
// callback with the error. // callback with the error.
handshake_failed_locked(handshaker, GRPC_ERROR_REF(error)); handshaker->HandshakeFailedLocked(GRPC_ERROR_REF(error));
goto done; goto done;
} }
// Add buffer to parser. // Add buffer to parser.
for (size_t i = 0; i < handshaker->args->read_buffer->count; ++i) { for (size_t i = 0; i < handshaker->args_->read_buffer->count; ++i) {
if (GRPC_SLICE_LENGTH(handshaker->args->read_buffer->slices[i]) > 0) { if (GRPC_SLICE_LENGTH(handshaker->args_->read_buffer->slices[i]) > 0) {
size_t body_start_offset = 0; size_t body_start_offset = 0;
error = grpc_http_parser_parse(&handshaker->http_parser, error = grpc_http_parser_parse(&handshaker->http_parser_,
handshaker->args->read_buffer->slices[i], handshaker->args_->read_buffer->slices[i],
&body_start_offset); &body_start_offset);
if (error != GRPC_ERROR_NONE) { if (error != GRPC_ERROR_NONE) {
handshake_failed_locked(handshaker, error); handshaker->HandshakeFailedLocked(error);
goto done; goto done;
} }
if (handshaker->http_parser.state == GRPC_HTTP_BODY) { if (handshaker->http_parser_.state == GRPC_HTTP_BODY) {
// Remove the data we've already read from the read buffer, // Remove the data we've already read from the read buffer,
// leaving only the leftover bytes (if any). // leaving only the leftover bytes (if any).
grpc_slice_buffer tmp_buffer; grpc_slice_buffer tmp_buffer;
grpc_slice_buffer_init(&tmp_buffer); grpc_slice_buffer_init(&tmp_buffer);
if (body_start_offset < if (body_start_offset <
GRPC_SLICE_LENGTH(handshaker->args->read_buffer->slices[i])) { GRPC_SLICE_LENGTH(handshaker->args_->read_buffer->slices[i])) {
grpc_slice_buffer_add( grpc_slice_buffer_add(
&tmp_buffer, &tmp_buffer,
grpc_slice_split_tail(&handshaker->args->read_buffer->slices[i], grpc_slice_split_tail(&handshaker->args_->read_buffer->slices[i],
body_start_offset)); body_start_offset));
} }
grpc_slice_buffer_addn(&tmp_buffer, grpc_slice_buffer_addn(&tmp_buffer,
&handshaker->args->read_buffer->slices[i + 1], &handshaker->args_->read_buffer->slices[i + 1],
handshaker->args->read_buffer->count - i - 1); handshaker->args_->read_buffer->count - i - 1);
grpc_slice_buffer_swap(handshaker->args->read_buffer, &tmp_buffer); grpc_slice_buffer_swap(handshaker->args_->read_buffer, &tmp_buffer);
grpc_slice_buffer_destroy_internal(&tmp_buffer); grpc_slice_buffer_destroy_internal(&tmp_buffer);
break; break;
} }
@ -194,64 +203,53 @@ static void on_read_done(void* arg, grpc_error* error) {
// need to fix the HTTP parser to understand when the body is // need to fix the HTTP parser to understand when the body is
// complete (e.g., handling chunked transfer encoding or looking // complete (e.g., handling chunked transfer encoding or looking
// at the Content-Length: header). // at the Content-Length: header).
if (handshaker->http_parser.state != GRPC_HTTP_BODY) { if (handshaker->http_parser_.state != GRPC_HTTP_BODY) {
grpc_slice_buffer_reset_and_unref_internal(handshaker->args->read_buffer); grpc_slice_buffer_reset_and_unref_internal(handshaker->args_->read_buffer);
grpc_endpoint_read(handshaker->args->endpoint, grpc_endpoint_read(handshaker->args_->endpoint,
handshaker->args->read_buffer, handshaker->args_->read_buffer,
&handshaker->response_read_closure); &handshaker->response_read_closure_);
gpr_mu_unlock(&handshaker->mu); gpr_mu_unlock(&handshaker->mu_);
return; return;
} }
// Make sure we got a 2xx response. // Make sure we got a 2xx response.
if (handshaker->http_response.status < 200 || if (handshaker->http_response_.status < 200 ||
handshaker->http_response.status >= 300) { handshaker->http_response_.status >= 300) {
char* msg; char* msg;
gpr_asprintf(&msg, "HTTP proxy returned response code %d", gpr_asprintf(&msg, "HTTP proxy returned response code %d",
handshaker->http_response.status); handshaker->http_response_.status);
error = GRPC_ERROR_CREATE_FROM_COPIED_STRING(msg); error = GRPC_ERROR_CREATE_FROM_COPIED_STRING(msg);
gpr_free(msg); gpr_free(msg);
handshake_failed_locked(handshaker, error); handshaker->HandshakeFailedLocked(error);
goto done; goto done;
} }
// Success. Invoke handshake-done callback. // Success. Invoke handshake-done callback.
GRPC_CLOSURE_SCHED(handshaker->on_handshake_done, error); GRPC_CLOSURE_SCHED(handshaker->on_handshake_done_, error);
done: done:
// Set shutdown to true so that subsequent calls to // Set shutdown to true so that subsequent calls to
// http_connect_handshaker_shutdown() do nothing. // http_connect_handshaker_shutdown() do nothing.
handshaker->shutdown = true; handshaker->is_shutdown_ = true;
gpr_mu_unlock(&handshaker->mu); gpr_mu_unlock(&handshaker->mu_);
http_connect_handshaker_unref(handshaker); handshaker->Unref();
} }
// //
// Public handshaker methods // Public handshaker methods
// //
static void http_connect_handshaker_destroy(grpc_handshaker* handshaker_in) { void HttpConnectHandshaker::Shutdown(grpc_error* why) {
http_connect_handshaker* handshaker = gpr_mu_lock(&mu_);
reinterpret_cast<http_connect_handshaker*>(handshaker_in); if (!is_shutdown_) {
http_connect_handshaker_unref(handshaker); is_shutdown_ = true;
} grpc_endpoint_shutdown(args_->endpoint, GRPC_ERROR_REF(why));
CleanupArgsForFailureLocked();
static void http_connect_handshaker_shutdown(grpc_handshaker* handshaker_in,
grpc_error* why) {
http_connect_handshaker* handshaker =
reinterpret_cast<http_connect_handshaker*>(handshaker_in);
gpr_mu_lock(&handshaker->mu);
if (!handshaker->shutdown) {
handshaker->shutdown = true;
grpc_endpoint_shutdown(handshaker->args->endpoint, GRPC_ERROR_REF(why));
cleanup_args_for_failure_locked(handshaker);
} }
gpr_mu_unlock(&handshaker->mu); gpr_mu_unlock(&mu_);
GRPC_ERROR_UNREF(why); GRPC_ERROR_UNREF(why);
} }
static void http_connect_handshaker_do_handshake( void HttpConnectHandshaker::DoHandshake(grpc_tcp_server_acceptor* acceptor,
grpc_handshaker* handshaker_in, grpc_tcp_server_acceptor* acceptor, grpc_closure* on_handshake_done,
grpc_closure* on_handshake_done, grpc_handshaker_args* args) { HandshakerArgs* args) {
http_connect_handshaker* handshaker =
reinterpret_cast<http_connect_handshaker*>(handshaker_in);
// Check for HTTP CONNECT channel arg. // Check for HTTP CONNECT channel arg.
// If not found, invoke on_handshake_done without doing anything. // If not found, invoke on_handshake_done without doing anything.
const grpc_arg* arg = const grpc_arg* arg =
@ -260,9 +258,9 @@ static void http_connect_handshaker_do_handshake(
if (server_name == nullptr) { if (server_name == nullptr) {
// Set shutdown to true so that subsequent calls to // Set shutdown to true so that subsequent calls to
// http_connect_handshaker_shutdown() do nothing. // http_connect_handshaker_shutdown() do nothing.
gpr_mu_lock(&handshaker->mu); gpr_mu_lock(&mu_);
handshaker->shutdown = true; is_shutdown_ = true;
gpr_mu_unlock(&handshaker->mu); gpr_mu_unlock(&mu_);
GRPC_CLOSURE_SCHED(on_handshake_done, GRPC_ERROR_NONE); GRPC_CLOSURE_SCHED(on_handshake_done, GRPC_ERROR_NONE);
return; return;
} }
@ -280,6 +278,7 @@ static void http_connect_handshaker_do_handshake(
gpr_malloc(sizeof(grpc_http_header) * num_header_strings)); gpr_malloc(sizeof(grpc_http_header) * num_header_strings));
for (size_t i = 0; i < num_header_strings; ++i) { for (size_t i = 0; i < num_header_strings; ++i) {
char* sep = strchr(header_strings[i], ':'); char* sep = strchr(header_strings[i], ':');
if (sep == nullptr) { if (sep == nullptr) {
gpr_log(GPR_ERROR, "skipping unparseable HTTP CONNECT header: %s", gpr_log(GPR_ERROR, "skipping unparseable HTTP CONNECT header: %s",
header_strings[i]); header_strings[i]);
@ -292,9 +291,9 @@ static void http_connect_handshaker_do_handshake(
} }
} }
// Save state in the handshaker object. // Save state in the handshaker object.
gpr_mu_lock(&handshaker->mu); MutexLock lock(&mu_);
handshaker->args = args; args_ = args;
handshaker->on_handshake_done = on_handshake_done; on_handshake_done_ = on_handshake_done;
// Log connection via proxy. // Log connection via proxy.
char* proxy_name = grpc_endpoint_get_peer(args->endpoint); char* proxy_name = grpc_endpoint_get_peer(args->endpoint);
gpr_log(GPR_INFO, "Connecting to server %s via HTTP proxy %s", server_name, gpr_log(GPR_INFO, "Connecting to server %s via HTTP proxy %s", server_name,
@ -302,15 +301,18 @@ static void http_connect_handshaker_do_handshake(
gpr_free(proxy_name); gpr_free(proxy_name);
// Construct HTTP CONNECT request. // Construct HTTP CONNECT request.
grpc_httpcli_request request; grpc_httpcli_request request;
memset(&request, 0, sizeof(request));
request.host = server_name; request.host = server_name;
request.ssl_host_override = nullptr;
request.http.method = (char*)"CONNECT"; request.http.method = (char*)"CONNECT";
request.http.path = server_name; request.http.path = server_name;
request.http.version = GRPC_HTTP_HTTP10; // Set by OnReadDone
request.http.hdrs = headers; request.http.hdrs = headers;
request.http.hdr_count = num_headers; request.http.hdr_count = num_headers;
request.http.body_length = 0;
request.http.body = nullptr;
request.handshaker = &grpc_httpcli_plaintext; request.handshaker = &grpc_httpcli_plaintext;
grpc_slice request_slice = grpc_httpcli_format_connect_request(&request); grpc_slice request_slice = grpc_httpcli_format_connect_request(&request);
grpc_slice_buffer_add(&handshaker->write_buffer, request_slice); grpc_slice_buffer_add(&write_buffer_, request_slice);
// Clean up. // Clean up.
gpr_free(headers); gpr_free(headers);
for (size_t i = 0; i < num_header_strings; ++i) { for (size_t i = 0; i < num_header_strings; ++i) {
@ -318,54 +320,42 @@ static void http_connect_handshaker_do_handshake(
} }
gpr_free(header_strings); gpr_free(header_strings);
// Take a new ref to be held by the write callback. // Take a new ref to be held by the write callback.
gpr_ref(&handshaker->refcount); Ref().release();
grpc_endpoint_write(args->endpoint, &handshaker->write_buffer, grpc_endpoint_write(args->endpoint, &write_buffer_, &request_done_closure_,
&handshaker->request_done_closure, nullptr); nullptr);
gpr_mu_unlock(&handshaker->mu);
} }
static const grpc_handshaker_vtable http_connect_handshaker_vtable = { HttpConnectHandshaker::HttpConnectHandshaker() {
http_connect_handshaker_destroy, http_connect_handshaker_shutdown, gpr_mu_init(&mu_);
http_connect_handshaker_do_handshake, "http_connect"}; grpc_slice_buffer_init(&write_buffer_);
GRPC_CLOSURE_INIT(&request_done_closure_, &HttpConnectHandshaker::OnWriteDone,
static grpc_handshaker* grpc_http_connect_handshaker_create() { this, grpc_schedule_on_exec_ctx);
http_connect_handshaker* handshaker = GRPC_CLOSURE_INIT(&response_read_closure_, &HttpConnectHandshaker::OnReadDone,
static_cast<http_connect_handshaker*>(gpr_malloc(sizeof(*handshaker))); this, grpc_schedule_on_exec_ctx);
memset(handshaker, 0, sizeof(*handshaker)); grpc_http_parser_init(&http_parser_, GRPC_HTTP_RESPONSE, &http_response_);
grpc_handshaker_init(&http_connect_handshaker_vtable, &handshaker->base);
gpr_mu_init(&handshaker->mu);
gpr_ref_init(&handshaker->refcount, 1);
grpc_slice_buffer_init(&handshaker->write_buffer);
GRPC_CLOSURE_INIT(&handshaker->request_done_closure, on_write_done,
handshaker, grpc_schedule_on_exec_ctx);
GRPC_CLOSURE_INIT(&handshaker->response_read_closure, on_read_done,
handshaker, grpc_schedule_on_exec_ctx);
grpc_http_parser_init(&handshaker->http_parser, GRPC_HTTP_RESPONSE,
&handshaker->http_response);
return &handshaker->base;
} }
// //
// handshaker factory // handshaker factory
// //
static void handshaker_factory_add_handshakers( class HttpConnectHandshakerFactory : public HandshakerFactory {
grpc_handshaker_factory* factory, const grpc_channel_args* args, public:
void AddHandshakers(const grpc_channel_args* args,
grpc_pollset_set* interested_parties, grpc_pollset_set* interested_parties,
grpc_handshake_manager* handshake_mgr) { HandshakeManager* handshake_mgr) override {
grpc_handshake_manager_add(handshake_mgr, handshake_mgr->Add(MakeRefCounted<HttpConnectHandshaker>());
grpc_http_connect_handshaker_create()); }
} ~HttpConnectHandshakerFactory() override = default;
};
static void handshaker_factory_destroy(grpc_handshaker_factory* factory) {}
static const grpc_handshaker_factory_vtable handshaker_factory_vtable = { } // namespace
handshaker_factory_add_handshakers, handshaker_factory_destroy};
static grpc_handshaker_factory handshaker_factory = { } // namespace grpc_core
&handshaker_factory_vtable};
void grpc_http_connect_register_handshaker_factory() { void grpc_http_connect_register_handshaker_factory() {
grpc_handshaker_factory_register(true /* at_start */, HANDSHAKER_CLIENT, using namespace grpc_core;
&handshaker_factory); HandshakerRegistry::RegisterHandshakerFactory(
true /* at_start */, HANDSHAKER_CLIENT,
UniquePtr<HandshakerFactory>(New<HttpConnectHandshakerFactory>()));
} }

@ -55,7 +55,7 @@ typedef struct {
grpc_closure connected; grpc_closure connected;
grpc_handshake_manager* handshake_mgr; grpc_core::RefCountedPtr<grpc_core::HandshakeManager> handshake_mgr;
} chttp2_connector; } chttp2_connector;
static void chttp2_connector_ref(grpc_connector* con) { static void chttp2_connector_ref(grpc_connector* con) {
@ -79,7 +79,7 @@ static void chttp2_connector_shutdown(grpc_connector* con, grpc_error* why) {
gpr_mu_lock(&c->mu); gpr_mu_lock(&c->mu);
c->shutdown = true; c->shutdown = true;
if (c->handshake_mgr != nullptr) { if (c->handshake_mgr != nullptr) {
grpc_handshake_manager_shutdown(c->handshake_mgr, GRPC_ERROR_REF(why)); c->handshake_mgr->Shutdown(GRPC_ERROR_REF(why));
} }
// If handshaking is not yet in progress, shutdown the endpoint. // If handshaking is not yet in progress, shutdown the endpoint.
// Otherwise, the handshaker will do this for us. // Otherwise, the handshaker will do this for us.
@ -91,7 +91,7 @@ static void chttp2_connector_shutdown(grpc_connector* con, grpc_error* why) {
} }
static void on_handshake_done(void* arg, grpc_error* error) { static void on_handshake_done(void* arg, grpc_error* error) {
grpc_handshaker_args* args = static_cast<grpc_handshaker_args*>(arg); auto* args = static_cast<grpc_core::HandshakerArgs*>(arg);
chttp2_connector* c = static_cast<chttp2_connector*>(args->user_data); chttp2_connector* c = static_cast<chttp2_connector*>(args->user_data);
gpr_mu_lock(&c->mu); gpr_mu_lock(&c->mu);
if (error != GRPC_ERROR_NONE || c->shutdown) { if (error != GRPC_ERROR_NONE || c->shutdown) {
@ -152,20 +152,20 @@ static void on_handshake_done(void* arg, grpc_error* error) {
grpc_closure* notify = c->notify; grpc_closure* notify = c->notify;
c->notify = nullptr; c->notify = nullptr;
GRPC_CLOSURE_SCHED(notify, error); GRPC_CLOSURE_SCHED(notify, error);
grpc_handshake_manager_destroy(c->handshake_mgr); c->handshake_mgr.reset();
c->handshake_mgr = nullptr;
gpr_mu_unlock(&c->mu); gpr_mu_unlock(&c->mu);
chttp2_connector_unref(reinterpret_cast<grpc_connector*>(c)); chttp2_connector_unref(reinterpret_cast<grpc_connector*>(c));
} }
static void start_handshake_locked(chttp2_connector* c) { static void start_handshake_locked(chttp2_connector* c) {
c->handshake_mgr = grpc_handshake_manager_create(); c->handshake_mgr = grpc_core::MakeRefCounted<grpc_core::HandshakeManager>();
grpc_handshakers_add(HANDSHAKER_CLIENT, c->args.channel_args, grpc_core::HandshakerRegistry::AddHandshakers(
c->args.interested_parties, c->handshake_mgr); grpc_core::HANDSHAKER_CLIENT, c->args.channel_args,
c->args.interested_parties, c->handshake_mgr.get());
grpc_endpoint_add_to_pollset_set(c->endpoint, c->args.interested_parties); grpc_endpoint_add_to_pollset_set(c->endpoint, c->args.interested_parties);
grpc_handshake_manager_do_handshake( c->handshake_mgr->DoHandshake(c->endpoint, c->args.channel_args,
c->handshake_mgr, c->endpoint, c->args.channel_args, c->args.deadline, c->args.deadline, nullptr /* acceptor */,
nullptr /* acceptor */, on_handshake_done, c); on_handshake_done, c);
c->endpoint = nullptr; // Endpoint handed off to handshake manager. c->endpoint = nullptr; // Endpoint handed off to handshake manager.
} }

@ -54,7 +54,7 @@ typedef struct {
bool shutdown; bool shutdown;
grpc_closure tcp_server_shutdown_complete; grpc_closure tcp_server_shutdown_complete;
grpc_closure* server_destroy_listener_done; grpc_closure* server_destroy_listener_done;
grpc_handshake_manager* pending_handshake_mgrs; grpc_core::HandshakeManager* pending_handshake_mgrs;
grpc_core::RefCountedPtr<grpc_core::channelz::ListenSocketNode> grpc_core::RefCountedPtr<grpc_core::channelz::ListenSocketNode>
channelz_listen_socket; channelz_listen_socket;
} server_state; } server_state;
@ -64,7 +64,7 @@ typedef struct {
server_state* svr_state; server_state* svr_state;
grpc_pollset* accepting_pollset; grpc_pollset* accepting_pollset;
grpc_tcp_server_acceptor* acceptor; grpc_tcp_server_acceptor* acceptor;
grpc_handshake_manager* handshake_mgr; grpc_core::RefCountedPtr<grpc_core::HandshakeManager> handshake_mgr;
// State for enforcing handshake timeout on receiving HTTP/2 settings. // State for enforcing handshake timeout on receiving HTTP/2 settings.
grpc_chttp2_transport* transport; grpc_chttp2_transport* transport;
grpc_millis deadline; grpc_millis deadline;
@ -112,7 +112,7 @@ static void on_receive_settings(void* arg, grpc_error* error) {
} }
static void on_handshake_done(void* arg, grpc_error* error) { static void on_handshake_done(void* arg, grpc_error* error) {
grpc_handshaker_args* args = static_cast<grpc_handshaker_args*>(arg); auto* args = static_cast<grpc_core::HandshakerArgs*>(arg);
server_connection_state* connection_state = server_connection_state* connection_state =
static_cast<server_connection_state*>(args->user_data); static_cast<server_connection_state*>(args->user_data);
gpr_mu_lock(&connection_state->svr_state->mu); gpr_mu_lock(&connection_state->svr_state->mu);
@ -175,11 +175,10 @@ static void on_handshake_done(void* arg, grpc_error* error) {
} }
} }
} }
grpc_handshake_manager_pending_list_remove( connection_state->handshake_mgr->RemoveFromPendingMgrList(
&connection_state->svr_state->pending_handshake_mgrs, &connection_state->svr_state->pending_handshake_mgrs);
connection_state->handshake_mgr);
gpr_mu_unlock(&connection_state->svr_state->mu); gpr_mu_unlock(&connection_state->svr_state->mu);
grpc_handshake_manager_destroy(connection_state->handshake_mgr); connection_state->handshake_mgr.reset();
gpr_free(connection_state->acceptor); gpr_free(connection_state->acceptor);
grpc_tcp_server_unref(connection_state->svr_state->tcp_server); grpc_tcp_server_unref(connection_state->svr_state->tcp_server);
server_connection_state_unref(connection_state); server_connection_state_unref(connection_state);
@ -211,9 +210,8 @@ static void on_accept(void* arg, grpc_endpoint* tcp,
gpr_free(acceptor); gpr_free(acceptor);
return; return;
} }
grpc_handshake_manager* handshake_mgr = grpc_handshake_manager_create(); auto handshake_mgr = grpc_core::MakeRefCounted<grpc_core::HandshakeManager>();
grpc_handshake_manager_pending_list_add(&state->pending_handshake_mgrs, handshake_mgr->AddToPendingMgrList(&state->pending_handshake_mgrs);
handshake_mgr);
grpc_tcp_server_ref(state->tcp_server); grpc_tcp_server_ref(state->tcp_server);
gpr_mu_unlock(&state->mu); gpr_mu_unlock(&state->mu);
server_connection_state* connection_state = server_connection_state* connection_state =
@ -227,18 +225,18 @@ static void on_accept(void* arg, grpc_endpoint* tcp,
connection_state->interested_parties = grpc_pollset_set_create(); connection_state->interested_parties = grpc_pollset_set_create();
grpc_pollset_set_add_pollset(connection_state->interested_parties, grpc_pollset_set_add_pollset(connection_state->interested_parties,
connection_state->accepting_pollset); connection_state->accepting_pollset);
grpc_handshakers_add(HANDSHAKER_SERVER, state->args, grpc_core::HandshakerRegistry::AddHandshakers(
grpc_core::HANDSHAKER_SERVER, state->args,
connection_state->interested_parties, connection_state->interested_parties,
connection_state->handshake_mgr); connection_state->handshake_mgr.get());
const grpc_arg* timeout_arg = const grpc_arg* timeout_arg =
grpc_channel_args_find(state->args, GRPC_ARG_SERVER_HANDSHAKE_TIMEOUT_MS); grpc_channel_args_find(state->args, GRPC_ARG_SERVER_HANDSHAKE_TIMEOUT_MS);
connection_state->deadline = connection_state->deadline =
grpc_core::ExecCtx::Get()->Now() + grpc_core::ExecCtx::Get()->Now() +
grpc_channel_arg_get_integer(timeout_arg, grpc_channel_arg_get_integer(timeout_arg,
{120 * GPR_MS_PER_SEC, 1, INT_MAX}); {120 * GPR_MS_PER_SEC, 1, INT_MAX});
grpc_handshake_manager_do_handshake(connection_state->handshake_mgr, tcp, connection_state->handshake_mgr->DoHandshake(
state->args, connection_state->deadline, tcp, state->args, connection_state->deadline, acceptor, on_handshake_done,
acceptor, on_handshake_done,
connection_state); connection_state);
} }
@ -260,8 +258,9 @@ static void tcp_server_shutdown_complete(void* arg, grpc_error* error) {
gpr_mu_lock(&state->mu); gpr_mu_lock(&state->mu);
grpc_closure* destroy_done = state->server_destroy_listener_done; grpc_closure* destroy_done = state->server_destroy_listener_done;
GPR_ASSERT(state->shutdown); GPR_ASSERT(state->shutdown);
grpc_handshake_manager_pending_list_shutdown_all( if (state->pending_handshake_mgrs != nullptr) {
state->pending_handshake_mgrs, GRPC_ERROR_REF(error)); state->pending_handshake_mgrs->ShutdownAllPending(GRPC_ERROR_REF(error));
}
state->channelz_listen_socket.reset(); state->channelz_listen_socket.reset();
gpr_mu_unlock(&state->mu); gpr_mu_unlock(&state->mu);
// Flush queued work before destroying handshaker factory, since that // Flush queued work before destroying handshaker factory, since that

@ -30,302 +30,229 @@
#include "src/core/lib/iomgr/timer.h" #include "src/core/lib/iomgr/timer.h"
#include "src/core/lib/slice/slice_internal.h" #include "src/core/lib/slice/slice_internal.h"
grpc_core::TraceFlag grpc_handshaker_trace(false, "handshaker"); namespace grpc_core {
// TraceFlag grpc_handshaker_trace(false, "handshaker");
// grpc_handshaker
//
void grpc_handshaker_init(const grpc_handshaker_vtable* vtable, namespace {
grpc_handshaker* handshaker) {
handshaker->vtable = vtable;
}
void grpc_handshaker_destroy(grpc_handshaker* handshaker) {
handshaker->vtable->destroy(handshaker);
}
void grpc_handshaker_shutdown(grpc_handshaker* handshaker, grpc_error* why) {
handshaker->vtable->shutdown(handshaker, why);
}
void grpc_handshaker_do_handshake(grpc_handshaker* handshaker,
grpc_tcp_server_acceptor* acceptor,
grpc_closure* on_handshake_done,
grpc_handshaker_args* args) {
handshaker->vtable->do_handshake(handshaker, acceptor, on_handshake_done,
args);
}
const char* grpc_handshaker_name(grpc_handshaker* handshaker) { char* HandshakerArgsString(HandshakerArgs* args) {
return handshaker->vtable->name; char* args_str = grpc_channel_args_string(args->args);
size_t num_args = args->args != nullptr ? args->args->num_args : 0;
size_t read_buffer_length =
args->read_buffer != nullptr ? args->read_buffer->length : 0;
char* str;
gpr_asprintf(&str,
"{endpoint=%p, args=%p {size=%" PRIuPTR
": %s}, read_buffer=%p (length=%" PRIuPTR "), exit_early=%d}",
args->endpoint, args->args, num_args, args_str,
args->read_buffer, read_buffer_length, args->exit_early);
gpr_free(args_str);
return str;
} }
// } // namespace
// grpc_handshake_manager
//
struct grpc_handshake_manager { HandshakeManager::HandshakeManager() { gpr_mu_init(&mu_); }
gpr_mu mu;
gpr_refcount refs;
bool shutdown;
// An array of handshakers added via grpc_handshake_manager_add().
size_t count;
grpc_handshaker** handshakers;
// The index of the handshaker to invoke next and closure to invoke it.
size_t index;
grpc_closure call_next_handshaker;
// The acceptor to call the handshakers with.
grpc_tcp_server_acceptor* acceptor;
// Deadline timer across all handshakers.
grpc_timer deadline_timer;
grpc_closure on_timeout;
// The final callback and user_data to invoke after the last handshaker.
grpc_closure on_handshake_done;
void* user_data;
// Handshaker args.
grpc_handshaker_args args;
// Links to the previous and next managers in a list of all pending handshakes
// Used at server side only.
grpc_handshake_manager* prev;
grpc_handshake_manager* next;
};
grpc_handshake_manager* grpc_handshake_manager_create() { /// Add \a mgr to the server side list of all pending handshake managers, the
grpc_handshake_manager* mgr = static_cast<grpc_handshake_manager*>( /// list starts with \a *head.
gpr_zalloc(sizeof(grpc_handshake_manager))); // Not thread-safe. Caller needs to synchronize.
gpr_mu_init(&mgr->mu); void HandshakeManager::AddToPendingMgrList(HandshakeManager** head) {
gpr_ref_init(&mgr->refs, 1); GPR_ASSERT(prev_ == nullptr);
return mgr; GPR_ASSERT(next_ == nullptr);
} next_ = *head;
void grpc_handshake_manager_pending_list_add(grpc_handshake_manager** head,
grpc_handshake_manager* mgr) {
GPR_ASSERT(mgr->prev == nullptr);
GPR_ASSERT(mgr->next == nullptr);
mgr->next = *head;
if (*head) { if (*head) {
(*head)->prev = mgr; (*head)->prev_ = this;
} }
*head = mgr; *head = this;
} }
void grpc_handshake_manager_pending_list_remove(grpc_handshake_manager** head, /// Remove \a mgr from the server side list of all pending handshake managers.
grpc_handshake_manager* mgr) { // Not thread-safe. Caller needs to synchronize.
if (mgr->next != nullptr) { void HandshakeManager::RemoveFromPendingMgrList(HandshakeManager** head) {
mgr->next->prev = mgr->prev; if (next_ != nullptr) {
next_->prev_ = prev_;
} }
if (mgr->prev != nullptr) { if (prev_ != nullptr) {
mgr->prev->next = mgr->next; prev_->next_ = next_;
} else { } else {
GPR_ASSERT(*head == mgr); GPR_ASSERT(*head == this);
*head = mgr->next; *head = next_;
} }
} }
void grpc_handshake_manager_pending_list_shutdown_all( /// Shutdown all pending handshake managers starting at head on the server
grpc_handshake_manager* head, grpc_error* why) { /// side. Not thread-safe. Caller needs to synchronize.
void HandshakeManager::ShutdownAllPending(grpc_error* why) {
auto* head = this;
while (head != nullptr) { while (head != nullptr) {
grpc_handshake_manager_shutdown(head, GRPC_ERROR_REF(why)); head->Shutdown(GRPC_ERROR_REF(why));
head = head->next; head = head->next_;
} }
GRPC_ERROR_UNREF(why); GRPC_ERROR_UNREF(why);
} }
static bool is_power_of_2(size_t n) { return (n & (n - 1)) == 0; } void HandshakeManager::Add(RefCountedPtr<Handshaker> handshaker) {
void grpc_handshake_manager_add(grpc_handshake_manager* mgr,
grpc_handshaker* handshaker) {
if (grpc_handshaker_trace.enabled()) { if (grpc_handshaker_trace.enabled()) {
gpr_log( gpr_log(
GPR_INFO, GPR_INFO,
"handshake_manager %p: adding handshaker %s [%p] at index %" PRIuPTR, "handshake_manager %p: adding handshaker %s [%p] at index %" PRIuPTR,
mgr, grpc_handshaker_name(handshaker), handshaker, mgr->count); this, handshaker->name(), handshaker.get(), handshakers_.size());
}
gpr_mu_lock(&mgr->mu);
// To avoid allocating memory for each handshaker we add, we double
// the number of elements every time we need more.
size_t realloc_count = 0;
if (mgr->count == 0) {
realloc_count = 2;
} else if (mgr->count >= 2 && is_power_of_2(mgr->count)) {
realloc_count = mgr->count * 2;
} }
if (realloc_count > 0) { MutexLock lock(&mu_);
mgr->handshakers = static_cast<grpc_handshaker**>(gpr_realloc( handshakers_.push_back(std::move(handshaker));
mgr->handshakers, realloc_count * sizeof(grpc_handshaker*)));
}
mgr->handshakers[mgr->count++] = handshaker;
gpr_mu_unlock(&mgr->mu);
} }
static void grpc_handshake_manager_unref(grpc_handshake_manager* mgr) { HandshakeManager::~HandshakeManager() {
if (gpr_unref(&mgr->refs)) { handshakers_.clear();
for (size_t i = 0; i < mgr->count; ++i) { gpr_mu_destroy(&mu_);
grpc_handshaker_destroy(mgr->handshakers[i]);
}
gpr_free(mgr->handshakers);
gpr_mu_destroy(&mgr->mu);
gpr_free(mgr);
}
} }
void grpc_handshake_manager_destroy(grpc_handshake_manager* mgr) { void HandshakeManager::Shutdown(grpc_error* why) {
grpc_handshake_manager_unref(mgr); {
} MutexLock lock(&mu_);
void grpc_handshake_manager_shutdown(grpc_handshake_manager* mgr,
grpc_error* why) {
gpr_mu_lock(&mgr->mu);
// Shutdown the handshaker that's currently in progress, if any. // Shutdown the handshaker that's currently in progress, if any.
if (!mgr->shutdown && mgr->index > 0) { if (!is_shutdown_ && index_ > 0) {
mgr->shutdown = true; is_shutdown_ = true;
grpc_handshaker_shutdown(mgr->handshakers[mgr->index - 1], handshakers_[index_ - 1]->Shutdown(GRPC_ERROR_REF(why));
GRPC_ERROR_REF(why)); }
} }
gpr_mu_unlock(&mgr->mu);
GRPC_ERROR_UNREF(why); GRPC_ERROR_UNREF(why);
} }
static char* handshaker_args_string(grpc_handshaker_args* args) {
char* args_str = grpc_channel_args_string(args->args);
size_t num_args = args->args != nullptr ? args->args->num_args : 0;
size_t read_buffer_length =
args->read_buffer != nullptr ? args->read_buffer->length : 0;
char* str;
gpr_asprintf(&str,
"{endpoint=%p, args=%p {size=%" PRIuPTR
": %s}, read_buffer=%p (length=%" PRIuPTR "), exit_early=%d}",
args->endpoint, args->args, num_args, args_str,
args->read_buffer, read_buffer_length, args->exit_early);
gpr_free(args_str);
return str;
}
// Helper function to call either the next handshaker or the // Helper function to call either the next handshaker or the
// on_handshake_done callback. // on_handshake_done callback.
// Returns true if we've scheduled the on_handshake_done callback. // Returns true if we've scheduled the on_handshake_done callback.
static bool call_next_handshaker_locked(grpc_handshake_manager* mgr, bool HandshakeManager::CallNextHandshakerLocked(grpc_error* error) {
grpc_error* error) {
if (grpc_handshaker_trace.enabled()) { if (grpc_handshaker_trace.enabled()) {
char* args_str = handshaker_args_string(&mgr->args); char* args_str = HandshakerArgsString(&args_);
gpr_log(GPR_INFO, gpr_log(GPR_INFO,
"handshake_manager %p: error=%s shutdown=%d index=%" PRIuPTR "handshake_manager %p: error=%s shutdown=%d index=%" PRIuPTR
", args=%s", ", args=%s",
mgr, grpc_error_string(error), mgr->shutdown, mgr->index, args_str); this, grpc_error_string(error), is_shutdown_, index_, args_str);
gpr_free(args_str); gpr_free(args_str);
} }
GPR_ASSERT(mgr->index <= mgr->count); GPR_ASSERT(index_ <= handshakers_.size());
// If we got an error or we've been shut down or we're exiting early or // If we got an error or we've been shut down or we're exiting early or
// we've finished the last handshaker, invoke the on_handshake_done // we've finished the last handshaker, invoke the on_handshake_done
// callback. Otherwise, call the next handshaker. // callback. Otherwise, call the next handshaker.
if (error != GRPC_ERROR_NONE || mgr->shutdown || mgr->args.exit_early || if (error != GRPC_ERROR_NONE || is_shutdown_ || args_.exit_early ||
mgr->index == mgr->count) { index_ == handshakers_.size()) {
if (error == GRPC_ERROR_NONE && mgr->shutdown) { if (error == GRPC_ERROR_NONE && is_shutdown_) {
error = GRPC_ERROR_CREATE_FROM_STATIC_STRING("handshaker shutdown"); error = GRPC_ERROR_CREATE_FROM_STATIC_STRING("handshaker shutdown");
// It is possible that the endpoint has already been destroyed by // It is possible that the endpoint has already been destroyed by
// a shutdown call while this callback was sitting on the ExecCtx // a shutdown call while this callback was sitting on the ExecCtx
// with no error. // with no error.
if (mgr->args.endpoint != nullptr) { if (args_.endpoint != nullptr) {
// TODO(roth): It is currently necessary to shutdown endpoints // TODO(roth): It is currently necessary to shutdown endpoints
// before destroying then, even when we know that there are no // before destroying then, even when we know that there are no
// pending read/write callbacks. This should be fixed, at which // pending read/write callbacks. This should be fixed, at which
// point this can be removed. // point this can be removed.
grpc_endpoint_shutdown(mgr->args.endpoint, GRPC_ERROR_REF(error)); grpc_endpoint_shutdown(args_.endpoint, GRPC_ERROR_REF(error));
grpc_endpoint_destroy(mgr->args.endpoint); grpc_endpoint_destroy(args_.endpoint);
mgr->args.endpoint = nullptr; args_.endpoint = nullptr;
grpc_channel_args_destroy(mgr->args.args); grpc_channel_args_destroy(args_.args);
mgr->args.args = nullptr; args_.args = nullptr;
grpc_slice_buffer_destroy_internal(mgr->args.read_buffer); grpc_slice_buffer_destroy_internal(args_.read_buffer);
gpr_free(mgr->args.read_buffer); gpr_free(args_.read_buffer);
mgr->args.read_buffer = nullptr; args_.read_buffer = nullptr;
} }
} }
if (grpc_handshaker_trace.enabled()) { if (grpc_handshaker_trace.enabled()) {
gpr_log(GPR_INFO, gpr_log(GPR_INFO,
"handshake_manager %p: handshaking complete -- scheduling " "handshake_manager %p: handshaking complete -- scheduling "
"on_handshake_done with error=%s", "on_handshake_done with error=%s",
mgr, grpc_error_string(error)); this, grpc_error_string(error));
} }
// Cancel deadline timer, since we're invoking the on_handshake_done // Cancel deadline timer, since we're invoking the on_handshake_done
// callback now. // callback now.
grpc_timer_cancel(&mgr->deadline_timer); grpc_timer_cancel(&deadline_timer_);
GRPC_CLOSURE_SCHED(&mgr->on_handshake_done, error); GRPC_CLOSURE_SCHED(&on_handshake_done_, error);
mgr->shutdown = true; is_shutdown_ = true;
} else { } else {
auto handshaker = handshakers_[index_];
if (grpc_handshaker_trace.enabled()) { if (grpc_handshaker_trace.enabled()) {
gpr_log( gpr_log(
GPR_INFO, GPR_INFO,
"handshake_manager %p: calling handshaker %s [%p] at index %" PRIuPTR, "handshake_manager %p: calling handshaker %s [%p] at index %" PRIuPTR,
mgr, grpc_handshaker_name(mgr->handshakers[mgr->index]), this, handshaker->name(), handshaker.get(), index_);
mgr->handshakers[mgr->index], mgr->index);
} }
grpc_handshaker_do_handshake(mgr->handshakers[mgr->index], mgr->acceptor, handshaker->DoHandshake(acceptor_, &call_next_handshaker_, &args_);
&mgr->call_next_handshaker, &mgr->args);
} }
++mgr->index; ++index_;
return mgr->shutdown; return is_shutdown_;
} }
// A function used as the handshaker-done callback when chaining void HandshakeManager::CallNextHandshakerFn(void* arg, grpc_error* error) {
// handshakers together. auto* mgr = static_cast<HandshakeManager*>(arg);
static void call_next_handshaker(void* arg, grpc_error* error) { bool done;
grpc_handshake_manager* mgr = static_cast<grpc_handshake_manager*>(arg); {
gpr_mu_lock(&mgr->mu); MutexLock lock(&mgr->mu_);
bool done = call_next_handshaker_locked(mgr, GRPC_ERROR_REF(error)); done = mgr->CallNextHandshakerLocked(GRPC_ERROR_REF(error));
gpr_mu_unlock(&mgr->mu); }
// If we're invoked the final callback, we won't be coming back // If we're invoked the final callback, we won't be coming back
// to this function, so we can release our reference to the // to this function, so we can release our reference to the
// handshake manager. // handshake manager.
if (done) { if (done) {
grpc_handshake_manager_unref(mgr); mgr->Unref();
} }
} }
// Callback invoked when deadline is exceeded. void HandshakeManager::OnTimeoutFn(void* arg, grpc_error* error) {
static void on_timeout(void* arg, grpc_error* error) { auto* mgr = static_cast<HandshakeManager*>(arg);
grpc_handshake_manager* mgr = static_cast<grpc_handshake_manager*>(arg); if (error == GRPC_ERROR_NONE) { // Timer fired, rather than being cancelled
if (error == GRPC_ERROR_NONE) { // Timer fired, rather than being cancelled. mgr->Shutdown(GRPC_ERROR_CREATE_FROM_STATIC_STRING("Handshake timed out"));
grpc_handshake_manager_shutdown(
mgr, GRPC_ERROR_CREATE_FROM_STATIC_STRING("Handshake timed out"));
} }
grpc_handshake_manager_unref(mgr); mgr->Unref();
} }
void grpc_handshake_manager_do_handshake(grpc_handshake_manager* mgr, void HandshakeManager::DoHandshake(grpc_endpoint* endpoint,
grpc_endpoint* endpoint,
const grpc_channel_args* channel_args, const grpc_channel_args* channel_args,
grpc_millis deadline, grpc_millis deadline,
grpc_tcp_server_acceptor* acceptor, grpc_tcp_server_acceptor* acceptor,
grpc_iomgr_cb_func on_handshake_done, grpc_iomgr_cb_func on_handshake_done,
void* user_data) { void* user_data) {
gpr_mu_lock(&mgr->mu); bool done;
GPR_ASSERT(mgr->index == 0); {
GPR_ASSERT(!mgr->shutdown); MutexLock lock(&mu_);
GPR_ASSERT(index_ == 0);
GPR_ASSERT(!is_shutdown_);
// Construct handshaker args. These will be passed through all // Construct handshaker args. These will be passed through all
// handshakers and eventually be freed by the on_handshake_done callback. // handshakers and eventually be freed by the on_handshake_done callback.
mgr->args.endpoint = endpoint; args_.endpoint = endpoint;
mgr->args.args = grpc_channel_args_copy(channel_args); args_.args = grpc_channel_args_copy(channel_args);
mgr->args.user_data = user_data; args_.user_data = user_data;
mgr->args.read_buffer = static_cast<grpc_slice_buffer*>( args_.read_buffer =
gpr_malloc(sizeof(*mgr->args.read_buffer))); static_cast<grpc_slice_buffer*>(gpr_malloc(sizeof(*args_.read_buffer)));
grpc_slice_buffer_init(mgr->args.read_buffer); grpc_slice_buffer_init(args_.read_buffer);
// Initialize state needed for calling handshakers. // Initialize state needed for calling handshakers.
mgr->acceptor = acceptor; acceptor_ = acceptor;
GRPC_CLOSURE_INIT(&mgr->call_next_handshaker, call_next_handshaker, mgr, GRPC_CLOSURE_INIT(&call_next_handshaker_,
&HandshakeManager::CallNextHandshakerFn, this,
grpc_schedule_on_exec_ctx); grpc_schedule_on_exec_ctx);
GRPC_CLOSURE_INIT(&mgr->on_handshake_done, on_handshake_done, &mgr->args, GRPC_CLOSURE_INIT(&on_handshake_done_, on_handshake_done, &args_,
grpc_schedule_on_exec_ctx); grpc_schedule_on_exec_ctx);
// Start deadline timer, which owns a ref. // Start deadline timer, which owns a ref.
gpr_ref(&mgr->refs); Ref().release();
GRPC_CLOSURE_INIT(&mgr->on_timeout, on_timeout, mgr, GRPC_CLOSURE_INIT(&on_timeout_, &HandshakeManager::OnTimeoutFn, this,
grpc_schedule_on_exec_ctx); grpc_schedule_on_exec_ctx);
grpc_timer_init(&mgr->deadline_timer, deadline, &mgr->on_timeout); grpc_timer_init(&deadline_timer_, deadline, &on_timeout_);
// Start first handshaker, which also owns a ref. // Start first handshaker, which also owns a ref.
gpr_ref(&mgr->refs); Ref().release();
bool done = call_next_handshaker_locked(mgr, GRPC_ERROR_NONE); done = CallNextHandshakerLocked(GRPC_ERROR_NONE);
gpr_mu_unlock(&mgr->mu); }
if (done) { if (done) {
grpc_handshake_manager_unref(mgr); Unref();
} }
} }
} // namespace grpc_core
void grpc_handshake_manager_add(grpc_handshake_manager* mgr,
grpc_handshaker* handshaker) {
// This is a transition method to aid the API change for handshakers.
using namespace grpc_core;
RefCountedPtr<Handshaker> refd_hs(static_cast<Handshaker*>(handshaker));
mgr->Add(refd_hs);
}

@ -21,12 +21,21 @@
#include <grpc/support/port_platform.h> #include <grpc/support/port_platform.h>
#include <grpc/support/string_util.h>
#include <grpc/impl/codegen/grpc_types.h> #include <grpc/impl/codegen/grpc_types.h>
#include "src/core/lib/channel/channel_args.h"
#include "src/core/lib/gprpp/inlined_vector.h"
#include "src/core/lib/gprpp/mutex_lock.h"
#include "src/core/lib/gprpp/ref_counted.h"
#include "src/core/lib/iomgr/closure.h" #include "src/core/lib/iomgr/closure.h"
#include "src/core/lib/iomgr/endpoint.h" #include "src/core/lib/iomgr/endpoint.h"
#include "src/core/lib/iomgr/exec_ctx.h" #include "src/core/lib/iomgr/exec_ctx.h"
#include "src/core/lib/iomgr/tcp_server.h" #include "src/core/lib/iomgr/tcp_server.h"
#include "src/core/lib/iomgr/timer.h"
namespace grpc_core {
/// Handshakers are used to perform initial handshakes on a connection /// Handshakers are used to perform initial handshakes on a connection
/// before the client sends the initial request. Some examples of what /// before the client sends the initial request. Some examples of what
@ -35,12 +44,6 @@
/// ///
/// In general, handshakers should be used via a handshake manager. /// In general, handshakers should be used via a handshake manager.
///
/// grpc_handshaker
///
typedef struct grpc_handshaker grpc_handshaker;
/// Arguments passed through handshakers and to the on_handshake_done callback. /// Arguments passed through handshakers and to the on_handshake_done callback.
/// ///
/// For handshakers, all members are input/output parameters; for /// For handshakers, all members are input/output parameters; for
@ -55,115 +58,121 @@ typedef struct grpc_handshaker grpc_handshaker;
/// ///
/// For the on_handshake_done callback, all members are input arguments, /// For the on_handshake_done callback, all members are input arguments,
/// which the callback takes ownership of. /// which the callback takes ownership of.
typedef struct { struct HandshakerArgs {
grpc_endpoint* endpoint; grpc_endpoint* endpoint = nullptr;
grpc_channel_args* args; grpc_channel_args* args = nullptr;
grpc_slice_buffer* read_buffer; grpc_slice_buffer* read_buffer = nullptr;
// A handshaker may set this to true before invoking on_handshake_done // A handshaker may set this to true before invoking on_handshake_done
// to indicate that subsequent handshakers should be skipped. // to indicate that subsequent handshakers should be skipped.
bool exit_early; bool exit_early = false;
// User data passed through the handshake manager. Not used by // User data passed through the handshake manager. Not used by
// individual handshakers. // individual handshakers.
void* user_data; void* user_data = nullptr;
} grpc_handshaker_args; };
typedef struct {
/// Destroys the handshaker.
void (*destroy)(grpc_handshaker* handshaker);
/// Shuts down the handshaker (e.g., to clean up when the operation is ///
/// aborted in the middle). /// Handshaker
void (*shutdown)(grpc_handshaker* handshaker, grpc_error* why); ///
/// Performs handshaking, modifying \a args as needed (e.g., to class Handshaker : public RefCounted<Handshaker> {
/// replace \a endpoint with a wrapped endpoint). public:
/// When finished, invokes \a on_handshake_done. virtual ~Handshaker() = default;
/// \a acceptor will be NULL for client-side handshakers. virtual void Shutdown(grpc_error* why) GRPC_ABSTRACT;
void (*do_handshake)(grpc_handshaker* handshaker, virtual void DoHandshake(grpc_tcp_server_acceptor* acceptor,
grpc_tcp_server_acceptor* acceptor,
grpc_closure* on_handshake_done, grpc_closure* on_handshake_done,
grpc_handshaker_args* args); HandshakerArgs* args) GRPC_ABSTRACT;
virtual const char* name() const GRPC_ABSTRACT;
GRPC_ABSTRACT_BASE_CLASS
};
/// The name of the handshaker, for debugging purposes. //
const char* name; // HandshakeManager
} grpc_handshaker_vtable; //
/// Base struct. To subclass, make this the first member of the class HandshakeManager : public RefCounted<HandshakeManager> {
/// implementation struct. public:
struct grpc_handshaker { HandshakeManager();
const grpc_handshaker_vtable* vtable; ~HandshakeManager();
};
/// Called by concrete implementations to initialize the base struct. /// Add \a mgr to the server side list of all pending handshake managers, the
void grpc_handshaker_init(const grpc_handshaker_vtable* vtable, /// list starts with \a *head.
grpc_handshaker* handshaker); // Not thread-safe. Caller needs to synchronize.
void AddToPendingMgrList(HandshakeManager** head);
void grpc_handshaker_destroy(grpc_handshaker* handshaker); /// Remove \a mgr from the server side list of all pending handshake managers.
void grpc_handshaker_shutdown(grpc_handshaker* handshaker, grpc_error* why); // Not thread-safe. Caller needs to synchronize.
void grpc_handshaker_do_handshake(grpc_handshaker* handshaker, void RemoveFromPendingMgrList(HandshakeManager** head);
grpc_tcp_server_acceptor* acceptor,
grpc_closure* on_handshake_done,
grpc_handshaker_args* args);
const char* grpc_handshaker_name(grpc_handshaker* handshaker);
/// /// Shutdown all pending handshake managers starting at head on the server
/// grpc_handshake_manager /// side. Not thread-safe. Caller needs to synchronize.
/// void ShutdownAllPending(grpc_error* why);
typedef struct grpc_handshake_manager grpc_handshake_manager; /// Adds a handshaker to the handshake manager.
/// Takes ownership of \a handshaker.
void Add(RefCountedPtr<Handshaker> handshaker);
/// Creates a new handshake manager. Caller takes ownership. /// Shuts down the handshake manager (e.g., to clean up when the operation is
grpc_handshake_manager* grpc_handshake_manager_create(); /// aborted in the middle).
void Shutdown(grpc_error* why);
/// Invokes handshakers in the order they were added.
/// Takes ownership of \a endpoint, and then passes that ownership to
/// the \a on_handshake_done callback.
/// Does NOT take ownership of \a channel_args. Instead, makes a copy before
/// invoking the first handshaker.
/// \a acceptor will be nullptr for client-side handshakers.
///
/// When done, invokes \a on_handshake_done with a HandshakerArgs
/// object as its argument. If the callback is invoked with error !=
/// GRPC_ERROR_NONE, then handshaking failed and the handshaker has done
/// the necessary clean-up. Otherwise, the callback takes ownership of
/// the arguments.
void DoHandshake(grpc_endpoint* endpoint,
const grpc_channel_args* channel_args, grpc_millis deadline,
grpc_tcp_server_acceptor* acceptor,
grpc_iomgr_cb_func on_handshake_done, void* user_data);
private:
bool CallNextHandshakerLocked(grpc_error* error);
// A function used as the handshaker-done callback when chaining
// handshakers together.
static void CallNextHandshakerFn(void* arg, grpc_error* error);
// Callback invoked when deadline is exceeded.
static void OnTimeoutFn(void* arg, grpc_error* error);
static const size_t HANDSHAKERS_INIT_SIZE = 2;
gpr_mu mu_;
bool is_shutdown_ = false;
// An array of handshakers added via grpc_handshake_manager_add().
InlinedVector<RefCountedPtr<Handshaker>, HANDSHAKERS_INIT_SIZE> handshakers_;
// The index of the handshaker to invoke next and closure to invoke it.
size_t index_ = 0;
grpc_closure call_next_handshaker_;
// The acceptor to call the handshakers with.
grpc_tcp_server_acceptor* acceptor_;
// Deadline timer across all handshakers.
grpc_timer deadline_timer_;
grpc_closure on_timeout_;
// The final callback and user_data to invoke after the last handshaker.
grpc_closure on_handshake_done_;
// Handshaker args.
HandshakerArgs args_;
// Links to the previous and next managers in a list of all pending handshakes
// Used at server side only.
HandshakeManager* prev_ = nullptr;
HandshakeManager* next_ = nullptr;
};
} // namespace grpc_core
/// Adds a handshaker to the handshake manager. // TODO(arjunroy): These are transitional to account for the new handshaker API
/// Takes ownership of \a handshaker. // and will eventually be removed entirely.
typedef grpc_core::HandshakeManager grpc_handshake_manager;
typedef grpc_core::Handshaker grpc_handshaker;
void grpc_handshake_manager_add(grpc_handshake_manager* mgr, void grpc_handshake_manager_add(grpc_handshake_manager* mgr,
grpc_handshaker* handshaker); grpc_handshaker* handshaker);
/// Destroys the handshake manager.
void grpc_handshake_manager_destroy(grpc_handshake_manager* mgr);
/// Shuts down the handshake manager (e.g., to clean up when the operation is
/// aborted in the middle).
/// The caller must still call grpc_handshake_manager_destroy() after
/// calling this function.
void grpc_handshake_manager_shutdown(grpc_handshake_manager* mgr,
grpc_error* why);
/// Invokes handshakers in the order they were added.
/// Takes ownership of \a endpoint, and then passes that ownership to
/// the \a on_handshake_done callback.
/// Does NOT take ownership of \a channel_args. Instead, makes a copy before
/// invoking the first handshaker.
/// \a acceptor will be nullptr for client-side handshakers.
///
/// When done, invokes \a on_handshake_done with a grpc_handshaker_args
/// object as its argument. If the callback is invoked with error !=
/// GRPC_ERROR_NONE, then handshaking failed and the handshaker has done
/// the necessary clean-up. Otherwise, the callback takes ownership of
/// the arguments.
void grpc_handshake_manager_do_handshake(grpc_handshake_manager* mgr,
grpc_endpoint* endpoint,
const grpc_channel_args* channel_args,
grpc_millis deadline,
grpc_tcp_server_acceptor* acceptor,
grpc_iomgr_cb_func on_handshake_done,
void* user_data);
/// Add \a mgr to the server side list of all pending handshake managers, the
/// list starts with \a *head.
// Not thread-safe. Caller needs to synchronize.
void grpc_handshake_manager_pending_list_add(grpc_handshake_manager** head,
grpc_handshake_manager* mgr);
/// Remove \a mgr from the server side list of all pending handshake managers.
// Not thread-safe. Caller needs to synchronize.
void grpc_handshake_manager_pending_list_remove(grpc_handshake_manager** head,
grpc_handshake_manager* mgr);
/// Shutdown all pending handshake managers on the server side.
// Not thread-safe. Caller needs to synchronize.
void grpc_handshake_manager_pending_list_shutdown_all(
grpc_handshake_manager* head, grpc_error* why);
#endif /* GRPC_CORE_LIB_CHANNEL_HANDSHAKER_H */ #endif /* GRPC_CORE_LIB_CHANNEL_HANDSHAKER_H */

@ -1,42 +0,0 @@
/*
*
* Copyright 2016 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
#include <grpc/support/port_platform.h>
#include "src/core/lib/channel/handshaker_factory.h"
#include <grpc/support/log.h>
void grpc_handshaker_factory_add_handshakers(
grpc_handshaker_factory* handshaker_factory, const grpc_channel_args* args,
grpc_pollset_set* interested_parties,
grpc_handshake_manager* handshake_mgr) {
if (handshaker_factory != nullptr) {
GPR_ASSERT(handshaker_factory->vtable != nullptr);
handshaker_factory->vtable->add_handshakers(
handshaker_factory, args, interested_parties, handshake_mgr);
}
}
void grpc_handshaker_factory_destroy(
grpc_handshaker_factory* handshaker_factory) {
if (handshaker_factory != nullptr) {
GPR_ASSERT(handshaker_factory->vtable != nullptr);
handshaker_factory->vtable->destroy(handshaker_factory);
}
}

@ -27,26 +27,18 @@
// A handshaker factory is used to create handshakers. // A handshaker factory is used to create handshakers.
typedef struct grpc_handshaker_factory grpc_handshaker_factory; namespace grpc_core {
typedef struct { class HandshakerFactory {
void (*add_handshakers)(grpc_handshaker_factory* handshaker_factory, public:
const grpc_channel_args* args, virtual void AddHandshakers(const grpc_channel_args* args,
grpc_pollset_set* interested_parties, grpc_pollset_set* interested_parties,
grpc_handshake_manager* handshake_mgr); HandshakeManager* handshake_mgr) GRPC_ABSTRACT;
void (*destroy)(grpc_handshaker_factory* handshaker_factory); virtual ~HandshakerFactory() = default;
} grpc_handshaker_factory_vtable;
struct grpc_handshaker_factory { GRPC_ABSTRACT_BASE_CLASS
const grpc_handshaker_factory_vtable* vtable;
}; };
void grpc_handshaker_factory_add_handshakers( } // namespace grpc_core
grpc_handshaker_factory* handshaker_factory, const grpc_channel_args* args,
grpc_pollset_set* interested_parties,
grpc_handshake_manager* handshake_mgr);
void grpc_handshaker_factory_destroy(
grpc_handshaker_factory* handshaker_factory);
#endif /* GRPC_CORE_LIB_CHANNEL_HANDSHAKER_FACTORY_H */ #endif /* GRPC_CORE_LIB_CHANNEL_HANDSHAKER_FACTORY_H */

@ -19,8 +19,11 @@
#include <grpc/support/port_platform.h> #include <grpc/support/port_platform.h>
#include "src/core/lib/channel/handshaker_registry.h" #include "src/core/lib/channel/handshaker_registry.h"
#include "src/core/lib/gprpp/inlined_vector.h"
#include "src/core/lib/gprpp/memory.h"
#include <string.h> #include <string.h>
#include <algorithm>
#include <grpc/support/alloc.h> #include <grpc/support/alloc.h>
@ -28,74 +31,83 @@
// grpc_handshaker_factory_list // grpc_handshaker_factory_list
// //
typedef struct { namespace grpc_core {
grpc_handshaker_factory** list;
size_t num_factories; namespace {
} grpc_handshaker_factory_list;
static void grpc_handshaker_factory_list_register(
grpc_handshaker_factory_list* list, bool at_start,
grpc_handshaker_factory* factory) {
list->list = static_cast<grpc_handshaker_factory**>(gpr_realloc(
list->list,
(list->num_factories + 1) * sizeof(grpc_handshaker_factory*)));
if (at_start) {
memmove(list->list + 1, list->list,
sizeof(grpc_handshaker_factory*) * list->num_factories);
list->list[0] = factory;
} else {
list->list[list->num_factories] = factory;
}
++list->num_factories;
}
static void grpc_handshaker_factory_list_add_handshakers( class HandshakerFactoryList {
grpc_handshaker_factory_list* list, const grpc_channel_args* args, public:
void Register(bool at_start, UniquePtr<HandshakerFactory> factory);
void AddHandshakers(const grpc_channel_args* args,
grpc_pollset_set* interested_parties, grpc_pollset_set* interested_parties,
grpc_handshake_manager* handshake_mgr) { HandshakeManager* handshake_mgr);
for (size_t i = 0; i < list->num_factories; ++i) {
grpc_handshaker_factory_add_handshakers(list->list[i], args, private:
interested_parties, handshake_mgr); InlinedVector<UniquePtr<HandshakerFactory>, 2> factories_;
};
HandshakerFactoryList* g_handshaker_factory_lists = nullptr;
} // namespace
void HandshakerFactoryList::Register(bool at_start,
UniquePtr<HandshakerFactory> factory) {
factories_.push_back(std::move(factory));
if (at_start) {
auto* end = &factories_[factories_.size() - 1];
std::rotate(&factories_[0], end, end + 1);
} }
} }
static void grpc_handshaker_factory_list_destroy( void HandshakerFactoryList::AddHandshakers(const grpc_channel_args* args,
grpc_handshaker_factory_list* list) { grpc_pollset_set* interested_parties,
for (size_t i = 0; i < list->num_factories; ++i) { HandshakeManager* handshake_mgr) {
grpc_handshaker_factory_destroy(list->list[i]); for (size_t idx = 0; idx < factories_.size(); ++idx) {
auto& handshaker_factory = factories_[idx];
handshaker_factory->AddHandshakers(args, interested_parties, handshake_mgr);
} }
gpr_free(list->list);
} }
// //
// plugin // plugin
// //
static grpc_handshaker_factory_list void HandshakerRegistry::Init() {
g_handshaker_factory_lists[NUM_HANDSHAKER_TYPES]; GPR_ASSERT(g_handshaker_factory_lists == nullptr);
g_handshaker_factory_lists = static_cast<HandshakerFactoryList*>(
void grpc_handshaker_factory_registry_init() { gpr_malloc(sizeof(*g_handshaker_factory_lists) * NUM_HANDSHAKER_TYPES));
memset(g_handshaker_factory_lists, 0, sizeof(g_handshaker_factory_lists)); GPR_ASSERT(g_handshaker_factory_lists != nullptr);
for (auto idx = 0; idx < NUM_HANDSHAKER_TYPES; ++idx) {
auto factory_list = g_handshaker_factory_lists + idx;
new (factory_list) HandshakerFactoryList();
}
} }
void grpc_handshaker_factory_registry_shutdown() { void HandshakerRegistry::Shutdown() {
for (size_t i = 0; i < NUM_HANDSHAKER_TYPES; ++i) { GPR_ASSERT(g_handshaker_factory_lists != nullptr);
grpc_handshaker_factory_list_destroy(&g_handshaker_factory_lists[i]); for (auto idx = 0; idx < NUM_HANDSHAKER_TYPES; ++idx) {
auto factory_list = g_handshaker_factory_lists + idx;
factory_list->~HandshakerFactoryList();
} }
gpr_free(g_handshaker_factory_lists);
g_handshaker_factory_lists = nullptr;
} }
void grpc_handshaker_factory_register(bool at_start, void HandshakerRegistry::RegisterHandshakerFactory(
grpc_handshaker_type handshaker_type, bool at_start, HandshakerType handshaker_type,
grpc_handshaker_factory* factory) { UniquePtr<HandshakerFactory> factory) {
grpc_handshaker_factory_list_register( GPR_ASSERT(g_handshaker_factory_lists != nullptr);
&g_handshaker_factory_lists[handshaker_type], at_start, factory); auto& factory_list = g_handshaker_factory_lists[handshaker_type];
factory_list.Register(at_start, std::move(factory));
} }
void grpc_handshakers_add(grpc_handshaker_type handshaker_type, void HandshakerRegistry::AddHandshakers(HandshakerType handshaker_type,
const grpc_channel_args* args, const grpc_channel_args* args,
grpc_pollset_set* interested_parties, grpc_pollset_set* interested_parties,
grpc_handshake_manager* handshake_mgr) { HandshakeManager* handshake_mgr) {
grpc_handshaker_factory_list_add_handshakers( GPR_ASSERT(g_handshaker_factory_lists != nullptr);
&g_handshaker_factory_lists[handshaker_type], args, interested_parties, auto& factory_list = g_handshaker_factory_lists[handshaker_type];
handshake_mgr); factory_list.AddHandshakers(args, interested_parties, handshake_mgr);
} }
} // namespace grpc_core

@ -25,25 +25,30 @@
#include "src/core/lib/channel/handshaker_factory.h" #include "src/core/lib/channel/handshaker_factory.h"
namespace grpc_core {
typedef enum { typedef enum {
HANDSHAKER_CLIENT = 0, HANDSHAKER_CLIENT = 0,
HANDSHAKER_SERVER, HANDSHAKER_SERVER,
NUM_HANDSHAKER_TYPES, // Must be last. NUM_HANDSHAKER_TYPES, // Must be last.
} grpc_handshaker_type; } HandshakerType;
void grpc_handshaker_factory_registry_init(); class HandshakerRegistry {
void grpc_handshaker_factory_registry_shutdown(); public:
/// Registers a new handshaker factory. Takes ownership.
/// Registers a new handshaker factory. Takes ownership. /// If \a at_start is true, the new handshaker will be at the beginning of
/// If \a at_start is true, the new handshaker will be at the beginning of /// the list. Otherwise, it will be added to the end.
/// the list. Otherwise, it will be added to the end. static void RegisterHandshakerFactory(bool at_start,
void grpc_handshaker_factory_register(bool at_start, HandshakerType handshaker_type,
grpc_handshaker_type handshaker_type, UniquePtr<HandshakerFactory> factory);
grpc_handshaker_factory* factory); static void AddHandshakers(HandshakerType handshaker_type,
void grpc_handshakers_add(grpc_handshaker_type handshaker_type,
const grpc_channel_args* args, const grpc_channel_args* args,
grpc_pollset_set* interested_parties, grpc_pollset_set* interested_parties,
grpc_handshake_manager* handshake_mgr); HandshakeManager* handshake_mgr);
static void Init();
static void Shutdown();
};
} // namespace grpc_core
#endif /* GRPC_CORE_LIB_CHANNEL_HANDSHAKER_REGISTRY_H */ #endif /* GRPC_CORE_LIB_CHANNEL_HANDSHAKER_REGISTRY_H */

@ -67,7 +67,7 @@ class grpc_httpcli_ssl_channel_security_connector final
} }
void add_handshakers(grpc_pollset_set* interested_parties, void add_handshakers(grpc_pollset_set* interested_parties,
grpc_handshake_manager* handshake_mgr) override { grpc_core::HandshakeManager* handshake_mgr) override {
tsi_handshaker* handshaker = nullptr; tsi_handshaker* handshaker = nullptr;
if (handshaker_factory_ != nullptr) { if (handshaker_factory_ != nullptr) {
tsi_result result = tsi_ssl_client_handshaker_factory_create_handshaker( tsi_result result = tsi_ssl_client_handshaker_factory_create_handshaker(
@ -77,8 +77,7 @@ class grpc_httpcli_ssl_channel_security_connector final
tsi_result_to_string(result)); tsi_result_to_string(result));
} }
} }
grpc_handshake_manager_add( handshake_mgr->Add(grpc_core::SecurityHandshakerCreate(handshaker, this));
handshake_mgr, grpc_security_handshaker_create(handshaker, this));
} }
tsi_ssl_client_handshaker_factory* handshaker_factory() const { tsi_ssl_client_handshaker_factory* handshaker_factory() const {
@ -155,11 +154,11 @@ httpcli_ssl_channel_security_connector_create(
typedef struct { typedef struct {
void (*func)(void* arg, grpc_endpoint* endpoint); void (*func)(void* arg, grpc_endpoint* endpoint);
void* arg; void* arg;
grpc_handshake_manager* handshake_mgr; grpc_core::RefCountedPtr<grpc_core::HandshakeManager> handshake_mgr;
} on_done_closure; } on_done_closure;
static void on_handshake_done(void* arg, grpc_error* error) { static void on_handshake_done(void* arg, grpc_error* error) {
grpc_handshaker_args* args = static_cast<grpc_handshaker_args*>(arg); auto* args = static_cast<grpc_core::HandshakerArgs*>(arg);
on_done_closure* c = static_cast<on_done_closure*>(args->user_data); on_done_closure* c = static_cast<on_done_closure*>(args->user_data);
if (error != GRPC_ERROR_NONE) { if (error != GRPC_ERROR_NONE) {
const char* msg = grpc_error_string(error); const char* msg = grpc_error_string(error);
@ -172,14 +171,13 @@ static void on_handshake_done(void* arg, grpc_error* error) {
gpr_free(args->read_buffer); gpr_free(args->read_buffer);
c->func(c->arg, args->endpoint); c->func(c->arg, args->endpoint);
} }
grpc_handshake_manager_destroy(c->handshake_mgr); grpc_core::Delete<on_done_closure>(c);
gpr_free(c);
} }
static void ssl_handshake(void* arg, grpc_endpoint* tcp, const char* host, static void ssl_handshake(void* arg, grpc_endpoint* tcp, const char* host,
grpc_millis deadline, grpc_millis deadline,
void (*on_done)(void* arg, grpc_endpoint* endpoint)) { void (*on_done)(void* arg, grpc_endpoint* endpoint)) {
on_done_closure* c = static_cast<on_done_closure*>(gpr_malloc(sizeof(*c))); auto* c = grpc_core::New<on_done_closure>();
const char* pem_root_certs = const char* pem_root_certs =
grpc_core::DefaultSslRootStore::GetPemRootCerts(); grpc_core::DefaultSslRootStore::GetPemRootCerts();
const tsi_ssl_root_certs_store* root_store = const tsi_ssl_root_certs_store* root_store =
@ -198,12 +196,13 @@ static void ssl_handshake(void* arg, grpc_endpoint* tcp, const char* host,
GPR_ASSERT(sc != nullptr); GPR_ASSERT(sc != nullptr);
grpc_arg channel_arg = grpc_security_connector_to_arg(sc.get()); grpc_arg channel_arg = grpc_security_connector_to_arg(sc.get());
grpc_channel_args args = {1, &channel_arg}; grpc_channel_args args = {1, &channel_arg};
c->handshake_mgr = grpc_handshake_manager_create(); c->handshake_mgr = grpc_core::MakeRefCounted<grpc_core::HandshakeManager>();
grpc_handshakers_add(HANDSHAKER_CLIENT, &args, grpc_core::HandshakerRegistry::AddHandshakers(
nullptr /* interested_parties */, c->handshake_mgr); grpc_core::HANDSHAKER_CLIENT, &args, /*interested_parties=*/nullptr,
grpc_handshake_manager_do_handshake( c->handshake_mgr.get());
c->handshake_mgr, tcp, nullptr /* channel_args */, deadline, c->handshake_mgr->DoHandshake(tcp, /*channel_args=*/nullptr, deadline,
nullptr /* acceptor */, on_handshake_done, c /* user_data */); /*acceptor=*/nullptr, on_handshake_done,
/*user_data=*/c);
sc.reset(DEBUG_LOCATION, "httpcli"); sc.reset(DEBUG_LOCATION, "httpcli");
} }

@ -80,8 +80,9 @@ class grpc_alts_channel_security_connector final
~grpc_alts_channel_security_connector() override { gpr_free(target_name_); } ~grpc_alts_channel_security_connector() override { gpr_free(target_name_); }
void add_handshakers(grpc_pollset_set* interested_parties, void add_handshakers(
grpc_handshake_manager* handshake_manager) override { grpc_pollset_set* interested_parties,
grpc_core::HandshakeManager* handshake_manager) override {
tsi_handshaker* handshaker = nullptr; tsi_handshaker* handshaker = nullptr;
const grpc_alts_credentials* creds = const grpc_alts_credentials* creds =
static_cast<const grpc_alts_credentials*>(channel_creds()); static_cast<const grpc_alts_credentials*>(channel_creds());
@ -89,8 +90,8 @@ class grpc_alts_channel_security_connector final
creds->handshaker_service_url(), true, creds->handshaker_service_url(), true,
interested_parties, interested_parties,
&handshaker) == TSI_OK); &handshaker) == TSI_OK);
grpc_handshake_manager_add( handshake_manager->Add(
handshake_manager, grpc_security_handshaker_create(handshaker, this)); grpc_core::SecurityHandshakerCreate(handshaker, this));
} }
void check_peer(tsi_peer peer, grpc_endpoint* ep, void check_peer(tsi_peer peer, grpc_endpoint* ep,
@ -139,16 +140,17 @@ class grpc_alts_server_security_connector final
} }
~grpc_alts_server_security_connector() override = default; ~grpc_alts_server_security_connector() override = default;
void add_handshakers(grpc_pollset_set* interested_parties, void add_handshakers(
grpc_handshake_manager* handshake_manager) override { grpc_pollset_set* interested_parties,
grpc_core::HandshakeManager* handshake_manager) override {
tsi_handshaker* handshaker = nullptr; tsi_handshaker* handshaker = nullptr;
const grpc_alts_server_credentials* creds = const grpc_alts_server_credentials* creds =
static_cast<const grpc_alts_server_credentials*>(server_creds()); static_cast<const grpc_alts_server_credentials*>(server_creds());
GPR_ASSERT(alts_tsi_handshaker_create( GPR_ASSERT(alts_tsi_handshaker_create(
creds->options(), nullptr, creds->handshaker_service_url(), creds->options(), nullptr, creds->handshaker_service_url(),
false, interested_parties, &handshaker) == TSI_OK); false, interested_parties, &handshaker) == TSI_OK);
grpc_handshake_manager_add( handshake_manager->Add(
handshake_manager, grpc_security_handshaker_create(handshaker, this)); grpc_core::SecurityHandshakerCreate(handshaker, this));
} }
void check_peer(tsi_peer peer, grpc_endpoint* ep, void check_peer(tsi_peer peer, grpc_endpoint* ep,

@ -92,10 +92,8 @@ class grpc_fake_channel_security_connector final
} }
void add_handshakers(grpc_pollset_set* interested_parties, void add_handshakers(grpc_pollset_set* interested_parties,
grpc_handshake_manager* handshake_mgr) override { grpc_core::HandshakeManager* handshake_mgr) override {
grpc_handshake_manager_add( handshake_mgr->Add(grpc_core::SecurityHandshakerCreate(
handshake_mgr,
grpc_security_handshaker_create(
tsi_create_fake_handshaker(/*is_client=*/true), this)); tsi_create_fake_handshaker(/*is_client=*/true), this));
} }
@ -273,10 +271,8 @@ class grpc_fake_server_security_connector
} }
void add_handshakers(grpc_pollset_set* interested_parties, void add_handshakers(grpc_pollset_set* interested_parties,
grpc_handshake_manager* handshake_mgr) override { grpc_core::HandshakeManager* handshake_mgr) override {
grpc_handshake_manager_add( handshake_mgr->Add(grpc_core::SecurityHandshakerCreate(
handshake_mgr,
grpc_security_handshaker_create(
tsi_create_fake_handshaker(/*=is_client*/ false), this)); tsi_create_fake_handshaker(/*=is_client*/ false), this));
} }

@ -128,13 +128,14 @@ class grpc_local_channel_security_connector final
~grpc_local_channel_security_connector() override { gpr_free(target_name_); } ~grpc_local_channel_security_connector() override { gpr_free(target_name_); }
void add_handshakers(grpc_pollset_set* interested_parties, void add_handshakers(
grpc_handshake_manager* handshake_manager) override { grpc_pollset_set* interested_parties,
grpc_core::HandshakeManager* handshake_manager) override {
tsi_handshaker* handshaker = nullptr; tsi_handshaker* handshaker = nullptr;
GPR_ASSERT(local_tsi_handshaker_create(true /* is_client */, &handshaker) == GPR_ASSERT(local_tsi_handshaker_create(true /* is_client */, &handshaker) ==
TSI_OK); TSI_OK);
grpc_handshake_manager_add( handshake_manager->Add(
handshake_manager, grpc_security_handshaker_create(handshaker, this)); grpc_core::SecurityHandshakerCreate(handshaker, this));
} }
int cmp(const grpc_security_connector* other_sc) const override { int cmp(const grpc_security_connector* other_sc) const override {
@ -184,13 +185,14 @@ class grpc_local_server_security_connector final
: grpc_server_security_connector(nullptr, std::move(server_creds)) {} : grpc_server_security_connector(nullptr, std::move(server_creds)) {}
~grpc_local_server_security_connector() override = default; ~grpc_local_server_security_connector() override = default;
void add_handshakers(grpc_pollset_set* interested_parties, void add_handshakers(
grpc_handshake_manager* handshake_manager) override { grpc_pollset_set* interested_parties,
grpc_core::HandshakeManager* handshake_manager) override {
tsi_handshaker* handshaker = nullptr; tsi_handshaker* handshaker = nullptr;
GPR_ASSERT(local_tsi_handshaker_create(false /* is_client */, GPR_ASSERT(local_tsi_handshaker_create(false /* is_client */,
&handshaker) == TSI_OK); &handshaker) == TSI_OK);
grpc_handshake_manager_add( handshake_manager->Add(
handshake_manager, grpc_security_handshaker_create(handshaker, this)); grpc_core::SecurityHandshakerCreate(handshaker, this));
} }
void check_peer(tsi_peer peer, grpc_endpoint* ep, void check_peer(tsi_peer peer, grpc_endpoint* ep,

@ -109,7 +109,7 @@ class grpc_channel_security_connector : public grpc_security_connector {
grpc_error* error) GRPC_ABSTRACT; grpc_error* error) GRPC_ABSTRACT;
/// Registers handshakers with \a handshake_mgr. /// Registers handshakers with \a handshake_mgr.
virtual void add_handshakers(grpc_pollset_set* interested_parties, virtual void add_handshakers(grpc_pollset_set* interested_parties,
grpc_handshake_manager* handshake_mgr) grpc_core::HandshakeManager* handshake_mgr)
GRPC_ABSTRACT; GRPC_ABSTRACT;
const grpc_channel_credentials* channel_creds() const { const grpc_channel_credentials* channel_creds() const {
@ -150,7 +150,7 @@ class grpc_server_security_connector : public grpc_security_connector {
~grpc_server_security_connector() override = default; ~grpc_server_security_connector() override = default;
virtual void add_handshakers(grpc_pollset_set* interested_parties, virtual void add_handshakers(grpc_pollset_set* interested_parties,
grpc_handshake_manager* handshake_mgr) grpc_core::HandshakeManager* handshake_mgr)
GRPC_ABSTRACT; GRPC_ABSTRACT;
const grpc_server_credentials* server_creds() const { const grpc_server_credentials* server_creds() const {

@ -128,7 +128,7 @@ class grpc_ssl_channel_security_connector final
} }
void add_handshakers(grpc_pollset_set* interested_parties, void add_handshakers(grpc_pollset_set* interested_parties,
grpc_handshake_manager* handshake_mgr) override { grpc_core::HandshakeManager* handshake_mgr) override {
// Instantiate TSI handshaker. // Instantiate TSI handshaker.
tsi_handshaker* tsi_hs = nullptr; tsi_handshaker* tsi_hs = nullptr;
tsi_result result = tsi_ssl_client_handshaker_factory_create_handshaker( tsi_result result = tsi_ssl_client_handshaker_factory_create_handshaker(
@ -142,8 +142,7 @@ class grpc_ssl_channel_security_connector final
return; return;
} }
// Create handshakers. // Create handshakers.
grpc_handshake_manager_add(handshake_mgr, handshake_mgr->Add(grpc_core::SecurityHandshakerCreate(tsi_hs, this));
grpc_security_handshaker_create(tsi_hs, this));
} }
void check_peer(tsi_peer peer, grpc_endpoint* ep, void check_peer(tsi_peer peer, grpc_endpoint* ep,
@ -283,7 +282,7 @@ class grpc_ssl_server_security_connector
} }
void add_handshakers(grpc_pollset_set* interested_parties, void add_handshakers(grpc_pollset_set* interested_parties,
grpc_handshake_manager* handshake_mgr) override { grpc_core::HandshakeManager* handshake_mgr) override {
// Instantiate TSI handshaker. // Instantiate TSI handshaker.
try_fetch_ssl_server_credentials(); try_fetch_ssl_server_credentials();
tsi_handshaker* tsi_hs = nullptr; tsi_handshaker* tsi_hs = nullptr;
@ -295,8 +294,7 @@ class grpc_ssl_server_security_connector
return; return;
} }
// Create handshakers. // Create handshakers.
grpc_handshake_manager_add(handshake_mgr, handshake_mgr->Add(grpc_core::SecurityHandshakerCreate(tsi_hs, this));
grpc_security_handshaker_create(tsi_hs, this));
} }
void check_peer(tsi_peer peer, grpc_endpoint* ep, void check_peer(tsi_peer peer, grpc_endpoint* ep,

@ -39,74 +39,113 @@
#define GRPC_INITIAL_HANDSHAKE_BUFFER_SIZE 256 #define GRPC_INITIAL_HANDSHAKE_BUFFER_SIZE 256
namespace { namespace grpc_core {
struct security_handshaker {
security_handshaker(tsi_handshaker* handshaker,
grpc_security_connector* connector);
~security_handshaker() {
gpr_mu_destroy(&mu);
tsi_handshaker_destroy(handshaker);
tsi_handshaker_result_destroy(handshaker_result);
if (endpoint_to_destroy != nullptr) {
grpc_endpoint_destroy(endpoint_to_destroy);
}
if (read_buffer_to_destroy != nullptr) {
grpc_slice_buffer_destroy_internal(read_buffer_to_destroy);
gpr_free(read_buffer_to_destroy);
}
gpr_free(handshake_buffer);
grpc_slice_buffer_destroy_internal(&outgoing);
auth_context.reset(DEBUG_LOCATION, "handshake");
connector.reset(DEBUG_LOCATION, "handshake");
}
void Ref() { refs.Ref(); } namespace {
void Unref() {
if (refs.Unref()) {
grpc_core::Delete(this);
}
}
grpc_handshaker base; class SecurityHandshaker : public Handshaker {
public:
SecurityHandshaker(tsi_handshaker* handshaker,
grpc_security_connector* connector);
~SecurityHandshaker() override;
void Shutdown(grpc_error* why) override;
void DoHandshake(grpc_tcp_server_acceptor* acceptor,
grpc_closure* on_handshake_done,
HandshakerArgs* args) override;
const char* name() const override { return "security"; }
private:
grpc_error* DoHandshakerNextLocked(const unsigned char* bytes_received,
size_t bytes_received_size);
grpc_error* OnHandshakeNextDoneLocked(
tsi_result result, const unsigned char* bytes_to_send,
size_t bytes_to_send_size, tsi_handshaker_result* handshaker_result);
void HandshakeFailedLocked(grpc_error* error);
void CleanupArgsForFailureLocked();
static void OnHandshakeDataReceivedFromPeerFn(void* arg, grpc_error* error);
static void OnHandshakeDataSentToPeerFn(void* arg, grpc_error* error);
static void OnHandshakeNextDoneGrpcWrapper(
tsi_result result, void* user_data, const unsigned char* bytes_to_send,
size_t bytes_to_send_size, tsi_handshaker_result* handshaker_result);
static void OnPeerCheckedFn(void* arg, grpc_error* error);
void OnPeerCheckedInner(grpc_error* error);
size_t MoveReadBufferIntoHandshakeBuffer();
grpc_error* CheckPeerLocked();
// State set at creation time. // State set at creation time.
tsi_handshaker* handshaker; tsi_handshaker* handshaker_;
grpc_core::RefCountedPtr<grpc_security_connector> connector; RefCountedPtr<grpc_security_connector> connector_;
gpr_mu mu; gpr_mu mu_;
grpc_core::RefCount refs;
bool shutdown = false; bool is_shutdown_ = false;
// Endpoint and read buffer to destroy after a shutdown. // Endpoint and read buffer to destroy after a shutdown.
grpc_endpoint* endpoint_to_destroy = nullptr; grpc_endpoint* endpoint_to_destroy_ = nullptr;
grpc_slice_buffer* read_buffer_to_destroy = nullptr; grpc_slice_buffer* read_buffer_to_destroy_ = nullptr;
// State saved while performing the handshake. // State saved while performing the handshake.
grpc_handshaker_args* args = nullptr; HandshakerArgs* args_ = nullptr;
grpc_closure* on_handshake_done = nullptr; grpc_closure* on_handshake_done_ = nullptr;
size_t handshake_buffer_size; size_t handshake_buffer_size_;
unsigned char* handshake_buffer; unsigned char* handshake_buffer_;
grpc_slice_buffer outgoing; grpc_slice_buffer outgoing_;
grpc_closure on_handshake_data_sent_to_peer; grpc_closure on_handshake_data_sent_to_peer_;
grpc_closure on_handshake_data_received_from_peer; grpc_closure on_handshake_data_received_from_peer_;
grpc_closure on_peer_checked; grpc_closure on_peer_checked_;
grpc_core::RefCountedPtr<grpc_auth_context> auth_context; RefCountedPtr<grpc_auth_context> auth_context_;
tsi_handshaker_result* handshaker_result = nullptr; tsi_handshaker_result* handshaker_result_ = nullptr;
}; };
} // namespace
static size_t move_read_buffer_into_handshake_buffer(security_handshaker* h) { SecurityHandshaker::SecurityHandshaker(tsi_handshaker* handshaker,
size_t bytes_in_read_buffer = h->args->read_buffer->length; grpc_security_connector* connector)
if (h->handshake_buffer_size < bytes_in_read_buffer) { : handshaker_(handshaker),
h->handshake_buffer = static_cast<uint8_t*>( connector_(connector->Ref(DEBUG_LOCATION, "handshake")),
gpr_realloc(h->handshake_buffer, bytes_in_read_buffer)); handshake_buffer_size_(GRPC_INITIAL_HANDSHAKE_BUFFER_SIZE),
h->handshake_buffer_size = bytes_in_read_buffer; handshake_buffer_(
static_cast<uint8_t*>(gpr_malloc(handshake_buffer_size_))) {
gpr_mu_init(&mu_);
grpc_slice_buffer_init(&outgoing_);
GRPC_CLOSURE_INIT(&on_handshake_data_sent_to_peer_,
&SecurityHandshaker::OnHandshakeDataSentToPeerFn, this,
grpc_schedule_on_exec_ctx);
GRPC_CLOSURE_INIT(&on_handshake_data_received_from_peer_,
&SecurityHandshaker::OnHandshakeDataReceivedFromPeerFn,
this, grpc_schedule_on_exec_ctx);
GRPC_CLOSURE_INIT(&on_peer_checked_, &SecurityHandshaker::OnPeerCheckedFn,
this, grpc_schedule_on_exec_ctx);
}
SecurityHandshaker::~SecurityHandshaker() {
gpr_mu_destroy(&mu_);
tsi_handshaker_destroy(handshaker_);
tsi_handshaker_result_destroy(handshaker_result_);
if (endpoint_to_destroy_ != nullptr) {
grpc_endpoint_destroy(endpoint_to_destroy_);
}
if (read_buffer_to_destroy_ != nullptr) {
grpc_slice_buffer_destroy_internal(read_buffer_to_destroy_);
gpr_free(read_buffer_to_destroy_);
}
gpr_free(handshake_buffer_);
grpc_slice_buffer_destroy_internal(&outgoing_);
auth_context_.reset(DEBUG_LOCATION, "handshake");
connector_.reset(DEBUG_LOCATION, "handshake");
}
size_t SecurityHandshaker::MoveReadBufferIntoHandshakeBuffer() {
size_t bytes_in_read_buffer = args_->read_buffer->length;
if (handshake_buffer_size_ < bytes_in_read_buffer) {
handshake_buffer_ = static_cast<uint8_t*>(
gpr_realloc(handshake_buffer_, bytes_in_read_buffer));
handshake_buffer_size_ = bytes_in_read_buffer;
} }
size_t offset = 0; size_t offset = 0;
while (h->args->read_buffer->count > 0) { while (args_->read_buffer->count > 0) {
grpc_slice next_slice = grpc_slice_buffer_take_first(h->args->read_buffer); grpc_slice next_slice = grpc_slice_buffer_take_first(args_->read_buffer);
memcpy(h->handshake_buffer + offset, GRPC_SLICE_START_PTR(next_slice), memcpy(handshake_buffer_ + offset, GRPC_SLICE_START_PTR(next_slice),
GRPC_SLICE_LENGTH(next_slice)); GRPC_SLICE_LENGTH(next_slice));
offset += GRPC_SLICE_LENGTH(next_slice); offset += GRPC_SLICE_LENGTH(next_slice);
grpc_slice_unref_internal(next_slice); grpc_slice_unref_internal(next_slice);
@ -114,21 +153,20 @@ static size_t move_read_buffer_into_handshake_buffer(security_handshaker* h) {
return bytes_in_read_buffer; return bytes_in_read_buffer;
} }
// Set args fields to NULL, saving the endpoint and read buffer for // Set args_ fields to NULL, saving the endpoint and read buffer for
// later destruction. // later destruction.
static void cleanup_args_for_failure_locked(security_handshaker* h) { void SecurityHandshaker::CleanupArgsForFailureLocked() {
h->endpoint_to_destroy = h->args->endpoint; endpoint_to_destroy_ = args_->endpoint;
h->args->endpoint = nullptr; args_->endpoint = nullptr;
h->read_buffer_to_destroy = h->args->read_buffer; read_buffer_to_destroy_ = args_->read_buffer;
h->args->read_buffer = nullptr; args_->read_buffer = nullptr;
grpc_channel_args_destroy(h->args->args); grpc_channel_args_destroy(args_->args);
h->args->args = nullptr; args_->args = nullptr;
} }
// If the handshake failed or we're shutting down, clean up and invoke the // If the handshake failed or we're shutting down, clean up and invoke the
// callback with the error. // callback with the error.
static void security_handshake_failed_locked(security_handshaker* h, void SecurityHandshaker::HandshakeFailedLocked(grpc_error* error) {
grpc_error* error) {
if (error == GRPC_ERROR_NONE) { if (error == GRPC_ERROR_NONE) {
// If we were shut down after the handshake succeeded but before an // If we were shut down after the handshake succeeded but before an
// endpoint callback was invoked, we need to generate our own error. // endpoint callback was invoked, we need to generate our own error.
@ -137,50 +175,51 @@ static void security_handshake_failed_locked(security_handshaker* h,
const char* msg = grpc_error_string(error); const char* msg = grpc_error_string(error);
gpr_log(GPR_DEBUG, "Security handshake failed: %s", msg); gpr_log(GPR_DEBUG, "Security handshake failed: %s", msg);
if (!h->shutdown) { if (!is_shutdown_) {
// TODO(ctiller): It is currently necessary to shutdown endpoints // TODO(ctiller): It is currently necessary to shutdown endpoints
// before destroying them, even if we know that there are no // before destroying them, even if we know that there are no
// pending read/write callbacks. This should be fixed, at which // pending read/write callbacks. This should be fixed, at which
// point this can be removed. // point this can be removed.
grpc_endpoint_shutdown(h->args->endpoint, GRPC_ERROR_REF(error)); grpc_endpoint_shutdown(args_->endpoint, GRPC_ERROR_REF(error));
// Not shutting down, so the write failed. Clean up before // Not shutting down, so the write failed. Clean up before
// invoking the callback. // invoking the callback.
cleanup_args_for_failure_locked(h); CleanupArgsForFailureLocked();
// Set shutdown to true so that subsequent calls to // Set shutdown to true so that subsequent calls to
// security_handshaker_shutdown() do nothing. // security_handshaker_shutdown() do nothing.
h->shutdown = true; is_shutdown_ = true;
} }
// Invoke callback. // Invoke callback.
GRPC_CLOSURE_SCHED(h->on_handshake_done, error); GRPC_CLOSURE_SCHED(on_handshake_done_, error);
} }
static void on_peer_checked_inner(security_handshaker* h, grpc_error* error) { void SecurityHandshaker::OnPeerCheckedInner(grpc_error* error) {
if (error != GRPC_ERROR_NONE || h->shutdown) { MutexLock lock(&mu_);
security_handshake_failed_locked(h, GRPC_ERROR_REF(error)); if (error != GRPC_ERROR_NONE || is_shutdown_) {
HandshakeFailedLocked(GRPC_ERROR_REF(error));
return; return;
} }
// Create zero-copy frame protector, if implemented. // Create zero-copy frame protector, if implemented.
tsi_zero_copy_grpc_protector* zero_copy_protector = nullptr; tsi_zero_copy_grpc_protector* zero_copy_protector = nullptr;
tsi_result result = tsi_handshaker_result_create_zero_copy_grpc_protector( tsi_result result = tsi_handshaker_result_create_zero_copy_grpc_protector(
h->handshaker_result, nullptr, &zero_copy_protector); handshaker_result_, nullptr, &zero_copy_protector);
if (result != TSI_OK && result != TSI_UNIMPLEMENTED) { if (result != TSI_OK && result != TSI_UNIMPLEMENTED) {
error = grpc_set_tsi_error_result( error = grpc_set_tsi_error_result(
GRPC_ERROR_CREATE_FROM_STATIC_STRING( GRPC_ERROR_CREATE_FROM_STATIC_STRING(
"Zero-copy frame protector creation failed"), "Zero-copy frame protector creation failed"),
result); result);
security_handshake_failed_locked(h, error); HandshakeFailedLocked(error);
return; return;
} }
// Create frame protector if zero-copy frame protector is NULL. // Create frame protector if zero-copy frame protector is NULL.
tsi_frame_protector* protector = nullptr; tsi_frame_protector* protector = nullptr;
if (zero_copy_protector == nullptr) { if (zero_copy_protector == nullptr) {
result = tsi_handshaker_result_create_frame_protector(h->handshaker_result, result = tsi_handshaker_result_create_frame_protector(handshaker_result_,
nullptr, &protector); nullptr, &protector);
if (result != TSI_OK) { if (result != TSI_OK) {
error = grpc_set_tsi_error_result(GRPC_ERROR_CREATE_FROM_STATIC_STRING( error = grpc_set_tsi_error_result(GRPC_ERROR_CREATE_FROM_STATIC_STRING(
"Frame protector creation failed"), "Frame protector creation failed"),
result); result);
security_handshake_failed_locked(h, error); HandshakeFailedLocked(error);
return; return;
} }
} }
@ -188,68 +227,63 @@ static void on_peer_checked_inner(security_handshaker* h, grpc_error* error) {
const unsigned char* unused_bytes = nullptr; const unsigned char* unused_bytes = nullptr;
size_t unused_bytes_size = 0; size_t unused_bytes_size = 0;
result = tsi_handshaker_result_get_unused_bytes( result = tsi_handshaker_result_get_unused_bytes(
h->handshaker_result, &unused_bytes, &unused_bytes_size); handshaker_result_, &unused_bytes, &unused_bytes_size);
// Create secure endpoint. // Create secure endpoint.
if (unused_bytes_size > 0) { if (unused_bytes_size > 0) {
grpc_slice slice = grpc_slice slice =
grpc_slice_from_copied_buffer((char*)unused_bytes, unused_bytes_size); grpc_slice_from_copied_buffer((char*)unused_bytes, unused_bytes_size);
h->args->endpoint = grpc_secure_endpoint_create( args_->endpoint = grpc_secure_endpoint_create(
protector, zero_copy_protector, h->args->endpoint, &slice, 1); protector, zero_copy_protector, args_->endpoint, &slice, 1);
grpc_slice_unref_internal(slice); grpc_slice_unref_internal(slice);
} else { } else {
h->args->endpoint = grpc_secure_endpoint_create( args_->endpoint = grpc_secure_endpoint_create(
protector, zero_copy_protector, h->args->endpoint, nullptr, 0); protector, zero_copy_protector, args_->endpoint, nullptr, 0);
} }
tsi_handshaker_result_destroy(h->handshaker_result); tsi_handshaker_result_destroy(handshaker_result_);
h->handshaker_result = nullptr; handshaker_result_ = nullptr;
// Add auth context to channel args. // Add auth context to channel args.
grpc_arg auth_context_arg = grpc_auth_context_to_arg(h->auth_context.get()); grpc_arg auth_context_arg = grpc_auth_context_to_arg(auth_context_.get());
grpc_channel_args* tmp_args = h->args->args; grpc_channel_args* tmp_args = args_->args;
h->args->args = args_->args = grpc_channel_args_copy_and_add(tmp_args, &auth_context_arg, 1);
grpc_channel_args_copy_and_add(tmp_args, &auth_context_arg, 1);
grpc_channel_args_destroy(tmp_args); grpc_channel_args_destroy(tmp_args);
// Invoke callback. // Invoke callback.
GRPC_CLOSURE_SCHED(h->on_handshake_done, GRPC_ERROR_NONE); GRPC_CLOSURE_SCHED(on_handshake_done_, GRPC_ERROR_NONE);
// Set shutdown to true so that subsequent calls to // Set shutdown to true so that subsequent calls to
// security_handshaker_shutdown() do nothing. // security_handshaker_shutdown() do nothing.
h->shutdown = true; is_shutdown_ = true;
} }
static void on_peer_checked(void* arg, grpc_error* error) { void SecurityHandshaker::OnPeerCheckedFn(void* arg, grpc_error* error) {
security_handshaker* h = static_cast<security_handshaker*>(arg); RefCountedPtr<SecurityHandshaker>(static_cast<SecurityHandshaker*>(arg))
gpr_mu_lock(&h->mu); ->OnPeerCheckedInner(error);
on_peer_checked_inner(h, error);
gpr_mu_unlock(&h->mu);
h->Unref();
} }
static grpc_error* check_peer_locked(security_handshaker* h) { grpc_error* SecurityHandshaker::CheckPeerLocked() {
tsi_peer peer; tsi_peer peer;
tsi_result result = tsi_result result =
tsi_handshaker_result_extract_peer(h->handshaker_result, &peer); tsi_handshaker_result_extract_peer(handshaker_result_, &peer);
if (result != TSI_OK) { if (result != TSI_OK) {
return grpc_set_tsi_error_result( return grpc_set_tsi_error_result(
GRPC_ERROR_CREATE_FROM_STATIC_STRING("Peer extraction failed"), result); GRPC_ERROR_CREATE_FROM_STATIC_STRING("Peer extraction failed"), result);
} }
h->connector->check_peer(peer, h->args->endpoint, &h->auth_context, connector_->check_peer(peer, args_->endpoint, &auth_context_,
&h->on_peer_checked); &on_peer_checked_);
return GRPC_ERROR_NONE; return GRPC_ERROR_NONE;
} }
static grpc_error* on_handshake_next_done_locked( grpc_error* SecurityHandshaker::OnHandshakeNextDoneLocked(
security_handshaker* h, tsi_result result, tsi_result result, const unsigned char* bytes_to_send,
const unsigned char* bytes_to_send, size_t bytes_to_send_size, size_t bytes_to_send_size, tsi_handshaker_result* handshaker_result) {
tsi_handshaker_result* handshaker_result) {
grpc_error* error = GRPC_ERROR_NONE; grpc_error* error = GRPC_ERROR_NONE;
// Handshaker was shutdown. // Handshaker was shutdown.
if (h->shutdown) { if (is_shutdown_) {
return GRPC_ERROR_CREATE_FROM_STATIC_STRING("Handshaker shutdown"); return GRPC_ERROR_CREATE_FROM_STATIC_STRING("Handshaker shutdown");
} }
// Read more if we need to. // Read more if we need to.
if (result == TSI_INCOMPLETE_DATA) { if (result == TSI_INCOMPLETE_DATA) {
GPR_ASSERT(bytes_to_send_size == 0); GPR_ASSERT(bytes_to_send_size == 0);
grpc_endpoint_read(h->args->endpoint, h->args->read_buffer, grpc_endpoint_read(args_->endpoint, args_->read_buffer,
&h->on_handshake_data_received_from_peer); &on_handshake_data_received_from_peer_);
return error; return error;
} }
if (result != TSI_OK) { if (result != TSI_OK) {
@ -258,55 +292,52 @@ static grpc_error* on_handshake_next_done_locked(
} }
// Update handshaker result. // Update handshaker result.
if (handshaker_result != nullptr) { if (handshaker_result != nullptr) {
GPR_ASSERT(h->handshaker_result == nullptr); GPR_ASSERT(handshaker_result_ == nullptr);
h->handshaker_result = handshaker_result; handshaker_result_ = handshaker_result;
} }
if (bytes_to_send_size > 0) { if (bytes_to_send_size > 0) {
// Send data to peer, if needed. // Send data to peer, if needed.
grpc_slice to_send = grpc_slice_from_copied_buffer( grpc_slice to_send = grpc_slice_from_copied_buffer(
reinterpret_cast<const char*>(bytes_to_send), bytes_to_send_size); reinterpret_cast<const char*>(bytes_to_send), bytes_to_send_size);
grpc_slice_buffer_reset_and_unref_internal(&h->outgoing); grpc_slice_buffer_reset_and_unref_internal(&outgoing_);
grpc_slice_buffer_add(&h->outgoing, to_send); grpc_slice_buffer_add(&outgoing_, to_send);
grpc_endpoint_write(h->args->endpoint, &h->outgoing, grpc_endpoint_write(args_->endpoint, &outgoing_,
&h->on_handshake_data_sent_to_peer, nullptr); &on_handshake_data_sent_to_peer_, nullptr);
} else if (handshaker_result == nullptr) { } else if (handshaker_result == nullptr) {
// There is nothing to send, but need to read from peer. // There is nothing to send, but need to read from peer.
grpc_endpoint_read(h->args->endpoint, h->args->read_buffer, grpc_endpoint_read(args_->endpoint, args_->read_buffer,
&h->on_handshake_data_received_from_peer); &on_handshake_data_received_from_peer_);
} else { } else {
// Handshake has finished, check peer and so on. // Handshake has finished, check peer and so on.
error = check_peer_locked(h); error = CheckPeerLocked();
} }
return error; return error;
} }
static void on_handshake_next_done_grpc_wrapper( void SecurityHandshaker::OnHandshakeNextDoneGrpcWrapper(
tsi_result result, void* user_data, const unsigned char* bytes_to_send, tsi_result result, void* user_data, const unsigned char* bytes_to_send,
size_t bytes_to_send_size, tsi_handshaker_result* handshaker_result) { size_t bytes_to_send_size, tsi_handshaker_result* handshaker_result) {
security_handshaker* h = static_cast<security_handshaker*>(user_data); RefCountedPtr<SecurityHandshaker> h(
gpr_mu_lock(&h->mu); static_cast<SecurityHandshaker*>(user_data));
grpc_error* error = on_handshake_next_done_locked( MutexLock lock(&h->mu_);
h, result, bytes_to_send, bytes_to_send_size, handshaker_result); grpc_error* error = h->OnHandshakeNextDoneLocked(
result, bytes_to_send, bytes_to_send_size, handshaker_result);
if (error != GRPC_ERROR_NONE) { if (error != GRPC_ERROR_NONE) {
security_handshake_failed_locked(h, error); h->HandshakeFailedLocked(error);
gpr_mu_unlock(&h->mu);
h->Unref();
} else { } else {
gpr_mu_unlock(&h->mu); h.release(); // Avoid unref
} }
} }
static grpc_error* do_handshaker_next_locked( grpc_error* SecurityHandshaker::DoHandshakerNextLocked(
security_handshaker* h, const unsigned char* bytes_received, const unsigned char* bytes_received, size_t bytes_received_size) {
size_t bytes_received_size) {
// Invoke TSI handshaker. // Invoke TSI handshaker.
const unsigned char* bytes_to_send = nullptr; const unsigned char* bytes_to_send = nullptr;
size_t bytes_to_send_size = 0; size_t bytes_to_send_size = 0;
tsi_handshaker_result* handshaker_result = nullptr; tsi_handshaker_result* hs_result = nullptr;
tsi_result result = tsi_handshaker_next( tsi_result result = tsi_handshaker_next(
h->handshaker, bytes_received, bytes_received_size, &bytes_to_send, handshaker_, bytes_received, bytes_received_size, &bytes_to_send,
&bytes_to_send_size, &handshaker_result, &bytes_to_send_size, &hs_result, &OnHandshakeNextDoneGrpcWrapper, this);
&on_handshake_next_done_grpc_wrapper, h);
if (result == TSI_ASYNC) { if (result == TSI_ASYNC) {
// Handshaker operating asynchronously. Nothing else to do here; // Handshaker operating asynchronously. Nothing else to do here;
// callback will be invoked in a TSI thread. // callback will be invoked in a TSI thread.
@ -314,233 +345,169 @@ static grpc_error* do_handshaker_next_locked(
} }
// Handshaker returned synchronously. Invoke callback directly in // Handshaker returned synchronously. Invoke callback directly in
// this thread with our existing exec_ctx. // this thread with our existing exec_ctx.
return on_handshake_next_done_locked(h, result, bytes_to_send, return OnHandshakeNextDoneLocked(result, bytes_to_send, bytes_to_send_size,
bytes_to_send_size, handshaker_result); hs_result);
} }
static void on_handshake_data_received_from_peer(void* arg, grpc_error* error) { void SecurityHandshaker::OnHandshakeDataReceivedFromPeerFn(void* arg,
security_handshaker* h = static_cast<security_handshaker*>(arg); grpc_error* error) {
gpr_mu_lock(&h->mu); RefCountedPtr<SecurityHandshaker> h(static_cast<SecurityHandshaker*>(arg));
if (error != GRPC_ERROR_NONE || h->shutdown) { MutexLock lock(&h->mu_);
security_handshake_failed_locked( if (error != GRPC_ERROR_NONE || h->is_shutdown_) {
h, GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( h->HandshakeFailedLocked(GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING(
"Handshake read failed", &error, 1)); "Handshake read failed", &error, 1));
gpr_mu_unlock(&h->mu);
h->Unref();
return; return;
} }
// Copy all slices received. // Copy all slices received.
size_t bytes_received_size = move_read_buffer_into_handshake_buffer(h); size_t bytes_received_size = h->MoveReadBufferIntoHandshakeBuffer();
// Call TSI handshaker. // Call TSI handshaker.
error = error = h->DoHandshakerNextLocked(h->handshake_buffer_, bytes_received_size);
do_handshaker_next_locked(h, h->handshake_buffer, bytes_received_size);
if (error != GRPC_ERROR_NONE) { if (error != GRPC_ERROR_NONE) {
security_handshake_failed_locked(h, error); h->HandshakeFailedLocked(error);
gpr_mu_unlock(&h->mu);
h->Unref();
} else { } else {
gpr_mu_unlock(&h->mu); h.release(); // Avoid unref
} }
} }
static void on_handshake_data_sent_to_peer(void* arg, grpc_error* error) { void SecurityHandshaker::OnHandshakeDataSentToPeerFn(void* arg,
security_handshaker* h = static_cast<security_handshaker*>(arg); grpc_error* error) {
gpr_mu_lock(&h->mu); RefCountedPtr<SecurityHandshaker> h(static_cast<SecurityHandshaker*>(arg));
if (error != GRPC_ERROR_NONE || h->shutdown) { MutexLock lock(&h->mu_);
security_handshake_failed_locked( if (error != GRPC_ERROR_NONE || h->is_shutdown_) {
h, GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( h->HandshakeFailedLocked(GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING(
"Handshake write failed", &error, 1)); "Handshake write failed", &error, 1));
gpr_mu_unlock(&h->mu);
h->Unref();
return; return;
} }
// We may be done. // We may be done.
if (h->handshaker_result == nullptr) { if (h->handshaker_result_ == nullptr) {
grpc_endpoint_read(h->args->endpoint, h->args->read_buffer, grpc_endpoint_read(h->args_->endpoint, h->args_->read_buffer,
&h->on_handshake_data_received_from_peer); &h->on_handshake_data_received_from_peer_);
} else { } else {
error = check_peer_locked(h); error = h->CheckPeerLocked();
if (error != GRPC_ERROR_NONE) { if (error != GRPC_ERROR_NONE) {
security_handshake_failed_locked(h, error); h->HandshakeFailedLocked(error);
gpr_mu_unlock(&h->mu);
h->Unref();
return; return;
} }
} }
gpr_mu_unlock(&h->mu); h.release(); // Avoid unref
} }
// //
// public handshaker API // public handshaker API
// //
static void security_handshaker_destroy(grpc_handshaker* handshaker) { void SecurityHandshaker::Shutdown(grpc_error* why) {
security_handshaker* h = reinterpret_cast<security_handshaker*>(handshaker); MutexLock lock(&mu_);
h->Unref(); if (!is_shutdown_) {
} is_shutdown_ = true;
tsi_handshaker_shutdown(handshaker_);
static void security_handshaker_shutdown(grpc_handshaker* handshaker, grpc_endpoint_shutdown(args_->endpoint, GRPC_ERROR_REF(why));
grpc_error* why) { CleanupArgsForFailureLocked();
security_handshaker* h = reinterpret_cast<security_handshaker*>(handshaker); }
gpr_mu_lock(&h->mu);
if (!h->shutdown) {
h->shutdown = true;
tsi_handshaker_shutdown(h->handshaker);
grpc_endpoint_shutdown(h->args->endpoint, GRPC_ERROR_REF(why));
cleanup_args_for_failure_locked(h);
}
gpr_mu_unlock(&h->mu);
GRPC_ERROR_UNREF(why); GRPC_ERROR_UNREF(why);
} }
static void security_handshaker_do_handshake(grpc_handshaker* handshaker, void SecurityHandshaker::DoHandshake(grpc_tcp_server_acceptor* acceptor,
grpc_tcp_server_acceptor* acceptor,
grpc_closure* on_handshake_done, grpc_closure* on_handshake_done,
grpc_handshaker_args* args) { HandshakerArgs* args) {
security_handshaker* h = reinterpret_cast<security_handshaker*>(handshaker); auto ref = Ref();
gpr_mu_lock(&h->mu); MutexLock lock(&mu_);
h->args = args; args_ = args;
h->on_handshake_done = on_handshake_done; on_handshake_done_ = on_handshake_done;
h->Ref(); size_t bytes_received_size = MoveReadBufferIntoHandshakeBuffer();
size_t bytes_received_size = move_read_buffer_into_handshake_buffer(h);
grpc_error* error = grpc_error* error =
do_handshaker_next_locked(h, h->handshake_buffer, bytes_received_size); DoHandshakerNextLocked(handshake_buffer_, bytes_received_size);
if (error != GRPC_ERROR_NONE) { if (error != GRPC_ERROR_NONE) {
security_handshake_failed_locked(h, error); HandshakeFailedLocked(error);
gpr_mu_unlock(&h->mu); } else {
h->Unref(); ref.release(); // Avoid unref
return;
} }
gpr_mu_unlock(&h->mu);
}
static const grpc_handshaker_vtable security_handshaker_vtable = {
security_handshaker_destroy, security_handshaker_shutdown,
security_handshaker_do_handshake, "security"};
namespace {
security_handshaker::security_handshaker(tsi_handshaker* handshaker,
grpc_security_connector* connector)
: handshaker(handshaker),
connector(connector->Ref(DEBUG_LOCATION, "handshake")),
handshake_buffer_size(GRPC_INITIAL_HANDSHAKE_BUFFER_SIZE),
handshake_buffer(
static_cast<uint8_t*>(gpr_malloc(handshake_buffer_size))) {
grpc_handshaker_init(&security_handshaker_vtable, &base);
gpr_mu_init(&mu);
grpc_slice_buffer_init(&outgoing);
GRPC_CLOSURE_INIT(&on_handshake_data_sent_to_peer,
::on_handshake_data_sent_to_peer, this,
grpc_schedule_on_exec_ctx);
GRPC_CLOSURE_INIT(&on_handshake_data_received_from_peer,
::on_handshake_data_received_from_peer, this,
grpc_schedule_on_exec_ctx);
GRPC_CLOSURE_INIT(&on_peer_checked, ::on_peer_checked, this,
grpc_schedule_on_exec_ctx);
}
} // namespace
static grpc_handshaker* security_handshaker_create(
tsi_handshaker* handshaker, grpc_security_connector* connector) {
security_handshaker* h =
grpc_core::New<security_handshaker>(handshaker, connector);
return &h->base;
} }
// //
// fail_handshaker // FailHandshaker
// //
static void fail_handshaker_destroy(grpc_handshaker* handshaker) { class FailHandshaker : public Handshaker {
gpr_free(handshaker); public:
} const char* name() const override { return "security_fail"; }
void Shutdown(grpc_error* why) override { GRPC_ERROR_UNREF(why); }
static void fail_handshaker_shutdown(grpc_handshaker* handshaker, void DoHandshake(grpc_tcp_server_acceptor* acceptor,
grpc_error* why) {
GRPC_ERROR_UNREF(why);
}
static void fail_handshaker_do_handshake(grpc_handshaker* handshaker,
grpc_tcp_server_acceptor* acceptor,
grpc_closure* on_handshake_done, grpc_closure* on_handshake_done,
grpc_handshaker_args* args) { HandshakerArgs* args) override {
GRPC_CLOSURE_SCHED(on_handshake_done, GRPC_CLOSURE_SCHED(on_handshake_done,
GRPC_ERROR_CREATE_FROM_STATIC_STRING( GRPC_ERROR_CREATE_FROM_STATIC_STRING(
"Failed to create security handshaker")); "Failed to create security handshaker"));
} }
static const grpc_handshaker_vtable fail_handshaker_vtable = {
fail_handshaker_destroy, fail_handshaker_shutdown,
fail_handshaker_do_handshake, "security_fail"};
static grpc_handshaker* fail_handshaker_create() { private:
grpc_handshaker* h = static_cast<grpc_handshaker*>(gpr_malloc(sizeof(*h))); virtual ~FailHandshaker() = default;
grpc_handshaker_init(&fail_handshaker_vtable, h); };
return h;
}
// //
// handshaker factories // handshaker factories
// //
static void client_handshaker_factory_add_handshakers( class ClientSecurityHandshakerFactory : public HandshakerFactory {
grpc_handshaker_factory* handshaker_factory, const grpc_channel_args* args, public:
void AddHandshakers(const grpc_channel_args* args,
grpc_pollset_set* interested_parties, grpc_pollset_set* interested_parties,
grpc_handshake_manager* handshake_mgr) { HandshakeManager* handshake_mgr) override {
grpc_channel_security_connector* security_connector = auto* security_connector =
reinterpret_cast<grpc_channel_security_connector*>( reinterpret_cast<grpc_channel_security_connector*>(
grpc_security_connector_find_in_args(args)); grpc_security_connector_find_in_args(args));
if (security_connector) { if (security_connector) {
security_connector->add_handshakers(interested_parties, handshake_mgr); security_connector->add_handshakers(interested_parties, handshake_mgr);
} }
} }
~ClientSecurityHandshakerFactory() override = default;
};
static void server_handshaker_factory_add_handshakers( class ServerSecurityHandshakerFactory : public HandshakerFactory {
grpc_handshaker_factory* hf, const grpc_channel_args* args, public:
void AddHandshakers(const grpc_channel_args* args,
grpc_pollset_set* interested_parties, grpc_pollset_set* interested_parties,
grpc_handshake_manager* handshake_mgr) { HandshakeManager* handshake_mgr) override {
grpc_server_security_connector* security_connector = auto* security_connector =
reinterpret_cast<grpc_server_security_connector*>( reinterpret_cast<grpc_server_security_connector*>(
grpc_security_connector_find_in_args(args)); grpc_security_connector_find_in_args(args));
if (security_connector) { if (security_connector) {
security_connector->add_handshakers(interested_parties, handshake_mgr); security_connector->add_handshakers(interested_parties, handshake_mgr);
} }
} }
~ServerSecurityHandshakerFactory() override = default;
static void handshaker_factory_destroy( };
grpc_handshaker_factory* handshaker_factory) {}
static const grpc_handshaker_factory_vtable client_handshaker_factory_vtable = {
client_handshaker_factory_add_handshakers, handshaker_factory_destroy};
static grpc_handshaker_factory client_handshaker_factory = {
&client_handshaker_factory_vtable};
static const grpc_handshaker_factory_vtable server_handshaker_factory_vtable = {
server_handshaker_factory_add_handshakers, handshaker_factory_destroy};
static grpc_handshaker_factory server_handshaker_factory = { } // namespace
&server_handshaker_factory_vtable};
// //
// exported functions // exported functions
// //
grpc_handshaker* grpc_security_handshaker_create( RefCountedPtr<Handshaker> SecurityHandshakerCreate(
tsi_handshaker* handshaker, grpc_security_connector* connector) { tsi_handshaker* handshaker, grpc_security_connector* connector) {
// If no TSI handshaker was created, return a handshaker that always fails. // If no TSI handshaker was created, return a handshaker that always fails.
// Otherwise, return a real security handshaker. // Otherwise, return a real security handshaker.
if (handshaker == nullptr) { if (handshaker == nullptr) {
return fail_handshaker_create(); return MakeRefCounted<FailHandshaker>();
} else { } else {
return security_handshaker_create(handshaker, connector); return MakeRefCounted<SecurityHandshaker>(handshaker, connector);
} }
} }
void grpc_security_register_handshaker_factories() { grpc_handshaker* grpc_security_handshaker_create(
grpc_handshaker_factory_register(false /* at_start */, HANDSHAKER_CLIENT, tsi_handshaker* handshaker, grpc_security_connector* connector) {
&client_handshaker_factory); return SecurityHandshakerCreate(handshaker, connector).release();
grpc_handshaker_factory_register(false /* at_start */, HANDSHAKER_SERVER, }
&server_handshaker_factory);
void SecurityRegisterHandshakerFactories() {
HandshakerRegistry::RegisterHandshakerFactory(
false /* at_start */, HANDSHAKER_CLIENT,
UniquePtr<HandshakerFactory>(New<ClientSecurityHandshakerFactory>()));
HandshakerRegistry::RegisterHandshakerFactory(
false /* at_start */, HANDSHAKER_SERVER,
UniquePtr<HandshakerFactory>(New<ServerSecurityHandshakerFactory>()));
} }
} // namespace grpc_core

@ -24,11 +24,20 @@
#include "src/core/lib/channel/handshaker.h" #include "src/core/lib/channel/handshaker.h"
#include "src/core/lib/security/security_connector/security_connector.h" #include "src/core/lib/security/security_connector/security_connector.h"
namespace grpc_core {
/// Creates a security handshaker using \a handshaker. /// Creates a security handshaker using \a handshaker.
grpc_handshaker* grpc_security_handshaker_create( RefCountedPtr<Handshaker> SecurityHandshakerCreate(
tsi_handshaker* handshaker, grpc_security_connector* connector); tsi_handshaker* handshaker, grpc_security_connector* connector);
/// Registers security handshaker factories. /// Registers security handshaker factories.
void grpc_security_register_handshaker_factories(); void SecurityRegisterHandshakerFactories();
} // namespace grpc_core
// TODO(arjunroy): This is transitional to account for the new handshaker API
// and will eventually be removed entirely.
grpc_handshaker* grpc_security_handshaker_create(
tsi_handshaker* handshaker, grpc_security_connector* connector);
#endif /* GRPC_CORE_LIB_SECURITY_TRANSPORT_SECURITY_HANDSHAKER_H */ #endif /* GRPC_CORE_LIB_SECURITY_TRANSPORT_SECURITY_HANDSHAKER_H */

@ -134,7 +134,7 @@ void grpc_init(void) {
grpc_core::ExecCtx::GlobalInit(); grpc_core::ExecCtx::GlobalInit();
grpc_iomgr_init(); grpc_iomgr_init();
gpr_timers_global_init(); gpr_timers_global_init();
grpc_handshaker_factory_registry_init(); grpc_core::HandshakerRegistry::Init();
grpc_security_init(); grpc_security_init();
for (i = 0; i < g_number_of_plugins; i++) { for (i = 0; i < g_number_of_plugins; i++) {
if (g_all_of_the_plugins[i].init != nullptr) { if (g_all_of_the_plugins[i].init != nullptr) {
@ -177,7 +177,7 @@ void grpc_shutdown(void) {
gpr_timers_global_destroy(); gpr_timers_global_destroy();
grpc_tracer_shutdown(); grpc_tracer_shutdown();
grpc_mdctx_global_shutdown(); grpc_mdctx_global_shutdown();
grpc_handshaker_factory_registry_shutdown(); grpc_core::HandshakerRegistry::Shutdown();
grpc_slice_intern_shutdown(); grpc_slice_intern_shutdown();
grpc_core::channelz::ChannelzRegistry::Shutdown(); grpc_core::channelz::ChannelzRegistry::Shutdown();
grpc_stats_shutdown(); grpc_stats_shutdown();

@ -78,4 +78,4 @@ void grpc_register_security_filters(void) {
maybe_prepend_server_auth_filter, nullptr); maybe_prepend_server_auth_filter, nullptr);
} }
void grpc_security_init() { grpc_security_register_handshaker_factories(); } void grpc_security_init() { grpc_core::SecurityRegisterHandshakerFactories(); }

@ -68,7 +68,6 @@ CORE_SOURCE_FILES = [
'src/core/lib/channel/channelz_registry.cc', 'src/core/lib/channel/channelz_registry.cc',
'src/core/lib/channel/connected_channel.cc', 'src/core/lib/channel/connected_channel.cc',
'src/core/lib/channel/handshaker.cc', 'src/core/lib/channel/handshaker.cc',
'src/core/lib/channel/handshaker_factory.cc',
'src/core/lib/channel/handshaker_registry.cc', 'src/core/lib/channel/handshaker_registry.cc',
'src/core/lib/channel/status_util.cc', 'src/core/lib/channel/status_util.cc',
'src/core/lib/compression/compression.cc', 'src/core/lib/compression/compression.cc',

@ -49,51 +49,38 @@
* to the security_handshaker). This test is meant to protect code relying on * to the security_handshaker). This test is meant to protect code relying on
* this functionality that lives outside of this repo. */ * this functionality that lives outside of this repo. */
static void readahead_handshaker_destroy(grpc_handshaker* handshaker) { namespace grpc_core {
gpr_free(handshaker);
}
static void readahead_handshaker_shutdown(grpc_handshaker* handshaker,
grpc_error* error) {}
static void readahead_handshaker_do_handshake( class ReadAheadHandshaker : public Handshaker {
grpc_handshaker* handshaker, grpc_tcp_server_acceptor* acceptor, public:
grpc_closure* on_handshake_done, grpc_handshaker_args* args) { virtual ~ReadAheadHandshaker() {}
const char* name() const override { return "read_ahead"; }
void Shutdown(grpc_error* why) override {}
void DoHandshake(grpc_tcp_server_acceptor* acceptor,
grpc_closure* on_handshake_done,
HandshakerArgs* args) override {
grpc_endpoint_read(args->endpoint, args->read_buffer, on_handshake_done); grpc_endpoint_read(args->endpoint, args->read_buffer, on_handshake_done);
} }
};
const grpc_handshaker_vtable readahead_handshaker_vtable = {
readahead_handshaker_destroy, readahead_handshaker_shutdown,
readahead_handshaker_do_handshake, "read_ahead"};
static grpc_handshaker* readahead_handshaker_create() { class ReadAheadHandshakerFactory : public HandshakerFactory {
grpc_handshaker* h = public:
static_cast<grpc_handshaker*>(gpr_zalloc(sizeof(grpc_handshaker))); void AddHandshakers(const grpc_channel_args* args,
grpc_handshaker_init(&readahead_handshaker_vtable, h);
return h;
}
static void readahead_handshaker_factory_add_handshakers(
grpc_handshaker_factory* hf, const grpc_channel_args* args,
grpc_pollset_set* interested_parties, grpc_pollset_set* interested_parties,
grpc_handshake_manager* handshake_mgr) { HandshakeManager* handshake_mgr) override {
grpc_handshake_manager_add(handshake_mgr, readahead_handshaker_create()); handshake_mgr->Add(MakeRefCounted<ReadAheadHandshaker>());
} }
~ReadAheadHandshakerFactory() override = default;
static void readahead_handshaker_factory_destroy( };
grpc_handshaker_factory* handshaker_factory) {}
static const grpc_handshaker_factory_vtable } // namespace grpc_core
readahead_handshaker_factory_vtable = {
readahead_handshaker_factory_add_handshakers,
readahead_handshaker_factory_destroy};
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
grpc_handshaker_factory readahead_handshaker_factory = { using namespace grpc_core;
&readahead_handshaker_factory_vtable};
grpc_init(); grpc_init();
grpc_handshaker_factory_register(true /* at_start */, HANDSHAKER_SERVER, HandshakerRegistry::RegisterHandshakerFactory(
&readahead_handshaker_factory); true /* at_start */, HANDSHAKER_SERVER,
UniquePtr<HandshakerFactory>(New<ReadAheadHandshakerFactory>()));
const char* full_alpn_list[] = {"grpc-exp", "h2"}; const char* full_alpn_list[] = {"grpc-exp", "h2"};
GPR_ASSERT(server_ssl_test(full_alpn_list, 2, "grpc-exp")); GPR_ASSERT(server_ssl_test(full_alpn_list, 2, "grpc-exp"));
grpc_shutdown(); grpc_shutdown();

@ -41,7 +41,8 @@ struct handshake_state {
}; };
static void on_handshake_done(void* arg, grpc_error* error) { static void on_handshake_done(void* arg, grpc_error* error) {
grpc_handshaker_args* args = static_cast<grpc_handshaker_args*>(arg); grpc_core::HandshakerArgs* args =
static_cast<grpc_core::HandshakerArgs*>(arg);
struct handshake_state* state = struct handshake_state* state =
static_cast<struct handshake_state*>(args->user_data); static_cast<struct handshake_state*>(args->user_data);
GPR_ASSERT(state->done_callback_called == false); GPR_ASSERT(state->done_callback_called == false);
@ -89,11 +90,12 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) {
struct handshake_state state; struct handshake_state state;
state.done_callback_called = false; state.done_callback_called = false;
grpc_handshake_manager* handshake_mgr = grpc_handshake_manager_create(); auto handshake_mgr =
sc->add_handshakers(nullptr, handshake_mgr); grpc_core::MakeRefCounted<grpc_core::HandshakeManager>();
grpc_handshake_manager_do_handshake( sc->add_handshakers(nullptr, handshake_mgr.get());
handshake_mgr, mock_endpoint, nullptr /* channel_args */, deadline, handshake_mgr->DoHandshake(mock_endpoint, nullptr /* channel_args */,
nullptr /* acceptor */, on_handshake_done, &state); deadline, nullptr /* acceptor */,
on_handshake_done, &state);
grpc_core::ExecCtx::Get()->Flush(); grpc_core::ExecCtx::Get()->Flush();
// If the given string happens to be part of the correct client hello, the // If the given string happens to be part of the correct client hello, the
@ -108,7 +110,6 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) {
GPR_ASSERT(state.done_callback_called); GPR_ASSERT(state.done_callback_called);
grpc_handshake_manager_destroy(handshake_mgr);
sc.reset(DEBUG_LOCATION, "test"); sc.reset(DEBUG_LOCATION, "test");
grpc_server_credentials_release(creds); grpc_server_credentials_release(creds);
grpc_slice_unref(cert_slice); grpc_slice_unref(cert_slice);

@ -1080,7 +1080,6 @@ src/core/lib/channel/connected_channel.h \
src/core/lib/channel/context.h \ src/core/lib/channel/context.h \
src/core/lib/channel/handshaker.cc \ src/core/lib/channel/handshaker.cc \
src/core/lib/channel/handshaker.h \ src/core/lib/channel/handshaker.h \
src/core/lib/channel/handshaker_factory.cc \
src/core/lib/channel/handshaker_factory.h \ src/core/lib/channel/handshaker_factory.h \
src/core/lib/channel/handshaker_registry.cc \ src/core/lib/channel/handshaker_registry.cc \
src/core/lib/channel/handshaker_registry.h \ src/core/lib/channel/handshaker_registry.h \

@ -9440,7 +9440,6 @@
"src/core/lib/channel/channelz_registry.cc", "src/core/lib/channel/channelz_registry.cc",
"src/core/lib/channel/connected_channel.cc", "src/core/lib/channel/connected_channel.cc",
"src/core/lib/channel/handshaker.cc", "src/core/lib/channel/handshaker.cc",
"src/core/lib/channel/handshaker_factory.cc",
"src/core/lib/channel/handshaker_registry.cc", "src/core/lib/channel/handshaker_registry.cc",
"src/core/lib/channel/status_util.cc", "src/core/lib/channel/status_util.cc",
"src/core/lib/compression/compression.cc", "src/core/lib/compression/compression.cc",

Loading…
Cancel
Save