Add tsi_handshaker_shutdown to TSI

pull/15265/head
Yihua Zhang 7 years ago
parent 47537f51ac
commit 6fbc436b11
  1. 1
      src/core/lib/security/transport/security_handshaker.cc
  2. 26
      src/core/tsi/alts/handshaker/alts_handshaker_client.cc
  3. 10
      src/core/tsi/alts/handshaker/alts_handshaker_client.h
  4. 30
      src/core/tsi/alts/handshaker/alts_tsi_handshaker.cc
  5. 3
      src/core/tsi/alts/handshaker/alts_tsi_handshaker_private.h
  6. 1
      src/core/tsi/fake_transport_security.cc
  7. 1
      src/core/tsi/ssl_transport_security.cc
  8. 14
      src/core/tsi/transport_security.cc
  9. 2
      src/core/tsi/transport_security.h
  10. 7
      src/core/tsi/transport_security_adapter.cc
  11. 10
      src/core/tsi/transport_security_interface.h
  12. 3
      test/core/tsi/alts/handshaker/alts_handshaker_client_test.cc
  13. 88
      test/core/tsi/alts/handshaker/alts_tsi_handshaker_test.cc

@ -380,6 +380,7 @@ static void security_handshaker_shutdown(grpc_handshaker* handshaker,
gpr_mu_lock(&h->mu);
if (!h->shutdown) {
h->shutdown = true;
tsi_handshaker_shutdown(h->handshaker);
grpc_endpoint_shutdown(h->args->endpoint, GRPC_ERROR_REF(why));
cleanup_args_for_failure_locked(h);
}

@ -118,8 +118,7 @@ static grpc_byte_buffer* get_serialized_start_client(alts_tsi_event* event) {
static tsi_result handshaker_client_start_client(alts_handshaker_client* client,
alts_tsi_event* event) {
if (client == nullptr || event == nullptr) {
gpr_log(GPR_ERROR,
"Invalid arguments to alts_grpc_handshaker_client_start_client()");
gpr_log(GPR_ERROR, "Invalid arguments to handshaker_client_start_client()");
return TSI_INVALID_ARGUMENT;
}
grpc_byte_buffer* buffer = get_serialized_start_client(event);
@ -167,8 +166,7 @@ static tsi_result handshaker_client_start_server(alts_handshaker_client* client,
alts_tsi_event* event,
grpc_slice* bytes_received) {
if (client == nullptr || event == nullptr || bytes_received == nullptr) {
gpr_log(GPR_ERROR,
"Invalid arguments to alts_grpc_handshaker_client_start_server()");
gpr_log(GPR_ERROR, "Invalid arguments to handshaker_client_start_server()");
return TSI_INVALID_ARGUMENT;
}
grpc_byte_buffer* buffer = get_serialized_start_server(event, bytes_received);
@ -206,8 +204,7 @@ static tsi_result handshaker_client_next(alts_handshaker_client* client,
alts_tsi_event* event,
grpc_slice* bytes_received) {
if (client == nullptr || event == nullptr || bytes_received == nullptr) {
gpr_log(GPR_ERROR,
"Invalid arguments to alts_grpc_handshaker_client_next()");
gpr_log(GPR_ERROR, "Invalid arguments to handshaker_client_next()");
return TSI_INVALID_ARGUMENT;
}
grpc_byte_buffer* buffer = get_serialized_next(bytes_received);
@ -223,6 +220,13 @@ static tsi_result handshaker_client_next(alts_handshaker_client* client,
return result;
}
static void handshaker_client_shutdown(alts_handshaker_client* client) {
GPR_ASSERT(client != nullptr);
alts_grpc_handshaker_client* grpc_client =
reinterpret_cast<alts_grpc_handshaker_client*>(client);
GPR_ASSERT(grpc_call_cancel(grpc_client->call, nullptr) == GRPC_CALL_OK);
}
static void handshaker_client_destruct(alts_handshaker_client* client) {
if (client == nullptr) {
return;
@ -234,7 +238,8 @@ static void handshaker_client_destruct(alts_handshaker_client* client) {
static const alts_handshaker_client_vtable vtable = {
handshaker_client_start_client, handshaker_client_start_server,
handshaker_client_next, handshaker_client_destruct};
handshaker_client_next, handshaker_client_shutdown,
handshaker_client_destruct};
alts_handshaker_client* alts_grpc_handshaker_client_create(
grpc_channel* channel, grpc_completion_queue* queue,
@ -306,6 +311,13 @@ tsi_result alts_handshaker_client_next(alts_handshaker_client* client,
return TSI_INVALID_ARGUMENT;
}
void alts_handshaker_client_shutdown(alts_handshaker_client* client) {
if (client != nullptr && client->vtable != nullptr &&
client->vtable->shutdown != nullptr) {
client->vtable->shutdown(client);
}
}
void alts_handshaker_client_destroy(alts_handshaker_client* client) {
if (client != nullptr) {
if (client->vtable != nullptr && client->vtable->destruct != nullptr) {

@ -51,6 +51,7 @@ typedef struct alts_handshaker_client_vtable {
alts_tsi_event* event, grpc_slice* bytes_received);
tsi_result (*next)(alts_handshaker_client* client, alts_tsi_event* event,
grpc_slice* bytes_received);
void (*shutdown)(alts_handshaker_client* client);
void (*destruct)(alts_handshaker_client* client);
} alts_handshaker_client_vtable;
@ -99,6 +100,15 @@ tsi_result alts_handshaker_client_next(alts_handshaker_client* client,
alts_tsi_event* event,
grpc_slice* bytes_received);
/**
* This method cancels previously scheduled, but yet executed handshaker
* requests to ALTS handshaker service. After this operation, the handshake
* will be shutdown, and no more handshaker requests will get scheduled.
*
* - client: ALTS handshaker client instance.
*/
void alts_handshaker_client_shutdown(alts_handshaker_client* client);
/**
* This method destroys a ALTS handshaker client.
*

@ -241,6 +241,10 @@ static tsi_result handshaker_next(
gpr_log(GPR_ERROR, "Invalid arguments to handshaker_next()");
return TSI_INVALID_ARGUMENT;
}
if (self->handshake_shutdown) {
gpr_log(GPR_ERROR, "TSI handshake shutdown");
return TSI_HANDSHAKE_SHUTDOWN;
}
alts_tsi_handshaker* handshaker =
reinterpret_cast<alts_tsi_handshaker*>(self);
tsi_result ok = TSI_OK;
@ -277,6 +281,16 @@ static tsi_result handshaker_next(
return TSI_ASYNC;
}
static void handshaker_shutdown(tsi_handshaker* self) {
GPR_ASSERT(self != nullptr);
if (self->handshake_shutdown) {
return;
}
alts_tsi_handshaker* handshaker =
reinterpret_cast<alts_tsi_handshaker*>(self);
alts_handshaker_client_shutdown(handshaker->client);
}
static void handshaker_destroy(tsi_handshaker* self) {
if (self == nullptr) {
return;
@ -292,8 +306,10 @@ static void handshaker_destroy(tsi_handshaker* self) {
}
static const tsi_handshaker_vtable handshaker_vtable = {
nullptr, nullptr, nullptr, nullptr, nullptr, handshaker_destroy,
handshaker_next};
nullptr, nullptr,
nullptr, nullptr,
nullptr, handshaker_destroy,
handshaker_next, handshaker_shutdown};
static void thread_worker(void* arg) {
while (true) {
@ -401,6 +417,11 @@ void alts_tsi_handshaker_handle_response(alts_tsi_handshaker* handshaker,
cb(TSI_INTERNAL_ERROR, user_data, nullptr, 0, nullptr);
return;
}
if (handshaker->base.handshake_shutdown) {
gpr_log(GPR_ERROR, "TSI handshake shutdown");
cb(TSI_HANDSHAKE_SHUTDOWN, user_data, nullptr, 0, nullptr);
return;
}
/* Failed grpc call check. */
if (!is_ok || status != GRPC_STATUS_OK) {
gpr_log(GPR_ERROR, "grpc call made to handshaker service failed");
@ -479,5 +500,10 @@ void alts_tsi_handshaker_set_client_for_testing(
handshaker->client = client;
}
alts_handshaker_client* alts_tsi_handshaker_get_client_for_testing(
alts_tsi_handshaker* handshaker) {
return handshaker->client;
}
} // namespace internal
} // namespace grpc_core

@ -33,6 +33,9 @@ namespace internal {
void alts_tsi_handshaker_set_client_for_testing(alts_tsi_handshaker* handshaker,
alts_handshaker_client* client);
alts_handshaker_client* alts_tsi_handshaker_get_client_for_testing(
alts_tsi_handshaker* handshaker);
/* For testing only. */
bool alts_tsi_handshaker_get_has_sent_start_message_for_testing(
alts_tsi_handshaker* handshaker);

@ -738,6 +738,7 @@ static const tsi_handshaker_vtable handshaker_vtable = {
nullptr, /* create_frame_protector -- deprecated */
fake_handshaker_destroy,
fake_handshaker_next,
nullptr, /* shutdown */
};
tsi_handshaker* tsi_create_fake_handshaker(int is_client) {

@ -1189,6 +1189,7 @@ static const tsi_handshaker_vtable handshaker_vtable = {
ssl_handshaker_create_frame_protector,
ssl_handshaker_destroy,
nullptr,
nullptr, /* shutdown */
};
/* --- tsi_ssl_handshaker_factory common methods. --- */

@ -136,6 +136,7 @@ tsi_result tsi_handshaker_get_bytes_to_send_to_peer(tsi_handshaker* self,
return TSI_INVALID_ARGUMENT;
}
if (self->frame_protector_created) return TSI_FAILED_PRECONDITION;
if (self->handshake_shutdown) return TSI_HANDSHAKE_SHUTDOWN;
if (self->vtable->get_bytes_to_send_to_peer == nullptr)
return TSI_UNIMPLEMENTED;
return self->vtable->get_bytes_to_send_to_peer(self, bytes, bytes_size);
@ -149,6 +150,7 @@ tsi_result tsi_handshaker_process_bytes_from_peer(tsi_handshaker* self,
return TSI_INVALID_ARGUMENT;
}
if (self->frame_protector_created) return TSI_FAILED_PRECONDITION;
if (self->handshake_shutdown) return TSI_HANDSHAKE_SHUTDOWN;
if (self->vtable->process_bytes_from_peer == nullptr)
return TSI_UNIMPLEMENTED;
return self->vtable->process_bytes_from_peer(self, bytes, bytes_size);
@ -157,6 +159,7 @@ tsi_result tsi_handshaker_process_bytes_from_peer(tsi_handshaker* self,
tsi_result tsi_handshaker_get_result(tsi_handshaker* self) {
if (self == nullptr || self->vtable == nullptr) return TSI_INVALID_ARGUMENT;
if (self->frame_protector_created) return TSI_FAILED_PRECONDITION;
if (self->handshake_shutdown) return TSI_HANDSHAKE_SHUTDOWN;
if (self->vtable->get_result == nullptr) return TSI_UNIMPLEMENTED;
return self->vtable->get_result(self);
}
@ -167,6 +170,7 @@ tsi_result tsi_handshaker_extract_peer(tsi_handshaker* self, tsi_peer* peer) {
}
memset(peer, 0, sizeof(tsi_peer));
if (self->frame_protector_created) return TSI_FAILED_PRECONDITION;
if (self->handshake_shutdown) return TSI_HANDSHAKE_SHUTDOWN;
if (tsi_handshaker_get_result(self) != TSI_OK) {
return TSI_FAILED_PRECONDITION;
}
@ -182,6 +186,7 @@ tsi_result tsi_handshaker_create_frame_protector(
return TSI_INVALID_ARGUMENT;
}
if (self->frame_protector_created) return TSI_FAILED_PRECONDITION;
if (self->handshake_shutdown) return TSI_HANDSHAKE_SHUTDOWN;
if (tsi_handshaker_get_result(self) != TSI_OK) return TSI_FAILED_PRECONDITION;
if (self->vtable->create_frame_protector == nullptr) return TSI_UNIMPLEMENTED;
result = self->vtable->create_frame_protector(self, max_protected_frame_size,
@ -199,12 +204,21 @@ tsi_result tsi_handshaker_next(
tsi_handshaker_on_next_done_cb cb, void* user_data) {
if (self == nullptr || self->vtable == nullptr) return TSI_INVALID_ARGUMENT;
if (self->handshaker_result_created) return TSI_FAILED_PRECONDITION;
if (self->handshake_shutdown) return TSI_HANDSHAKE_SHUTDOWN;
if (self->vtable->next == nullptr) return TSI_UNIMPLEMENTED;
return self->vtable->next(self, received_bytes, received_bytes_size,
bytes_to_send, bytes_to_send_size,
handshaker_result, cb, user_data);
}
void tsi_handshaker_shutdown(tsi_handshaker* self) {
if (self == nullptr || self->vtable == nullptr) return;
self->handshake_shutdown = true;
if (self->vtable->shutdown != nullptr) {
self->vtable->shutdown(self);
}
}
void tsi_handshaker_destroy(tsi_handshaker* self) {
if (self == nullptr) return;
self->vtable->destroy(self);

@ -73,12 +73,14 @@ typedef struct {
size_t* bytes_to_send_size,
tsi_handshaker_result** handshaker_result,
tsi_handshaker_on_next_done_cb cb, void* user_data);
void (*shutdown)(tsi_handshaker* self);
} tsi_handshaker_vtable;
struct tsi_handshaker {
const tsi_handshaker_vtable* vtable;
bool frame_protector_created;
bool handshaker_result_created;
bool handshake_shutdown;
};
/* Base for tsi_handshaker_result implementations.

@ -148,6 +148,12 @@ static void adapter_destroy(tsi_handshaker* self) {
gpr_free(self);
}
static void adapter_shutdown(tsi_handshaker* self) {
tsi_adapter_handshaker* impl =
reinterpret_cast<tsi_adapter_handshaker*>(self);
tsi_handshaker_shutdown(impl->wrapped);
}
static tsi_result adapter_next(
tsi_handshaker* self, const unsigned char* received_bytes,
size_t received_bytes_size, const unsigned char** bytes_to_send,
@ -213,6 +219,7 @@ static const tsi_handshaker_vtable handshaker_vtable = {
adapter_create_frame_protector,
adapter_destroy,
adapter_next,
adapter_shutdown,
};
tsi_handshaker* tsi_create_adapter_handshaker(tsi_handshaker* wrapped) {

@ -42,7 +42,8 @@ typedef enum {
TSI_PROTOCOL_FAILURE = 10,
TSI_HANDSHAKE_IN_PROGRESS = 11,
TSI_OUT_OF_RESOURCES = 12,
TSI_ASYNC = 13
TSI_ASYNC = 13,
TSI_HANDSHAKE_SHUTDOWN = 14,
} tsi_result;
typedef enum {
@ -440,6 +441,13 @@ tsi_result tsi_handshaker_next(
size_t* bytes_to_send_size, tsi_handshaker_result** handshaker_result,
tsi_handshaker_on_next_done_cb cb, void* user_data);
/* This method shuts down a TSI handshake that is in progress.
*
* This method will be invoked when TSI handshake should be terminated before
* being finished in order to free any resources being used.
*/
void tsi_handshaker_shutdown(tsi_handshaker* self);
/* This method releases the tsi_handshaker object. After this method is called,
no other method can be called on the object. */
void tsi_handshaker_destroy(tsi_handshaker* self);

@ -326,6 +326,9 @@ static void schedule_request_invalid_arg_test() {
GPR_ASSERT(alts_handshaker_client_next(nullptr, event, &config->out_frame) ==
TSI_INVALID_ARGUMENT);
/* Check shutdown. */
alts_handshaker_client_shutdown(nullptr);
/* Cleanup. */
alts_tsi_event_destroy(event);
destroy_config(config);

@ -330,6 +330,8 @@ static tsi_result mock_client_start(alts_handshaker_client* self,
return TSI_OK;
}
static void mock_shutdown(alts_handshaker_client* self) {}
static tsi_result mock_server_start(alts_handshaker_client* self,
alts_tsi_event* event,
grpc_slice* bytes_received) {
@ -400,7 +402,8 @@ static tsi_result mock_next(alts_handshaker_client* self, alts_tsi_event* event,
static void mock_destruct(alts_handshaker_client* client) {}
static const alts_handshaker_client_vtable vtable = {
mock_client_start, mock_server_start, mock_next, mock_destruct};
mock_client_start, mock_server_start, mock_next, mock_shutdown,
mock_destruct};
static alts_handshaker_client* alts_mock_handshaker_client_create(
bool used_for_success_test) {
@ -442,6 +445,16 @@ static void check_handshaker_next_invalid_input() {
tsi_handshaker_destroy(handshaker);
}
static void check_handshaker_shutdown_invalid_input() {
/* Initialization. */
tsi_handshaker* handshaker = create_test_handshaker(
false /* used_for_success_test */, true /* is_client */);
/* Check nullptr handshaker. */
tsi_handshaker_shutdown(nullptr);
/* Cleanup. */
tsi_handshaker_destroy(handshaker);
}
static void check_handshaker_next_success() {
/**
* Create handshakers for which internal mock client is going to do
@ -480,6 +493,33 @@ static void check_handshaker_next_success() {
tsi_handshaker_destroy(client_handshaker);
}
static void check_handshaker_next_with_shutdown() {
/* Initialization. */
tsi_handshaker* handshaker = create_test_handshaker(
true /* used_for_success_test */, true /* is_client*/);
/* next(success) -- shutdown(success) -- next (fail) */
GPR_ASSERT(tsi_handshaker_next(handshaker, nullptr, 0, nullptr, nullptr,
nullptr, on_client_start_success_cb,
nullptr) == TSI_ASYNC);
wait(&tsi_to_caller_notification);
tsi_handshaker_shutdown(handshaker);
GPR_ASSERT(tsi_handshaker_next(
handshaker,
(const unsigned char*)ALTS_TSI_HANDSHAKER_TEST_RECV_BYTES,
strlen(ALTS_TSI_HANDSHAKER_TEST_RECV_BYTES), nullptr, nullptr,
nullptr, on_client_next_success_cb,
nullptr) == TSI_HANDSHAKE_SHUTDOWN);
/* Cleanup. */
tsi_handshaker_destroy(handshaker);
}
static void check_handle_response_with_shutdown(void* unused) {
/* Client start. */
wait(&caller_to_tsi_notification);
alts_tsi_event_dispatch_to_handshaker(client_start_event, true /* is_ok */);
alts_tsi_event_destroy(client_start_event);
}
static void check_handshaker_next_failure() {
/**
* Create handshakers for which internal mock client is always going to fail.
@ -647,6 +687,49 @@ static void check_handle_response_failure() {
tsi_handshaker_destroy(handshaker);
}
static void on_shutdown_resp_cb(tsi_result status, void* user_data,
const unsigned char* bytes_to_send,
size_t bytes_to_send_size,
tsi_handshaker_result* result) {
GPR_ASSERT(status == TSI_HANDSHAKE_SHUTDOWN);
GPR_ASSERT(user_data == nullptr);
GPR_ASSERT(bytes_to_send == nullptr);
GPR_ASSERT(bytes_to_send_size == 0);
GPR_ASSERT(result == nullptr);
}
static void check_handle_response_after_shutdown() {
tsi_handshaker* handshaker = create_test_handshaker(
true /* used_for_success_test */, true /* is_client */);
alts_tsi_handshaker* alts_handshaker =
reinterpret_cast<alts_tsi_handshaker*>(handshaker);
/* Tests. */
tsi_handshaker_shutdown(handshaker);
grpc_byte_buffer* recv_buffer = generate_handshaker_response(CLIENT_START);
alts_tsi_handshaker_handle_response(alts_handshaker, recv_buffer,
GRPC_STATUS_OK, nullptr,
on_shutdown_resp_cb, nullptr, true);
grpc_byte_buffer_destroy(recv_buffer);
/* Cleanup. */
tsi_handshaker_destroy(handshaker);
}
void check_handshaker_next_fails_after_shutdown() {
/* Initialization. */
notification_init(&caller_to_tsi_notification);
notification_init(&tsi_to_caller_notification);
client_start_event = nullptr;
/* Tests. */
grpc_core::Thread thd("alts_tsi_handshaker_test",
&check_handle_response_with_shutdown, nullptr);
thd.Start();
check_handshaker_next_with_shutdown();
thd.Join();
/* Cleanup. */
notification_destroy(&caller_to_tsi_notification);
notification_destroy(&tsi_to_caller_notification);
}
void check_handshaker_success() {
/* Initialization. */
notification_init(&caller_to_tsi_notification);
@ -672,10 +755,13 @@ int main(int argc, char** argv) {
/* Tests. */
check_handshaker_success();
check_handshaker_next_invalid_input();
check_handshaker_shutdown_invalid_input();
check_handshaker_next_fails_after_shutdown();
check_handshaker_next_failure();
check_handle_response_invalid_input();
check_handle_response_invalid_resp();
check_handle_response_failure();
check_handle_response_after_shutdown();
/* Cleanup. */
grpc_shutdown();
return 0;

Loading…
Cancel
Save