diff --git a/BUILD b/BUILD index 1159511776b..06d0845a763 100644 --- a/BUILD +++ b/BUILD @@ -1487,7 +1487,6 @@ grpc_cc_library( "//src/core:iomgr_fwd", "//src/core:iomgr_port", "//src/core:json", - "//src/core:latch", "//src/core:map", "//src/core:match", "//src/core:memory_quota", diff --git a/build_autogenerated.yaml b/build_autogenerated.yaml index 580fb9253ee..c4d635f9f24 100644 --- a/build_autogenerated.yaml +++ b/build_autogenerated.yaml @@ -3730,7 +3730,6 @@ libs: - src/core/lib/promise/if.h - src/core/lib/promise/interceptor_list.h - src/core/lib/promise/intra_activity_waiter.h - - src/core/lib/promise/latch.h - src/core/lib/promise/loop.h - src/core/lib/promise/map.h - src/core/lib/promise/pipe.h @@ -7544,7 +7543,6 @@ targets: - src/core/lib/promise/if.h - src/core/lib/promise/interceptor_list.h - src/core/lib/promise/intra_activity_waiter.h - - src/core/lib/promise/latch.h - src/core/lib/promise/loop.h - src/core/lib/promise/map.h - src/core/lib/promise/pipe.h diff --git a/src/core/BUILD b/src/core/BUILD index 74a87fd481a..310ac860b23 100644 --- a/src/core/BUILD +++ b/src/core/BUILD @@ -3460,38 +3460,33 @@ grpc_cc_library( "ext/filters/message_size/message_size_filter.h", ], external_deps = [ - "absl/status:statusor", + "absl/status", "absl/strings", "absl/strings:str_format", "absl/types:optional", ], language = "c++", deps = [ - "activity", - "arena", - "arena_promise", "channel_args", "channel_fwd", "channel_init", "channel_stack_type", - "context", + "closure", + "error", "grpc_service_config", "json", "json_args", "json_object_loader", - "latch", - "poll", - "race", "service_config_parser", - "slice", "slice_buffer", + "status_helper", "validation_errors", "//:channel_stack_builder", "//:config", + "//:debug_location", "//:gpr", "//:grpc_base", "//:grpc_public_hdrs", - "//:grpc_trace", ], ) diff --git a/src/core/ext/filters/http/client/http_client_filter.cc b/src/core/ext/filters/http/client/http_client_filter.cc index 89ff0356cce..58f9a709024 100644 --- a/src/core/ext/filters/http/client/http_client_filter.cc +++ b/src/core/ext/filters/http/client/http_client_filter.cc @@ -133,13 +133,13 @@ ArenaPromise HttpClientFilter::MakeCallPromise( return std::move(md); }); - return Race(initial_metadata_err->Wait(), - Map(next_promise_factory(std::move(call_args)), + return Race(Map(next_promise_factory(std::move(call_args)), [](ServerMetadataHandle md) -> ServerMetadataHandle { auto r = CheckServerMetadata(md.get()); if (!r.ok()) return ServerMetadataFromStatus(r); return md; - })); + }), + initial_metadata_err->Wait()); } HttpClientFilter::HttpClientFilter(HttpSchemeMetadata::ValueType scheme, diff --git a/src/core/ext/filters/http/message_compress/compression_filter.cc b/src/core/ext/filters/http/message_compress/compression_filter.cc index a0f87cfd6be..03171819c68 100644 --- a/src/core/ext/filters/http/message_compress/compression_filter.cc +++ b/src/core/ext/filters/http/message_compress/compression_filter.cc @@ -233,7 +233,7 @@ ArenaPromise ClientCompressionFilter::MakeCallPromise( return CompressMessage(std::move(message), compression_algorithm); }); auto* decompress_args = GetContext()->New( - DecompressArgs{GRPC_COMPRESS_ALGORITHMS_COUNT, absl::nullopt}); + DecompressArgs{GRPC_COMPRESS_NONE, absl::nullopt}); auto* decompress_err = GetContext()->New>(); call_args.server_initial_metadata->InterceptAndMap( @@ -254,8 +254,8 @@ ArenaPromise ClientCompressionFilter::MakeCallPromise( return std::move(*r); }); // Run the next filter, and race it with getting an error from decompression. - return Race(decompress_err->Wait(), - next_promise_factory(std::move(call_args))); + return Race(next_promise_factory(std::move(call_args)), + decompress_err->Wait()); } ArenaPromise ServerCompressionFilter::MakeCallPromise( @@ -269,8 +269,7 @@ ArenaPromise ServerCompressionFilter::MakeCallPromise( this](MessageHandle message) -> absl::optional { auto r = DecompressMessage(std::move(message), decompress_args); if (grpc_call_trace.enabled()) { - gpr_log(GPR_DEBUG, "%s[compression] DecompressMessage returned %s", - Activity::current()->DebugTag().c_str(), + gpr_log(GPR_DEBUG, "DecompressMessage returned %s", r.status().ToString().c_str()); } if (!r.ok()) { @@ -301,8 +300,8 @@ ArenaPromise ServerCompressionFilter::MakeCallPromise( // - decompress incoming messages // - wait for initial metadata to be sent, and then commence compression of // outgoing messages - return Race(decompress_err->Wait(), - next_promise_factory(std::move(call_args))); + return Race(next_promise_factory(std::move(call_args)), + decompress_err->Wait()); } } // namespace grpc_core diff --git a/src/core/ext/filters/message_size/message_size_filter.cc b/src/core/ext/filters/message_size/message_size_filter.cc index c265ecab7d7..33ff178e5a9 100644 --- a/src/core/ext/filters/message_size/message_size_filter.cc +++ b/src/core/ext/filters/message_size/message_size_filter.cc @@ -18,13 +18,10 @@ #include "src/core/ext/filters/message_size/message_size_filter.h" -#include - -#include #include -#include -#include +#include +#include "absl/status/status.h" #include "absl/strings/str_format.h" #include @@ -35,22 +32,21 @@ #include "src/core/lib/channel/channel_stack.h" #include "src/core/lib/channel/channel_stack_builder.h" #include "src/core/lib/config/core_configuration.h" -#include "src/core/lib/debug/trace.h" -#include "src/core/lib/promise/activity.h" -#include "src/core/lib/promise/context.h" -#include "src/core/lib/promise/latch.h" -#include "src/core/lib/promise/poll.h" -#include "src/core/lib/promise/race.h" -#include "src/core/lib/resource_quota/arena.h" +#include "src/core/lib/gprpp/debug_location.h" +#include "src/core/lib/gprpp/status_helper.h" +#include "src/core/lib/iomgr/call_combiner.h" +#include "src/core/lib/iomgr/closure.h" +#include "src/core/lib/iomgr/error.h" #include "src/core/lib/service_config/service_config_call_data.h" -#include "src/core/lib/slice/slice.h" #include "src/core/lib/slice/slice_buffer.h" -#include "src/core/lib/surface/call_trace.h" #include "src/core/lib/surface/channel_init.h" #include "src/core/lib/surface/channel_stack_type.h" -#include "src/core/lib/transport/metadata_batch.h" #include "src/core/lib/transport/transport.h" +static void recv_message_ready(void* user_data, grpc_error_handle error); +static void recv_trailing_metadata_ready(void* user_data, + grpc_error_handle error); + namespace grpc_core { // @@ -128,164 +124,251 @@ size_t MessageSizeParser::ParserIndex() { parser_name()); } -// -// MessageSizeFilter -// +} // namespace grpc_core -const grpc_channel_filter ClientMessageSizeFilter::kFilter = - MakePromiseBasedFilter("message_size"); -const grpc_channel_filter ServerMessageSizeFilter::kFilter = - MakePromiseBasedFilter("message_size"); - -class MessageSizeFilter::CallBuilder { - private: - auto Interceptor(uint32_t max_length, bool is_send) { - return [max_length, is_send, - err = err_](MessageHandle msg) -> absl::optional { - if (grpc_call_trace.enabled()) { - gpr_log(GPR_INFO, "%s[message_size] %s len:%" PRIdPTR " max:%d", - Activity::current()->DebugTag().c_str(), - is_send ? "send" : "recv", msg->payload()->Length(), - max_length); +namespace { +struct channel_data { + grpc_core::MessageSizeParsedConfig limits; + const size_t service_config_parser_index{ + grpc_core::MessageSizeParser::ParserIndex()}; +}; + +struct call_data { + call_data(grpc_call_element* elem, const channel_data& chand, + const grpc_call_element_args& args) + : call_combiner(args.call_combiner), limits(chand.limits) { + GRPC_CLOSURE_INIT(&recv_message_ready, ::recv_message_ready, elem, + grpc_schedule_on_exec_ctx); + GRPC_CLOSURE_INIT(&recv_trailing_metadata_ready, + ::recv_trailing_metadata_ready, elem, + grpc_schedule_on_exec_ctx); + // Get max sizes from channel data, then merge in per-method config values. + // Note: Per-method config is only available on the client, so we + // apply the max request size to the send limit and the max response + // size to the receive limit. + const grpc_core::MessageSizeParsedConfig* config_from_call_context = + grpc_core::MessageSizeParsedConfig::GetFromCallContext( + args.context, chand.service_config_parser_index); + if (config_from_call_context != nullptr) { + absl::optional max_send_size = limits.max_send_size(); + absl::optional max_recv_size = limits.max_recv_size(); + if (config_from_call_context->max_send_size().has_value() && + (!max_send_size.has_value() || + *config_from_call_context->max_send_size() < *max_send_size)) { + max_send_size = *config_from_call_context->max_send_size(); } - if (msg->payload()->Length() > max_length) { - if (err->is_set()) return std::move(msg); - auto r = GetContext()->MakePooled( - GetContext()); - r->Set(GrpcStatusMetadata(), GRPC_STATUS_RESOURCE_EXHAUSTED); - r->Set(GrpcMessageMetadata(), - Slice::FromCopiedString( - absl::StrFormat("%s message larger than max (%u vs. %d)", - is_send ? "Sent" : "Received", - msg->payload()->Length(), max_length))); - err->Set(std::move(r)); - return absl::nullopt; + if (config_from_call_context->max_recv_size().has_value() && + (!max_recv_size.has_value() || + *config_from_call_context->max_recv_size() < *max_recv_size)) { + max_recv_size = *config_from_call_context->max_recv_size(); } - return std::move(msg); - }; + limits = grpc_core::MessageSizeParsedConfig(max_send_size, max_recv_size); + } } - public: - explicit CallBuilder(const MessageSizeParsedConfig& limits) - : limits_(limits) {} + ~call_data() {} + + grpc_core::CallCombiner* call_combiner; + grpc_core::MessageSizeParsedConfig limits; + // Receive closures are chained: we inject this closure as the + // recv_message_ready up-call on transport_stream_op, and remember to + // call our next_recv_message_ready member after handling it. + grpc_closure recv_message_ready; + grpc_closure recv_trailing_metadata_ready; + // The error caused by a message that is too large, or absl::OkStatus() + grpc_error_handle error; + // Used by recv_message_ready. + absl::optional* recv_message = nullptr; + // Original recv_message_ready callback, invoked after our own. + grpc_closure* next_recv_message_ready = nullptr; + // Original recv_trailing_metadata callback, invoked after our own. + grpc_closure* original_recv_trailing_metadata_ready; + bool seen_recv_trailing_metadata = false; + grpc_error_handle recv_trailing_metadata_error; +}; + +} // namespace - template - void AddSend(T* pipe_end) { - if (!limits_.max_send_size().has_value()) return; - pipe_end->InterceptAndMap(Interceptor(*limits_.max_send_size(), true)); +// Callback invoked when we receive a message. Here we check the max +// receive message size. +static void recv_message_ready(void* user_data, grpc_error_handle error) { + grpc_call_element* elem = static_cast(user_data); + call_data* calld = static_cast(elem->call_data); + if (calld->recv_message->has_value() && + calld->limits.max_recv_size().has_value() && + (*calld->recv_message)->Length() > + static_cast(*calld->limits.max_recv_size())) { + grpc_error_handle new_error = grpc_error_set_int( + GRPC_ERROR_CREATE(absl::StrFormat( + "Received message larger than max (%u vs. %d)", + (*calld->recv_message)->Length(), *calld->limits.max_recv_size())), + grpc_core::StatusIntProperty::kRpcStatus, + GRPC_STATUS_RESOURCE_EXHAUSTED); + error = grpc_error_add_child(error, new_error); + calld->error = error; } - template - void AddRecv(T* pipe_end) { - if (!limits_.max_recv_size().has_value()) return; - pipe_end->InterceptAndMap(Interceptor(*limits_.max_recv_size(), false)); + // Invoke the next callback. + grpc_closure* closure = calld->next_recv_message_ready; + calld->next_recv_message_ready = nullptr; + if (calld->seen_recv_trailing_metadata) { + // We might potentially see another RECV_MESSAGE op. In that case, we do not + // want to run the recv_trailing_metadata_ready closure again. The newer + // RECV_MESSAGE op cannot cause any errors since the transport has already + // invoked the recv_trailing_metadata_ready closure and all further + // RECV_MESSAGE ops will get null payloads. + calld->seen_recv_trailing_metadata = false; + GRPC_CALL_COMBINER_START(calld->call_combiner, + &calld->recv_trailing_metadata_ready, + calld->recv_trailing_metadata_error, + "continue recv_trailing_metadata_ready"); } + grpc_core::Closure::Run(DEBUG_LOCATION, closure, error); +} - ArenaPromise Run( - CallArgs call_args, NextPromiseFactory next_promise_factory) { - return Race(err_->Wait(), next_promise_factory(std::move(call_args))); +// Callback invoked on completion of recv_trailing_metadata +// Notifies the recv_trailing_metadata batch of any message size failures +static void recv_trailing_metadata_ready(void* user_data, + grpc_error_handle error) { + grpc_call_element* elem = static_cast(user_data); + call_data* calld = static_cast(elem->call_data); + if (calld->next_recv_message_ready != nullptr) { + calld->seen_recv_trailing_metadata = true; + calld->recv_trailing_metadata_error = error; + GRPC_CALL_COMBINER_STOP(calld->call_combiner, + "deferring recv_trailing_metadata_ready until " + "after recv_message_ready"); + return; } + error = grpc_error_add_child(error, calld->error); + // Invoke the next callback. + grpc_core::Closure::Run(DEBUG_LOCATION, + calld->original_recv_trailing_metadata_ready, error); +} - private: - Latch* const err_ = - GetContext()->New>(); - MessageSizeParsedConfig limits_; -}; - -absl::StatusOr ClientMessageSizeFilter::Create( - const ChannelArgs& args, ChannelFilter::Args) { - return ClientMessageSizeFilter(args); +// Start transport stream op. +static void message_size_start_transport_stream_op_batch( + grpc_call_element* elem, grpc_transport_stream_op_batch* op) { + call_data* calld = static_cast(elem->call_data); + // Check max send message size. + if (op->send_message && calld->limits.max_send_size().has_value() && + op->payload->send_message.send_message->Length() > + static_cast(*calld->limits.max_send_size())) { + grpc_transport_stream_op_batch_finish_with_failure( + op, + grpc_error_set_int(GRPC_ERROR_CREATE(absl::StrFormat( + "Sent message larger than max (%u vs. %d)", + op->payload->send_message.send_message->Length(), + *calld->limits.max_send_size())), + grpc_core::StatusIntProperty::kRpcStatus, + GRPC_STATUS_RESOURCE_EXHAUSTED), + calld->call_combiner); + return; + } + // Inject callback for receiving a message. + if (op->recv_message) { + calld->next_recv_message_ready = + op->payload->recv_message.recv_message_ready; + calld->recv_message = op->payload->recv_message.recv_message; + op->payload->recv_message.recv_message_ready = &calld->recv_message_ready; + } + // Inject callback for receiving trailing metadata. + if (op->recv_trailing_metadata) { + calld->original_recv_trailing_metadata_ready = + op->payload->recv_trailing_metadata.recv_trailing_metadata_ready; + op->payload->recv_trailing_metadata.recv_trailing_metadata_ready = + &calld->recv_trailing_metadata_ready; + } + // Chain to the next filter. + grpc_call_next_op(elem, op); } -absl::StatusOr ServerMessageSizeFilter::Create( - const ChannelArgs& args, ChannelFilter::Args) { - return ServerMessageSizeFilter(args); +// Constructor for call_data. +static grpc_error_handle message_size_init_call_elem( + grpc_call_element* elem, const grpc_call_element_args* args) { + channel_data* chand = static_cast(elem->channel_data); + new (elem->call_data) call_data(elem, *chand, *args); + return absl::OkStatus(); } -ArenaPromise ClientMessageSizeFilter::MakeCallPromise( - CallArgs call_args, NextPromiseFactory next_promise_factory) { - // Get max sizes from channel data, then merge in per-method config values. - // Note: Per-method config is only available on the client, so we - // apply the max request size to the send limit and the max response - // size to the receive limit. - MessageSizeParsedConfig limits = this->limits(); - const MessageSizeParsedConfig* config_from_call_context = - MessageSizeParsedConfig::GetFromCallContext( - GetContext(), - service_config_parser_index_); - if (config_from_call_context != nullptr) { - absl::optional max_send_size = limits.max_send_size(); - absl::optional max_recv_size = limits.max_recv_size(); - if (config_from_call_context->max_send_size().has_value() && - (!max_send_size.has_value() || - *config_from_call_context->max_send_size() < *max_send_size)) { - max_send_size = *config_from_call_context->max_send_size(); - } - if (config_from_call_context->max_recv_size().has_value() && - (!max_recv_size.has_value() || - *config_from_call_context->max_recv_size() < *max_recv_size)) { - max_recv_size = *config_from_call_context->max_recv_size(); - } - limits = MessageSizeParsedConfig(max_send_size, max_recv_size); - } +// Destructor for call_data. +static void message_size_destroy_call_elem( + grpc_call_element* elem, const grpc_call_final_info* /*final_info*/, + grpc_closure* /*ignored*/) { + call_data* calld = static_cast(elem->call_data); + calld->~call_data(); +} - CallBuilder b(limits); - b.AddSend(call_args.client_to_server_messages); - b.AddRecv(call_args.server_to_client_messages); - return b.Run(std::move(call_args), std::move(next_promise_factory)); +// Constructor for channel_data. +static grpc_error_handle message_size_init_channel_elem( + grpc_channel_element* elem, grpc_channel_element_args* args) { + GPR_ASSERT(!args->is_last); + channel_data* chand = static_cast(elem->channel_data); + new (chand) channel_data(); + chand->limits = grpc_core::MessageSizeParsedConfig::GetFromChannelArgs( + args->channel_args); + return absl::OkStatus(); } -ArenaPromise ServerMessageSizeFilter::MakeCallPromise( - CallArgs call_args, NextPromiseFactory next_promise_factory) { - CallBuilder b(limits()); - b.AddSend(call_args.server_to_client_messages); - b.AddRecv(call_args.client_to_server_messages); - return b.Run(std::move(call_args), std::move(next_promise_factory)); +// Destructor for channel_data. +static void message_size_destroy_channel_elem(grpc_channel_element* elem) { + channel_data* chand = static_cast(elem->channel_data); + chand->~channel_data(); } -namespace { +const grpc_channel_filter grpc_message_size_filter = { + message_size_start_transport_stream_op_batch, + nullptr, + grpc_channel_next_op, + sizeof(call_data), + message_size_init_call_elem, + grpc_call_stack_ignore_set_pollset_or_pollset_set, + message_size_destroy_call_elem, + sizeof(channel_data), + message_size_init_channel_elem, + grpc_channel_stack_no_post_init, + message_size_destroy_channel_elem, + grpc_channel_next_get_info, + "message_size"}; + // Used for GRPC_CLIENT_SUBCHANNEL -bool MaybeAddMessageSizeFilterToSubchannel(ChannelStackBuilder* builder) { +static bool maybe_add_message_size_filter_subchannel( + grpc_core::ChannelStackBuilder* builder) { if (builder->channel_args().WantMinimalStack()) { return true; } - builder->PrependFilter(&ClientMessageSizeFilter::kFilter); + builder->PrependFilter(&grpc_message_size_filter); return true; } -// Used for GRPC_CLIENT_DIRECT_CHANNEL and GRPC_SERVER_CHANNEL. Adds the -// filter only if message size limits or service config is specified. -auto MaybeAddMessageSizeFilter(const grpc_channel_filter* filter) { - return [filter](ChannelStackBuilder* builder) { - auto channel_args = builder->channel_args(); - if (channel_args.WantMinimalStack()) { - return true; - } - MessageSizeParsedConfig limits = - MessageSizeParsedConfig::GetFromChannelArgs(channel_args); - const bool enable = - limits.max_send_size().has_value() || - limits.max_recv_size().has_value() || - channel_args.GetString(GRPC_ARG_SERVICE_CONFIG).has_value(); - if (enable) builder->PrependFilter(filter); +// Used for GRPC_CLIENT_DIRECT_CHANNEL and GRPC_SERVER_CHANNEL. Adds the filter +// only if message size limits or service config is specified. +static bool maybe_add_message_size_filter( + grpc_core::ChannelStackBuilder* builder) { + auto channel_args = builder->channel_args(); + if (channel_args.WantMinimalStack()) { return true; - }; + } + grpc_core::MessageSizeParsedConfig limits = + grpc_core::MessageSizeParsedConfig::GetFromChannelArgs(channel_args); + const bool enable = + limits.max_send_size().has_value() || + limits.max_recv_size().has_value() || + channel_args.GetString(GRPC_ARG_SERVICE_CONFIG).has_value(); + if (enable) builder->PrependFilter(&grpc_message_size_filter); + return true; } -} // namespace +namespace grpc_core { void RegisterMessageSizeFilter(CoreConfiguration::Builder* builder) { MessageSizeParser::Register(builder); - builder->channel_init()->RegisterStage(GRPC_CLIENT_SUBCHANNEL, - GRPC_CHANNEL_INIT_BUILTIN_PRIORITY, - MaybeAddMessageSizeFilterToSubchannel); - builder->channel_init()->RegisterStage( - GRPC_CLIENT_DIRECT_CHANNEL, GRPC_CHANNEL_INIT_BUILTIN_PRIORITY, - MaybeAddMessageSizeFilter(&ClientMessageSizeFilter::kFilter)); builder->channel_init()->RegisterStage( - GRPC_SERVER_CHANNEL, GRPC_CHANNEL_INIT_BUILTIN_PRIORITY, - MaybeAddMessageSizeFilter(&ServerMessageSizeFilter::kFilter)); + GRPC_CLIENT_SUBCHANNEL, GRPC_CHANNEL_INIT_BUILTIN_PRIORITY, + maybe_add_message_size_filter_subchannel); + builder->channel_init()->RegisterStage(GRPC_CLIENT_DIRECT_CHANNEL, + GRPC_CHANNEL_INIT_BUILTIN_PRIORITY, + maybe_add_message_size_filter); + builder->channel_init()->RegisterStage(GRPC_SERVER_CHANNEL, + GRPC_CHANNEL_INIT_BUILTIN_PRIORITY, + maybe_add_message_size_filter); } } // namespace grpc_core diff --git a/src/core/ext/filters/message_size/message_size_filter.h b/src/core/ext/filters/message_size/message_size_filter.h index 75135a1b75e..e47485a8950 100644 --- a/src/core/ext/filters/message_size/message_size_filter.h +++ b/src/core/ext/filters/message_size/message_size_filter.h @@ -24,22 +24,21 @@ #include -#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "src/core/lib/channel/channel_args.h" #include "src/core/lib/channel/channel_fwd.h" +#include "src/core/lib/channel/channel_stack.h" #include "src/core/lib/channel/context.h" -#include "src/core/lib/channel/promise_based_filter.h" #include "src/core/lib/config/core_configuration.h" #include "src/core/lib/gprpp/validation_errors.h" #include "src/core/lib/json/json.h" #include "src/core/lib/json/json_args.h" #include "src/core/lib/json/json_object_loader.h" -#include "src/core/lib/promise/arena_promise.h" #include "src/core/lib/service_config/service_config_parser.h" -#include "src/core/lib/transport/transport.h" + +extern const grpc_channel_filter grpc_message_size_filter; namespace grpc_core { @@ -86,50 +85,6 @@ class MessageSizeParser : public ServiceConfigParser::Parser { absl::optional GetMaxRecvSizeFromChannelArgs(const ChannelArgs& args); absl::optional GetMaxSendSizeFromChannelArgs(const ChannelArgs& args); -class MessageSizeFilter : public ChannelFilter { - protected: - explicit MessageSizeFilter(const ChannelArgs& args) - : limits_(MessageSizeParsedConfig::GetFromChannelArgs(args)) {} - - class CallBuilder; - - const MessageSizeParsedConfig& limits() const { return limits_; } - - private: - MessageSizeParsedConfig limits_; -}; - -class ServerMessageSizeFilter final : public MessageSizeFilter { - public: - static const grpc_channel_filter kFilter; - - static absl::StatusOr Create( - const ChannelArgs& args, ChannelFilter::Args filter_args); - - // Construct a promise for one call. - ArenaPromise MakeCallPromise( - CallArgs call_args, NextPromiseFactory next_promise_factory) override; - - private: - using MessageSizeFilter::MessageSizeFilter; -}; - -class ClientMessageSizeFilter final : public MessageSizeFilter { - public: - static const grpc_channel_filter kFilter; - - static absl::StatusOr Create( - const ChannelArgs& args, ChannelFilter::Args filter_args); - - // Construct a promise for one call. - ArenaPromise MakeCallPromise( - CallArgs call_args, NextPromiseFactory next_promise_factory) override; - - private: - const size_t service_config_parser_index_{MessageSizeParser::ParserIndex()}; - using MessageSizeFilter::MessageSizeFilter; -}; - } // namespace grpc_core #endif // GRPC_SRC_CORE_EXT_FILTERS_MESSAGE_SIZE_MESSAGE_SIZE_FILTER_H diff --git a/src/core/ext/transport/inproc/inproc_transport.cc b/src/core/ext/transport/inproc/inproc_transport.cc index f99d41bfeed..dc6b4804f0a 100644 --- a/src/core/ext/transport/inproc/inproc_transport.cc +++ b/src/core/ext/transport/inproc/inproc_transport.cc @@ -766,8 +766,6 @@ void op_state_machine_locked(inproc_stream* s, grpc_error_handle error) { nullptr); s->to_read_trailing_md.Clear(); s->to_read_trailing_md_filled = false; - s->recv_trailing_md_op->payload->recv_trailing_metadata - .recv_trailing_metadata->Set(grpc_core::GrpcStatusFromWire(), true); // We should schedule the recv_trailing_md_op completion if // 1. this stream is the client-side diff --git a/src/core/lib/channel/connected_channel.cc b/src/core/lib/channel/connected_channel.cc index f66203b2598..ea3f226311a 100644 --- a/src/core/lib/channel/connected_channel.cc +++ b/src/core/lib/channel/connected_channel.cc @@ -56,7 +56,6 @@ #include "src/core/lib/gprpp/orphanable.h" #include "src/core/lib/gprpp/status_helper.h" #include "src/core/lib/gprpp/sync.h" -#include "src/core/lib/gprpp/time.h" #include "src/core/lib/iomgr/call_combiner.h" #include "src/core/lib/iomgr/closure.h" #include "src/core/lib/iomgr/error.h" @@ -74,7 +73,6 @@ #include "src/core/lib/surface/call.h" #include "src/core/lib/surface/call_trace.h" #include "src/core/lib/surface/channel_stack_type.h" -#include "src/core/lib/transport/error_utils.h" #include "src/core/lib/transport/metadata_batch.h" #include "src/core/lib/transport/transport.h" #include "src/core/lib/transport/transport_fwd.h" @@ -480,15 +478,7 @@ class ConnectedChannelStream : public Orphanable { return Match( recv_message_state_, [](Idle) -> std::string { return "IDLE"; }, [](Closed) -> std::string { return "CLOSED"; }, - [](const PendingReceiveMessage& m) -> std::string { - if (m.received) { - return absl::StrCat("RECEIVED_FROM_TRANSPORT:", - m.payload.has_value() - ? absl::StrCat(m.payload->Length(), "b") - : "EOS"); - } - return "WAITING"; - }, + [](const PendingReceiveMessage&) -> std::string { return "WAITING"; }, [](const absl::optional& message) -> std::string { return absl::StrCat( "READY:", message.has_value() @@ -580,7 +570,13 @@ class ConnectedChannelStream : public Orphanable { void RecvMessageBatchDone(grpc_error_handle error) { { MutexLock lock(mu()); - if (absl::holds_alternative(recv_message_state_)) { + if (error != absl::OkStatus()) { + if (grpc_call_trace.enabled()) { + gpr_log(GPR_INFO, "%s[connected] RecvMessageBatchDone: error=%s", + recv_message_waker_.ActivityDebugTag().c_str(), + StatusToString(error).c_str()); + } + } else if (absl::holds_alternative(recv_message_state_)) { if (grpc_call_trace.enabled()) { gpr_log(GPR_INFO, "%s[connected] RecvMessageBatchDone: already closed, " @@ -588,21 +584,14 @@ class ConnectedChannelStream : public Orphanable { recv_message_waker_.ActivityDebugTag().c_str()); } } else { - auto pending = - absl::get_if(&recv_message_state_); - GPR_ASSERT(pending != nullptr); - if (!error.ok()) { - if (grpc_call_trace.enabled()) { - gpr_log(GPR_INFO, "%s[connected] RecvMessageBatchDone: error=%s", - recv_message_waker_.ActivityDebugTag().c_str(), - StatusToString(error).c_str()); - } - pending->payload.reset(); - } else if (grpc_call_trace.enabled()) { + if (grpc_call_trace.enabled()) { gpr_log(GPR_INFO, "%s[connected] RecvMessageBatchDone: received message", recv_message_waker_.ActivityDebugTag().c_str()); } + auto pending = + absl::get_if(&recv_message_state_); + GPR_ASSERT(pending != nullptr); GPR_ASSERT(pending->received == false); pending->received = true; } @@ -682,8 +671,6 @@ class ClientStream : public ConnectedChannelStream { public: ClientStream(grpc_transport* transport, CallArgs call_args) : ConnectedChannelStream(transport), - client_initial_metadata_outstanding_token_( - std::move(call_args.client_initial_metadata_outstanding)), server_initial_metadata_pipe_(call_args.server_initial_metadata), client_to_server_messages_(call_args.client_to_server_messages), server_to_client_messages_(call_args.server_to_client_messages), @@ -717,14 +704,12 @@ class ClientStream : public ConnectedChannelStream { nullptr, GetContext()); grpc_transport_set_pops(transport(), stream(), GetContext()->polling_entity()); - memset(&send_metadata_, 0, sizeof(send_metadata_)); - memset(&recv_metadata_, 0, sizeof(recv_metadata_)); - send_metadata_.send_initial_metadata = true; - recv_metadata_.recv_initial_metadata = true; - recv_metadata_.recv_trailing_metadata = true; - send_metadata_.payload = batch_payload(); - recv_metadata_.payload = batch_payload(); - send_metadata_.on_complete = &send_metadata_batch_done_; + memset(&metadata_, 0, sizeof(metadata_)); + metadata_.send_initial_metadata = true; + metadata_.recv_initial_metadata = true; + metadata_.recv_trailing_metadata = true; + metadata_.payload = batch_payload(); + metadata_.on_complete = &metadata_batch_done_; batch_payload()->send_initial_metadata.send_initial_metadata = client_initial_metadata_.get(); batch_payload()->send_initial_metadata.peer_string = @@ -749,16 +734,9 @@ class ClientStream : public ConnectedChannelStream { IncrementRefCount("metadata_batch_done"); IncrementRefCount("initial_metadata_ready"); IncrementRefCount("trailing_metadata_ready"); - recv_initial_metadata_waker_ = Activity::current()->MakeOwningWaker(); - recv_trailing_metadata_waker_ = Activity::current()->MakeOwningWaker(); - send_initial_metadata_waker_ = Activity::current()->MakeOwningWaker(); - SchedulePush(&send_metadata_); - SchedulePush(&recv_metadata_); - } - if (std::exchange(need_to_clear_client_initial_metadata_outstanding_token_, - false)) { - client_initial_metadata_outstanding_token_.Complete( - client_initial_metadata_send_result_); + initial_metadata_waker_ = Activity::current()->MakeOwningWaker(); + trailing_metadata_waker_ = Activity::current()->MakeOwningWaker(); + SchedulePush(&metadata_); } if (server_initial_metadata_state_ == ServerInitialMetadataState::kReceivedButNotPushed) { @@ -775,21 +753,9 @@ class ClientStream : public ConnectedChannelStream { server_initial_metadata_push_promise_.reset(); } } - if (server_initial_metadata_state_ == ServerInitialMetadataState::kError) { - server_initial_metadata_pipe_->Close(); - } PollSendMessage(client_to_server_messages_, &client_trailing_metadata_); PollRecvMessage(server_to_client_messages_); - if (grpc_call_trace.enabled()) { - gpr_log( - GPR_INFO, - "%s[connected] Finishing PollConnectedChannel: requesting metadata", - Activity::current()->DebugTag().c_str()); - } - if ((server_initial_metadata_state_ == - ServerInitialMetadataState::kPushed || - server_initial_metadata_state_ == - ServerInitialMetadataState::kError) && + if (server_initial_metadata_state_ == ServerInitialMetadataState::kPushed && !IsPromiseReceiving() && std::exchange(queued_trailing_metadata_, false)) { if (grpc_call_trace.enabled()) { @@ -808,32 +774,18 @@ class ClientStream : public ConnectedChannelStream { } void RecvInitialMetadataReady(grpc_error_handle error) { + GPR_ASSERT(error == absl::OkStatus()); { MutexLock lock(mu()); - if (grpc_call_trace.enabled()) { - gpr_log(GPR_DEBUG, "%s[connected] RecvInitialMetadataReady: error=%s", - recv_initial_metadata_waker_.ActivityDebugTag().c_str(), - error.ToString().c_str()); - } server_initial_metadata_state_ = - error.ok() ? ServerInitialMetadataState::kReceivedButNotPushed - : ServerInitialMetadataState::kError; - recv_initial_metadata_waker_.Wakeup(); + ServerInitialMetadataState::kReceivedButNotPushed; + initial_metadata_waker_.Wakeup(); } Unref("initial_metadata_ready"); } void RecvTrailingMetadataReady(grpc_error_handle error) { - if (!error.ok()) { - server_trailing_metadata_->Clear(); - grpc_status_code status = GRPC_STATUS_UNKNOWN; - std::string message; - grpc_error_get_status(error, Timestamp::InfFuture(), &status, &message, - nullptr, nullptr); - server_trailing_metadata_->Set(GrpcStatusMetadata(), status); - server_trailing_metadata_->Set(GrpcMessageMetadata(), - Slice::FromCopiedString(message)); - } + GPR_ASSERT(error == absl::OkStatus()); { MutexLock lock(mu()); queued_trailing_metadata_ = true; @@ -842,21 +794,16 @@ class ClientStream : public ConnectedChannelStream { "%s[connected] RecvTrailingMetadataReady: " "queued_trailing_metadata_ " "set to true; active_ops: %s", - recv_trailing_metadata_waker_.ActivityDebugTag().c_str(), + trailing_metadata_waker_.ActivityDebugTag().c_str(), ActiveOpsString().c_str()); } - recv_trailing_metadata_waker_.Wakeup(); + trailing_metadata_waker_.Wakeup(); } Unref("trailing_metadata_ready"); } - void SendMetadataBatchDone(grpc_error_handle error) { - { - MutexLock lock(mu()); - need_to_clear_client_initial_metadata_outstanding_token_ = true; - client_initial_metadata_send_result_ = error.ok(); - send_initial_metadata_waker_.Wakeup(); - } + void MetadataBatchDone(grpc_error_handle error) { + GPR_ASSERT(error == absl::OkStatus()); Unref("metadata_batch_done"); } @@ -876,8 +823,6 @@ class ClientStream : public ConnectedChannelStream { // has been pushed on the pipe to publish it up the call stack AND removed // by the call at the top. kPushed, - // Received initial metadata with an error status. - kError, }; std::string ActiveOpsString() const override @@ -886,10 +831,10 @@ class ClientStream : public ConnectedChannelStream { if (finished()) ops.push_back("FINISHED"); // Outstanding Operations on Transport std::vector waiting; - if (recv_initial_metadata_waker_ != Waker()) { + if (initial_metadata_waker_ != Waker()) { waiting.push_back("initial_metadata"); } - if (recv_trailing_metadata_waker_ != Waker()) { + if (trailing_metadata_waker_ != Waker()) { waiting.push_back("trailing_metadata"); } if (!waiting.empty()) { @@ -905,18 +850,6 @@ class ClientStream : public ConnectedChannelStream { if (!queued.empty()) { ops.push_back(absl::StrCat("queued:", absl::StrJoin(queued, ","))); } - switch (server_initial_metadata_state_) { - case ServerInitialMetadataState::kNotReceived: - case ServerInitialMetadataState::kReceivedButNotPushed: - case ServerInitialMetadataState::kPushed: - break; - case ServerInitialMetadataState::kPushing: - ops.push_back("server_initial_metadata:PUSHING"); - break; - case ServerInitialMetadataState::kError: - ops.push_back("server_initial_metadata:ERROR"); - break; - } // Send message std::string send_message_state = SendMessageString(); if (send_message_state != "WAITING") { @@ -931,17 +864,11 @@ class ClientStream : public ConnectedChannelStream { } bool requested_metadata_ = false; - bool need_to_clear_client_initial_metadata_outstanding_token_ - ABSL_GUARDED_BY(mu()) = false; - bool client_initial_metadata_send_result_ ABSL_GUARDED_BY(mu()); ServerInitialMetadataState server_initial_metadata_state_ ABSL_GUARDED_BY(mu()) = ServerInitialMetadataState::kNotReceived; bool queued_trailing_metadata_ ABSL_GUARDED_BY(mu()) = false; - Waker recv_initial_metadata_waker_ ABSL_GUARDED_BY(mu()); - Waker send_initial_metadata_waker_ ABSL_GUARDED_BY(mu()); - Waker recv_trailing_metadata_waker_ ABSL_GUARDED_BY(mu()); - ClientInitialMetadataOutstandingToken - client_initial_metadata_outstanding_token_; + Waker initial_metadata_waker_ ABSL_GUARDED_BY(mu()); + Waker trailing_metadata_waker_ ABSL_GUARDED_BY(mu()); PipeSender* server_initial_metadata_pipe_; PipeReceiver* client_to_server_messages_; PipeSender* server_to_client_messages_; @@ -957,10 +884,9 @@ class ClientStream : public ConnectedChannelStream { ServerMetadataHandle server_trailing_metadata_; absl::optional::PushType> server_initial_metadata_push_promise_; - grpc_transport_stream_op_batch send_metadata_; - grpc_transport_stream_op_batch recv_metadata_; - grpc_closure send_metadata_batch_done_ = - MakeMemberClosure( + grpc_transport_stream_op_batch metadata_; + grpc_closure metadata_batch_done_ = + MakeMemberClosure( this, DEBUG_LOCATION); }; @@ -1022,7 +948,6 @@ class ServerStream final : public ConnectedChannelStream { gim.client_initial_metadata.get(); batch_payload()->recv_initial_metadata.recv_initial_metadata_ready = &gim.recv_initial_metadata_ready; - IncrementRefCount("RecvInitialMetadata"); SchedulePush(&gim.recv_initial_metadata); // Fetch trailing metadata (to catch cancellations) @@ -1040,9 +965,8 @@ class ServerStream final : public ConnectedChannelStream { &GetContext()->call_stats()->transport_stream_stats; batch_payload()->recv_trailing_metadata.recv_trailing_metadata_ready = >m.recv_trailing_metadata_ready; - gtm.waker = Activity::current()->MakeOwningWaker(); - IncrementRefCount("RecvTrailingMetadata"); SchedulePush(>m.recv_trailing_metadata); + gtm.waker = Activity::current()->MakeOwningWaker(); } Poll PollOnce() { @@ -1071,7 +995,6 @@ class ServerStream final : public ConnectedChannelStream { .emplace(std::move(**md)) .get(); batch_payload()->send_initial_metadata.peer_string = nullptr; - IncrementRefCount("SendInitialMetadata"); SchedulePush(&send_initial_metadata_); return true; } else { @@ -1116,7 +1039,6 @@ class ServerStream final : public ConnectedChannelStream { incoming_messages_ = &pipes_.client_to_server.sender; auto promise = p->next_promise_factory(CallArgs{ std::move(p->client_initial_metadata), - ClientInitialMetadataOutstandingToken::Empty(), &pipes_.server_initial_metadata.sender, &pipes_.client_to_server.receiver, &pipes_.server_to_client.sender}); call_state_.emplace( @@ -1171,7 +1093,6 @@ class ServerStream final : public ConnectedChannelStream { ->as_string_view()), StatusIntProperty::kRpcStatus, status_code); } - IncrementRefCount("SendTrailingMetadata"); SchedulePush(&op); } } @@ -1267,7 +1188,6 @@ class ServerStream final : public ConnectedChannelStream { std::move(getting.next_promise_factory)}; call_state_.emplace(std::move(got)); waker.Wakeup(); - Unref("RecvInitialMetadata"); } void SendTrailingMetadataDone(absl::Status result) { @@ -1280,19 +1200,16 @@ class ServerStream final : public ConnectedChannelStream { waker.ActivityDebugTag().c_str(), result.ToString().c_str(), completing.sent ? "true" : "false", md->DebugString().c_str()); } + md->Set(GrpcStatusFromWire(), completing.sent); if (!result.ok()) { md->Clear(); md->Set(GrpcStatusMetadata(), static_cast(result.code())); md->Set(GrpcMessageMetadata(), Slice::FromCopiedString(result.message())); - md->Set(GrpcCallWasCancelled(), true); - } - if (!md->get(GrpcCallWasCancelled()).has_value()) { - md->Set(GrpcCallWasCancelled(), !completing.sent); + md->Set(GrpcStatusFromWire(), false); } call_state_.emplace(Complete{std::move(md)}); waker.Wakeup(); - Unref("SendTrailingMetadata"); } std::string ActiveOpsString() const override @@ -1344,7 +1261,7 @@ class ServerStream final : public ConnectedChannelStream { return absl::StrJoin(ops, " "); } - void SendInitialMetadataDone() { Unref("SendInitialMetadata"); } + void SendInitialMetadataDone() {} void RecvTrailingMetadataReady(absl::Status error) { MutexLock lock(mu()); @@ -1368,7 +1285,6 @@ class ServerStream final : public ConnectedChannelStream { client_trailing_metadata_state_.emplace( GotClientHalfClose{error}); waker.Wakeup(); - Unref("RecvTrailingMetadata"); } struct Pipes { diff --git a/src/core/lib/channel/promise_based_filter.cc b/src/core/lib/channel/promise_based_filter.cc index 74aa3f489af..7fb044c0fb8 100644 --- a/src/core/lib/channel/promise_based_filter.cc +++ b/src/core/lib/channel/promise_based_filter.cc @@ -16,8 +16,6 @@ #include "src/core/lib/channel/promise_based_filter.h" -#include - #include #include #include @@ -217,7 +215,7 @@ void BaseCallData::CapturedBatch::ResumeWith(Flusher* releaser) { // refcnt==0 ==> cancelled if (grpc_trace_channel.enabled()) { gpr_log(GPR_INFO, "%sRESUME BATCH REQUEST CANCELLED", - releaser->call()->DebugTag().c_str()); + Activity::current()->DebugTag().c_str()); } return; } @@ -241,10 +239,6 @@ void BaseCallData::CapturedBatch::CancelWith(grpc_error_handle error, auto* batch = std::exchange(batch_, nullptr); GPR_ASSERT(batch != nullptr); uintptr_t& refcnt = *RefCountField(batch); - gpr_log(GPR_DEBUG, "%sCancelWith: %p refs=%" PRIdPTR " err=%s [%s]", - releaser->call()->DebugTag().c_str(), batch, refcnt, - error.ToString().c_str(), - grpc_transport_stream_op_batch_string(batch).c_str()); if (refcnt == 0) { // refcnt==0 ==> cancelled if (grpc_trace_channel.enabled()) { @@ -331,8 +325,6 @@ const char* BaseCallData::SendMessage::StateString(State state) { return "CANCELLED"; case State::kCancelledButNotYetPolled: return "CANCELLED_BUT_NOT_YET_POLLED"; - case State::kCancelledButNoStatus: - return "CANCELLED_BUT_NO_STATUS"; } return "UNKNOWN"; } @@ -357,7 +349,6 @@ void BaseCallData::SendMessage::StartOp(CapturedBatch batch) { Crash(absl::StrFormat("ILLEGAL STATE: %s", StateString(state_))); case State::kCancelled: case State::kCancelledButNotYetPolled: - case State::kCancelledButNoStatus: return; } batch_ = batch; @@ -385,7 +376,6 @@ void BaseCallData::SendMessage::GotPipe(T* pipe_end) { case State::kForwardedBatch: case State::kBatchCompleted: case State::kPushedToPipe: - case State::kCancelledButNoStatus: Crash(absl::StrFormat("ILLEGAL STATE: %s", StateString(state_))); case State::kCancelled: case State::kCancelledButNotYetPolled: @@ -401,7 +391,6 @@ bool BaseCallData::SendMessage::IsIdle() const { case State::kForwardedBatch: case State::kCancelled: case State::kCancelledButNotYetPolled: - case State::kCancelledButNoStatus: return true; case State::kGotBatchNoPipe: case State::kGotBatch: @@ -430,7 +419,6 @@ void BaseCallData::SendMessage::OnComplete(absl::Status status) { break; case State::kCancelled: case State::kCancelledButNotYetPolled: - case State::kCancelledButNoStatus: flusher.AddClosure(intercepted_on_complete_, status, "forward after cancel"); break; @@ -455,14 +443,10 @@ void BaseCallData::SendMessage::Done(const ServerMetadata& metadata, case State::kCancelledButNotYetPolled: break; case State::kInitial: - state_ = State::kCancelled; - break; case State::kIdle: case State::kForwardedBatch: state_ = State::kCancelledButNotYetPolled; - if (base_->is_current()) base_->ForceImmediateRepoll(); break; - case State::kCancelledButNoStatus: case State::kGotBatchNoPipe: case State::kGotBatch: { std::string temp; @@ -481,7 +465,6 @@ void BaseCallData::SendMessage::Done(const ServerMetadata& metadata, push_.reset(); next_.reset(); state_ = State::kCancelledButNotYetPolled; - if (base_->is_current()) base_->ForceImmediateRepoll(); break; } } @@ -500,7 +483,6 @@ void BaseCallData::SendMessage::WakeInsideCombiner(Flusher* flusher, case State::kIdle: case State::kGotBatchNoPipe: case State::kCancelled: - case State::kCancelledButNoStatus: break; case State::kCancelledButNotYetPolled: interceptor()->Push()->Close(); @@ -542,18 +524,13 @@ void BaseCallData::SendMessage::WakeInsideCombiner(Flusher* flusher, "result.has_value=%s", base_->LogTag().c_str(), p->has_value() ? "true" : "false"); } - if (p->has_value()) { - batch_->payload->send_message.send_message->Swap((**p)->payload()); - batch_->payload->send_message.flags = (**p)->flags(); - state_ = State::kForwardedBatch; - batch_.ResumeWith(flusher); - next_.reset(); - if (!absl::holds_alternative((*push_)())) push_.reset(); - } else { - state_ = State::kCancelledButNoStatus; - next_.reset(); - push_.reset(); - } + GPR_ASSERT(p->has_value()); + batch_->payload->send_message.send_message->Swap((**p)->payload()); + batch_->payload->send_message.flags = (**p)->flags(); + state_ = State::kForwardedBatch; + batch_.ResumeWith(flusher); + next_.reset(); + if (!absl::holds_alternative((*push_)())) push_.reset(); } } break; case State::kForwardedBatch: @@ -1113,14 +1090,11 @@ class ClientCallData::PollContext { // Poll the promise once since we're waiting for it. Poll poll = self_->promise_(); if (grpc_trace_channel.enabled()) { - gpr_log(GPR_INFO, "%s ClientCallData.PollContext.Run: poll=%s; %s", + gpr_log(GPR_INFO, "%s ClientCallData.PollContext.Run: poll=%s", self_->LogTag().c_str(), - PollToString(poll, - [](const ServerMetadataHandle& h) { - return h->DebugString(); - }) - .c_str(), - self_->DebugString().c_str()); + PollToString(poll, [](const ServerMetadataHandle& h) { + return h->DebugString(); + }).c_str()); } if (auto* r = absl::get_if(&poll)) { auto md = std::move(*r); @@ -1300,11 +1274,7 @@ ClientCallData::ClientCallData(grpc_call_element* elem, [args]() { return args->arena->New(args->arena); }, - [args]() { return args->arena->New(args->arena); }), - initial_metadata_outstanding_token_( - (flags & kFilterIsLast) != 0 - ? ClientInitialMetadataOutstandingToken::New(arena()) - : ClientInitialMetadataOutstandingToken::Empty()) { + [args]() { return args->arena->New(args->arena); }) { GRPC_CLOSURE_INIT(&recv_trailing_metadata_ready_, RecvTrailingMetadataReadyCallback, this, grpc_schedule_on_exec_ctx); @@ -1573,7 +1543,6 @@ void ClientCallData::StartPromise(Flusher* flusher) { promise_ = filter->MakeCallPromise( CallArgs{WrapMetadata(send_initial_metadata_batch_->payload ->send_initial_metadata.send_initial_metadata), - std::move(initial_metadata_outstanding_token_), server_initial_metadata_pipe() == nullptr ? nullptr : &server_initial_metadata_pipe()->sender, @@ -1894,15 +1863,8 @@ struct ServerCallData::SendInitialMetadata { class ServerCallData::PollContext { public: - explicit PollContext(ServerCallData* self, Flusher* flusher, - DebugLocation created = DebugLocation()) - : self_(self), flusher_(flusher), created_(created) { - if (self_->poll_ctx_ != nullptr) { - Crash(absl::StrCat( - "PollContext: disallowed recursion. New: ", created_.file(), ":", - created_.line(), "; Old: ", self_->poll_ctx_->created_.file(), ":", - self_->poll_ctx_->created_.line())); - } + explicit PollContext(ServerCallData* self, Flusher* flusher) + : self_(self), flusher_(flusher) { GPR_ASSERT(self_->poll_ctx_ == nullptr); self_->poll_ctx_ = this; scoped_activity_.Init(self_); @@ -1948,7 +1910,6 @@ class ServerCallData::PollContext { Flusher* const flusher_; bool repoll_ = false; bool have_scoped_activity_; - GPR_NO_UNIQUE_ADDRESS DebugLocation created_; }; const char* ServerCallData::StateString(RecvInitialState state) { @@ -2118,10 +2079,7 @@ void ServerCallData::StartBatch(grpc_transport_stream_op_batch* b) { switch (send_trailing_state_) { case SendTrailingState::kInitial: send_trailing_metadata_batch_ = batch; - if (receive_message() != nullptr && - batch->payload->send_trailing_metadata.send_trailing_metadata - ->get(GrpcStatusMetadata()) - .value_or(GRPC_STATUS_UNKNOWN) != GRPC_STATUS_OK) { + if (receive_message() != nullptr) { receive_message()->Done( *batch->payload->send_trailing_metadata.send_trailing_metadata, &flusher); @@ -2178,12 +2136,9 @@ void ServerCallData::Completed(grpc_error_handle error, Flusher* flusher) { case SendTrailingState::kForwarded: send_trailing_state_ = SendTrailingState::kCancelled; if (!error.ok()) { - call_stack()->IncrementRefCount(); auto* batch = grpc_make_transport_stream_op( - NewClosure([call_combiner = call_combiner(), - call_stack = call_stack()](absl::Status) { + NewClosure([call_combiner = call_combiner()](absl::Status) { GRPC_CALL_COMBINER_STOP(call_combiner, "done-cancel"); - call_stack->Unref(); })); batch->cancel_stream = true; batch->payload->cancel_stream.cancel_error = error; @@ -2357,7 +2312,6 @@ void ServerCallData::RecvInitialMetadataReady(grpc_error_handle error) { FakeActivity(this).Run([this, filter] { promise_ = filter->MakeCallPromise( CallArgs{WrapMetadata(recv_initial_metadata_), - ClientInitialMetadataOutstandingToken::Empty(), server_initial_metadata_pipe() == nullptr ? nullptr : &server_initial_metadata_pipe()->sender, @@ -2459,14 +2413,9 @@ void ServerCallData::WakeInsideCombiner(Flusher* flusher) { (send_trailing_metadata_batch_->send_message && send_message()->IsForwarded()))) { send_trailing_state_ = SendTrailingState::kQueued; - if (send_trailing_metadata_batch_->payload->send_trailing_metadata - .send_trailing_metadata->get(GrpcStatusMetadata()) - .value_or(GRPC_STATUS_UNKNOWN) != GRPC_STATUS_OK) { - send_message()->Done( - *send_trailing_metadata_batch_->payload->send_trailing_metadata - .send_trailing_metadata, - flusher); - } + send_message()->Done(*send_trailing_metadata_batch_->payload + ->send_trailing_metadata.send_trailing_metadata, + flusher); } } if (receive_message() != nullptr) { diff --git a/src/core/lib/channel/promise_based_filter.h b/src/core/lib/channel/promise_based_filter.h index f91e56439b9..5597349337e 100644 --- a/src/core/lib/channel/promise_based_filter.h +++ b/src/core/lib/channel/promise_based_filter.h @@ -218,11 +218,7 @@ class BaseCallData : public Activity, private Wakeable { void Resume(grpc_transport_stream_op_batch* batch) { GPR_ASSERT(!call_->is_last()); - if (batch->HasOp()) { - release_.push_back(batch); - } else if (batch->on_complete != nullptr) { - Complete(batch); - } + release_.push_back(batch); } void Cancel(grpc_transport_stream_op_batch* batch, @@ -241,8 +237,6 @@ class BaseCallData : public Activity, private Wakeable { call_closures_.Add(closure, error, reason); } - BaseCallData* call() const { return call_; } - private: absl::InlinedVector release_; CallCombinerClosureList call_closures_; @@ -404,8 +398,6 @@ class BaseCallData : public Activity, private Wakeable { kCancelledButNotYetPolled, // We're done. kCancelled, - // We're done, but we haven't gotten a status yet - kCancelledButNoStatus, }; static const char* StateString(State); @@ -672,8 +664,6 @@ class ClientCallData : public BaseCallData { RecvTrailingState recv_trailing_state_ = RecvTrailingState::kInitial; // Polling related data. Non-null if we're actively polling PollContext* poll_ctx_ = nullptr; - // Initial metadata outstanding token - ClientInitialMetadataOutstandingToken initial_metadata_outstanding_token_; }; class ServerCallData : public BaseCallData { diff --git a/src/core/lib/iomgr/call_combiner.h b/src/core/lib/iomgr/call_combiner.h index e314479413b..50aeb63c780 100644 --- a/src/core/lib/iomgr/call_combiner.h +++ b/src/core/lib/iomgr/call_combiner.h @@ -171,8 +171,8 @@ class CallCombinerClosureList { if (GRPC_TRACE_FLAG_ENABLED(grpc_call_combiner_trace)) { gpr_log(GPR_INFO, "CallCombinerClosureList executing closure while already " - "holding call_combiner %p: closure=%s error=%s reason=%s", - call_combiner, closures_[0].closure->DebugString().c_str(), + "holding call_combiner %p: closure=%p error=%s reason=%s", + call_combiner, closures_[0].closure, StatusToString(closures_[0].error).c_str(), closures_[0].reason); } // This will release the call combiner. diff --git a/src/core/lib/promise/detail/promise_factory.h b/src/core/lib/promise/detail/promise_factory.h index adca5af7f08..12b291a545b 100644 --- a/src/core/lib/promise/detail/promise_factory.h +++ b/src/core/lib/promise/detail/promise_factory.h @@ -17,7 +17,6 @@ #include -#include #include #include @@ -107,9 +106,6 @@ class Curried { private: GPR_NO_UNIQUE_ADDRESS F f_; GPR_NO_UNIQUE_ADDRESS Arg arg_; -#ifndef NDEBUG - std::unique_ptr asan_canary_ = std::make_unique(0); -#endif }; // Promote a callable(A) -> T | Poll to a PromiseFactory(A) -> Promise by diff --git a/src/core/lib/promise/interceptor_list.h b/src/core/lib/promise/interceptor_list.h index 6b5576db8cb..2e09f574ae5 100644 --- a/src/core/lib/promise/interceptor_list.h +++ b/src/core/lib/promise/interceptor_list.h @@ -140,8 +140,7 @@ class InterceptorList { async_resolution_.space.get()); async_resolution_.current_factory = async_resolution_.current_factory->next(); - if (!p->has_value()) async_resolution_.current_factory = nullptr; - if (async_resolution_.current_factory == nullptr) { + if (async_resolution_.current_factory == nullptr || !p->has_value()) { return std::move(*p); } async_resolution_.current_factory->MakePromise( diff --git a/src/core/lib/promise/latch.h b/src/core/lib/promise/latch.h index ee1f9d6e958..305cf53ab6e 100644 --- a/src/core/lib/promise/latch.h +++ b/src/core/lib/promise/latch.h @@ -89,8 +89,6 @@ class Latch { waiter_.Wake(); } - bool is_set() const { return has_value_; } - private: std::string DebugTag() { return absl::StrCat(Activity::current()->DebugTag(), " LATCH[0x", @@ -167,7 +165,7 @@ class Latch { private: std::string DebugTag() { - return absl::StrCat(Activity::current()->DebugTag(), " LATCH(void)[0x", + return absl::StrCat(Activity::current()->DebugTag(), " LATCH[0x", reinterpret_cast(this), "]: "); } diff --git a/src/core/lib/security/transport/server_auth_filter.cc b/src/core/lib/security/transport/server_auth_filter.cc index c67b97388ba..f5938e99e6c 100644 --- a/src/core/lib/security/transport/server_auth_filter.cc +++ b/src/core/lib/security/transport/server_auth_filter.cc @@ -18,12 +18,12 @@ #include +#include + #include #include -#include #include #include -#include #include #include "absl/status/status.h" @@ -41,7 +41,6 @@ #include "src/core/lib/channel/channel_stack.h" #include "src/core/lib/channel/context.h" #include "src/core/lib/channel/promise_based_filter.h" -#include "src/core/lib/debug/trace.h" #include "src/core/lib/gprpp/debug_location.h" #include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/gprpp/status_helper.h" @@ -58,7 +57,6 @@ #include "src/core/lib/security/transport/auth_filters.h" // IWYU pragma: keep #include "src/core/lib/slice/slice.h" #include "src/core/lib/slice/slice_internal.h" -#include "src/core/lib/surface/call_trace.h" #include "src/core/lib/transport/metadata_batch.h" #include "src/core/lib/transport/transport.h" @@ -122,28 +120,12 @@ class ServerAuthFilter::RunApplicationCode { // memory later RunApplicationCode(ServerAuthFilter* filter, CallArgs call_args) : state_(GetContext()->ManagedNew(std::move(call_args))) { - if (grpc_call_trace.enabled()) { - gpr_log(GPR_ERROR, - "%s[server-auth]: Delegate to application: filter=%p this=%p " - "auth_ctx=%p", - Activity::current()->DebugTag().c_str(), filter, this, - filter->auth_context_.get()); - } filter->server_credentials_->auth_metadata_processor().process( filter->server_credentials_->auth_metadata_processor().state, filter->auth_context_.get(), state_->md.metadata, state_->md.count, OnMdProcessingDone, state_); } - RunApplicationCode(const RunApplicationCode&) = delete; - RunApplicationCode& operator=(const RunApplicationCode&) = delete; - RunApplicationCode(RunApplicationCode&& other) noexcept - : state_(std::exchange(other.state_, nullptr)) {} - RunApplicationCode& operator=(RunApplicationCode&& other) noexcept { - state_ = std::exchange(other.state_, nullptr); - return *this; - } - Poll> operator()() { if (state_->done.load(std::memory_order_acquire)) { return Poll>(std::move(state_->call_args)); diff --git a/src/core/lib/slice/slice.cc b/src/core/lib/slice/slice.cc index 6180ef10e56..51ee3a83644 100644 --- a/src/core/lib/slice/slice.cc +++ b/src/core/lib/slice/slice.cc @@ -480,7 +480,7 @@ int grpc_slice_slice(grpc_slice haystack, grpc_slice needle) { } const uint8_t* last = haystack_bytes + haystack_len - needle_len; - for (const uint8_t* cur = haystack_bytes; cur <= last; ++cur) { + for (const uint8_t* cur = haystack_bytes; cur != last; ++cur) { if (0 == memcmp(cur, needle_bytes, needle_len)) { return static_cast(cur - haystack_bytes); } diff --git a/src/core/lib/surface/call.cc b/src/core/lib/surface/call.cc index d1110193409..346c3fb0db1 100644 --- a/src/core/lib/surface/call.cc +++ b/src/core/lib/surface/call.cc @@ -85,7 +85,6 @@ #include "src/core/lib/promise/arena_promise.h" #include "src/core/lib/promise/context.h" #include "src/core/lib/promise/detail/basic_seq.h" -#include "src/core/lib/promise/latch.h" #include "src/core/lib/promise/pipe.h" #include "src/core/lib/promise/poll.h" #include "src/core/lib/resource_quota/arena.h" @@ -135,13 +134,11 @@ class Call : public CppImplOf { virtual void InternalRef(const char* reason) = 0; virtual void InternalUnref(const char* reason) = 0; - grpc_compression_algorithm test_only_compression_algorithm() { - return incoming_compression_algorithm_; - } - uint32_t test_only_message_flags() { return test_only_last_message_flags_; } - CompressionAlgorithmSet encodings_accepted_by_peer() { - return encodings_accepted_by_peer_; - } + virtual grpc_compression_algorithm test_only_compression_algorithm() = 0; + virtual uint32_t test_only_message_flags() = 0; + virtual uint32_t test_only_encodings_accepted_by_peer() = 0; + virtual grpc_compression_algorithm compression_for_level( + grpc_compression_level level) = 0; // This should return nullptr for the promise stack (and alternative means // for that functionality be invented) @@ -208,29 +205,6 @@ class Call : public CppImplOf { void ClearPeerString() { gpr_atm_rel_store(&peer_string_, 0); } - // TODO(ctiller): cancel_func is for cancellation of the call - filter stack - // holds no mutexes here, promise stack does, and so locking is different. - // Remove this and cancel directly once promise conversion is done. - void ProcessIncomingInitialMetadata( - grpc_metadata_batch& md, - absl::FunctionRef cancel_func); - // Fixup outgoing metadata before sending - adds compression, protects - // internal headers against external modification. - void PrepareOutgoingInitialMetadata(const grpc_op& op, - grpc_metadata_batch& md); - void NoteLastMessageFlags(uint32_t flags) { - test_only_last_message_flags_ = flags; - } - grpc_compression_algorithm incoming_compression_algorithm() const { - return incoming_compression_algorithm_; - } - - static void HandleCompressionAlgorithmDisabled( - grpc_compression_algorithm compression_algorithm, - absl::FunctionRef cancel_func) GPR_ATTRIBUTE_NOINLINE; - void HandleCompressionAlgorithmNotAccepted( - grpc_compression_algorithm compression_algorithm) GPR_ATTRIBUTE_NOINLINE; - private: RefCountedPtr channel_; Arena* const arena_; @@ -242,13 +216,6 @@ class Call : public CppImplOf { bool cancellation_is_inherited_ = false; // A char* indicating the peer name. gpr_atm peer_string_ = 0; - // Compression algorithm for *incoming* data - grpc_compression_algorithm incoming_compression_algorithm_ = - GRPC_COMPRESS_NONE; - // Supported encodings (compression algorithms), a bitset. - // Always support no compression. - CompressionAlgorithmSet encodings_accepted_by_peer_{GRPC_COMPRESS_NONE}; - uint32_t test_only_last_message_flags_ = 0; }; Call::ParentCall* Call::GetOrCreateParentCall() { @@ -385,92 +352,6 @@ void Call::DeleteThis() { arena->Destroy(); } -void Call::PrepareOutgoingInitialMetadata(const grpc_op& op, - grpc_metadata_batch& md) { - // TODO(juanlishen): If the user has already specified a compression - // algorithm by setting the initial metadata with key of - // GRPC_COMPRESSION_REQUEST_ALGORITHM_MD_KEY, we shouldn't override that - // with the compression algorithm mapped from compression level. - // process compression level - grpc_compression_level effective_compression_level = GRPC_COMPRESS_LEVEL_NONE; - bool level_set = false; - if (op.data.send_initial_metadata.maybe_compression_level.is_set) { - effective_compression_level = - op.data.send_initial_metadata.maybe_compression_level.level; - level_set = true; - } else { - const grpc_compression_options copts = channel()->compression_options(); - if (copts.default_level.is_set) { - level_set = true; - effective_compression_level = copts.default_level.level; - } - } - // Currently, only server side supports compression level setting. - if (level_set && !is_client()) { - const grpc_compression_algorithm calgo = - encodings_accepted_by_peer().CompressionAlgorithmForLevel( - effective_compression_level); - // The following metadata will be checked and removed by the message - // compression filter. It will be used as the call's compression - // algorithm. - md.Set(GrpcInternalEncodingRequest(), calgo); - } - // Ignore any te metadata key value pairs specified. - md.Remove(TeMetadata()); -} - -void Call::ProcessIncomingInitialMetadata( - grpc_metadata_batch& md, - absl::FunctionRef cancel_func) { - incoming_compression_algorithm_ = - md.Take(GrpcEncodingMetadata()).value_or(GRPC_COMPRESS_NONE); - encodings_accepted_by_peer_ = - md.Take(GrpcAcceptEncodingMetadata()) - .value_or(CompressionAlgorithmSet{GRPC_COMPRESS_NONE}); - - const grpc_compression_options compression_options = - channel_->compression_options(); - const grpc_compression_algorithm compression_algorithm = - incoming_compression_algorithm_; - if (GPR_UNLIKELY(!CompressionAlgorithmSet::FromUint32( - compression_options.enabled_algorithms_bitset) - .IsSet(compression_algorithm))) { - // check if algorithm is supported by current channel config - HandleCompressionAlgorithmDisabled(compression_algorithm, cancel_func); - } - // GRPC_COMPRESS_NONE is always set. - GPR_DEBUG_ASSERT(encodings_accepted_by_peer_.IsSet(GRPC_COMPRESS_NONE)); - if (GPR_UNLIKELY(!encodings_accepted_by_peer_.IsSet(compression_algorithm))) { - if (GRPC_TRACE_FLAG_ENABLED(grpc_compression_trace)) { - HandleCompressionAlgorithmNotAccepted(compression_algorithm); - } - } -} - -void Call::HandleCompressionAlgorithmNotAccepted( - grpc_compression_algorithm compression_algorithm) { - const char* algo_name = nullptr; - grpc_compression_algorithm_name(compression_algorithm, &algo_name); - gpr_log(GPR_ERROR, - "Compression algorithm ('%s') not present in the " - "accepted encodings (%s)", - algo_name, - std::string(encodings_accepted_by_peer_.ToString()).c_str()); -} - -void Call::HandleCompressionAlgorithmDisabled( - grpc_compression_algorithm compression_algorithm, - absl::FunctionRef cancel_func) { - const char* algo_name = nullptr; - grpc_compression_algorithm_name(compression_algorithm, &algo_name); - std::string error_msg = - absl::StrFormat("Compression algorithm '%s' is disabled.", algo_name); - gpr_log(GPR_ERROR, "%s", error_msg.c_str()); - cancel_func(grpc_error_set_int(absl::UnimplementedError(error_msg), - StatusIntProperty::kRpcStatus, - GRPC_STATUS_UNIMPLEMENTED)); -} - /////////////////////////////////////////////////////////////////////////////// // FilterStackCall // To be removed once promise conversion is complete @@ -529,6 +410,11 @@ class FilterStackCall final : public Call { return context_[elem].value; } + grpc_compression_algorithm compression_for_level( + grpc_compression_level level) override { + return encodings_accepted_by_peer_.CompressionAlgorithmForLevel(level); + } + bool is_trailers_only() const override { bool result = is_trailers_only_; GPR_DEBUG_ASSERT(!result || recv_initial_metadata_.TransportSize() == 0); @@ -546,6 +432,18 @@ class FilterStackCall final : public Call { return authority_metadata->as_string_view(); } + grpc_compression_algorithm test_only_compression_algorithm() override { + return incoming_compression_algorithm_; + } + + uint32_t test_only_message_flags() override { + return test_only_last_message_flags_; + } + + uint32_t test_only_encodings_accepted_by_peer() override { + return encodings_accepted_by_peer_.ToLegacyBitmask(); + } + static size_t InitialSizeEstimate() { return sizeof(FilterStackCall) + sizeof(BatchControl) * kMaxConcurrentBatches; @@ -626,6 +524,7 @@ class FilterStackCall final : public Call { void FinishStep(PendingOp op); void ProcessDataAfterMetadata(); void ReceivingStreamReady(grpc_error_handle error); + void ValidateFilteredMetadata(); void ReceivingInitialMetadataReady(grpc_error_handle error); void ReceivingTrailingMetadataReady(grpc_error_handle error); void FinishBatch(grpc_error_handle error); @@ -650,6 +549,10 @@ class FilterStackCall final : public Call { grpc_closure* start_batch_closure); void SetFinalStatus(grpc_error_handle error); BatchControl* ReuseOrAllocateBatchControl(const grpc_op* ops); + void HandleCompressionAlgorithmDisabled( + grpc_compression_algorithm compression_algorithm) GPR_ATTRIBUTE_NOINLINE; + void HandleCompressionAlgorithmNotAccepted( + grpc_compression_algorithm compression_algorithm) GPR_ATTRIBUTE_NOINLINE; bool PrepareApplicationMetadata(size_t count, grpc_metadata* metadata, bool is_trailing); void PublishAppMetadata(grpc_metadata_batch* b, bool is_trailing); @@ -693,6 +596,13 @@ class FilterStackCall final : public Call { // completed grpc_call_final_info final_info_; + // Compression algorithm for *incoming* data + grpc_compression_algorithm incoming_compression_algorithm_ = + GRPC_COMPRESS_NONE; + // Supported encodings (compression algorithms), a bitset. + // Always support no compression. + CompressionAlgorithmSet encodings_accepted_by_peer_{GRPC_COMPRESS_NONE}; + // Contexts for various subsystems (security, tracing, ...). grpc_call_context_element context_[GRPC_CONTEXT_COUNT] = {}; @@ -706,6 +616,7 @@ class FilterStackCall final : public Call { grpc_closure receiving_stream_ready_; grpc_closure receiving_initial_metadata_ready_; grpc_closure receiving_trailing_metadata_ready_; + uint32_t test_only_last_message_flags_ = 0; // Status about operation of call bool sent_server_trailing_metadata_ = false; gpr_atm cancelled_with_error_ = 0; @@ -1122,8 +1033,11 @@ void FilterStackCall::PublishAppMetadata(grpc_metadata_batch* b, } void FilterStackCall::RecvInitialFilter(grpc_metadata_batch* b) { - ProcessIncomingInitialMetadata( - *b, [this](absl::Status err) { CancelWithError(std::move(err)); }); + incoming_compression_algorithm_ = + b->Take(GrpcEncodingMetadata()).value_or(GRPC_COMPRESS_NONE); + encodings_accepted_by_peer_ = + b->Take(GrpcAcceptEncodingMetadata()) + .value_or(CompressionAlgorithmSet{GRPC_COMPRESS_NONE}); PublishAppMetadata(b, false); } @@ -1290,11 +1204,11 @@ void FilterStackCall::BatchControl::ProcessDataAfterMetadata() { call->receiving_message_ = false; FinishStep(PendingOp::kRecvMessage); } else { - call->NoteLastMessageFlags(call->receiving_stream_flags_); + call->test_only_last_message_flags_ = call->receiving_stream_flags_; if ((call->receiving_stream_flags_ & GRPC_WRITE_INTERNAL_COMPRESS) && - (call->incoming_compression_algorithm() != GRPC_COMPRESS_NONE)) { + (call->incoming_compression_algorithm_ != GRPC_COMPRESS_NONE)) { *call->receiving_buffer_ = grpc_raw_compressed_byte_buffer_create( - nullptr, 0, call->incoming_compression_algorithm()); + nullptr, 0, call->incoming_compression_algorithm_); } else { *call->receiving_buffer_ = grpc_raw_byte_buffer_create(nullptr, 0); } @@ -1335,6 +1249,50 @@ void FilterStackCall::BatchControl::ReceivingStreamReady( } } +void FilterStackCall::HandleCompressionAlgorithmDisabled( + grpc_compression_algorithm compression_algorithm) { + const char* algo_name = nullptr; + grpc_compression_algorithm_name(compression_algorithm, &algo_name); + std::string error_msg = + absl::StrFormat("Compression algorithm '%s' is disabled.", algo_name); + gpr_log(GPR_ERROR, "%s", error_msg.c_str()); + CancelWithStatus(GRPC_STATUS_UNIMPLEMENTED, error_msg.c_str()); +} + +void FilterStackCall::HandleCompressionAlgorithmNotAccepted( + grpc_compression_algorithm compression_algorithm) { + const char* algo_name = nullptr; + grpc_compression_algorithm_name(compression_algorithm, &algo_name); + gpr_log(GPR_ERROR, + "Compression algorithm ('%s') not present in the " + "accepted encodings (%s)", + algo_name, + std::string(encodings_accepted_by_peer_.ToString()).c_str()); +} + +void FilterStackCall::BatchControl::ValidateFilteredMetadata() { + FilterStackCall* call = call_; + + const grpc_compression_options compression_options = + call->channel()->compression_options(); + const grpc_compression_algorithm compression_algorithm = + call->incoming_compression_algorithm_; + if (GPR_UNLIKELY(!CompressionAlgorithmSet::FromUint32( + compression_options.enabled_algorithms_bitset) + .IsSet(compression_algorithm))) { + // check if algorithm is supported by current channel config + call->HandleCompressionAlgorithmDisabled(compression_algorithm); + } + // GRPC_COMPRESS_NONE is always set. + GPR_DEBUG_ASSERT(call->encodings_accepted_by_peer_.IsSet(GRPC_COMPRESS_NONE)); + if (GPR_UNLIKELY( + !call->encodings_accepted_by_peer_.IsSet(compression_algorithm))) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_compression_trace)) { + call->HandleCompressionAlgorithmNotAccepted(compression_algorithm); + } + } +} + void FilterStackCall::BatchControl::ReceivingInitialMetadataReady( grpc_error_handle error) { FilterStackCall* call = call_; @@ -1345,6 +1303,9 @@ void FilterStackCall::BatchControl::ReceivingInitialMetadataReady( grpc_metadata_batch* md = &call->recv_initial_metadata_; call->RecvInitialFilter(md); + // TODO(ctiller): this could be moved into recv_initial_filter now + ValidateFilteredMetadata(); + absl::optional deadline = md->get(GrpcTimeoutMetadata()); if (deadline.has_value() && !call->is_client()) { call_->set_send_deadline(*deadline); @@ -1493,6 +1454,36 @@ grpc_call_error FilterStackCall::StartBatch(const grpc_op* ops, size_t nops, error = GRPC_CALL_ERROR_TOO_MANY_OPERATIONS; goto done_with_error; } + // TODO(juanlishen): If the user has already specified a compression + // algorithm by setting the initial metadata with key of + // GRPC_COMPRESSION_REQUEST_ALGORITHM_MD_KEY, we shouldn't override that + // with the compression algorithm mapped from compression level. + // process compression level + grpc_compression_level effective_compression_level = + GRPC_COMPRESS_LEVEL_NONE; + bool level_set = false; + if (op->data.send_initial_metadata.maybe_compression_level.is_set) { + effective_compression_level = + op->data.send_initial_metadata.maybe_compression_level.level; + level_set = true; + } else { + const grpc_compression_options copts = + channel()->compression_options(); + if (copts.default_level.is_set) { + level_set = true; + effective_compression_level = copts.default_level.level; + } + } + // Currently, only server side supports compression level setting. + if (level_set && !is_client()) { + const grpc_compression_algorithm calgo = + encodings_accepted_by_peer_.CompressionAlgorithmForLevel( + effective_compression_level); + // The following metadata will be checked and removed by the message + // compression filter. It will be used as the call's compression + // algorithm. + send_initial_metadata_.Set(GrpcInternalEncodingRequest(), calgo); + } if (op->data.send_initial_metadata.count > INT_MAX) { error = GRPC_CALL_ERROR_INVALID_METADATA; goto done_with_error; @@ -1505,7 +1496,8 @@ grpc_call_error FilterStackCall::StartBatch(const grpc_op* ops, size_t nops, error = GRPC_CALL_ERROR_INVALID_METADATA; goto done_with_error; } - PrepareOutgoingInitialMetadata(*op, send_initial_metadata_); + // Ignore any te metadata key value pairs specified. + send_initial_metadata_.Remove(TeMetadata()); // TODO(ctiller): just make these the same variable? if (is_client() && send_deadline() != Timestamp::InfFuture()) { send_initial_metadata_.Set(GrpcTimeoutMetadata(), send_deadline()); @@ -2038,6 +2030,16 @@ class PromiseBasedCall : public Call, } } + grpc_compression_algorithm test_only_compression_algorithm() override { + abort(); + } + uint32_t test_only_message_flags() override { abort(); } + uint32_t test_only_encodings_accepted_by_peer() override { abort(); } + grpc_compression_algorithm compression_for_level( + grpc_compression_level) override { + abort(); + } + // This should return nullptr for the promise stack (and alternative means // for that functionality be invented) grpc_call_stack* call_stack() override { return nullptr; } @@ -2095,11 +2097,6 @@ class PromiseBasedCall : public Call, ~PromiseBasedCall() override { if (non_owning_wakeable_) non_owning_wakeable_->DropActivity(); if (cq_) GRPC_CQ_INTERNAL_UNREF(cq_, "bind"); - for (int i = 0; i < GRPC_CONTEXT_COUNT; i++) { - if (context_[i].destroy) { - context_[i].destroy(context_[i].value); - } - } } // Enumerates why a Completion is still pending @@ -2107,7 +2104,6 @@ class PromiseBasedCall : public Call, // We're in the midst of starting a batch of operations kStartingBatch = 0, // The following correspond with the batch operations from above - kSendInitialMetadata, kReceiveInitialMetadata, kReceiveStatusOnClient, kReceiveCloseOnServer = kReceiveStatusOnClient, @@ -2121,8 +2117,6 @@ class PromiseBasedCall : public Call, switch (reason) { case PendingOp::kStartingBatch: return "StartingBatch"; - case PendingOp::kSendInitialMetadata: - return "SendInitialMetadata"; case PendingOp::kReceiveInitialMetadata: return "ReceiveInitialMetadata"; case PendingOp::kReceiveStatusOnClient: @@ -2153,23 +2147,10 @@ class PromiseBasedCall : public Call, ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); // Stringify a completion std::string CompletionString(const Completion& completion) const { - auto& pending = completion_info_[completion.index()].pending; - const char* success_string = ":unknown"; - switch (pending.success) { - case CompletionSuccess::kSuccess: - success_string = ""; - break; - case CompletionSuccess::kFailure: - success_string = ":failure"; - break; - case CompletionSuccess::kForceSuccess: - success_string = ":force-success"; - break; - } return completion.has_value() - ? absl::StrFormat("%d:tag=%p%s", - static_cast(completion.index()), - pending.tag, success_string) + ? absl::StrFormat( + "%d:tag=%p", static_cast(completion.index()), + completion_info_[completion.index()].pending.tag) : "no-completion"; } // Finish one op on the completion. Must have been previously been added. @@ -2179,9 +2160,6 @@ class PromiseBasedCall : public Call, // Mark the completion as failed. Does not finish it. void FailCompletion(const Completion& completion, SourceLocation source_location = {}); - // Mark the completion as infallible. Overrides FailCompletion to report - // success always. - void ForceCompletionSuccess(const Completion& completion); // Run the promise polling loop until it stalls. void Update() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); // Update the promise state once. @@ -2245,13 +2223,13 @@ class PromiseBasedCall : public Call, void StartRecvMessage(const grpc_op& op, const Completion& completion, PipeReceiver* receiver) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); - void PollRecvMessage() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); - void CancelRecvMessage(SourceLocation = {}) + void PollRecvMessage(grpc_compression_algorithm compression_algorithm) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + void CancelRecvMessage() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); void StartSendMessage(const grpc_op& op, const Completion& completion, PipeSender* sender) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); - void PollSendMessage() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + bool PollSendMessage() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); void CancelSendMessage() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); bool completed() const ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) { @@ -2261,18 +2239,14 @@ class PromiseBasedCall : public Call, bool is_sending() const ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) { return outstanding_send_.has_value(); } - bool is_receiving() const ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) { - return outstanding_recv_.has_value(); - } private: - enum class CompletionSuccess : uint8_t { kSuccess, kFailure, kForceSuccess }; union CompletionInfo { struct Pending { // Bitmask of PendingOps uint8_t pending_op_bits; bool is_closure; - CompletionSuccess success; + bool success; void* tag; } pending; grpc_cq_completion completion; @@ -2285,8 +2259,8 @@ class PromiseBasedCall : public Call, // Ref the Handle (not the activity). void Ref() { refs_.fetch_add(1, std::memory_order_relaxed); } - // Activity is going away... drop its reference and sever the - // connection back. + // Activity is going away... drop its reference and sever the connection + // back. void DropActivity() ABSL_LOCKS_EXCLUDED(mu_) { auto unref = absl::MakeCleanup([this]() { Unref(); }); MutexLock lock(&mu_); @@ -2329,9 +2303,9 @@ class PromiseBasedCall : public Call, } mutable Mutex mu_; - // We have two initial refs: one for the wakeup that this is - // created for, and will be dropped by Wakeup, and the other for - // the activity which is dropped by DropActivity. + // We have two initial refs: one for the wakeup that this is created for, + // and will be dropped by Wakeup, and the other for the activity which is + // dropped by DropActivity. std::atomic refs_{2}; PromiseBasedCall* call_ ABSL_GUARDED_BY(mu_); }; @@ -2465,16 +2439,15 @@ void* PromiseBasedCall::ContextGet(grpc_context_index elem) const { PromiseBasedCall::Completion PromiseBasedCall::StartCompletion( void* tag, bool is_closure, const grpc_op* ops) { Completion c(BatchSlotForOp(ops[0].op)); - if (!is_closure) { - grpc_cq_begin_op(cq(), tag); - } - completion_info_[c.index()].pending = { - PendingOpBit(PendingOp::kStartingBatch), is_closure, - CompletionSuccess::kSuccess, tag}; if (grpc_call_trace.enabled()) { gpr_log(GPR_INFO, "%s[call] StartCompletion %s tag=%p", DebugTag().c_str(), CompletionString(c).c_str(), tag); } + if (!is_closure) { + grpc_cq_begin_op(cq(), tag); + } + completion_info_[c.index()].pending = { + PendingOpBit(PendingOp::kStartingBatch), is_closure, true, tag}; return c; } @@ -2499,27 +2472,7 @@ void PromiseBasedCall::FailCompletion(const Completion& completion, "%s[call] FailCompletion %s", DebugTag().c_str(), CompletionString(completion).c_str()); } - auto& success = completion_info_[completion.index()].pending.success; - switch (success) { - case CompletionSuccess::kSuccess: - success = CompletionSuccess::kFailure; - break; - case CompletionSuccess::kFailure: - case CompletionSuccess::kForceSuccess: - break; - } -} - -void PromiseBasedCall::ForceCompletionSuccess(const Completion& completion) { - auto& success = completion_info_[completion.index()].pending.success; - switch (success) { - case CompletionSuccess::kSuccess: - case CompletionSuccess::kFailure: - success = CompletionSuccess::kForceSuccess; - break; - case CompletionSuccess::kForceSuccess: - break; - } + completion_info_[completion.index()].pending.success = false; } void PromiseBasedCall::FinishOpOnCompletion(Completion* completion, @@ -2527,8 +2480,7 @@ void PromiseBasedCall::FinishOpOnCompletion(Completion* completion, if (grpc_call_trace.enabled()) { auto pending_op_bits = completion_info_[completion->index()].pending.pending_op_bits; - bool success = completion_info_[completion->index()].pending.success != - CompletionSuccess::kFailure; + bool success = completion_info_[completion->index()].pending.success; std::vector pending; for (size_t i = 0; i < 8 * sizeof(pending_op_bits); i++) { if (static_cast(i) == reason) continue; @@ -2537,9 +2489,7 @@ void PromiseBasedCall::FinishOpOnCompletion(Completion* completion, } } gpr_log( - GPR_INFO, - "%s[call] FinishOpOnCompletion tag:%p completion:%s " - "finish:%s %s", + GPR_INFO, "%s[call] FinishOpOnCompletion tag:%p %s %s %s", DebugTag().c_str(), completion_info_[completion->index()].pending.tag, CompletionString(*completion).c_str(), PendingOpString(reason), (pending.empty() @@ -2552,9 +2502,8 @@ void PromiseBasedCall::FinishOpOnCompletion(Completion* completion, CompletionInfo::Pending& pending = completion_info_[i].pending; GPR_ASSERT(pending.pending_op_bits & PendingOpBit(reason)); pending.pending_op_bits &= ~PendingOpBit(reason); + auto error = pending.success ? absl::OkStatus() : absl::CancelledError(); if (pending.pending_op_bits == 0) { - const bool success = pending.success != CompletionSuccess::kFailure; - auto error = success ? absl::OkStatus() : absl::CancelledError(); if (pending.is_closure) { ExecCtx::Run(DEBUG_LOCATION, static_cast(pending.tag), error); @@ -2631,23 +2580,26 @@ void PromiseBasedCall::StartSendMessage(const grpc_op& op, } } -void PromiseBasedCall::PollSendMessage() { - if (!outstanding_send_.has_value()) return; +bool PromiseBasedCall::PollSendMessage() { + if (!outstanding_send_.has_value()) return true; Poll r = (*outstanding_send_)(); if (const bool* result = absl::get_if(&r)) { if (grpc_call_trace.enabled()) { gpr_log(GPR_DEBUG, "%sPollSendMessage completes %s", DebugTag().c_str(), *result ? "successfully" : "with failure"); } - if (!*result) FailCompletion(send_message_completion_); + if (!*result) { + FailCompletion(send_message_completion_); + return false; + } FinishOpOnCompletion(&send_message_completion_, PendingOp::kSendMessage); outstanding_send_.reset(); } + return true; } void PromiseBasedCall::CancelSendMessage() { if (!outstanding_send_.has_value()) return; - FailCompletion(send_message_completion_); FinishOpOnCompletion(&send_message_completion_, PendingOp::kSendMessage); outstanding_send_.reset(); } @@ -2662,18 +2614,18 @@ void PromiseBasedCall::StartRecvMessage(const grpc_op& op, outstanding_recv_.emplace(receiver->Next()); } -void PromiseBasedCall::PollRecvMessage() { +void PromiseBasedCall::PollRecvMessage( + grpc_compression_algorithm incoming_compression_algorithm) { if (!outstanding_recv_.has_value()) return; Poll> r = (*outstanding_recv_)(); if (auto* result = absl::get_if>(&r)) { outstanding_recv_.reset(); if (result->has_value()) { MessageHandle& message = **result; - NoteLastMessageFlags(message->flags()); if ((message->flags() & GRPC_WRITE_INTERNAL_COMPRESS) && - (incoming_compression_algorithm() != GRPC_COMPRESS_NONE)) { + (incoming_compression_algorithm != GRPC_COMPRESS_NONE)) { *recv_message_ = grpc_raw_compressed_byte_buffer_create( - nullptr, 0, incoming_compression_algorithm()); + nullptr, 0, incoming_compression_algorithm); } else { *recv_message_ = grpc_raw_byte_buffer_create(nullptr, 0); } @@ -2681,8 +2633,7 @@ void PromiseBasedCall::PollRecvMessage() { &(*recv_message_)->data.raw.slice_buffer); if (grpc_call_trace.enabled()) { gpr_log(GPR_INFO, - "%s[call] PollRecvMessage: outstanding_recv " - "finishes: received " + "%s[call] PollRecvMessage: outstanding_recv finishes: received " "%" PRIdPTR " byte message", DebugTag().c_str(), (*recv_message_)->data.raw.slice_buffer.length); @@ -2690,8 +2641,7 @@ void PromiseBasedCall::PollRecvMessage() { } else if (result->cancelled()) { if (grpc_call_trace.enabled()) { gpr_log(GPR_INFO, - "%s[call] PollRecvMessage: outstanding_recv " - "finishes: received " + "%s[call] PollRecvMessage: outstanding_recv finishes: received " "end-of-stream with error", DebugTag().c_str()); } @@ -2700,8 +2650,7 @@ void PromiseBasedCall::PollRecvMessage() { } else { if (grpc_call_trace.enabled()) { gpr_log(GPR_INFO, - "%s[call] PollRecvMessage: outstanding_recv " - "finishes: received " + "%s[call] PollRecvMessage: outstanding_recv finishes: received " "end-of-stream", DebugTag().c_str()); } @@ -2711,10 +2660,8 @@ void PromiseBasedCall::PollRecvMessage() { } else if (completed_) { if (grpc_call_trace.enabled()) { gpr_log(GPR_INFO, - "%s[call] UpdateOnce: outstanding_recv finishes: " - "promise has " - "completed without queuing a message, forcing " - "end-of-stream", + "%s[call] UpdateOnce: outstanding_recv finishes: promise has " + "completed without queuing a message, forcing end-of-stream", DebugTag().c_str()); } outstanding_recv_.reset(); @@ -2723,18 +2670,10 @@ void PromiseBasedCall::PollRecvMessage() { } } -void PromiseBasedCall::CancelRecvMessage(SourceLocation location) { +void PromiseBasedCall::CancelRecvMessage() { if (!outstanding_recv_.has_value()) return; - if (grpc_call_trace.enabled()) { - gpr_log(location.file(), location.line(), GPR_LOG_SEVERITY_DEBUG, - "%s[call] CancelRecvMessage: op:%s outstanding_recv:%s", - DebugTag().c_str(), - CompletionString(recv_message_completion_).c_str(), - outstanding_recv_.has_value() ? "present" : "absent"); - } *recv_message_ = nullptr; outstanding_recv_.reset(); - FailCompletion(recv_message_completion_); FinishOpOnCompletion(&recv_message_completion_, PendingOp::kReceiveMessage); } @@ -2810,9 +2749,9 @@ class ClientPromiseBasedCall final : public PromiseBasedCall { send_initial_metadata_.reset(); recv_status_on_client_ = absl::monostate(); promise_ = ArenaPromise(); - // Need to destroy the pipes under the ScopedContext above, so we - // move them out here and then allow the destructors to run at - // end of scope, but before context. + // Need to destroy the pipes under the ScopedContext above, so we move them + // out here and then allow the destructors to run at end of scope, but + // before context. auto c2s = std::move(client_to_server_messages_); auto s2c = std::move(server_to_client_messages_); auto sim = std::move(server_initial_metadata_); @@ -2864,6 +2803,7 @@ class ClientPromiseBasedCall final : public PromiseBasedCall { arena()}; Pipe client_to_server_messages_ ABSL_GUARDED_BY(mu()){arena()}; Pipe server_to_client_messages_ ABSL_GUARDED_BY(mu()){arena()}; + ClientMetadataHandle send_initial_metadata_; grpc_metadata_array* recv_initial_metadata_ ABSL_GUARDED_BY(mu()) = nullptr; absl::variant> server_initial_metadata_ready_; - absl::optional - client_initial_metadata_sent_; - Completion send_initial_metadata_completion_ ABSL_GUARDED_BY(mu()); + absl::optional incoming_compression_algorithm_; Completion recv_initial_metadata_completion_ ABSL_GUARDED_BY(mu()); Completion recv_status_on_client_completion_ ABSL_GUARDED_BY(mu()); Completion close_send_completion_ ABSL_GUARDED_BY(mu()); @@ -2884,12 +2822,12 @@ class ClientPromiseBasedCall final : public PromiseBasedCall { void ClientPromiseBasedCall::StartPromise( ClientMetadataHandle client_initial_metadata) { GPR_ASSERT(!promise_.has_value()); - auto token = ClientInitialMetadataOutstandingToken::New(); - client_initial_metadata_sent_.emplace(token.Wait()); promise_ = channel()->channel_stack()->MakeClientCallPromise(CallArgs{ - std::move(client_initial_metadata), std::move(token), - &server_initial_metadata_.sender, &client_to_server_messages_.receiver, - &server_to_client_messages_.sender}); + std::move(client_initial_metadata), + &server_initial_metadata_.sender, + &client_to_server_messages_.receiver, + &server_to_client_messages_.sender, + }); } void ClientPromiseBasedCall::CancelWithErrorLocked(grpc_error_handle error) { @@ -2939,22 +2877,13 @@ void ClientPromiseBasedCall::CommitBatch(const grpc_op* ops, size_t nops, const grpc_op& op = ops[op_idx]; switch (op.op) { case GRPC_OP_SEND_INITIAL_METADATA: { + // compression not implemented + GPR_ASSERT( + !op.data.send_initial_metadata.maybe_compression_level.is_set); if (!completed()) { CToMetadata(op.data.send_initial_metadata.metadata, op.data.send_initial_metadata.count, send_initial_metadata_.get()); - PrepareOutgoingInitialMetadata(op, *send_initial_metadata_); - if (send_deadline() != Timestamp::InfFuture()) { - send_initial_metadata_->Set(GrpcTimeoutMetadata(), send_deadline()); - } - send_initial_metadata_->Set( - WaitForReady(), - WaitForReady::ValueType{ - (op.flags & GRPC_INITIAL_METADATA_WAIT_FOR_READY) != 0, - (op.flags & - GRPC_INITIAL_METADATA_WAIT_FOR_READY_EXPLICITLY_SET) != 0}); - send_initial_metadata_completion_ = - AddOpToCompletion(completion, PendingOp::kSendInitialMetadata); StartPromise(std::move(send_initial_metadata_)); } } break; @@ -2969,7 +2898,6 @@ void ClientPromiseBasedCall::CommitBatch(const grpc_op* ops, size_t nops, case GRPC_OP_RECV_STATUS_ON_CLIENT: { recv_status_on_client_completion_ = AddOpToCompletion(completion, PendingOp::kReceiveStatusOnClient); - ForceCompletionSuccess(completion); if (auto* finished_metadata = absl::get_if(&recv_status_on_client_)) { PublishStatus(op.data.recv_status_on_client, @@ -3019,10 +2947,8 @@ grpc_call_error ClientPromiseBasedCall::StartBatch(const grpc_op* ops, } void ClientPromiseBasedCall::PublishInitialMetadata(ServerMetadata* metadata) { - ProcessIncomingInitialMetadata(*metadata, [this](absl::Status status) { - mu()->AssertHeld(); - CancelWithErrorLocked(std::move(status)); - }); + incoming_compression_algorithm_ = + metadata->Take(GrpcEncodingMetadata()).value_or(GRPC_COMPRESS_NONE); server_initial_metadata_ready_.reset(); GPR_ASSERT(recv_initial_metadata_ != nullptr); PublishMetadataArray(metadata, @@ -3042,33 +2968,21 @@ void ClientPromiseBasedCall::UpdateOnce() { PollStateDebugString().c_str(), promise_.has_value() ? "true" : "false"); } - if (client_initial_metadata_sent_.has_value()) { - Poll p = (*client_initial_metadata_sent_)(); - bool* r = absl::get_if(&p); - if (r != nullptr) { - client_initial_metadata_sent_.reset(); - if (!*r) FailCompletion(send_initial_metadata_completion_); - FinishOpOnCompletion(&send_initial_metadata_completion_, - PendingOp::kSendInitialMetadata); - } - } if (server_initial_metadata_ready_.has_value()) { Poll> r = (*server_initial_metadata_ready_)(); if (auto* server_initial_metadata = absl::get_if>(&r)) { - if (server_initial_metadata->has_value()) { - PublishInitialMetadata(server_initial_metadata->value().get()); - } else { - ServerMetadata no_metadata{GetContext()}; - PublishInitialMetadata(&no_metadata); - } + PublishInitialMetadata(server_initial_metadata->value().get()); } else if (completed()) { ServerMetadata no_metadata{GetContext()}; PublishInitialMetadata(&no_metadata); } } - PollSendMessage(); + if (!PollSendMessage()) { + Finish(ServerMetadataFromStatus(absl::Status( + absl::StatusCode::kInternal, "Failed to send message to server"))); + } if (!is_sending() && close_send_completion_.has_value()) { client_to_server_messages_.sender.Close(); FinishOpOnCompletion(&close_send_completion_, @@ -3084,15 +2998,12 @@ void ClientPromiseBasedCall::UpdateOnce() { }).c_str()); } if (auto* result = absl::get_if(&r)) { - if (!server_initial_metadata_ready_.has_value()) { - PollRecvMessage(); - } AcceptTransportStatsFromContext(); Finish(std::move(*result)); } } - if (!server_initial_metadata_ready_.has_value()) { - PollRecvMessage(); + if (incoming_compression_algorithm_.has_value()) { + PollRecvMessage(*incoming_compression_algorithm_); } } @@ -3102,8 +3013,6 @@ void ClientPromiseBasedCall::Finish(ServerMetadataHandle trailing_metadata) { trailing_metadata->DebugString().c_str()); } promise_ = ArenaPromise(); - CancelSendMessage(); - CancelRecvMessage(); ResetDeadline(); set_completed(); if (recv_initial_metadata_ != nullptr) { @@ -3119,18 +3028,8 @@ void ClientPromiseBasedCall::Finish(ServerMetadataHandle trailing_metadata) { (*server_initial_metadata_ready_)(); server_initial_metadata_ready_.reset(); if (auto* result = absl::get_if>(&r)) { - if (pending_initial_metadata) { - if (result->has_value()) { - PublishInitialMetadata(result->value().get()); - is_trailers_only_ = false; - } else { - ServerMetadata no_metadata{GetContext()}; - PublishInitialMetadata(&no_metadata); - is_trailers_only_ = true; - } - } else { - is_trailers_only_ = false; - } + if (pending_initial_metadata) PublishInitialMetadata(result->value().get()); + is_trailers_only_ = false; } else { if (pending_initial_metadata) { ServerMetadata no_metadata{GetContext()}; @@ -3261,10 +3160,9 @@ class ServerPromiseBasedCall final : public PromiseBasedCall { private: class RecvCloseOpCancelState { public: - // Request that receiver be filled in per - // grpc_op_recv_close_on_server. Returns true if the request can - // be fulfilled immediately. Returns false if the request will be - // fulfilled later. + // Request that receiver be filled in per grpc_op_recv_close_on_server. + // Returns true if the request can be fulfilled immediately. + // Returns false if the request will be fulfilled later. bool ReceiveCloseOnServerOpStarted(int* receiver) { switch (state_) { case kUnset: @@ -3282,20 +3180,18 @@ class ServerPromiseBasedCall final : public PromiseBasedCall { } // Mark the call as having completed. - // Returns true if this finishes a previous - // RequestReceiveCloseOnServer. - bool CompleteCallWithCancelledSetTo(bool cancelled) { + // Returns true if this finishes a previous RequestReceiveCloseOnServer. + bool CompleteCall(bool success) { switch (state_) { case kUnset: - state_ = cancelled ? kFinishedWithFailure : kFinishedWithSuccess; + state_ = success ? kFinishedWithSuccess : kFinishedWithFailure; return false; case kFinishedWithFailure: - return false; case kFinishedWithSuccess: abort(); // unreachable default: - *reinterpret_cast(state_) = cancelled ? 1 : 0; - state_ = cancelled ? kFinishedWithFailure : kFinishedWithSuccess; + *reinterpret_cast(state_) = success ? 0 : 1; + state_ = success ? kFinishedWithSuccess : kFinishedWithFailure; return true; } } @@ -3318,9 +3214,8 @@ class ServerPromiseBasedCall final : public PromiseBasedCall { static constexpr uintptr_t kUnset = 0; static constexpr uintptr_t kFinishedWithFailure = 1; static constexpr uintptr_t kFinishedWithSuccess = 2; - // Holds one of kUnset, kFinishedWithFailure, or - // kFinishedWithSuccess OR an int* that wants to receive the - // final status. + // Holds one of kUnset, kFinishedWithFailure, or kFinishedWithSuccess + // OR an int* that wants to receive the final status. uintptr_t state_ = kUnset; }; @@ -3328,10 +3223,6 @@ class ServerPromiseBasedCall final : public PromiseBasedCall { void CommitBatch(const grpc_op* ops, size_t nops, const Completion& completion) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu()); - void CommitBatchAfterCompletion(const grpc_op* ops, size_t nops, - const Completion& completion) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu()); - void Finish(ServerMetadataHandle result) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu()); friend class ServerCallContext; ServerCallContext call_context_; @@ -3347,7 +3238,8 @@ class ServerPromiseBasedCall final : public PromiseBasedCall { SendInitialMetadataState send_initial_metadata_state_ ABSL_GUARDED_BY(mu()) = absl::monostate{}; ServerMetadataHandle send_trailing_metadata_ ABSL_GUARDED_BY(mu()); - bool force_metadata_send_ ABSL_GUARDED_BY(mu()) = false; + grpc_compression_algorithm incoming_compression_algorithm_ + ABSL_GUARDED_BY(mu()); RecvCloseOpCancelState recv_close_op_cancel_state_ ABSL_GUARDED_BY(mu()); Completion recv_close_completion_ ABSL_GUARDED_BY(mu()); bool cancel_send_and_receive_ ABSL_GUARDED_BY(mu()) = false; @@ -3368,8 +3260,7 @@ ServerPromiseBasedCall::ServerPromiseBasedCall(Arena* arena, MutexLock lock(mu()); ScopedContext activity_context(this); promise_ = channel()->channel_stack()->MakeServerCallPromise( - CallArgs{nullptr, ClientInitialMetadataOutstandingToken::Empty(), nullptr, - nullptr, nullptr}); + CallArgs{nullptr, nullptr, nullptr, nullptr}); } Poll ServerPromiseBasedCall::PollTopOfCall() { @@ -3390,14 +3281,7 @@ Poll ServerPromiseBasedCall::PollTopOfCall() { } PollSendMessage(); - PollRecvMessage(); - - if (force_metadata_send_) GPR_ASSERT(!is_sending()); - - if (grpc_call_trace.enabled()) { - gpr_log(GPR_INFO, "%s[call] PollTopOfCall: is_sending=%s", - DebugTag().c_str(), is_sending() ? "yes" : "no"); - } + PollRecvMessage(incoming_compression_algorithm_); if (!is_sending() && send_trailing_metadata_ != nullptr) { server_to_client_messages_->Close(); @@ -3442,53 +3326,39 @@ void ServerPromiseBasedCall::UpdateOnce() { }).c_str()); } if (auto* result = absl::get_if(&r)) { - if (!(*result)->get(GrpcCallWasCancelled()).value_or(false)) { - PollRecvMessage(); + if (grpc_call_trace.enabled()) { + gpr_log(GPR_INFO, "%s[call] UpdateOnce: GotResult %s result:%s", + DebugTag().c_str(), + recv_close_op_cancel_state_.ToString().c_str(), + (*result)->DebugString().c_str()); } - Finish(std::move(*result)); + if (recv_close_op_cancel_state_.CompleteCall( + (*result)->get(GrpcStatusFromWire()).value_or(false))) { + FinishOpOnCompletion(&recv_close_completion_, + PendingOp::kReceiveCloseOnServer); + } + channelz::ServerNode* channelz_node = server_->channelz_node(); + if (channelz_node != nullptr) { + if ((*result) + ->get(GrpcStatusMetadata()) + .value_or(GRPC_STATUS_UNKNOWN) == GRPC_STATUS_OK) { + channelz_node->RecordCallSucceeded(); + } else { + channelz_node->RecordCallFailed(); + } + } + if (send_status_from_server_completion_.has_value()) { + FinishOpOnCompletion(&send_status_from_server_completion_, + PendingOp::kSendStatusFromServer); + } + CancelSendMessage(); + CancelRecvMessage(); + set_completed(); promise_ = ArenaPromise(); } } } -void ServerPromiseBasedCall::Finish(ServerMetadataHandle result) { - if (grpc_call_trace.enabled()) { - gpr_log(GPR_INFO, "%s[call] UpdateOnce: GotResult %s result:%s", - DebugTag().c_str(), recv_close_op_cancel_state_.ToString().c_str(), - result->DebugString().c_str()); - } - if (recv_close_op_cancel_state_.CompleteCallWithCancelledSetTo( - result->get(GrpcCallWasCancelled()).value_or(true))) { - FinishOpOnCompletion(&recv_close_completion_, - PendingOp::kReceiveCloseOnServer); - } - channelz::ServerNode* channelz_node = server_->channelz_node(); - if (channelz_node != nullptr) { - if (result->get(GrpcStatusMetadata()).value_or(GRPC_STATUS_UNKNOWN) == - GRPC_STATUS_OK) { - channelz_node->RecordCallSucceeded(); - } else { - channelz_node->RecordCallFailed(); - } - } - if (send_status_from_server_completion_.has_value()) { - FinishOpOnCompletion(&send_status_from_server_completion_, - PendingOp::kSendStatusFromServer); - } - CancelSendMessage(); - CancelRecvMessage(); - set_completed(); - // TODO(ctiller): this will probably need to be inlined somehow for - // performance - InternalRef("propagate_cancel"); - channel()->event_engine()->Run([this]() { - ApplicationCallbackExecCtx callback_exec_ctx; - ExecCtx exec_ctx; - PropagateCancellationToChildren(); - InternalUnref("propagate_cancel"); - }); -} - grpc_call_error ServerPromiseBasedCall::ValidateBatch(const grpc_op* ops, size_t nops) const { BitSet<8> got_ops; @@ -3531,17 +3401,21 @@ void ServerPromiseBasedCall::CommitBatch(const grpc_op* ops, size_t nops, const grpc_op& op = ops[op_idx]; switch (op.op) { case GRPC_OP_SEND_INITIAL_METADATA: { - auto metadata = arena()->MakePooled(arena()); - PrepareOutgoingInitialMetadata(op, *metadata); - CToMetadata(op.data.send_initial_metadata.metadata, - op.data.send_initial_metadata.count, metadata.get()); - if (grpc_call_trace.enabled()) { - gpr_log(GPR_INFO, "%s[call] Send initial metadata", - DebugTag().c_str()); + // compression not implemented + GPR_ASSERT( + !op.data.send_initial_metadata.maybe_compression_level.is_set); + if (!completed()) { + auto metadata = arena()->MakePooled(arena()); + CToMetadata(op.data.send_initial_metadata.metadata, + op.data.send_initial_metadata.count, metadata.get()); + if (grpc_call_trace.enabled()) { + gpr_log(GPR_INFO, "%s[call] Send initial metadata", + DebugTag().c_str()); + } + auto* pipe = absl::get*>( + send_initial_metadata_state_); + send_initial_metadata_state_ = pipe->Push(std::move(metadata)); } - auto* pipe = absl::get*>( - send_initial_metadata_state_); - send_initial_metadata_state_ = pipe->Push(std::move(metadata)); } break; case GRPC_OP_SEND_MESSAGE: StartSendMessage(op, completion, server_to_client_messages_); @@ -3569,7 +3443,6 @@ void ServerPromiseBasedCall::CommitBatch(const grpc_op* ops, size_t nops, DebugTag().c_str(), recv_close_op_cancel_state_.ToString().c_str()); } - ForceCompletionSuccess(completion); if (!recv_close_op_cancel_state_.ReceiveCloseOnServerOpStarted( op.data.recv_close_on_server.cancelled)) { recv_close_completion_ = @@ -3584,35 +3457,6 @@ void ServerPromiseBasedCall::CommitBatch(const grpc_op* ops, size_t nops, } } -void ServerPromiseBasedCall::CommitBatchAfterCompletion( - const grpc_op* ops, size_t nops, const Completion& completion) { - for (size_t op_idx = 0; op_idx < nops; op_idx++) { - const grpc_op& op = ops[op_idx]; - switch (op.op) { - case GRPC_OP_SEND_INITIAL_METADATA: - case GRPC_OP_SEND_MESSAGE: - case GRPC_OP_RECV_MESSAGE: - case GRPC_OP_SEND_STATUS_FROM_SERVER: - FailCompletion(completion); - break; - case GRPC_OP_RECV_CLOSE_ON_SERVER: - if (grpc_call_trace.enabled()) { - gpr_log(GPR_INFO, "%s[call] StartBatch: RecvClose %s", - DebugTag().c_str(), - recv_close_op_cancel_state_.ToString().c_str()); - } - ForceCompletionSuccess(completion); - GPR_ASSERT(recv_close_op_cancel_state_.ReceiveCloseOnServerOpStarted( - op.data.recv_close_on_server.cancelled)); - break; - case GRPC_OP_RECV_STATUS_ON_CLIENT: - case GRPC_OP_SEND_CLOSE_FROM_CLIENT: - case GRPC_OP_RECV_INITIAL_METADATA: - abort(); // unreachable - } - } -} - grpc_call_error ServerPromiseBasedCall::StartBatch(const grpc_op* ops, size_t nops, void* notify_tag, @@ -3629,12 +3473,8 @@ grpc_call_error ServerPromiseBasedCall::StartBatch(const grpc_op* ops, } Completion completion = StartCompletion(notify_tag, is_notify_tag_closure, ops); - if (!completed()) { - CommitBatch(ops, nops, completion); - Update(); - } else { - CommitBatchAfterCompletion(ops, nops, completion); - } + CommitBatch(ops, nops, completion); + Update(); FinishOpOnCompletion(&completion, PendingOp::kStartingBatch); return GRPC_CALL_OK; } @@ -3643,7 +3483,6 @@ void ServerPromiseBasedCall::CancelWithErrorLocked(absl::Status error) { if (!promise_.has_value()) return; cancel_send_and_receive_ = true; send_trailing_metadata_ = ServerMetadataFromStatus(error, arena()); - send_trailing_metadata_->Set(GrpcCallWasCancelled(), true); ForceWakeup(); } @@ -3660,13 +3499,11 @@ ServerCallContext::MakeTopOfServerCallPromise( call_->server_to_client_messages_ = call_args.server_to_client_messages; call_->client_to_server_messages_ = call_args.client_to_server_messages; call_->send_initial_metadata_state_ = call_args.server_initial_metadata; + call_->incoming_compression_algorithm_ = + call_args.client_initial_metadata->get(GrpcEncodingMetadata()) + .value_or(GRPC_COMPRESS_NONE); call_->client_initial_metadata_ = std::move(call_args.client_initial_metadata); - call_->ProcessIncomingInitialMetadata( - *call_->client_initial_metadata_, [this](absl::Status status) { - call_->mu()->AssertHeld(); - call_->CancelWithErrorLocked(std::move(status)); - }); PublishMetadataArray(call_->client_initial_metadata_.get(), publish_initial_metadata); call_->ExternalRef(); @@ -3777,9 +3614,7 @@ uint32_t grpc_call_test_only_get_message_flags(grpc_call* call) { } uint32_t grpc_call_test_only_get_encodings_accepted_by_peer(grpc_call* call) { - return grpc_core::Call::FromC(call) - ->encodings_accepted_by_peer() - .ToLegacyBitmask(); + return grpc_core::Call::FromC(call)->test_only_encodings_accepted_by_peer(); } grpc_core::Arena* grpc_call_get_arena(grpc_call* call) { @@ -3828,9 +3663,7 @@ uint8_t grpc_call_is_client(grpc_call* call) { grpc_compression_algorithm grpc_call_compression_for_level( grpc_call* call, grpc_compression_level level) { - return grpc_core::Call::FromC(call) - ->encodings_accepted_by_peer() - .CompressionAlgorithmForLevel(level); + return grpc_core::Call::FromC(call)->compression_for_level(level); } bool grpc_call_is_trailers_only(const grpc_call* call) { diff --git a/src/core/lib/surface/lame_client.cc b/src/core/lib/surface/lame_client.cc index 7fbdf8e64b4..ecbc9eed098 100644 --- a/src/core/lib/surface/lame_client.cc +++ b/src/core/lib/surface/lame_client.cc @@ -79,7 +79,6 @@ ArenaPromise LameClientFilter::MakeCallPromise( if (args.server_to_client_messages != nullptr) { args.server_to_client_messages->Close(); } - args.client_initial_metadata_outstanding.Complete(true); return Immediate(ServerMetadataFromStatus(error_)); } diff --git a/src/core/lib/transport/metadata_batch.h b/src/core/lib/transport/metadata_batch.h index 64c1e8f912f..47884b34f80 100644 --- a/src/core/lib/transport/metadata_batch.h +++ b/src/core/lib/transport/metadata_batch.h @@ -396,15 +396,6 @@ struct GrpcStatusFromWire { static absl::string_view DisplayValue(bool x) { return x ? "true" : "false"; } }; -// Annotation to denote that this call qualifies for cancelled=1 for the -// RECV_CLOSE_ON_SERVER op -struct GrpcCallWasCancelled { - static absl::string_view DebugKey() { return "GrpcCallWasCancelled"; } - static constexpr bool kRepeatable = false; - using ValueType = bool; - static absl::string_view DisplayValue(bool x) { return x ? "true" : "false"; } -}; - // Annotation added by client surface code to denote wait-for-ready state struct WaitForReady { struct ValueType { @@ -1336,8 +1327,7 @@ using grpc_metadata_batch_base = grpc_core::MetadataMap< // Non-encodable things grpc_core::GrpcStreamNetworkState, grpc_core::PeerString, grpc_core::GrpcStatusContext, grpc_core::GrpcStatusFromWire, - grpc_core::GrpcCallWasCancelled, grpc_core::WaitForReady, - grpc_core::GrpcTrailersOnly>; + grpc_core::WaitForReady, grpc_core::GrpcTrailersOnly>; struct grpc_metadata_batch : public grpc_metadata_batch_base { using grpc_metadata_batch_base::grpc_metadata_batch_base; diff --git a/src/core/lib/transport/transport.h b/src/core/lib/transport/transport.h index f24109d8ab0..bcd151cc6af 100644 --- a/src/core/lib/transport/transport.h +++ b/src/core/lib/transport/transport.h @@ -27,7 +27,6 @@ #include #include -#include #include #include "absl/status/status.h" @@ -55,7 +54,6 @@ #include "src/core/lib/promise/arena_promise.h" #include "src/core/lib/promise/context.h" #include "src/core/lib/promise/detail/status.h" -#include "src/core/lib/promise/latch.h" #include "src/core/lib/promise/pipe.h" #include "src/core/lib/resource_quota/arena.h" #include "src/core/lib/slice/slice_buffer.h" @@ -146,70 +144,11 @@ struct StatusCastImpl { } }; -// Move only type that tracks call startup. -// Allows observation of when client_initial_metadata has been processed by the -// end of the local call stack. -// Interested observers can call Wait() to obtain a promise that will resolve -// when all local client_initial_metadata processing has completed. -// The result of this token is either true on successful completion, or false -// if the metadata was not sent. -// To set a successful completion, call Complete(true). For failure, call -// Complete(false). -// If Complete is not called, the destructor of a still held token will complete -// with failure. -// Transports should hold this token until client_initial_metadata has passed -// any flow control (eg MAX_CONCURRENT_STREAMS for http2). -class ClientInitialMetadataOutstandingToken { - public: - static ClientInitialMetadataOutstandingToken Empty() { - return ClientInitialMetadataOutstandingToken(); - } - static ClientInitialMetadataOutstandingToken New( - Arena* arena = GetContext()) { - ClientInitialMetadataOutstandingToken token; - token.latch_ = arena->New>(); - return token; - } - - ClientInitialMetadataOutstandingToken( - const ClientInitialMetadataOutstandingToken&) = delete; - ClientInitialMetadataOutstandingToken& operator=( - const ClientInitialMetadataOutstandingToken&) = delete; - ClientInitialMetadataOutstandingToken( - ClientInitialMetadataOutstandingToken&& other) noexcept - : latch_(std::exchange(other.latch_, nullptr)) {} - ClientInitialMetadataOutstandingToken& operator=( - ClientInitialMetadataOutstandingToken&& other) noexcept { - latch_ = std::exchange(other.latch_, nullptr); - return *this; - } - ~ClientInitialMetadataOutstandingToken() { - if (latch_ != nullptr) latch_->Set(false); - } - void Complete(bool success) { std::exchange(latch_, nullptr)->Set(success); } - - // Returns a promise that will resolve when this object (or its moved-from - // ancestor) is dropped. - auto Wait() { return latch_->Wait(); } - - private: - ClientInitialMetadataOutstandingToken() = default; - - Latch* latch_ = nullptr; -}; - -using ClientInitialMetadataOutstandingTokenWaitType = - decltype(std::declval().Wait()); - struct CallArgs { // Initial metadata from the client to the server. // During promise setup this can be manipulated by filters (and then // passed on to the next filter). ClientMetadataHandle client_initial_metadata; - // Token indicating that client_initial_metadata is still being processed. - // This should be moved around and only destroyed when the transport is - // satisfied that the metadata has passed any flow control measures it has. - ClientInitialMetadataOutstandingToken client_initial_metadata_outstanding; // Initial metadata from the server to the client. // Set once when it's available. // During promise setup filters can substitute their own latch for this @@ -392,12 +331,6 @@ struct grpc_transport_stream_op_batch { /// Is this stream traced bool is_traced : 1; - bool HasOp() const { - return send_initial_metadata || send_trailing_metadata || send_message || - recv_initial_metadata || recv_message || recv_trailing_metadata || - cancel_stream; - } - //************************************************************************** // remaining fields are initialized and used at the discretion of the // current handler of the op diff --git a/test/core/end2end/fixtures/h2_oauth2_tls12.cc b/test/core/end2end/fixtures/h2_oauth2_tls12.cc index 0a1d17b4fa6..5d3681cdca8 100644 --- a/test/core/end2end/fixtures/h2_oauth2_tls12.cc +++ b/test/core/end2end/fixtures/h2_oauth2_tls12.cc @@ -45,6 +45,8 @@ #define SERVER_KEY_PATH "src/core/tsi/test_creds/server1.key" static const char oauth2_md[] = "Bearer aaslkfjs424535asdf"; +static const char* client_identity_property_name = "smurf_name"; +static const char* client_identity = "Brainy Smurf"; struct fullstack_secure_fixture_data { std::string localaddr; @@ -68,7 +70,7 @@ typedef struct { size_t pseudo_refcount; } test_processor_state; -static void process_oauth2_success(void* state, grpc_auth_context*, +static void process_oauth2_success(void* state, grpc_auth_context* ctx, const grpc_metadata* md, size_t md_count, grpc_process_auth_metadata_done_cb cb, void* user_data) { @@ -80,6 +82,10 @@ static void process_oauth2_success(void* state, grpc_auth_context*, s = static_cast(state); GPR_ASSERT(s->pseudo_refcount == 1); GPR_ASSERT(oauth2 != nullptr); + grpc_auth_context_add_cstring_property(ctx, client_identity_property_name, + client_identity); + GPR_ASSERT(grpc_auth_context_set_peer_identity_property_name( + ctx, client_identity_property_name) == 1); cb(user_data, oauth2, 1, nullptr, 0, GRPC_STATUS_OK, nullptr); } diff --git a/test/core/end2end/fixtures/h2_oauth2_tls13.cc b/test/core/end2end/fixtures/h2_oauth2_tls13.cc index 997f8002bb2..df8b3307f97 100644 --- a/test/core/end2end/fixtures/h2_oauth2_tls13.cc +++ b/test/core/end2end/fixtures/h2_oauth2_tls13.cc @@ -45,6 +45,8 @@ #define SERVER_KEY_PATH "src/core/tsi/test_creds/server1.key" static const char oauth2_md[] = "Bearer aaslkfjs424535asdf"; +static const char* client_identity_property_name = "smurf_name"; +static const char* client_identity = "Brainy Smurf"; struct fullstack_secure_fixture_data { std::string localaddr; @@ -68,7 +70,7 @@ typedef struct { size_t pseudo_refcount; } test_processor_state; -static void process_oauth2_success(void* state, grpc_auth_context*, +static void process_oauth2_success(void* state, grpc_auth_context* ctx, const grpc_metadata* md, size_t md_count, grpc_process_auth_metadata_done_cb cb, void* user_data) { @@ -80,6 +82,10 @@ static void process_oauth2_success(void* state, grpc_auth_context*, s = static_cast(state); GPR_ASSERT(s->pseudo_refcount == 1); GPR_ASSERT(oauth2 != nullptr); + grpc_auth_context_add_cstring_property(ctx, client_identity_property_name, + client_identity); + GPR_ASSERT(grpc_auth_context_set_peer_identity_property_name( + ctx, client_identity_property_name) == 1); cb(user_data, oauth2, 1, nullptr, 0, GRPC_STATUS_OK, nullptr); } diff --git a/test/core/end2end/fixtures/proxy.cc b/test/core/end2end/fixtures/proxy.cc index 8e99c5b73a7..c1b7da1a446 100644 --- a/test/core/end2end/fixtures/proxy.cc +++ b/test/core/end2end/fixtures/proxy.cc @@ -210,7 +210,7 @@ static void on_p2s_sent_message(void* arg, int success) { grpc_op op; grpc_call_error err; - grpc_byte_buffer_destroy(std::exchange(pc->c2p_msg, nullptr)); + grpc_byte_buffer_destroy(pc->c2p_msg); if (!pc->proxy->shutdown && success) { op.op = GRPC_OP_RECV_MESSAGE; op.flags = 0; diff --git a/test/core/end2end/tests/filter_init_fails.cc b/test/core/end2end/tests/filter_init_fails.cc index 5435400bd5c..7665b72c703 100644 --- a/test/core/end2end/tests/filter_init_fails.cc +++ b/test/core/end2end/tests/filter_init_fails.cc @@ -21,7 +21,6 @@ #include #include -#include #include #include "absl/status/status.h" @@ -41,10 +40,7 @@ #include "src/core/lib/gprpp/status_helper.h" #include "src/core/lib/iomgr/closure.h" #include "src/core/lib/iomgr/error.h" -#include "src/core/lib/promise/arena_promise.h" -#include "src/core/lib/promise/promise.h" #include "src/core/lib/surface/channel_stack_type.h" -#include "src/core/lib/transport/transport.h" #include "test/core/end2end/cq_verifier.h" #include "test/core/end2end/end2end_tests.h" #include "test/core/util/test_config.h" @@ -452,23 +448,12 @@ static grpc_error_handle init_channel_elem( static void destroy_channel_elem(grpc_channel_element* /*elem*/) {} static const grpc_channel_filter test_filter = { - grpc_call_next_op, - [](grpc_channel_element*, grpc_core::CallArgs, - grpc_core::NextPromiseFactory) - -> grpc_core::ArenaPromise { - return grpc_core::Immediate(grpc_core::ServerMetadataFromStatus( - absl::PermissionDeniedError("access denied"))); - }, - grpc_channel_next_op, - 0, - init_call_elem, - grpc_call_stack_ignore_set_pollset_or_pollset_set, - destroy_call_elem, - 0, - init_channel_elem, - grpc_channel_stack_no_post_init, - destroy_channel_elem, - grpc_channel_next_get_info, + grpc_call_next_op, nullptr, + grpc_channel_next_op, 0, + init_call_elem, grpc_call_stack_ignore_set_pollset_or_pollset_set, + destroy_call_elem, 0, + init_channel_elem, grpc_channel_stack_no_post_init, + destroy_channel_elem, grpc_channel_next_get_info, "filter_init_fails"}; //****************************************************************************** diff --git a/test/core/end2end/tests/max_message_length.cc b/test/core/end2end/tests/max_message_length.cc index 99083650c1f..b778ae42b47 100644 --- a/test/core/end2end/tests/max_message_length.cc +++ b/test/core/end2end/tests/max_message_length.cc @@ -128,9 +128,6 @@ static void test_max_message_length_on_request(grpc_end2end_test_config config, grpc_status_code status; grpc_call_error error; grpc_slice details; - grpc_slice expect_in_details = grpc_slice_from_copied_string( - send_limit ? "Sent message larger than max (11 vs. 5)" - : "Received message larger than max (11 vs. 5)"); int was_cancelled = 2; grpc_channel_args* client_args = nullptr; @@ -269,10 +266,13 @@ static void test_max_message_length_on_request(grpc_end2end_test_config config, done: GPR_ASSERT(status == GRPC_STATUS_RESOURCE_EXHAUSTED); - GPR_ASSERT(grpc_slice_slice(details, expect_in_details) >= 0); + GPR_ASSERT( + grpc_slice_str_cmp( + details, send_limit + ? "Sent message larger than max (11 vs. 5)" + : "Received message larger than max (11 vs. 5)") == 0); grpc_slice_unref(details); - grpc_slice_unref(expect_in_details); grpc_metadata_array_destroy(&initial_metadata_recv); grpc_metadata_array_destroy(&trailing_metadata_recv); grpc_metadata_array_destroy(&request_metadata_recv); @@ -316,9 +316,6 @@ static void test_max_message_length_on_response(grpc_end2end_test_config config, grpc_status_code status; grpc_call_error error; grpc_slice details; - grpc_slice expect_in_details = grpc_slice_from_copied_string( - send_limit ? "Sent message larger than max (11 vs. 5)" - : "Received message larger than max (11 vs. 5)"); int was_cancelled = 2; grpc_channel_args* client_args = nullptr; @@ -458,10 +455,13 @@ static void test_max_message_length_on_response(grpc_end2end_test_config config, GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.method, "/service/method")); GPR_ASSERT(status == GRPC_STATUS_RESOURCE_EXHAUSTED); - GPR_ASSERT(grpc_slice_slice(details, expect_in_details) >= 0); + GPR_ASSERT( + grpc_slice_str_cmp( + details, send_limit + ? "Sent message larger than max (11 vs. 5)" + : "Received message larger than max (11 vs. 5)") == 0); grpc_slice_unref(details); - grpc_slice_unref(expect_in_details); grpc_metadata_array_destroy(&initial_metadata_recv); grpc_metadata_array_destroy(&trailing_metadata_recv); grpc_metadata_array_destroy(&request_metadata_recv); diff --git a/test/core/end2end/tests/streaming_error_response.cc b/test/core/end2end/tests/streaming_error_response.cc index d03f869da75..641549025b8 100644 --- a/test/core/end2end/tests/streaming_error_response.cc +++ b/test/core/end2end/tests/streaming_error_response.cc @@ -40,14 +40,10 @@ static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, const char* test_name, grpc_channel_args* client_args, grpc_channel_args* server_args, - bool request_status_early, - bool recv_message_separately) { + bool request_status_early) { grpc_end2end_test_fixture f; - gpr_log( - GPR_INFO, - "Running test: %s/%s/request_status_early=%s/recv_message_separately=%s", - test_name, config.name, request_status_early ? "true" : "false", - recv_message_separately ? "true" : "false"); + gpr_log(GPR_INFO, "Running test: %s/%s/request_status_early=%s", test_name, + config.name, request_status_early ? "true" : "false"); f = config.create_fixture(client_args, server_args); config.init_server(&f, server_args); config.init_client(&f, client_args); @@ -112,7 +108,7 @@ static void test(grpc_end2end_test_config config, bool request_status_early, grpc_raw_byte_buffer_create(&response_payload2_slice, 1); grpc_end2end_test_fixture f = begin_test(config, "streaming_error_response", nullptr, nullptr, - request_status_early, recv_message_separately); + request_status_early); grpc_core::CqVerifier cqv(f.cq); grpc_op ops[6]; grpc_op* op; diff --git a/test/core/filters/client_auth_filter_test.cc b/test/core/filters/client_auth_filter_test.cc index a65a6fb55e0..54cbddb012f 100644 --- a/test/core/filters/client_auth_filter_test.cc +++ b/test/core/filters/client_auth_filter_test.cc @@ -155,8 +155,7 @@ TEST_F(ClientAuthFilterTest, CallCredsFails) { auto promise = filter->MakeCallPromise( CallArgs{ClientMetadataHandle(&initial_metadata_batch_, Arena::PooledDeleter(nullptr)), - ClientInitialMetadataOutstandingToken::Empty(), nullptr, nullptr, - nullptr}, + nullptr, nullptr, nullptr}, [&](CallArgs /*call_args*/) { return ArenaPromise( [&]() -> Poll { @@ -186,8 +185,7 @@ TEST_F(ClientAuthFilterTest, RewritesInvalidStatusFromCallCreds) { auto promise = filter->MakeCallPromise( CallArgs{ClientMetadataHandle(&initial_metadata_batch_, Arena::PooledDeleter(nullptr)), - ClientInitialMetadataOutstandingToken::Empty(), nullptr, nullptr, - nullptr}, + nullptr, nullptr, nullptr}, [&](CallArgs /*call_args*/) { return ArenaPromise( [&]() -> Poll { diff --git a/test/core/filters/client_authority_filter_test.cc b/test/core/filters/client_authority_filter_test.cc index c300f6403ac..5509be006fc 100644 --- a/test/core/filters/client_authority_filter_test.cc +++ b/test/core/filters/client_authority_filter_test.cc @@ -72,8 +72,7 @@ TEST(ClientAuthorityFilterTest, PromiseCompletesImmediatelyAndSetsAuthority) { auto promise = filter.MakeCallPromise( CallArgs{ClientMetadataHandle(&initial_metadata_batch, Arena::PooledDeleter(nullptr)), - ClientInitialMetadataOutstandingToken::Empty(), nullptr, nullptr, - nullptr}, + nullptr, nullptr, nullptr}, [&](CallArgs call_args) { EXPECT_EQ(call_args.client_initial_metadata ->get_pointer(HttpAuthorityMetadata()) @@ -108,8 +107,7 @@ TEST(ClientAuthorityFilterTest, auto promise = filter.MakeCallPromise( CallArgs{ClientMetadataHandle(&initial_metadata_batch, Arena::PooledDeleter(nullptr)), - ClientInitialMetadataOutstandingToken::Empty(), nullptr, nullptr, - nullptr}, + nullptr, nullptr, nullptr}, [&](CallArgs call_args) { EXPECT_EQ(call_args.client_initial_metadata ->get_pointer(HttpAuthorityMetadata()) diff --git a/test/core/filters/filter_fuzzer.cc b/test/core/filters/filter_fuzzer.cc index 5ddf733bac6..443d72bca8f 100644 --- a/test/core/filters/filter_fuzzer.cc +++ b/test/core/filters/filter_fuzzer.cc @@ -477,7 +477,6 @@ class MainLoop { auto* server_initial_metadata = arena_->New>(); CallArgs call_args{std::move(*LoadMetadata(client_initial_metadata, &client_initial_metadata_)), - ClientInitialMetadataOutstandingToken::Empty(), &server_initial_metadata->sender, nullptr, nullptr}; if (is_client) { promise_ = main_loop_->channel_stack_->MakeClientCallPromise(