@ -42,6 +42,8 @@ typedef struct local_zero_copy_grpc_protector {
typedef struct local_tsi_handshaker_result {
tsi_handshaker_result base ;
bool is_client ;
unsigned char * unused_bytes ;
size_t unused_bytes_size ;
} local_tsi_handshaker_result ;
/* Main struct for local TSI handshaker. */
@ -127,6 +129,20 @@ static tsi_result handshaker_result_create_zero_copy_grpc_protector(
return ok ;
}
static tsi_result handshaker_result_get_unused_bytes (
const tsi_handshaker_result * self , const unsigned char * * bytes ,
size_t * bytes_size ) {
if ( self = = nullptr | | bytes = = nullptr | | bytes_size = = nullptr ) {
gpr_log ( GPR_ERROR , " Invalid arguments to get_unused_bytes() " ) ;
return TSI_INVALID_ARGUMENT ;
}
auto * result = reinterpret_cast < local_tsi_handshaker_result * > (
const_cast < tsi_handshaker_result * > ( self ) ) ;
* bytes_size = result - > unused_bytes_size ;
* bytes = result - > unused_bytes ;
return TSI_OK ;
}
static void handshaker_result_destroy ( tsi_handshaker_result * self ) {
if ( self = = nullptr ) {
return ;
@ -134,6 +150,7 @@ static void handshaker_result_destroy(tsi_handshaker_result* self) {
local_tsi_handshaker_result * result =
reinterpret_cast < local_tsi_handshaker_result * > (
const_cast < tsi_handshaker_result * > ( self ) ) ;
gpr_free ( result - > unused_bytes ) ;
gpr_free ( result ) ;
}
@ -141,10 +158,11 @@ static const tsi_handshaker_result_vtable result_vtable = {
handshaker_result_extract_peer ,
handshaker_result_create_zero_copy_grpc_protector ,
nullptr , /* handshaker_result_create_frame_protector */
nullptr , /* handshaker_result_get_unused_bytes */
handshaker_result_destroy } ;
handshaker_result_get_unused_bytes , handshaker_result_destroy } ;
static tsi_result create_handshaker_result ( bool is_client ,
const unsigned char * received_bytes ,
size_t received_bytes_size ,
tsi_handshaker_result * * self ) {
if ( self = = nullptr ) {
gpr_log ( GPR_ERROR , " Invalid arguments to create_handshaker_result() " ) ;
@ -153,6 +171,12 @@ static tsi_result create_handshaker_result(bool is_client,
local_tsi_handshaker_result * result =
static_cast < local_tsi_handshaker_result * > ( gpr_zalloc ( sizeof ( * result ) ) ) ;
result - > is_client = is_client ;
if ( received_bytes_size > 0 ) {
result - > unused_bytes =
static_cast < unsigned char * > ( gpr_malloc ( received_bytes_size ) ) ;
memcpy ( result - > unused_bytes , received_bytes , received_bytes_size ) ;
}
result - > unused_bytes_size = received_bytes_size ;
result - > base . vtable = & result_vtable ;
* self = & result - > base ;
return TSI_OK ;
@ -161,8 +185,8 @@ static tsi_result create_handshaker_result(bool is_client,
/* --- tsi_handshaker methods implementation. --- */
static tsi_result handshaker_next (
tsi_handshaker * self , const unsigned char * /*received_bytes*/ ,
size_t /*received_bytes_size*/ , const unsigned char * * /*bytes_to_send*/ ,
tsi_handshaker * self , const unsigned char * received_bytes ,
size_t received_bytes_size , const unsigned char * * /*bytes_to_send*/ ,
size_t * bytes_to_send_size , tsi_handshaker_result * * result ,
tsi_handshaker_on_next_done_cb /*cb*/ , void * /*user_data*/ ) {
if ( self = = nullptr ) {
@ -175,7 +199,8 @@ static tsi_result handshaker_next(
local_tsi_handshaker * handshaker =
reinterpret_cast < local_tsi_handshaker * > ( self ) ;
* bytes_to_send_size = 0 ;
create_handshaker_result ( handshaker - > is_client , result ) ;
create_handshaker_result ( handshaker - > is_client , received_bytes ,
received_bytes_size , result ) ;
return TSI_OK ;
}