From 98b930300ed71f6cc4f7bcdc262735ff76b37dd4 Mon Sep 17 00:00:00 2001 From: Vignesh Babu Date: Thu, 30 Jun 2022 22:57:33 +0000 Subject: [PATCH] Change secure endpoint write code to use max_frame_size to control encrypted frame sizes at the sender. (#29990) * use max_frame_size to control encrypted frame sizes on the sender * Add comment * adding logic to set max_frame_size in chttp2 transport and protecting it under a flag * fix typo * fix review comments * set max frame size usage in endpoint_tests * update endpoint_tests * adding an interception layer to secure_endpoint_test * add comments * reverting some mistaken changes * Automated change: Fix sanity tests * try increasing deadline to check if msan passes * Automated change: Fix sanity tests Co-authored-by: Vignesh2208 --- .../chttp2/transport/chttp2_transport.cc | 21 +++- .../lib/security/transport/secure_endpoint.cc | 31 ++++-- src/core/tsi/fake_transport_security.cc | 5 + src/core/tsi/fake_transport_security.h | 6 ++ test/core/iomgr/endpoint_tests.cc | 20 ++-- test/core/security/secure_endpoint_test.cc | 95 ++++++++++++++++++- 6 files changed, 161 insertions(+), 17 deletions(-) diff --git a/src/core/ext/transport/chttp2/transport/chttp2_transport.cc b/src/core/ext/transport/chttp2/transport/chttp2_transport.cc index 2fe922c2a91..d8349e7db44 100644 --- a/src/core/ext/transport/chttp2/transport/chttp2_transport.cc +++ b/src/core/ext/transport/chttp2/transport/chttp2_transport.cc @@ -62,6 +62,7 @@ #include "src/core/lib/gpr/useful.h" #include "src/core/lib/gprpp/bitset.h" #include "src/core/lib/gprpp/debug_location.h" +#include "src/core/lib/gprpp/global_config_env.h" #include "src/core/lib/gprpp/ref_counted.h" #include "src/core/lib/gprpp/status_helper.h" #include "src/core/lib/gprpp/time.h" @@ -92,6 +93,12 @@ #include "src/core/lib/transport/transport.h" #include "src/core/lib/transport/transport_impl.h" +GPR_GLOBAL_CONFIG_DEFINE_BOOL( + grpc_experimental_enable_peer_state_based_framing, false, + "If set, the max sizes of frames sent to lower layers is controlled based " + "on the peer's memory pressure which is reflected in its max http2 frame " + "size."); + #define DEFAULT_CONNECTION_WINDOW_TARGET (1024 * 1024) #define MAX_WINDOW 0x7fffffffu #define MAX_WRITE_BUFFER_SIZE (64 * 1024 * 1024) @@ -979,14 +986,26 @@ static void write_action_begin_locked(void* gt, static void write_action(void* gt, grpc_error_handle /*error*/) { GPR_TIMER_SCOPE("write_action", 0); + static bool kEnablePeerStateBasedFraming = + GPR_GLOBAL_CONFIG_GET(grpc_experimental_enable_peer_state_based_framing); grpc_chttp2_transport* t = static_cast(gt); void* cl = t->cl; t->cl = nullptr; + // If grpc_experimental_enable_peer_state_based_framing is set to true, + // choose max_frame_size as 2 * max http2 frame size of peer. If peer is under + // high memory pressure, then it would advertise a smaller max http2 frame + // size. With this logic, the sender would automatically reduce the sending + // frame size as well. + int max_frame_size = + kEnablePeerStateBasedFraming + ? 2 * t->settings[GRPC_PEER_SETTINGS] + [GRPC_CHTTP2_SETTINGS_MAX_FRAME_SIZE] + : INT_MAX; grpc_endpoint_write( t->ep, &t->outbuf, GRPC_CLOSURE_INIT(&t->write_action_end_locked, write_action_end, t, grpc_schedule_on_exec_ctx), - cl, /*max_frame_size=*/INT_MAX); + cl, max_frame_size); } static void write_action_end(void* tp, grpc_error_handle error) { diff --git a/src/core/lib/security/transport/secure_endpoint.cc b/src/core/lib/security/transport/secure_endpoint.cc index b0c8a3558b1..303a0dc1154 100644 --- a/src/core/lib/security/transport/secure_endpoint.cc +++ b/src/core/lib/security/transport/secure_endpoint.cc @@ -21,7 +21,6 @@ #include "src/core/lib/security/transport/secure_endpoint.h" #include -#include #include #include @@ -105,6 +104,7 @@ struct secure_endpoint { } has_posted_reclaimer.store(false, std::memory_order_relaxed); min_progress_size = 1; + grpc_slice_buffer_init(&protector_staging_buffer); gpr_ref_init(&ref, 1); } @@ -117,6 +117,7 @@ struct secure_endpoint { grpc_slice_unref_internal(read_staging_buffer); grpc_slice_unref_internal(write_staging_buffer); grpc_slice_buffer_destroy_internal(&output_buffer); + grpc_slice_buffer_destroy_internal(&protector_staging_buffer); gpr_mu_destroy(&protector_mu); } @@ -143,7 +144,7 @@ struct secure_endpoint { grpc_core::MemoryAllocator::Reservation self_reservation; std::atomic has_posted_reclaimer; int min_progress_size; - + grpc_slice_buffer protector_staging_buffer; gpr_refcount ref; }; } // namespace @@ -384,8 +385,7 @@ static void flush_write_staging_buffer(secure_endpoint* ep, uint8_t** cur, } static void endpoint_write(grpc_endpoint* secure_ep, grpc_slice_buffer* slices, - grpc_closure* cb, void* arg, - int /*max_frame_size*/) { + grpc_closure* cb, void* arg, int max_frame_size) { GPR_TIMER_SCOPE("secure_endpoint.endpoint_write", 0); unsigned i; @@ -410,8 +410,25 @@ static void endpoint_write(grpc_endpoint* secure_ep, grpc_slice_buffer* slices, if (ep->zero_copy_protector != nullptr) { // Use zero-copy grpc protector to protect. - result = tsi_zero_copy_grpc_protector_protect(ep->zero_copy_protector, - slices, &ep->output_buffer); + result = TSI_OK; + // Break the input slices into chunks of size = max_frame_size and call + // tsi_zero_copy_grpc_protector_protect on each chunk. This ensures that + // the protector cannot create frames larger than the specified + // max_frame_size. + while (slices->length > static_cast(max_frame_size) && + result == TSI_OK) { + grpc_slice_buffer_move_first(slices, + static_cast(max_frame_size), + &ep->protector_staging_buffer); + result = tsi_zero_copy_grpc_protector_protect( + ep->zero_copy_protector, &ep->protector_staging_buffer, + &ep->output_buffer); + } + if (result == TSI_OK && slices->length > 0) { + result = tsi_zero_copy_grpc_protector_protect( + ep->zero_copy_protector, slices, &ep->output_buffer); + } + grpc_slice_buffer_reset_and_unref_internal(&ep->protector_staging_buffer); } else { // Use frame protector to protect. for (i = 0; i < slices->count; i++) { @@ -479,7 +496,7 @@ static void endpoint_write(grpc_endpoint* secure_ep, grpc_slice_buffer* slices, } grpc_endpoint_write(ep->wrapped_ep, &ep->output_buffer, cb, arg, - /*max_frame_size=*/INT_MAX); + max_frame_size); } static void endpoint_shutdown(grpc_endpoint* secure_ep, grpc_error_handle why) { diff --git a/src/core/tsi/fake_transport_security.cc b/src/core/tsi/fake_transport_security.cc index 18a42dd7301..2143fb50bfa 100644 --- a/src/core/tsi/fake_transport_security.cc +++ b/src/core/tsi/fake_transport_security.cc @@ -143,6 +143,11 @@ static uint32_t read_frame_size(const grpc_slice_buffer* sb) { return load32_little_endian(frame_size_buffer); } +uint32_t tsi_fake_zero_copy_grpc_protector_next_frame_size( + const grpc_slice_buffer* protected_slices) { + return read_frame_size(protected_slices); +} + static void tsi_fake_frame_reset(tsi_fake_frame* frame, int needs_draining) { frame->offset = 0; frame->needs_draining = needs_draining; diff --git a/src/core/tsi/fake_transport_security.h b/src/core/tsi/fake_transport_security.h index 704d9fb1a12..b2cf5ffc15a 100644 --- a/src/core/tsi/fake_transport_security.h +++ b/src/core/tsi/fake_transport_security.h @@ -21,6 +21,7 @@ #include +#include "src/core/lib/slice/slice_internal.h" #include "src/core/tsi/transport_security_interface.h" /* Value for the TSI_CERTIFICATE_TYPE_PEER_PROPERTY property for FAKE certs. */ @@ -44,4 +45,9 @@ tsi_frame_protector* tsi_create_fake_frame_protector( tsi_zero_copy_grpc_protector* tsi_create_fake_zero_copy_grpc_protector( size_t* max_protected_frame_size); +/* Given a buffer containing slices encrypted by a fake_zero_copy_protector + * it parses these protected slices to return the total frame size of the first + * contained frame */ +uint32_t tsi_fake_zero_copy_grpc_protector_next_frame_size( + const grpc_slice_buffer* protected_slices); #endif /* GRPC_CORE_TSI_FAKE_TRANSPORT_SECURITY_H */ diff --git a/test/core/iomgr/endpoint_tests.cc b/test/core/iomgr/endpoint_tests.cc index ad2488a01d8..4fb94901f05 100644 --- a/test/core/iomgr/endpoint_tests.cc +++ b/test/core/iomgr/endpoint_tests.cc @@ -112,6 +112,7 @@ struct read_and_write_test_state { uint8_t current_write_data; int read_done; int write_done; + int max_write_frame_size; grpc_slice_buffer incoming; grpc_slice_buffer outgoing; grpc_closure done_read; @@ -153,7 +154,7 @@ static void write_scheduler(void* data, grpc_error_handle /* error */) { struct read_and_write_test_state* state = static_cast(data); grpc_endpoint_write(state->write_ep, &state->outgoing, &state->done_write, - nullptr, /*max_frame_size=*/INT_MAX); + nullptr, /*max_frame_size=*/state->max_write_frame_size); } static void read_and_write_test_write_handler(void* data, @@ -197,13 +198,14 @@ static void read_and_write_test_write_handler(void* data, */ static void read_and_write_test(grpc_endpoint_test_config config, size_t num_bytes, size_t write_size, - size_t slice_size, bool shutdown) { + size_t slice_size, int max_write_frame_size, + bool shutdown) { struct read_and_write_test_state state; grpc_endpoint_test_fixture f = begin_test(config, "read_and_write_test", slice_size); grpc_core::ExecCtx exec_ctx; auto deadline = grpc_core::Timestamp::FromTimespecRoundUp( - grpc_timeout_seconds_to_deadline(20)); + grpc_timeout_seconds_to_deadline(60)); gpr_log(GPR_DEBUG, "num_bytes=%" PRIuPTR " write_size=%" PRIuPTR " slice_size=%" PRIuPTR " shutdown=%d", @@ -223,6 +225,7 @@ static void read_and_write_test(grpc_endpoint_test_config config, state.target_bytes = num_bytes; state.bytes_read = 0; state.current_write_size = write_size; + state.max_write_frame_size = max_write_frame_size; state.bytes_written = 0; state.read_done = 0; state.write_done = 0; @@ -305,7 +308,6 @@ static void multiple_shutdown_test(grpc_endpoint_test_config config) { grpc_endpoint_test_fixture f = begin_test(config, "multiple_shutdown_test", 128); int fail_count = 0; - grpc_slice_buffer slice_buffer; grpc_slice_buffer_init(&slice_buffer); @@ -346,11 +348,13 @@ void grpc_endpoint_tests(grpc_endpoint_test_config config, g_pollset = pollset; g_mu = mu; multiple_shutdown_test(config); - read_and_write_test(config, 10000000, 100000, 8192, false); - read_and_write_test(config, 1000000, 100000, 1, false); - read_and_write_test(config, 100000000, 100000, 1, true); + for (int i = 1; i <= 8192; i = i * 2) { + read_and_write_test(config, 10000000, 100000, 8192, i, false); + read_and_write_test(config, 1000000, 100000, 1, i, false); + read_and_write_test(config, 100000000, 100000, 1, i, true); + } for (i = 1; i < 1000; i = std::max(i + 1, i * 5 / 4)) { - read_and_write_test(config, 40320, i, i, false); + read_and_write_test(config, 40320, i, i, i, false); } g_pollset = nullptr; g_mu = nullptr; diff --git a/test/core/security/secure_endpoint_test.cc b/test/core/security/secure_endpoint_test.cc index 03118d82775..708a499668b 100644 --- a/test/core/security/secure_endpoint_test.cc +++ b/test/core/security/secure_endpoint_test.cc @@ -36,6 +36,93 @@ static gpr_mu* g_mu; static grpc_pollset* g_pollset; +#define TSI_FAKE_FRAME_HEADER_SIZE 4 + +typedef struct intercept_endpoint { + grpc_endpoint base; + grpc_endpoint* wrapped_ep; + grpc_slice_buffer staging_buffer; +} intercept_endpoint; + +static void me_read(grpc_endpoint* ep, grpc_slice_buffer* slices, + grpc_closure* cb, bool urgent, int min_progress_size) { + intercept_endpoint* m = reinterpret_cast(ep); + grpc_endpoint_read(m->wrapped_ep, slices, cb, urgent, min_progress_size); +} + +static void me_write(grpc_endpoint* ep, grpc_slice_buffer* slices, + grpc_closure* cb, void* arg, int max_frame_size) { + intercept_endpoint* m = reinterpret_cast(ep); + int remaining = slices->length; + while (remaining > 0) { + // Estimate the frame size of the next frame. + int next_frame_size = + tsi_fake_zero_copy_grpc_protector_next_frame_size(slices); + GPR_ASSERT(next_frame_size > TSI_FAKE_FRAME_HEADER_SIZE); + // Ensure the protected data size does not exceed the max_frame_size. + GPR_ASSERT(next_frame_size - TSI_FAKE_FRAME_HEADER_SIZE <= max_frame_size); + // Move this frame into a staging buffer and repeat. + grpc_slice_buffer_move_first(slices, next_frame_size, &m->staging_buffer); + remaining -= next_frame_size; + } + grpc_slice_buffer_swap(&m->staging_buffer, slices); + grpc_endpoint_write(m->wrapped_ep, slices, cb, arg, max_frame_size); +} + +static void me_add_to_pollset(grpc_endpoint* /*ep*/, + grpc_pollset* /*pollset*/) {} + +static void me_add_to_pollset_set(grpc_endpoint* /*ep*/, + grpc_pollset_set* /*pollset*/) {} + +static void me_delete_from_pollset_set(grpc_endpoint* /*ep*/, + grpc_pollset_set* /*pollset*/) {} + +static void me_shutdown(grpc_endpoint* ep, grpc_error_handle why) { + intercept_endpoint* m = reinterpret_cast(ep); + grpc_endpoint_shutdown(m->wrapped_ep, why); +} + +static void me_destroy(grpc_endpoint* ep) { + intercept_endpoint* m = reinterpret_cast(ep); + grpc_endpoint_destroy(m->wrapped_ep); + grpc_slice_buffer_destroy(&m->staging_buffer); + gpr_free(m); +} + +static absl::string_view me_get_peer(grpc_endpoint* /*ep*/) { + return "fake:intercept-endpoint"; +} + +static absl::string_view me_get_local_address(grpc_endpoint* /*ep*/) { + return "fake:intercept-endpoint"; +} + +static int me_get_fd(grpc_endpoint* /*ep*/) { return -1; } + +static bool me_can_track_err(grpc_endpoint* /*ep*/) { return false; } + +static const grpc_endpoint_vtable vtable = {me_read, + me_write, + me_add_to_pollset, + me_add_to_pollset_set, + me_delete_from_pollset_set, + me_shutdown, + me_destroy, + me_get_peer, + me_get_local_address, + me_get_fd, + me_can_track_err}; + +grpc_endpoint* wrap_with_intercept_endpoint(grpc_endpoint* wrapped_ep) { + intercept_endpoint* m = + static_cast(gpr_malloc(sizeof(*m))); + m->base.vtable = &vtable; + m->wrapped_ep = wrapped_ep; + grpc_slice_buffer_init(&m->staging_buffer); + return &m->base; +} + static grpc_endpoint_test_fixture secure_endpoint_create_fixture_tcp_socketpair( size_t slice_size, grpc_slice* leftover_slices, size_t leftover_nslices, bool use_zero_copy_protector) { @@ -68,6 +155,13 @@ static grpc_endpoint_test_fixture secure_endpoint_create_fixture_tcp_socketpair( grpc_endpoint_add_to_pollset(tcp.client, g_pollset); grpc_endpoint_add_to_pollset(tcp.server, g_pollset); + // TODO(vigneshbabu): Extend the intercept endpoint logic to cover non-zero + // copy based frame protectors as well. + if (use_zero_copy_protector && leftover_nslices == 0) { + tcp.client = wrap_with_intercept_endpoint(tcp.client); + tcp.server = wrap_with_intercept_endpoint(tcp.server); + } + if (leftover_nslices == 0) { f.client_ep = grpc_secure_endpoint_create(fake_read_protector, fake_read_zero_copy_protector, @@ -125,7 +219,6 @@ static grpc_endpoint_test_fixture secure_endpoint_create_fixture_tcp_socketpair( tcp.server, nullptr, &args, 0); grpc_resource_quota_unref( static_cast(a[1].value.pointer.p)); - return f; }