From 03a51fa9d1c20926d25acc3d0a378ec610e72661 Mon Sep 17 00:00:00 2001 From: "Mark D. Roth" Date: Fri, 1 Oct 2021 10:01:03 -0700 Subject: [PATCH] don't create secure endpoint if TSI implementation does not require it (#27509) * don't create secure endpoint if security level is TSI_SECURITY_NONE * instead of depending on security level, add a method to tsi_handshaker_result * clang-format * fix build * fix build for realz * code review changes * update comments --- .../security/transport/security_handshaker.cc | 111 +++++++++++------- .../alts/handshaker/alts_tsi_handshaker.cc | 11 +- src/core/tsi/fake_transport_security.cc | 9 ++ src/core/tsi/local_transport_security.cc | 79 ++----------- src/core/tsi/ssl_transport_security.cc | 8 ++ src/core/tsi/transport_security.cc | 12 ++ src/core/tsi/transport_security.h | 17 ++- src/core/tsi/transport_security_interface.h | 26 ++++ 8 files changed, 157 insertions(+), 116 deletions(-) diff --git a/src/core/lib/security/transport/security_handshaker.cc b/src/core/lib/security/transport/security_handshaker.cc index 1d894876b9c..0f9114d8f22 100644 --- a/src/core/lib/security/transport/security_handshaker.cc +++ b/src/core/lib/security/transport/security_handshaker.cc @@ -116,13 +116,10 @@ SecurityHandshaker::SecurityHandshaker(tsi_handshaker* handshaker, connector_(connector->Ref(DEBUG_LOCATION, "handshake")), handshake_buffer_size_(GRPC_INITIAL_HANDSHAKE_BUFFER_SIZE), handshake_buffer_( - static_cast(gpr_malloc(handshake_buffer_size_))) { - const grpc_arg* arg = - grpc_channel_args_find(args, GRPC_ARG_TSI_MAX_FRAME_SIZE); - if (arg != nullptr && arg->type == GRPC_ARG_INTEGER) { - max_frame_size_ = grpc_channel_arg_get_integer( - arg, {0, 0, std::numeric_limits::max()}); - } + static_cast(gpr_malloc(handshake_buffer_size_))), + max_frame_size_(grpc_channel_args_find_integer( + args, GRPC_ARG_TSI_MAX_FRAME_SIZE, + {0, 0, std::numeric_limits::max()})) { grpc_slice_buffer_init(&outgoing_); GRPC_CLOSURE_INIT(&on_peer_checked_, &SecurityHandshaker::OnPeerCheckedFn, this, grpc_schedule_on_exec_ctx); @@ -232,37 +229,10 @@ void SecurityHandshaker::OnPeerCheckedInner(grpc_error_handle error) { HandshakeFailedLocked(error); return; } - // Create zero-copy frame protector, if implemented. - tsi_zero_copy_grpc_protector* zero_copy_protector = nullptr; - tsi_result result = tsi_handshaker_result_create_zero_copy_grpc_protector( - handshaker_result_, max_frame_size_ == 0 ? nullptr : &max_frame_size_, - &zero_copy_protector); - if (result != TSI_OK && result != TSI_UNIMPLEMENTED) { - error = grpc_set_tsi_error_result( - GRPC_ERROR_CREATE_FROM_STATIC_STRING( - "Zero-copy frame protector creation failed"), - result); - HandshakeFailedLocked(error); - return; - } - // Create frame protector if zero-copy frame protector is NULL. - tsi_frame_protector* protector = nullptr; - if (zero_copy_protector == nullptr) { - result = tsi_handshaker_result_create_frame_protector( - handshaker_result_, max_frame_size_ == 0 ? nullptr : &max_frame_size_, - &protector); - if (result != TSI_OK) { - error = grpc_set_tsi_error_result(GRPC_ERROR_CREATE_FROM_STATIC_STRING( - "Frame protector creation failed"), - result); - HandshakeFailedLocked(error); - return; - } - } // Get unused bytes. const unsigned char* unused_bytes = nullptr; size_t unused_bytes_size = 0; - result = tsi_handshaker_result_get_unused_bytes( + tsi_result result = tsi_handshaker_result_get_unused_bytes( handshaker_result_, &unused_bytes, &unused_bytes_size); if (result != TSI_OK) { HandshakeFailedLocked(grpc_set_tsi_error_result( @@ -271,17 +241,71 @@ void SecurityHandshaker::OnPeerCheckedInner(grpc_error_handle error) { result)); return; } - // Create secure endpoint. - if (unused_bytes_size > 0) { + // Check whether we need to wrap the endpoint. + tsi_frame_protector_type frame_protector_type; + result = tsi_handshaker_result_get_frame_protector_type( + handshaker_result_, &frame_protector_type); + if (result != TSI_OK) { + HandshakeFailedLocked(grpc_set_tsi_error_result( + GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "TSI handshaker result does not implement " + "get_frame_protector_type"), + result)); + return; + } + tsi_zero_copy_grpc_protector* zero_copy_protector = nullptr; + tsi_frame_protector* protector = nullptr; + switch (frame_protector_type) { + case TSI_FRAME_PROTECTOR_ZERO_COPY: + ABSL_FALLTHROUGH_INTENDED; + case TSI_FRAME_PROTECTOR_NORMAL_OR_ZERO_COPY: + // Create zero-copy frame protector. + result = tsi_handshaker_result_create_zero_copy_grpc_protector( + handshaker_result_, max_frame_size_ == 0 ? nullptr : &max_frame_size_, + &zero_copy_protector); + if (result != TSI_OK) { + HandshakeFailedLocked(grpc_set_tsi_error_result( + GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Zero-copy frame protector creation failed"), + result)); + return; + } + break; + case TSI_FRAME_PROTECTOR_NORMAL: + // Create normal frame protector. + result = tsi_handshaker_result_create_frame_protector( + handshaker_result_, max_frame_size_ == 0 ? nullptr : &max_frame_size_, + &protector); + if (result != TSI_OK) { + HandshakeFailedLocked( + grpc_set_tsi_error_result(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Frame protector creation failed"), + result)); + return; + } + break; + case TSI_FRAME_PROTECTOR_NONE: + break; + } + // If we have a frame protector, create a secure endpoint. + if (zero_copy_protector != nullptr || protector != nullptr) { + if (unused_bytes_size > 0) { + grpc_slice slice = grpc_slice_from_copied_buffer( + reinterpret_cast(unused_bytes), unused_bytes_size); + args_->endpoint = grpc_secure_endpoint_create( + protector, zero_copy_protector, args_->endpoint, &slice, 1); + grpc_slice_unref_internal(slice); + } else { + args_->endpoint = grpc_secure_endpoint_create( + protector, zero_copy_protector, args_->endpoint, nullptr, 0); + } + } else if (unused_bytes_size > 0) { + // Not wrapping the endpoint, so just pass along unused bytes. grpc_slice slice = grpc_slice_from_copied_buffer( reinterpret_cast(unused_bytes), unused_bytes_size); - args_->endpoint = grpc_secure_endpoint_create( - protector, zero_copy_protector, args_->endpoint, &slice, 1); - grpc_slice_unref_internal(slice); - } else { - args_->endpoint = grpc_secure_endpoint_create( - protector, zero_copy_protector, args_->endpoint, nullptr, 0); + grpc_slice_buffer_add(args_->read_buffer, slice); } + // Done with handshaker result. tsi_handshaker_result_destroy(handshaker_result_); handshaker_result_ = nullptr; // Add auth context to channel args. @@ -437,7 +461,6 @@ void SecurityHandshaker::OnHandshakeDataReceivedFromPeerFn( size_t bytes_received_size = h->MoveReadBufferIntoHandshakeBuffer(); // Call TSI handshaker. error = h->DoHandshakerNextLocked(h->handshake_buffer_, bytes_received_size); - if (error != GRPC_ERROR_NONE) { h->HandshakeFailedLocked(error); } else { diff --git a/src/core/tsi/alts/handshaker/alts_tsi_handshaker.cc b/src/core/tsi/alts/handshaker/alts_tsi_handshaker.cc index 9cca8ab4726..e3b529419b2 100644 --- a/src/core/tsi/alts/handshaker/alts_tsi_handshaker.cc +++ b/src/core/tsi/alts/handshaker/alts_tsi_handshaker.cc @@ -151,6 +151,13 @@ static tsi_result handshaker_result_extract_peer( return ok; } +static tsi_result handshaker_result_get_frame_protector_type( + const tsi_handshaker_result* /*self*/, + tsi_frame_protector_type* frame_protector_type) { + *frame_protector_type = TSI_FRAME_PROTECTOR_NORMAL_OR_ZERO_COPY; + return TSI_OK; +} + static tsi_result handshaker_result_create_zero_copy_grpc_protector( const tsi_handshaker_result* self, size_t* max_output_protected_frame_size, tsi_zero_copy_grpc_protector** protector) { @@ -247,9 +254,11 @@ static void handshaker_result_destroy(tsi_handshaker_result* self) { static const tsi_handshaker_result_vtable result_vtable = { handshaker_result_extract_peer, + handshaker_result_get_frame_protector_type, handshaker_result_create_zero_copy_grpc_protector, handshaker_result_create_frame_protector, - handshaker_result_get_unused_bytes, handshaker_result_destroy}; + handshaker_result_get_unused_bytes, + handshaker_result_destroy}; tsi_result alts_tsi_handshaker_result_create(grpc_gcp_HandshakerResp* resp, bool is_client, diff --git a/src/core/tsi/fake_transport_security.cc b/src/core/tsi/fake_transport_security.cc index 727e79cca91..1484dce0d80 100644 --- a/src/core/tsi/fake_transport_security.cc +++ b/src/core/tsi/fake_transport_security.cc @@ -498,6 +498,7 @@ struct fake_handshaker_result { unsigned char* unused_bytes; size_t unused_bytes_size; }; + static tsi_result fake_handshaker_result_extract_peer( const tsi_handshaker_result* /*self*/, tsi_peer* peer) { /* Construct a tsi_peer with 1 property: certificate type, security_level. */ @@ -514,6 +515,13 @@ static tsi_result fake_handshaker_result_extract_peer( return result; } +static tsi_result fake_handshaker_result_get_frame_protector_type( + const tsi_handshaker_result* /*self*/, + tsi_frame_protector_type* frame_protector_type) { + *frame_protector_type = TSI_FRAME_PROTECTOR_NORMAL_OR_ZERO_COPY; + return TSI_OK; +} + static tsi_result fake_handshaker_result_create_zero_copy_grpc_protector( const tsi_handshaker_result* /*self*/, size_t* max_output_protected_frame_size, @@ -549,6 +557,7 @@ static void fake_handshaker_result_destroy(tsi_handshaker_result* self) { static const tsi_handshaker_result_vtable handshaker_result_vtable = { fake_handshaker_result_extract_peer, + fake_handshaker_result_get_frame_protector_type, fake_handshaker_result_create_zero_copy_grpc_protector, fake_handshaker_result_create_frame_protector, fake_handshaker_result_get_unused_bytes, diff --git a/src/core/tsi/local_transport_security.cc b/src/core/tsi/local_transport_security.cc index 97a33493e21..e143aa1451e 100644 --- a/src/core/tsi/local_transport_security.cc +++ b/src/core/tsi/local_transport_security.cc @@ -52,60 +52,6 @@ typedef struct local_tsi_handshaker { bool is_client; } local_tsi_handshaker; -/* --- tsi_zero_copy_grpc_protector methods implementation. --- */ - -static tsi_result local_zero_copy_grpc_protector_protect( - tsi_zero_copy_grpc_protector* self, grpc_slice_buffer* unprotected_slices, - grpc_slice_buffer* protected_slices) { - if (self == nullptr || unprotected_slices == nullptr || - protected_slices == nullptr) { - gpr_log(GPR_ERROR, "Invalid nullptr arguments to zero-copy grpc protect."); - return TSI_INVALID_ARGUMENT; - } - grpc_slice_buffer_move_into(unprotected_slices, protected_slices); - return TSI_OK; -} - -static tsi_result local_zero_copy_grpc_protector_unprotect( - tsi_zero_copy_grpc_protector* self, grpc_slice_buffer* protected_slices, - grpc_slice_buffer* unprotected_slices) { - if (self == nullptr || unprotected_slices == nullptr || - protected_slices == nullptr) { - gpr_log(GPR_ERROR, - "Invalid nullptr arguments to zero-copy grpc unprotect."); - return TSI_INVALID_ARGUMENT; - } - grpc_slice_buffer_move_into(protected_slices, unprotected_slices); - return TSI_OK; -} - -static void local_zero_copy_grpc_protector_destroy( - tsi_zero_copy_grpc_protector* self) { - gpr_free(self); -} - -static const tsi_zero_copy_grpc_protector_vtable - local_zero_copy_grpc_protector_vtable = { - local_zero_copy_grpc_protector_protect, - local_zero_copy_grpc_protector_unprotect, - local_zero_copy_grpc_protector_destroy, - nullptr /* local_zero_copy_grpc_protector_max_frame_size */}; - -tsi_result local_zero_copy_grpc_protector_create( - tsi_zero_copy_grpc_protector** protector) { - if (grpc_core::ExecCtx::Get() == nullptr || protector == nullptr) { - gpr_log( - GPR_ERROR, - "Invalid nullptr arguments to local_zero_copy_grpc_protector create."); - return TSI_INVALID_ARGUMENT; - } - local_zero_copy_grpc_protector* impl = - static_cast(gpr_zalloc(sizeof(*impl))); - impl->base.vtable = &local_zero_copy_grpc_protector_vtable; - *protector = &impl->base; - return TSI_OK; -} - /* --- tsi_handshaker_result methods implementation. --- */ static tsi_result handshaker_result_extract_peer( @@ -113,20 +59,11 @@ static tsi_result handshaker_result_extract_peer( return TSI_OK; } -static tsi_result handshaker_result_create_zero_copy_grpc_protector( - const tsi_handshaker_result* self, - size_t* /*max_output_protected_frame_size*/, - tsi_zero_copy_grpc_protector** protector) { - if (self == nullptr || protector == nullptr) { - gpr_log(GPR_ERROR, - "Invalid arguments to create_zero_copy_grpc_protector()"); - return TSI_INVALID_ARGUMENT; - } - tsi_result ok = local_zero_copy_grpc_protector_create(protector); - if (ok != TSI_OK) { - gpr_log(GPR_ERROR, "Failed to create zero-copy grpc protector"); - } - return ok; +static tsi_result handshaker_result_get_frame_protector_type( + const tsi_handshaker_result* /*self*/, + tsi_frame_protector_type* frame_protector_type) { + *frame_protector_type = TSI_FRAME_PROTECTOR_NONE; + return TSI_OK; } static tsi_result handshaker_result_get_unused_bytes( @@ -156,9 +93,11 @@ static void handshaker_result_destroy(tsi_handshaker_result* self) { static const tsi_handshaker_result_vtable result_vtable = { handshaker_result_extract_peer, - handshaker_result_create_zero_copy_grpc_protector, + handshaker_result_get_frame_protector_type, + nullptr, /* handshaker_result_create_zero_copy_grpc_protector */ nullptr, /* handshaker_result_create_frame_protector */ - handshaker_result_get_unused_bytes, handshaker_result_destroy}; + handshaker_result_get_unused_bytes, + handshaker_result_destroy}; static tsi_result create_handshaker_result(bool is_client, const unsigned char* received_bytes, diff --git a/src/core/tsi/ssl_transport_security.cc b/src/core/tsi/ssl_transport_security.cc index b687207bd5b..0ff41e72969 100644 --- a/src/core/tsi/ssl_transport_security.cc +++ b/src/core/tsi/ssl_transport_security.cc @@ -1302,6 +1302,13 @@ static tsi_result ssl_handshaker_result_extract_peer( return result; } +static tsi_result ssl_handshaker_result_get_frame_protector_type( + const tsi_handshaker_result* /*self*/, + tsi_frame_protector_type* frame_protector_type) { + *frame_protector_type = TSI_FRAME_PROTECTOR_NORMAL; + return TSI_OK; +} + static tsi_result ssl_handshaker_result_create_frame_protector( const tsi_handshaker_result* self, size_t* max_output_protected_frame_size, tsi_frame_protector** protector) { @@ -1368,6 +1375,7 @@ static void ssl_handshaker_result_destroy(tsi_handshaker_result* self) { static const tsi_handshaker_result_vtable handshaker_result_vtable = { ssl_handshaker_result_extract_peer, + ssl_handshaker_result_get_frame_protector_type, nullptr, /* create_zero_copy_grpc_protector */ ssl_handshaker_result_create_frame_protector, ssl_handshaker_result_get_unused_bytes, diff --git a/src/core/tsi/transport_security.cc b/src/core/tsi/transport_security.cc index f08093dc58c..5d822604747 100644 --- a/src/core/tsi/transport_security.cc +++ b/src/core/tsi/transport_security.cc @@ -251,6 +251,18 @@ tsi_result tsi_handshaker_result_extract_peer(const tsi_handshaker_result* self, return self->vtable->extract_peer(self, peer); } +tsi_result tsi_handshaker_result_get_frame_protector_type( + const tsi_handshaker_result* self, + tsi_frame_protector_type* frame_protector_type) { + if (self == nullptr || frame_protector_type == nullptr) { + return TSI_INVALID_ARGUMENT; + } + if (self->vtable->get_frame_protector_type == nullptr) { + return TSI_UNIMPLEMENTED; + } + return self->vtable->get_frame_protector_type(self, frame_protector_type); +} + tsi_result tsi_handshaker_result_create_frame_protector( const tsi_handshaker_result* self, size_t* max_output_protected_frame_size, tsi_frame_protector** protector) { diff --git a/src/core/tsi/transport_security.h b/src/core/tsi/transport_security.h index d3b1c7f4aa9..53f85d2852b 100644 --- a/src/core/tsi/transport_security.h +++ b/src/core/tsi/transport_security.h @@ -29,7 +29,8 @@ extern grpc_core::TraceFlag tsi_tracing_enabled; /* Base for tsi_frame_protector implementations. - See transport_security_interface.h for documentation. */ + See transport_security_interface.h for documentation. + All methods must be implemented. */ struct tsi_frame_protector_vtable { tsi_result (*protect)(tsi_frame_protector* self, const unsigned char* unprotected_bytes, @@ -54,6 +55,9 @@ struct tsi_frame_protector { /* Base for tsi_handshaker implementations. See transport_security_interface.h for documentation. */ struct tsi_handshaker_vtable { + /* Methods for supporting the old synchronous API. + These can be null if the TSI impl supports only the new + async-capable API. */ tsi_result (*get_bytes_to_send_to_peer)(tsi_handshaker* self, unsigned char* bytes, size_t* bytes_size); @@ -65,7 +69,10 @@ struct tsi_handshaker_vtable { tsi_result (*create_frame_protector)(tsi_handshaker* self, size_t* max_protected_frame_size, tsi_frame_protector** protector); + /* Must be implemented by all TSI impls. */ void (*destroy)(tsi_handshaker* self); + /* Methods for supporting the new async-capable API. + These can be null if the TSI impl supports only the old sync API. */ tsi_result (*next)(tsi_handshaker* self, const unsigned char* received_bytes, size_t received_bytes_size, const unsigned char** bytes_to_send, @@ -88,13 +95,21 @@ struct tsi_handshaker { API depend on grpc. The create_zero_copy_grpc_protector() method is only used in grpc, where we do need the exec_ctx passed through, but the API still needs to compile in other applications, where grpc_exec_ctx is not defined. + All methods must be non-null, except where noted below. */ struct tsi_handshaker_result_vtable { tsi_result (*extract_peer)(const tsi_handshaker_result* self, tsi_peer* peer); + tsi_result (*get_frame_protector_type)( + const tsi_handshaker_result* self, + tsi_frame_protector_type* frame_protector_type); + /* May be null if get_frame_protector_type() returns + TSI_FRAME_PROTECTOR_NORMAL or TSI_FRAME_PROTECTOR_NONE. */ tsi_result (*create_zero_copy_grpc_protector)( const tsi_handshaker_result* self, size_t* max_output_protected_frame_size, tsi_zero_copy_grpc_protector** protector); + /* May be null if get_frame_protector_type() returns + TSI_FRAME_PROTECTOR_ZERO_COPY or TSI_FRAME_PROTECTOR_NONE. */ tsi_result (*create_frame_protector)(const tsi_handshaker_result* self, size_t* max_output_protected_frame_size, tsi_frame_protector** protector); diff --git a/src/core/tsi/transport_security_interface.h b/src/core/tsi/transport_security_interface.h index 50beccac328..0ba139379be 100644 --- a/src/core/tsi/transport_security_interface.h +++ b/src/core/tsi/transport_security_interface.h @@ -64,6 +64,26 @@ typedef enum { TSI_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY, } tsi_client_certificate_request_type; +typedef enum { + // TSI implementation provides a normal frame protector. The caller + // should invoke tsi_handshaker_result_create_frame_protector() to + // generate the frame protector. + TSI_FRAME_PROTECTOR_NORMAL, + // TSI implementation provides a zero-copy frame protector. The caller + // should invoke tsi_handshaker_result_create_zero_copy_grpc_protector() + // to generate the frame protector. + TSI_FRAME_PROTECTOR_ZERO_COPY, + // TSI implementation provides both normal and zero-copy frame protectors. + // The caller should invoke either + // tsi_handshaker_result_create_frame_protector() or + // tsi_handshaker_result_create_zero_copy_grpc_protector() to generate + // the frame protector. + TSI_FRAME_PROTECTOR_NORMAL_OR_ZERO_COPY, + // TSI implementation does not provide any frame protector. This means + // that it is safe for the caller to send bytes unprotected on the wire. + TSI_FRAME_PROTECTOR_NONE, +} tsi_frame_protector_type; + typedef enum { TSI_TLS1_2, TSI_TLS1_3, @@ -234,6 +254,12 @@ typedef struct tsi_handshaker_result tsi_handshaker_result; tsi_result tsi_handshaker_result_extract_peer(const tsi_handshaker_result* self, tsi_peer* peer); +/* This method indicates what type of frame protector is provided by the + TSI implementation. */ +tsi_result tsi_handshaker_result_get_frame_protector_type( + const tsi_handshaker_result* self, + tsi_frame_protector_type* frame_protector_type); + /* This method creates a tsi_frame_protector object. It returns TSI_OK assuming there is no fatal error. The caller is responsible for destroying the protector. */