From be4e1051a90d24364a702cf42f77675526115d34 Mon Sep 17 00:00:00 2001 From: Alisha Nanda Date: Wed, 9 Mar 2022 13:49:01 -0800 Subject: [PATCH] Add resource quota for secure endpoint (#28970) * Add resource quota for secure_endpoint * Fix double linked list error in tests * Experiments to prevent memory leaks * Fix memory leaks * Fix asan error * Add locking to fix data races * Automated change: Fix sanity tests * Address review comments Co-authored-by: ananda1066 --- BUILD | 3 + .../lib/security/transport/secure_endpoint.cc | 327 +++++++++++------- .../lib/security/transport/secure_endpoint.h | 2 +- .../security/transport/security_handshaker.cc | 10 +- test/core/security/secure_endpoint_test.cc | 14 +- 5 files changed, 218 insertions(+), 138 deletions(-) diff --git a/BUILD b/BUILD index 90e46170564..4f8a69f34e8 100644 --- a/BUILD +++ b/BUILD @@ -4154,9 +4154,12 @@ grpc_cc_library( "grpc_base", "grpc_trace", "json", + "memory_quota", "promise", "ref_counted", "ref_counted_ptr", + "resource_quota", + "resource_quota_trace", "try_seq", "tsi_base", ], diff --git a/src/core/lib/security/transport/secure_endpoint.cc b/src/core/lib/security/transport/secure_endpoint.cc index c60ecfba37e..74ec0d10712 100644 --- a/src/core/lib/security/transport/secure_endpoint.cc +++ b/src/core/lib/security/transport/secure_endpoint.cc @@ -33,6 +33,9 @@ #include "src/core/lib/gprpp/memory.h" #include "src/core/lib/iomgr/sockaddr.h" #include "src/core/lib/profiling/timers.h" +#include "src/core/lib/resource_quota/api.h" +#include "src/core/lib/resource_quota/memory_quota.h" +#include "src/core/lib/resource_quota/trace.h" #include "src/core/lib/security/transport/tsi_error.h" #include "src/core/lib/slice/slice_internal.h" #include "src/core/lib/slice/slice_string_helpers.h" @@ -48,6 +51,7 @@ struct secure_endpoint { tsi_frame_protector* protector, tsi_zero_copy_grpc_protector* zero_copy_protector, grpc_endpoint* transport, grpc_slice* leftover_slices, + const grpc_channel_args* channel_args, size_t leftover_nslices) : wrapped_ep(transport), protector(protector), @@ -62,6 +66,17 @@ struct secure_endpoint { grpc_slice_ref_internal(leftover_slices[i])); } grpc_slice_buffer_init(&output_buffer); + memory_owner = + grpc_core::ResourceQuotaFromChannelArgs(channel_args) + ->memory_quota() + ->CreateMemoryOwner(absl::StrCat(grpc_endpoint_get_peer(transport), + ":secure_endpoint")); + self_reservation = memory_owner.MakeReservation(sizeof(this)); + read_staging_buffer = + memory_owner.MakeSlice(grpc_core::MemoryRequest(STAGING_BUFFER_SIZE)); + write_staging_buffer = + memory_owner.MakeSlice(grpc_core::MemoryRequest(STAGING_BUFFER_SIZE)); + has_posted_reclaimer.store(false, std::memory_order_relaxed); gpr_ref_init(&ref, 1); } @@ -82,6 +97,8 @@ struct secure_endpoint { struct tsi_frame_protector* protector; struct tsi_zero_copy_grpc_protector* zero_copy_protector; gpr_mu protector_mu; + absl::Mutex read_mu; + absl::Mutex write_mu; /* saved upper level callbacks and user_data. */ grpc_closure* read_cb = nullptr; grpc_closure* write_cb = nullptr; @@ -91,9 +108,12 @@ struct secure_endpoint { /* saved handshaker leftover data to unprotect. */ grpc_slice_buffer leftover_bytes; /* buffers for read and write */ - grpc_slice read_staging_buffer = GRPC_SLICE_MALLOC(STAGING_BUFFER_SIZE); - grpc_slice write_staging_buffer = GRPC_SLICE_MALLOC(STAGING_BUFFER_SIZE); + grpc_slice read_staging_buffer ABSL_GUARDED_BY(read_mu); + grpc_slice write_staging_buffer ABSL_GUARDED_BY(write_mu); grpc_slice_buffer output_buffer; + grpc_core::MemoryOwner memory_owner; + grpc_core::MemoryAllocator::Reservation self_reservation; + std::atomic has_posted_reclaimer; gpr_refcount ref; }; @@ -143,10 +163,46 @@ static void secure_endpoint_unref(secure_endpoint* ep) { static void secure_endpoint_ref(secure_endpoint* ep) { gpr_ref(&ep->ref); } #endif +static void maybe_post_reclaimer(secure_endpoint* ep) { + if (!ep->has_posted_reclaimer) { + SECURE_ENDPOINT_REF(ep, "benign_reclaimer"); + ep->has_posted_reclaimer.exchange(true, std::memory_order_relaxed); + ep->memory_owner.PostReclaimer( + grpc_core::ReclamationPass::kBenign, + [ep](absl::optional sweep) { + if (sweep.has_value()) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_resource_quota_trace)) { + gpr_log(GPR_INFO, + "secure endpoint: benign reclamation to free memory"); + } + grpc_slice temp_read_slice; + grpc_slice temp_write_slice; + + ep->read_mu.Lock(); + temp_read_slice = ep->read_staging_buffer; + ep->read_staging_buffer = grpc_empty_slice(); + ep->read_mu.Unlock(); + + ep->write_mu.Lock(); + temp_write_slice = ep->write_staging_buffer; + ep->write_staging_buffer = grpc_empty_slice(); + ep->write_mu.Unlock(); + + grpc_slice_unref_internal(temp_read_slice); + grpc_slice_unref_internal(temp_write_slice); + ep->has_posted_reclaimer.exchange(false, std::memory_order_relaxed); + } + SECURE_ENDPOINT_UNREF(ep, "benign_reclaimer"); + }); + } +} + static void flush_read_staging_buffer(secure_endpoint* ep, uint8_t** cur, - uint8_t** end) { - grpc_slice_buffer_add(ep->read_buffer, ep->read_staging_buffer); - ep->read_staging_buffer = GRPC_SLICE_MALLOC(STAGING_BUFFER_SIZE); + uint8_t** end) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(ep->read_mu) { + grpc_slice_buffer_add_indexed(ep->read_buffer, ep->read_staging_buffer); + ep->read_staging_buffer = + ep->memory_owner.MakeSlice(grpc_core::MemoryRequest(STAGING_BUFFER_SIZE)); *cur = GRPC_SLICE_START_PTR(ep->read_staging_buffer); *end = GRPC_SLICE_END_PTR(ep->read_staging_buffer); } @@ -171,68 +227,73 @@ static void on_read(void* user_data, grpc_error_handle error) { uint8_t keep_looping = 0; tsi_result result = TSI_OK; secure_endpoint* ep = static_cast(user_data); - uint8_t* cur = GRPC_SLICE_START_PTR(ep->read_staging_buffer); - uint8_t* end = GRPC_SLICE_END_PTR(ep->read_staging_buffer); - if (error != GRPC_ERROR_NONE) { - grpc_slice_buffer_reset_and_unref_internal(ep->read_buffer); - call_read_cb(ep, GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( - "Secure read failed", &error, 1)); - return; - } + { + absl::MutexLock l(&ep->read_mu); + uint8_t* cur = GRPC_SLICE_START_PTR(ep->read_staging_buffer); + uint8_t* end = GRPC_SLICE_END_PTR(ep->read_staging_buffer); - if (ep->zero_copy_protector != nullptr) { - // Use zero-copy grpc protector to unprotect. - result = tsi_zero_copy_grpc_protector_unprotect( - ep->zero_copy_protector, &ep->source_buffer, ep->read_buffer); - } else { - // Use frame protector to unprotect. - /* TODO(yangg) check error, maybe bail out early */ - for (i = 0; i < ep->source_buffer.count; i++) { - grpc_slice encrypted = ep->source_buffer.slices[i]; - uint8_t* message_bytes = GRPC_SLICE_START_PTR(encrypted); - size_t message_size = GRPC_SLICE_LENGTH(encrypted); - - while (message_size > 0 || keep_looping) { - size_t unprotected_buffer_size_written = static_cast(end - cur); - size_t processed_message_size = message_size; - gpr_mu_lock(&ep->protector_mu); - result = tsi_frame_protector_unprotect( - ep->protector, message_bytes, &processed_message_size, cur, - &unprotected_buffer_size_written); - gpr_mu_unlock(&ep->protector_mu); - if (result != TSI_OK) { - gpr_log(GPR_ERROR, "Decryption error: %s", - tsi_result_to_string(result)); - break; - } - message_bytes += processed_message_size; - message_size -= processed_message_size; - cur += unprotected_buffer_size_written; - - if (cur == end) { - flush_read_staging_buffer(ep, &cur, &end); - /* Force to enter the loop again to extract buffered bytes in - protector. The bytes could be buffered because of running out of - staging_buffer. If this happens at the end of all slices, doing - another unprotect avoids leaving data in the protector. */ - keep_looping = 1; - } else if (unprotected_buffer_size_written > 0) { - keep_looping = 1; - } else { - keep_looping = 0; + if (error != GRPC_ERROR_NONE) { + grpc_slice_buffer_reset_and_unref_internal(ep->read_buffer); + call_read_cb(ep, GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( + "Secure read failed", &error, 1)); + return; + } + + if (ep->zero_copy_protector != nullptr) { + // Use zero-copy grpc protector to unprotect. + result = tsi_zero_copy_grpc_protector_unprotect( + ep->zero_copy_protector, &ep->source_buffer, ep->read_buffer); + } else { + // Use frame protector to unprotect. + /* TODO(yangg) check error, maybe bail out early */ + for (i = 0; i < ep->source_buffer.count; i++) { + grpc_slice encrypted = ep->source_buffer.slices[i]; + uint8_t* message_bytes = GRPC_SLICE_START_PTR(encrypted); + size_t message_size = GRPC_SLICE_LENGTH(encrypted); + + while (message_size > 0 || keep_looping) { + size_t unprotected_buffer_size_written = + static_cast(end - cur); + size_t processed_message_size = message_size; + gpr_mu_lock(&ep->protector_mu); + result = tsi_frame_protector_unprotect( + ep->protector, message_bytes, &processed_message_size, cur, + &unprotected_buffer_size_written); + gpr_mu_unlock(&ep->protector_mu); + if (result != TSI_OK) { + gpr_log(GPR_ERROR, "Decryption error: %s", + tsi_result_to_string(result)); + break; + } + message_bytes += processed_message_size; + message_size -= processed_message_size; + cur += unprotected_buffer_size_written; + + if (cur == end) { + flush_read_staging_buffer(ep, &cur, &end); + /* Force to enter the loop again to extract buffered bytes in + protector. The bytes could be buffered because of running out of + staging_buffer. If this happens at the end of all slices, doing + another unprotect avoids leaving data in the protector. */ + keep_looping = 1; + } else if (unprotected_buffer_size_written > 0) { + keep_looping = 1; + } else { + keep_looping = 0; + } } + if (result != TSI_OK) break; } - if (result != TSI_OK) break; - } - if (cur != GRPC_SLICE_START_PTR(ep->read_staging_buffer)) { - grpc_slice_buffer_add( - ep->read_buffer, - grpc_slice_split_head( - &ep->read_staging_buffer, - static_cast( - cur - GRPC_SLICE_START_PTR(ep->read_staging_buffer)))); + if (cur != GRPC_SLICE_START_PTR(ep->read_staging_buffer)) { + grpc_slice_buffer_add( + ep->read_buffer, + grpc_slice_split_head( + &ep->read_staging_buffer, + static_cast( + cur - GRPC_SLICE_START_PTR(ep->read_staging_buffer)))); + } } } @@ -270,11 +331,14 @@ static void endpoint_read(grpc_endpoint* secure_ep, grpc_slice_buffer* slices, } static void flush_write_staging_buffer(secure_endpoint* ep, uint8_t** cur, - uint8_t** end) { - grpc_slice_buffer_add(&ep->output_buffer, ep->write_staging_buffer); - ep->write_staging_buffer = GRPC_SLICE_MALLOC(STAGING_BUFFER_SIZE); + uint8_t** end) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(ep->write_mu) { + grpc_slice_buffer_add_indexed(&ep->output_buffer, ep->write_staging_buffer); + ep->write_staging_buffer = + ep->memory_owner.MakeSlice(grpc_core::MemoryRequest(STAGING_BUFFER_SIZE)); *cur = GRPC_SLICE_START_PTR(ep->write_staging_buffer); *end = GRPC_SLICE_END_PTR(ep->write_staging_buffer); + maybe_post_reclaimer(ep); } static void endpoint_write(grpc_endpoint* secure_ep, grpc_slice_buffer* slices, @@ -284,75 +348,79 @@ static void endpoint_write(grpc_endpoint* secure_ep, grpc_slice_buffer* slices, unsigned i; tsi_result result = TSI_OK; secure_endpoint* ep = reinterpret_cast(secure_ep); - uint8_t* cur = GRPC_SLICE_START_PTR(ep->write_staging_buffer); - uint8_t* end = GRPC_SLICE_END_PTR(ep->write_staging_buffer); - grpc_slice_buffer_reset_and_unref_internal(&ep->output_buffer); + { + absl::MutexLock l(&ep->write_mu); + uint8_t* cur = GRPC_SLICE_START_PTR(ep->write_staging_buffer); + uint8_t* end = GRPC_SLICE_END_PTR(ep->write_staging_buffer); - if (GRPC_TRACE_FLAG_ENABLED(grpc_trace_secure_endpoint)) { - for (i = 0; i < slices->count; i++) { - char* data = - grpc_dump_slice(slices->slices[i], GPR_DUMP_HEX | GPR_DUMP_ASCII); - gpr_log(GPR_INFO, "WRITE %p: %s", ep, data); - gpr_free(data); - } - } - - if (ep->zero_copy_protector != nullptr) { - // Use zero-copy grpc protector to protect. - result = tsi_zero_copy_grpc_protector_protect(ep->zero_copy_protector, - slices, &ep->output_buffer); - } else { - // Use frame protector to protect. - for (i = 0; i < slices->count; i++) { - grpc_slice plain = slices->slices[i]; - uint8_t* message_bytes = GRPC_SLICE_START_PTR(plain); - size_t message_size = GRPC_SLICE_LENGTH(plain); - while (message_size > 0) { - size_t protected_buffer_size_to_send = static_cast(end - cur); - size_t processed_message_size = message_size; - gpr_mu_lock(&ep->protector_mu); - result = tsi_frame_protector_protect(ep->protector, message_bytes, - &processed_message_size, cur, - &protected_buffer_size_to_send); - gpr_mu_unlock(&ep->protector_mu); - if (result != TSI_OK) { - gpr_log(GPR_ERROR, "Encryption error: %s", - tsi_result_to_string(result)); - break; - } - message_bytes += processed_message_size; - message_size -= processed_message_size; - cur += protected_buffer_size_to_send; + grpc_slice_buffer_reset_and_unref_internal(&ep->output_buffer); - if (cur == end) { - flush_write_staging_buffer(ep, &cur, &end); - } + if (GRPC_TRACE_FLAG_ENABLED(grpc_trace_secure_endpoint)) { + for (i = 0; i < slices->count; i++) { + char* data = + grpc_dump_slice(slices->slices[i], GPR_DUMP_HEX | GPR_DUMP_ASCII); + gpr_log(GPR_INFO, "WRITE %p: %s", ep, data); + gpr_free(data); } - if (result != TSI_OK) break; } - if (result == TSI_OK) { - size_t still_pending_size; - do { - size_t protected_buffer_size_to_send = static_cast(end - cur); - gpr_mu_lock(&ep->protector_mu); - result = tsi_frame_protector_protect_flush( - ep->protector, cur, &protected_buffer_size_to_send, - &still_pending_size); - gpr_mu_unlock(&ep->protector_mu); + + if (ep->zero_copy_protector != nullptr) { + // Use zero-copy grpc protector to protect. + result = tsi_zero_copy_grpc_protector_protect(ep->zero_copy_protector, + slices, &ep->output_buffer); + } else { + // Use frame protector to protect. + for (i = 0; i < slices->count; i++) { + grpc_slice plain = slices->slices[i]; + uint8_t* message_bytes = GRPC_SLICE_START_PTR(plain); + size_t message_size = GRPC_SLICE_LENGTH(plain); + while (message_size > 0) { + size_t protected_buffer_size_to_send = static_cast(end - cur); + size_t processed_message_size = message_size; + gpr_mu_lock(&ep->protector_mu); + result = tsi_frame_protector_protect(ep->protector, message_bytes, + &processed_message_size, cur, + &protected_buffer_size_to_send); + gpr_mu_unlock(&ep->protector_mu); + if (result != TSI_OK) { + gpr_log(GPR_ERROR, "Encryption error: %s", + tsi_result_to_string(result)); + break; + } + message_bytes += processed_message_size; + message_size -= processed_message_size; + cur += protected_buffer_size_to_send; + + if (cur == end) { + flush_write_staging_buffer(ep, &cur, &end); + } + } if (result != TSI_OK) break; - cur += protected_buffer_size_to_send; - if (cur == end) { - flush_write_staging_buffer(ep, &cur, &end); + } + if (result == TSI_OK) { + size_t still_pending_size; + do { + size_t protected_buffer_size_to_send = static_cast(end - cur); + gpr_mu_lock(&ep->protector_mu); + result = tsi_frame_protector_protect_flush( + ep->protector, cur, &protected_buffer_size_to_send, + &still_pending_size); + gpr_mu_unlock(&ep->protector_mu); + if (result != TSI_OK) break; + cur += protected_buffer_size_to_send; + if (cur == end) { + flush_write_staging_buffer(ep, &cur, &end); + } + } while (still_pending_size > 0); + if (cur != GRPC_SLICE_START_PTR(ep->write_staging_buffer)) { + grpc_slice_buffer_add( + &ep->output_buffer, + grpc_slice_split_head( + &ep->write_staging_buffer, + static_cast( + cur - GRPC_SLICE_START_PTR(ep->write_staging_buffer)))); } - } while (still_pending_size > 0); - if (cur != GRPC_SLICE_START_PTR(ep->write_staging_buffer)) { - grpc_slice_buffer_add( - &ep->output_buffer, - grpc_slice_split_head( - &ep->write_staging_buffer, - static_cast( - cur - GRPC_SLICE_START_PTR(ep->write_staging_buffer)))); } } } @@ -377,6 +445,7 @@ static void endpoint_shutdown(grpc_endpoint* secure_ep, grpc_error_handle why) { static void endpoint_destroy(grpc_endpoint* secure_ep) { secure_endpoint* ep = reinterpret_cast(secure_ep); + ep->memory_owner.Reset(); SECURE_ENDPOINT_UNREF(ep, "destroy"); } @@ -434,9 +503,9 @@ grpc_endpoint* grpc_secure_endpoint_create( struct tsi_frame_protector* protector, struct tsi_zero_copy_grpc_protector* zero_copy_protector, grpc_endpoint* to_wrap, grpc_slice* leftover_slices, - size_t leftover_nslices) { + const grpc_channel_args* channel_args, size_t leftover_nslices) { secure_endpoint* ep = new secure_endpoint(&vtable, protector, zero_copy_protector, to_wrap, - leftover_slices, leftover_nslices); + leftover_slices, channel_args, leftover_nslices); return &ep->base; } diff --git a/src/core/lib/security/transport/secure_endpoint.h b/src/core/lib/security/transport/secure_endpoint.h index b3b8e239327..163d718b4fb 100644 --- a/src/core/lib/security/transport/secure_endpoint.h +++ b/src/core/lib/security/transport/secure_endpoint.h @@ -37,6 +37,6 @@ grpc_endpoint* grpc_secure_endpoint_create( struct tsi_frame_protector* protector, struct tsi_zero_copy_grpc_protector* zero_copy_protector, grpc_endpoint* to_wrap, grpc_slice* leftover_slices, - size_t leftover_nslices); + const grpc_channel_args* channel_args, size_t leftover_nslices); #endif /* GRPC_CORE_LIB_SECURITY_TRANSPORT_SECURE_ENDPOINT_H */ diff --git a/src/core/lib/security/transport/security_handshaker.cc b/src/core/lib/security/transport/security_handshaker.cc index 0d3ff8ea59c..0b71bea54d3 100644 --- a/src/core/lib/security/transport/security_handshaker.cc +++ b/src/core/lib/security/transport/security_handshaker.cc @@ -294,12 +294,14 @@ void SecurityHandshaker::OnPeerCheckedInner(grpc_error_handle error) { if (unused_bytes_size > 0) { grpc_slice slice = grpc_slice_from_copied_buffer( reinterpret_cast(unused_bytes), unused_bytes_size); - args_->endpoint = grpc_secure_endpoint_create( - protector, zero_copy_protector, args_->endpoint, &slice, 1); + args_->endpoint = + grpc_secure_endpoint_create(protector, zero_copy_protector, + args_->endpoint, &slice, args_->args, 1); grpc_slice_unref_internal(slice); } else { - args_->endpoint = grpc_secure_endpoint_create( - protector, zero_copy_protector, args_->endpoint, nullptr, 0); + args_->endpoint = + grpc_secure_endpoint_create(protector, zero_copy_protector, + args_->endpoint, nullptr, args_->args, 0); } } else if (unused_bytes_size > 0) { // Not wrapping the endpoint, so just pass along unused bytes. diff --git a/test/core/security/secure_endpoint_test.cc b/test/core/security/secure_endpoint_test.cc index 1e6f6d3abc9..20bfcfbb060 100644 --- a/test/core/security/secure_endpoint_test.cc +++ b/test/core/security/secure_endpoint_test.cc @@ -55,10 +55,14 @@ static grpc_endpoint_test_fixture secure_endpoint_create_fixture_tcp_socketpair( grpc_endpoint_test_fixture f; grpc_endpoint_pair tcp; - grpc_arg a[1]; + grpc_arg a[2]; a[0].key = const_cast(GRPC_ARG_TCP_READ_CHUNK_SIZE); a[0].type = GRPC_ARG_INTEGER; a[0].value.integer = static_cast(slice_size); + a[1].key = const_cast(GRPC_ARG_RESOURCE_QUOTA); + a[1].type = GRPC_ARG_POINTER; + a[1].value.pointer.p = grpc_resource_quota_create("test"); + a[1].value.pointer.vtable = grpc_resource_quota_arg_vtable(); grpc_channel_args args = {GPR_ARRAY_SIZE(a), a}; tcp = grpc_iomgr_create_endpoint_pair("fixture", &args); grpc_endpoint_add_to_pollset(tcp.client, g_pollset); @@ -67,7 +71,7 @@ static grpc_endpoint_test_fixture secure_endpoint_create_fixture_tcp_socketpair( if (leftover_nslices == 0) { f.client_ep = grpc_secure_endpoint_create(fake_read_protector, fake_read_zero_copy_protector, - tcp.client, nullptr, 0); + tcp.client, nullptr, &args, 0); } else { unsigned i; tsi_result result; @@ -111,14 +115,16 @@ static grpc_endpoint_test_fixture secure_endpoint_create_fixture_tcp_socketpair( total_buffer_size - buffer_size); f.client_ep = grpc_secure_endpoint_create( fake_read_protector, fake_read_zero_copy_protector, tcp.client, - &encrypted_leftover, 1); + &encrypted_leftover, &args, 1); grpc_slice_unref(encrypted_leftover); gpr_free(encrypted_buffer); } f.server_ep = grpc_secure_endpoint_create(fake_write_protector, fake_write_zero_copy_protector, - tcp.server, nullptr, 0); + tcp.server, nullptr, &args, 0); + grpc_resource_quota_unref( + static_cast(a[1].value.pointer.p)); return f; }