From 6fbc436b1172ef03361dac77b949712e9844a897 Mon Sep 17 00:00:00 2001 From: Yihua Zhang Date: Wed, 9 May 2018 09:45:47 -0700 Subject: [PATCH] Add tsi_handshaker_shutdown to TSI --- .../security/transport/security_handshaker.cc | 1 + .../alts/handshaker/alts_handshaker_client.cc | 26 ++++-- .../alts/handshaker/alts_handshaker_client.h | 10 +++ .../alts/handshaker/alts_tsi_handshaker.cc | 30 ++++++- .../handshaker/alts_tsi_handshaker_private.h | 3 + src/core/tsi/fake_transport_security.cc | 1 + src/core/tsi/ssl_transport_security.cc | 1 + src/core/tsi/transport_security.cc | 14 +++ src/core/tsi/transport_security.h | 2 + src/core/tsi/transport_security_adapter.cc | 7 ++ src/core/tsi/transport_security_interface.h | 10 ++- .../handshaker/alts_handshaker_client_test.cc | 3 + .../handshaker/alts_tsi_handshaker_test.cc | 88 ++++++++++++++++++- 13 files changed, 185 insertions(+), 11 deletions(-) diff --git a/src/core/lib/security/transport/security_handshaker.cc b/src/core/lib/security/transport/security_handshaker.cc index d9ba3483e65..aff723ed044 100644 --- a/src/core/lib/security/transport/security_handshaker.cc +++ b/src/core/lib/security/transport/security_handshaker.cc @@ -380,6 +380,7 @@ static void security_handshaker_shutdown(grpc_handshaker* handshaker, gpr_mu_lock(&h->mu); if (!h->shutdown) { h->shutdown = true; + tsi_handshaker_shutdown(h->handshaker); grpc_endpoint_shutdown(h->args->endpoint, GRPC_ERROR_REF(why)); cleanup_args_for_failure_locked(h); } diff --git a/src/core/tsi/alts/handshaker/alts_handshaker_client.cc b/src/core/tsi/alts/handshaker/alts_handshaker_client.cc index 40f30e41ca5..b5268add0d1 100644 --- a/src/core/tsi/alts/handshaker/alts_handshaker_client.cc +++ b/src/core/tsi/alts/handshaker/alts_handshaker_client.cc @@ -118,8 +118,7 @@ static grpc_byte_buffer* get_serialized_start_client(alts_tsi_event* event) { static tsi_result handshaker_client_start_client(alts_handshaker_client* client, alts_tsi_event* event) { if (client == nullptr || event == nullptr) { - gpr_log(GPR_ERROR, - "Invalid arguments to alts_grpc_handshaker_client_start_client()"); + gpr_log(GPR_ERROR, "Invalid arguments to handshaker_client_start_client()"); return TSI_INVALID_ARGUMENT; } grpc_byte_buffer* buffer = get_serialized_start_client(event); @@ -167,8 +166,7 @@ static tsi_result handshaker_client_start_server(alts_handshaker_client* client, alts_tsi_event* event, grpc_slice* bytes_received) { if (client == nullptr || event == nullptr || bytes_received == nullptr) { - gpr_log(GPR_ERROR, - "Invalid arguments to alts_grpc_handshaker_client_start_server()"); + gpr_log(GPR_ERROR, "Invalid arguments to handshaker_client_start_server()"); return TSI_INVALID_ARGUMENT; } grpc_byte_buffer* buffer = get_serialized_start_server(event, bytes_received); @@ -206,8 +204,7 @@ static tsi_result handshaker_client_next(alts_handshaker_client* client, alts_tsi_event* event, grpc_slice* bytes_received) { if (client == nullptr || event == nullptr || bytes_received == nullptr) { - gpr_log(GPR_ERROR, - "Invalid arguments to alts_grpc_handshaker_client_next()"); + gpr_log(GPR_ERROR, "Invalid arguments to handshaker_client_next()"); return TSI_INVALID_ARGUMENT; } grpc_byte_buffer* buffer = get_serialized_next(bytes_received); @@ -223,6 +220,13 @@ static tsi_result handshaker_client_next(alts_handshaker_client* client, return result; } +static void handshaker_client_shutdown(alts_handshaker_client* client) { + GPR_ASSERT(client != nullptr); + alts_grpc_handshaker_client* grpc_client = + reinterpret_cast(client); + GPR_ASSERT(grpc_call_cancel(grpc_client->call, nullptr) == GRPC_CALL_OK); +} + static void handshaker_client_destruct(alts_handshaker_client* client) { if (client == nullptr) { return; @@ -234,7 +238,8 @@ static void handshaker_client_destruct(alts_handshaker_client* client) { static const alts_handshaker_client_vtable vtable = { handshaker_client_start_client, handshaker_client_start_server, - handshaker_client_next, handshaker_client_destruct}; + handshaker_client_next, handshaker_client_shutdown, + handshaker_client_destruct}; alts_handshaker_client* alts_grpc_handshaker_client_create( grpc_channel* channel, grpc_completion_queue* queue, @@ -306,6 +311,13 @@ tsi_result alts_handshaker_client_next(alts_handshaker_client* client, return TSI_INVALID_ARGUMENT; } +void alts_handshaker_client_shutdown(alts_handshaker_client* client) { + if (client != nullptr && client->vtable != nullptr && + client->vtable->shutdown != nullptr) { + client->vtable->shutdown(client); + } +} + void alts_handshaker_client_destroy(alts_handshaker_client* client) { if (client != nullptr) { if (client->vtable != nullptr && client->vtable->destruct != nullptr) { diff --git a/src/core/tsi/alts/handshaker/alts_handshaker_client.h b/src/core/tsi/alts/handshaker/alts_handshaker_client.h index fb2d2cf68e6..8dd8fe440db 100644 --- a/src/core/tsi/alts/handshaker/alts_handshaker_client.h +++ b/src/core/tsi/alts/handshaker/alts_handshaker_client.h @@ -51,6 +51,7 @@ typedef struct alts_handshaker_client_vtable { alts_tsi_event* event, grpc_slice* bytes_received); tsi_result (*next)(alts_handshaker_client* client, alts_tsi_event* event, grpc_slice* bytes_received); + void (*shutdown)(alts_handshaker_client* client); void (*destruct)(alts_handshaker_client* client); } alts_handshaker_client_vtable; @@ -99,6 +100,15 @@ tsi_result alts_handshaker_client_next(alts_handshaker_client* client, alts_tsi_event* event, grpc_slice* bytes_received); +/** + * This method cancels previously scheduled, but yet executed handshaker + * requests to ALTS handshaker service. After this operation, the handshake + * will be shutdown, and no more handshaker requests will get scheduled. + * + * - client: ALTS handshaker client instance. + */ +void alts_handshaker_client_shutdown(alts_handshaker_client* client); + /** * This method destroys a ALTS handshaker client. * diff --git a/src/core/tsi/alts/handshaker/alts_tsi_handshaker.cc b/src/core/tsi/alts/handshaker/alts_tsi_handshaker.cc index 529f2103c71..96760853808 100644 --- a/src/core/tsi/alts/handshaker/alts_tsi_handshaker.cc +++ b/src/core/tsi/alts/handshaker/alts_tsi_handshaker.cc @@ -241,6 +241,10 @@ static tsi_result handshaker_next( gpr_log(GPR_ERROR, "Invalid arguments to handshaker_next()"); return TSI_INVALID_ARGUMENT; } + if (self->handshake_shutdown) { + gpr_log(GPR_ERROR, "TSI handshake shutdown"); + return TSI_HANDSHAKE_SHUTDOWN; + } alts_tsi_handshaker* handshaker = reinterpret_cast(self); tsi_result ok = TSI_OK; @@ -277,6 +281,16 @@ static tsi_result handshaker_next( return TSI_ASYNC; } +static void handshaker_shutdown(tsi_handshaker* self) { + GPR_ASSERT(self != nullptr); + if (self->handshake_shutdown) { + return; + } + alts_tsi_handshaker* handshaker = + reinterpret_cast(self); + alts_handshaker_client_shutdown(handshaker->client); +} + static void handshaker_destroy(tsi_handshaker* self) { if (self == nullptr) { return; @@ -292,8 +306,10 @@ static void handshaker_destroy(tsi_handshaker* self) { } static const tsi_handshaker_vtable handshaker_vtable = { - nullptr, nullptr, nullptr, nullptr, nullptr, handshaker_destroy, - handshaker_next}; + nullptr, nullptr, + nullptr, nullptr, + nullptr, handshaker_destroy, + handshaker_next, handshaker_shutdown}; static void thread_worker(void* arg) { while (true) { @@ -401,6 +417,11 @@ void alts_tsi_handshaker_handle_response(alts_tsi_handshaker* handshaker, cb(TSI_INTERNAL_ERROR, user_data, nullptr, 0, nullptr); return; } + if (handshaker->base.handshake_shutdown) { + gpr_log(GPR_ERROR, "TSI handshake shutdown"); + cb(TSI_HANDSHAKE_SHUTDOWN, user_data, nullptr, 0, nullptr); + return; + } /* Failed grpc call check. */ if (!is_ok || status != GRPC_STATUS_OK) { gpr_log(GPR_ERROR, "grpc call made to handshaker service failed"); @@ -479,5 +500,10 @@ void alts_tsi_handshaker_set_client_for_testing( handshaker->client = client; } +alts_handshaker_client* alts_tsi_handshaker_get_client_for_testing( + alts_tsi_handshaker* handshaker) { + return handshaker->client; +} + } // namespace internal } // namespace grpc_core diff --git a/src/core/tsi/alts/handshaker/alts_tsi_handshaker_private.h b/src/core/tsi/alts/handshaker/alts_tsi_handshaker_private.h index 9b7b9bb6b1c..96120714076 100644 --- a/src/core/tsi/alts/handshaker/alts_tsi_handshaker_private.h +++ b/src/core/tsi/alts/handshaker/alts_tsi_handshaker_private.h @@ -33,6 +33,9 @@ namespace internal { void alts_tsi_handshaker_set_client_for_testing(alts_tsi_handshaker* handshaker, alts_handshaker_client* client); +alts_handshaker_client* alts_tsi_handshaker_get_client_for_testing( + alts_tsi_handshaker* handshaker); + /* For testing only. */ bool alts_tsi_handshaker_get_has_sent_start_message_for_testing( alts_tsi_handshaker* handshaker); diff --git a/src/core/tsi/fake_transport_security.cc b/src/core/tsi/fake_transport_security.cc index ad08b50ede3..4d4c4950451 100644 --- a/src/core/tsi/fake_transport_security.cc +++ b/src/core/tsi/fake_transport_security.cc @@ -738,6 +738,7 @@ static const tsi_handshaker_vtable handshaker_vtable = { nullptr, /* create_frame_protector -- deprecated */ fake_handshaker_destroy, fake_handshaker_next, + nullptr, /* shutdown */ }; tsi_handshaker* tsi_create_fake_handshaker(int is_client) { diff --git a/src/core/tsi/ssl_transport_security.cc b/src/core/tsi/ssl_transport_security.cc index a2301be40a0..8d0729ba05d 100644 --- a/src/core/tsi/ssl_transport_security.cc +++ b/src/core/tsi/ssl_transport_security.cc @@ -1189,6 +1189,7 @@ static const tsi_handshaker_vtable handshaker_vtable = { ssl_handshaker_create_frame_protector, ssl_handshaker_destroy, nullptr, + nullptr, /* shutdown */ }; /* --- tsi_ssl_handshaker_factory common methods. --- */ diff --git a/src/core/tsi/transport_security.cc b/src/core/tsi/transport_security.cc index 129533f7799..99b3229e882 100644 --- a/src/core/tsi/transport_security.cc +++ b/src/core/tsi/transport_security.cc @@ -136,6 +136,7 @@ tsi_result tsi_handshaker_get_bytes_to_send_to_peer(tsi_handshaker* self, return TSI_INVALID_ARGUMENT; } if (self->frame_protector_created) return TSI_FAILED_PRECONDITION; + if (self->handshake_shutdown) return TSI_HANDSHAKE_SHUTDOWN; if (self->vtable->get_bytes_to_send_to_peer == nullptr) return TSI_UNIMPLEMENTED; return self->vtable->get_bytes_to_send_to_peer(self, bytes, bytes_size); @@ -149,6 +150,7 @@ tsi_result tsi_handshaker_process_bytes_from_peer(tsi_handshaker* self, return TSI_INVALID_ARGUMENT; } if (self->frame_protector_created) return TSI_FAILED_PRECONDITION; + if (self->handshake_shutdown) return TSI_HANDSHAKE_SHUTDOWN; if (self->vtable->process_bytes_from_peer == nullptr) return TSI_UNIMPLEMENTED; return self->vtable->process_bytes_from_peer(self, bytes, bytes_size); @@ -157,6 +159,7 @@ tsi_result tsi_handshaker_process_bytes_from_peer(tsi_handshaker* self, tsi_result tsi_handshaker_get_result(tsi_handshaker* self) { if (self == nullptr || self->vtable == nullptr) return TSI_INVALID_ARGUMENT; if (self->frame_protector_created) return TSI_FAILED_PRECONDITION; + if (self->handshake_shutdown) return TSI_HANDSHAKE_SHUTDOWN; if (self->vtable->get_result == nullptr) return TSI_UNIMPLEMENTED; return self->vtable->get_result(self); } @@ -167,6 +170,7 @@ tsi_result tsi_handshaker_extract_peer(tsi_handshaker* self, tsi_peer* peer) { } memset(peer, 0, sizeof(tsi_peer)); if (self->frame_protector_created) return TSI_FAILED_PRECONDITION; + if (self->handshake_shutdown) return TSI_HANDSHAKE_SHUTDOWN; if (tsi_handshaker_get_result(self) != TSI_OK) { return TSI_FAILED_PRECONDITION; } @@ -182,6 +186,7 @@ tsi_result tsi_handshaker_create_frame_protector( return TSI_INVALID_ARGUMENT; } if (self->frame_protector_created) return TSI_FAILED_PRECONDITION; + if (self->handshake_shutdown) return TSI_HANDSHAKE_SHUTDOWN; if (tsi_handshaker_get_result(self) != TSI_OK) return TSI_FAILED_PRECONDITION; if (self->vtable->create_frame_protector == nullptr) return TSI_UNIMPLEMENTED; result = self->vtable->create_frame_protector(self, max_protected_frame_size, @@ -199,12 +204,21 @@ tsi_result tsi_handshaker_next( tsi_handshaker_on_next_done_cb cb, void* user_data) { if (self == nullptr || self->vtable == nullptr) return TSI_INVALID_ARGUMENT; if (self->handshaker_result_created) return TSI_FAILED_PRECONDITION; + if (self->handshake_shutdown) return TSI_HANDSHAKE_SHUTDOWN; if (self->vtable->next == nullptr) return TSI_UNIMPLEMENTED; return self->vtable->next(self, received_bytes, received_bytes_size, bytes_to_send, bytes_to_send_size, handshaker_result, cb, user_data); } +void tsi_handshaker_shutdown(tsi_handshaker* self) { + if (self == nullptr || self->vtable == nullptr) return; + self->handshake_shutdown = true; + if (self->vtable->shutdown != nullptr) { + self->vtable->shutdown(self); + } +} + void tsi_handshaker_destroy(tsi_handshaker* self) { if (self == nullptr) return; self->vtable->destroy(self); diff --git a/src/core/tsi/transport_security.h b/src/core/tsi/transport_security.h index b1ec82d3f77..1923a702e50 100644 --- a/src/core/tsi/transport_security.h +++ b/src/core/tsi/transport_security.h @@ -73,12 +73,14 @@ typedef struct { size_t* bytes_to_send_size, tsi_handshaker_result** handshaker_result, tsi_handshaker_on_next_done_cb cb, void* user_data); + void (*shutdown)(tsi_handshaker* self); } tsi_handshaker_vtable; struct tsi_handshaker { const tsi_handshaker_vtable* vtable; bool frame_protector_created; bool handshaker_result_created; + bool handshake_shutdown; }; /* Base for tsi_handshaker_result implementations. diff --git a/src/core/tsi/transport_security_adapter.cc b/src/core/tsi/transport_security_adapter.cc index 25608f065ab..642188e6197 100644 --- a/src/core/tsi/transport_security_adapter.cc +++ b/src/core/tsi/transport_security_adapter.cc @@ -148,6 +148,12 @@ static void adapter_destroy(tsi_handshaker* self) { gpr_free(self); } +static void adapter_shutdown(tsi_handshaker* self) { + tsi_adapter_handshaker* impl = + reinterpret_cast(self); + tsi_handshaker_shutdown(impl->wrapped); +} + static tsi_result adapter_next( tsi_handshaker* self, const unsigned char* received_bytes, size_t received_bytes_size, const unsigned char** bytes_to_send, @@ -213,6 +219,7 @@ static const tsi_handshaker_vtable handshaker_vtable = { adapter_create_frame_protector, adapter_destroy, adapter_next, + adapter_shutdown, }; tsi_handshaker* tsi_create_adapter_handshaker(tsi_handshaker* wrapped) { diff --git a/src/core/tsi/transport_security_interface.h b/src/core/tsi/transport_security_interface.h index 8c10866934f..07f2bdfd81a 100644 --- a/src/core/tsi/transport_security_interface.h +++ b/src/core/tsi/transport_security_interface.h @@ -42,7 +42,8 @@ typedef enum { TSI_PROTOCOL_FAILURE = 10, TSI_HANDSHAKE_IN_PROGRESS = 11, TSI_OUT_OF_RESOURCES = 12, - TSI_ASYNC = 13 + TSI_ASYNC = 13, + TSI_HANDSHAKE_SHUTDOWN = 14, } tsi_result; typedef enum { @@ -440,6 +441,13 @@ tsi_result tsi_handshaker_next( size_t* bytes_to_send_size, tsi_handshaker_result** handshaker_result, tsi_handshaker_on_next_done_cb cb, void* user_data); +/* This method shuts down a TSI handshake that is in progress. + * + * This method will be invoked when TSI handshake should be terminated before + * being finished in order to free any resources being used. + */ +void tsi_handshaker_shutdown(tsi_handshaker* self); + /* This method releases the tsi_handshaker object. After this method is called, no other method can be called on the object. */ void tsi_handshaker_destroy(tsi_handshaker* self); diff --git a/test/core/tsi/alts/handshaker/alts_handshaker_client_test.cc b/test/core/tsi/alts/handshaker/alts_handshaker_client_test.cc index b9dd52a64a2..c8d88aa72c7 100644 --- a/test/core/tsi/alts/handshaker/alts_handshaker_client_test.cc +++ b/test/core/tsi/alts/handshaker/alts_handshaker_client_test.cc @@ -326,6 +326,9 @@ static void schedule_request_invalid_arg_test() { GPR_ASSERT(alts_handshaker_client_next(nullptr, event, &config->out_frame) == TSI_INVALID_ARGUMENT); + /* Check shutdown. */ + alts_handshaker_client_shutdown(nullptr); + /* Cleanup. */ alts_tsi_event_destroy(event); destroy_config(config); diff --git a/test/core/tsi/alts/handshaker/alts_tsi_handshaker_test.cc b/test/core/tsi/alts/handshaker/alts_tsi_handshaker_test.cc index 95724f84f43..85a58114ba6 100644 --- a/test/core/tsi/alts/handshaker/alts_tsi_handshaker_test.cc +++ b/test/core/tsi/alts/handshaker/alts_tsi_handshaker_test.cc @@ -330,6 +330,8 @@ static tsi_result mock_client_start(alts_handshaker_client* self, return TSI_OK; } +static void mock_shutdown(alts_handshaker_client* self) {} + static tsi_result mock_server_start(alts_handshaker_client* self, alts_tsi_event* event, grpc_slice* bytes_received) { @@ -400,7 +402,8 @@ static tsi_result mock_next(alts_handshaker_client* self, alts_tsi_event* event, static void mock_destruct(alts_handshaker_client* client) {} static const alts_handshaker_client_vtable vtable = { - mock_client_start, mock_server_start, mock_next, mock_destruct}; + mock_client_start, mock_server_start, mock_next, mock_shutdown, + mock_destruct}; static alts_handshaker_client* alts_mock_handshaker_client_create( bool used_for_success_test) { @@ -442,6 +445,16 @@ static void check_handshaker_next_invalid_input() { tsi_handshaker_destroy(handshaker); } +static void check_handshaker_shutdown_invalid_input() { + /* Initialization. */ + tsi_handshaker* handshaker = create_test_handshaker( + false /* used_for_success_test */, true /* is_client */); + /* Check nullptr handshaker. */ + tsi_handshaker_shutdown(nullptr); + /* Cleanup. */ + tsi_handshaker_destroy(handshaker); +} + static void check_handshaker_next_success() { /** * Create handshakers for which internal mock client is going to do @@ -480,6 +493,33 @@ static void check_handshaker_next_success() { tsi_handshaker_destroy(client_handshaker); } +static void check_handshaker_next_with_shutdown() { + /* Initialization. */ + tsi_handshaker* handshaker = create_test_handshaker( + true /* used_for_success_test */, true /* is_client*/); + /* next(success) -- shutdown(success) -- next (fail) */ + GPR_ASSERT(tsi_handshaker_next(handshaker, nullptr, 0, nullptr, nullptr, + nullptr, on_client_start_success_cb, + nullptr) == TSI_ASYNC); + wait(&tsi_to_caller_notification); + tsi_handshaker_shutdown(handshaker); + GPR_ASSERT(tsi_handshaker_next( + handshaker, + (const unsigned char*)ALTS_TSI_HANDSHAKER_TEST_RECV_BYTES, + strlen(ALTS_TSI_HANDSHAKER_TEST_RECV_BYTES), nullptr, nullptr, + nullptr, on_client_next_success_cb, + nullptr) == TSI_HANDSHAKE_SHUTDOWN); + /* Cleanup. */ + tsi_handshaker_destroy(handshaker); +} + +static void check_handle_response_with_shutdown(void* unused) { + /* Client start. */ + wait(&caller_to_tsi_notification); + alts_tsi_event_dispatch_to_handshaker(client_start_event, true /* is_ok */); + alts_tsi_event_destroy(client_start_event); +} + static void check_handshaker_next_failure() { /** * Create handshakers for which internal mock client is always going to fail. @@ -647,6 +687,49 @@ static void check_handle_response_failure() { tsi_handshaker_destroy(handshaker); } +static void on_shutdown_resp_cb(tsi_result status, void* user_data, + const unsigned char* bytes_to_send, + size_t bytes_to_send_size, + tsi_handshaker_result* result) { + GPR_ASSERT(status == TSI_HANDSHAKE_SHUTDOWN); + GPR_ASSERT(user_data == nullptr); + GPR_ASSERT(bytes_to_send == nullptr); + GPR_ASSERT(bytes_to_send_size == 0); + GPR_ASSERT(result == nullptr); +} + +static void check_handle_response_after_shutdown() { + tsi_handshaker* handshaker = create_test_handshaker( + true /* used_for_success_test */, true /* is_client */); + alts_tsi_handshaker* alts_handshaker = + reinterpret_cast(handshaker); + /* Tests. */ + tsi_handshaker_shutdown(handshaker); + grpc_byte_buffer* recv_buffer = generate_handshaker_response(CLIENT_START); + alts_tsi_handshaker_handle_response(alts_handshaker, recv_buffer, + GRPC_STATUS_OK, nullptr, + on_shutdown_resp_cb, nullptr, true); + grpc_byte_buffer_destroy(recv_buffer); + /* Cleanup. */ + tsi_handshaker_destroy(handshaker); +} + +void check_handshaker_next_fails_after_shutdown() { + /* Initialization. */ + notification_init(&caller_to_tsi_notification); + notification_init(&tsi_to_caller_notification); + client_start_event = nullptr; + /* Tests. */ + grpc_core::Thread thd("alts_tsi_handshaker_test", + &check_handle_response_with_shutdown, nullptr); + thd.Start(); + check_handshaker_next_with_shutdown(); + thd.Join(); + /* Cleanup. */ + notification_destroy(&caller_to_tsi_notification); + notification_destroy(&tsi_to_caller_notification); +} + void check_handshaker_success() { /* Initialization. */ notification_init(&caller_to_tsi_notification); @@ -672,10 +755,13 @@ int main(int argc, char** argv) { /* Tests. */ check_handshaker_success(); check_handshaker_next_invalid_input(); + check_handshaker_shutdown_invalid_input(); + check_handshaker_next_fails_after_shutdown(); check_handshaker_next_failure(); check_handle_response_invalid_input(); check_handle_response_invalid_resp(); check_handle_response_failure(); + check_handle_response_after_shutdown(); /* Cleanup. */ grpc_shutdown(); return 0;