From f48119619ed71e4d9d15093c4c588da2c360121c Mon Sep 17 00:00:00 2001 From: "Mark D. Roth" Date: Tue, 28 Sep 2021 16:18:39 -0700 Subject: [PATCH] Fix local TSI impl to pass along unused bytes (#27508) * try insecure creds in proxy case * revert server-side changes * fix local transport security to pass along unused bytes * fix security handshaker to check the result of TSI get_unused_bytes() * fix local TSI impl get_unused_bytes() to check its params * clang-format Co-authored-by: yihuaz --- .../security/transport/security_handshaker.cc | 7 ++++ src/core/tsi/local_transport_security.cc | 35 ++++++++++++++++--- test/core/end2end/fixtures/h2_http_proxy.cc | 7 ++-- 3 files changed, 42 insertions(+), 7 deletions(-) diff --git a/src/core/lib/security/transport/security_handshaker.cc b/src/core/lib/security/transport/security_handshaker.cc index d062e90fdc4..1d894876b9c 100644 --- a/src/core/lib/security/transport/security_handshaker.cc +++ b/src/core/lib/security/transport/security_handshaker.cc @@ -264,6 +264,13 @@ void SecurityHandshaker::OnPeerCheckedInner(grpc_error_handle error) { size_t unused_bytes_size = 0; result = tsi_handshaker_result_get_unused_bytes( handshaker_result_, &unused_bytes, &unused_bytes_size); + if (result != TSI_OK) { + HandshakeFailedLocked(grpc_set_tsi_error_result( + GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "TSI handshaker result does not provide unused bytes"), + result)); + return; + } // Create secure endpoint. if (unused_bytes_size > 0) { grpc_slice slice = grpc_slice_from_copied_buffer( diff --git a/src/core/tsi/local_transport_security.cc b/src/core/tsi/local_transport_security.cc index 19043cd1a9a..97a33493e21 100644 --- a/src/core/tsi/local_transport_security.cc +++ b/src/core/tsi/local_transport_security.cc @@ -42,6 +42,8 @@ typedef struct local_zero_copy_grpc_protector { typedef struct local_tsi_handshaker_result { tsi_handshaker_result base; bool is_client; + unsigned char* unused_bytes; + size_t unused_bytes_size; } local_tsi_handshaker_result; /* Main struct for local TSI handshaker. */ @@ -127,6 +129,20 @@ static tsi_result handshaker_result_create_zero_copy_grpc_protector( return ok; } +static tsi_result handshaker_result_get_unused_bytes( + const tsi_handshaker_result* self, const unsigned char** bytes, + size_t* bytes_size) { + if (self == nullptr || bytes == nullptr || bytes_size == nullptr) { + gpr_log(GPR_ERROR, "Invalid arguments to get_unused_bytes()"); + return TSI_INVALID_ARGUMENT; + } + auto* result = reinterpret_cast( + const_cast(self)); + *bytes_size = result->unused_bytes_size; + *bytes = result->unused_bytes; + return TSI_OK; +} + static void handshaker_result_destroy(tsi_handshaker_result* self) { if (self == nullptr) { return; @@ -134,6 +150,7 @@ static void handshaker_result_destroy(tsi_handshaker_result* self) { local_tsi_handshaker_result* result = reinterpret_cast( const_cast(self)); + gpr_free(result->unused_bytes); gpr_free(result); } @@ -141,10 +158,11 @@ static const tsi_handshaker_result_vtable result_vtable = { handshaker_result_extract_peer, handshaker_result_create_zero_copy_grpc_protector, nullptr, /* handshaker_result_create_frame_protector */ - nullptr, /* handshaker_result_get_unused_bytes */ - handshaker_result_destroy}; + handshaker_result_get_unused_bytes, handshaker_result_destroy}; static tsi_result create_handshaker_result(bool is_client, + const unsigned char* received_bytes, + size_t received_bytes_size, tsi_handshaker_result** self) { if (self == nullptr) { gpr_log(GPR_ERROR, "Invalid arguments to create_handshaker_result()"); @@ -153,6 +171,12 @@ static tsi_result create_handshaker_result(bool is_client, local_tsi_handshaker_result* result = static_cast(gpr_zalloc(sizeof(*result))); result->is_client = is_client; + if (received_bytes_size > 0) { + result->unused_bytes = + static_cast(gpr_malloc(received_bytes_size)); + memcpy(result->unused_bytes, received_bytes, received_bytes_size); + } + result->unused_bytes_size = received_bytes_size; result->base.vtable = &result_vtable; *self = &result->base; return TSI_OK; @@ -161,8 +185,8 @@ static tsi_result create_handshaker_result(bool is_client, /* --- tsi_handshaker methods implementation. --- */ static tsi_result handshaker_next( - tsi_handshaker* self, const unsigned char* /*received_bytes*/, - size_t /*received_bytes_size*/, const unsigned char** /*bytes_to_send*/, + tsi_handshaker* self, const unsigned char* received_bytes, + size_t received_bytes_size, const unsigned char** /*bytes_to_send*/, size_t* bytes_to_send_size, tsi_handshaker_result** result, tsi_handshaker_on_next_done_cb /*cb*/, void* /*user_data*/) { if (self == nullptr) { @@ -175,7 +199,8 @@ static tsi_result handshaker_next( local_tsi_handshaker* handshaker = reinterpret_cast(self); *bytes_to_send_size = 0; - create_handshaker_result(handshaker->is_client, result); + create_handshaker_result(handshaker->is_client, received_bytes, + received_bytes_size, result); return TSI_OK; } diff --git a/test/core/end2end/fixtures/h2_http_proxy.cc b/test/core/end2end/fixtures/h2_http_proxy.cc index 61a3955633b..a4b389c86fe 100644 --- a/test/core/end2end/fixtures/h2_http_proxy.cc +++ b/test/core/end2end/fixtures/h2_http_proxy.cc @@ -22,6 +22,7 @@ #include "absl/strings/str_format.h" +#include #include #include #include @@ -81,8 +82,10 @@ void chttp2_init_client_fullstack(grpc_end2end_test_fixture* f, grpc_end2end_http_proxy_get_proxy_name(ffd->proxy)); } gpr_setenv("http_proxy", proxy_uri.c_str()); - f->client = grpc_insecure_channel_create(ffd->server_addr.c_str(), - client_args, nullptr); + grpc_channel_credentials* creds = grpc_insecure_credentials_create(); + f->client = grpc_secure_channel_create(creds, ffd->server_addr.c_str(), + client_args, nullptr); + grpc_channel_credentials_release(creds); GPR_ASSERT(f->client); }