diff --git a/BUILD b/BUILD index 34795eb4993..085f6e19aa3 100644 --- a/BUILD +++ b/BUILD @@ -1471,6 +1471,7 @@ grpc_cc_library( "//src/core:iomgr_fwd", "//src/core:iomgr_port", "//src/core:json", + "//src/core:latch", "//src/core:map", "//src/core:match", "//src/core:memory_quota", diff --git a/build_autogenerated.yaml b/build_autogenerated.yaml index 806c2ae9c0d..a779f07539a 100644 --- a/build_autogenerated.yaml +++ b/build_autogenerated.yaml @@ -3730,6 +3730,7 @@ libs: - src/core/lib/promise/if.h - src/core/lib/promise/interceptor_list.h - src/core/lib/promise/intra_activity_waiter.h + - src/core/lib/promise/latch.h - src/core/lib/promise/loop.h - src/core/lib/promise/map.h - src/core/lib/promise/pipe.h @@ -7544,6 +7545,7 @@ targets: - src/core/lib/promise/if.h - src/core/lib/promise/interceptor_list.h - src/core/lib/promise/intra_activity_waiter.h + - src/core/lib/promise/latch.h - src/core/lib/promise/loop.h - src/core/lib/promise/map.h - src/core/lib/promise/pipe.h diff --git a/src/core/BUILD b/src/core/BUILD index 4f3d671e2e2..688e7167caf 100644 --- a/src/core/BUILD +++ b/src/core/BUILD @@ -3437,33 +3437,38 @@ grpc_cc_library( "ext/filters/message_size/message_size_filter.h", ], external_deps = [ - "absl/status", + "absl/status:statusor", "absl/strings", "absl/strings:str_format", "absl/types:optional", ], language = "c++", deps = [ + "activity", + "arena", + "arena_promise", "channel_args", "channel_fwd", "channel_init", "channel_stack_type", - "closure", - "error", + "context", "grpc_service_config", "json", "json_args", "json_object_loader", + "latch", + "poll", + "race", "service_config_parser", + "slice", "slice_buffer", - "status_helper", "validation_errors", "//:channel_stack_builder", "//:config", - "//:debug_location", "//:gpr", "//:grpc_base", "//:grpc_public_hdrs", + "//:grpc_trace", ], ) diff --git a/src/core/ext/filters/http/client/http_client_filter.cc b/src/core/ext/filters/http/client/http_client_filter.cc index 58f9a709024..89ff0356cce 100644 --- a/src/core/ext/filters/http/client/http_client_filter.cc +++ b/src/core/ext/filters/http/client/http_client_filter.cc @@ -133,13 +133,13 @@ ArenaPromise HttpClientFilter::MakeCallPromise( return std::move(md); }); - return Race(Map(next_promise_factory(std::move(call_args)), + return Race(initial_metadata_err->Wait(), + Map(next_promise_factory(std::move(call_args)), [](ServerMetadataHandle md) -> ServerMetadataHandle { auto r = CheckServerMetadata(md.get()); if (!r.ok()) return ServerMetadataFromStatus(r); return md; - }), - initial_metadata_err->Wait()); + })); } HttpClientFilter::HttpClientFilter(HttpSchemeMetadata::ValueType scheme, diff --git a/src/core/ext/filters/http/message_compress/compression_filter.cc b/src/core/ext/filters/http/message_compress/compression_filter.cc index 03171819c68..a0f87cfd6be 100644 --- a/src/core/ext/filters/http/message_compress/compression_filter.cc +++ b/src/core/ext/filters/http/message_compress/compression_filter.cc @@ -233,7 +233,7 @@ ArenaPromise ClientCompressionFilter::MakeCallPromise( return CompressMessage(std::move(message), compression_algorithm); }); auto* decompress_args = GetContext()->New( - DecompressArgs{GRPC_COMPRESS_NONE, absl::nullopt}); + DecompressArgs{GRPC_COMPRESS_ALGORITHMS_COUNT, absl::nullopt}); auto* decompress_err = GetContext()->New>(); call_args.server_initial_metadata->InterceptAndMap( @@ -254,8 +254,8 @@ ArenaPromise ClientCompressionFilter::MakeCallPromise( return std::move(*r); }); // Run the next filter, and race it with getting an error from decompression. - return Race(next_promise_factory(std::move(call_args)), - decompress_err->Wait()); + return Race(decompress_err->Wait(), + next_promise_factory(std::move(call_args))); } ArenaPromise ServerCompressionFilter::MakeCallPromise( @@ -269,7 +269,8 @@ ArenaPromise ServerCompressionFilter::MakeCallPromise( this](MessageHandle message) -> absl::optional { auto r = DecompressMessage(std::move(message), decompress_args); if (grpc_call_trace.enabled()) { - gpr_log(GPR_DEBUG, "DecompressMessage returned %s", + gpr_log(GPR_DEBUG, "%s[compression] DecompressMessage returned %s", + Activity::current()->DebugTag().c_str(), r.status().ToString().c_str()); } if (!r.ok()) { @@ -300,8 +301,8 @@ ArenaPromise ServerCompressionFilter::MakeCallPromise( // - decompress incoming messages // - wait for initial metadata to be sent, and then commence compression of // outgoing messages - return Race(next_promise_factory(std::move(call_args)), - decompress_err->Wait()); + return Race(decompress_err->Wait(), + next_promise_factory(std::move(call_args))); } } // namespace grpc_core diff --git a/src/core/ext/filters/message_size/message_size_filter.cc b/src/core/ext/filters/message_size/message_size_filter.cc index 33ff178e5a9..c265ecab7d7 100644 --- a/src/core/ext/filters/message_size/message_size_filter.cc +++ b/src/core/ext/filters/message_size/message_size_filter.cc @@ -18,10 +18,13 @@ #include "src/core/ext/filters/message_size/message_size_filter.h" +#include + +#include #include -#include +#include +#include -#include "absl/status/status.h" #include "absl/strings/str_format.h" #include @@ -32,21 +35,22 @@ #include "src/core/lib/channel/channel_stack.h" #include "src/core/lib/channel/channel_stack_builder.h" #include "src/core/lib/config/core_configuration.h" -#include "src/core/lib/gprpp/debug_location.h" -#include "src/core/lib/gprpp/status_helper.h" -#include "src/core/lib/iomgr/call_combiner.h" -#include "src/core/lib/iomgr/closure.h" -#include "src/core/lib/iomgr/error.h" +#include "src/core/lib/debug/trace.h" +#include "src/core/lib/promise/activity.h" +#include "src/core/lib/promise/context.h" +#include "src/core/lib/promise/latch.h" +#include "src/core/lib/promise/poll.h" +#include "src/core/lib/promise/race.h" +#include "src/core/lib/resource_quota/arena.h" #include "src/core/lib/service_config/service_config_call_data.h" +#include "src/core/lib/slice/slice.h" #include "src/core/lib/slice/slice_buffer.h" +#include "src/core/lib/surface/call_trace.h" #include "src/core/lib/surface/channel_init.h" #include "src/core/lib/surface/channel_stack_type.h" +#include "src/core/lib/transport/metadata_batch.h" #include "src/core/lib/transport/transport.h" -static void recv_message_ready(void* user_data, grpc_error_handle error); -static void recv_trailing_metadata_ready(void* user_data, - grpc_error_handle error); - namespace grpc_core { // @@ -124,251 +128,164 @@ size_t MessageSizeParser::ParserIndex() { parser_name()); } -} // namespace grpc_core - -namespace { -struct channel_data { - grpc_core::MessageSizeParsedConfig limits; - const size_t service_config_parser_index{ - grpc_core::MessageSizeParser::ParserIndex()}; -}; +// +// MessageSizeFilter +// -struct call_data { - call_data(grpc_call_element* elem, const channel_data& chand, - const grpc_call_element_args& args) - : call_combiner(args.call_combiner), limits(chand.limits) { - GRPC_CLOSURE_INIT(&recv_message_ready, ::recv_message_ready, elem, - grpc_schedule_on_exec_ctx); - GRPC_CLOSURE_INIT(&recv_trailing_metadata_ready, - ::recv_trailing_metadata_ready, elem, - grpc_schedule_on_exec_ctx); - // Get max sizes from channel data, then merge in per-method config values. - // Note: Per-method config is only available on the client, so we - // apply the max request size to the send limit and the max response - // size to the receive limit. - const grpc_core::MessageSizeParsedConfig* config_from_call_context = - grpc_core::MessageSizeParsedConfig::GetFromCallContext( - args.context, chand.service_config_parser_index); - if (config_from_call_context != nullptr) { - absl::optional max_send_size = limits.max_send_size(); - absl::optional max_recv_size = limits.max_recv_size(); - if (config_from_call_context->max_send_size().has_value() && - (!max_send_size.has_value() || - *config_from_call_context->max_send_size() < *max_send_size)) { - max_send_size = *config_from_call_context->max_send_size(); +const grpc_channel_filter ClientMessageSizeFilter::kFilter = + MakePromiseBasedFilter("message_size"); +const grpc_channel_filter ServerMessageSizeFilter::kFilter = + MakePromiseBasedFilter("message_size"); + +class MessageSizeFilter::CallBuilder { + private: + auto Interceptor(uint32_t max_length, bool is_send) { + return [max_length, is_send, + err = err_](MessageHandle msg) -> absl::optional { + if (grpc_call_trace.enabled()) { + gpr_log(GPR_INFO, "%s[message_size] %s len:%" PRIdPTR " max:%d", + Activity::current()->DebugTag().c_str(), + is_send ? "send" : "recv", msg->payload()->Length(), + max_length); } - if (config_from_call_context->max_recv_size().has_value() && - (!max_recv_size.has_value() || - *config_from_call_context->max_recv_size() < *max_recv_size)) { - max_recv_size = *config_from_call_context->max_recv_size(); + if (msg->payload()->Length() > max_length) { + if (err->is_set()) return std::move(msg); + auto r = GetContext()->MakePooled( + GetContext()); + r->Set(GrpcStatusMetadata(), GRPC_STATUS_RESOURCE_EXHAUSTED); + r->Set(GrpcMessageMetadata(), + Slice::FromCopiedString( + absl::StrFormat("%s message larger than max (%u vs. %d)", + is_send ? "Sent" : "Received", + msg->payload()->Length(), max_length))); + err->Set(std::move(r)); + return absl::nullopt; } - limits = grpc_core::MessageSizeParsedConfig(max_send_size, max_recv_size); - } + return std::move(msg); + }; } - ~call_data() {} - - grpc_core::CallCombiner* call_combiner; - grpc_core::MessageSizeParsedConfig limits; - // Receive closures are chained: we inject this closure as the - // recv_message_ready up-call on transport_stream_op, and remember to - // call our next_recv_message_ready member after handling it. - grpc_closure recv_message_ready; - grpc_closure recv_trailing_metadata_ready; - // The error caused by a message that is too large, or absl::OkStatus() - grpc_error_handle error; - // Used by recv_message_ready. - absl::optional* recv_message = nullptr; - // Original recv_message_ready callback, invoked after our own. - grpc_closure* next_recv_message_ready = nullptr; - // Original recv_trailing_metadata callback, invoked after our own. - grpc_closure* original_recv_trailing_metadata_ready; - bool seen_recv_trailing_metadata = false; - grpc_error_handle recv_trailing_metadata_error; -}; - -} // namespace + public: + explicit CallBuilder(const MessageSizeParsedConfig& limits) + : limits_(limits) {} -// Callback invoked when we receive a message. Here we check the max -// receive message size. -static void recv_message_ready(void* user_data, grpc_error_handle error) { - grpc_call_element* elem = static_cast(user_data); - call_data* calld = static_cast(elem->call_data); - if (calld->recv_message->has_value() && - calld->limits.max_recv_size().has_value() && - (*calld->recv_message)->Length() > - static_cast(*calld->limits.max_recv_size())) { - grpc_error_handle new_error = grpc_error_set_int( - GRPC_ERROR_CREATE(absl::StrFormat( - "Received message larger than max (%u vs. %d)", - (*calld->recv_message)->Length(), *calld->limits.max_recv_size())), - grpc_core::StatusIntProperty::kRpcStatus, - GRPC_STATUS_RESOURCE_EXHAUSTED); - error = grpc_error_add_child(error, new_error); - calld->error = error; + template + void AddSend(T* pipe_end) { + if (!limits_.max_send_size().has_value()) return; + pipe_end->InterceptAndMap(Interceptor(*limits_.max_send_size(), true)); } - // Invoke the next callback. - grpc_closure* closure = calld->next_recv_message_ready; - calld->next_recv_message_ready = nullptr; - if (calld->seen_recv_trailing_metadata) { - // We might potentially see another RECV_MESSAGE op. In that case, we do not - // want to run the recv_trailing_metadata_ready closure again. The newer - // RECV_MESSAGE op cannot cause any errors since the transport has already - // invoked the recv_trailing_metadata_ready closure and all further - // RECV_MESSAGE ops will get null payloads. - calld->seen_recv_trailing_metadata = false; - GRPC_CALL_COMBINER_START(calld->call_combiner, - &calld->recv_trailing_metadata_ready, - calld->recv_trailing_metadata_error, - "continue recv_trailing_metadata_ready"); + template + void AddRecv(T* pipe_end) { + if (!limits_.max_recv_size().has_value()) return; + pipe_end->InterceptAndMap(Interceptor(*limits_.max_recv_size(), false)); } - grpc_core::Closure::Run(DEBUG_LOCATION, closure, error); -} -// Callback invoked on completion of recv_trailing_metadata -// Notifies the recv_trailing_metadata batch of any message size failures -static void recv_trailing_metadata_ready(void* user_data, - grpc_error_handle error) { - grpc_call_element* elem = static_cast(user_data); - call_data* calld = static_cast(elem->call_data); - if (calld->next_recv_message_ready != nullptr) { - calld->seen_recv_trailing_metadata = true; - calld->recv_trailing_metadata_error = error; - GRPC_CALL_COMBINER_STOP(calld->call_combiner, - "deferring recv_trailing_metadata_ready until " - "after recv_message_ready"); - return; + ArenaPromise Run( + CallArgs call_args, NextPromiseFactory next_promise_factory) { + return Race(err_->Wait(), next_promise_factory(std::move(call_args))); } - error = grpc_error_add_child(error, calld->error); - // Invoke the next callback. - grpc_core::Closure::Run(DEBUG_LOCATION, - calld->original_recv_trailing_metadata_ready, error); -} -// Start transport stream op. -static void message_size_start_transport_stream_op_batch( - grpc_call_element* elem, grpc_transport_stream_op_batch* op) { - call_data* calld = static_cast(elem->call_data); - // Check max send message size. - if (op->send_message && calld->limits.max_send_size().has_value() && - op->payload->send_message.send_message->Length() > - static_cast(*calld->limits.max_send_size())) { - grpc_transport_stream_op_batch_finish_with_failure( - op, - grpc_error_set_int(GRPC_ERROR_CREATE(absl::StrFormat( - "Sent message larger than max (%u vs. %d)", - op->payload->send_message.send_message->Length(), - *calld->limits.max_send_size())), - grpc_core::StatusIntProperty::kRpcStatus, - GRPC_STATUS_RESOURCE_EXHAUSTED), - calld->call_combiner); - return; - } - // Inject callback for receiving a message. - if (op->recv_message) { - calld->next_recv_message_ready = - op->payload->recv_message.recv_message_ready; - calld->recv_message = op->payload->recv_message.recv_message; - op->payload->recv_message.recv_message_ready = &calld->recv_message_ready; - } - // Inject callback for receiving trailing metadata. - if (op->recv_trailing_metadata) { - calld->original_recv_trailing_metadata_ready = - op->payload->recv_trailing_metadata.recv_trailing_metadata_ready; - op->payload->recv_trailing_metadata.recv_trailing_metadata_ready = - &calld->recv_trailing_metadata_ready; - } - // Chain to the next filter. - grpc_call_next_op(elem, op); -} + private: + Latch* const err_ = + GetContext()->New>(); + MessageSizeParsedConfig limits_; +}; -// Constructor for call_data. -static grpc_error_handle message_size_init_call_elem( - grpc_call_element* elem, const grpc_call_element_args* args) { - channel_data* chand = static_cast(elem->channel_data); - new (elem->call_data) call_data(elem, *chand, *args); - return absl::OkStatus(); +absl::StatusOr ClientMessageSizeFilter::Create( + const ChannelArgs& args, ChannelFilter::Args) { + return ClientMessageSizeFilter(args); } -// Destructor for call_data. -static void message_size_destroy_call_elem( - grpc_call_element* elem, const grpc_call_final_info* /*final_info*/, - grpc_closure* /*ignored*/) { - call_data* calld = static_cast(elem->call_data); - calld->~call_data(); +absl::StatusOr ServerMessageSizeFilter::Create( + const ChannelArgs& args, ChannelFilter::Args) { + return ServerMessageSizeFilter(args); } -// Constructor for channel_data. -static grpc_error_handle message_size_init_channel_elem( - grpc_channel_element* elem, grpc_channel_element_args* args) { - GPR_ASSERT(!args->is_last); - channel_data* chand = static_cast(elem->channel_data); - new (chand) channel_data(); - chand->limits = grpc_core::MessageSizeParsedConfig::GetFromChannelArgs( - args->channel_args); - return absl::OkStatus(); -} +ArenaPromise ClientMessageSizeFilter::MakeCallPromise( + CallArgs call_args, NextPromiseFactory next_promise_factory) { + // Get max sizes from channel data, then merge in per-method config values. + // Note: Per-method config is only available on the client, so we + // apply the max request size to the send limit and the max response + // size to the receive limit. + MessageSizeParsedConfig limits = this->limits(); + const MessageSizeParsedConfig* config_from_call_context = + MessageSizeParsedConfig::GetFromCallContext( + GetContext(), + service_config_parser_index_); + if (config_from_call_context != nullptr) { + absl::optional max_send_size = limits.max_send_size(); + absl::optional max_recv_size = limits.max_recv_size(); + if (config_from_call_context->max_send_size().has_value() && + (!max_send_size.has_value() || + *config_from_call_context->max_send_size() < *max_send_size)) { + max_send_size = *config_from_call_context->max_send_size(); + } + if (config_from_call_context->max_recv_size().has_value() && + (!max_recv_size.has_value() || + *config_from_call_context->max_recv_size() < *max_recv_size)) { + max_recv_size = *config_from_call_context->max_recv_size(); + } + limits = MessageSizeParsedConfig(max_send_size, max_recv_size); + } -// Destructor for channel_data. -static void message_size_destroy_channel_elem(grpc_channel_element* elem) { - channel_data* chand = static_cast(elem->channel_data); - chand->~channel_data(); + CallBuilder b(limits); + b.AddSend(call_args.client_to_server_messages); + b.AddRecv(call_args.server_to_client_messages); + return b.Run(std::move(call_args), std::move(next_promise_factory)); } -const grpc_channel_filter grpc_message_size_filter = { - message_size_start_transport_stream_op_batch, - nullptr, - grpc_channel_next_op, - sizeof(call_data), - message_size_init_call_elem, - grpc_call_stack_ignore_set_pollset_or_pollset_set, - message_size_destroy_call_elem, - sizeof(channel_data), - message_size_init_channel_elem, - grpc_channel_stack_no_post_init, - message_size_destroy_channel_elem, - grpc_channel_next_get_info, - "message_size"}; +ArenaPromise ServerMessageSizeFilter::MakeCallPromise( + CallArgs call_args, NextPromiseFactory next_promise_factory) { + CallBuilder b(limits()); + b.AddSend(call_args.server_to_client_messages); + b.AddRecv(call_args.client_to_server_messages); + return b.Run(std::move(call_args), std::move(next_promise_factory)); +} +namespace { // Used for GRPC_CLIENT_SUBCHANNEL -static bool maybe_add_message_size_filter_subchannel( - grpc_core::ChannelStackBuilder* builder) { +bool MaybeAddMessageSizeFilterToSubchannel(ChannelStackBuilder* builder) { if (builder->channel_args().WantMinimalStack()) { return true; } - builder->PrependFilter(&grpc_message_size_filter); + builder->PrependFilter(&ClientMessageSizeFilter::kFilter); return true; } -// Used for GRPC_CLIENT_DIRECT_CHANNEL and GRPC_SERVER_CHANNEL. Adds the filter -// only if message size limits or service config is specified. -static bool maybe_add_message_size_filter( - grpc_core::ChannelStackBuilder* builder) { - auto channel_args = builder->channel_args(); - if (channel_args.WantMinimalStack()) { +// Used for GRPC_CLIENT_DIRECT_CHANNEL and GRPC_SERVER_CHANNEL. Adds the +// filter only if message size limits or service config is specified. +auto MaybeAddMessageSizeFilter(const grpc_channel_filter* filter) { + return [filter](ChannelStackBuilder* builder) { + auto channel_args = builder->channel_args(); + if (channel_args.WantMinimalStack()) { + return true; + } + MessageSizeParsedConfig limits = + MessageSizeParsedConfig::GetFromChannelArgs(channel_args); + const bool enable = + limits.max_send_size().has_value() || + limits.max_recv_size().has_value() || + channel_args.GetString(GRPC_ARG_SERVICE_CONFIG).has_value(); + if (enable) builder->PrependFilter(filter); return true; - } - grpc_core::MessageSizeParsedConfig limits = - grpc_core::MessageSizeParsedConfig::GetFromChannelArgs(channel_args); - const bool enable = - limits.max_send_size().has_value() || - limits.max_recv_size().has_value() || - channel_args.GetString(GRPC_ARG_SERVICE_CONFIG).has_value(); - if (enable) builder->PrependFilter(&grpc_message_size_filter); - return true; + }; } -namespace grpc_core { +} // namespace void RegisterMessageSizeFilter(CoreConfiguration::Builder* builder) { MessageSizeParser::Register(builder); - builder->channel_init()->RegisterStage( - GRPC_CLIENT_SUBCHANNEL, GRPC_CHANNEL_INIT_BUILTIN_PRIORITY, - maybe_add_message_size_filter_subchannel); - builder->channel_init()->RegisterStage(GRPC_CLIENT_DIRECT_CHANNEL, - GRPC_CHANNEL_INIT_BUILTIN_PRIORITY, - maybe_add_message_size_filter); - builder->channel_init()->RegisterStage(GRPC_SERVER_CHANNEL, + builder->channel_init()->RegisterStage(GRPC_CLIENT_SUBCHANNEL, GRPC_CHANNEL_INIT_BUILTIN_PRIORITY, - maybe_add_message_size_filter); + MaybeAddMessageSizeFilterToSubchannel); + builder->channel_init()->RegisterStage( + GRPC_CLIENT_DIRECT_CHANNEL, GRPC_CHANNEL_INIT_BUILTIN_PRIORITY, + MaybeAddMessageSizeFilter(&ClientMessageSizeFilter::kFilter)); + builder->channel_init()->RegisterStage( + GRPC_SERVER_CHANNEL, GRPC_CHANNEL_INIT_BUILTIN_PRIORITY, + MaybeAddMessageSizeFilter(&ServerMessageSizeFilter::kFilter)); } } // namespace grpc_core diff --git a/src/core/ext/filters/message_size/message_size_filter.h b/src/core/ext/filters/message_size/message_size_filter.h index e47485a8950..75135a1b75e 100644 --- a/src/core/ext/filters/message_size/message_size_filter.h +++ b/src/core/ext/filters/message_size/message_size_filter.h @@ -24,21 +24,22 @@ #include +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "src/core/lib/channel/channel_args.h" #include "src/core/lib/channel/channel_fwd.h" -#include "src/core/lib/channel/channel_stack.h" #include "src/core/lib/channel/context.h" +#include "src/core/lib/channel/promise_based_filter.h" #include "src/core/lib/config/core_configuration.h" #include "src/core/lib/gprpp/validation_errors.h" #include "src/core/lib/json/json.h" #include "src/core/lib/json/json_args.h" #include "src/core/lib/json/json_object_loader.h" +#include "src/core/lib/promise/arena_promise.h" #include "src/core/lib/service_config/service_config_parser.h" - -extern const grpc_channel_filter grpc_message_size_filter; +#include "src/core/lib/transport/transport.h" namespace grpc_core { @@ -85,6 +86,50 @@ class MessageSizeParser : public ServiceConfigParser::Parser { absl::optional GetMaxRecvSizeFromChannelArgs(const ChannelArgs& args); absl::optional GetMaxSendSizeFromChannelArgs(const ChannelArgs& args); +class MessageSizeFilter : public ChannelFilter { + protected: + explicit MessageSizeFilter(const ChannelArgs& args) + : limits_(MessageSizeParsedConfig::GetFromChannelArgs(args)) {} + + class CallBuilder; + + const MessageSizeParsedConfig& limits() const { return limits_; } + + private: + MessageSizeParsedConfig limits_; +}; + +class ServerMessageSizeFilter final : public MessageSizeFilter { + public: + static const grpc_channel_filter kFilter; + + static absl::StatusOr Create( + const ChannelArgs& args, ChannelFilter::Args filter_args); + + // Construct a promise for one call. + ArenaPromise MakeCallPromise( + CallArgs call_args, NextPromiseFactory next_promise_factory) override; + + private: + using MessageSizeFilter::MessageSizeFilter; +}; + +class ClientMessageSizeFilter final : public MessageSizeFilter { + public: + static const grpc_channel_filter kFilter; + + static absl::StatusOr Create( + const ChannelArgs& args, ChannelFilter::Args filter_args); + + // Construct a promise for one call. + ArenaPromise MakeCallPromise( + CallArgs call_args, NextPromiseFactory next_promise_factory) override; + + private: + const size_t service_config_parser_index_{MessageSizeParser::ParserIndex()}; + using MessageSizeFilter::MessageSizeFilter; +}; + } // namespace grpc_core #endif // GRPC_SRC_CORE_EXT_FILTERS_MESSAGE_SIZE_MESSAGE_SIZE_FILTER_H diff --git a/src/core/ext/transport/inproc/inproc_transport.cc b/src/core/ext/transport/inproc/inproc_transport.cc index dc6b4804f0a..f99d41bfeed 100644 --- a/src/core/ext/transport/inproc/inproc_transport.cc +++ b/src/core/ext/transport/inproc/inproc_transport.cc @@ -766,6 +766,8 @@ void op_state_machine_locked(inproc_stream* s, grpc_error_handle error) { nullptr); s->to_read_trailing_md.Clear(); s->to_read_trailing_md_filled = false; + s->recv_trailing_md_op->payload->recv_trailing_metadata + .recv_trailing_metadata->Set(grpc_core::GrpcStatusFromWire(), true); // We should schedule the recv_trailing_md_op completion if // 1. this stream is the client-side diff --git a/src/core/lib/channel/connected_channel.cc b/src/core/lib/channel/connected_channel.cc index ea3f226311a..f66203b2598 100644 --- a/src/core/lib/channel/connected_channel.cc +++ b/src/core/lib/channel/connected_channel.cc @@ -56,6 +56,7 @@ #include "src/core/lib/gprpp/orphanable.h" #include "src/core/lib/gprpp/status_helper.h" #include "src/core/lib/gprpp/sync.h" +#include "src/core/lib/gprpp/time.h" #include "src/core/lib/iomgr/call_combiner.h" #include "src/core/lib/iomgr/closure.h" #include "src/core/lib/iomgr/error.h" @@ -73,6 +74,7 @@ #include "src/core/lib/surface/call.h" #include "src/core/lib/surface/call_trace.h" #include "src/core/lib/surface/channel_stack_type.h" +#include "src/core/lib/transport/error_utils.h" #include "src/core/lib/transport/metadata_batch.h" #include "src/core/lib/transport/transport.h" #include "src/core/lib/transport/transport_fwd.h" @@ -478,7 +480,15 @@ class ConnectedChannelStream : public Orphanable { return Match( recv_message_state_, [](Idle) -> std::string { return "IDLE"; }, [](Closed) -> std::string { return "CLOSED"; }, - [](const PendingReceiveMessage&) -> std::string { return "WAITING"; }, + [](const PendingReceiveMessage& m) -> std::string { + if (m.received) { + return absl::StrCat("RECEIVED_FROM_TRANSPORT:", + m.payload.has_value() + ? absl::StrCat(m.payload->Length(), "b") + : "EOS"); + } + return "WAITING"; + }, [](const absl::optional& message) -> std::string { return absl::StrCat( "READY:", message.has_value() @@ -570,13 +580,7 @@ class ConnectedChannelStream : public Orphanable { void RecvMessageBatchDone(grpc_error_handle error) { { MutexLock lock(mu()); - if (error != absl::OkStatus()) { - if (grpc_call_trace.enabled()) { - gpr_log(GPR_INFO, "%s[connected] RecvMessageBatchDone: error=%s", - recv_message_waker_.ActivityDebugTag().c_str(), - StatusToString(error).c_str()); - } - } else if (absl::holds_alternative(recv_message_state_)) { + if (absl::holds_alternative(recv_message_state_)) { if (grpc_call_trace.enabled()) { gpr_log(GPR_INFO, "%s[connected] RecvMessageBatchDone: already closed, " @@ -584,14 +588,21 @@ class ConnectedChannelStream : public Orphanable { recv_message_waker_.ActivityDebugTag().c_str()); } } else { - if (grpc_call_trace.enabled()) { + auto pending = + absl::get_if(&recv_message_state_); + GPR_ASSERT(pending != nullptr); + if (!error.ok()) { + if (grpc_call_trace.enabled()) { + gpr_log(GPR_INFO, "%s[connected] RecvMessageBatchDone: error=%s", + recv_message_waker_.ActivityDebugTag().c_str(), + StatusToString(error).c_str()); + } + pending->payload.reset(); + } else if (grpc_call_trace.enabled()) { gpr_log(GPR_INFO, "%s[connected] RecvMessageBatchDone: received message", recv_message_waker_.ActivityDebugTag().c_str()); } - auto pending = - absl::get_if(&recv_message_state_); - GPR_ASSERT(pending != nullptr); GPR_ASSERT(pending->received == false); pending->received = true; } @@ -671,6 +682,8 @@ class ClientStream : public ConnectedChannelStream { public: ClientStream(grpc_transport* transport, CallArgs call_args) : ConnectedChannelStream(transport), + client_initial_metadata_outstanding_token_( + std::move(call_args.client_initial_metadata_outstanding)), server_initial_metadata_pipe_(call_args.server_initial_metadata), client_to_server_messages_(call_args.client_to_server_messages), server_to_client_messages_(call_args.server_to_client_messages), @@ -704,12 +717,14 @@ class ClientStream : public ConnectedChannelStream { nullptr, GetContext()); grpc_transport_set_pops(transport(), stream(), GetContext()->polling_entity()); - memset(&metadata_, 0, sizeof(metadata_)); - metadata_.send_initial_metadata = true; - metadata_.recv_initial_metadata = true; - metadata_.recv_trailing_metadata = true; - metadata_.payload = batch_payload(); - metadata_.on_complete = &metadata_batch_done_; + memset(&send_metadata_, 0, sizeof(send_metadata_)); + memset(&recv_metadata_, 0, sizeof(recv_metadata_)); + send_metadata_.send_initial_metadata = true; + recv_metadata_.recv_initial_metadata = true; + recv_metadata_.recv_trailing_metadata = true; + send_metadata_.payload = batch_payload(); + recv_metadata_.payload = batch_payload(); + send_metadata_.on_complete = &send_metadata_batch_done_; batch_payload()->send_initial_metadata.send_initial_metadata = client_initial_metadata_.get(); batch_payload()->send_initial_metadata.peer_string = @@ -734,9 +749,16 @@ class ClientStream : public ConnectedChannelStream { IncrementRefCount("metadata_batch_done"); IncrementRefCount("initial_metadata_ready"); IncrementRefCount("trailing_metadata_ready"); - initial_metadata_waker_ = Activity::current()->MakeOwningWaker(); - trailing_metadata_waker_ = Activity::current()->MakeOwningWaker(); - SchedulePush(&metadata_); + recv_initial_metadata_waker_ = Activity::current()->MakeOwningWaker(); + recv_trailing_metadata_waker_ = Activity::current()->MakeOwningWaker(); + send_initial_metadata_waker_ = Activity::current()->MakeOwningWaker(); + SchedulePush(&send_metadata_); + SchedulePush(&recv_metadata_); + } + if (std::exchange(need_to_clear_client_initial_metadata_outstanding_token_, + false)) { + client_initial_metadata_outstanding_token_.Complete( + client_initial_metadata_send_result_); } if (server_initial_metadata_state_ == ServerInitialMetadataState::kReceivedButNotPushed) { @@ -753,9 +775,21 @@ class ClientStream : public ConnectedChannelStream { server_initial_metadata_push_promise_.reset(); } } + if (server_initial_metadata_state_ == ServerInitialMetadataState::kError) { + server_initial_metadata_pipe_->Close(); + } PollSendMessage(client_to_server_messages_, &client_trailing_metadata_); PollRecvMessage(server_to_client_messages_); - if (server_initial_metadata_state_ == ServerInitialMetadataState::kPushed && + if (grpc_call_trace.enabled()) { + gpr_log( + GPR_INFO, + "%s[connected] Finishing PollConnectedChannel: requesting metadata", + Activity::current()->DebugTag().c_str()); + } + if ((server_initial_metadata_state_ == + ServerInitialMetadataState::kPushed || + server_initial_metadata_state_ == + ServerInitialMetadataState::kError) && !IsPromiseReceiving() && std::exchange(queued_trailing_metadata_, false)) { if (grpc_call_trace.enabled()) { @@ -774,18 +808,32 @@ class ClientStream : public ConnectedChannelStream { } void RecvInitialMetadataReady(grpc_error_handle error) { - GPR_ASSERT(error == absl::OkStatus()); { MutexLock lock(mu()); + if (grpc_call_trace.enabled()) { + gpr_log(GPR_DEBUG, "%s[connected] RecvInitialMetadataReady: error=%s", + recv_initial_metadata_waker_.ActivityDebugTag().c_str(), + error.ToString().c_str()); + } server_initial_metadata_state_ = - ServerInitialMetadataState::kReceivedButNotPushed; - initial_metadata_waker_.Wakeup(); + error.ok() ? ServerInitialMetadataState::kReceivedButNotPushed + : ServerInitialMetadataState::kError; + recv_initial_metadata_waker_.Wakeup(); } Unref("initial_metadata_ready"); } void RecvTrailingMetadataReady(grpc_error_handle error) { - GPR_ASSERT(error == absl::OkStatus()); + if (!error.ok()) { + server_trailing_metadata_->Clear(); + grpc_status_code status = GRPC_STATUS_UNKNOWN; + std::string message; + grpc_error_get_status(error, Timestamp::InfFuture(), &status, &message, + nullptr, nullptr); + server_trailing_metadata_->Set(GrpcStatusMetadata(), status); + server_trailing_metadata_->Set(GrpcMessageMetadata(), + Slice::FromCopiedString(message)); + } { MutexLock lock(mu()); queued_trailing_metadata_ = true; @@ -794,16 +842,21 @@ class ClientStream : public ConnectedChannelStream { "%s[connected] RecvTrailingMetadataReady: " "queued_trailing_metadata_ " "set to true; active_ops: %s", - trailing_metadata_waker_.ActivityDebugTag().c_str(), + recv_trailing_metadata_waker_.ActivityDebugTag().c_str(), ActiveOpsString().c_str()); } - trailing_metadata_waker_.Wakeup(); + recv_trailing_metadata_waker_.Wakeup(); } Unref("trailing_metadata_ready"); } - void MetadataBatchDone(grpc_error_handle error) { - GPR_ASSERT(error == absl::OkStatus()); + void SendMetadataBatchDone(grpc_error_handle error) { + { + MutexLock lock(mu()); + need_to_clear_client_initial_metadata_outstanding_token_ = true; + client_initial_metadata_send_result_ = error.ok(); + send_initial_metadata_waker_.Wakeup(); + } Unref("metadata_batch_done"); } @@ -823,6 +876,8 @@ class ClientStream : public ConnectedChannelStream { // has been pushed on the pipe to publish it up the call stack AND removed // by the call at the top. kPushed, + // Received initial metadata with an error status. + kError, }; std::string ActiveOpsString() const override @@ -831,10 +886,10 @@ class ClientStream : public ConnectedChannelStream { if (finished()) ops.push_back("FINISHED"); // Outstanding Operations on Transport std::vector waiting; - if (initial_metadata_waker_ != Waker()) { + if (recv_initial_metadata_waker_ != Waker()) { waiting.push_back("initial_metadata"); } - if (trailing_metadata_waker_ != Waker()) { + if (recv_trailing_metadata_waker_ != Waker()) { waiting.push_back("trailing_metadata"); } if (!waiting.empty()) { @@ -850,6 +905,18 @@ class ClientStream : public ConnectedChannelStream { if (!queued.empty()) { ops.push_back(absl::StrCat("queued:", absl::StrJoin(queued, ","))); } + switch (server_initial_metadata_state_) { + case ServerInitialMetadataState::kNotReceived: + case ServerInitialMetadataState::kReceivedButNotPushed: + case ServerInitialMetadataState::kPushed: + break; + case ServerInitialMetadataState::kPushing: + ops.push_back("server_initial_metadata:PUSHING"); + break; + case ServerInitialMetadataState::kError: + ops.push_back("server_initial_metadata:ERROR"); + break; + } // Send message std::string send_message_state = SendMessageString(); if (send_message_state != "WAITING") { @@ -864,11 +931,17 @@ class ClientStream : public ConnectedChannelStream { } bool requested_metadata_ = false; + bool need_to_clear_client_initial_metadata_outstanding_token_ + ABSL_GUARDED_BY(mu()) = false; + bool client_initial_metadata_send_result_ ABSL_GUARDED_BY(mu()); ServerInitialMetadataState server_initial_metadata_state_ ABSL_GUARDED_BY(mu()) = ServerInitialMetadataState::kNotReceived; bool queued_trailing_metadata_ ABSL_GUARDED_BY(mu()) = false; - Waker initial_metadata_waker_ ABSL_GUARDED_BY(mu()); - Waker trailing_metadata_waker_ ABSL_GUARDED_BY(mu()); + Waker recv_initial_metadata_waker_ ABSL_GUARDED_BY(mu()); + Waker send_initial_metadata_waker_ ABSL_GUARDED_BY(mu()); + Waker recv_trailing_metadata_waker_ ABSL_GUARDED_BY(mu()); + ClientInitialMetadataOutstandingToken + client_initial_metadata_outstanding_token_; PipeSender* server_initial_metadata_pipe_; PipeReceiver* client_to_server_messages_; PipeSender* server_to_client_messages_; @@ -884,9 +957,10 @@ class ClientStream : public ConnectedChannelStream { ServerMetadataHandle server_trailing_metadata_; absl::optional::PushType> server_initial_metadata_push_promise_; - grpc_transport_stream_op_batch metadata_; - grpc_closure metadata_batch_done_ = - MakeMemberClosure( + grpc_transport_stream_op_batch send_metadata_; + grpc_transport_stream_op_batch recv_metadata_; + grpc_closure send_metadata_batch_done_ = + MakeMemberClosure( this, DEBUG_LOCATION); }; @@ -948,6 +1022,7 @@ class ServerStream final : public ConnectedChannelStream { gim.client_initial_metadata.get(); batch_payload()->recv_initial_metadata.recv_initial_metadata_ready = &gim.recv_initial_metadata_ready; + IncrementRefCount("RecvInitialMetadata"); SchedulePush(&gim.recv_initial_metadata); // Fetch trailing metadata (to catch cancellations) @@ -965,8 +1040,9 @@ class ServerStream final : public ConnectedChannelStream { &GetContext()->call_stats()->transport_stream_stats; batch_payload()->recv_trailing_metadata.recv_trailing_metadata_ready = >m.recv_trailing_metadata_ready; - SchedulePush(>m.recv_trailing_metadata); gtm.waker = Activity::current()->MakeOwningWaker(); + IncrementRefCount("RecvTrailingMetadata"); + SchedulePush(>m.recv_trailing_metadata); } Poll PollOnce() { @@ -995,6 +1071,7 @@ class ServerStream final : public ConnectedChannelStream { .emplace(std::move(**md)) .get(); batch_payload()->send_initial_metadata.peer_string = nullptr; + IncrementRefCount("SendInitialMetadata"); SchedulePush(&send_initial_metadata_); return true; } else { @@ -1039,6 +1116,7 @@ class ServerStream final : public ConnectedChannelStream { incoming_messages_ = &pipes_.client_to_server.sender; auto promise = p->next_promise_factory(CallArgs{ std::move(p->client_initial_metadata), + ClientInitialMetadataOutstandingToken::Empty(), &pipes_.server_initial_metadata.sender, &pipes_.client_to_server.receiver, &pipes_.server_to_client.sender}); call_state_.emplace( @@ -1093,6 +1171,7 @@ class ServerStream final : public ConnectedChannelStream { ->as_string_view()), StatusIntProperty::kRpcStatus, status_code); } + IncrementRefCount("SendTrailingMetadata"); SchedulePush(&op); } } @@ -1188,6 +1267,7 @@ class ServerStream final : public ConnectedChannelStream { std::move(getting.next_promise_factory)}; call_state_.emplace(std::move(got)); waker.Wakeup(); + Unref("RecvInitialMetadata"); } void SendTrailingMetadataDone(absl::Status result) { @@ -1200,16 +1280,19 @@ class ServerStream final : public ConnectedChannelStream { waker.ActivityDebugTag().c_str(), result.ToString().c_str(), completing.sent ? "true" : "false", md->DebugString().c_str()); } - md->Set(GrpcStatusFromWire(), completing.sent); if (!result.ok()) { md->Clear(); md->Set(GrpcStatusMetadata(), static_cast(result.code())); md->Set(GrpcMessageMetadata(), Slice::FromCopiedString(result.message())); - md->Set(GrpcStatusFromWire(), false); + md->Set(GrpcCallWasCancelled(), true); + } + if (!md->get(GrpcCallWasCancelled()).has_value()) { + md->Set(GrpcCallWasCancelled(), !completing.sent); } call_state_.emplace(Complete{std::move(md)}); waker.Wakeup(); + Unref("SendTrailingMetadata"); } std::string ActiveOpsString() const override @@ -1261,7 +1344,7 @@ class ServerStream final : public ConnectedChannelStream { return absl::StrJoin(ops, " "); } - void SendInitialMetadataDone() {} + void SendInitialMetadataDone() { Unref("SendInitialMetadata"); } void RecvTrailingMetadataReady(absl::Status error) { MutexLock lock(mu()); @@ -1285,6 +1368,7 @@ class ServerStream final : public ConnectedChannelStream { client_trailing_metadata_state_.emplace( GotClientHalfClose{error}); waker.Wakeup(); + Unref("RecvTrailingMetadata"); } struct Pipes { diff --git a/src/core/lib/channel/promise_based_filter.cc b/src/core/lib/channel/promise_based_filter.cc index 0b0569d8fbd..8930eb0fe0e 100644 --- a/src/core/lib/channel/promise_based_filter.cc +++ b/src/core/lib/channel/promise_based_filter.cc @@ -16,6 +16,8 @@ #include "src/core/lib/channel/promise_based_filter.h" +#include + #include #include #include @@ -215,7 +217,7 @@ void BaseCallData::CapturedBatch::ResumeWith(Flusher* releaser) { // refcnt==0 ==> cancelled if (grpc_trace_channel.enabled()) { gpr_log(GPR_INFO, "%sRESUME BATCH REQUEST CANCELLED", - Activity::current()->DebugTag().c_str()); + releaser->call()->DebugTag().c_str()); } return; } @@ -239,6 +241,10 @@ void BaseCallData::CapturedBatch::CancelWith(grpc_error_handle error, auto* batch = std::exchange(batch_, nullptr); GPR_ASSERT(batch != nullptr); uintptr_t& refcnt = *RefCountField(batch); + gpr_log(GPR_DEBUG, "%sCancelWith: %p refs=%" PRIdPTR " err=%s [%s]", + releaser->call()->DebugTag().c_str(), batch, refcnt, + error.ToString().c_str(), + grpc_transport_stream_op_batch_string(batch).c_str()); if (refcnt == 0) { // refcnt==0 ==> cancelled if (grpc_trace_channel.enabled()) { @@ -325,6 +331,8 @@ const char* BaseCallData::SendMessage::StateString(State state) { return "CANCELLED"; case State::kCancelledButNotYetPolled: return "CANCELLED_BUT_NOT_YET_POLLED"; + case State::kCancelledButNoStatus: + return "CANCELLED_BUT_NO_STATUS"; } return "UNKNOWN"; } @@ -349,6 +357,7 @@ void BaseCallData::SendMessage::StartOp(CapturedBatch batch) { Crash(absl::StrFormat("ILLEGAL STATE: %s", StateString(state_))); case State::kCancelled: case State::kCancelledButNotYetPolled: + case State::kCancelledButNoStatus: return; } batch_ = batch; @@ -376,6 +385,7 @@ void BaseCallData::SendMessage::GotPipe(T* pipe_end) { case State::kForwardedBatch: case State::kBatchCompleted: case State::kPushedToPipe: + case State::kCancelledButNoStatus: Crash(absl::StrFormat("ILLEGAL STATE: %s", StateString(state_))); case State::kCancelled: case State::kCancelledButNotYetPolled: @@ -391,6 +401,7 @@ bool BaseCallData::SendMessage::IsIdle() const { case State::kForwardedBatch: case State::kCancelled: case State::kCancelledButNotYetPolled: + case State::kCancelledButNoStatus: return true; case State::kGotBatchNoPipe: case State::kGotBatch: @@ -419,6 +430,7 @@ void BaseCallData::SendMessage::OnComplete(absl::Status status) { break; case State::kCancelled: case State::kCancelledButNotYetPolled: + case State::kCancelledButNoStatus: flusher.AddClosure(intercepted_on_complete_, status, "forward after cancel"); break; @@ -443,10 +455,14 @@ void BaseCallData::SendMessage::Done(const ServerMetadata& metadata, case State::kCancelledButNotYetPolled: break; case State::kInitial: + state_ = State::kCancelled; + break; case State::kIdle: case State::kForwardedBatch: state_ = State::kCancelledButNotYetPolled; + if (base_->is_current()) base_->ForceImmediateRepoll(); break; + case State::kCancelledButNoStatus: case State::kGotBatchNoPipe: case State::kGotBatch: { std::string temp; @@ -465,6 +481,7 @@ void BaseCallData::SendMessage::Done(const ServerMetadata& metadata, push_.reset(); next_.reset(); state_ = State::kCancelledButNotYetPolled; + if (base_->is_current()) base_->ForceImmediateRepoll(); break; } } @@ -483,6 +500,7 @@ void BaseCallData::SendMessage::WakeInsideCombiner(Flusher* flusher, case State::kIdle: case State::kGotBatchNoPipe: case State::kCancelled: + case State::kCancelledButNoStatus: break; case State::kCancelledButNotYetPolled: interceptor()->Push()->Close(); @@ -524,13 +542,18 @@ void BaseCallData::SendMessage::WakeInsideCombiner(Flusher* flusher, "result.has_value=%s", base_->LogTag().c_str(), p->has_value() ? "true" : "false"); } - GPR_ASSERT(p->has_value()); - batch_->payload->send_message.send_message->Swap((**p)->payload()); - batch_->payload->send_message.flags = (**p)->flags(); - state_ = State::kForwardedBatch; - batch_.ResumeWith(flusher); - next_.reset(); - if (!absl::holds_alternative((*push_)())) push_.reset(); + if (p->has_value()) { + batch_->payload->send_message.send_message->Swap((**p)->payload()); + batch_->payload->send_message.flags = (**p)->flags(); + state_ = State::kForwardedBatch; + batch_.ResumeWith(flusher); + next_.reset(); + if (!absl::holds_alternative((*push_)())) push_.reset(); + } else { + state_ = State::kCancelledButNoStatus; + next_.reset(); + push_.reset(); + } } } break; case State::kForwardedBatch: @@ -1090,11 +1113,14 @@ class ClientCallData::PollContext { // Poll the promise once since we're waiting for it. Poll poll = self_->promise_(); if (grpc_trace_channel.enabled()) { - gpr_log(GPR_INFO, "%s ClientCallData.PollContext.Run: poll=%s", + gpr_log(GPR_INFO, "%s ClientCallData.PollContext.Run: poll=%s; %s", self_->LogTag().c_str(), - PollToString(poll, [](const ServerMetadataHandle& h) { - return h->DebugString(); - }).c_str()); + PollToString(poll, + [](const ServerMetadataHandle& h) { + return h->DebugString(); + }) + .c_str(), + self_->DebugString().c_str()); } if (auto* r = absl::get_if(&poll)) { auto md = std::move(*r); @@ -1274,7 +1300,11 @@ ClientCallData::ClientCallData(grpc_call_element* elem, [args]() { return args->arena->New(args->arena); }, - [args]() { return args->arena->New(args->arena); }) { + [args]() { return args->arena->New(args->arena); }), + initial_metadata_outstanding_token_( + (flags & kFilterIsLast) != 0 + ? ClientInitialMetadataOutstandingToken::New(arena()) + : ClientInitialMetadataOutstandingToken::Empty()) { GRPC_CLOSURE_INIT(&recv_trailing_metadata_ready_, RecvTrailingMetadataReadyCallback, this, grpc_schedule_on_exec_ctx); @@ -1543,6 +1573,7 @@ void ClientCallData::StartPromise(Flusher* flusher) { promise_ = filter->MakeCallPromise( CallArgs{WrapMetadata(send_initial_metadata_batch_->payload ->send_initial_metadata.send_initial_metadata), + std::move(initial_metadata_outstanding_token_), server_initial_metadata_pipe() == nullptr ? nullptr : &server_initial_metadata_pipe()->sender, @@ -1863,8 +1894,15 @@ struct ServerCallData::SendInitialMetadata { class ServerCallData::PollContext { public: - explicit PollContext(ServerCallData* self, Flusher* flusher) - : self_(self), flusher_(flusher) { + explicit PollContext(ServerCallData* self, Flusher* flusher, + DebugLocation created = DebugLocation()) + : self_(self), flusher_(flusher), created_(created) { + if (self_->poll_ctx_ != nullptr) { + Crash(absl::StrCat( + "PollContext: disallowed recursion. New: ", created_.file(), ":", + created_.line(), "; Old: ", self_->poll_ctx_->created_.file(), ":", + self_->poll_ctx_->created_.line())); + } GPR_ASSERT(self_->poll_ctx_ == nullptr); self_->poll_ctx_ = this; scoped_activity_.Init(self_); @@ -1910,6 +1948,7 @@ class ServerCallData::PollContext { Flusher* const flusher_; bool repoll_ = false; bool have_scoped_activity_; + GPR_NO_UNIQUE_ADDRESS DebugLocation created_; }; const char* ServerCallData::StateString(RecvInitialState state) { @@ -2079,7 +2118,10 @@ void ServerCallData::StartBatch(grpc_transport_stream_op_batch* b) { switch (send_trailing_state_) { case SendTrailingState::kInitial: send_trailing_metadata_batch_ = batch; - if (receive_message() != nullptr) { + if (receive_message() != nullptr && + batch->payload->send_trailing_metadata.send_trailing_metadata + ->get(GrpcStatusMetadata()) + .value_or(GRPC_STATUS_UNKNOWN) != GRPC_STATUS_OK) { receive_message()->Done( *batch->payload->send_trailing_metadata.send_trailing_metadata, &flusher); @@ -2138,6 +2180,7 @@ void ServerCallData::Completed(grpc_error_handle error, Flusher* flusher) { if (!error.ok()) { auto* batch = grpc_make_transport_stream_op( NewClosure([call_combiner = call_combiner()](absl::Status) { + gpr_log(GPR_DEBUG, "ON COMPLETE DONE"); GRPC_CALL_COMBINER_STOP(call_combiner, "done-cancel"); })); batch->cancel_stream = true; @@ -2312,6 +2355,7 @@ void ServerCallData::RecvInitialMetadataReady(grpc_error_handle error) { FakeActivity(this).Run([this, filter] { promise_ = filter->MakeCallPromise( CallArgs{WrapMetadata(recv_initial_metadata_), + ClientInitialMetadataOutstandingToken::Empty(), server_initial_metadata_pipe() == nullptr ? nullptr : &server_initial_metadata_pipe()->sender, @@ -2413,9 +2457,14 @@ void ServerCallData::WakeInsideCombiner(Flusher* flusher) { (send_trailing_metadata_batch_->send_message && send_message()->IsForwarded()))) { send_trailing_state_ = SendTrailingState::kQueued; - send_message()->Done(*send_trailing_metadata_batch_->payload - ->send_trailing_metadata.send_trailing_metadata, - flusher); + if (send_trailing_metadata_batch_->payload->send_trailing_metadata + .send_trailing_metadata->get(GrpcStatusMetadata()) + .value_or(GRPC_STATUS_UNKNOWN) != GRPC_STATUS_OK) { + send_message()->Done( + *send_trailing_metadata_batch_->payload->send_trailing_metadata + .send_trailing_metadata, + flusher); + } } } if (receive_message() != nullptr) { diff --git a/src/core/lib/channel/promise_based_filter.h b/src/core/lib/channel/promise_based_filter.h index bdec3556822..0e9cb164a3e 100644 --- a/src/core/lib/channel/promise_based_filter.h +++ b/src/core/lib/channel/promise_based_filter.h @@ -218,7 +218,11 @@ class BaseCallData : public Activity, private Wakeable { void Resume(grpc_transport_stream_op_batch* batch) { GPR_ASSERT(!call_->is_last()); - release_.push_back(batch); + if (batch->HasOp()) { + release_.push_back(batch); + } else if (batch->on_complete != nullptr) { + Complete(batch); + } } void Cancel(grpc_transport_stream_op_batch* batch, @@ -237,6 +241,8 @@ class BaseCallData : public Activity, private Wakeable { call_closures_.Add(closure, error, reason); } + BaseCallData* call() const { return call_; } + private: absl::InlinedVector release_; CallCombinerClosureList call_closures_; @@ -398,6 +404,8 @@ class BaseCallData : public Activity, private Wakeable { kCancelledButNotYetPolled, // We're done. kCancelled, + // We're done, but we haven't gotten a status yet + kCancelledButNoStatus, }; static const char* StateString(State); @@ -664,6 +672,8 @@ class ClientCallData : public BaseCallData { RecvTrailingState recv_trailing_state_ = RecvTrailingState::kInitial; // Polling related data. Non-null if we're actively polling PollContext* poll_ctx_ = nullptr; + // Initial metadata outstanding token + ClientInitialMetadataOutstandingToken initial_metadata_outstanding_token_; }; class ServerCallData : public BaseCallData { diff --git a/src/core/lib/iomgr/call_combiner.h b/src/core/lib/iomgr/call_combiner.h index 50aeb63c780..e314479413b 100644 --- a/src/core/lib/iomgr/call_combiner.h +++ b/src/core/lib/iomgr/call_combiner.h @@ -171,8 +171,8 @@ class CallCombinerClosureList { if (GRPC_TRACE_FLAG_ENABLED(grpc_call_combiner_trace)) { gpr_log(GPR_INFO, "CallCombinerClosureList executing closure while already " - "holding call_combiner %p: closure=%p error=%s reason=%s", - call_combiner, closures_[0].closure, + "holding call_combiner %p: closure=%s error=%s reason=%s", + call_combiner, closures_[0].closure->DebugString().c_str(), StatusToString(closures_[0].error).c_str(), closures_[0].reason); } // This will release the call combiner. diff --git a/src/core/lib/promise/detail/promise_factory.h b/src/core/lib/promise/detail/promise_factory.h index 12b291a545b..adca5af7f08 100644 --- a/src/core/lib/promise/detail/promise_factory.h +++ b/src/core/lib/promise/detail/promise_factory.h @@ -17,6 +17,7 @@ #include +#include #include #include @@ -106,6 +107,9 @@ class Curried { private: GPR_NO_UNIQUE_ADDRESS F f_; GPR_NO_UNIQUE_ADDRESS Arg arg_; +#ifndef NDEBUG + std::unique_ptr asan_canary_ = std::make_unique(0); +#endif }; // Promote a callable(A) -> T | Poll to a PromiseFactory(A) -> Promise by diff --git a/src/core/lib/promise/interceptor_list.h b/src/core/lib/promise/interceptor_list.h index 2e09f574ae5..6b5576db8cb 100644 --- a/src/core/lib/promise/interceptor_list.h +++ b/src/core/lib/promise/interceptor_list.h @@ -140,7 +140,8 @@ class InterceptorList { async_resolution_.space.get()); async_resolution_.current_factory = async_resolution_.current_factory->next(); - if (async_resolution_.current_factory == nullptr || !p->has_value()) { + if (!p->has_value()) async_resolution_.current_factory = nullptr; + if (async_resolution_.current_factory == nullptr) { return std::move(*p); } async_resolution_.current_factory->MakePromise( diff --git a/src/core/lib/promise/latch.h b/src/core/lib/promise/latch.h index 305cf53ab6e..ee1f9d6e958 100644 --- a/src/core/lib/promise/latch.h +++ b/src/core/lib/promise/latch.h @@ -89,6 +89,8 @@ class Latch { waiter_.Wake(); } + bool is_set() const { return has_value_; } + private: std::string DebugTag() { return absl::StrCat(Activity::current()->DebugTag(), " LATCH[0x", @@ -165,7 +167,7 @@ class Latch { private: std::string DebugTag() { - return absl::StrCat(Activity::current()->DebugTag(), " LATCH[0x", + return absl::StrCat(Activity::current()->DebugTag(), " LATCH(void)[0x", reinterpret_cast(this), "]: "); } diff --git a/src/core/lib/security/transport/server_auth_filter.cc b/src/core/lib/security/transport/server_auth_filter.cc index f5938e99e6c..c67b97388ba 100644 --- a/src/core/lib/security/transport/server_auth_filter.cc +++ b/src/core/lib/security/transport/server_auth_filter.cc @@ -18,12 +18,12 @@ #include -#include - #include #include +#include #include #include +#include #include #include "absl/status/status.h" @@ -41,6 +41,7 @@ #include "src/core/lib/channel/channel_stack.h" #include "src/core/lib/channel/context.h" #include "src/core/lib/channel/promise_based_filter.h" +#include "src/core/lib/debug/trace.h" #include "src/core/lib/gprpp/debug_location.h" #include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/gprpp/status_helper.h" @@ -57,6 +58,7 @@ #include "src/core/lib/security/transport/auth_filters.h" // IWYU pragma: keep #include "src/core/lib/slice/slice.h" #include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/surface/call_trace.h" #include "src/core/lib/transport/metadata_batch.h" #include "src/core/lib/transport/transport.h" @@ -120,12 +122,28 @@ class ServerAuthFilter::RunApplicationCode { // memory later RunApplicationCode(ServerAuthFilter* filter, CallArgs call_args) : state_(GetContext()->ManagedNew(std::move(call_args))) { + if (grpc_call_trace.enabled()) { + gpr_log(GPR_ERROR, + "%s[server-auth]: Delegate to application: filter=%p this=%p " + "auth_ctx=%p", + Activity::current()->DebugTag().c_str(), filter, this, + filter->auth_context_.get()); + } filter->server_credentials_->auth_metadata_processor().process( filter->server_credentials_->auth_metadata_processor().state, filter->auth_context_.get(), state_->md.metadata, state_->md.count, OnMdProcessingDone, state_); } + RunApplicationCode(const RunApplicationCode&) = delete; + RunApplicationCode& operator=(const RunApplicationCode&) = delete; + RunApplicationCode(RunApplicationCode&& other) noexcept + : state_(std::exchange(other.state_, nullptr)) {} + RunApplicationCode& operator=(RunApplicationCode&& other) noexcept { + state_ = std::exchange(other.state_, nullptr); + return *this; + } + Poll> operator()() { if (state_->done.load(std::memory_order_acquire)) { return Poll>(std::move(state_->call_args)); diff --git a/src/core/lib/slice/slice.cc b/src/core/lib/slice/slice.cc index 51ee3a83644..6180ef10e56 100644 --- a/src/core/lib/slice/slice.cc +++ b/src/core/lib/slice/slice.cc @@ -480,7 +480,7 @@ int grpc_slice_slice(grpc_slice haystack, grpc_slice needle) { } const uint8_t* last = haystack_bytes + haystack_len - needle_len; - for (const uint8_t* cur = haystack_bytes; cur != last; ++cur) { + for (const uint8_t* cur = haystack_bytes; cur <= last; ++cur) { if (0 == memcmp(cur, needle_bytes, needle_len)) { return static_cast(cur - haystack_bytes); } diff --git a/src/core/lib/surface/call.cc b/src/core/lib/surface/call.cc index 233ea2f1ece..3836e52e6d9 100644 --- a/src/core/lib/surface/call.cc +++ b/src/core/lib/surface/call.cc @@ -85,6 +85,7 @@ #include "src/core/lib/promise/arena_promise.h" #include "src/core/lib/promise/context.h" #include "src/core/lib/promise/detail/basic_seq.h" +#include "src/core/lib/promise/latch.h" #include "src/core/lib/promise/pipe.h" #include "src/core/lib/promise/poll.h" #include "src/core/lib/resource_quota/arena.h" @@ -134,11 +135,13 @@ class Call : public CppImplOf { virtual void InternalRef(const char* reason) = 0; virtual void InternalUnref(const char* reason) = 0; - virtual grpc_compression_algorithm test_only_compression_algorithm() = 0; - virtual uint32_t test_only_message_flags() = 0; - virtual uint32_t test_only_encodings_accepted_by_peer() = 0; - virtual grpc_compression_algorithm compression_for_level( - grpc_compression_level level) = 0; + grpc_compression_algorithm test_only_compression_algorithm() { + return incoming_compression_algorithm_; + } + uint32_t test_only_message_flags() { return test_only_last_message_flags_; } + CompressionAlgorithmSet encodings_accepted_by_peer() { + return encodings_accepted_by_peer_; + } // This should return nullptr for the promise stack (and alternative means // for that functionality be invented) @@ -205,6 +208,29 @@ class Call : public CppImplOf { void ClearPeerString() { gpr_atm_rel_store(&peer_string_, 0); } + // TODO(ctiller): cancel_func is for cancellation of the call - filter stack + // holds no mutexes here, promise stack does, and so locking is different. + // Remove this and cancel directly once promise conversion is done. + void ProcessIncomingInitialMetadata( + grpc_metadata_batch& md, + absl::FunctionRef cancel_func); + // Fixup outgoing metadata before sending - adds compression, protects + // internal headers against external modification. + void PrepareOutgoingInitialMetadata(const grpc_op& op, + grpc_metadata_batch& md); + void NoteLastMessageFlags(uint32_t flags) { + test_only_last_message_flags_ = flags; + } + grpc_compression_algorithm incoming_compression_algorithm() const { + return incoming_compression_algorithm_; + } + + static void HandleCompressionAlgorithmDisabled( + grpc_compression_algorithm compression_algorithm, + absl::FunctionRef cancel_func) GPR_ATTRIBUTE_NOINLINE; + void HandleCompressionAlgorithmNotAccepted( + grpc_compression_algorithm compression_algorithm) GPR_ATTRIBUTE_NOINLINE; + private: RefCountedPtr channel_; Arena* const arena_; @@ -216,6 +242,13 @@ class Call : public CppImplOf { bool cancellation_is_inherited_ = false; // A char* indicating the peer name. gpr_atm peer_string_ = 0; + // Compression algorithm for *incoming* data + grpc_compression_algorithm incoming_compression_algorithm_ = + GRPC_COMPRESS_NONE; + // Supported encodings (compression algorithms), a bitset. + // Always support no compression. + CompressionAlgorithmSet encodings_accepted_by_peer_{GRPC_COMPRESS_NONE}; + uint32_t test_only_last_message_flags_ = 0; }; Call::ParentCall* Call::GetOrCreateParentCall() { @@ -351,6 +384,92 @@ void Call::DeleteThis() { channel->UpdateCallSizeEstimate(arena->Destroy()); } +void Call::PrepareOutgoingInitialMetadata(const grpc_op& op, + grpc_metadata_batch& md) { + // TODO(juanlishen): If the user has already specified a compression + // algorithm by setting the initial metadata with key of + // GRPC_COMPRESSION_REQUEST_ALGORITHM_MD_KEY, we shouldn't override that + // with the compression algorithm mapped from compression level. + // process compression level + grpc_compression_level effective_compression_level = GRPC_COMPRESS_LEVEL_NONE; + bool level_set = false; + if (op.data.send_initial_metadata.maybe_compression_level.is_set) { + effective_compression_level = + op.data.send_initial_metadata.maybe_compression_level.level; + level_set = true; + } else { + const grpc_compression_options copts = channel()->compression_options(); + if (copts.default_level.is_set) { + level_set = true; + effective_compression_level = copts.default_level.level; + } + } + // Currently, only server side supports compression level setting. + if (level_set && !is_client()) { + const grpc_compression_algorithm calgo = + encodings_accepted_by_peer().CompressionAlgorithmForLevel( + effective_compression_level); + // The following metadata will be checked and removed by the message + // compression filter. It will be used as the call's compression + // algorithm. + md.Set(GrpcInternalEncodingRequest(), calgo); + } + // Ignore any te metadata key value pairs specified. + md.Remove(TeMetadata()); +} + +void Call::ProcessIncomingInitialMetadata( + grpc_metadata_batch& md, + absl::FunctionRef cancel_func) { + incoming_compression_algorithm_ = + md.Take(GrpcEncodingMetadata()).value_or(GRPC_COMPRESS_NONE); + encodings_accepted_by_peer_ = + md.Take(GrpcAcceptEncodingMetadata()) + .value_or(CompressionAlgorithmSet{GRPC_COMPRESS_NONE}); + + const grpc_compression_options compression_options = + channel_->compression_options(); + const grpc_compression_algorithm compression_algorithm = + incoming_compression_algorithm_; + if (GPR_UNLIKELY(!CompressionAlgorithmSet::FromUint32( + compression_options.enabled_algorithms_bitset) + .IsSet(compression_algorithm))) { + // check if algorithm is supported by current channel config + HandleCompressionAlgorithmDisabled(compression_algorithm, cancel_func); + } + // GRPC_COMPRESS_NONE is always set. + GPR_DEBUG_ASSERT(encodings_accepted_by_peer_.IsSet(GRPC_COMPRESS_NONE)); + if (GPR_UNLIKELY(!encodings_accepted_by_peer_.IsSet(compression_algorithm))) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_compression_trace)) { + HandleCompressionAlgorithmNotAccepted(compression_algorithm); + } + } +} + +void Call::HandleCompressionAlgorithmNotAccepted( + grpc_compression_algorithm compression_algorithm) { + const char* algo_name = nullptr; + grpc_compression_algorithm_name(compression_algorithm, &algo_name); + gpr_log(GPR_ERROR, + "Compression algorithm ('%s') not present in the " + "accepted encodings (%s)", + algo_name, + std::string(encodings_accepted_by_peer_.ToString()).c_str()); +} + +void Call::HandleCompressionAlgorithmDisabled( + grpc_compression_algorithm compression_algorithm, + absl::FunctionRef cancel_func) { + const char* algo_name = nullptr; + grpc_compression_algorithm_name(compression_algorithm, &algo_name); + std::string error_msg = + absl::StrFormat("Compression algorithm '%s' is disabled.", algo_name); + gpr_log(GPR_ERROR, "%s", error_msg.c_str()); + cancel_func(grpc_error_set_int(absl::UnimplementedError(error_msg), + StatusIntProperty::kRpcStatus, + GRPC_STATUS_UNIMPLEMENTED)); +} + /////////////////////////////////////////////////////////////////////////////// // FilterStackCall // To be removed once promise conversion is complete @@ -409,11 +528,6 @@ class FilterStackCall final : public Call { return context_[elem].value; } - grpc_compression_algorithm compression_for_level( - grpc_compression_level level) override { - return encodings_accepted_by_peer_.CompressionAlgorithmForLevel(level); - } - bool is_trailers_only() const override { bool result = is_trailers_only_; GPR_DEBUG_ASSERT(!result || recv_initial_metadata_.TransportSize() == 0); @@ -431,18 +545,6 @@ class FilterStackCall final : public Call { return authority_metadata->as_string_view(); } - grpc_compression_algorithm test_only_compression_algorithm() override { - return incoming_compression_algorithm_; - } - - uint32_t test_only_message_flags() override { - return test_only_last_message_flags_; - } - - uint32_t test_only_encodings_accepted_by_peer() override { - return encodings_accepted_by_peer_.ToLegacyBitmask(); - } - static size_t InitialSizeEstimate() { return sizeof(FilterStackCall) + sizeof(BatchControl) * kMaxConcurrentBatches; @@ -523,7 +625,6 @@ class FilterStackCall final : public Call { void FinishStep(PendingOp op); void ProcessDataAfterMetadata(); void ReceivingStreamReady(grpc_error_handle error); - void ValidateFilteredMetadata(); void ReceivingInitialMetadataReady(grpc_error_handle error); void ReceivingTrailingMetadataReady(grpc_error_handle error); void FinishBatch(grpc_error_handle error); @@ -548,10 +649,6 @@ class FilterStackCall final : public Call { grpc_closure* start_batch_closure); void SetFinalStatus(grpc_error_handle error); BatchControl* ReuseOrAllocateBatchControl(const grpc_op* ops); - void HandleCompressionAlgorithmDisabled( - grpc_compression_algorithm compression_algorithm) GPR_ATTRIBUTE_NOINLINE; - void HandleCompressionAlgorithmNotAccepted( - grpc_compression_algorithm compression_algorithm) GPR_ATTRIBUTE_NOINLINE; bool PrepareApplicationMetadata(size_t count, grpc_metadata* metadata, bool is_trailing); void PublishAppMetadata(grpc_metadata_batch* b, bool is_trailing); @@ -595,13 +692,6 @@ class FilterStackCall final : public Call { // completed grpc_call_final_info final_info_; - // Compression algorithm for *incoming* data - grpc_compression_algorithm incoming_compression_algorithm_ = - GRPC_COMPRESS_NONE; - // Supported encodings (compression algorithms), a bitset. - // Always support no compression. - CompressionAlgorithmSet encodings_accepted_by_peer_{GRPC_COMPRESS_NONE}; - // Contexts for various subsystems (security, tracing, ...). grpc_call_context_element context_[GRPC_CONTEXT_COUNT] = {}; @@ -615,7 +705,6 @@ class FilterStackCall final : public Call { grpc_closure receiving_stream_ready_; grpc_closure receiving_initial_metadata_ready_; grpc_closure receiving_trailing_metadata_ready_; - uint32_t test_only_last_message_flags_ = 0; // Status about operation of call bool sent_server_trailing_metadata_ = false; gpr_atm cancelled_with_error_ = 0; @@ -1032,11 +1121,8 @@ void FilterStackCall::PublishAppMetadata(grpc_metadata_batch* b, } void FilterStackCall::RecvInitialFilter(grpc_metadata_batch* b) { - incoming_compression_algorithm_ = - b->Take(GrpcEncodingMetadata()).value_or(GRPC_COMPRESS_NONE); - encodings_accepted_by_peer_ = - b->Take(GrpcAcceptEncodingMetadata()) - .value_or(CompressionAlgorithmSet{GRPC_COMPRESS_NONE}); + ProcessIncomingInitialMetadata( + *b, [this](absl::Status err) { CancelWithError(std::move(err)); }); PublishAppMetadata(b, false); } @@ -1203,11 +1289,11 @@ void FilterStackCall::BatchControl::ProcessDataAfterMetadata() { call->receiving_message_ = false; FinishStep(PendingOp::kRecvMessage); } else { - call->test_only_last_message_flags_ = call->receiving_stream_flags_; + call->NoteLastMessageFlags(call->receiving_stream_flags_); if ((call->receiving_stream_flags_ & GRPC_WRITE_INTERNAL_COMPRESS) && - (call->incoming_compression_algorithm_ != GRPC_COMPRESS_NONE)) { + (call->incoming_compression_algorithm() != GRPC_COMPRESS_NONE)) { *call->receiving_buffer_ = grpc_raw_compressed_byte_buffer_create( - nullptr, 0, call->incoming_compression_algorithm_); + nullptr, 0, call->incoming_compression_algorithm()); } else { *call->receiving_buffer_ = grpc_raw_byte_buffer_create(nullptr, 0); } @@ -1248,50 +1334,6 @@ void FilterStackCall::BatchControl::ReceivingStreamReady( } } -void FilterStackCall::HandleCompressionAlgorithmDisabled( - grpc_compression_algorithm compression_algorithm) { - const char* algo_name = nullptr; - grpc_compression_algorithm_name(compression_algorithm, &algo_name); - std::string error_msg = - absl::StrFormat("Compression algorithm '%s' is disabled.", algo_name); - gpr_log(GPR_ERROR, "%s", error_msg.c_str()); - CancelWithStatus(GRPC_STATUS_UNIMPLEMENTED, error_msg.c_str()); -} - -void FilterStackCall::HandleCompressionAlgorithmNotAccepted( - grpc_compression_algorithm compression_algorithm) { - const char* algo_name = nullptr; - grpc_compression_algorithm_name(compression_algorithm, &algo_name); - gpr_log(GPR_ERROR, - "Compression algorithm ('%s') not present in the " - "accepted encodings (%s)", - algo_name, - std::string(encodings_accepted_by_peer_.ToString()).c_str()); -} - -void FilterStackCall::BatchControl::ValidateFilteredMetadata() { - FilterStackCall* call = call_; - - const grpc_compression_options compression_options = - call->channel()->compression_options(); - const grpc_compression_algorithm compression_algorithm = - call->incoming_compression_algorithm_; - if (GPR_UNLIKELY(!CompressionAlgorithmSet::FromUint32( - compression_options.enabled_algorithms_bitset) - .IsSet(compression_algorithm))) { - // check if algorithm is supported by current channel config - call->HandleCompressionAlgorithmDisabled(compression_algorithm); - } - // GRPC_COMPRESS_NONE is always set. - GPR_DEBUG_ASSERT(call->encodings_accepted_by_peer_.IsSet(GRPC_COMPRESS_NONE)); - if (GPR_UNLIKELY( - !call->encodings_accepted_by_peer_.IsSet(compression_algorithm))) { - if (GRPC_TRACE_FLAG_ENABLED(grpc_compression_trace)) { - call->HandleCompressionAlgorithmNotAccepted(compression_algorithm); - } - } -} - void FilterStackCall::BatchControl::ReceivingInitialMetadataReady( grpc_error_handle error) { FilterStackCall* call = call_; @@ -1302,9 +1344,6 @@ void FilterStackCall::BatchControl::ReceivingInitialMetadataReady( grpc_metadata_batch* md = &call->recv_initial_metadata_; call->RecvInitialFilter(md); - // TODO(ctiller): this could be moved into recv_initial_filter now - ValidateFilteredMetadata(); - absl::optional deadline = md->get(GrpcTimeoutMetadata()); if (deadline.has_value() && !call->is_client()) { call_->set_send_deadline(*deadline); @@ -1453,36 +1492,6 @@ grpc_call_error FilterStackCall::StartBatch(const grpc_op* ops, size_t nops, error = GRPC_CALL_ERROR_TOO_MANY_OPERATIONS; goto done_with_error; } - // TODO(juanlishen): If the user has already specified a compression - // algorithm by setting the initial metadata with key of - // GRPC_COMPRESSION_REQUEST_ALGORITHM_MD_KEY, we shouldn't override that - // with the compression algorithm mapped from compression level. - // process compression level - grpc_compression_level effective_compression_level = - GRPC_COMPRESS_LEVEL_NONE; - bool level_set = false; - if (op->data.send_initial_metadata.maybe_compression_level.is_set) { - effective_compression_level = - op->data.send_initial_metadata.maybe_compression_level.level; - level_set = true; - } else { - const grpc_compression_options copts = - channel()->compression_options(); - if (copts.default_level.is_set) { - level_set = true; - effective_compression_level = copts.default_level.level; - } - } - // Currently, only server side supports compression level setting. - if (level_set && !is_client()) { - const grpc_compression_algorithm calgo = - encodings_accepted_by_peer_.CompressionAlgorithmForLevel( - effective_compression_level); - // The following metadata will be checked and removed by the message - // compression filter. It will be used as the call's compression - // algorithm. - send_initial_metadata_.Set(GrpcInternalEncodingRequest(), calgo); - } if (op->data.send_initial_metadata.count > INT_MAX) { error = GRPC_CALL_ERROR_INVALID_METADATA; goto done_with_error; @@ -1495,8 +1504,7 @@ grpc_call_error FilterStackCall::StartBatch(const grpc_op* ops, size_t nops, error = GRPC_CALL_ERROR_INVALID_METADATA; goto done_with_error; } - // Ignore any te metadata key value pairs specified. - send_initial_metadata_.Remove(TeMetadata()); + PrepareOutgoingInitialMetadata(*op, send_initial_metadata_); // TODO(ctiller): just make these the same variable? if (is_client() && send_deadline() != Timestamp::InfFuture()) { send_initial_metadata_.Set(GrpcTimeoutMetadata(), send_deadline()); @@ -2029,16 +2037,6 @@ class PromiseBasedCall : public Call, } } - grpc_compression_algorithm test_only_compression_algorithm() override { - abort(); - } - uint32_t test_only_message_flags() override { abort(); } - uint32_t test_only_encodings_accepted_by_peer() override { abort(); } - grpc_compression_algorithm compression_for_level( - grpc_compression_level) override { - abort(); - } - // This should return nullptr for the promise stack (and alternative means // for that functionality be invented) grpc_call_stack* call_stack() override { return nullptr; } @@ -2096,6 +2094,11 @@ class PromiseBasedCall : public Call, ~PromiseBasedCall() override { if (non_owning_wakeable_) non_owning_wakeable_->DropActivity(); if (cq_) GRPC_CQ_INTERNAL_UNREF(cq_, "bind"); + for (int i = 0; i < GRPC_CONTEXT_COUNT; i++) { + if (context_[i].destroy) { + context_[i].destroy(context_[i].value); + } + } } // Enumerates why a Completion is still pending @@ -2103,6 +2106,7 @@ class PromiseBasedCall : public Call, // We're in the midst of starting a batch of operations kStartingBatch = 0, // The following correspond with the batch operations from above + kSendInitialMetadata, kReceiveInitialMetadata, kReceiveStatusOnClient, kReceiveCloseOnServer = kReceiveStatusOnClient, @@ -2116,6 +2120,8 @@ class PromiseBasedCall : public Call, switch (reason) { case PendingOp::kStartingBatch: return "StartingBatch"; + case PendingOp::kSendInitialMetadata: + return "SendInitialMetadata"; case PendingOp::kReceiveInitialMetadata: return "ReceiveInitialMetadata"; case PendingOp::kReceiveStatusOnClient: @@ -2146,10 +2152,23 @@ class PromiseBasedCall : public Call, ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); // Stringify a completion std::string CompletionString(const Completion& completion) const { + auto& pending = completion_info_[completion.index()].pending; + const char* success_string = ":unknown"; + switch (pending.success) { + case CompletionSuccess::kSuccess: + success_string = ""; + break; + case CompletionSuccess::kFailure: + success_string = ":failure"; + break; + case CompletionSuccess::kForceSuccess: + success_string = ":force-success"; + break; + } return completion.has_value() - ? absl::StrFormat( - "%d:tag=%p", static_cast(completion.index()), - completion_info_[completion.index()].pending.tag) + ? absl::StrFormat("%d:tag=%p%s", + static_cast(completion.index()), + pending.tag, success_string) : "no-completion"; } // Finish one op on the completion. Must have been previously been added. @@ -2159,6 +2178,9 @@ class PromiseBasedCall : public Call, // Mark the completion as failed. Does not finish it. void FailCompletion(const Completion& completion, SourceLocation source_location = {}); + // Mark the completion as infallible. Overrides FailCompletion to report + // success always. + void ForceCompletionSuccess(const Completion& completion); // Run the promise polling loop until it stalls. void Update() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); // Update the promise state once. @@ -2224,11 +2246,12 @@ class PromiseBasedCall : public Call, ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); void PollRecvMessage(grpc_compression_algorithm compression_algorithm) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); - void CancelRecvMessage() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + void CancelRecvMessage(SourceLocation = {}) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); void StartSendMessage(const grpc_op& op, const Completion& completion, PipeSender* sender) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); - bool PollSendMessage() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + void PollSendMessage() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); void CancelSendMessage() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); bool completed() const ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) { @@ -2238,14 +2261,18 @@ class PromiseBasedCall : public Call, bool is_sending() const ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) { return outstanding_send_.has_value(); } + bool is_receiving() const ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + return outstanding_recv_.has_value(); + } private: + enum class CompletionSuccess : uint8_t { kSuccess, kFailure, kForceSuccess }; union CompletionInfo { struct Pending { // Bitmask of PendingOps uint8_t pending_op_bits; bool is_closure; - bool success; + CompletionSuccess success; void* tag; } pending; grpc_cq_completion completion; @@ -2258,8 +2285,8 @@ class PromiseBasedCall : public Call, // Ref the Handle (not the activity). void Ref() { refs_.fetch_add(1, std::memory_order_relaxed); } - // Activity is going away... drop its reference and sever the connection - // back. + // Activity is going away... drop its reference and sever the + // connection back. void DropActivity() ABSL_LOCKS_EXCLUDED(mu_) { auto unref = absl::MakeCleanup([this]() { Unref(); }); MutexLock lock(&mu_); @@ -2267,11 +2294,11 @@ class PromiseBasedCall : public Call, call_ = nullptr; } - // Activity needs to wake up (if it still exists!) - wake it up, and drop - // the ref that was kept for this handle. + // Activity needs to wake up (if it still exists!) - wake it up, + // and drop the ref that was kept for this handle. void Wakeup() override ABSL_LOCKS_EXCLUDED(mu_) { - // Drop the ref to the handle at end of scope (we have one ref = one - // wakeup semantics). + // Drop the ref to the handle at end of scope (we have one ref + // = one wakeup semantics). auto unref = absl::MakeCleanup([this]() { Unref(); }); ReleasableMutexLock lock(&mu_); // Note that activity refcount can drop to zero, but we could win the lock @@ -2280,8 +2307,8 @@ class PromiseBasedCall : public Call, PromiseBasedCall* call = call_; if (call != nullptr && call->RefIfNonZero()) { lock.Release(); - // Activity still exists and we have a reference: wake it up, which will - // drop the ref. + // Activity still exists and we have a reference: wake it up, + // which will drop the ref. call->Wakeup(); } } @@ -2302,9 +2329,9 @@ class PromiseBasedCall : public Call, } mutable Mutex mu_; - // We have two initial refs: one for the wakeup that this is created for, - // and will be dropped by Wakeup, and the other for the activity which is - // dropped by DropActivity. + // We have two initial refs: one for the wakeup that this is + // created for, and will be dropped by Wakeup, and the other for + // the activity which is dropped by DropActivity. std::atomic refs_{2}; PromiseBasedCall* call_ ABSL_GUARDED_BY(mu_); }; @@ -2438,15 +2465,16 @@ void* PromiseBasedCall::ContextGet(grpc_context_index elem) const { PromiseBasedCall::Completion PromiseBasedCall::StartCompletion( void* tag, bool is_closure, const grpc_op* ops) { Completion c(BatchSlotForOp(ops[0].op)); - if (grpc_call_trace.enabled()) { - gpr_log(GPR_INFO, "%s[call] StartCompletion %s tag=%p", DebugTag().c_str(), - CompletionString(c).c_str(), tag); - } if (!is_closure) { grpc_cq_begin_op(cq(), tag); } completion_info_[c.index()].pending = { - PendingOpBit(PendingOp::kStartingBatch), is_closure, true, tag}; + PendingOpBit(PendingOp::kStartingBatch), is_closure, + CompletionSuccess::kSuccess, tag}; + if (grpc_call_trace.enabled()) { + gpr_log(GPR_INFO, "%s[call] StartCompletion %s tag=%p", DebugTag().c_str(), + CompletionString(c).c_str(), tag); + } return c; } @@ -2471,7 +2499,27 @@ void PromiseBasedCall::FailCompletion(const Completion& completion, "%s[call] FailCompletion %s", DebugTag().c_str(), CompletionString(completion).c_str()); } - completion_info_[completion.index()].pending.success = false; + auto& success = completion_info_[completion.index()].pending.success; + switch (success) { + case CompletionSuccess::kSuccess: + success = CompletionSuccess::kFailure; + break; + case CompletionSuccess::kFailure: + case CompletionSuccess::kForceSuccess: + break; + } +} + +void PromiseBasedCall::ForceCompletionSuccess(const Completion& completion) { + auto& success = completion_info_[completion.index()].pending.success; + switch (success) { + case CompletionSuccess::kSuccess: + case CompletionSuccess::kFailure: + success = CompletionSuccess::kForceSuccess; + break; + case CompletionSuccess::kForceSuccess: + break; + } } void PromiseBasedCall::FinishOpOnCompletion(Completion* completion, @@ -2479,7 +2527,8 @@ void PromiseBasedCall::FinishOpOnCompletion(Completion* completion, if (grpc_call_trace.enabled()) { auto pending_op_bits = completion_info_[completion->index()].pending.pending_op_bits; - bool success = completion_info_[completion->index()].pending.success; + bool success = completion_info_[completion->index()].pending.success != + CompletionSuccess::kFailure; std::vector pending; for (size_t i = 0; i < 8 * sizeof(pending_op_bits); i++) { if (static_cast(i) == reason) continue; @@ -2488,7 +2537,9 @@ void PromiseBasedCall::FinishOpOnCompletion(Completion* completion, } } gpr_log( - GPR_INFO, "%s[call] FinishOpOnCompletion tag:%p %s %s %s", + GPR_INFO, + "%s[call] FinishOpOnCompletion tag:%p completion:%s " + "finish:%s %s", DebugTag().c_str(), completion_info_[completion->index()].pending.tag, CompletionString(*completion).c_str(), PendingOpString(reason), (pending.empty() @@ -2501,8 +2552,9 @@ void PromiseBasedCall::FinishOpOnCompletion(Completion* completion, CompletionInfo::Pending& pending = completion_info_[i].pending; GPR_ASSERT(pending.pending_op_bits & PendingOpBit(reason)); pending.pending_op_bits &= ~PendingOpBit(reason); - auto error = pending.success ? absl::OkStatus() : absl::CancelledError(); if (pending.pending_op_bits == 0) { + const bool success = pending.success != CompletionSuccess::kFailure; + auto error = success ? absl::OkStatus() : absl::CancelledError(); if (pending.is_closure) { ExecCtx::Run(DEBUG_LOCATION, static_cast(pending.tag), error); @@ -2579,26 +2631,23 @@ void PromiseBasedCall::StartSendMessage(const grpc_op& op, } } -bool PromiseBasedCall::PollSendMessage() { - if (!outstanding_send_.has_value()) return true; +void PromiseBasedCall::PollSendMessage() { + if (!outstanding_send_.has_value()) return; Poll r = (*outstanding_send_)(); if (const bool* result = absl::get_if(&r)) { if (grpc_call_trace.enabled()) { gpr_log(GPR_DEBUG, "%sPollSendMessage completes %s", DebugTag().c_str(), *result ? "successfully" : "with failure"); } - if (!*result) { - FailCompletion(send_message_completion_); - return false; - } + if (!*result) FailCompletion(send_message_completion_); FinishOpOnCompletion(&send_message_completion_, PendingOp::kSendMessage); outstanding_send_.reset(); } - return true; } void PromiseBasedCall::CancelSendMessage() { if (!outstanding_send_.has_value()) return; + FailCompletion(send_message_completion_); FinishOpOnCompletion(&send_message_completion_, PendingOp::kSendMessage); outstanding_send_.reset(); } @@ -2621,6 +2670,7 @@ void PromiseBasedCall::PollRecvMessage( outstanding_recv_.reset(); if (result->has_value()) { MessageHandle& message = **result; + NoteLastMessageFlags(message->flags()); if ((message->flags() & GRPC_WRITE_INTERNAL_COMPRESS) && (incoming_compression_algorithm != GRPC_COMPRESS_NONE)) { *recv_message_ = grpc_raw_compressed_byte_buffer_create( @@ -2632,7 +2682,8 @@ void PromiseBasedCall::PollRecvMessage( &(*recv_message_)->data.raw.slice_buffer); if (grpc_call_trace.enabled()) { gpr_log(GPR_INFO, - "%s[call] PollRecvMessage: outstanding_recv finishes: received " + "%s[call] PollRecvMessage: outstanding_recv " + "finishes: received " "%" PRIdPTR " byte message", DebugTag().c_str(), (*recv_message_)->data.raw.slice_buffer.length); @@ -2640,7 +2691,8 @@ void PromiseBasedCall::PollRecvMessage( } else if (result->cancelled()) { if (grpc_call_trace.enabled()) { gpr_log(GPR_INFO, - "%s[call] PollRecvMessage: outstanding_recv finishes: received " + "%s[call] PollRecvMessage: outstanding_recv " + "finishes: received " "end-of-stream with error", DebugTag().c_str()); } @@ -2649,7 +2701,8 @@ void PromiseBasedCall::PollRecvMessage( } else { if (grpc_call_trace.enabled()) { gpr_log(GPR_INFO, - "%s[call] PollRecvMessage: outstanding_recv finishes: received " + "%s[call] PollRecvMessage: outstanding_recv " + "finishes: received " "end-of-stream", DebugTag().c_str()); } @@ -2659,8 +2712,10 @@ void PromiseBasedCall::PollRecvMessage( } else if (completed_) { if (grpc_call_trace.enabled()) { gpr_log(GPR_INFO, - "%s[call] UpdateOnce: outstanding_recv finishes: promise has " - "completed without queuing a message, forcing end-of-stream", + "%s[call] UpdateOnce: outstanding_recv finishes: " + "promise has " + "completed without queuing a message, forcing " + "end-of-stream", DebugTag().c_str()); } outstanding_recv_.reset(); @@ -2669,10 +2724,18 @@ void PromiseBasedCall::PollRecvMessage( } } -void PromiseBasedCall::CancelRecvMessage() { +void PromiseBasedCall::CancelRecvMessage(SourceLocation location) { if (!outstanding_recv_.has_value()) return; + if (grpc_call_trace.enabled()) { + gpr_log(location.file(), location.line(), GPR_LOG_SEVERITY_DEBUG, + "%s[call] CancelRecvMessage: op:%s outstanding_recv:%s", + DebugTag().c_str(), + CompletionString(recv_message_completion_).c_str(), + outstanding_recv_.has_value() ? "present" : "absent"); + } *recv_message_ = nullptr; outstanding_recv_.reset(); + FailCompletion(recv_message_completion_); FinishOpOnCompletion(&recv_message_completion_, PendingOp::kReceiveMessage); } @@ -2748,9 +2811,9 @@ class ClientPromiseBasedCall final : public PromiseBasedCall { send_initial_metadata_.reset(); recv_status_on_client_ = absl::monostate(); promise_ = ArenaPromise(); - // Need to destroy the pipes under the ScopedContext above, so we move them - // out here and then allow the destructors to run at end of scope, but - // before context. + // Need to destroy the pipes under the ScopedContext above, so we + // move them out here and then allow the destructors to run at + // end of scope, but before context. auto c2s = std::move(client_to_server_messages_); auto s2c = std::move(server_to_client_messages_); auto sim = std::move(server_initial_metadata_); @@ -2802,7 +2865,6 @@ class ClientPromiseBasedCall final : public PromiseBasedCall { arena()}; Pipe client_to_server_messages_ ABSL_GUARDED_BY(mu()){arena()}; Pipe server_to_client_messages_ ABSL_GUARDED_BY(mu()){arena()}; - ClientMetadataHandle send_initial_metadata_; grpc_metadata_array* recv_initial_metadata_ ABSL_GUARDED_BY(mu()) = nullptr; absl::variant> server_initial_metadata_ready_; - absl::optional incoming_compression_algorithm_; + absl::optional + client_initial_metadata_sent_; + Completion send_initial_metadata_completion_ ABSL_GUARDED_BY(mu()); Completion recv_initial_metadata_completion_ ABSL_GUARDED_BY(mu()); Completion recv_status_on_client_completion_ ABSL_GUARDED_BY(mu()); Completion close_send_completion_ ABSL_GUARDED_BY(mu()); @@ -2821,12 +2885,12 @@ class ClientPromiseBasedCall final : public PromiseBasedCall { void ClientPromiseBasedCall::StartPromise( ClientMetadataHandle client_initial_metadata) { GPR_ASSERT(!promise_.has_value()); + auto token = ClientInitialMetadataOutstandingToken::New(); + client_initial_metadata_sent_.emplace(token.Wait()); promise_ = channel()->channel_stack()->MakeClientCallPromise(CallArgs{ - std::move(client_initial_metadata), - &server_initial_metadata_.sender, - &client_to_server_messages_.receiver, - &server_to_client_messages_.sender, - }); + std::move(client_initial_metadata), std::move(token), + &server_initial_metadata_.sender, &client_to_server_messages_.receiver, + &server_to_client_messages_.sender}); } void ClientPromiseBasedCall::CancelWithErrorLocked(grpc_error_handle error) { @@ -2876,13 +2940,22 @@ void ClientPromiseBasedCall::CommitBatch(const grpc_op* ops, size_t nops, const grpc_op& op = ops[op_idx]; switch (op.op) { case GRPC_OP_SEND_INITIAL_METADATA: { - // compression not implemented - GPR_ASSERT( - !op.data.send_initial_metadata.maybe_compression_level.is_set); if (!completed()) { CToMetadata(op.data.send_initial_metadata.metadata, op.data.send_initial_metadata.count, send_initial_metadata_.get()); + PrepareOutgoingInitialMetadata(op, *send_initial_metadata_); + if (send_deadline() != Timestamp::InfFuture()) { + send_initial_metadata_->Set(GrpcTimeoutMetadata(), send_deadline()); + } + send_initial_metadata_->Set( + WaitForReady(), + WaitForReady::ValueType{ + (op.flags & GRPC_INITIAL_METADATA_WAIT_FOR_READY) != 0, + (op.flags & + GRPC_INITIAL_METADATA_WAIT_FOR_READY_EXPLICITLY_SET) != 0}); + send_initial_metadata_completion_ = + AddOpToCompletion(completion, PendingOp::kSendInitialMetadata); StartPromise(std::move(send_initial_metadata_)); } } break; @@ -2897,6 +2970,7 @@ void ClientPromiseBasedCall::CommitBatch(const grpc_op* ops, size_t nops, case GRPC_OP_RECV_STATUS_ON_CLIENT: { recv_status_on_client_completion_ = AddOpToCompletion(completion, PendingOp::kReceiveStatusOnClient); + ForceCompletionSuccess(completion); if (auto* finished_metadata = absl::get_if(&recv_status_on_client_)) { PublishStatus(op.data.recv_status_on_client, @@ -2946,8 +3020,10 @@ grpc_call_error ClientPromiseBasedCall::StartBatch(const grpc_op* ops, } void ClientPromiseBasedCall::PublishInitialMetadata(ServerMetadata* metadata) { - incoming_compression_algorithm_ = - metadata->Take(GrpcEncodingMetadata()).value_or(GRPC_COMPRESS_NONE); + ProcessIncomingInitialMetadata(*metadata, [this](absl::Status status) { + mu()->AssertHeld(); + CancelWithErrorLocked(std::move(status)); + }); server_initial_metadata_ready_.reset(); GPR_ASSERT(recv_initial_metadata_ != nullptr); PublishMetadataArray(metadata, @@ -2967,21 +3043,33 @@ void ClientPromiseBasedCall::UpdateOnce() { PollStateDebugString().c_str(), promise_.has_value() ? "true" : "false"); } + if (client_initial_metadata_sent_.has_value()) { + Poll p = (*client_initial_metadata_sent_)(); + bool* r = absl::get_if(&p); + if (r != nullptr) { + client_initial_metadata_sent_.reset(); + if (!*r) FailCompletion(send_initial_metadata_completion_); + FinishOpOnCompletion(&send_initial_metadata_completion_, + PendingOp::kSendInitialMetadata); + } + } if (server_initial_metadata_ready_.has_value()) { Poll> r = (*server_initial_metadata_ready_)(); if (auto* server_initial_metadata = absl::get_if>(&r)) { - PublishInitialMetadata(server_initial_metadata->value().get()); + if (server_initial_metadata->has_value()) { + PublishInitialMetadata(server_initial_metadata->value().get()); + } else { + ServerMetadata no_metadata{GetContext()}; + PublishInitialMetadata(&no_metadata); + } } else if (completed()) { ServerMetadata no_metadata{GetContext()}; PublishInitialMetadata(&no_metadata); } } - if (!PollSendMessage()) { - Finish(ServerMetadataFromStatus(absl::Status( - absl::StatusCode::kInternal, "Failed to send message to server"))); - } + PollSendMessage(); if (!is_sending() && close_send_completion_.has_value()) { client_to_server_messages_.sender.Close(); FinishOpOnCompletion(&close_send_completion_, @@ -2997,12 +3085,15 @@ void ClientPromiseBasedCall::UpdateOnce() { }).c_str()); } if (auto* result = absl::get_if(&r)) { + if (!server_initial_metadata_ready_.has_value()) { + PollRecvMessage(incoming_compression_algorithm()); + } AcceptTransportStatsFromContext(); Finish(std::move(*result)); } } - if (incoming_compression_algorithm_.has_value()) { - PollRecvMessage(*incoming_compression_algorithm_); + if (!server_initial_metadata_ready_.has_value()) { + PollRecvMessage(incoming_compression_algorithm()); } } @@ -3012,6 +3103,8 @@ void ClientPromiseBasedCall::Finish(ServerMetadataHandle trailing_metadata) { trailing_metadata->DebugString().c_str()); } promise_ = ArenaPromise(); + CancelSendMessage(); + CancelRecvMessage(); ResetDeadline(); set_completed(); if (recv_initial_metadata_ != nullptr) { @@ -3027,8 +3120,18 @@ void ClientPromiseBasedCall::Finish(ServerMetadataHandle trailing_metadata) { (*server_initial_metadata_ready_)(); server_initial_metadata_ready_.reset(); if (auto* result = absl::get_if>(&r)) { - if (pending_initial_metadata) PublishInitialMetadata(result->value().get()); - is_trailers_only_ = false; + if (pending_initial_metadata) { + if (result->has_value()) { + PublishInitialMetadata(result->value().get()); + is_trailers_only_ = false; + } else { + ServerMetadata no_metadata{GetContext()}; + PublishInitialMetadata(&no_metadata); + is_trailers_only_ = true; + } + } else { + is_trailers_only_ = false; + } } else { if (pending_initial_metadata) { ServerMetadata no_metadata{GetContext()}; @@ -3159,9 +3262,10 @@ class ServerPromiseBasedCall final : public PromiseBasedCall { private: class RecvCloseOpCancelState { public: - // Request that receiver be filled in per grpc_op_recv_close_on_server. - // Returns true if the request can be fulfilled immediately. - // Returns false if the request will be fulfilled later. + // Request that receiver be filled in per + // grpc_op_recv_close_on_server. Returns true if the request can + // be fulfilled immediately. Returns false if the request will be + // fulfilled later. bool ReceiveCloseOnServerOpStarted(int* receiver) { switch (state_) { case kUnset: @@ -3179,18 +3283,20 @@ class ServerPromiseBasedCall final : public PromiseBasedCall { } // Mark the call as having completed. - // Returns true if this finishes a previous RequestReceiveCloseOnServer. - bool CompleteCall(bool success) { + // Returns true if this finishes a previous + // RequestReceiveCloseOnServer. + bool CompleteCallWithCancelledSetTo(bool cancelled) { switch (state_) { case kUnset: - state_ = success ? kFinishedWithSuccess : kFinishedWithFailure; + state_ = cancelled ? kFinishedWithFailure : kFinishedWithSuccess; return false; case kFinishedWithFailure: + return false; case kFinishedWithSuccess: abort(); // unreachable default: - *reinterpret_cast(state_) = success ? 0 : 1; - state_ = success ? kFinishedWithSuccess : kFinishedWithFailure; + *reinterpret_cast(state_) = cancelled ? 1 : 0; + state_ = cancelled ? kFinishedWithFailure : kFinishedWithSuccess; return true; } } @@ -3213,8 +3319,9 @@ class ServerPromiseBasedCall final : public PromiseBasedCall { static constexpr uintptr_t kUnset = 0; static constexpr uintptr_t kFinishedWithFailure = 1; static constexpr uintptr_t kFinishedWithSuccess = 2; - // Holds one of kUnset, kFinishedWithFailure, or kFinishedWithSuccess - // OR an int* that wants to receive the final status. + // Holds one of kUnset, kFinishedWithFailure, or + // kFinishedWithSuccess OR an int* that wants to receive the + // final status. uintptr_t state_ = kUnset; }; @@ -3222,6 +3329,10 @@ class ServerPromiseBasedCall final : public PromiseBasedCall { void CommitBatch(const grpc_op* ops, size_t nops, const Completion& completion) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu()); + void CommitBatchAfterCompletion(const grpc_op* ops, size_t nops, + const Completion& completion) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu()); + void Finish(ServerMetadataHandle result) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu()); friend class ServerCallContext; ServerCallContext call_context_; @@ -3239,6 +3350,7 @@ class ServerPromiseBasedCall final : public PromiseBasedCall { ServerMetadataHandle send_trailing_metadata_ ABSL_GUARDED_BY(mu()); grpc_compression_algorithm incoming_compression_algorithm_ ABSL_GUARDED_BY(mu()); + bool force_metadata_send_ ABSL_GUARDED_BY(mu()) = false; RecvCloseOpCancelState recv_close_op_cancel_state_ ABSL_GUARDED_BY(mu()); Completion recv_close_completion_ ABSL_GUARDED_BY(mu()); bool cancel_send_and_receive_ ABSL_GUARDED_BY(mu()) = false; @@ -3259,7 +3371,8 @@ ServerPromiseBasedCall::ServerPromiseBasedCall(Arena* arena, MutexLock lock(mu()); ScopedContext activity_context(this); promise_ = channel()->channel_stack()->MakeServerCallPromise( - CallArgs{nullptr, nullptr, nullptr, nullptr}); + CallArgs{nullptr, ClientInitialMetadataOutstandingToken::Empty(), nullptr, + nullptr, nullptr}); } Poll ServerPromiseBasedCall::PollTopOfCall() { @@ -3282,6 +3395,13 @@ Poll ServerPromiseBasedCall::PollTopOfCall() { PollSendMessage(); PollRecvMessage(incoming_compression_algorithm_); + if (force_metadata_send_) GPR_ASSERT(!is_sending()); + + if (grpc_call_trace.enabled()) { + gpr_log(GPR_INFO, "%s[call] PollTopOfCall: is_sending=%s", + DebugTag().c_str(), is_sending() ? "yes" : "no"); + } + if (!is_sending() && send_trailing_metadata_ != nullptr) { server_to_client_messages_->Close(); return std::move(send_trailing_metadata_); @@ -3325,39 +3445,53 @@ void ServerPromiseBasedCall::UpdateOnce() { }).c_str()); } if (auto* result = absl::get_if(&r)) { - if (grpc_call_trace.enabled()) { - gpr_log(GPR_INFO, "%s[call] UpdateOnce: GotResult %s result:%s", - DebugTag().c_str(), - recv_close_op_cancel_state_.ToString().c_str(), - (*result)->DebugString().c_str()); - } - if (recv_close_op_cancel_state_.CompleteCall( - (*result)->get(GrpcStatusFromWire()).value_or(false))) { - FinishOpOnCompletion(&recv_close_completion_, - PendingOp::kReceiveCloseOnServer); + if (!(*result)->get(GrpcCallWasCancelled()).value_or(false)) { + PollRecvMessage(incoming_compression_algorithm_); } - channelz::ServerNode* channelz_node = server_->channelz_node(); - if (channelz_node != nullptr) { - if ((*result) - ->get(GrpcStatusMetadata()) - .value_or(GRPC_STATUS_UNKNOWN) == GRPC_STATUS_OK) { - channelz_node->RecordCallSucceeded(); - } else { - channelz_node->RecordCallFailed(); - } - } - if (send_status_from_server_completion_.has_value()) { - FinishOpOnCompletion(&send_status_from_server_completion_, - PendingOp::kSendStatusFromServer); - } - CancelSendMessage(); - CancelRecvMessage(); - set_completed(); + Finish(std::move(*result)); promise_ = ArenaPromise(); } } } +void ServerPromiseBasedCall::Finish(ServerMetadataHandle result) { + if (grpc_call_trace.enabled()) { + gpr_log(GPR_INFO, "%s[call] UpdateOnce: GotResult %s result:%s", + DebugTag().c_str(), recv_close_op_cancel_state_.ToString().c_str(), + result->DebugString().c_str()); + } + if (recv_close_op_cancel_state_.CompleteCallWithCancelledSetTo( + result->get(GrpcCallWasCancelled()).value_or(true))) { + FinishOpOnCompletion(&recv_close_completion_, + PendingOp::kReceiveCloseOnServer); + } + channelz::ServerNode* channelz_node = server_->channelz_node(); + if (channelz_node != nullptr) { + if (result->get(GrpcStatusMetadata()).value_or(GRPC_STATUS_UNKNOWN) == + GRPC_STATUS_OK) { + channelz_node->RecordCallSucceeded(); + } else { + channelz_node->RecordCallFailed(); + } + } + if (send_status_from_server_completion_.has_value()) { + FinishOpOnCompletion(&send_status_from_server_completion_, + PendingOp::kSendStatusFromServer); + } + CancelSendMessage(); + CancelRecvMessage(); + set_completed(); + // TODO(ctiller): this will probably need to be inlined somehow for + // performance + InternalRef("propagate_cancel"); + channel()->event_engine()->Run([this]() { + ApplicationCallbackExecCtx callback_exec_ctx; + ExecCtx exec_ctx; + PropagateCancellationToChildren(); + InternalUnref("propagate_cancel"); + }); +} + grpc_call_error ServerPromiseBasedCall::ValidateBatch(const grpc_op* ops, size_t nops) const { BitSet<8> got_ops; @@ -3400,21 +3534,17 @@ void ServerPromiseBasedCall::CommitBatch(const grpc_op* ops, size_t nops, const grpc_op& op = ops[op_idx]; switch (op.op) { case GRPC_OP_SEND_INITIAL_METADATA: { - // compression not implemented - GPR_ASSERT( - !op.data.send_initial_metadata.maybe_compression_level.is_set); - if (!completed()) { - auto metadata = arena()->MakePooled(arena()); - CToMetadata(op.data.send_initial_metadata.metadata, - op.data.send_initial_metadata.count, metadata.get()); - if (grpc_call_trace.enabled()) { - gpr_log(GPR_INFO, "%s[call] Send initial metadata", - DebugTag().c_str()); - } - auto* pipe = absl::get*>( - send_initial_metadata_state_); - send_initial_metadata_state_ = pipe->Push(std::move(metadata)); + auto metadata = arena()->MakePooled(arena()); + PrepareOutgoingInitialMetadata(op, *metadata); + CToMetadata(op.data.send_initial_metadata.metadata, + op.data.send_initial_metadata.count, metadata.get()); + if (grpc_call_trace.enabled()) { + gpr_log(GPR_INFO, "%s[call] Send initial metadata", + DebugTag().c_str()); } + auto* pipe = absl::get*>( + send_initial_metadata_state_); + send_initial_metadata_state_ = pipe->Push(std::move(metadata)); } break; case GRPC_OP_SEND_MESSAGE: StartSendMessage(op, completion, server_to_client_messages_); @@ -3442,6 +3572,7 @@ void ServerPromiseBasedCall::CommitBatch(const grpc_op* ops, size_t nops, DebugTag().c_str(), recv_close_op_cancel_state_.ToString().c_str()); } + ForceCompletionSuccess(completion); if (!recv_close_op_cancel_state_.ReceiveCloseOnServerOpStarted( op.data.recv_close_on_server.cancelled)) { recv_close_completion_ = @@ -3456,6 +3587,35 @@ void ServerPromiseBasedCall::CommitBatch(const grpc_op* ops, size_t nops, } } +void ServerPromiseBasedCall::CommitBatchAfterCompletion( + const grpc_op* ops, size_t nops, const Completion& completion) { + for (size_t op_idx = 0; op_idx < nops; op_idx++) { + const grpc_op& op = ops[op_idx]; + switch (op.op) { + case GRPC_OP_SEND_INITIAL_METADATA: + case GRPC_OP_SEND_MESSAGE: + case GRPC_OP_RECV_MESSAGE: + case GRPC_OP_SEND_STATUS_FROM_SERVER: + FailCompletion(completion); + break; + case GRPC_OP_RECV_CLOSE_ON_SERVER: + if (grpc_call_trace.enabled()) { + gpr_log(GPR_INFO, "%s[call] StartBatch: RecvClose %s", + DebugTag().c_str(), + recv_close_op_cancel_state_.ToString().c_str()); + } + ForceCompletionSuccess(completion); + GPR_ASSERT(recv_close_op_cancel_state_.ReceiveCloseOnServerOpStarted( + op.data.recv_close_on_server.cancelled)); + break; + case GRPC_OP_RECV_STATUS_ON_CLIENT: + case GRPC_OP_SEND_CLOSE_FROM_CLIENT: + case GRPC_OP_RECV_INITIAL_METADATA: + abort(); // unreachable + } + } +} + grpc_call_error ServerPromiseBasedCall::StartBatch(const grpc_op* ops, size_t nops, void* notify_tag, @@ -3472,8 +3632,12 @@ grpc_call_error ServerPromiseBasedCall::StartBatch(const grpc_op* ops, } Completion completion = StartCompletion(notify_tag, is_notify_tag_closure, ops); - CommitBatch(ops, nops, completion); - Update(); + if (!completed()) { + CommitBatch(ops, nops, completion); + Update(); + } else { + CommitBatchAfterCompletion(ops, nops, completion); + } FinishOpOnCompletion(&completion, PendingOp::kStartingBatch); return GRPC_CALL_OK; } @@ -3482,6 +3646,7 @@ void ServerPromiseBasedCall::CancelWithErrorLocked(absl::Status error) { if (!promise_.has_value()) return; cancel_send_and_receive_ = true; send_trailing_metadata_ = ServerMetadataFromStatus(error, arena()); + send_trailing_metadata_->Set(GrpcCallWasCancelled(), true); ForceWakeup(); } @@ -3503,6 +3668,11 @@ ServerCallContext::MakeTopOfServerCallPromise( .value_or(GRPC_COMPRESS_NONE); call_->client_initial_metadata_ = std::move(call_args.client_initial_metadata); + call_->ProcessIncomingInitialMetadata( + *call_->client_initial_metadata_, [this](absl::Status status) { + call_->mu()->AssertHeld(); + call_->CancelWithErrorLocked(std::move(status)); + }); PublishMetadataArray(call_->client_initial_metadata_.get(), publish_initial_metadata); call_->ExternalRef(); @@ -3613,7 +3783,9 @@ uint32_t grpc_call_test_only_get_message_flags(grpc_call* call) { } uint32_t grpc_call_test_only_get_encodings_accepted_by_peer(grpc_call* call) { - return grpc_core::Call::FromC(call)->test_only_encodings_accepted_by_peer(); + return grpc_core::Call::FromC(call) + ->encodings_accepted_by_peer() + .ToLegacyBitmask(); } grpc_core::Arena* grpc_call_get_arena(grpc_call* call) { @@ -3662,7 +3834,9 @@ uint8_t grpc_call_is_client(grpc_call* call) { grpc_compression_algorithm grpc_call_compression_for_level( grpc_call* call, grpc_compression_level level) { - return grpc_core::Call::FromC(call)->compression_for_level(level); + return grpc_core::Call::FromC(call) + ->encodings_accepted_by_peer() + .CompressionAlgorithmForLevel(level); } bool grpc_call_is_trailers_only(const grpc_call* call) { diff --git a/src/core/lib/surface/lame_client.cc b/src/core/lib/surface/lame_client.cc index ecbc9eed098..7fbdf8e64b4 100644 --- a/src/core/lib/surface/lame_client.cc +++ b/src/core/lib/surface/lame_client.cc @@ -79,6 +79,7 @@ ArenaPromise LameClientFilter::MakeCallPromise( if (args.server_to_client_messages != nullptr) { args.server_to_client_messages->Close(); } + args.client_initial_metadata_outstanding.Complete(true); return Immediate(ServerMetadataFromStatus(error_)); } diff --git a/src/core/lib/transport/metadata_batch.h b/src/core/lib/transport/metadata_batch.h index 47884b34f80..64c1e8f912f 100644 --- a/src/core/lib/transport/metadata_batch.h +++ b/src/core/lib/transport/metadata_batch.h @@ -396,6 +396,15 @@ struct GrpcStatusFromWire { static absl::string_view DisplayValue(bool x) { return x ? "true" : "false"; } }; +// Annotation to denote that this call qualifies for cancelled=1 for the +// RECV_CLOSE_ON_SERVER op +struct GrpcCallWasCancelled { + static absl::string_view DebugKey() { return "GrpcCallWasCancelled"; } + static constexpr bool kRepeatable = false; + using ValueType = bool; + static absl::string_view DisplayValue(bool x) { return x ? "true" : "false"; } +}; + // Annotation added by client surface code to denote wait-for-ready state struct WaitForReady { struct ValueType { @@ -1327,7 +1336,8 @@ using grpc_metadata_batch_base = grpc_core::MetadataMap< // Non-encodable things grpc_core::GrpcStreamNetworkState, grpc_core::PeerString, grpc_core::GrpcStatusContext, grpc_core::GrpcStatusFromWire, - grpc_core::WaitForReady, grpc_core::GrpcTrailersOnly>; + grpc_core::GrpcCallWasCancelled, grpc_core::WaitForReady, + grpc_core::GrpcTrailersOnly>; struct grpc_metadata_batch : public grpc_metadata_batch_base { using grpc_metadata_batch_base::grpc_metadata_batch_base; diff --git a/src/core/lib/transport/transport.h b/src/core/lib/transport/transport.h index bcd151cc6af..f24109d8ab0 100644 --- a/src/core/lib/transport/transport.h +++ b/src/core/lib/transport/transport.h @@ -27,6 +27,7 @@ #include #include +#include #include #include "absl/status/status.h" @@ -54,6 +55,7 @@ #include "src/core/lib/promise/arena_promise.h" #include "src/core/lib/promise/context.h" #include "src/core/lib/promise/detail/status.h" +#include "src/core/lib/promise/latch.h" #include "src/core/lib/promise/pipe.h" #include "src/core/lib/resource_quota/arena.h" #include "src/core/lib/slice/slice_buffer.h" @@ -144,11 +146,70 @@ struct StatusCastImpl { } }; +// Move only type that tracks call startup. +// Allows observation of when client_initial_metadata has been processed by the +// end of the local call stack. +// Interested observers can call Wait() to obtain a promise that will resolve +// when all local client_initial_metadata processing has completed. +// The result of this token is either true on successful completion, or false +// if the metadata was not sent. +// To set a successful completion, call Complete(true). For failure, call +// Complete(false). +// If Complete is not called, the destructor of a still held token will complete +// with failure. +// Transports should hold this token until client_initial_metadata has passed +// any flow control (eg MAX_CONCURRENT_STREAMS for http2). +class ClientInitialMetadataOutstandingToken { + public: + static ClientInitialMetadataOutstandingToken Empty() { + return ClientInitialMetadataOutstandingToken(); + } + static ClientInitialMetadataOutstandingToken New( + Arena* arena = GetContext()) { + ClientInitialMetadataOutstandingToken token; + token.latch_ = arena->New>(); + return token; + } + + ClientInitialMetadataOutstandingToken( + const ClientInitialMetadataOutstandingToken&) = delete; + ClientInitialMetadataOutstandingToken& operator=( + const ClientInitialMetadataOutstandingToken&) = delete; + ClientInitialMetadataOutstandingToken( + ClientInitialMetadataOutstandingToken&& other) noexcept + : latch_(std::exchange(other.latch_, nullptr)) {} + ClientInitialMetadataOutstandingToken& operator=( + ClientInitialMetadataOutstandingToken&& other) noexcept { + latch_ = std::exchange(other.latch_, nullptr); + return *this; + } + ~ClientInitialMetadataOutstandingToken() { + if (latch_ != nullptr) latch_->Set(false); + } + void Complete(bool success) { std::exchange(latch_, nullptr)->Set(success); } + + // Returns a promise that will resolve when this object (or its moved-from + // ancestor) is dropped. + auto Wait() { return latch_->Wait(); } + + private: + ClientInitialMetadataOutstandingToken() = default; + + Latch* latch_ = nullptr; +}; + +using ClientInitialMetadataOutstandingTokenWaitType = + decltype(std::declval().Wait()); + struct CallArgs { // Initial metadata from the client to the server. // During promise setup this can be manipulated by filters (and then // passed on to the next filter). ClientMetadataHandle client_initial_metadata; + // Token indicating that client_initial_metadata is still being processed. + // This should be moved around and only destroyed when the transport is + // satisfied that the metadata has passed any flow control measures it has. + ClientInitialMetadataOutstandingToken client_initial_metadata_outstanding; // Initial metadata from the server to the client. // Set once when it's available. // During promise setup filters can substitute their own latch for this @@ -331,6 +392,12 @@ struct grpc_transport_stream_op_batch { /// Is this stream traced bool is_traced : 1; + bool HasOp() const { + return send_initial_metadata || send_trailing_metadata || send_message || + recv_initial_metadata || recv_message || recv_trailing_metadata || + cancel_stream; + } + //************************************************************************** // remaining fields are initialized and used at the discretion of the // current handler of the op diff --git a/test/core/end2end/fixtures/h2_oauth2_tls12.cc b/test/core/end2end/fixtures/h2_oauth2_tls12.cc index 5d3681cdca8..0a1d17b4fa6 100644 --- a/test/core/end2end/fixtures/h2_oauth2_tls12.cc +++ b/test/core/end2end/fixtures/h2_oauth2_tls12.cc @@ -45,8 +45,6 @@ #define SERVER_KEY_PATH "src/core/tsi/test_creds/server1.key" static const char oauth2_md[] = "Bearer aaslkfjs424535asdf"; -static const char* client_identity_property_name = "smurf_name"; -static const char* client_identity = "Brainy Smurf"; struct fullstack_secure_fixture_data { std::string localaddr; @@ -70,7 +68,7 @@ typedef struct { size_t pseudo_refcount; } test_processor_state; -static void process_oauth2_success(void* state, grpc_auth_context* ctx, +static void process_oauth2_success(void* state, grpc_auth_context*, const grpc_metadata* md, size_t md_count, grpc_process_auth_metadata_done_cb cb, void* user_data) { @@ -82,10 +80,6 @@ static void process_oauth2_success(void* state, grpc_auth_context* ctx, s = static_cast(state); GPR_ASSERT(s->pseudo_refcount == 1); GPR_ASSERT(oauth2 != nullptr); - grpc_auth_context_add_cstring_property(ctx, client_identity_property_name, - client_identity); - GPR_ASSERT(grpc_auth_context_set_peer_identity_property_name( - ctx, client_identity_property_name) == 1); cb(user_data, oauth2, 1, nullptr, 0, GRPC_STATUS_OK, nullptr); } diff --git a/test/core/end2end/fixtures/h2_oauth2_tls13.cc b/test/core/end2end/fixtures/h2_oauth2_tls13.cc index df8b3307f97..997f8002bb2 100644 --- a/test/core/end2end/fixtures/h2_oauth2_tls13.cc +++ b/test/core/end2end/fixtures/h2_oauth2_tls13.cc @@ -45,8 +45,6 @@ #define SERVER_KEY_PATH "src/core/tsi/test_creds/server1.key" static const char oauth2_md[] = "Bearer aaslkfjs424535asdf"; -static const char* client_identity_property_name = "smurf_name"; -static const char* client_identity = "Brainy Smurf"; struct fullstack_secure_fixture_data { std::string localaddr; @@ -70,7 +68,7 @@ typedef struct { size_t pseudo_refcount; } test_processor_state; -static void process_oauth2_success(void* state, grpc_auth_context* ctx, +static void process_oauth2_success(void* state, grpc_auth_context*, const grpc_metadata* md, size_t md_count, grpc_process_auth_metadata_done_cb cb, void* user_data) { @@ -82,10 +80,6 @@ static void process_oauth2_success(void* state, grpc_auth_context* ctx, s = static_cast(state); GPR_ASSERT(s->pseudo_refcount == 1); GPR_ASSERT(oauth2 != nullptr); - grpc_auth_context_add_cstring_property(ctx, client_identity_property_name, - client_identity); - GPR_ASSERT(grpc_auth_context_set_peer_identity_property_name( - ctx, client_identity_property_name) == 1); cb(user_data, oauth2, 1, nullptr, 0, GRPC_STATUS_OK, nullptr); } diff --git a/test/core/end2end/fixtures/proxy.cc b/test/core/end2end/fixtures/proxy.cc index c1b7da1a446..8e99c5b73a7 100644 --- a/test/core/end2end/fixtures/proxy.cc +++ b/test/core/end2end/fixtures/proxy.cc @@ -210,7 +210,7 @@ static void on_p2s_sent_message(void* arg, int success) { grpc_op op; grpc_call_error err; - grpc_byte_buffer_destroy(pc->c2p_msg); + grpc_byte_buffer_destroy(std::exchange(pc->c2p_msg, nullptr)); if (!pc->proxy->shutdown && success) { op.op = GRPC_OP_RECV_MESSAGE; op.flags = 0; diff --git a/test/core/end2end/tests/filter_init_fails.cc b/test/core/end2end/tests/filter_init_fails.cc index 7665b72c703..5435400bd5c 100644 --- a/test/core/end2end/tests/filter_init_fails.cc +++ b/test/core/end2end/tests/filter_init_fails.cc @@ -21,6 +21,7 @@ #include #include +#include #include #include "absl/status/status.h" @@ -40,7 +41,10 @@ #include "src/core/lib/gprpp/status_helper.h" #include "src/core/lib/iomgr/closure.h" #include "src/core/lib/iomgr/error.h" +#include "src/core/lib/promise/arena_promise.h" +#include "src/core/lib/promise/promise.h" #include "src/core/lib/surface/channel_stack_type.h" +#include "src/core/lib/transport/transport.h" #include "test/core/end2end/cq_verifier.h" #include "test/core/end2end/end2end_tests.h" #include "test/core/util/test_config.h" @@ -448,12 +452,23 @@ static grpc_error_handle init_channel_elem( static void destroy_channel_elem(grpc_channel_element* /*elem*/) {} static const grpc_channel_filter test_filter = { - grpc_call_next_op, nullptr, - grpc_channel_next_op, 0, - init_call_elem, grpc_call_stack_ignore_set_pollset_or_pollset_set, - destroy_call_elem, 0, - init_channel_elem, grpc_channel_stack_no_post_init, - destroy_channel_elem, grpc_channel_next_get_info, + grpc_call_next_op, + [](grpc_channel_element*, grpc_core::CallArgs, + grpc_core::NextPromiseFactory) + -> grpc_core::ArenaPromise { + return grpc_core::Immediate(grpc_core::ServerMetadataFromStatus( + absl::PermissionDeniedError("access denied"))); + }, + grpc_channel_next_op, + 0, + init_call_elem, + grpc_call_stack_ignore_set_pollset_or_pollset_set, + destroy_call_elem, + 0, + init_channel_elem, + grpc_channel_stack_no_post_init, + destroy_channel_elem, + grpc_channel_next_get_info, "filter_init_fails"}; //****************************************************************************** diff --git a/test/core/end2end/tests/max_message_length.cc b/test/core/end2end/tests/max_message_length.cc index b778ae42b47..99083650c1f 100644 --- a/test/core/end2end/tests/max_message_length.cc +++ b/test/core/end2end/tests/max_message_length.cc @@ -128,6 +128,9 @@ static void test_max_message_length_on_request(grpc_end2end_test_config config, grpc_status_code status; grpc_call_error error; grpc_slice details; + grpc_slice expect_in_details = grpc_slice_from_copied_string( + send_limit ? "Sent message larger than max (11 vs. 5)" + : "Received message larger than max (11 vs. 5)"); int was_cancelled = 2; grpc_channel_args* client_args = nullptr; @@ -266,13 +269,10 @@ static void test_max_message_length_on_request(grpc_end2end_test_config config, done: GPR_ASSERT(status == GRPC_STATUS_RESOURCE_EXHAUSTED); - GPR_ASSERT( - grpc_slice_str_cmp( - details, send_limit - ? "Sent message larger than max (11 vs. 5)" - : "Received message larger than max (11 vs. 5)") == 0); + GPR_ASSERT(grpc_slice_slice(details, expect_in_details) >= 0); grpc_slice_unref(details); + grpc_slice_unref(expect_in_details); grpc_metadata_array_destroy(&initial_metadata_recv); grpc_metadata_array_destroy(&trailing_metadata_recv); grpc_metadata_array_destroy(&request_metadata_recv); @@ -316,6 +316,9 @@ static void test_max_message_length_on_response(grpc_end2end_test_config config, grpc_status_code status; grpc_call_error error; grpc_slice details; + grpc_slice expect_in_details = grpc_slice_from_copied_string( + send_limit ? "Sent message larger than max (11 vs. 5)" + : "Received message larger than max (11 vs. 5)"); int was_cancelled = 2; grpc_channel_args* client_args = nullptr; @@ -455,13 +458,10 @@ static void test_max_message_length_on_response(grpc_end2end_test_config config, GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.method, "/service/method")); GPR_ASSERT(status == GRPC_STATUS_RESOURCE_EXHAUSTED); - GPR_ASSERT( - grpc_slice_str_cmp( - details, send_limit - ? "Sent message larger than max (11 vs. 5)" - : "Received message larger than max (11 vs. 5)") == 0); + GPR_ASSERT(grpc_slice_slice(details, expect_in_details) >= 0); grpc_slice_unref(details); + grpc_slice_unref(expect_in_details); grpc_metadata_array_destroy(&initial_metadata_recv); grpc_metadata_array_destroy(&trailing_metadata_recv); grpc_metadata_array_destroy(&request_metadata_recv); diff --git a/test/core/end2end/tests/streaming_error_response.cc b/test/core/end2end/tests/streaming_error_response.cc index 641549025b8..d03f869da75 100644 --- a/test/core/end2end/tests/streaming_error_response.cc +++ b/test/core/end2end/tests/streaming_error_response.cc @@ -40,10 +40,14 @@ static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, const char* test_name, grpc_channel_args* client_args, grpc_channel_args* server_args, - bool request_status_early) { + bool request_status_early, + bool recv_message_separately) { grpc_end2end_test_fixture f; - gpr_log(GPR_INFO, "Running test: %s/%s/request_status_early=%s", test_name, - config.name, request_status_early ? "true" : "false"); + gpr_log( + GPR_INFO, + "Running test: %s/%s/request_status_early=%s/recv_message_separately=%s", + test_name, config.name, request_status_early ? "true" : "false", + recv_message_separately ? "true" : "false"); f = config.create_fixture(client_args, server_args); config.init_server(&f, server_args); config.init_client(&f, client_args); @@ -108,7 +112,7 @@ static void test(grpc_end2end_test_config config, bool request_status_early, grpc_raw_byte_buffer_create(&response_payload2_slice, 1); grpc_end2end_test_fixture f = begin_test(config, "streaming_error_response", nullptr, nullptr, - request_status_early); + request_status_early, recv_message_separately); grpc_core::CqVerifier cqv(f.cq); grpc_op ops[6]; grpc_op* op; diff --git a/test/core/filters/client_auth_filter_test.cc b/test/core/filters/client_auth_filter_test.cc index 54cbddb012f..a65a6fb55e0 100644 --- a/test/core/filters/client_auth_filter_test.cc +++ b/test/core/filters/client_auth_filter_test.cc @@ -155,7 +155,8 @@ TEST_F(ClientAuthFilterTest, CallCredsFails) { auto promise = filter->MakeCallPromise( CallArgs{ClientMetadataHandle(&initial_metadata_batch_, Arena::PooledDeleter(nullptr)), - nullptr, nullptr, nullptr}, + ClientInitialMetadataOutstandingToken::Empty(), nullptr, nullptr, + nullptr}, [&](CallArgs /*call_args*/) { return ArenaPromise( [&]() -> Poll { @@ -185,7 +186,8 @@ TEST_F(ClientAuthFilterTest, RewritesInvalidStatusFromCallCreds) { auto promise = filter->MakeCallPromise( CallArgs{ClientMetadataHandle(&initial_metadata_batch_, Arena::PooledDeleter(nullptr)), - nullptr, nullptr, nullptr}, + ClientInitialMetadataOutstandingToken::Empty(), nullptr, nullptr, + nullptr}, [&](CallArgs /*call_args*/) { return ArenaPromise( [&]() -> Poll { diff --git a/test/core/filters/client_authority_filter_test.cc b/test/core/filters/client_authority_filter_test.cc index 5509be006fc..c300f6403ac 100644 --- a/test/core/filters/client_authority_filter_test.cc +++ b/test/core/filters/client_authority_filter_test.cc @@ -72,7 +72,8 @@ TEST(ClientAuthorityFilterTest, PromiseCompletesImmediatelyAndSetsAuthority) { auto promise = filter.MakeCallPromise( CallArgs{ClientMetadataHandle(&initial_metadata_batch, Arena::PooledDeleter(nullptr)), - nullptr, nullptr, nullptr}, + ClientInitialMetadataOutstandingToken::Empty(), nullptr, nullptr, + nullptr}, [&](CallArgs call_args) { EXPECT_EQ(call_args.client_initial_metadata ->get_pointer(HttpAuthorityMetadata()) @@ -107,7 +108,8 @@ TEST(ClientAuthorityFilterTest, auto promise = filter.MakeCallPromise( CallArgs{ClientMetadataHandle(&initial_metadata_batch, Arena::PooledDeleter(nullptr)), - nullptr, nullptr, nullptr}, + ClientInitialMetadataOutstandingToken::Empty(), nullptr, nullptr, + nullptr}, [&](CallArgs call_args) { EXPECT_EQ(call_args.client_initial_metadata ->get_pointer(HttpAuthorityMetadata()) diff --git a/test/core/filters/filter_fuzzer.cc b/test/core/filters/filter_fuzzer.cc index 72b2435fd83..0739813927e 100644 --- a/test/core/filters/filter_fuzzer.cc +++ b/test/core/filters/filter_fuzzer.cc @@ -477,6 +477,7 @@ class MainLoop { auto* server_initial_metadata = arena_->New>(); CallArgs call_args{std::move(*LoadMetadata(client_initial_metadata, &client_initial_metadata_)), + ClientInitialMetadataOutstandingToken::Empty(), &server_initial_metadata->sender, nullptr, nullptr}; if (is_client) { promise_ = main_loop_->channel_stack_->MakeClientCallPromise(