diff --git a/src/core/lib/security/security_connector/alts/alts_security_connector.cc b/src/core/lib/security/security_connector/alts/alts_security_connector.cc index 1274edb6e6e..6bfe6ea74c3 100644 --- a/src/core/lib/security/security_connector/alts/alts_security_connector.cc +++ b/src/core/lib/security/security_connector/alts/alts_security_connector.cc @@ -82,10 +82,17 @@ class grpc_alts_channel_security_connector final tsi_handshaker* handshaker = nullptr; const grpc_alts_credentials* creds = static_cast(channel_creds()); - GPR_ASSERT(alts_tsi_handshaker_create(creds->options(), target_name_, - creds->handshaker_service_url(), true, - interested_parties, - &handshaker) == TSI_OK); + size_t user_specified_max_frame_size = 0; + const grpc_arg* arg = + grpc_channel_args_find(args, GRPC_ARG_TSI_MAX_FRAME_SIZE); + if (arg != nullptr && arg->type == GRPC_ARG_INTEGER) { + user_specified_max_frame_size = grpc_channel_arg_get_integer( + arg, {0, 0, std::numeric_limits::max()}); + } + GPR_ASSERT(alts_tsi_handshaker_create( + creds->options(), target_name_, + creds->handshaker_service_url(), true, interested_parties, + &handshaker, user_specified_max_frame_size) == TSI_OK); handshake_manager->Add( grpc_core::SecurityHandshakerCreate(handshaker, this, args)); } @@ -140,9 +147,17 @@ class grpc_alts_server_security_connector final tsi_handshaker* handshaker = nullptr; const grpc_alts_server_credentials* creds = static_cast(server_creds()); + size_t user_specified_max_frame_size = 0; + const grpc_arg* arg = + grpc_channel_args_find(args, GRPC_ARG_TSI_MAX_FRAME_SIZE); + if (arg != nullptr && arg->type == GRPC_ARG_INTEGER) { + user_specified_max_frame_size = grpc_channel_arg_get_integer( + arg, {0, 0, std::numeric_limits::max()}); + } GPR_ASSERT(alts_tsi_handshaker_create( creds->options(), nullptr, creds->handshaker_service_url(), - false, interested_parties, &handshaker) == TSI_OK); + false, interested_parties, &handshaker, + user_specified_max_frame_size) == TSI_OK); handshake_manager->Add( grpc_core::SecurityHandshakerCreate(handshaker, this, args)); } diff --git a/src/core/tsi/alts/handshaker/alts_handshaker_client.cc b/src/core/tsi/alts/handshaker/alts_handshaker_client.cc index 2592763e5a2..61927276195 100644 --- a/src/core/tsi/alts/handshaker/alts_handshaker_client.cc +++ b/src/core/tsi/alts/handshaker/alts_handshaker_client.cc @@ -102,6 +102,8 @@ typedef struct alts_grpc_handshaker_client { bool receive_status_finished; /* if non-null, contains arguments to complete a TSI next callback. */ recv_message_result* pending_recv_message_result; + /* Maximum frame size used by frame protector. */ + size_t max_frame_size; } alts_grpc_handshaker_client; static void handshaker_client_send_buffer_destroy( @@ -506,6 +508,8 @@ static grpc_byte_buffer* get_serialized_start_client( upb_strview_makez(ptr->data)); ptr = ptr->next; } + grpc_gcp_StartClientHandshakeReq_set_max_frame_size( + start_client, static_cast(client->max_frame_size)); return get_serialized_handshaker_req(req, arena.ptr()); } @@ -565,6 +569,8 @@ static grpc_byte_buffer* get_serialized_start_server( arena.ptr()); grpc_gcp_RpcProtocolVersions_assign_from_struct( server_version, arena.ptr(), &client->options->rpc_versions); + grpc_gcp_StartServerHandshakeReq_set_max_frame_size( + start_server, static_cast(client->max_frame_size)); return get_serialized_handshaker_req(req, arena.ptr()); } @@ -674,7 +680,7 @@ alts_handshaker_client* alts_grpc_handshaker_client_create( grpc_alts_credentials_options* options, const grpc_slice& target_name, grpc_iomgr_cb_func grpc_cb, tsi_handshaker_on_next_done_cb cb, void* user_data, alts_handshaker_client_vtable* vtable_for_testing, - bool is_client) { + bool is_client, size_t max_frame_size) { if (channel == nullptr || handshaker_service_url == nullptr) { gpr_log(GPR_ERROR, "Invalid arguments to alts_handshaker_client_create()"); return nullptr; @@ -694,6 +700,7 @@ alts_handshaker_client* alts_grpc_handshaker_client_create( client->recv_bytes = grpc_empty_slice(); grpc_metadata_array_init(&client->recv_initial_metadata); client->is_client = is_client; + client->max_frame_size = max_frame_size; client->buffer_size = TSI_ALTS_INITIAL_BUFFER_SIZE; client->buffer = static_cast(gpr_zalloc(client->buffer_size)); grpc_slice slice = grpc_slice_from_copied_string(handshaker_service_url); diff --git a/src/core/tsi/alts/handshaker/alts_handshaker_client.h b/src/core/tsi/alts/handshaker/alts_handshaker_client.h index 319a23c88c7..d8669da01cb 100644 --- a/src/core/tsi/alts/handshaker/alts_handshaker_client.h +++ b/src/core/tsi/alts/handshaker/alts_handshaker_client.h @@ -117,7 +117,7 @@ void alts_handshaker_client_destroy(alts_handshaker_client* client); * This method creates an ALTS handshaker client. * * - handshaker: ALTS TSI handshaker to which the created handshaker client - * belongs to. + * belongs to. * - channel: grpc channel to ALTS handshaker service. * - handshaker_service_url: address of ALTS handshaker service in the format of * "host:port". @@ -132,8 +132,12 @@ void alts_handshaker_client_destroy(alts_handshaker_client* client); * - vtable_for_testing: ALTS handshaker client vtable instance used for * testing purpose. * - is_client: a boolean value indicating if the created handshaker client is - * used at the client (is_client = true) or server (is_client = false) side. It - * returns the created ALTS handshaker client on success, and NULL on failure. + * used at the client (is_client = true) or server (is_client = false) side. + * - max_frame_size: Maximum frame size used by frame protector (User specified + * maximum frame size if present or default max frame size). + * + * It returns the created ALTS handshaker client on success, and NULL + * on failure. */ alts_handshaker_client* alts_grpc_handshaker_client_create( alts_tsi_handshaker* handshaker, grpc_channel* channel, @@ -141,7 +145,7 @@ alts_handshaker_client* alts_grpc_handshaker_client_create( grpc_alts_credentials_options* options, const grpc_slice& target_name, grpc_iomgr_cb_func grpc_cb, tsi_handshaker_on_next_done_cb cb, void* user_data, alts_handshaker_client_vtable* vtable_for_testing, - bool is_client); + bool is_client, size_t max_frame_size); /** * This method handles handshaker response returned from ALTS handshaker diff --git a/src/core/tsi/alts/handshaker/alts_tsi_handshaker.cc b/src/core/tsi/alts/handshaker/alts_tsi_handshaker.cc index 0c700306d8f..2a925182d3f 100644 --- a/src/core/tsi/alts/handshaker/alts_tsi_handshaker.cc +++ b/src/core/tsi/alts/handshaker/alts_tsi_handshaker.cc @@ -63,6 +63,8 @@ struct alts_tsi_handshaker { // shutdown effectively follows base.handshake_shutdown, // but is synchronized by the mutex of this object. bool shutdown; + // Maximum frame size used by frame protector. + size_t max_frame_size; }; /* Main struct for ALTS TSI handshaker result. */ @@ -75,6 +77,8 @@ typedef struct alts_tsi_handshaker_result { grpc_slice rpc_versions; bool is_client; grpc_slice serialized_context; + // Peer's maximum frame size. + size_t max_frame_size; } alts_tsi_handshaker_result; static tsi_result handshaker_result_extract_peer( @@ -156,6 +160,26 @@ static tsi_result handshaker_result_create_zero_copy_grpc_protector( alts_tsi_handshaker_result* result = reinterpret_cast( const_cast(self)); + + // In case the peer does not send max frame size (e.g. peer is gRPC Go or + // peer uses an old binary), the negotiated frame size is set to + // kTsiAltsMinFrameSize (ignoring max_output_protected_frame_size value if + // present). Otherwise, it is based on peer and user specified max frame + // size (if present). + size_t max_frame_size = kTsiAltsMinFrameSize; + if (result->max_frame_size) { + size_t peer_max_frame_size = result->max_frame_size; + max_frame_size = std::min(peer_max_frame_size, + max_output_protected_frame_size == nullptr + ? kTsiAltsMaxFrameSize + : *max_output_protected_frame_size); + max_frame_size = std::max(max_frame_size, kTsiAltsMinFrameSize); + } + max_output_protected_frame_size = &max_frame_size; + gpr_log(GPR_DEBUG, + "After Frame Size Negotiation, maximum frame size used by frame " + "protector equals %zu", + *max_output_protected_frame_size); tsi_result ok = alts_zero_copy_grpc_protector_create( reinterpret_cast(result->key_data), kAltsAes128GcmRekeyKeyLength, /*is_rekey=*/true, result->is_client, @@ -288,6 +312,7 @@ tsi_result alts_tsi_handshaker_result_create(grpc_gcp_HandshakerResp* resp, static_cast(gpr_zalloc(peer_service_account.size + 1)); memcpy(result->peer_identity, peer_service_account.data, peer_service_account.size); + result->max_frame_size = grpc_gcp_HandshakerResult_max_frame_size(hresult); upb::Arena rpc_versions_arena; bool serialized = grpc_gcp_rpc_protocol_versions_encode( peer_rpc_version, rpc_versions_arena.ptr(), &result->rpc_versions); @@ -374,7 +399,8 @@ static tsi_result alts_tsi_handshaker_continue_handshaker_next( handshaker, channel, handshaker->handshaker_service_url, handshaker->interested_parties, handshaker->options, handshaker->target_name, grpc_cb, cb, user_data, - handshaker->client_vtable_for_testing, handshaker->is_client); + handshaker->client_vtable_for_testing, handshaker->is_client, + handshaker->max_frame_size); if (client == nullptr) { gpr_log(GPR_ERROR, "Failed to create ALTS handshaker client"); return TSI_FAILED_PRECONDITION; @@ -570,7 +596,8 @@ bool alts_tsi_handshaker_has_shutdown(alts_tsi_handshaker* handshaker) { tsi_result alts_tsi_handshaker_create( const grpc_alts_credentials_options* options, const char* target_name, const char* handshaker_service_url, bool is_client, - grpc_pollset_set* interested_parties, tsi_handshaker** self) { + grpc_pollset_set* interested_parties, tsi_handshaker** self, + size_t user_specified_max_frame_size) { if (handshaker_service_url == nullptr || self == nullptr || options == nullptr || (is_client && target_name == nullptr)) { gpr_log(GPR_ERROR, "Invalid arguments to alts_tsi_handshaker_create()"); @@ -590,6 +617,9 @@ tsi_result alts_tsi_handshaker_create( handshaker->has_created_handshaker_client = false; handshaker->handshaker_service_url = gpr_strdup(handshaker_service_url); handshaker->options = grpc_alts_credentials_options_copy(options); + handshaker->max_frame_size = user_specified_max_frame_size != 0 + ? user_specified_max_frame_size + : kTsiAltsMaxFrameSize; handshaker->base.vtable = handshaker->use_dedicated_cq ? &handshaker_vtable_dedicated : &handshaker_vtable; diff --git a/src/core/tsi/alts/handshaker/alts_tsi_handshaker.h b/src/core/tsi/alts/handshaker/alts_tsi_handshaker.h index 5bace9affa8..e1ae985a84d 100644 --- a/src/core/tsi/alts/handshaker/alts_tsi_handshaker.h +++ b/src/core/tsi/alts/handshaker/alts_tsi_handshaker.h @@ -38,6 +38,11 @@ const size_t kTsiAltsNumOfPeerProperties = 5; +// Frame size negotiation extends send frame size range to +// [kTsiAltsMinFrameSize, kTsiAltsMaxFrameSize]. +const size_t kTsiAltsMinFrameSize = 16 * 1024; +const size_t kTsiAltsMaxFrameSize = 128 * 1024; + typedef struct alts_tsi_handshaker alts_tsi_handshaker; /** @@ -54,6 +59,8 @@ typedef struct alts_tsi_handshaker alts_tsi_handshaker; * - interested_parties: set of pollsets interested in this connection. * - self: address of ALTS TSI handshaker instance to be returned from the * method. + * - user_specified_max_frame_size: Determines the maximum frame size used by + * frame protector that is specified via user. If unspecified, the value is 0. * * It returns TSI_OK on success and an error status code on failure. Note that * if interested_parties is nullptr, a dedicated TSI thread will be created and @@ -62,7 +69,8 @@ typedef struct alts_tsi_handshaker alts_tsi_handshaker; tsi_result alts_tsi_handshaker_create( const grpc_alts_credentials_options* options, const char* target_name, const char* handshaker_service_url, bool is_client, - grpc_pollset_set* interested_parties, tsi_handshaker** self); + grpc_pollset_set* interested_parties, tsi_handshaker** self, + size_t user_specified_max_frame_size); /** * This method creates an ALTS TSI handshaker result instance. diff --git a/test/core/tsi/alts/handshaker/alts_handshaker_client_test.cc b/test/core/tsi/alts/handshaker/alts_handshaker_client_test.cc index 0e1ab006728..5f9a4b2d745 100644 --- a/test/core/tsi/alts/handshaker/alts_handshaker_client_test.cc +++ b/test/core/tsi/alts/handshaker/alts_handshaker_client_test.cc @@ -31,6 +31,7 @@ #define ALTS_HANDSHAKER_CLIENT_TEST_TARGET_NAME "bigtable.google.api.com" #define ALTS_HANDSHAKER_CLIENT_TEST_TARGET_SERVICE_ACCOUNT1 "A@google.com" #define ALTS_HANDSHAKER_CLIENT_TEST_TARGET_SERVICE_ACCOUNT2 "B@google.com" +#define ALTS_HANDSHAKER_CLIENT_TEST_MAX_FRAME_SIZE 64 * 1024 const size_t kHandshakerClientOpNum = 4; const size_t kMaxRpcVersionMajor = 3; @@ -155,8 +156,8 @@ static grpc_call_error check_must_not_be_called(grpc_call* /*call*/, /** * A mock grpc_caller used to check correct execution of client_start operation. * It checks if the client_start handshaker request is populated with correct - * handshake_security_protocol, application_protocol, and record_protocol, and - * op is correctly populated. + * handshake_security_protocol, application_protocol, record_protocol and + * max_frame_size, and op is correctly populated. */ static grpc_call_error check_client_start_success(grpc_call* /*call*/, const grpc_op* op, @@ -196,7 +197,8 @@ static grpc_call_error check_client_start_success(grpc_call* /*call*/, GPR_ASSERT(upb_strview_eql( grpc_gcp_StartClientHandshakeReq_target_name(client_start), upb_strview_makez(ALTS_HANDSHAKER_CLIENT_TEST_TARGET_NAME))); - + GPR_ASSERT(grpc_gcp_StartClientHandshakeReq_max_frame_size(client_start) == + ALTS_HANDSHAKER_CLIENT_TEST_MAX_FRAME_SIZE); GPR_ASSERT(validate_op(client, op, nops, true /* is_start */)); return GRPC_CALL_OK; } @@ -204,8 +206,8 @@ static grpc_call_error check_client_start_success(grpc_call* /*call*/, /** * A mock grpc_caller used to check correct execution of server_start operation. * It checks if the server_start handshaker request is populated with correct - * handshake_security_protocol, application_protocol, and record_protocol, and - * op is correctly populated. + * handshake_security_protocol, application_protocol, record_protocol and + * max_frame_size, and op is correctly populated. */ static grpc_call_error check_server_start_success(grpc_call* /*call*/, const grpc_op* op, @@ -245,6 +247,8 @@ static grpc_call_error check_server_start_success(grpc_call* /*call*/, upb_strview_makez(ALTS_RECORD_PROTOCOL))); validate_rpc_protocol_versions( grpc_gcp_StartServerHandshakeReq_rpc_versions(server_start)); + GPR_ASSERT(grpc_gcp_StartServerHandshakeReq_max_frame_size(server_start) == + ALTS_HANDSHAKER_CLIENT_TEST_MAX_FRAME_SIZE); GPR_ASSERT(validate_op(client, op, nops, true /* is_start */)); return GRPC_CALL_OK; } @@ -321,12 +325,14 @@ static alts_handshaker_client_test_config* create_config() { nullptr, config->channel, ALTS_HANDSHAKER_SERVICE_URL_FOR_TESTING, nullptr, server_options, grpc_slice_from_static_string(ALTS_HANDSHAKER_CLIENT_TEST_TARGET_NAME), - nullptr, nullptr, nullptr, nullptr, false); + nullptr, nullptr, nullptr, nullptr, false, + ALTS_HANDSHAKER_CLIENT_TEST_MAX_FRAME_SIZE); config->client = alts_grpc_handshaker_client_create( nullptr, config->channel, ALTS_HANDSHAKER_SERVICE_URL_FOR_TESTING, nullptr, client_options, grpc_slice_from_static_string(ALTS_HANDSHAKER_CLIENT_TEST_TARGET_NAME), - nullptr, nullptr, nullptr, nullptr, true); + nullptr, nullptr, nullptr, nullptr, true, + ALTS_HANDSHAKER_CLIENT_TEST_MAX_FRAME_SIZE); GPR_ASSERT(config->client != nullptr); GPR_ASSERT(config->server != nullptr); grpc_alts_credentials_options_destroy(client_options); diff --git a/test/core/tsi/alts/handshaker/alts_tsi_handshaker_test.cc b/test/core/tsi/alts/handshaker/alts_tsi_handshaker_test.cc index 5dd76d82fdc..2127e980488 100644 --- a/test/core/tsi/alts/handshaker/alts_tsi_handshaker_test.cc +++ b/test/core/tsi/alts/handshaker/alts_tsi_handshaker_test.cc @@ -27,6 +27,7 @@ #include "src/core/tsi/alts/handshaker/alts_shared_resource.h" #include "src/core/tsi/alts/handshaker/alts_tsi_handshaker.h" #include "src/core/tsi/alts/handshaker/alts_tsi_handshaker_private.h" +#include "src/core/tsi/transport_security_grpc.h" #include "src/proto/grpc/gcp/altscontext.upb.h" #include "test/core/tsi/alts/handshaker/alts_handshaker_service_api_test_lib.h" #include "test/core/util/test_config.h" @@ -49,6 +50,7 @@ #define ALTS_TSI_HANDSHAKER_TEST_APPLICATION_PROTOCOL \ "test application protocol" #define ALTS_TSI_HANDSHAKER_TEST_RECORD_PROTOCOL "test record protocol" +#define ALTS_TSI_HANDSHAKER_TEST_MAX_FRAME_SIZE 256 * 1024 using grpc_core::internal::alts_handshaker_client_check_fields_for_testing; using grpc_core::internal::alts_handshaker_client_get_handshaker_for_testing; @@ -164,6 +166,8 @@ static grpc_byte_buffer* generate_handshaker_response( upb_strview_makez(ALTS_TSI_HANDSHAKER_TEST_APPLICATION_PROTOCOL)); grpc_gcp_HandshakerResult_set_record_protocol( result, upb_strview_makez(ALTS_TSI_HANDSHAKER_TEST_RECORD_PROTOCOL)); + grpc_gcp_HandshakerResult_set_max_frame_size( + result, ALTS_TSI_HANDSHAKER_TEST_MAX_FRAME_SIZE); break; case SERVER_NEXT: grpc_gcp_HandshakerResp_set_bytes_consumed( @@ -283,6 +287,17 @@ static void on_client_next_success_cb(tsi_result status, void* user_data, GPR_ASSERT(memcmp(bytes_to_send, ALTS_TSI_HANDSHAKER_TEST_OUT_FRAME, bytes_to_send_size) == 0); GPR_ASSERT(result != nullptr); + // Validate max frame size value after Frame Size Negotiation. Here peer max + // frame size is greater than default value, and user specified max frame size + // is absent. + tsi_zero_copy_grpc_protector* zero_copy_protector = nullptr; + GPR_ASSERT(tsi_handshaker_result_create_zero_copy_grpc_protector( + result, nullptr, &zero_copy_protector) == TSI_OK); + size_t actual_max_frame_size; + tsi_zero_copy_grpc_protector_max_frame_size(zero_copy_protector, + &actual_max_frame_size); + GPR_ASSERT(actual_max_frame_size == kTsiAltsMaxFrameSize); + tsi_zero_copy_grpc_protector_destroy(zero_copy_protector); /* Validate peer identity. */ tsi_peer peer; GPR_ASSERT(tsi_handshaker_result_extract_peer(result, &peer) == TSI_OK); @@ -343,6 +358,20 @@ static void on_server_next_success_cb(tsi_result status, void* user_data, GPR_ASSERT(bytes_to_send_size == 0); GPR_ASSERT(bytes_to_send == nullptr); GPR_ASSERT(result != nullptr); + // Validate max frame size value after Frame Size Negotiation. The negotiated + // frame size value equals minimum send frame size, due to the absence of peer + // max frame size. + tsi_zero_copy_grpc_protector* zero_copy_protector = nullptr; + size_t user_specified_max_frame_size = + ALTS_TSI_HANDSHAKER_TEST_MAX_FRAME_SIZE; + GPR_ASSERT(tsi_handshaker_result_create_zero_copy_grpc_protector( + result, &user_specified_max_frame_size, + &zero_copy_protector) == TSI_OK); + size_t actual_max_frame_size; + tsi_zero_copy_grpc_protector_max_frame_size(zero_copy_protector, + &actual_max_frame_size); + GPR_ASSERT(actual_max_frame_size == kTsiAltsMinFrameSize); + tsi_zero_copy_grpc_protector_destroy(zero_copy_protector); /* Validate peer identity. */ tsi_peer peer; GPR_ASSERT(tsi_handshaker_result_extract_peer(result, &peer) == TSI_OK); @@ -478,7 +507,7 @@ static tsi_handshaker* create_test_handshaker(bool is_client) { grpc_alts_credentials_client_options_create(); alts_tsi_handshaker_create(options, "target_name", ALTS_HANDSHAKER_SERVICE_URL_FOR_TESTING, is_client, - nullptr, &handshaker); + nullptr, &handshaker, 0); alts_tsi_handshaker* alts_handshaker = reinterpret_cast(handshaker); alts_tsi_handshaker_set_client_vtable_for_testing(alts_handshaker, &vtable);