Add resource quota for secure endpoint (#28970)

* Add resource quota for secure_endpoint

* Fix double linked list error in tests

* Experiments to prevent  memory leaks

* Fix memory leaks

* Fix asan error

* Add locking to fix data races

* Automated change: Fix sanity tests

* Address review comments

Co-authored-by: ananda1066 <ananda1066@users.noreply.github.com>
pull/29064/head
Alisha Nanda 3 years ago committed by GitHub
parent 82ae3c7043
commit be4e1051a9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 3
      BUILD
  2. 327
      src/core/lib/security/transport/secure_endpoint.cc
  3. 2
      src/core/lib/security/transport/secure_endpoint.h
  4. 10
      src/core/lib/security/transport/security_handshaker.cc
  5. 14
      test/core/security/secure_endpoint_test.cc

@ -4154,9 +4154,12 @@ grpc_cc_library(
"grpc_base",
"grpc_trace",
"json",
"memory_quota",
"promise",
"ref_counted",
"ref_counted_ptr",
"resource_quota",
"resource_quota_trace",
"try_seq",
"tsi_base",
],

@ -33,6 +33,9 @@
#include "src/core/lib/gprpp/memory.h"
#include "src/core/lib/iomgr/sockaddr.h"
#include "src/core/lib/profiling/timers.h"
#include "src/core/lib/resource_quota/api.h"
#include "src/core/lib/resource_quota/memory_quota.h"
#include "src/core/lib/resource_quota/trace.h"
#include "src/core/lib/security/transport/tsi_error.h"
#include "src/core/lib/slice/slice_internal.h"
#include "src/core/lib/slice/slice_string_helpers.h"
@ -48,6 +51,7 @@ struct secure_endpoint {
tsi_frame_protector* protector,
tsi_zero_copy_grpc_protector* zero_copy_protector,
grpc_endpoint* transport, grpc_slice* leftover_slices,
const grpc_channel_args* channel_args,
size_t leftover_nslices)
: wrapped_ep(transport),
protector(protector),
@ -62,6 +66,17 @@ struct secure_endpoint {
grpc_slice_ref_internal(leftover_slices[i]));
}
grpc_slice_buffer_init(&output_buffer);
memory_owner =
grpc_core::ResourceQuotaFromChannelArgs(channel_args)
->memory_quota()
->CreateMemoryOwner(absl::StrCat(grpc_endpoint_get_peer(transport),
":secure_endpoint"));
self_reservation = memory_owner.MakeReservation(sizeof(this));
read_staging_buffer =
memory_owner.MakeSlice(grpc_core::MemoryRequest(STAGING_BUFFER_SIZE));
write_staging_buffer =
memory_owner.MakeSlice(grpc_core::MemoryRequest(STAGING_BUFFER_SIZE));
has_posted_reclaimer.store(false, std::memory_order_relaxed);
gpr_ref_init(&ref, 1);
}
@ -82,6 +97,8 @@ struct secure_endpoint {
struct tsi_frame_protector* protector;
struct tsi_zero_copy_grpc_protector* zero_copy_protector;
gpr_mu protector_mu;
absl::Mutex read_mu;
absl::Mutex write_mu;
/* saved upper level callbacks and user_data. */
grpc_closure* read_cb = nullptr;
grpc_closure* write_cb = nullptr;
@ -91,9 +108,12 @@ struct secure_endpoint {
/* saved handshaker leftover data to unprotect. */
grpc_slice_buffer leftover_bytes;
/* buffers for read and write */
grpc_slice read_staging_buffer = GRPC_SLICE_MALLOC(STAGING_BUFFER_SIZE);
grpc_slice write_staging_buffer = GRPC_SLICE_MALLOC(STAGING_BUFFER_SIZE);
grpc_slice read_staging_buffer ABSL_GUARDED_BY(read_mu);
grpc_slice write_staging_buffer ABSL_GUARDED_BY(write_mu);
grpc_slice_buffer output_buffer;
grpc_core::MemoryOwner memory_owner;
grpc_core::MemoryAllocator::Reservation self_reservation;
std::atomic<bool> has_posted_reclaimer;
gpr_refcount ref;
};
@ -143,10 +163,46 @@ static void secure_endpoint_unref(secure_endpoint* ep) {
static void secure_endpoint_ref(secure_endpoint* ep) { gpr_ref(&ep->ref); }
#endif
static void maybe_post_reclaimer(secure_endpoint* ep) {
if (!ep->has_posted_reclaimer) {
SECURE_ENDPOINT_REF(ep, "benign_reclaimer");
ep->has_posted_reclaimer.exchange(true, std::memory_order_relaxed);
ep->memory_owner.PostReclaimer(
grpc_core::ReclamationPass::kBenign,
[ep](absl::optional<grpc_core::ReclamationSweep> sweep) {
if (sweep.has_value()) {
if (GRPC_TRACE_FLAG_ENABLED(grpc_resource_quota_trace)) {
gpr_log(GPR_INFO,
"secure endpoint: benign reclamation to free memory");
}
grpc_slice temp_read_slice;
grpc_slice temp_write_slice;
ep->read_mu.Lock();
temp_read_slice = ep->read_staging_buffer;
ep->read_staging_buffer = grpc_empty_slice();
ep->read_mu.Unlock();
ep->write_mu.Lock();
temp_write_slice = ep->write_staging_buffer;
ep->write_staging_buffer = grpc_empty_slice();
ep->write_mu.Unlock();
grpc_slice_unref_internal(temp_read_slice);
grpc_slice_unref_internal(temp_write_slice);
ep->has_posted_reclaimer.exchange(false, std::memory_order_relaxed);
}
SECURE_ENDPOINT_UNREF(ep, "benign_reclaimer");
});
}
}
static void flush_read_staging_buffer(secure_endpoint* ep, uint8_t** cur,
uint8_t** end) {
grpc_slice_buffer_add(ep->read_buffer, ep->read_staging_buffer);
ep->read_staging_buffer = GRPC_SLICE_MALLOC(STAGING_BUFFER_SIZE);
uint8_t** end)
ABSL_EXCLUSIVE_LOCKS_REQUIRED(ep->read_mu) {
grpc_slice_buffer_add_indexed(ep->read_buffer, ep->read_staging_buffer);
ep->read_staging_buffer =
ep->memory_owner.MakeSlice(grpc_core::MemoryRequest(STAGING_BUFFER_SIZE));
*cur = GRPC_SLICE_START_PTR(ep->read_staging_buffer);
*end = GRPC_SLICE_END_PTR(ep->read_staging_buffer);
}
@ -171,68 +227,73 @@ static void on_read(void* user_data, grpc_error_handle error) {
uint8_t keep_looping = 0;
tsi_result result = TSI_OK;
secure_endpoint* ep = static_cast<secure_endpoint*>(user_data);
uint8_t* cur = GRPC_SLICE_START_PTR(ep->read_staging_buffer);
uint8_t* end = GRPC_SLICE_END_PTR(ep->read_staging_buffer);
if (error != GRPC_ERROR_NONE) {
grpc_slice_buffer_reset_and_unref_internal(ep->read_buffer);
call_read_cb(ep, GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING(
"Secure read failed", &error, 1));
return;
}
{
absl::MutexLock l(&ep->read_mu);
uint8_t* cur = GRPC_SLICE_START_PTR(ep->read_staging_buffer);
uint8_t* end = GRPC_SLICE_END_PTR(ep->read_staging_buffer);
if (ep->zero_copy_protector != nullptr) {
// Use zero-copy grpc protector to unprotect.
result = tsi_zero_copy_grpc_protector_unprotect(
ep->zero_copy_protector, &ep->source_buffer, ep->read_buffer);
} else {
// Use frame protector to unprotect.
/* TODO(yangg) check error, maybe bail out early */
for (i = 0; i < ep->source_buffer.count; i++) {
grpc_slice encrypted = ep->source_buffer.slices[i];
uint8_t* message_bytes = GRPC_SLICE_START_PTR(encrypted);
size_t message_size = GRPC_SLICE_LENGTH(encrypted);
while (message_size > 0 || keep_looping) {
size_t unprotected_buffer_size_written = static_cast<size_t>(end - cur);
size_t processed_message_size = message_size;
gpr_mu_lock(&ep->protector_mu);
result = tsi_frame_protector_unprotect(
ep->protector, message_bytes, &processed_message_size, cur,
&unprotected_buffer_size_written);
gpr_mu_unlock(&ep->protector_mu);
if (result != TSI_OK) {
gpr_log(GPR_ERROR, "Decryption error: %s",
tsi_result_to_string(result));
break;
}
message_bytes += processed_message_size;
message_size -= processed_message_size;
cur += unprotected_buffer_size_written;
if (cur == end) {
flush_read_staging_buffer(ep, &cur, &end);
/* Force to enter the loop again to extract buffered bytes in
protector. The bytes could be buffered because of running out of
staging_buffer. If this happens at the end of all slices, doing
another unprotect avoids leaving data in the protector. */
keep_looping = 1;
} else if (unprotected_buffer_size_written > 0) {
keep_looping = 1;
} else {
keep_looping = 0;
if (error != GRPC_ERROR_NONE) {
grpc_slice_buffer_reset_and_unref_internal(ep->read_buffer);
call_read_cb(ep, GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING(
"Secure read failed", &error, 1));
return;
}
if (ep->zero_copy_protector != nullptr) {
// Use zero-copy grpc protector to unprotect.
result = tsi_zero_copy_grpc_protector_unprotect(
ep->zero_copy_protector, &ep->source_buffer, ep->read_buffer);
} else {
// Use frame protector to unprotect.
/* TODO(yangg) check error, maybe bail out early */
for (i = 0; i < ep->source_buffer.count; i++) {
grpc_slice encrypted = ep->source_buffer.slices[i];
uint8_t* message_bytes = GRPC_SLICE_START_PTR(encrypted);
size_t message_size = GRPC_SLICE_LENGTH(encrypted);
while (message_size > 0 || keep_looping) {
size_t unprotected_buffer_size_written =
static_cast<size_t>(end - cur);
size_t processed_message_size = message_size;
gpr_mu_lock(&ep->protector_mu);
result = tsi_frame_protector_unprotect(
ep->protector, message_bytes, &processed_message_size, cur,
&unprotected_buffer_size_written);
gpr_mu_unlock(&ep->protector_mu);
if (result != TSI_OK) {
gpr_log(GPR_ERROR, "Decryption error: %s",
tsi_result_to_string(result));
break;
}
message_bytes += processed_message_size;
message_size -= processed_message_size;
cur += unprotected_buffer_size_written;
if (cur == end) {
flush_read_staging_buffer(ep, &cur, &end);
/* Force to enter the loop again to extract buffered bytes in
protector. The bytes could be buffered because of running out of
staging_buffer. If this happens at the end of all slices, doing
another unprotect avoids leaving data in the protector. */
keep_looping = 1;
} else if (unprotected_buffer_size_written > 0) {
keep_looping = 1;
} else {
keep_looping = 0;
}
}
if (result != TSI_OK) break;
}
if (result != TSI_OK) break;
}
if (cur != GRPC_SLICE_START_PTR(ep->read_staging_buffer)) {
grpc_slice_buffer_add(
ep->read_buffer,
grpc_slice_split_head(
&ep->read_staging_buffer,
static_cast<size_t>(
cur - GRPC_SLICE_START_PTR(ep->read_staging_buffer))));
if (cur != GRPC_SLICE_START_PTR(ep->read_staging_buffer)) {
grpc_slice_buffer_add(
ep->read_buffer,
grpc_slice_split_head(
&ep->read_staging_buffer,
static_cast<size_t>(
cur - GRPC_SLICE_START_PTR(ep->read_staging_buffer))));
}
}
}
@ -270,11 +331,14 @@ static void endpoint_read(grpc_endpoint* secure_ep, grpc_slice_buffer* slices,
}
static void flush_write_staging_buffer(secure_endpoint* ep, uint8_t** cur,
uint8_t** end) {
grpc_slice_buffer_add(&ep->output_buffer, ep->write_staging_buffer);
ep->write_staging_buffer = GRPC_SLICE_MALLOC(STAGING_BUFFER_SIZE);
uint8_t** end)
ABSL_EXCLUSIVE_LOCKS_REQUIRED(ep->write_mu) {
grpc_slice_buffer_add_indexed(&ep->output_buffer, ep->write_staging_buffer);
ep->write_staging_buffer =
ep->memory_owner.MakeSlice(grpc_core::MemoryRequest(STAGING_BUFFER_SIZE));
*cur = GRPC_SLICE_START_PTR(ep->write_staging_buffer);
*end = GRPC_SLICE_END_PTR(ep->write_staging_buffer);
maybe_post_reclaimer(ep);
}
static void endpoint_write(grpc_endpoint* secure_ep, grpc_slice_buffer* slices,
@ -284,75 +348,79 @@ static void endpoint_write(grpc_endpoint* secure_ep, grpc_slice_buffer* slices,
unsigned i;
tsi_result result = TSI_OK;
secure_endpoint* ep = reinterpret_cast<secure_endpoint*>(secure_ep);
uint8_t* cur = GRPC_SLICE_START_PTR(ep->write_staging_buffer);
uint8_t* end = GRPC_SLICE_END_PTR(ep->write_staging_buffer);
grpc_slice_buffer_reset_and_unref_internal(&ep->output_buffer);
{
absl::MutexLock l(&ep->write_mu);
uint8_t* cur = GRPC_SLICE_START_PTR(ep->write_staging_buffer);
uint8_t* end = GRPC_SLICE_END_PTR(ep->write_staging_buffer);
if (GRPC_TRACE_FLAG_ENABLED(grpc_trace_secure_endpoint)) {
for (i = 0; i < slices->count; i++) {
char* data =
grpc_dump_slice(slices->slices[i], GPR_DUMP_HEX | GPR_DUMP_ASCII);
gpr_log(GPR_INFO, "WRITE %p: %s", ep, data);
gpr_free(data);
}
}
if (ep->zero_copy_protector != nullptr) {
// Use zero-copy grpc protector to protect.
result = tsi_zero_copy_grpc_protector_protect(ep->zero_copy_protector,
slices, &ep->output_buffer);
} else {
// Use frame protector to protect.
for (i = 0; i < slices->count; i++) {
grpc_slice plain = slices->slices[i];
uint8_t* message_bytes = GRPC_SLICE_START_PTR(plain);
size_t message_size = GRPC_SLICE_LENGTH(plain);
while (message_size > 0) {
size_t protected_buffer_size_to_send = static_cast<size_t>(end - cur);
size_t processed_message_size = message_size;
gpr_mu_lock(&ep->protector_mu);
result = tsi_frame_protector_protect(ep->protector, message_bytes,
&processed_message_size, cur,
&protected_buffer_size_to_send);
gpr_mu_unlock(&ep->protector_mu);
if (result != TSI_OK) {
gpr_log(GPR_ERROR, "Encryption error: %s",
tsi_result_to_string(result));
break;
}
message_bytes += processed_message_size;
message_size -= processed_message_size;
cur += protected_buffer_size_to_send;
grpc_slice_buffer_reset_and_unref_internal(&ep->output_buffer);
if (cur == end) {
flush_write_staging_buffer(ep, &cur, &end);
}
if (GRPC_TRACE_FLAG_ENABLED(grpc_trace_secure_endpoint)) {
for (i = 0; i < slices->count; i++) {
char* data =
grpc_dump_slice(slices->slices[i], GPR_DUMP_HEX | GPR_DUMP_ASCII);
gpr_log(GPR_INFO, "WRITE %p: %s", ep, data);
gpr_free(data);
}
if (result != TSI_OK) break;
}
if (result == TSI_OK) {
size_t still_pending_size;
do {
size_t protected_buffer_size_to_send = static_cast<size_t>(end - cur);
gpr_mu_lock(&ep->protector_mu);
result = tsi_frame_protector_protect_flush(
ep->protector, cur, &protected_buffer_size_to_send,
&still_pending_size);
gpr_mu_unlock(&ep->protector_mu);
if (ep->zero_copy_protector != nullptr) {
// Use zero-copy grpc protector to protect.
result = tsi_zero_copy_grpc_protector_protect(ep->zero_copy_protector,
slices, &ep->output_buffer);
} else {
// Use frame protector to protect.
for (i = 0; i < slices->count; i++) {
grpc_slice plain = slices->slices[i];
uint8_t* message_bytes = GRPC_SLICE_START_PTR(plain);
size_t message_size = GRPC_SLICE_LENGTH(plain);
while (message_size > 0) {
size_t protected_buffer_size_to_send = static_cast<size_t>(end - cur);
size_t processed_message_size = message_size;
gpr_mu_lock(&ep->protector_mu);
result = tsi_frame_protector_protect(ep->protector, message_bytes,
&processed_message_size, cur,
&protected_buffer_size_to_send);
gpr_mu_unlock(&ep->protector_mu);
if (result != TSI_OK) {
gpr_log(GPR_ERROR, "Encryption error: %s",
tsi_result_to_string(result));
break;
}
message_bytes += processed_message_size;
message_size -= processed_message_size;
cur += protected_buffer_size_to_send;
if (cur == end) {
flush_write_staging_buffer(ep, &cur, &end);
}
}
if (result != TSI_OK) break;
cur += protected_buffer_size_to_send;
if (cur == end) {
flush_write_staging_buffer(ep, &cur, &end);
}
if (result == TSI_OK) {
size_t still_pending_size;
do {
size_t protected_buffer_size_to_send = static_cast<size_t>(end - cur);
gpr_mu_lock(&ep->protector_mu);
result = tsi_frame_protector_protect_flush(
ep->protector, cur, &protected_buffer_size_to_send,
&still_pending_size);
gpr_mu_unlock(&ep->protector_mu);
if (result != TSI_OK) break;
cur += protected_buffer_size_to_send;
if (cur == end) {
flush_write_staging_buffer(ep, &cur, &end);
}
} while (still_pending_size > 0);
if (cur != GRPC_SLICE_START_PTR(ep->write_staging_buffer)) {
grpc_slice_buffer_add(
&ep->output_buffer,
grpc_slice_split_head(
&ep->write_staging_buffer,
static_cast<size_t>(
cur - GRPC_SLICE_START_PTR(ep->write_staging_buffer))));
}
} while (still_pending_size > 0);
if (cur != GRPC_SLICE_START_PTR(ep->write_staging_buffer)) {
grpc_slice_buffer_add(
&ep->output_buffer,
grpc_slice_split_head(
&ep->write_staging_buffer,
static_cast<size_t>(
cur - GRPC_SLICE_START_PTR(ep->write_staging_buffer))));
}
}
}
@ -377,6 +445,7 @@ static void endpoint_shutdown(grpc_endpoint* secure_ep, grpc_error_handle why) {
static void endpoint_destroy(grpc_endpoint* secure_ep) {
secure_endpoint* ep = reinterpret_cast<secure_endpoint*>(secure_ep);
ep->memory_owner.Reset();
SECURE_ENDPOINT_UNREF(ep, "destroy");
}
@ -434,9 +503,9 @@ grpc_endpoint* grpc_secure_endpoint_create(
struct tsi_frame_protector* protector,
struct tsi_zero_copy_grpc_protector* zero_copy_protector,
grpc_endpoint* to_wrap, grpc_slice* leftover_slices,
size_t leftover_nslices) {
const grpc_channel_args* channel_args, size_t leftover_nslices) {
secure_endpoint* ep =
new secure_endpoint(&vtable, protector, zero_copy_protector, to_wrap,
leftover_slices, leftover_nslices);
leftover_slices, channel_args, leftover_nslices);
return &ep->base;
}

@ -37,6 +37,6 @@ grpc_endpoint* grpc_secure_endpoint_create(
struct tsi_frame_protector* protector,
struct tsi_zero_copy_grpc_protector* zero_copy_protector,
grpc_endpoint* to_wrap, grpc_slice* leftover_slices,
size_t leftover_nslices);
const grpc_channel_args* channel_args, size_t leftover_nslices);
#endif /* GRPC_CORE_LIB_SECURITY_TRANSPORT_SECURE_ENDPOINT_H */

@ -294,12 +294,14 @@ void SecurityHandshaker::OnPeerCheckedInner(grpc_error_handle error) {
if (unused_bytes_size > 0) {
grpc_slice slice = grpc_slice_from_copied_buffer(
reinterpret_cast<const char*>(unused_bytes), unused_bytes_size);
args_->endpoint = grpc_secure_endpoint_create(
protector, zero_copy_protector, args_->endpoint, &slice, 1);
args_->endpoint =
grpc_secure_endpoint_create(protector, zero_copy_protector,
args_->endpoint, &slice, args_->args, 1);
grpc_slice_unref_internal(slice);
} else {
args_->endpoint = grpc_secure_endpoint_create(
protector, zero_copy_protector, args_->endpoint, nullptr, 0);
args_->endpoint =
grpc_secure_endpoint_create(protector, zero_copy_protector,
args_->endpoint, nullptr, args_->args, 0);
}
} else if (unused_bytes_size > 0) {
// Not wrapping the endpoint, so just pass along unused bytes.

@ -55,10 +55,14 @@ static grpc_endpoint_test_fixture secure_endpoint_create_fixture_tcp_socketpair(
grpc_endpoint_test_fixture f;
grpc_endpoint_pair tcp;
grpc_arg a[1];
grpc_arg a[2];
a[0].key = const_cast<char*>(GRPC_ARG_TCP_READ_CHUNK_SIZE);
a[0].type = GRPC_ARG_INTEGER;
a[0].value.integer = static_cast<int>(slice_size);
a[1].key = const_cast<char*>(GRPC_ARG_RESOURCE_QUOTA);
a[1].type = GRPC_ARG_POINTER;
a[1].value.pointer.p = grpc_resource_quota_create("test");
a[1].value.pointer.vtable = grpc_resource_quota_arg_vtable();
grpc_channel_args args = {GPR_ARRAY_SIZE(a), a};
tcp = grpc_iomgr_create_endpoint_pair("fixture", &args);
grpc_endpoint_add_to_pollset(tcp.client, g_pollset);
@ -67,7 +71,7 @@ static grpc_endpoint_test_fixture secure_endpoint_create_fixture_tcp_socketpair(
if (leftover_nslices == 0) {
f.client_ep = grpc_secure_endpoint_create(fake_read_protector,
fake_read_zero_copy_protector,
tcp.client, nullptr, 0);
tcp.client, nullptr, &args, 0);
} else {
unsigned i;
tsi_result result;
@ -111,14 +115,16 @@ static grpc_endpoint_test_fixture secure_endpoint_create_fixture_tcp_socketpair(
total_buffer_size - buffer_size);
f.client_ep = grpc_secure_endpoint_create(
fake_read_protector, fake_read_zero_copy_protector, tcp.client,
&encrypted_leftover, 1);
&encrypted_leftover, &args, 1);
grpc_slice_unref(encrypted_leftover);
gpr_free(encrypted_buffer);
}
f.server_ep = grpc_secure_endpoint_create(fake_write_protector,
fake_write_zero_copy_protector,
tcp.server, nullptr, 0);
tcp.server, nullptr, &args, 0);
grpc_resource_quota_unref(
static_cast<grpc_resource_quota*>(a[1].value.pointer.p));
return f;
}

Loading…
Cancel
Save