From aca1145bb6b91f8309b39e69e51ae688dc7c4bbf Mon Sep 17 00:00:00 2001 From: Yash Tibrewal Date: Mon, 6 Apr 2020 18:26:38 -0700 Subject: [PATCH] Fix C Core tests --- include/grpc/impl/codegen/grpc_types.h | 4 +- .../ext/filters/http/http_filters_plugin.cc | 3 +- .../message_decompress_filter.cc | 35 +++---- test/core/end2end/tests/compressed_payload.cc | 96 +++++++++++++++---- 4 files changed, 101 insertions(+), 37 deletions(-) diff --git a/include/grpc/impl/codegen/grpc_types.h b/include/grpc/impl/codegen/grpc_types.h index c1624ad9b7c..26d5984b53a 100644 --- a/include/grpc/impl/codegen/grpc_types.h +++ b/include/grpc/impl/codegen/grpc_types.h @@ -180,8 +180,8 @@ typedef struct { grpc_byte_buffer_reader. This arg also determines whether max message limits will be applied to the decompressed buffer or the non-decompressed buffer. It is recommended to keep this enabled to protect against zip bomb attacks. */ -#define GRPC_ARG_ENABLE_PER_MESSAGE_DECOMPRESSION \ - "grpc.per_message_decompression" +#define GRPC_ARG_ENABLE_PER_MESSAGE_DECOMPRESSION_INSIDE_CORE \ + "grpc.per_message_decompression_inside_core" /** Enable/disable support for deadline checking. Defaults to 1, unless GRPC_ARG_MINIMAL_STACK is enabled, in which case it defaults to 0 */ #define GRPC_ARG_ENABLE_DEADLINE_CHECKS "grpc.enable_deadline_checking" diff --git a/src/core/ext/filters/http/http_filters_plugin.cc b/src/core/ext/filters/http/http_filters_plugin.cc index 59749d54546..3c90df4a15c 100644 --- a/src/core/ext/filters/http/http_filters_plugin.cc +++ b/src/core/ext/filters/http/http_filters_plugin.cc @@ -38,7 +38,8 @@ static optional_filter compress_filter = { &grpc_message_compress_filter, GRPC_ARG_ENABLE_PER_MESSAGE_COMPRESSION}; static optional_filter decompress_filter = { - &grpc_message_decompress_filter, GRPC_ARG_ENABLE_PER_MESSAGE_DECOMPRESSION}; + &grpc_message_decompress_filter, + GRPC_ARG_ENABLE_PER_MESSAGE_DECOMPRESSION_INSIDE_CORE}; static bool is_building_http_like_transport( grpc_channel_stack_builder* builder) { diff --git a/src/core/ext/filters/http/message_decompress/message_decompress_filter.cc b/src/core/ext/filters/http/message_decompress/message_decompress_filter.cc index 2543a86fdad..c30285408f6 100644 --- a/src/core/ext/filters/http/message_decompress/message_decompress_filter.cc +++ b/src/core/ext/filters/http/message_decompress/message_decompress_filter.cc @@ -54,7 +54,6 @@ class CallData { OnRecvInitialMetadataReady, this, grpc_schedule_on_exec_ctx); // Initialize state for recv_message_ready callback - grpc_slice_buffer_init(&recv_slices_); GRPC_CLOSURE_INIT(&on_recv_message_next_done_, OnRecvMessageNextDone, this, grpc_schedule_on_exec_ctx); GRPC_CLOSURE_INIT(&on_recv_message_ready_, OnRecvMessageReady, this, @@ -134,8 +133,6 @@ void CallData::OnRecvInitialMetadataReady(void* arg, grpc_error* error) { calld->recv_initial_metadata_->idx.named.grpc_encoding; if (grpc_encoding != nullptr) { calld->algorithm_ = DecodeMessageCompressionAlgorithm(grpc_encoding->md); - grpc_metadata_batch_remove(calld->recv_initial_metadata_, - GRPC_BATCH_GRPC_ENCODING); } } calld->MaybeResumeOnRecvMessageReady(); @@ -156,15 +153,7 @@ void CallData::MaybeResumeOnRecvMessageReady() { void CallData::OnRecvMessageReady(void* arg, grpc_error* error) { CallData* calld = static_cast(arg); - if (error == GRPC_ERROR_NONE && - calld->algorithm_ != GRPC_MESSAGE_COMPRESS_NONE) { - // recv_message can be NULL if trailing metadata is received instead of - // message. - if (*calld->recv_message_ == nullptr || - (*calld->recv_message_)->length() == 0) { - calld->ContinueRecvMessageReadyCallback(GRPC_ERROR_NONE); - return; - } + if (error == GRPC_ERROR_NONE) { if (calld->original_recv_initial_metadata_ready_ != nullptr) { calld->seen_recv_message_ready_ = true; GRPC_CALL_COMBINER_STOP(calld->call_combiner_, @@ -172,10 +161,20 @@ void CallData::OnRecvMessageReady(void* arg, grpc_error* error) { "OnRecvInitialMetadataReady"); return; } - calld->ContinueReadingRecvMessage(); - } else { - calld->ContinueRecvMessageReadyCallback(GRPC_ERROR_REF(error)); + if (calld->algorithm_ != GRPC_MESSAGE_COMPRESS_NONE) { + // recv_message can be NULL if trailing metadata is received instead of + // message, or it's possible that the message was not compressed. + if (*calld->recv_message_ == nullptr || + (*calld->recv_message_)->length() == 0 || + ((*calld->recv_message_)->flags() & GRPC_WRITE_INTERNAL_COMPRESS) == + 0) { + return calld->ContinueRecvMessageReadyCallback(GRPC_ERROR_NONE); + } + grpc_slice_buffer_init(&calld->recv_slices_); + return calld->ContinueReadingRecvMessage(); + } } + calld->ContinueRecvMessageReadyCallback(GRPC_ERROR_REF(error)); } void CallData::ContinueReadingRecvMessage() { @@ -219,6 +218,7 @@ void CallData::OnRecvMessageNextDone(void* arg, grpc_error* error) { void CallData::FinishRecvMessage() { grpc_slice_buffer decompressed_slices; + grpc_slice_buffer_init(&decompressed_slices); if (grpc_msg_decompress(algorithm_, &recv_slices_, &decompressed_slices) == 0) { char* msg; @@ -230,10 +230,11 @@ void CallData::FinishRecvMessage() { error_ = GRPC_ERROR_CREATE_FROM_COPIED_STRING(msg); gpr_free(msg); } else { - uint32_t recv_flags = (*recv_message_)->flags(); + uint32_t recv_flags = + (*recv_message_)->flags() & (~GRPC_WRITE_INTERNAL_COMPRESS); // Swap out the original receive byte stream with our new one and send the // batch down. - recv_replacement_stream_.Init(&recv_slices_, recv_flags); + recv_replacement_stream_.Init(&decompressed_slices, recv_flags); recv_message_->reset(recv_replacement_stream_.get()); recv_message_ = nullptr; } diff --git a/test/core/end2end/tests/compressed_payload.cc b/test/core/end2end/tests/compressed_payload.cc index 81e6bfccd10..6fc6bf985d3 100644 --- a/test/core/end2end/tests/compressed_payload.cc +++ b/test/core/end2end/tests/compressed_payload.cc @@ -97,7 +97,8 @@ static void request_for_disabled_algorithm( uint32_t send_flags_bitmask, grpc_compression_algorithm algorithm_to_disable, grpc_compression_algorithm requested_client_compression_algorithm, - grpc_status_code expected_error, grpc_metadata* client_metadata) { + grpc_status_code expected_error, grpc_metadata* client_metadata, + bool decompress_in_core) { grpc_call* c; grpc_call* s; grpc_slice request_payload_slice; @@ -132,6 +133,21 @@ static void request_for_disabled_algorithm( grpc_core::ExecCtx exec_ctx; server_args = grpc_channel_args_compression_algorithm_set_state( &server_args, algorithm_to_disable, false); + if (!decompress_in_core) { + grpc_arg disable_decompression_in_core_arg = + grpc_channel_arg_integer_create( + const_cast( + GRPC_ARG_ENABLE_PER_MESSAGE_DECOMPRESSION_INSIDE_CORE), + 0); + grpc_channel_args* old_client_args = client_args; + grpc_channel_args* old_server_args = server_args; + client_args = grpc_channel_args_copy_and_add( + client_args, &disable_decompression_in_core_arg, 1); + server_args = grpc_channel_args_copy_and_add( + server_args, &disable_decompression_in_core_arg, 1); + grpc_channel_args_destroy(old_client_args); + grpc_channel_args_destroy(old_server_args); + } } f = begin_test(config, test_name, client_args, server_args); @@ -264,7 +280,7 @@ static void request_for_disabled_algorithm( config.tear_down_data(&f); } -static void request_with_payload_template( +static void request_with_payload_template_inner( grpc_end2end_test_config config, const char* test_name, uint32_t client_send_flags_bitmask, grpc_compression_algorithm default_client_channel_compression_algorithm, @@ -273,7 +289,7 @@ static void request_with_payload_template( grpc_compression_algorithm expected_algorithm_from_server, grpc_metadata* client_init_metadata, bool set_server_level, grpc_compression_level server_compression_level, - bool send_message_before_initial_metadata) { + bool send_message_before_initial_metadata, bool decompress_in_core) { grpc_call* c; grpc_call* s; grpc_slice request_payload_slice; @@ -308,11 +324,28 @@ static void request_with_payload_template( grpc_slice response_payload_slice = grpc_slice_from_copied_string(response_str); - client_args = grpc_channel_args_set_channel_default_compression_algorithm( - nullptr, default_client_channel_compression_algorithm); - server_args = grpc_channel_args_set_channel_default_compression_algorithm( - nullptr, default_server_channel_compression_algorithm); - + { + grpc_core::ExecCtx exec_ctx; + client_args = grpc_channel_args_set_channel_default_compression_algorithm( + nullptr, default_client_channel_compression_algorithm); + server_args = grpc_channel_args_set_channel_default_compression_algorithm( + nullptr, default_server_channel_compression_algorithm); + if (!decompress_in_core) { + grpc_arg disable_decompression_in_core_arg = + grpc_channel_arg_integer_create( + const_cast( + GRPC_ARG_ENABLE_PER_MESSAGE_DECOMPRESSION_INSIDE_CORE), + 0); + grpc_channel_args* old_client_args = client_args; + grpc_channel_args* old_server_args = server_args; + client_args = grpc_channel_args_copy_and_add( + client_args, &disable_decompression_in_core_arg, 1); + server_args = grpc_channel_args_copy_and_add( + server_args, &disable_decompression_in_core_arg, 1); + grpc_channel_args_destroy(old_client_args); + grpc_channel_args_destroy(old_server_args); + } + } f = begin_test(config, test_name, client_args, server_args); cqv = cq_verifier_create(f.cq); @@ -341,7 +374,6 @@ static void request_with_payload_template( GPR_ASSERT(GRPC_CALL_OK == error); CQ_EXPECT_COMPLETION(cqv, tag(2), true); } - memset(ops, 0, sizeof(ops)); op = ops; op->op = GRPC_OP_SEND_INITIAL_METADATA; @@ -385,7 +417,6 @@ static void request_with_payload_template( GRPC_COMPRESS_DEFLATE) != 0); GPR_ASSERT(GPR_BITGET(grpc_call_test_only_get_encodings_accepted_by_peer(s), GRPC_COMPRESS_GZIP) != 0); - memset(ops, 0, sizeof(ops)); op = ops; op->op = GRPC_OP_SEND_INITIAL_METADATA; @@ -406,7 +437,6 @@ static void request_with_payload_template( error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(101), nullptr); GPR_ASSERT(GRPC_CALL_OK == error); - for (int i = 0; i < 2; i++) { response_payload = grpc_raw_byte_buffer_create(&response_payload_slice, 1); @@ -442,7 +472,8 @@ static void request_with_payload_template( GPR_ASSERT(request_payload_recv->type == GRPC_BB_RAW); GPR_ASSERT(byte_buffer_eq_string(request_payload_recv, request_str)); GPR_ASSERT(request_payload_recv->data.raw.compression == - expected_algorithm_from_client); + (decompress_in_core ? GRPC_COMPRESS_NONE + : expected_algorithm_from_client)); memset(ops, 0, sizeof(ops)); op = ops; @@ -475,11 +506,13 @@ static void request_with_payload_template( if (server_compression_level > GRPC_COMPRESS_LEVEL_NONE) { const grpc_compression_algorithm algo_for_server_level = grpc_call_compression_for_level(s, server_compression_level); - GPR_ASSERT(response_payload_recv->data.raw.compression == - algo_for_server_level); + GPR_ASSERT( + response_payload_recv->data.raw.compression == + (decompress_in_core ? GRPC_COMPRESS_NONE : algo_for_server_level)); } else { GPR_ASSERT(response_payload_recv->data.raw.compression == - expected_algorithm_from_server); + (decompress_in_core ? GRPC_COMPRESS_NONE + : expected_algorithm_from_server)); } grpc_byte_buffer_destroy(request_payload); @@ -487,7 +520,6 @@ static void request_with_payload_template( grpc_byte_buffer_destroy(request_payload_recv); grpc_byte_buffer_destroy(response_payload_recv); } - grpc_slice_unref(request_payload_slice); grpc_slice_unref(response_payload_slice); @@ -547,6 +579,32 @@ static void request_with_payload_template( config.tear_down_data(&f); } +static void request_with_payload_template( + grpc_end2end_test_config config, const char* test_name, + uint32_t client_send_flags_bitmask, + grpc_compression_algorithm default_client_channel_compression_algorithm, + grpc_compression_algorithm default_server_channel_compression_algorithm, + grpc_compression_algorithm expected_algorithm_from_client, + grpc_compression_algorithm expected_algorithm_from_server, + grpc_metadata* client_init_metadata, bool set_server_level, + grpc_compression_level server_compression_level, + bool send_message_before_initial_metadata) { + request_with_payload_template_inner( + config, test_name, client_send_flags_bitmask, + default_client_channel_compression_algorithm, + default_server_channel_compression_algorithm, + expected_algorithm_from_client, expected_algorithm_from_server, + client_init_metadata, set_server_level, server_compression_level, + send_message_before_initial_metadata, false); + request_with_payload_template_inner( + config, test_name, client_send_flags_bitmask, + default_client_channel_compression_algorithm, + default_server_channel_compression_algorithm, + expected_algorithm_from_client, expected_algorithm_from_server, + client_init_metadata, set_server_level, server_compression_level, + send_message_before_initial_metadata, true); +} + static void test_invoke_request_with_exceptionally_uncompressed_payload( grpc_end2end_test_config config) { request_with_payload_template( @@ -634,7 +692,11 @@ static void test_invoke_request_with_disabled_algorithm( request_for_disabled_algorithm(config, "test_invoke_request_with_disabled_algorithm", 0, GRPC_COMPRESS_GZIP, GRPC_COMPRESS_GZIP, - GRPC_STATUS_UNIMPLEMENTED, nullptr); + GRPC_STATUS_UNIMPLEMENTED, nullptr, false); + request_for_disabled_algorithm(config, + "test_invoke_request_with_disabled_algorithm", + 0, GRPC_COMPRESS_GZIP, GRPC_COMPRESS_GZIP, + GRPC_STATUS_UNIMPLEMENTED, nullptr, true); } void compressed_payload(grpc_end2end_test_config config) {