Revert "Revert "TSI: return handshaker error message for inclusion in RPC failure status (#30077)" (#30284)" (#30286)

This reverts commit 8aeb548590.
pull/30401/head
Mark D. Roth 2 years ago committed by GitHub
parent 309d83832c
commit 18d82d4a6a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 9
      src/core/lib/security/transport/security_handshaker.cc
  2. 61
      src/core/tsi/alts/handshaker/alts_handshaker_client.cc
  3. 2
      src/core/tsi/alts/handshaker/alts_handshaker_client.h
  4. 21
      src/core/tsi/alts/handshaker/alts_tsi_handshaker.cc
  5. 2
      src/core/tsi/alts/handshaker/alts_tsi_handshaker_private.h
  6. 83
      src/core/tsi/fake_transport_security.cc
  7. 14
      src/core/tsi/local_transport_security.cc
  8. 63
      src/core/tsi/ssl_transport_security.cc
  9. 24
      src/core/tsi/transport_security.cc
  10. 3
      src/core/tsi/transport_security.h
  11. 22
      src/core/tsi/transport_security_interface.h
  12. 4
      test/core/tsi/alts/handshaker/alts_handshaker_client_test.cc
  13. 38
      test/core/tsi/alts/handshaker/alts_tsi_handshaker_test.cc

@ -131,6 +131,7 @@ class SecurityHandshaker : public Handshaker {
RefCountedPtr<grpc_auth_context> auth_context_; RefCountedPtr<grpc_auth_context> auth_context_;
tsi_handshaker_result* handshaker_result_ = nullptr; tsi_handshaker_result* handshaker_result_ = nullptr;
size_t max_frame_size_ = 0; size_t max_frame_size_ = 0;
std::string tsi_handshake_error_;
}; };
SecurityHandshaker::SecurityHandshaker(tsi_handshaker* handshaker, SecurityHandshaker::SecurityHandshaker(tsi_handshaker* handshaker,
@ -392,8 +393,9 @@ grpc_error_handle SecurityHandshaker::OnHandshakeNextDoneLocked(
connector_type = security_connector->type().name(); connector_type = security_connector->type().name();
} }
return grpc_set_tsi_error_result( return grpc_set_tsi_error_result(
GRPC_ERROR_CREATE_FROM_CPP_STRING( GRPC_ERROR_CREATE_FROM_CPP_STRING(absl::StrCat(
absl::StrCat(connector_type, " handshake failed")), connector_type, " handshake failed",
(tsi_handshake_error_.empty() ? "" : ": "), tsi_handshake_error_)),
result); result);
} }
// Update handshaker result. // Update handshaker result.
@ -453,7 +455,8 @@ grpc_error_handle SecurityHandshaker::DoHandshakerNextLocked(
tsi_handshaker_result* hs_result = nullptr; tsi_handshaker_result* hs_result = nullptr;
tsi_result result = tsi_handshaker_next( tsi_result result = tsi_handshaker_next(
handshaker_, bytes_received, bytes_received_size, &bytes_to_send, handshaker_, bytes_received, bytes_received_size, &bytes_to_send,
&bytes_to_send_size, &hs_result, &OnHandshakeNextDoneGrpcWrapper, this); &bytes_to_send_size, &hs_result, &OnHandshakeNextDoneGrpcWrapper, this,
&tsi_handshake_error_);
if (result == TSI_ASYNC) { if (result == TSI_ASYNC) {
// Handshaker operating asynchronously. Nothing else to do here; // Handshaker operating asynchronously. Nothing else to do here;
// callback will be invoked in a TSI thread. // callback will be invoked in a TSI thread.

@ -71,7 +71,8 @@ typedef struct alts_grpc_handshaker_client {
* handshaker service. */ * handshaker service. */
grpc_byte_buffer* send_buffer = nullptr; grpc_byte_buffer* send_buffer = nullptr;
grpc_byte_buffer* recv_buffer = nullptr; grpc_byte_buffer* recv_buffer = nullptr;
grpc_status_code status = GRPC_STATUS_OK; // Used to inject a read failure from tests.
bool inject_read_failure = false;
/* Initial metadata to be received from handshaker service. */ /* Initial metadata to be received from handshaker service. */
grpc_metadata_array recv_initial_metadata; grpc_metadata_array recv_initial_metadata;
/* A callback function provided by an application to be invoked when response /* A callback function provided by an application to be invoked when response
@ -106,6 +107,8 @@ typedef struct alts_grpc_handshaker_client {
recv_message_result* pending_recv_message_result = nullptr; recv_message_result* pending_recv_message_result = nullptr;
/* Maximum frame size used by frame protector. */ /* Maximum frame size used by frame protector. */
size_t max_frame_size; size_t max_frame_size;
// If non-null, will be populated with an error string upon error.
std::string* error;
} alts_grpc_handshaker_client; } alts_grpc_handshaker_client;
static void handshaker_client_send_buffer_destroy( static void handshaker_client_send_buffer_destroy(
@ -174,10 +177,11 @@ static void maybe_complete_tsi_next(
} }
static void handle_response_done(alts_grpc_handshaker_client* client, static void handle_response_done(alts_grpc_handshaker_client* client,
tsi_result status, tsi_result status, std::string error,
const unsigned char* bytes_to_send, const unsigned char* bytes_to_send,
size_t bytes_to_send_size, size_t bytes_to_send_size,
tsi_handshaker_result* result) { tsi_handshaker_result* result) {
if (client->error != nullptr) *client->error = std::move(error);
recv_message_result* p = grpc_core::Zalloc<recv_message_result>(); recv_message_result* p = grpc_core::Zalloc<recv_message_result>();
p->status = status; p->status = status;
p->bytes_to_send = bytes_to_send; p->bytes_to_send = bytes_to_send;
@ -193,7 +197,6 @@ void alts_handshaker_client_handle_response(alts_handshaker_client* c,
alts_grpc_handshaker_client* client = alts_grpc_handshaker_client* client =
reinterpret_cast<alts_grpc_handshaker_client*>(c); reinterpret_cast<alts_grpc_handshaker_client*>(c);
grpc_byte_buffer* recv_buffer = client->recv_buffer; grpc_byte_buffer* recv_buffer = client->recv_buffer;
grpc_status_code status = client->status;
alts_tsi_handshaker* handshaker = client->handshaker; alts_tsi_handshaker* handshaker = client->handshaker;
/* Invalid input check. */ /* Invalid input check. */
if (client->cb == nullptr) { if (client->cb == nullptr) {
@ -204,25 +207,34 @@ void alts_handshaker_client_handle_response(alts_handshaker_client* c,
if (handshaker == nullptr) { if (handshaker == nullptr) {
gpr_log(GPR_ERROR, gpr_log(GPR_ERROR,
"handshaker is nullptr in alts_tsi_handshaker_handle_response()"); "handshaker is nullptr in alts_tsi_handshaker_handle_response()");
handle_response_done(client, TSI_INTERNAL_ERROR, nullptr, 0, nullptr); handle_response_done(
client, TSI_INTERNAL_ERROR,
"handshaker is nullptr in alts_tsi_handshaker_handle_response()",
nullptr, 0, nullptr);
return; return;
} }
/* TSI handshake has been shutdown. */ /* TSI handshake has been shutdown. */
if (alts_tsi_handshaker_has_shutdown(handshaker)) { if (alts_tsi_handshaker_has_shutdown(handshaker)) {
gpr_log(GPR_INFO, "TSI handshake shutdown"); gpr_log(GPR_INFO, "TSI handshake shutdown");
handle_response_done(client, TSI_HANDSHAKE_SHUTDOWN, nullptr, 0, nullptr); handle_response_done(client, TSI_HANDSHAKE_SHUTDOWN,
"TSI handshake shutdown", nullptr, 0, nullptr);
return; return;
} }
/* Failed grpc call check. */ /* Check for failed grpc read. */
if (!is_ok || status != GRPC_STATUS_OK) { if (!is_ok || client->inject_read_failure) {
gpr_log(GPR_INFO, "grpc call made to handshaker service failed"); gpr_log(GPR_INFO, "read failed on grpc call to handshaker service");
handle_response_done(client, TSI_INTERNAL_ERROR, nullptr, 0, nullptr); handle_response_done(client, TSI_INTERNAL_ERROR,
"read failed on grpc call to handshaker service",
nullptr, 0, nullptr);
return; return;
} }
if (recv_buffer == nullptr) { if (recv_buffer == nullptr) {
gpr_log(GPR_ERROR, gpr_log(GPR_ERROR,
"recv_buffer is nullptr in alts_tsi_handshaker_handle_response()"); "recv_buffer is nullptr in alts_tsi_handshaker_handle_response()");
handle_response_done(client, TSI_INTERNAL_ERROR, nullptr, 0, nullptr); handle_response_done(
client, TSI_INTERNAL_ERROR,
"recv_buffer is nullptr in alts_tsi_handshaker_handle_response()",
nullptr, 0, nullptr);
return; return;
} }
upb::Arena arena; upb::Arena arena;
@ -233,14 +245,17 @@ void alts_handshaker_client_handle_response(alts_handshaker_client* c,
/* Invalid handshaker response check. */ /* Invalid handshaker response check. */
if (resp == nullptr) { if (resp == nullptr) {
gpr_log(GPR_ERROR, "alts_tsi_utils_deserialize_response() failed"); gpr_log(GPR_ERROR, "alts_tsi_utils_deserialize_response() failed");
handle_response_done(client, TSI_DATA_CORRUPTED, nullptr, 0, nullptr); handle_response_done(client, TSI_DATA_CORRUPTED,
"alts_tsi_utils_deserialize_response() failed",
nullptr, 0, nullptr);
return; return;
} }
const grpc_gcp_HandshakerStatus* resp_status = const grpc_gcp_HandshakerStatus* resp_status =
grpc_gcp_HandshakerResp_status(resp); grpc_gcp_HandshakerResp_status(resp);
if (resp_status == nullptr) { if (resp_status == nullptr) {
gpr_log(GPR_ERROR, "No status in HandshakerResp"); gpr_log(GPR_ERROR, "No status in HandshakerResp");
handle_response_done(client, TSI_DATA_CORRUPTED, nullptr, 0, nullptr); handle_response_done(client, TSI_DATA_CORRUPTED,
"No status in HandshakerResp", nullptr, 0, nullptr);
return; return;
} }
upb_StringView out_frames = grpc_gcp_HandshakerResp_out_frames(resp); upb_StringView out_frames = grpc_gcp_HandshakerResp_out_frames(resp);
@ -262,7 +277,9 @@ void alts_handshaker_client_handle_response(alts_handshaker_client* c,
alts_tsi_handshaker_result_create(resp, client->is_client, &result); alts_tsi_handshaker_result_create(resp, client->is_client, &result);
if (status != TSI_OK) { if (status != TSI_OK) {
gpr_log(GPR_ERROR, "alts_tsi_handshaker_result_create() failed"); gpr_log(GPR_ERROR, "alts_tsi_handshaker_result_create() failed");
handle_response_done(client, status, nullptr, 0, nullptr); handle_response_done(client, status,
"alts_tsi_handshaker_result_create() failed",
nullptr, 0, nullptr);
return; return;
} }
alts_tsi_handshaker_result_set_unused_bytes( alts_tsi_handshaker_result_set_unused_bytes(
@ -271,13 +288,13 @@ void alts_handshaker_client_handle_response(alts_handshaker_client* c,
} }
grpc_status_code code = static_cast<grpc_status_code>( grpc_status_code code = static_cast<grpc_status_code>(
grpc_gcp_HandshakerStatus_code(resp_status)); grpc_gcp_HandshakerStatus_code(resp_status));
std::string error;
if (code != GRPC_STATUS_OK) { if (code != GRPC_STATUS_OK) {
upb_StringView details = grpc_gcp_HandshakerStatus_details(resp_status); upb_StringView details = grpc_gcp_HandshakerStatus_details(resp_status);
if (details.size > 0) { if (details.size > 0) {
char* error_details = static_cast<char*>(gpr_zalloc(details.size + 1)); error = absl::StrCat("Status ", code, " from handshaker service: ",
memcpy(error_details, details.data, details.size); absl::string_view(details.data, details.size));
gpr_log(GPR_ERROR, "Error from handshaker service:%s", error_details); gpr_log(GPR_ERROR, "%s", error.c_str());
gpr_free(error_details);
} }
} }
// TODO(apolcyn): consider short ciruiting handle_response_done and // TODO(apolcyn): consider short ciruiting handle_response_done and
@ -285,7 +302,8 @@ void alts_handshaker_client_handle_response(alts_handshaker_client* c,
// handle_response_done's allocation per message received causes // handle_response_done's allocation per message received causes
// a performance issue. // a performance issue.
handle_response_done(client, alts_tsi_utils_convert_to_tsi_result(code), handle_response_done(client, alts_tsi_utils_convert_to_tsi_result(code),
bytes_to_send, bytes_to_send_size, result); std::move(error), bytes_to_send, bytes_to_send_size,
result);
} }
static tsi_result continue_make_grpc_call(alts_grpc_handshaker_client* client, static tsi_result continue_make_grpc_call(alts_grpc_handshaker_client* client,
@ -690,7 +708,7 @@ alts_handshaker_client* alts_grpc_handshaker_client_create(
grpc_alts_credentials_options* options, const grpc_slice& target_name, grpc_alts_credentials_options* options, const grpc_slice& target_name,
grpc_iomgr_cb_func grpc_cb, tsi_handshaker_on_next_done_cb cb, grpc_iomgr_cb_func grpc_cb, tsi_handshaker_on_next_done_cb cb,
void* user_data, alts_handshaker_client_vtable* vtable_for_testing, void* user_data, alts_handshaker_client_vtable* vtable_for_testing,
bool is_client, size_t max_frame_size) { bool is_client, size_t max_frame_size, std::string* error) {
if (channel == nullptr || handshaker_service_url == nullptr) { if (channel == nullptr || handshaker_service_url == nullptr) {
gpr_log(GPR_ERROR, "Invalid arguments to alts_handshaker_client_create()"); gpr_log(GPR_ERROR, "Invalid arguments to alts_handshaker_client_create()");
return nullptr; return nullptr;
@ -713,6 +731,7 @@ alts_handshaker_client* alts_grpc_handshaker_client_create(
client->buffer = static_cast<unsigned char*>(gpr_zalloc(client->buffer_size)); client->buffer = static_cast<unsigned char*>(gpr_zalloc(client->buffer_size));
client->handshake_status_details = grpc_empty_slice(); client->handshake_status_details = grpc_empty_slice();
client->max_frame_size = max_frame_size; client->max_frame_size = max_frame_size;
client->error = error;
grpc_slice slice = grpc_slice_from_copied_string(handshaker_service_url); grpc_slice slice = grpc_slice_from_copied_string(handshaker_service_url);
client->call = client->call =
strcmp(handshaker_service_url, ALTS_HANDSHAKER_SERVICE_URL_FOR_TESTING) == strcmp(handshaker_service_url, ALTS_HANDSHAKER_SERVICE_URL_FOR_TESTING) ==
@ -776,7 +795,7 @@ void alts_handshaker_client_set_recv_bytes_for_testing(
void alts_handshaker_client_set_fields_for_testing( void alts_handshaker_client_set_fields_for_testing(
alts_handshaker_client* c, alts_tsi_handshaker* handshaker, alts_handshaker_client* c, alts_tsi_handshaker* handshaker,
tsi_handshaker_on_next_done_cb cb, void* user_data, tsi_handshaker_on_next_done_cb cb, void* user_data,
grpc_byte_buffer* recv_buffer, grpc_status_code status) { grpc_byte_buffer* recv_buffer, bool inject_read_failure) {
GPR_ASSERT(c != nullptr); GPR_ASSERT(c != nullptr);
alts_grpc_handshaker_client* client = alts_grpc_handshaker_client* client =
reinterpret_cast<alts_grpc_handshaker_client*>(c); reinterpret_cast<alts_grpc_handshaker_client*>(c);
@ -784,7 +803,7 @@ void alts_handshaker_client_set_fields_for_testing(
client->cb = cb; client->cb = cb;
client->user_data = user_data; client->user_data = user_data;
client->recv_buffer = recv_buffer; client->recv_buffer = recv_buffer;
client->status = status; client->inject_read_failure = inject_read_failure;
} }
void alts_handshaker_client_check_fields_for_testing( void alts_handshaker_client_check_fields_for_testing(

@ -144,7 +144,7 @@ alts_handshaker_client* alts_grpc_handshaker_client_create(
grpc_alts_credentials_options* options, const grpc_slice& target_name, grpc_alts_credentials_options* options, const grpc_slice& target_name,
grpc_iomgr_cb_func grpc_cb, tsi_handshaker_on_next_done_cb cb, grpc_iomgr_cb_func grpc_cb, tsi_handshaker_on_next_done_cb cb,
void* user_data, alts_handshaker_client_vtable* vtable_for_testing, void* user_data, alts_handshaker_client_vtable* vtable_for_testing,
bool is_client, size_t max_frame_size); bool is_client, size_t max_frame_size, std::string* error);
/** /**
* This method handles handshaker response returned from ALTS handshaker * This method handles handshaker response returned from ALTS handshaker

@ -416,7 +416,7 @@ static void on_handshaker_service_resp_recv_dedicated(
static tsi_result alts_tsi_handshaker_continue_handshaker_next( static tsi_result alts_tsi_handshaker_continue_handshaker_next(
alts_tsi_handshaker* handshaker, const unsigned char* received_bytes, alts_tsi_handshaker* handshaker, const unsigned char* received_bytes,
size_t received_bytes_size, tsi_handshaker_on_next_done_cb cb, size_t received_bytes_size, tsi_handshaker_on_next_done_cb cb,
void* user_data) { void* user_data, std::string* error) {
if (!handshaker->has_created_handshaker_client) { if (!handshaker->has_created_handshaker_client) {
if (handshaker->channel == nullptr) { if (handshaker->channel == nullptr) {
grpc_alts_shared_resource_dedicated_start( grpc_alts_shared_resource_dedicated_start(
@ -437,9 +437,10 @@ static tsi_result alts_tsi_handshaker_continue_handshaker_next(
handshaker->interested_parties, handshaker->options, handshaker->interested_parties, handshaker->options,
handshaker->target_name, grpc_cb, cb, user_data, handshaker->target_name, grpc_cb, cb, user_data,
handshaker->client_vtable_for_testing, handshaker->is_client, handshaker->client_vtable_for_testing, handshaker->is_client,
handshaker->max_frame_size); handshaker->max_frame_size, error);
if (client == nullptr) { if (client == nullptr) {
gpr_log(GPR_ERROR, "Failed to create ALTS handshaker client"); gpr_log(GPR_ERROR, "Failed to create ALTS handshaker client");
if (error != nullptr) *error = "Failed to create ALTS handshaker client";
return TSI_FAILED_PRECONDITION; return TSI_FAILED_PRECONDITION;
} }
{ {
@ -448,6 +449,7 @@ static tsi_result alts_tsi_handshaker_continue_handshaker_next(
handshaker->client = client; handshaker->client = client;
if (handshaker->shutdown) { if (handshaker->shutdown) {
gpr_log(GPR_INFO, "TSI handshake shutdown"); gpr_log(GPR_INFO, "TSI handshake shutdown");
if (error != nullptr) *error = "TSI handshaker shutdown";
return TSI_HANDSHAKE_SHUTDOWN; return TSI_HANDSHAKE_SHUTDOWN;
} }
} }
@ -490,6 +492,7 @@ struct alts_tsi_handshaker_continue_handshaker_next_args {
tsi_handshaker_on_next_done_cb cb; tsi_handshaker_on_next_done_cb cb;
void* user_data; void* user_data;
grpc_closure closure; grpc_closure closure;
std::string* error = nullptr;
}; };
static void alts_tsi_handshaker_create_channel( static void alts_tsi_handshaker_create_channel(
@ -510,7 +513,8 @@ static void alts_tsi_handshaker_create_channel(
tsi_result continue_next_result = tsi_result continue_next_result =
alts_tsi_handshaker_continue_handshaker_next( alts_tsi_handshaker_continue_handshaker_next(
handshaker, next_args->received_bytes.get(), handshaker, next_args->received_bytes.get(),
next_args->received_bytes_size, next_args->cb, next_args->user_data); next_args->received_bytes_size, next_args->cb, next_args->user_data,
next_args->error);
if (continue_next_result != TSI_OK) { if (continue_next_result != TSI_OK) {
next_args->cb(continue_next_result, next_args->user_data, nullptr, 0, next_args->cb(continue_next_result, next_args->user_data, nullptr, 0,
nullptr); nullptr);
@ -522,9 +526,10 @@ static tsi_result handshaker_next(
tsi_handshaker* self, const unsigned char* received_bytes, tsi_handshaker* self, const unsigned char* received_bytes,
size_t received_bytes_size, const unsigned char** /*bytes_to_send*/, size_t received_bytes_size, const unsigned char** /*bytes_to_send*/,
size_t* /*bytes_to_send_size*/, tsi_handshaker_result** /*result*/, size_t* /*bytes_to_send_size*/, tsi_handshaker_result** /*result*/,
tsi_handshaker_on_next_done_cb cb, void* user_data) { tsi_handshaker_on_next_done_cb cb, void* user_data, std::string* error) {
if (self == nullptr || cb == nullptr) { if (self == nullptr || cb == nullptr) {
gpr_log(GPR_ERROR, "Invalid arguments to handshaker_next()"); gpr_log(GPR_ERROR, "Invalid arguments to handshaker_next()");
if (error != nullptr) *error = "invalid argument";
return TSI_INVALID_ARGUMENT; return TSI_INVALID_ARGUMENT;
} }
alts_tsi_handshaker* handshaker = alts_tsi_handshaker* handshaker =
@ -533,6 +538,7 @@ static tsi_result handshaker_next(
grpc_core::MutexLock lock(&handshaker->mu); grpc_core::MutexLock lock(&handshaker->mu);
if (handshaker->shutdown) { if (handshaker->shutdown) {
gpr_log(GPR_INFO, "TSI handshake shutdown"); gpr_log(GPR_INFO, "TSI handshake shutdown");
if (error != nullptr) *error = "handshake shutdown";
return TSI_HANDSHAKE_SHUTDOWN; return TSI_HANDSHAKE_SHUTDOWN;
} }
} }
@ -542,6 +548,7 @@ static tsi_result handshaker_next(
args->handshaker = handshaker; args->handshaker = handshaker;
args->received_bytes = nullptr; args->received_bytes = nullptr;
args->received_bytes_size = received_bytes_size; args->received_bytes_size = received_bytes_size;
args->error = error;
if (received_bytes_size > 0) { if (received_bytes_size > 0) {
args->received_bytes = std::unique_ptr<unsigned char>( args->received_bytes = std::unique_ptr<unsigned char>(
static_cast<unsigned char*>(gpr_zalloc(received_bytes_size))); static_cast<unsigned char*>(gpr_zalloc(received_bytes_size)));
@ -559,7 +566,7 @@ static tsi_result handshaker_next(
grpc_core::ExecCtx::Run(DEBUG_LOCATION, &args->closure, GRPC_ERROR_NONE); grpc_core::ExecCtx::Run(DEBUG_LOCATION, &args->closure, GRPC_ERROR_NONE);
} else { } else {
tsi_result ok = alts_tsi_handshaker_continue_handshaker_next( tsi_result ok = alts_tsi_handshaker_continue_handshaker_next(
handshaker, received_bytes, received_bytes_size, cb, user_data); handshaker, received_bytes, received_bytes_size, cb, user_data, error);
if (ok != TSI_OK) { if (ok != TSI_OK) {
gpr_log(GPR_ERROR, "Failed to schedule ALTS handshaker requests"); gpr_log(GPR_ERROR, "Failed to schedule ALTS handshaker requests");
return ok; return ok;
@ -577,11 +584,11 @@ static tsi_result handshaker_next_dedicated(
tsi_handshaker* self, const unsigned char* received_bytes, tsi_handshaker* self, const unsigned char* received_bytes,
size_t received_bytes_size, const unsigned char** bytes_to_send, size_t received_bytes_size, const unsigned char** bytes_to_send,
size_t* bytes_to_send_size, tsi_handshaker_result** result, size_t* bytes_to_send_size, tsi_handshaker_result** result,
tsi_handshaker_on_next_done_cb cb, void* user_data) { tsi_handshaker_on_next_done_cb cb, void* user_data, std::string* error) {
grpc_core::ExecCtx exec_ctx; grpc_core::ExecCtx exec_ctx;
return handshaker_next(self, received_bytes, received_bytes_size, return handshaker_next(self, received_bytes, received_bytes_size,
bytes_to_send, bytes_to_send_size, result, cb, bytes_to_send, bytes_to_send_size, result, cb,
user_data); user_data, error);
} }
static void handshaker_shutdown(tsi_handshaker* self) { static void handshaker_shutdown(tsi_handshaker* self) {

@ -63,7 +63,7 @@ void alts_handshaker_client_check_fields_for_testing(
void alts_handshaker_client_set_fields_for_testing( void alts_handshaker_client_set_fields_for_testing(
alts_handshaker_client* client, alts_tsi_handshaker* handshaker, alts_handshaker_client* client, alts_tsi_handshaker* handshaker,
tsi_handshaker_on_next_done_cb cb, void* user_data, tsi_handshaker_on_next_done_cb cb, void* user_data,
grpc_byte_buffer* recv_buffer, grpc_status_code status); grpc_byte_buffer* recv_buffer, bool inject_read_failure);
void alts_handshaker_client_set_vtable_for_testing( void alts_handshaker_client_set_vtable_for_testing(
alts_handshaker_client* client, alts_handshaker_client_vtable* vtable); alts_handshaker_client* client, alts_handshaker_client_vtable* vtable);

@ -96,7 +96,8 @@ static const char* tsi_fake_handshake_message_to_string(int msg) {
} }
static tsi_result tsi_fake_handshake_message_from_string( static tsi_result tsi_fake_handshake_message_from_string(
const char* msg_string, tsi_fake_handshake_message* msg) { const char* msg_string, tsi_fake_handshake_message* msg,
std::string* error) {
for (int i = 0; i < TSI_FAKE_HANDSHAKE_MESSAGE_MAX; i++) { for (int i = 0; i < TSI_FAKE_HANDSHAKE_MESSAGE_MAX; i++) {
if (strncmp(msg_string, tsi_fake_handshake_message_strings[i], if (strncmp(msg_string, tsi_fake_handshake_message_strings[i],
strlen(tsi_fake_handshake_message_strings[i])) == 0) { strlen(tsi_fake_handshake_message_strings[i])) == 0) {
@ -105,6 +106,7 @@ static tsi_result tsi_fake_handshake_message_from_string(
} }
} }
gpr_log(GPR_ERROR, "Invalid handshake message."); gpr_log(GPR_ERROR, "Invalid handshake message.");
if (error != nullptr) *error = "invalid handshake message";
return TSI_DATA_CORRUPTED; return TSI_DATA_CORRUPTED;
} }
@ -174,12 +176,16 @@ static void tsi_fake_frame_ensure_size(tsi_fake_frame* frame) {
* This method should not be called if frame->needs_framing is not 0. */ * This method should not be called if frame->needs_framing is not 0. */
static tsi_result tsi_fake_frame_decode(const unsigned char* incoming_bytes, static tsi_result tsi_fake_frame_decode(const unsigned char* incoming_bytes,
size_t* incoming_bytes_size, size_t* incoming_bytes_size,
tsi_fake_frame* frame) { tsi_fake_frame* frame,
std::string* error) {
size_t available_size = *incoming_bytes_size; size_t available_size = *incoming_bytes_size;
size_t to_read_size = 0; size_t to_read_size = 0;
const unsigned char* bytes_cursor = incoming_bytes; const unsigned char* bytes_cursor = incoming_bytes;
if (frame->needs_draining) return TSI_INTERNAL_ERROR; if (frame->needs_draining) {
if (error != nullptr) *error = "fake handshaker frame needs draining";
return TSI_INTERNAL_ERROR;
}
if (frame->data == nullptr) { if (frame->data == nullptr) {
frame->allocated_size = TSI_FAKE_FRAME_INITIAL_ALLOCATED_SIZE; frame->allocated_size = TSI_FAKE_FRAME_INITIAL_ALLOCATED_SIZE;
frame->data = frame->data =
@ -224,9 +230,13 @@ static tsi_result tsi_fake_frame_decode(const unsigned char* incoming_bytes,
* This method should not be called if frame->needs_framing is 0. */ * This method should not be called if frame->needs_framing is 0. */
static tsi_result tsi_fake_frame_encode(unsigned char* outgoing_bytes, static tsi_result tsi_fake_frame_encode(unsigned char* outgoing_bytes,
size_t* outgoing_bytes_size, size_t* outgoing_bytes_size,
tsi_fake_frame* frame) { tsi_fake_frame* frame,
std::string* error) {
size_t to_write_size = frame->size - frame->offset; size_t to_write_size = frame->size - frame->offset;
if (!frame->needs_draining) return TSI_INTERNAL_ERROR; if (!frame->needs_draining) {
if (error != nullptr) *error = "fake frame needs draining";
return TSI_INTERNAL_ERROR;
}
if (*outgoing_bytes_size < to_write_size) { if (*outgoing_bytes_size < to_write_size) {
memcpy(outgoing_bytes, frame->data + frame->offset, *outgoing_bytes_size); memcpy(outgoing_bytes, frame->data + frame->offset, *outgoing_bytes_size);
frame->offset += *outgoing_bytes_size; frame->offset += *outgoing_bytes_size;
@ -240,15 +250,14 @@ static tsi_result tsi_fake_frame_encode(unsigned char* outgoing_bytes,
/* Sets the payload of a fake frame to contain the given data blob, where /* Sets the payload of a fake frame to contain the given data blob, where
* data_size indicates the size of data. */ * data_size indicates the size of data. */
static tsi_result tsi_fake_frame_set_data(unsigned char* data, size_t data_size, static void tsi_fake_frame_set_data(unsigned char* data, size_t data_size,
tsi_fake_frame* frame) { tsi_fake_frame* frame) {
frame->offset = 0; frame->offset = 0;
frame->size = data_size + TSI_FAKE_FRAME_HEADER_SIZE; frame->size = data_size + TSI_FAKE_FRAME_HEADER_SIZE;
tsi_fake_frame_ensure_size(frame); tsi_fake_frame_ensure_size(frame);
store32_little_endian(static_cast<uint32_t>(frame->size), frame->data); store32_little_endian(static_cast<uint32_t>(frame->size), frame->data);
memcpy(frame->data + TSI_FAKE_FRAME_HEADER_SIZE, data, data_size); memcpy(frame->data + TSI_FAKE_FRAME_HEADER_SIZE, data, data_size);
tsi_fake_frame_reset(frame, 1 /* needs draining */); tsi_fake_frame_reset(frame, 1 /* needs draining */);
return TSI_OK;
} }
/* Destroys the contents of a fake frame. */ /* Destroys the contents of a fake frame. */
@ -276,8 +285,8 @@ static tsi_result fake_protector_protect(tsi_frame_protector* self,
/* Try to drain first. */ /* Try to drain first. */
if (frame->needs_draining) { if (frame->needs_draining) {
drained_size = saved_output_size - *num_bytes_written; drained_size = saved_output_size - *num_bytes_written;
result = result = tsi_fake_frame_encode(protected_output_frames, &drained_size,
tsi_fake_frame_encode(protected_output_frames, &drained_size, frame); frame, /*error=*/nullptr);
*num_bytes_written += drained_size; *num_bytes_written += drained_size;
protected_output_frames += drained_size; protected_output_frames += drained_size;
if (result != TSI_OK) { if (result != TSI_OK) {
@ -297,7 +306,8 @@ static tsi_result fake_protector_protect(tsi_frame_protector* self,
store32_little_endian(static_cast<uint32_t>(impl->max_frame_size), store32_little_endian(static_cast<uint32_t>(impl->max_frame_size),
frame_header); frame_header);
written_in_frame_size = TSI_FAKE_FRAME_HEADER_SIZE; written_in_frame_size = TSI_FAKE_FRAME_HEADER_SIZE;
result = tsi_fake_frame_decode(frame_header, &written_in_frame_size, frame); result = tsi_fake_frame_decode(frame_header, &written_in_frame_size, frame,
/*error=*/nullptr);
if (result != TSI_INCOMPLETE_DATA) { if (result != TSI_INCOMPLETE_DATA) {
gpr_log(GPR_ERROR, "tsi_fake_frame_decode returned %s", gpr_log(GPR_ERROR, "tsi_fake_frame_decode returned %s",
tsi_result_to_string(result)); tsi_result_to_string(result));
@ -305,7 +315,8 @@ static tsi_result fake_protector_protect(tsi_frame_protector* self,
} }
} }
result = result =
tsi_fake_frame_decode(unprotected_bytes, unprotected_bytes_size, frame); tsi_fake_frame_decode(unprotected_bytes, unprotected_bytes_size, frame,
/*error=*/nullptr);
if (result != TSI_OK) { if (result != TSI_OK) {
if (result == TSI_INCOMPLETE_DATA) result = TSI_OK; if (result == TSI_INCOMPLETE_DATA) result = TSI_OK;
return result; return result;
@ -315,7 +326,8 @@ static tsi_result fake_protector_protect(tsi_frame_protector* self,
if (!frame->needs_draining) return TSI_INTERNAL_ERROR; if (!frame->needs_draining) return TSI_INTERNAL_ERROR;
if (frame->offset != 0) return TSI_INTERNAL_ERROR; if (frame->offset != 0) return TSI_INTERNAL_ERROR;
drained_size = saved_output_size - *num_bytes_written; drained_size = saved_output_size - *num_bytes_written;
result = tsi_fake_frame_encode(protected_output_frames, &drained_size, frame); result = tsi_fake_frame_encode(protected_output_frames, &drained_size, frame,
/*error=*/nullptr);
*num_bytes_written += drained_size; *num_bytes_written += drained_size;
if (result == TSI_INCOMPLETE_DATA) result = TSI_OK; if (result == TSI_INCOMPLETE_DATA) result = TSI_OK;
return result; return result;
@ -337,7 +349,8 @@ static tsi_result fake_protector_protect_flush(
frame->data); /* Overwrite header. */ frame->data); /* Overwrite header. */
} }
result = tsi_fake_frame_encode(protected_output_frames, result = tsi_fake_frame_encode(protected_output_frames,
protected_output_frames_size, frame); protected_output_frames_size, frame,
/*error=*/nullptr);
if (result == TSI_INCOMPLETE_DATA) result = TSI_OK; if (result == TSI_INCOMPLETE_DATA) result = TSI_OK;
*still_pending_size = frame->size - frame->offset; *still_pending_size = frame->size - frame->offset;
return result; return result;
@ -361,7 +374,8 @@ static tsi_result fake_protector_unprotect(
/* Go past the header if needed. */ /* Go past the header if needed. */
if (frame->offset == 0) frame->offset = TSI_FAKE_FRAME_HEADER_SIZE; if (frame->offset == 0) frame->offset = TSI_FAKE_FRAME_HEADER_SIZE;
drained_size = saved_output_size - *num_bytes_written; drained_size = saved_output_size - *num_bytes_written;
result = tsi_fake_frame_encode(unprotected_bytes, &drained_size, frame); result = tsi_fake_frame_encode(unprotected_bytes, &drained_size, frame,
/*error=*/nullptr);
unprotected_bytes += drained_size; unprotected_bytes += drained_size;
*num_bytes_written += drained_size; *num_bytes_written += drained_size;
if (result != TSI_OK) { if (result != TSI_OK) {
@ -376,7 +390,8 @@ static tsi_result fake_protector_unprotect(
/* Now process the protected_bytes. */ /* Now process the protected_bytes. */
if (frame->needs_draining) return TSI_INTERNAL_ERROR; if (frame->needs_draining) return TSI_INTERNAL_ERROR;
result = tsi_fake_frame_decode(protected_frames_bytes, result = tsi_fake_frame_decode(protected_frames_bytes,
protected_frames_bytes_size, frame); protected_frames_bytes_size, frame,
/*error=*/nullptr);
if (result != TSI_OK) { if (result != TSI_OK) {
if (result == TSI_INCOMPLETE_DATA) result = TSI_OK; if (result == TSI_INCOMPLETE_DATA) result = TSI_OK;
return result; return result;
@ -387,7 +402,8 @@ static tsi_result fake_protector_unprotect(
if (frame->offset != 0) return TSI_INTERNAL_ERROR; if (frame->offset != 0) return TSI_INTERNAL_ERROR;
frame->offset = TSI_FAKE_FRAME_HEADER_SIZE; /* Go past the header. */ frame->offset = TSI_FAKE_FRAME_HEADER_SIZE; /* Go past the header. */
drained_size = saved_output_size - *num_bytes_written; drained_size = saved_output_size - *num_bytes_written;
result = tsi_fake_frame_encode(unprotected_bytes, &drained_size, frame); result = tsi_fake_frame_encode(unprotected_bytes, &drained_size, frame,
/*error=*/nullptr);
*num_bytes_written += drained_size; *num_bytes_written += drained_size;
if (result == TSI_INCOMPLETE_DATA) result = TSI_OK; if (result == TSI_INCOMPLETE_DATA) result = TSI_OK;
return result; return result;
@ -579,9 +595,10 @@ static const tsi_handshaker_result_vtable handshaker_result_vtable = {
static tsi_result fake_handshaker_result_create( static tsi_result fake_handshaker_result_create(
const unsigned char* unused_bytes, size_t unused_bytes_size, const unsigned char* unused_bytes, size_t unused_bytes_size,
tsi_handshaker_result** handshaker_result) { tsi_handshaker_result** handshaker_result, std::string* error) {
if ((unused_bytes_size > 0 && unused_bytes == nullptr) || if ((unused_bytes_size > 0 && unused_bytes == nullptr) ||
handshaker_result == nullptr) { handshaker_result == nullptr) {
if (error != nullptr) *error = "invalid argument";
return TSI_INVALID_ARGUMENT; return TSI_INVALID_ARGUMENT;
} }
fake_handshaker_result* result = grpc_core::Zalloc<fake_handshaker_result>(); fake_handshaker_result* result = grpc_core::Zalloc<fake_handshaker_result>();
@ -599,7 +616,8 @@ static tsi_result fake_handshaker_result_create(
/* --- tsi_handshaker methods implementation. ---*/ /* --- tsi_handshaker methods implementation. ---*/
static tsi_result fake_handshaker_get_bytes_to_send_to_peer( static tsi_result fake_handshaker_get_bytes_to_send_to_peer(
tsi_handshaker* self, unsigned char* bytes, size_t* bytes_size) { tsi_handshaker* self, unsigned char* bytes, size_t* bytes_size,
std::string* error) {
tsi_fake_handshaker* impl = reinterpret_cast<tsi_fake_handshaker*>(self); tsi_fake_handshaker* impl = reinterpret_cast<tsi_fake_handshaker*>(self);
tsi_result result = TSI_OK; tsi_result result = TSI_OK;
if (impl->needs_incoming_message || impl->result == TSI_OK) { if (impl->needs_incoming_message || impl->result == TSI_OK) {
@ -612,10 +630,9 @@ static tsi_result fake_handshaker_get_bytes_to_send_to_peer(
static_cast<tsi_fake_handshake_message>(impl->next_message_to_send + 2); static_cast<tsi_fake_handshake_message>(impl->next_message_to_send + 2);
const char* msg_string = const char* msg_string =
tsi_fake_handshake_message_to_string(impl->next_message_to_send); tsi_fake_handshake_message_to_string(impl->next_message_to_send);
result = tsi_fake_frame_set_data( tsi_fake_frame_set_data(
reinterpret_cast<unsigned char*>(const_cast<char*>(msg_string)), reinterpret_cast<unsigned char*>(const_cast<char*>(msg_string)),
strlen(msg_string), &impl->outgoing_frame); strlen(msg_string), &impl->outgoing_frame);
if (result != TSI_OK) return result;
if (next_message_to_send > TSI_FAKE_HANDSHAKE_MESSAGE_MAX) { if (next_message_to_send > TSI_FAKE_HANDSHAKE_MESSAGE_MAX) {
next_message_to_send = TSI_FAKE_HANDSHAKE_MESSAGE_MAX; next_message_to_send = TSI_FAKE_HANDSHAKE_MESSAGE_MAX;
} }
@ -626,7 +643,8 @@ static tsi_result fake_handshaker_get_bytes_to_send_to_peer(
} }
impl->next_message_to_send = next_message_to_send; impl->next_message_to_send = next_message_to_send;
} }
result = tsi_fake_frame_encode(bytes, bytes_size, &impl->outgoing_frame); result =
tsi_fake_frame_encode(bytes, bytes_size, &impl->outgoing_frame, error);
if (result != TSI_OK) return result; if (result != TSI_OK) return result;
if (!impl->is_client && if (!impl->is_client &&
impl->next_message_to_send == TSI_FAKE_HANDSHAKE_MESSAGE_MAX) { impl->next_message_to_send == TSI_FAKE_HANDSHAKE_MESSAGE_MAX) {
@ -642,7 +660,8 @@ static tsi_result fake_handshaker_get_bytes_to_send_to_peer(
} }
static tsi_result fake_handshaker_process_bytes_from_peer( static tsi_result fake_handshaker_process_bytes_from_peer(
tsi_handshaker* self, const unsigned char* bytes, size_t* bytes_size) { tsi_handshaker* self, const unsigned char* bytes, size_t* bytes_size,
std::string* error) {
tsi_result result = TSI_OK; tsi_result result = TSI_OK;
tsi_fake_handshaker* impl = reinterpret_cast<tsi_fake_handshaker*>(self); tsi_fake_handshaker* impl = reinterpret_cast<tsi_fake_handshaker*>(self);
tsi_fake_handshake_message expected_msg = tsi_fake_handshake_message expected_msg =
@ -653,14 +672,15 @@ static tsi_result fake_handshaker_process_bytes_from_peer(
*bytes_size = 0; *bytes_size = 0;
return TSI_OK; return TSI_OK;
} }
result = tsi_fake_frame_decode(bytes, bytes_size, &impl->incoming_frame); result =
tsi_fake_frame_decode(bytes, bytes_size, &impl->incoming_frame, error);
if (result != TSI_OK) return result; if (result != TSI_OK) return result;
/* We now have a complete frame. */ /* We now have a complete frame. */
result = tsi_fake_handshake_message_from_string( result = tsi_fake_handshake_message_from_string(
reinterpret_cast<const char*>(impl->incoming_frame.data) + reinterpret_cast<const char*>(impl->incoming_frame.data) +
TSI_FAKE_FRAME_HEADER_SIZE, TSI_FAKE_FRAME_HEADER_SIZE,
&received_msg); &received_msg, error);
if (result != TSI_OK) { if (result != TSI_OK) {
impl->result = result; impl->result = result;
return result; return result;
@ -703,11 +723,13 @@ static tsi_result fake_handshaker_next(
tsi_handshaker* self, const unsigned char* received_bytes, tsi_handshaker* self, const unsigned char* received_bytes,
size_t received_bytes_size, const unsigned char** bytes_to_send, size_t received_bytes_size, const unsigned char** bytes_to_send,
size_t* bytes_to_send_size, tsi_handshaker_result** handshaker_result, size_t* bytes_to_send_size, tsi_handshaker_result** handshaker_result,
tsi_handshaker_on_next_done_cb /*cb*/, void* /*user_data*/) { tsi_handshaker_on_next_done_cb /*cb*/, void* /*user_data*/,
std::string* error) {
/* Sanity check the arguments. */ /* Sanity check the arguments. */
if ((received_bytes_size > 0 && received_bytes == nullptr) || if ((received_bytes_size > 0 && received_bytes == nullptr) ||
bytes_to_send == nullptr || bytes_to_send_size == nullptr || bytes_to_send == nullptr || bytes_to_send_size == nullptr ||
handshaker_result == nullptr) { handshaker_result == nullptr) {
if (error != nullptr) *error = "invalid argument";
return TSI_INVALID_ARGUMENT; return TSI_INVALID_ARGUMENT;
} }
tsi_fake_handshaker* handshaker = tsi_fake_handshaker* handshaker =
@ -717,8 +739,8 @@ static tsi_result fake_handshaker_next(
/* Decode and process a handshake frame from the peer. */ /* Decode and process a handshake frame from the peer. */
size_t consumed_bytes_size = received_bytes_size; size_t consumed_bytes_size = received_bytes_size;
if (received_bytes_size > 0) { if (received_bytes_size > 0) {
result = fake_handshaker_process_bytes_from_peer(self, received_bytes, result = fake_handshaker_process_bytes_from_peer(
&consumed_bytes_size); self, received_bytes, &consumed_bytes_size, error);
if (result != TSI_OK) return result; if (result != TSI_OK) return result;
} }
@ -728,7 +750,8 @@ static tsi_result fake_handshaker_next(
do { do {
size_t sent_bytes_size = handshaker->outgoing_bytes_buffer_size - offset; size_t sent_bytes_size = handshaker->outgoing_bytes_buffer_size - offset;
result = fake_handshaker_get_bytes_to_send_to_peer( result = fake_handshaker_get_bytes_to_send_to_peer(
self, handshaker->outgoing_bytes_buffer + offset, &sent_bytes_size); self, handshaker->outgoing_bytes_buffer + offset, &sent_bytes_size,
error);
offset += sent_bytes_size; offset += sent_bytes_size;
if (result == TSI_INCOMPLETE_DATA) { if (result == TSI_INCOMPLETE_DATA) {
handshaker->outgoing_bytes_buffer_size *= 2; handshaker->outgoing_bytes_buffer_size *= 2;
@ -754,7 +777,7 @@ static tsi_result fake_handshaker_next(
/* Create a handshaker_result containing the unused bytes. */ /* Create a handshaker_result containing the unused bytes. */
result = fake_handshaker_result_create(unused_bytes, unused_bytes_size, result = fake_handshaker_result_create(unused_bytes, unused_bytes_size,
handshaker_result); handshaker_result, error);
if (result == TSI_OK) { if (result == TSI_OK) {
/* Indicate that the handshake has completed and that a handshaker_result /* Indicate that the handshake has completed and that a handshaker_result
* has been created. */ * has been created. */

@ -119,13 +119,17 @@ tsi_result create_handshaker_result(const unsigned char* received_bytes,
/* --- tsi_handshaker methods implementation. --- */ /* --- tsi_handshaker methods implementation. --- */
tsi_result handshaker_next( tsi_result handshaker_next(tsi_handshaker* self,
tsi_handshaker* self, const unsigned char* received_bytes, const unsigned char* received_bytes,
size_t received_bytes_size, const unsigned char** /*bytes_to_send*/, size_t received_bytes_size,
size_t* bytes_to_send_size, tsi_handshaker_result** result, const unsigned char** /*bytes_to_send*/,
tsi_handshaker_on_next_done_cb /*cb*/, void* /*user_data*/) { size_t* bytes_to_send_size,
tsi_handshaker_result** result,
tsi_handshaker_on_next_done_cb /*cb*/,
void* /*user_data*/, std::string* error) {
if (self == nullptr) { if (self == nullptr) {
gpr_log(GPR_ERROR, "Invalid arguments to handshaker_next()"); gpr_log(GPR_ERROR, "Invalid arguments to handshaker_next()");
if (error != nullptr) *error = "invalid argument";
return TSI_INVALID_ARGUMENT; return TSI_INVALID_ARGUMENT;
} }
/* Note that there is no interaction between TSI peers, and all operations are /* Note that there is no interaction between TSI peers, and all operations are

@ -45,6 +45,7 @@
#include <openssl/x509v3.h> #include <openssl/x509v3.h>
#include "absl/strings/match.h" #include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h" #include "absl/strings/string_view.h"
#include <grpc/grpc_security.h> #include <grpc/grpc_security.h>
@ -1415,9 +1416,11 @@ static const tsi_handshaker_result_vtable handshaker_result_vtable = {
static tsi_result ssl_handshaker_result_create( static tsi_result ssl_handshaker_result_create(
tsi_ssl_handshaker* handshaker, unsigned char* unused_bytes, tsi_ssl_handshaker* handshaker, unsigned char* unused_bytes,
size_t unused_bytes_size, tsi_handshaker_result** handshaker_result) { size_t unused_bytes_size, tsi_handshaker_result** handshaker_result,
std::string* error) {
if (handshaker == nullptr || handshaker_result == nullptr || if (handshaker == nullptr || handshaker_result == nullptr ||
(unused_bytes_size > 0 && unused_bytes == nullptr)) { (unused_bytes_size > 0 && unused_bytes == nullptr)) {
if (error != nullptr) *error = "invalid argument";
return TSI_INVALID_ARGUMENT; return TSI_INVALID_ARGUMENT;
} }
tsi_ssl_handshaker_result* result = tsi_ssl_handshaker_result* result =
@ -1438,9 +1441,11 @@ static tsi_result ssl_handshaker_result_create(
/* --- tsi_handshaker methods implementation. ---*/ /* --- tsi_handshaker methods implementation. ---*/
static tsi_result ssl_handshaker_get_bytes_to_send_to_peer( static tsi_result ssl_handshaker_get_bytes_to_send_to_peer(
tsi_ssl_handshaker* impl, unsigned char* bytes, size_t* bytes_size) { tsi_ssl_handshaker* impl, unsigned char* bytes, size_t* bytes_size,
std::string* error) {
int bytes_read_from_ssl = 0; int bytes_read_from_ssl = 0;
if (bytes == nullptr || bytes_size == nullptr || *bytes_size > INT_MAX) { if (bytes == nullptr || bytes_size == nullptr || *bytes_size > INT_MAX) {
if (error != nullptr) *error = "invalid argument";
return TSI_INVALID_ARGUMENT; return TSI_INVALID_ARGUMENT;
} }
GPR_ASSERT(*bytes_size <= INT_MAX); GPR_ASSERT(*bytes_size <= INT_MAX);
@ -1449,6 +1454,7 @@ static tsi_result ssl_handshaker_get_bytes_to_send_to_peer(
if (bytes_read_from_ssl < 0) { if (bytes_read_from_ssl < 0) {
*bytes_size = 0; *bytes_size = 0;
if (!BIO_should_retry(impl->network_io)) { if (!BIO_should_retry(impl->network_io)) {
if (error != nullptr) *error = "error reading from BIO";
impl->result = TSI_INTERNAL_ERROR; impl->result = TSI_INTERNAL_ERROR;
return impl->result; return impl->result;
} else { } else {
@ -1467,7 +1473,8 @@ static tsi_result ssl_handshaker_get_result(tsi_ssl_handshaker* impl) {
return impl->result; return impl->result;
} }
static tsi_result ssl_handshaker_do_handshake(tsi_ssl_handshaker* impl) { static tsi_result ssl_handshaker_do_handshake(tsi_ssl_handshaker* impl,
std::string* error) {
if (ssl_handshaker_get_result(impl) != TSI_HANDSHAKE_IN_PROGRESS) { if (ssl_handshaker_get_result(impl) != TSI_HANDSHAKE_IN_PROGRESS) {
impl->result = TSI_OK; impl->result = TSI_OK;
return impl->result; return impl->result;
@ -1493,6 +1500,9 @@ static tsi_result ssl_handshaker_do_handshake(tsi_ssl_handshaker* impl) {
ERR_error_string_n(ERR_get_error(), err_str, sizeof(err_str)); ERR_error_string_n(ERR_get_error(), err_str, sizeof(err_str));
gpr_log(GPR_ERROR, "Handshake failed with fatal error %s: %s.", gpr_log(GPR_ERROR, "Handshake failed with fatal error %s: %s.",
ssl_error_string(ssl_result), err_str); ssl_error_string(ssl_result), err_str);
if (error != nullptr) {
*error = absl::StrCat(ssl_error_string(ssl_result), ": ", err_str);
}
impl->result = TSI_PROTOCOL_FAILURE; impl->result = TSI_PROTOCOL_FAILURE;
return impl->result; return impl->result;
} }
@ -1501,9 +1511,11 @@ static tsi_result ssl_handshaker_do_handshake(tsi_ssl_handshaker* impl) {
} }
static tsi_result ssl_handshaker_process_bytes_from_peer( static tsi_result ssl_handshaker_process_bytes_from_peer(
tsi_ssl_handshaker* impl, const unsigned char* bytes, size_t* bytes_size) { tsi_ssl_handshaker* impl, const unsigned char* bytes, size_t* bytes_size,
std::string* error) {
int bytes_written_into_ssl_size = 0; int bytes_written_into_ssl_size = 0;
if (bytes == nullptr || bytes_size == nullptr || *bytes_size > INT_MAX) { if (bytes == nullptr || bytes_size == nullptr || *bytes_size > INT_MAX) {
if (error != nullptr) *error = "invalid argument";
return TSI_INVALID_ARGUMENT; return TSI_INVALID_ARGUMENT;
} }
GPR_ASSERT(*bytes_size <= INT_MAX); GPR_ASSERT(*bytes_size <= INT_MAX);
@ -1511,11 +1523,12 @@ static tsi_result ssl_handshaker_process_bytes_from_peer(
BIO_write(impl->network_io, bytes, static_cast<int>(*bytes_size)); BIO_write(impl->network_io, bytes, static_cast<int>(*bytes_size));
if (bytes_written_into_ssl_size < 0) { if (bytes_written_into_ssl_size < 0) {
gpr_log(GPR_ERROR, "Could not write to memory BIO."); gpr_log(GPR_ERROR, "Could not write to memory BIO.");
if (error != nullptr) *error = "could not write to memory BIO";
impl->result = TSI_INTERNAL_ERROR; impl->result = TSI_INTERNAL_ERROR;
return impl->result; return impl->result;
} }
*bytes_size = static_cast<size_t>(bytes_written_into_ssl_size); *bytes_size = static_cast<size_t>(bytes_written_into_ssl_size);
return ssl_handshaker_do_handshake(impl); return ssl_handshaker_do_handshake(impl, error);
} }
static void ssl_handshaker_destroy(tsi_handshaker* self) { static void ssl_handshaker_destroy(tsi_handshaker* self) {
@ -1531,9 +1544,11 @@ static void ssl_handshaker_destroy(tsi_handshaker* self) {
// |bytes_remaining|. // |bytes_remaining|.
static tsi_result ssl_bytes_remaining(tsi_ssl_handshaker* impl, static tsi_result ssl_bytes_remaining(tsi_ssl_handshaker* impl,
unsigned char** bytes_remaining, unsigned char** bytes_remaining,
size_t* bytes_remaining_size) { size_t* bytes_remaining_size,
std::string* error) {
if (impl == nullptr || bytes_remaining == nullptr || if (impl == nullptr || bytes_remaining == nullptr ||
bytes_remaining_size == nullptr) { bytes_remaining_size == nullptr) {
if (error != nullptr) *error = "invalid argument";
return TSI_INVALID_ARGUMENT; return TSI_INVALID_ARGUMENT;
} }
// Atempt to read all of the bytes in SSL's read BIO. These bytes should // Atempt to read all of the bytes in SSL's read BIO. These bytes should
@ -1551,6 +1566,9 @@ static tsi_result ssl_bytes_remaining(tsi_ssl_handshaker* impl,
"Failed to read the expected number of bytes from SSL object."); "Failed to read the expected number of bytes from SSL object.");
gpr_free(*bytes_remaining); gpr_free(*bytes_remaining);
*bytes_remaining = nullptr; *bytes_remaining = nullptr;
if (error != nullptr) {
*error = "Failed to read the expected number of bytes from SSL object.";
}
return TSI_INTERNAL_ERROR; return TSI_INTERNAL_ERROR;
} }
*bytes_remaining_size = static_cast<size_t>(bytes_read); *bytes_remaining_size = static_cast<size_t>(bytes_read);
@ -1562,14 +1580,15 @@ static tsi_result ssl_bytes_remaining(tsi_ssl_handshaker* impl,
// This API needs to be repeatedly called until all handshake data are // This API needs to be repeatedly called until all handshake data are
// received from SSL. // received from SSL.
static tsi_result ssl_handshaker_write_output_buffer(tsi_handshaker* self, static tsi_result ssl_handshaker_write_output_buffer(tsi_handshaker* self,
size_t* bytes_written) { size_t* bytes_written,
std::string* error) {
tsi_ssl_handshaker* impl = reinterpret_cast<tsi_ssl_handshaker*>(self); tsi_ssl_handshaker* impl = reinterpret_cast<tsi_ssl_handshaker*>(self);
tsi_result status = TSI_OK; tsi_result status = TSI_OK;
size_t offset = *bytes_written; size_t offset = *bytes_written;
do { do {
size_t to_send_size = impl->outgoing_bytes_buffer_size - offset; size_t to_send_size = impl->outgoing_bytes_buffer_size - offset;
status = ssl_handshaker_get_bytes_to_send_to_peer( status = ssl_handshaker_get_bytes_to_send_to_peer(
impl, impl->outgoing_bytes_buffer + offset, &to_send_size); impl, impl->outgoing_bytes_buffer + offset, &to_send_size, error);
offset += to_send_size; offset += to_send_size;
if (status == TSI_INCOMPLETE_DATA) { if (status == TSI_INCOMPLETE_DATA) {
impl->outgoing_bytes_buffer_size *= 2; impl->outgoing_bytes_buffer_size *= 2;
@ -1581,15 +1600,19 @@ static tsi_result ssl_handshaker_write_output_buffer(tsi_handshaker* self,
return status; return status;
} }
static tsi_result ssl_handshaker_next( static tsi_result ssl_handshaker_next(tsi_handshaker* self,
tsi_handshaker* self, const unsigned char* received_bytes, const unsigned char* received_bytes,
size_t received_bytes_size, const unsigned char** bytes_to_send, size_t received_bytes_size,
size_t* bytes_to_send_size, tsi_handshaker_result** handshaker_result, const unsigned char** bytes_to_send,
tsi_handshaker_on_next_done_cb /*cb*/, void* /*user_data*/) { size_t* bytes_to_send_size,
tsi_handshaker_result** handshaker_result,
tsi_handshaker_on_next_done_cb /*cb*/,
void* /*user_data*/, std::string* error) {
/* Input sanity check. */ /* Input sanity check. */
if ((received_bytes_size > 0 && received_bytes == nullptr) || if ((received_bytes_size > 0 && received_bytes == nullptr) ||
bytes_to_send == nullptr || bytes_to_send_size == nullptr || bytes_to_send == nullptr || bytes_to_send_size == nullptr ||
handshaker_result == nullptr) { handshaker_result == nullptr) {
if (error != nullptr) *error = "invalid argument";
return TSI_INVALID_ARGUMENT; return TSI_INVALID_ARGUMENT;
} }
/* If there are received bytes, process them first. */ /* If there are received bytes, process them first. */
@ -1599,16 +1622,16 @@ static tsi_result ssl_handshaker_next(
size_t bytes_written = 0; size_t bytes_written = 0;
if (received_bytes_size > 0) { if (received_bytes_size > 0) {
status = ssl_handshaker_process_bytes_from_peer(impl, received_bytes, status = ssl_handshaker_process_bytes_from_peer(impl, received_bytes,
&bytes_consumed); &bytes_consumed, error);
while (status == TSI_DRAIN_BUFFER) { while (status == TSI_DRAIN_BUFFER) {
status = ssl_handshaker_write_output_buffer(self, &bytes_written); status = ssl_handshaker_write_output_buffer(self, &bytes_written, error);
if (status != TSI_OK) return status; if (status != TSI_OK) return status;
status = ssl_handshaker_do_handshake(impl); status = ssl_handshaker_do_handshake(impl, error);
} }
} }
if (status != TSI_OK) return status; if (status != TSI_OK) return status;
/* Get bytes to send to the peer, if available. */ /* Get bytes to send to the peer, if available. */
status = ssl_handshaker_write_output_buffer(self, &bytes_written); status = ssl_handshaker_write_output_buffer(self, &bytes_written, error);
if (status != TSI_OK) return status; if (status != TSI_OK) return status;
*bytes_to_send = impl->outgoing_bytes_buffer; *bytes_to_send = impl->outgoing_bytes_buffer;
*bytes_to_send_size = bytes_written; *bytes_to_send_size = bytes_written;
@ -1622,15 +1645,17 @@ static tsi_result ssl_handshaker_next(
// peer that must be processed. // peer that must be processed.
unsigned char* unused_bytes = nullptr; unsigned char* unused_bytes = nullptr;
size_t unused_bytes_size = 0; size_t unused_bytes_size = 0;
status = ssl_bytes_remaining(impl, &unused_bytes, &unused_bytes_size); status =
ssl_bytes_remaining(impl, &unused_bytes, &unused_bytes_size, error);
if (status != TSI_OK) return status; if (status != TSI_OK) return status;
if (unused_bytes_size > received_bytes_size) { if (unused_bytes_size > received_bytes_size) {
gpr_log(GPR_ERROR, "More unused bytes than received bytes."); gpr_log(GPR_ERROR, "More unused bytes than received bytes.");
gpr_free(unused_bytes); gpr_free(unused_bytes);
if (error != nullptr) *error = "More unused bytes than received bytes.";
return TSI_INTERNAL_ERROR; return TSI_INTERNAL_ERROR;
} }
status = ssl_handshaker_result_create(impl, unused_bytes, unused_bytes_size, status = ssl_handshaker_result_create(impl, unused_bytes, unused_bytes_size,
handshaker_result); handshaker_result, error);
if (status == TSI_OK) { if (status == TSI_OK) {
/* Indicates that the handshake has completed and that a handshaker_result /* Indicates that the handshake has completed and that a handshaker_result
* has been created. */ * has been created. */

@ -216,14 +216,26 @@ tsi_result tsi_handshaker_next(
tsi_handshaker* self, const unsigned char* received_bytes, tsi_handshaker* self, const unsigned char* received_bytes,
size_t received_bytes_size, const unsigned char** bytes_to_send, size_t received_bytes_size, const unsigned char** bytes_to_send,
size_t* bytes_to_send_size, tsi_handshaker_result** handshaker_result, size_t* bytes_to_send_size, tsi_handshaker_result** handshaker_result,
tsi_handshaker_on_next_done_cb cb, void* user_data) { tsi_handshaker_on_next_done_cb cb, void* user_data, std::string* error) {
if (self == nullptr || self->vtable == nullptr) return TSI_INVALID_ARGUMENT; if (self == nullptr || self->vtable == nullptr) {
if (self->handshaker_result_created) return TSI_FAILED_PRECONDITION; if (error != nullptr) *error = "invalid argument";
if (self->handshake_shutdown) return TSI_HANDSHAKE_SHUTDOWN; return TSI_INVALID_ARGUMENT;
if (self->vtable->next == nullptr) return TSI_UNIMPLEMENTED; }
if (self->handshaker_result_created) {
if (error != nullptr) *error = "handshaker already returned a result";
return TSI_FAILED_PRECONDITION;
}
if (self->handshake_shutdown) {
if (error != nullptr) *error = "handshaker shutdown";
return TSI_HANDSHAKE_SHUTDOWN;
}
if (self->vtable->next == nullptr) {
if (error != nullptr) *error = "TSI handshaker does not implement next()";
return TSI_UNIMPLEMENTED;
}
return self->vtable->next(self, received_bytes, received_bytes_size, return self->vtable->next(self, received_bytes, received_bytes_size,
bytes_to_send, bytes_to_send_size, bytes_to_send, bytes_to_send_size,
handshaker_result, cb, user_data); handshaker_result, cb, user_data, error);
} }
void tsi_handshaker_shutdown(tsi_handshaker* self) { void tsi_handshaker_shutdown(tsi_handshaker* self) {

@ -78,7 +78,8 @@ struct tsi_handshaker_vtable {
const unsigned char** bytes_to_send, const unsigned char** bytes_to_send,
size_t* bytes_to_send_size, size_t* bytes_to_send_size,
tsi_handshaker_result** handshaker_result, tsi_handshaker_result** handshaker_result,
tsi_handshaker_on_next_done_cb cb, void* user_data); tsi_handshaker_on_next_done_cb cb, void* user_data,
std::string* error);
void (*shutdown)(tsi_handshaker* self); void (*shutdown)(tsi_handshaker* self);
}; };
struct tsi_handshaker { struct tsi_handshaker {

@ -24,6 +24,8 @@
#include <stdint.h> #include <stdint.h>
#include <stdlib.h> #include <stdlib.h>
#include <string>
#include "src/core/lib/debug/trace.h" #include "src/core/lib/debug/trace.h"
/* --- tsi result --- */ /* --- tsi result --- */
@ -472,6 +474,13 @@ typedef void (*tsi_handshaker_on_next_done_cb)(
- cb is the callback function defined above. It can be NULL for synchronous - cb is the callback function defined above. It can be NULL for synchronous
TSI handshaker implementation. TSI handshaker implementation.
- user_data is the argument to callback function passed from the caller. - user_data is the argument to callback function passed from the caller.
- error, if non-null, will be populated with a human-readable error
message whenever the result value is something other than TSI_OK,
TSI_ASYNC, or TSI_INCOMPLETE_DATA. The object pointed to by this
argument is owned by the caller and must continue to exist until after the
handshake is finished. Some TSI implementations cache this value,
so callers must pass the same value to all calls to tsi_handshaker_next()
for a given handshake.
This method returns TSI_ASYNC if the TSI handshaker implementation is This method returns TSI_ASYNC if the TSI handshaker implementation is
asynchronous, and in this case, the callback is guaranteed to run in another asynchronous, and in this case, the callback is guaranteed to run in another
thread owned by TSI. It returns TSI_OK if the handshake completes or if thread owned by TSI. It returns TSI_OK if the handshake completes or if
@ -482,11 +491,14 @@ typedef void (*tsi_handshaker_on_next_done_cb)(
The caller is responsible for destroying the handshaker_result. However, The caller is responsible for destroying the handshaker_result. However,
the caller should not free bytes_to_send, as the buffer is owned by the the caller should not free bytes_to_send, as the buffer is owned by the
tsi_handshaker object. */ tsi_handshaker object. */
tsi_result tsi_handshaker_next( tsi_result tsi_handshaker_next(tsi_handshaker* self,
tsi_handshaker* self, const unsigned char* received_bytes, const unsigned char* received_bytes,
size_t received_bytes_size, const unsigned char** bytes_to_send, size_t received_bytes_size,
size_t* bytes_to_send_size, tsi_handshaker_result** handshaker_result, const unsigned char** bytes_to_send,
tsi_handshaker_on_next_done_cb cb, void* user_data); size_t* bytes_to_send_size,
tsi_handshaker_result** handshaker_result,
tsi_handshaker_on_next_done_cb cb,
void* user_data, std::string* error = nullptr);
/* This method shuts down a TSI handshake that is in progress. /* This method shuts down a TSI handshake that is in progress.
* *

@ -331,13 +331,13 @@ static alts_handshaker_client_test_config* create_config() {
nullptr, server_options, nullptr, server_options,
grpc_slice_from_static_string(ALTS_HANDSHAKER_CLIENT_TEST_TARGET_NAME), grpc_slice_from_static_string(ALTS_HANDSHAKER_CLIENT_TEST_TARGET_NAME),
nullptr, nullptr, nullptr, nullptr, false, nullptr, nullptr, nullptr, nullptr, false,
ALTS_HANDSHAKER_CLIENT_TEST_MAX_FRAME_SIZE); ALTS_HANDSHAKER_CLIENT_TEST_MAX_FRAME_SIZE, nullptr);
config->client = alts_grpc_handshaker_client_create( config->client = alts_grpc_handshaker_client_create(
nullptr, config->channel, ALTS_HANDSHAKER_SERVICE_URL_FOR_TESTING, nullptr, config->channel, ALTS_HANDSHAKER_SERVICE_URL_FOR_TESTING,
nullptr, client_options, nullptr, client_options,
grpc_slice_from_static_string(ALTS_HANDSHAKER_CLIENT_TEST_TARGET_NAME), grpc_slice_from_static_string(ALTS_HANDSHAKER_CLIENT_TEST_TARGET_NAME),
nullptr, nullptr, nullptr, nullptr, true, nullptr, nullptr, nullptr, nullptr, true,
ALTS_HANDSHAKER_CLIENT_TEST_MAX_FRAME_SIZE); ALTS_HANDSHAKER_CLIENT_TEST_MAX_FRAME_SIZE, nullptr);
EXPECT_NE(config->client, nullptr); EXPECT_NE(config->client, nullptr);
EXPECT_NE(config->server, nullptr); EXPECT_NE(config->server, nullptr);
grpc_alts_credentials_options_destroy(client_options); grpc_alts_credentials_options_destroy(client_options);

@ -774,9 +774,9 @@ TEST(AltsTsiHandshakerTest, CheckHandleResponseNullptrHandshaker) {
alts_handshaker_client* client = alts_handshaker_client* client =
alts_tsi_handshaker_get_client_for_testing(alts_handshaker); alts_tsi_handshaker_get_client_for_testing(alts_handshaker);
/* Check nullptr handshaker. */ /* Check nullptr handshaker. */
alts_handshaker_client_set_fields_for_testing(client, nullptr, alts_handshaker_client_set_fields_for_testing(
on_invalid_input_cb, nullptr, client, nullptr, on_invalid_input_cb, nullptr, recv_buffer,
recv_buffer, GRPC_STATUS_OK); /*inject_read_failure=*/false);
alts_handshaker_client_handle_response(client, true); alts_handshaker_client_handle_response(client, true);
/* Note: here and elsewhere in this test, we first ref the handshaker in order /* Note: here and elsewhere in this test, we first ref the handshaker in order
* to match the unref that on_status_received will do. This necessary * to match the unref that on_status_received will do. This necessary
@ -812,9 +812,9 @@ TEST(AltsTsiHandshakerTest, CheckHandleResponseNullptrRecvBytes) {
alts_handshaker_client* client = alts_handshaker_client* client =
alts_tsi_handshaker_get_client_for_testing(alts_handshaker); alts_tsi_handshaker_get_client_for_testing(alts_handshaker);
/* Check nullptr recv_bytes. */ /* Check nullptr recv_bytes. */
alts_handshaker_client_set_fields_for_testing(client, alts_handshaker, alts_handshaker_client_set_fields_for_testing(
on_invalid_input_cb, nullptr, client, alts_handshaker, on_invalid_input_cb, nullptr, nullptr,
nullptr, GRPC_STATUS_OK); /*inject_read_failure=*/false);
alts_handshaker_client_handle_response(client, true); alts_handshaker_client_handle_response(client, true);
alts_handshaker_client_ref_for_testing(client); alts_handshaker_client_ref_for_testing(client);
{ {
@ -850,7 +850,7 @@ TEST(AltsTsiHandshakerTest,
/* Check failed grpc call made to handshaker service. */ /* Check failed grpc call made to handshaker service. */
alts_handshaker_client_set_fields_for_testing( alts_handshaker_client_set_fields_for_testing(
client, alts_handshaker, on_failed_grpc_call_cb, nullptr, recv_buffer, client, alts_handshaker, on_failed_grpc_call_cb, nullptr, recv_buffer,
GRPC_STATUS_UNKNOWN); /*inject_read_failure=*/true);
alts_handshaker_client_handle_response(client, true); alts_handshaker_client_handle_response(client, true);
alts_handshaker_client_ref_for_testing(client); alts_handshaker_client_ref_for_testing(client);
{ {
@ -885,9 +885,9 @@ TEST(AltsTsiHandshakerTest,
alts_handshaker_client* client = alts_handshaker_client* client =
alts_tsi_handshaker_get_client_for_testing(alts_handshaker); alts_tsi_handshaker_get_client_for_testing(alts_handshaker);
/* Check failed recv message op from handshaker service. */ /* Check failed recv message op from handshaker service. */
alts_handshaker_client_set_fields_for_testing(client, alts_handshaker, alts_handshaker_client_set_fields_for_testing(
on_failed_grpc_call_cb, nullptr, client, alts_handshaker, on_failed_grpc_call_cb, nullptr, recv_buffer,
recv_buffer, GRPC_STATUS_OK); /*inject_read_failure=*/false);
alts_handshaker_client_handle_response(client, false); alts_handshaker_client_handle_response(client, false);
alts_handshaker_client_ref_for_testing(client); alts_handshaker_client_ref_for_testing(client);
{ {
@ -931,9 +931,9 @@ TEST(AltsTsiHandshakerTest, CheckHandleResponseInvalidResp) {
alts_tsi_handshaker_get_client_for_testing(alts_handshaker); alts_tsi_handshaker_get_client_for_testing(alts_handshaker);
/* Tests. */ /* Tests. */
grpc_byte_buffer* recv_buffer = generate_handshaker_response(INVALID); grpc_byte_buffer* recv_buffer = generate_handshaker_response(INVALID);
alts_handshaker_client_set_fields_for_testing(client, alts_handshaker, alts_handshaker_client_set_fields_for_testing(
on_invalid_resp_cb, nullptr, client, alts_handshaker, on_invalid_resp_cb, nullptr, recv_buffer,
recv_buffer, GRPC_STATUS_OK); /*inject_read_failure=*/false);
alts_handshaker_client_handle_response(client, true); alts_handshaker_client_handle_response(client, true);
alts_handshaker_client_ref_for_testing(client); alts_handshaker_client_ref_for_testing(client);
{ {
@ -1003,9 +1003,9 @@ TEST(AltsTsiHandshakerTest, CheckHandleResponseFailure) {
alts_tsi_handshaker_get_client_for_testing(alts_handshaker); alts_tsi_handshaker_get_client_for_testing(alts_handshaker);
/* Tests. */ /* Tests. */
grpc_byte_buffer* recv_buffer = generate_handshaker_response(FAILED); grpc_byte_buffer* recv_buffer = generate_handshaker_response(FAILED);
alts_handshaker_client_set_fields_for_testing(client, alts_handshaker, alts_handshaker_client_set_fields_for_testing(
on_failed_resp_cb, nullptr, client, alts_handshaker, on_failed_resp_cb, nullptr, recv_buffer,
recv_buffer, GRPC_STATUS_OK); /*inject_read_failure=*/false);
alts_handshaker_client_handle_response(client, true /* is_ok*/); alts_handshaker_client_handle_response(client, true /* is_ok*/);
alts_handshaker_client_ref_for_testing(client); alts_handshaker_client_ref_for_testing(client);
{ {
@ -1049,9 +1049,9 @@ TEST(AltsTsiHandshakerTest, CheckHandleResponseAfterShutdown) {
/* Tests. */ /* Tests. */
tsi_handshaker_shutdown(handshaker); tsi_handshaker_shutdown(handshaker);
grpc_byte_buffer* recv_buffer = generate_handshaker_response(CLIENT_START); grpc_byte_buffer* recv_buffer = generate_handshaker_response(CLIENT_START);
alts_handshaker_client_set_fields_for_testing(client, alts_handshaker, alts_handshaker_client_set_fields_for_testing(
on_shutdown_resp_cb, nullptr, client, alts_handshaker, on_shutdown_resp_cb, nullptr, recv_buffer,
recv_buffer, GRPC_STATUS_OK); /*inject_read_failure=*/false);
alts_handshaker_client_handle_response(client, true); alts_handshaker_client_handle_response(client, true);
alts_handshaker_client_ref_for_testing(client); alts_handshaker_client_ref_for_testing(client);
{ {

Loading…
Cancel
Save