don't create secure endpoint if TSI implementation does not require it (#27509)

* don't create secure endpoint if security level is TSI_SECURITY_NONE

* instead of depending on security level, add a method to tsi_handshaker_result

* clang-format

* fix build

* fix build for realz

* code review changes

* update comments
pull/27426/head
Mark D. Roth 3 years ago committed by GitHub
parent 67eb6386d3
commit 03a51fa9d1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 111
      src/core/lib/security/transport/security_handshaker.cc
  2. 11
      src/core/tsi/alts/handshaker/alts_tsi_handshaker.cc
  3. 9
      src/core/tsi/fake_transport_security.cc
  4. 79
      src/core/tsi/local_transport_security.cc
  5. 8
      src/core/tsi/ssl_transport_security.cc
  6. 12
      src/core/tsi/transport_security.cc
  7. 17
      src/core/tsi/transport_security.h
  8. 26
      src/core/tsi/transport_security_interface.h

@ -116,13 +116,10 @@ SecurityHandshaker::SecurityHandshaker(tsi_handshaker* handshaker,
connector_(connector->Ref(DEBUG_LOCATION, "handshake")),
handshake_buffer_size_(GRPC_INITIAL_HANDSHAKE_BUFFER_SIZE),
handshake_buffer_(
static_cast<uint8_t*>(gpr_malloc(handshake_buffer_size_))) {
const grpc_arg* arg =
grpc_channel_args_find(args, GRPC_ARG_TSI_MAX_FRAME_SIZE);
if (arg != nullptr && arg->type == GRPC_ARG_INTEGER) {
max_frame_size_ = grpc_channel_arg_get_integer(
arg, {0, 0, std::numeric_limits<int>::max()});
}
static_cast<uint8_t*>(gpr_malloc(handshake_buffer_size_))),
max_frame_size_(grpc_channel_args_find_integer(
args, GRPC_ARG_TSI_MAX_FRAME_SIZE,
{0, 0, std::numeric_limits<int>::max()})) {
grpc_slice_buffer_init(&outgoing_);
GRPC_CLOSURE_INIT(&on_peer_checked_, &SecurityHandshaker::OnPeerCheckedFn,
this, grpc_schedule_on_exec_ctx);
@ -232,37 +229,10 @@ void SecurityHandshaker::OnPeerCheckedInner(grpc_error_handle error) {
HandshakeFailedLocked(error);
return;
}
// Create zero-copy frame protector, if implemented.
tsi_zero_copy_grpc_protector* zero_copy_protector = nullptr;
tsi_result result = tsi_handshaker_result_create_zero_copy_grpc_protector(
handshaker_result_, max_frame_size_ == 0 ? nullptr : &max_frame_size_,
&zero_copy_protector);
if (result != TSI_OK && result != TSI_UNIMPLEMENTED) {
error = grpc_set_tsi_error_result(
GRPC_ERROR_CREATE_FROM_STATIC_STRING(
"Zero-copy frame protector creation failed"),
result);
HandshakeFailedLocked(error);
return;
}
// Create frame protector if zero-copy frame protector is NULL.
tsi_frame_protector* protector = nullptr;
if (zero_copy_protector == nullptr) {
result = tsi_handshaker_result_create_frame_protector(
handshaker_result_, max_frame_size_ == 0 ? nullptr : &max_frame_size_,
&protector);
if (result != TSI_OK) {
error = grpc_set_tsi_error_result(GRPC_ERROR_CREATE_FROM_STATIC_STRING(
"Frame protector creation failed"),
result);
HandshakeFailedLocked(error);
return;
}
}
// Get unused bytes.
const unsigned char* unused_bytes = nullptr;
size_t unused_bytes_size = 0;
result = tsi_handshaker_result_get_unused_bytes(
tsi_result result = tsi_handshaker_result_get_unused_bytes(
handshaker_result_, &unused_bytes, &unused_bytes_size);
if (result != TSI_OK) {
HandshakeFailedLocked(grpc_set_tsi_error_result(
@ -271,17 +241,71 @@ void SecurityHandshaker::OnPeerCheckedInner(grpc_error_handle error) {
result));
return;
}
// Create secure endpoint.
if (unused_bytes_size > 0) {
// Check whether we need to wrap the endpoint.
tsi_frame_protector_type frame_protector_type;
result = tsi_handshaker_result_get_frame_protector_type(
handshaker_result_, &frame_protector_type);
if (result != TSI_OK) {
HandshakeFailedLocked(grpc_set_tsi_error_result(
GRPC_ERROR_CREATE_FROM_STATIC_STRING(
"TSI handshaker result does not implement "
"get_frame_protector_type"),
result));
return;
}
tsi_zero_copy_grpc_protector* zero_copy_protector = nullptr;
tsi_frame_protector* protector = nullptr;
switch (frame_protector_type) {
case TSI_FRAME_PROTECTOR_ZERO_COPY:
ABSL_FALLTHROUGH_INTENDED;
case TSI_FRAME_PROTECTOR_NORMAL_OR_ZERO_COPY:
// Create zero-copy frame protector.
result = tsi_handshaker_result_create_zero_copy_grpc_protector(
handshaker_result_, max_frame_size_ == 0 ? nullptr : &max_frame_size_,
&zero_copy_protector);
if (result != TSI_OK) {
HandshakeFailedLocked(grpc_set_tsi_error_result(
GRPC_ERROR_CREATE_FROM_STATIC_STRING(
"Zero-copy frame protector creation failed"),
result));
return;
}
break;
case TSI_FRAME_PROTECTOR_NORMAL:
// Create normal frame protector.
result = tsi_handshaker_result_create_frame_protector(
handshaker_result_, max_frame_size_ == 0 ? nullptr : &max_frame_size_,
&protector);
if (result != TSI_OK) {
HandshakeFailedLocked(
grpc_set_tsi_error_result(GRPC_ERROR_CREATE_FROM_STATIC_STRING(
"Frame protector creation failed"),
result));
return;
}
break;
case TSI_FRAME_PROTECTOR_NONE:
break;
}
// If we have a frame protector, create a secure endpoint.
if (zero_copy_protector != nullptr || protector != nullptr) {
if (unused_bytes_size > 0) {
grpc_slice slice = grpc_slice_from_copied_buffer(
reinterpret_cast<const char*>(unused_bytes), unused_bytes_size);
args_->endpoint = grpc_secure_endpoint_create(
protector, zero_copy_protector, args_->endpoint, &slice, 1);
grpc_slice_unref_internal(slice);
} else {
args_->endpoint = grpc_secure_endpoint_create(
protector, zero_copy_protector, args_->endpoint, nullptr, 0);
}
} else if (unused_bytes_size > 0) {
// Not wrapping the endpoint, so just pass along unused bytes.
grpc_slice slice = grpc_slice_from_copied_buffer(
reinterpret_cast<const char*>(unused_bytes), unused_bytes_size);
args_->endpoint = grpc_secure_endpoint_create(
protector, zero_copy_protector, args_->endpoint, &slice, 1);
grpc_slice_unref_internal(slice);
} else {
args_->endpoint = grpc_secure_endpoint_create(
protector, zero_copy_protector, args_->endpoint, nullptr, 0);
grpc_slice_buffer_add(args_->read_buffer, slice);
}
// Done with handshaker result.
tsi_handshaker_result_destroy(handshaker_result_);
handshaker_result_ = nullptr;
// Add auth context to channel args.
@ -437,7 +461,6 @@ void SecurityHandshaker::OnHandshakeDataReceivedFromPeerFn(
size_t bytes_received_size = h->MoveReadBufferIntoHandshakeBuffer();
// Call TSI handshaker.
error = h->DoHandshakerNextLocked(h->handshake_buffer_, bytes_received_size);
if (error != GRPC_ERROR_NONE) {
h->HandshakeFailedLocked(error);
} else {

@ -151,6 +151,13 @@ static tsi_result handshaker_result_extract_peer(
return ok;
}
static tsi_result handshaker_result_get_frame_protector_type(
const tsi_handshaker_result* /*self*/,
tsi_frame_protector_type* frame_protector_type) {
*frame_protector_type = TSI_FRAME_PROTECTOR_NORMAL_OR_ZERO_COPY;
return TSI_OK;
}
static tsi_result handshaker_result_create_zero_copy_grpc_protector(
const tsi_handshaker_result* self, size_t* max_output_protected_frame_size,
tsi_zero_copy_grpc_protector** protector) {
@ -247,9 +254,11 @@ static void handshaker_result_destroy(tsi_handshaker_result* self) {
static const tsi_handshaker_result_vtable result_vtable = {
handshaker_result_extract_peer,
handshaker_result_get_frame_protector_type,
handshaker_result_create_zero_copy_grpc_protector,
handshaker_result_create_frame_protector,
handshaker_result_get_unused_bytes, handshaker_result_destroy};
handshaker_result_get_unused_bytes,
handshaker_result_destroy};
tsi_result alts_tsi_handshaker_result_create(grpc_gcp_HandshakerResp* resp,
bool is_client,

@ -498,6 +498,7 @@ struct fake_handshaker_result {
unsigned char* unused_bytes;
size_t unused_bytes_size;
};
static tsi_result fake_handshaker_result_extract_peer(
const tsi_handshaker_result* /*self*/, tsi_peer* peer) {
/* Construct a tsi_peer with 1 property: certificate type, security_level. */
@ -514,6 +515,13 @@ static tsi_result fake_handshaker_result_extract_peer(
return result;
}
static tsi_result fake_handshaker_result_get_frame_protector_type(
const tsi_handshaker_result* /*self*/,
tsi_frame_protector_type* frame_protector_type) {
*frame_protector_type = TSI_FRAME_PROTECTOR_NORMAL_OR_ZERO_COPY;
return TSI_OK;
}
static tsi_result fake_handshaker_result_create_zero_copy_grpc_protector(
const tsi_handshaker_result* /*self*/,
size_t* max_output_protected_frame_size,
@ -549,6 +557,7 @@ static void fake_handshaker_result_destroy(tsi_handshaker_result* self) {
static const tsi_handshaker_result_vtable handshaker_result_vtable = {
fake_handshaker_result_extract_peer,
fake_handshaker_result_get_frame_protector_type,
fake_handshaker_result_create_zero_copy_grpc_protector,
fake_handshaker_result_create_frame_protector,
fake_handshaker_result_get_unused_bytes,

@ -52,60 +52,6 @@ typedef struct local_tsi_handshaker {
bool is_client;
} local_tsi_handshaker;
/* --- tsi_zero_copy_grpc_protector methods implementation. --- */
static tsi_result local_zero_copy_grpc_protector_protect(
tsi_zero_copy_grpc_protector* self, grpc_slice_buffer* unprotected_slices,
grpc_slice_buffer* protected_slices) {
if (self == nullptr || unprotected_slices == nullptr ||
protected_slices == nullptr) {
gpr_log(GPR_ERROR, "Invalid nullptr arguments to zero-copy grpc protect.");
return TSI_INVALID_ARGUMENT;
}
grpc_slice_buffer_move_into(unprotected_slices, protected_slices);
return TSI_OK;
}
static tsi_result local_zero_copy_grpc_protector_unprotect(
tsi_zero_copy_grpc_protector* self, grpc_slice_buffer* protected_slices,
grpc_slice_buffer* unprotected_slices) {
if (self == nullptr || unprotected_slices == nullptr ||
protected_slices == nullptr) {
gpr_log(GPR_ERROR,
"Invalid nullptr arguments to zero-copy grpc unprotect.");
return TSI_INVALID_ARGUMENT;
}
grpc_slice_buffer_move_into(protected_slices, unprotected_slices);
return TSI_OK;
}
static void local_zero_copy_grpc_protector_destroy(
tsi_zero_copy_grpc_protector* self) {
gpr_free(self);
}
static const tsi_zero_copy_grpc_protector_vtable
local_zero_copy_grpc_protector_vtable = {
local_zero_copy_grpc_protector_protect,
local_zero_copy_grpc_protector_unprotect,
local_zero_copy_grpc_protector_destroy,
nullptr /* local_zero_copy_grpc_protector_max_frame_size */};
tsi_result local_zero_copy_grpc_protector_create(
tsi_zero_copy_grpc_protector** protector) {
if (grpc_core::ExecCtx::Get() == nullptr || protector == nullptr) {
gpr_log(
GPR_ERROR,
"Invalid nullptr arguments to local_zero_copy_grpc_protector create.");
return TSI_INVALID_ARGUMENT;
}
local_zero_copy_grpc_protector* impl =
static_cast<local_zero_copy_grpc_protector*>(gpr_zalloc(sizeof(*impl)));
impl->base.vtable = &local_zero_copy_grpc_protector_vtable;
*protector = &impl->base;
return TSI_OK;
}
/* --- tsi_handshaker_result methods implementation. --- */
static tsi_result handshaker_result_extract_peer(
@ -113,20 +59,11 @@ static tsi_result handshaker_result_extract_peer(
return TSI_OK;
}
static tsi_result handshaker_result_create_zero_copy_grpc_protector(
const tsi_handshaker_result* self,
size_t* /*max_output_protected_frame_size*/,
tsi_zero_copy_grpc_protector** protector) {
if (self == nullptr || protector == nullptr) {
gpr_log(GPR_ERROR,
"Invalid arguments to create_zero_copy_grpc_protector()");
return TSI_INVALID_ARGUMENT;
}
tsi_result ok = local_zero_copy_grpc_protector_create(protector);
if (ok != TSI_OK) {
gpr_log(GPR_ERROR, "Failed to create zero-copy grpc protector");
}
return ok;
static tsi_result handshaker_result_get_frame_protector_type(
const tsi_handshaker_result* /*self*/,
tsi_frame_protector_type* frame_protector_type) {
*frame_protector_type = TSI_FRAME_PROTECTOR_NONE;
return TSI_OK;
}
static tsi_result handshaker_result_get_unused_bytes(
@ -156,9 +93,11 @@ static void handshaker_result_destroy(tsi_handshaker_result* self) {
static const tsi_handshaker_result_vtable result_vtable = {
handshaker_result_extract_peer,
handshaker_result_create_zero_copy_grpc_protector,
handshaker_result_get_frame_protector_type,
nullptr, /* handshaker_result_create_zero_copy_grpc_protector */
nullptr, /* handshaker_result_create_frame_protector */
handshaker_result_get_unused_bytes, handshaker_result_destroy};
handshaker_result_get_unused_bytes,
handshaker_result_destroy};
static tsi_result create_handshaker_result(bool is_client,
const unsigned char* received_bytes,

@ -1302,6 +1302,13 @@ static tsi_result ssl_handshaker_result_extract_peer(
return result;
}
static tsi_result ssl_handshaker_result_get_frame_protector_type(
const tsi_handshaker_result* /*self*/,
tsi_frame_protector_type* frame_protector_type) {
*frame_protector_type = TSI_FRAME_PROTECTOR_NORMAL;
return TSI_OK;
}
static tsi_result ssl_handshaker_result_create_frame_protector(
const tsi_handshaker_result* self, size_t* max_output_protected_frame_size,
tsi_frame_protector** protector) {
@ -1368,6 +1375,7 @@ static void ssl_handshaker_result_destroy(tsi_handshaker_result* self) {
static const tsi_handshaker_result_vtable handshaker_result_vtable = {
ssl_handshaker_result_extract_peer,
ssl_handshaker_result_get_frame_protector_type,
nullptr, /* create_zero_copy_grpc_protector */
ssl_handshaker_result_create_frame_protector,
ssl_handshaker_result_get_unused_bytes,

@ -251,6 +251,18 @@ tsi_result tsi_handshaker_result_extract_peer(const tsi_handshaker_result* self,
return self->vtable->extract_peer(self, peer);
}
tsi_result tsi_handshaker_result_get_frame_protector_type(
const tsi_handshaker_result* self,
tsi_frame_protector_type* frame_protector_type) {
if (self == nullptr || frame_protector_type == nullptr) {
return TSI_INVALID_ARGUMENT;
}
if (self->vtable->get_frame_protector_type == nullptr) {
return TSI_UNIMPLEMENTED;
}
return self->vtable->get_frame_protector_type(self, frame_protector_type);
}
tsi_result tsi_handshaker_result_create_frame_protector(
const tsi_handshaker_result* self, size_t* max_output_protected_frame_size,
tsi_frame_protector** protector) {

@ -29,7 +29,8 @@
extern grpc_core::TraceFlag tsi_tracing_enabled;
/* Base for tsi_frame_protector implementations.
See transport_security_interface.h for documentation. */
See transport_security_interface.h for documentation.
All methods must be implemented. */
struct tsi_frame_protector_vtable {
tsi_result (*protect)(tsi_frame_protector* self,
const unsigned char* unprotected_bytes,
@ -54,6 +55,9 @@ struct tsi_frame_protector {
/* Base for tsi_handshaker implementations.
See transport_security_interface.h for documentation. */
struct tsi_handshaker_vtable {
/* Methods for supporting the old synchronous API.
These can be null if the TSI impl supports only the new
async-capable API. */
tsi_result (*get_bytes_to_send_to_peer)(tsi_handshaker* self,
unsigned char* bytes,
size_t* bytes_size);
@ -65,7 +69,10 @@ struct tsi_handshaker_vtable {
tsi_result (*create_frame_protector)(tsi_handshaker* self,
size_t* max_protected_frame_size,
tsi_frame_protector** protector);
/* Must be implemented by all TSI impls. */
void (*destroy)(tsi_handshaker* self);
/* Methods for supporting the new async-capable API.
These can be null if the TSI impl supports only the old sync API. */
tsi_result (*next)(tsi_handshaker* self, const unsigned char* received_bytes,
size_t received_bytes_size,
const unsigned char** bytes_to_send,
@ -88,13 +95,21 @@ struct tsi_handshaker {
API depend on grpc. The create_zero_copy_grpc_protector() method is only used
in grpc, where we do need the exec_ctx passed through, but the API still
needs to compile in other applications, where grpc_exec_ctx is not defined.
All methods must be non-null, except where noted below.
*/
struct tsi_handshaker_result_vtable {
tsi_result (*extract_peer)(const tsi_handshaker_result* self, tsi_peer* peer);
tsi_result (*get_frame_protector_type)(
const tsi_handshaker_result* self,
tsi_frame_protector_type* frame_protector_type);
/* May be null if get_frame_protector_type() returns
TSI_FRAME_PROTECTOR_NORMAL or TSI_FRAME_PROTECTOR_NONE. */
tsi_result (*create_zero_copy_grpc_protector)(
const tsi_handshaker_result* self,
size_t* max_output_protected_frame_size,
tsi_zero_copy_grpc_protector** protector);
/* May be null if get_frame_protector_type() returns
TSI_FRAME_PROTECTOR_ZERO_COPY or TSI_FRAME_PROTECTOR_NONE. */
tsi_result (*create_frame_protector)(const tsi_handshaker_result* self,
size_t* max_output_protected_frame_size,
tsi_frame_protector** protector);

@ -64,6 +64,26 @@ typedef enum {
TSI_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY,
} tsi_client_certificate_request_type;
typedef enum {
// TSI implementation provides a normal frame protector. The caller
// should invoke tsi_handshaker_result_create_frame_protector() to
// generate the frame protector.
TSI_FRAME_PROTECTOR_NORMAL,
// TSI implementation provides a zero-copy frame protector. The caller
// should invoke tsi_handshaker_result_create_zero_copy_grpc_protector()
// to generate the frame protector.
TSI_FRAME_PROTECTOR_ZERO_COPY,
// TSI implementation provides both normal and zero-copy frame protectors.
// The caller should invoke either
// tsi_handshaker_result_create_frame_protector() or
// tsi_handshaker_result_create_zero_copy_grpc_protector() to generate
// the frame protector.
TSI_FRAME_PROTECTOR_NORMAL_OR_ZERO_COPY,
// TSI implementation does not provide any frame protector. This means
// that it is safe for the caller to send bytes unprotected on the wire.
TSI_FRAME_PROTECTOR_NONE,
} tsi_frame_protector_type;
typedef enum {
TSI_TLS1_2,
TSI_TLS1_3,
@ -234,6 +254,12 @@ typedef struct tsi_handshaker_result tsi_handshaker_result;
tsi_result tsi_handshaker_result_extract_peer(const tsi_handshaker_result* self,
tsi_peer* peer);
/* This method indicates what type of frame protector is provided by the
TSI implementation. */
tsi_result tsi_handshaker_result_get_frame_protector_type(
const tsi_handshaker_result* self,
tsi_frame_protector_type* frame_protector_type);
/* This method creates a tsi_frame_protector object. It returns TSI_OK assuming
there is no fatal error.
The caller is responsible for destroying the protector. */

Loading…
Cancel
Save