diff --git a/src/core/ext/filters/http/message_compress/message_compress_filter.c b/src/core/ext/filters/http/message_compress/message_compress_filter.c index 1cbc8526865..23b100d8ee6 100644 --- a/src/core/ext/filters/http/message_compress/message_compress_filter.c +++ b/src/core/ext/filters/http/message_compress/message_compress_filter.c @@ -51,6 +51,10 @@ int grpc_compression_trace = 0; +#define INITIAL_METADATA_UNSEEN 0 +#define HAS_COMPRESSION_ALGORITHM 1 +#define NO_COMPRESSION_ALGORITHM 2 + typedef struct call_data { grpc_slice_buffer slices; /**< Buffers up input slices to be compressed */ grpc_linked_mdelem compression_algorithm_storage; @@ -59,8 +63,16 @@ typedef struct call_data { /** Compression algorithm we'll try to use. It may be given by incoming * metadata, or by the channel's default compression settings. */ grpc_compression_algorithm compression_algorithm; - /** If true, contents of \a compression_algorithm are authoritative */ - int has_compression_algorithm; + + /* Atomic recording the state of initial metadata; allowed values: + INITIAL_METADATA_UNSEEN - initial metadata op not seen + HAS_COMPRESSION_ALGORITHM - initial metadata seen; compression algorithm + set + NO_COMPRESSION_ALGORITHM - initial metadata seen; no compression algorithm + set + pointer - a stalled op containing a send_message that's waiting on initial + metadata */ + gpr_atm send_initial_metadata_state; grpc_transport_stream_op_batch *send_op; uint32_t send_length; @@ -81,14 +93,15 @@ typedef struct channel_data { uint32_t supported_compression_algorithms; } channel_data; -static int skip_compression(grpc_call_element *elem, uint32_t flags) { +static bool skip_compression(grpc_call_element *elem, uint32_t flags, + bool has_compression_algorithm) { call_data *calld = elem->call_data; channel_data *channeld = elem->channel_data; if (flags & (GRPC_WRITE_NO_COMPRESS | GRPC_WRITE_INTERNAL_COMPRESS)) { return 1; } - if (calld->has_compression_algorithm) { + if (has_compression_algorithm) { if (calld->compression_algorithm == GRPC_COMPRESS_NONE) { return 1; } @@ -101,12 +114,14 @@ static int skip_compression(grpc_call_element *elem, uint32_t flags) { /** Filter initial metadata */ static grpc_error *process_send_initial_metadata( grpc_exec_ctx *exec_ctx, grpc_call_element *elem, - grpc_metadata_batch *initial_metadata) GRPC_MUST_USE_RESULT; + grpc_metadata_batch *initial_metadata, + bool *has_compression_algorithm) GRPC_MUST_USE_RESULT; static grpc_error *process_send_initial_metadata( grpc_exec_ctx *exec_ctx, grpc_call_element *elem, - grpc_metadata_batch *initial_metadata) { + grpc_metadata_batch *initial_metadata, bool *has_compression_algorithm) { call_data *calld = elem->call_data; channel_data *channeld = elem->channel_data; + *has_compression_algorithm = false; /* Parse incoming request for compression. If any, it'll be available * at calld->compression_algorithm */ if (initial_metadata->idx.named.grpc_internal_encoding_request != NULL) { @@ -130,7 +145,7 @@ static grpc_error *process_send_initial_metadata( gpr_free(val); calld->compression_algorithm = GRPC_COMPRESS_NONE; } - calld->has_compression_algorithm = 1; + *has_compression_algorithm = true; grpc_metadata_batch_remove( exec_ctx, initial_metadata, @@ -140,7 +155,7 @@ static grpc_error *process_send_initial_metadata( * exceptionally skipping compression, fall back to the channel * default */ calld->compression_algorithm = channeld->default_compression_algorithm; - calld->has_compression_algorithm = 1; /* GPR_TRUE */ + *has_compression_algorithm = true; } grpc_error *error = GRPC_ERROR_NONE; @@ -251,20 +266,54 @@ static void compress_start_transport_stream_op_batch( GPR_TIMER_BEGIN("compress_start_transport_stream_op_batch", 0); if (op->send_initial_metadata) { + bool has_compression_algorithm; grpc_error *error = process_send_initial_metadata( exec_ctx, elem, - op->payload->send_initial_metadata.send_initial_metadata); + op->payload->send_initial_metadata.send_initial_metadata, + &has_compression_algorithm); if (error != GRPC_ERROR_NONE) { grpc_transport_stream_op_batch_finish_with_failure(exec_ctx, op, error); return; } + gpr_atm cur = gpr_atm_acq_load(&calld->send_initial_metadata_state); + GPR_ASSERT(cur != HAS_COMPRESSION_ALGORITHM && + cur != NO_COMPRESSION_ALGORITHM); + gpr_atm_rel_store(&calld->send_initial_metadata_state, + has_compression_algorithm ? HAS_COMPRESSION_ALGORITHM + : NO_COMPRESSION_ALGORITHM); + if (cur != INITIAL_METADATA_UNSEEN) { + grpc_call_next_op(exec_ctx, elem, (grpc_transport_stream_op_batch *)op); + } } - if (op->send_message && - !skip_compression(elem, op->payload->send_message.send_message->flags)) { - calld->send_op = op; - calld->send_length = op->payload->send_message.send_message->length; - calld->send_flags = op->payload->send_message.send_message->flags; - continue_send_message(exec_ctx, elem); + if (op->send_message) { + gpr_atm cur; + retry_send: + cur = gpr_atm_acq_load(&calld->send_initial_metadata_state); + switch (cur) { + case INITIAL_METADATA_UNSEEN: + if (!gpr_atm_rel_cas(&calld->send_initial_metadata_state, cur, + (gpr_atm)op)) { + goto retry_send; + } + break; + case HAS_COMPRESSION_ALGORITHM: + case NO_COMPRESSION_ALGORITHM: + if (!skip_compression(elem, + op->payload->send_message.send_message->flags, + cur == HAS_COMPRESSION_ALGORITHM)) { + calld->send_op = op; + calld->send_length = op->payload->send_message.send_message->length; + calld->send_flags = op->payload->send_message.send_message->flags; + continue_send_message(exec_ctx, elem); + } else { + /* pass control down the stack */ + grpc_call_next_op(exec_ctx, elem, op); + } + break; + default: + /* >1 send_message concurrently */ + GPR_UNREACHABLE_CODE(break); + } } else { /* pass control down the stack */ grpc_call_next_op(exec_ctx, elem, op); @@ -282,7 +331,6 @@ static grpc_error *init_call_elem(grpc_exec_ctx *exec_ctx, /* initialize members */ grpc_slice_buffer_init(&calld->slices); - calld->has_compression_algorithm = 0; grpc_closure_init(&calld->got_slice, got_slice, elem, grpc_schedule_on_exec_ctx); grpc_closure_init(&calld->send_done, send_done, elem,