diff --git a/CMakeLists.txt b/CMakeLists.txt index 30472acaacf..b401fc53b49 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1329,6 +1329,7 @@ if(gRPC_BUILD_TESTS) if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) add_dependencies(buildtests_cxx remove_stream_from_stalled_lists_test) endif() + add_dependencies(buildtests_cxx request_buffer_test) add_dependencies(buildtests_cxx request_with_flags_test) add_dependencies(buildtests_cxx request_with_payload_test) if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) @@ -24090,6 +24091,106 @@ endif() endif() if(gRPC_BUILD_TESTS) +add_executable(request_buffer_test + src/core/call/request_buffer.cc + src/core/ext/upb-gen/google/protobuf/any.upb_minitable.c + src/core/ext/upb-gen/google/rpc/status.upb_minitable.c + src/core/lib/channel/channel_args.cc + src/core/lib/compression/compression.cc + src/core/lib/compression/compression_internal.cc + src/core/lib/debug/trace.cc + src/core/lib/debug/trace_flags.cc + src/core/lib/experiments/config.cc + src/core/lib/experiments/experiments.cc + src/core/lib/iomgr/closure.cc + src/core/lib/iomgr/combiner.cc + src/core/lib/iomgr/error.cc + src/core/lib/iomgr/exec_ctx.cc + src/core/lib/iomgr/executor.cc + src/core/lib/iomgr/iomgr_internal.cc + src/core/lib/promise/activity.cc + src/core/lib/promise/party.cc + src/core/lib/resource_quota/arena.cc + src/core/lib/resource_quota/connection_quota.cc + src/core/lib/resource_quota/memory_quota.cc + src/core/lib/resource_quota/periodic_update.cc + src/core/lib/resource_quota/resource_quota.cc + src/core/lib/resource_quota/thread_quota.cc + src/core/lib/slice/percent_encoding.cc + src/core/lib/slice/slice.cc + src/core/lib/slice/slice_buffer.cc + src/core/lib/slice/slice_string_helpers.cc + src/core/lib/surface/channel_stack_type.cc + src/core/lib/transport/call_arena_allocator.cc + src/core/lib/transport/call_filters.cc + src/core/lib/transport/call_final_info.cc + src/core/lib/transport/call_spine.cc + src/core/lib/transport/call_state.cc + src/core/lib/transport/error_utils.cc + src/core/lib/transport/message.cc + src/core/lib/transport/metadata.cc + src/core/lib/transport/metadata_batch.cc + src/core/lib/transport/parsed_metadata.cc + src/core/lib/transport/status_conversion.cc + src/core/lib/transport/timeout_encoding.cc + src/core/util/dump_args.cc + src/core/util/glob.cc + src/core/util/latent_see.cc + src/core/util/per_cpu.cc + src/core/util/ref_counted_string.cc + src/core/util/status_helper.cc + src/core/util/time.cc + test/core/call/request_buffer_test.cc +) +if(WIN32 AND MSVC) + if(BUILD_SHARED_LIBS) + target_compile_definitions(request_buffer_test + PRIVATE + "GPR_DLL_IMPORTS" + ) + endif() +endif() +target_compile_features(request_buffer_test PUBLIC cxx_std_14) +target_include_directories(request_buffer_test + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR} + ${CMAKE_CURRENT_SOURCE_DIR}/include + ${_gRPC_ADDRESS_SORTING_INCLUDE_DIR} + ${_gRPC_RE2_INCLUDE_DIR} + ${_gRPC_SSL_INCLUDE_DIR} + ${_gRPC_UPB_GENERATED_DIR} + ${_gRPC_UPB_GRPC_GENERATED_DIR} + ${_gRPC_UPB_INCLUDE_DIR} + ${_gRPC_XXHASH_INCLUDE_DIR} + ${_gRPC_ZLIB_INCLUDE_DIR} + third_party/googletest/googletest/include + third_party/googletest/googletest + third_party/googletest/googlemock/include + third_party/googletest/googlemock + ${_gRPC_PROTO_GENS_DIR} +) + +target_link_libraries(request_buffer_test + ${_gRPC_ALLTARGETS_LIBRARIES} + gtest + upb_mini_descriptor_lib + upb_wire_lib + absl::config + absl::no_destructor + absl::flat_hash_map + absl::inlined_vector + absl::function_ref + absl::hash + absl::type_traits + absl::statusor + absl::utility + gpr +) + + +endif() +if(gRPC_BUILD_TESTS) + add_executable(request_with_flags_test src/core/ext/transport/chaotic_good/client/chaotic_good_connector.cc src/core/ext/transport/chaotic_good/client_transport.cc diff --git a/build_autogenerated.yaml b/build_autogenerated.yaml index c70e708cea4..f256a8355c3 100644 --- a/build_autogenerated.yaml +++ b/build_autogenerated.yaml @@ -15297,6 +15297,174 @@ targets: - linux - posix - mac +- name: request_buffer_test + gtest: true + build: test + language: c++ + headers: + - src/core/call/request_buffer.h + - src/core/ext/upb-gen/google/protobuf/any.upb.h + - src/core/ext/upb-gen/google/protobuf/any.upb_minitable.h + - src/core/ext/upb-gen/google/rpc/status.upb.h + - src/core/ext/upb-gen/google/rpc/status.upb_minitable.h + - src/core/lib/channel/channel_args.h + - src/core/lib/compression/compression_internal.h + - src/core/lib/debug/trace.h + - src/core/lib/debug/trace_flags.h + - src/core/lib/debug/trace_impl.h + - src/core/lib/event_engine/event_engine_context.h + - src/core/lib/experiments/config.h + - src/core/lib/experiments/experiments.h + - src/core/lib/iomgr/closure.h + - src/core/lib/iomgr/combiner.h + - src/core/lib/iomgr/error.h + - src/core/lib/iomgr/exec_ctx.h + - src/core/lib/iomgr/executor.h + - src/core/lib/iomgr/iomgr_internal.h + - src/core/lib/promise/activity.h + - src/core/lib/promise/context.h + - src/core/lib/promise/detail/basic_seq.h + - src/core/lib/promise/detail/promise_factory.h + - src/core/lib/promise/detail/promise_like.h + - src/core/lib/promise/detail/seq_state.h + - src/core/lib/promise/detail/status.h + - src/core/lib/promise/exec_ctx_wakeup_scheduler.h + - src/core/lib/promise/for_each.h + - src/core/lib/promise/if.h + - src/core/lib/promise/interceptor_list.h + - src/core/lib/promise/latch.h + - src/core/lib/promise/loop.h + - src/core/lib/promise/map.h + - src/core/lib/promise/party.h + - src/core/lib/promise/pipe.h + - src/core/lib/promise/poll.h + - src/core/lib/promise/prioritized_race.h + - src/core/lib/promise/promise.h + - src/core/lib/promise/race.h + - src/core/lib/promise/seq.h + - src/core/lib/promise/status_flag.h + - src/core/lib/promise/try_seq.h + - src/core/lib/promise/wait_set.h + - src/core/lib/resource_quota/arena.h + - src/core/lib/resource_quota/connection_quota.h + - src/core/lib/resource_quota/memory_quota.h + - src/core/lib/resource_quota/periodic_update.h + - src/core/lib/resource_quota/resource_quota.h + - src/core/lib/resource_quota/thread_quota.h + - src/core/lib/slice/percent_encoding.h + - src/core/lib/slice/slice.h + - src/core/lib/slice/slice_buffer.h + - src/core/lib/slice/slice_internal.h + - src/core/lib/slice/slice_refcount.h + - src/core/lib/slice/slice_string_helpers.h + - src/core/lib/surface/channel_stack_type.h + - src/core/lib/transport/call_arena_allocator.h + - src/core/lib/transport/call_filters.h + - src/core/lib/transport/call_final_info.h + - src/core/lib/transport/call_spine.h + - src/core/lib/transport/call_state.h + - src/core/lib/transport/custom_metadata.h + - src/core/lib/transport/error_utils.h + - src/core/lib/transport/http2_errors.h + - src/core/lib/transport/message.h + - src/core/lib/transport/metadata.h + - src/core/lib/transport/metadata_batch.h + - src/core/lib/transport/metadata_compression_traits.h + - src/core/lib/transport/parsed_metadata.h + - src/core/lib/transport/simple_slice_based_metadata.h + - src/core/lib/transport/status_conversion.h + - src/core/lib/transport/timeout_encoding.h + - src/core/util/atomic_utils.h + - src/core/util/avl.h + - src/core/util/bitset.h + - src/core/util/chunked_vector.h + - src/core/util/cpp_impl_of.h + - src/core/util/down_cast.h + - src/core/util/dual_ref_counted.h + - src/core/util/dump_args.h + - src/core/util/glob.h + - src/core/util/if_list.h + - src/core/util/latent_see.h + - src/core/util/manual_constructor.h + - src/core/util/orphanable.h + - src/core/util/packed_table.h + - src/core/util/per_cpu.h + - src/core/util/ref_counted.h + - src/core/util/ref_counted_ptr.h + - src/core/util/ref_counted_string.h + - src/core/util/ring_buffer.h + - src/core/util/sorted_pack.h + - src/core/util/spinlock.h + - src/core/util/status_helper.h + - src/core/util/table.h + - src/core/util/time.h + - src/core/util/type_list.h + - test/core/promise/poll_matcher.h + - third_party/upb/upb/generated_code_support.h + src: + - src/core/call/request_buffer.cc + - src/core/ext/upb-gen/google/protobuf/any.upb_minitable.c + - src/core/ext/upb-gen/google/rpc/status.upb_minitable.c + - src/core/lib/channel/channel_args.cc + - src/core/lib/compression/compression.cc + - src/core/lib/compression/compression_internal.cc + - src/core/lib/debug/trace.cc + - src/core/lib/debug/trace_flags.cc + - src/core/lib/experiments/config.cc + - src/core/lib/experiments/experiments.cc + - src/core/lib/iomgr/closure.cc + - src/core/lib/iomgr/combiner.cc + - src/core/lib/iomgr/error.cc + - src/core/lib/iomgr/exec_ctx.cc + - src/core/lib/iomgr/executor.cc + - src/core/lib/iomgr/iomgr_internal.cc + - src/core/lib/promise/activity.cc + - src/core/lib/promise/party.cc + - src/core/lib/resource_quota/arena.cc + - src/core/lib/resource_quota/connection_quota.cc + - src/core/lib/resource_quota/memory_quota.cc + - src/core/lib/resource_quota/periodic_update.cc + - src/core/lib/resource_quota/resource_quota.cc + - src/core/lib/resource_quota/thread_quota.cc + - src/core/lib/slice/percent_encoding.cc + - src/core/lib/slice/slice.cc + - src/core/lib/slice/slice_buffer.cc + - src/core/lib/slice/slice_string_helpers.cc + - src/core/lib/surface/channel_stack_type.cc + - src/core/lib/transport/call_arena_allocator.cc + - src/core/lib/transport/call_filters.cc + - src/core/lib/transport/call_final_info.cc + - src/core/lib/transport/call_spine.cc + - src/core/lib/transport/call_state.cc + - src/core/lib/transport/error_utils.cc + - src/core/lib/transport/message.cc + - src/core/lib/transport/metadata.cc + - src/core/lib/transport/metadata_batch.cc + - src/core/lib/transport/parsed_metadata.cc + - src/core/lib/transport/status_conversion.cc + - src/core/lib/transport/timeout_encoding.cc + - src/core/util/dump_args.cc + - src/core/util/glob.cc + - src/core/util/latent_see.cc + - src/core/util/per_cpu.cc + - src/core/util/ref_counted_string.cc + - src/core/util/status_helper.cc + - src/core/util/time.cc + - test/core/call/request_buffer_test.cc + deps: + - gtest + - upb_mini_descriptor_lib + - upb_wire_lib + - absl/base:config + - absl/base:no_destructor + - absl/container:flat_hash_map + - absl/container:inlined_vector + - absl/functional:function_ref + - absl/hash:hash + - absl/meta:type_traits + - absl/status:statusor + - absl/utility:utility + - gpr - name: request_with_flags_test gtest: true build: test diff --git a/src/core/BUILD b/src/core/BUILD index 8220654c837..f5949c8e05f 100644 --- a/src/core/BUILD +++ b/src/core/BUILD @@ -1608,6 +1608,23 @@ grpc_cc_library( ], ) +grpc_cc_library( + name = "request_buffer", + srcs = [ + "call/request_buffer.cc", + ], + hdrs = [ + "call/request_buffer.h", + ], + external_deps = ["absl/types:optional"], + deps = [ + "call_spine", + "message", + "metadata", + "wait_set", + ], +) + grpc_cc_library( name = "slice_refcount", hdrs = [ diff --git a/src/core/call/request_buffer.cc b/src/core/call/request_buffer.cc new file mode 100644 index 00000000000..257743cbb90 --- /dev/null +++ b/src/core/call/request_buffer.cc @@ -0,0 +1,168 @@ +// Copyright 2024 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/core/call/request_buffer.h" + +#include + +#include "absl/types/optional.h" + +namespace grpc_core { + +ValueOrFailure RequestBuffer::PushClientInitialMetadata( + ClientMetadataHandle md) { + MutexLock lock(&mu_); + if (absl::get_if(&state_)) return Failure{}; + auto& buffering = absl::get(state_); + CHECK_EQ(buffering.initial_metadata.get(), nullptr); + buffering.initial_metadata = std::move(md); + buffering.buffered += buffering.initial_metadata->TransportSize(); + WakeupAsyncAllPullers(); + return buffering.buffered; +} + +Poll> RequestBuffer::PollPushMessage( + MessageHandle& message) { + MutexLock lock(&mu_); + if (absl::get_if(&state_)) return Failure{}; + size_t buffered = 0; + if (auto* buffering = absl::get_if(&state_)) { + if (winner_ != nullptr) return PendingPush(); + buffering->buffered += message->payload()->Length(); + buffered = buffering->buffered; + buffering->messages.push_back(std::move(message)); + } else { + auto& streaming = absl::get(state_); + CHECK_EQ(streaming.end_of_stream, false); + if (streaming.message != nullptr) { + return PendingPush(); + } + streaming.message = std::move(message); + } + WakeupAsyncAllPullers(); + return buffered; +} + +StatusFlag RequestBuffer::FinishSends() { + MutexLock lock(&mu_); + if (absl::get_if(&state_)) return Failure{}; + if (auto* buffering = absl::get_if(&state_)) { + Buffered buffered(std::move(buffering->initial_metadata), + std::move(buffering->messages)); + state_.emplace(std::move(buffered)); + } else { + auto& streaming = absl::get(state_); + CHECK_EQ(streaming.end_of_stream, false); + streaming.end_of_stream = true; + } + WakeupAsyncAllPullers(); + return Success{}; +} + +void RequestBuffer::Cancel(absl::Status error) { + MutexLock lock(&mu_); + if (absl::holds_alternative(state_)) return; + state_.emplace(std::move(error)); + WakeupAsyncAllPullers(); +} + +void RequestBuffer::Commit(Reader* winner) { + MutexLock lock(&mu_); + CHECK_EQ(winner_, nullptr); + winner_ = winner; + if (auto* buffering = absl::get_if(&state_)) { + if (buffering->initial_metadata != nullptr && + winner->message_index_ == buffering->messages.size() && + winner->pulled_client_initial_metadata_) { + state_.emplace(); + } + } else if (auto* buffered = absl::get_if(&state_)) { + CHECK_NE(buffered->initial_metadata.get(), nullptr); + if (winner->message_index_ == buffered->messages.size()) { + state_.emplace().end_of_stream = true; + } + } + WakeupAsyncAllPullersExcept(winner); +} + +void RequestBuffer::WakeupAsyncAllPullersExcept(Reader* except_reader) { + for (auto wakeup_reader : readers_) { + if (wakeup_reader == except_reader) continue; + wakeup_reader->pull_waker_.WakeupAsync(); + } +} + +Poll> +RequestBuffer::Reader::PollPullClientInitialMetadata() { + MutexLock lock(&buffer_->mu_); + if (buffer_->winner_ != nullptr && buffer_->winner_ != this) { + error_ = absl::CancelledError("Another call was chosen"); + return Failure{}; + } + if (auto* buffering = absl::get_if(&buffer_->state_)) { + if (buffering->initial_metadata.get() == nullptr) { + return buffer_->PendingPull(this); + } + pulled_client_initial_metadata_ = true; + auto result = ClaimObject(buffering->initial_metadata); + buffer_->MaybeSwitchToStreaming(); + return result; + } + if (auto* buffered = absl::get_if(&buffer_->state_)) { + pulled_client_initial_metadata_ = true; + return ClaimObject(buffered->initial_metadata); + } + error_ = absl::get(buffer_->state_).error; + return Failure{}; +} + +Poll>> +RequestBuffer::Reader::PollPullMessage() { + ReleasableMutexLock lock(&buffer_->mu_); + if (buffer_->winner_ != nullptr && buffer_->winner_ != this) { + error_ = absl::CancelledError("Another call was chosen"); + return Failure{}; + } + if (auto* buffering = absl::get_if(&buffer_->state_)) { + if (message_index_ == buffering->messages.size()) { + return buffer_->PendingPull(this); + } + const auto idx = message_index_; + auto result = ClaimObject(buffering->messages[idx]); + ++message_index_; + buffer_->MaybeSwitchToStreaming(); + return result; + } + if (auto* buffered = absl::get_if(&buffer_->state_)) { + if (message_index_ == buffered->messages.size()) return absl::nullopt; + const auto idx = message_index_; + ++message_index_; + return ClaimObject(buffered->messages[idx]); + } + if (auto* streaming = absl::get_if(&buffer_->state_)) { + if (streaming->message == nullptr) { + if (streaming->end_of_stream) return absl::nullopt; + return buffer_->PendingPull(this); + } + auto msg = std::move(streaming->message); + auto waker = std::move(buffer_->push_waker_); + lock.Release(); + waker.Wakeup(); + return msg; + } + error_ = absl::get(buffer_->state_).error; + return Failure{}; +} + +} // namespace grpc_core diff --git a/src/core/call/request_buffer.h b/src/core/call/request_buffer.h new file mode 100644 index 00000000000..719a562bdba --- /dev/null +++ b/src/core/call/request_buffer.h @@ -0,0 +1,182 @@ +// Copyright 2024 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GRPC_SRC_CORE_CALL_REQUEST_BUFFER_H +#define GRPC_SRC_CORE_CALL_REQUEST_BUFFER_H + +#include + +#include "src/core/lib/transport/call_spine.h" +#include "src/core/lib/transport/message.h" +#include "src/core/lib/transport/metadata.h" + +namespace grpc_core { + +// Outbound request buffer. +// Collects client->server metadata and messages whilst in its initial buffering +// mode. In buffering mode it can have zero or more Reader objects attached to +// it. +// The buffer can later be switched to committed mode, at which point it +// will have exactly one Reader object attached to it. +// Callers can choose to switch to committed mode based upon policy of their +// choice. +class RequestBuffer { + public: + // One reader of the request buffer. + class Reader { + public: + explicit Reader(RequestBuffer* buffer) ABSL_LOCKS_EXCLUDED(buffer->mu_) + : buffer_(buffer) { + buffer->AddReader(this); + } + ~Reader() ABSL_LOCKS_EXCLUDED(buffer_->mu_) { buffer_->RemoveReader(this); } + + Reader(const Reader&) = delete; + Reader& operator=(const Reader&) = delete; + + // Pull client initial metadata. Returns a promise that resolves to + // ValueOrFailure. + GRPC_MUST_USE_RESULT auto PullClientInitialMetadata() { + return [this]() { return PollPullClientInitialMetadata(); }; + } + // Pull a message. Returns a promise that resolves to a + // ValueOrFailure>. + GRPC_MUST_USE_RESULT auto PullMessage() { + return [this]() { return PollPullMessage(); }; + } + + absl::Status TakeError() { return std::move(error_); } + + private: + friend class RequestBuffer; + + Poll> PollPullClientInitialMetadata(); + Poll>> PollPullMessage(); + + template + T ClaimObject(T& object) ABSL_EXCLUSIVE_LOCKS_REQUIRED(buffer_->mu_) { + if (buffer_->winner_ == this) return std::move(object); + return CopyObject(object); + } + + ClientMetadataHandle CopyObject(const ClientMetadataHandle& md) { + return Arena::MakePooled(md->Copy()); + } + + MessageHandle CopyObject(const MessageHandle& msg) { + return Arena::MakePooled(msg->payload()->Copy(), msg->flags()); + } + + RequestBuffer* const buffer_; + bool pulled_client_initial_metadata_ = false; + size_t message_index_ = 0; + absl::Status error_; + Waker pull_waker_; + }; + + // Push ClientInitialMetadata into the buffer. + // This is instantaneous, and returns success with the amount of data + // buffered, or failure. + ValueOrFailure PushClientInitialMetadata(ClientMetadataHandle md); + // Resolves to a ValueOrFailure where the size_t is the amount of data + // buffered (or 0 if we're in committed mode). + GRPC_MUST_USE_RESULT auto PushMessage(MessageHandle message) { + return [this, message = std::move(message)]() mutable { + return PollPushMessage(message); + }; + } + // Push end of stream (client half-closure). + StatusFlag FinishSends(); + // Cancel the request, propagate failure to all readers. + void Cancel(absl::Status error = absl::CancelledError()); + + // Switch to committed mode - needs to be called exactly once with the winning + // reader. All other readers will see failure. + void Commit(Reader* winner); + + private: + // Buffering state: we're collecting metadata and messages. + struct Buffering { + // Initial metadata, or nullptr if not yet received. + ClientMetadataHandle initial_metadata; + // Buffered messages. + absl::InlinedVector messages; + // Amount of data buffered. + size_t buffered = 0; + }; + // Buffered state: all messages have been collected (the client has finished + // sending). + struct Buffered { + Buffered(ClientMetadataHandle md, + absl::InlinedVector msgs) + : initial_metadata(std::move(md)), messages(std::move(msgs)) {} + ClientMetadataHandle initial_metadata; + absl::InlinedVector messages; + }; + // Streaming state: we're streaming messages to the server. + // This implies winner_ is set. + struct Streaming { + MessageHandle message; + bool end_of_stream = false; + }; + // Cancelled state: the request has been cancelled. + struct Cancelled { + explicit Cancelled(absl::Status error) : error(std::move(error)) {} + absl::Status error; + }; + using State = absl::variant; + + Poll> PollPushMessage(MessageHandle& message); + Pending PendingPull(Reader* reader) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + reader->pull_waker_ = Activity::current()->MakeOwningWaker(); + return Pending{}; + } + Pending PendingPush() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + push_waker_ = Activity::current()->MakeOwningWaker(); + return Pending{}; + } + void MaybeSwitchToStreaming() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + auto& buffering = absl::get(state_); + if (winner_ == nullptr) return; + if (winner_->message_index_ < buffering.messages.size()) return; + state_.emplace(); + push_waker_.Wakeup(); + } + + void WakeupAsyncAllPullers() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + WakeupAsyncAllPullersExcept(nullptr); + } + void WakeupAsyncAllPullersExcept(Reader* except_reader) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + void AddReader(Reader* reader) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + readers_.insert(reader); + } + + void RemoveReader(Reader* reader) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + readers_.erase(reader); + } + + Mutex mu_; + Reader* winner_ ABSL_GUARDED_BY(mu_){nullptr}; + State state_ ABSL_GUARDED_BY(mu_){Buffering{}}; + // TODO(ctiller): change this to an intrusively linked list to avoid + // allocations. + absl::flat_hash_set readers_ ABSL_GUARDED_BY(mu_); + Waker push_waker_ ABSL_GUARDED_BY(mu_); +}; + +} // namespace grpc_core + +#endif // GRPC_SRC_CORE_CALL_REQUEST_BUFFER_H diff --git a/src/core/lib/transport/message.h b/src/core/lib/transport/message.h index 6390e9886be..f7375d060a9 100644 --- a/src/core/lib/transport/message.h +++ b/src/core/lib/transport/message.h @@ -49,6 +49,11 @@ class Message { std::string DebugString() const; + template + friend void AbslStringify(Sink& sink, const Message& message) { + sink.Append(message.DebugString()); + } + private: SliceBuffer payload_; uint32_t flags_ = 0; diff --git a/test/core/call/BUILD b/test/core/call/BUILD index 3c726fbf280..75a590cd3e7 100644 --- a/test/core/call/BUILD +++ b/test/core/call/BUILD @@ -87,3 +87,16 @@ grpc_cc_benchmark( "//src/core:default_event_engine", ], ) + +grpc_cc_test( + name = "request_buffer_test", + srcs = [ + "request_buffer_test.cc", + ], + external_deps = ["gtest"], + language = "C++", + deps = [ + "//src/core:request_buffer", + "//test/core/promise:poll_matcher", + ], +) diff --git a/test/core/call/request_buffer_test.cc b/test/core/call/request_buffer_test.cc new file mode 100644 index 00000000000..98aba0ac495 --- /dev/null +++ b/test/core/call/request_buffer_test.cc @@ -0,0 +1,722 @@ +// Copyright 2024 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/core/call/request_buffer.h" + +#include "gtest/gtest.h" + +#include "test/core/promise/poll_matcher.h" + +using testing::Mock; +using testing::StrictMock; + +namespace grpc_core { + +namespace { +void CrashOnParseError(absl::string_view error, const Slice& data) { + LOG(FATAL) << "Failed to parse " << error << " from " + << data.as_string_view(); +} + +// A mock activity that can be activated and deactivated. +class MockActivity : public Activity, public Wakeable { + public: + MOCK_METHOD(void, WakeupRequested, ()); + + void ForceImmediateRepoll(WakeupMask /*mask*/) override { WakeupRequested(); } + void Orphan() override {} + Waker MakeOwningWaker() override { return Waker(this, 0); } + Waker MakeNonOwningWaker() override { return Waker(this, 0); } + void Wakeup(WakeupMask /*mask*/) override { WakeupRequested(); } + void WakeupAsync(WakeupMask /*mask*/) override { WakeupRequested(); } + void Drop(WakeupMask /*mask*/) override {} + std::string DebugTag() const override { return "MockActivity"; } + std::string ActivityDebugTag(WakeupMask /*mask*/) const override { + return DebugTag(); + } + + void Activate() { + if (scoped_activity_ == nullptr) { + scoped_activity_ = std::make_unique(this); + } + } + + void Deactivate() { scoped_activity_.reset(); } + + private: + std::unique_ptr scoped_activity_; +}; + +#define EXPECT_WAKEUP(activity, statement) \ + EXPECT_CALL((activity), WakeupRequested()).Times(::testing::AtLeast(1)); \ + statement; \ + Mock::VerifyAndClearExpectations(&(activity)); + +ClientMetadataHandle TestMetadata() { + ClientMetadataHandle md = Arena::MakePooledForOverwrite(); + md->Append("key", Slice::FromStaticString("value"), CrashOnParseError); + return md; +} + +MessageHandle TestMessage(int index = 0) { + return Arena::MakePooled( + SliceBuffer(Slice::FromCopiedString(absl::StrCat("message ", index))), 0); +} + +MATCHER(IsTestMetadata, "") { + if (arg == nullptr) return false; + std::string backing; + if (arg->GetStringValue("key", &backing) != "value") { + *result_listener << arg->DebugString(); + return false; + } + return true; +} + +MATCHER(IsTestMessage, "") { + if (arg == nullptr) return false; + if (arg->flags() != 0) { + *result_listener << "flags: " << arg->flags(); + return false; + } + if (arg->payload()->JoinIntoString() != "message 0") { + *result_listener << "payload: " << arg->payload()->JoinIntoString(); + return false; + } + return true; +} + +MATCHER_P(IsTestMessage, index, "") { + if (arg == nullptr) return false; + if (arg->flags() != 0) { + *result_listener << "flags: " << arg->flags(); + return false; + } + if (arg->payload()->JoinIntoString() != absl::StrCat("message ", index)) { + *result_listener << "payload: " << arg->payload()->JoinIntoString(); + return false; + } + return true; +} + +} // namespace + +TEST(RequestBufferTest, NoOp) { RequestBuffer buffer; } + +TEST(RequestBufferTest, PushThenPullClientInitialMetadata) { + RequestBuffer buffer; + EXPECT_EQ(buffer.PushClientInitialMetadata(TestMetadata()), 40); + RequestBuffer::Reader reader(&buffer); + auto poll = reader.PullClientInitialMetadata()(); + ASSERT_THAT(poll, IsReady()); + auto value = std::move(poll.value()); + ASSERT_TRUE(value.ok()); + EXPECT_THAT(*value, IsTestMetadata()); +} + +TEST(RequestBufferTest, PushThenFinishThenPullClientInitialMetadata) { + RequestBuffer buffer; + EXPECT_EQ(buffer.PushClientInitialMetadata(TestMetadata()), 40); + buffer.FinishSends(); + RequestBuffer::Reader reader(&buffer); + auto poll = reader.PullClientInitialMetadata()(); + ASSERT_THAT(poll, IsReady()); + auto value = std::move(poll.value()); + ASSERT_TRUE(value.ok()); + EXPECT_THAT(*value, IsTestMetadata()); +} + +TEST(RequestBufferTest, PullThenPushClientInitialMetadata) { + StrictMock activity; + RequestBuffer buffer; + RequestBuffer::Reader reader(&buffer); + activity.Activate(); + auto poller = reader.PullClientInitialMetadata(); + auto poll = poller(); + EXPECT_THAT(poll, IsPending()); + ClientMetadataHandle md = Arena::MakePooledForOverwrite(); + md->Append("key", Slice::FromStaticString("value"), CrashOnParseError); + EXPECT_WAKEUP(activity, + EXPECT_EQ(buffer.PushClientInitialMetadata(std::move(md)), 40)); + poll = poller(); + ASSERT_THAT(poll, IsReady()); + auto value = std::move(poll.value()); + ASSERT_TRUE(value.ok()); + EXPECT_THAT(*value, IsTestMetadata()); +} + +TEST(RequestBufferTest, PushThenPullMessage) { + RequestBuffer buffer; + EXPECT_EQ(buffer.PushClientInitialMetadata(TestMetadata()), 40); + auto pusher = buffer.PushMessage(TestMessage()); + EXPECT_THAT(pusher(), IsReady(49)); + RequestBuffer::Reader reader(&buffer); + auto pull_md = reader.PullClientInitialMetadata(); + EXPECT_THAT(pull_md(), IsReady()); // value tested elsewhere + auto pull_msg = reader.PullMessage(); + auto poll_msg = pull_msg(); + ASSERT_THAT(poll_msg, IsReady()); + ASSERT_TRUE(poll_msg.value().ok()); + ASSERT_TRUE(poll_msg.value().value().has_value()); + EXPECT_THAT(poll_msg.value().value().value(), IsTestMessage()); +} + +TEST(RequestBufferTest, PushThenPullMessageStreamBeforeInitialMetadata) { + RequestBuffer buffer; + EXPECT_EQ(buffer.PushClientInitialMetadata(TestMetadata()), 40); + auto pusher = buffer.PushMessage(TestMessage()); + EXPECT_THAT(pusher(), IsReady(49)); + RequestBuffer::Reader reader(&buffer); + buffer.Commit(&reader); + auto pull_md = reader.PullClientInitialMetadata(); + EXPECT_THAT(pull_md(), IsReady()); // value tested elsewhere + auto pull_msg = reader.PullMessage(); + auto poll_msg = pull_msg(); + ASSERT_THAT(poll_msg, IsReady()); + ASSERT_TRUE(poll_msg.value().ok()); + ASSERT_TRUE(poll_msg.value().value().has_value()); + EXPECT_THAT(poll_msg.value().value().value(), IsTestMessage()); +} + +TEST(RequestBufferTest, PushThenPullMessageStreamBeforeFirstMessage) { + RequestBuffer buffer; + EXPECT_EQ(buffer.PushClientInitialMetadata(TestMetadata()), 40); + auto pusher = buffer.PushMessage(TestMessage()); + EXPECT_THAT(pusher(), IsReady(49)); + RequestBuffer::Reader reader(&buffer); + auto pull_md = reader.PullClientInitialMetadata(); + EXPECT_THAT(pull_md(), IsReady()); // value tested elsewhere + buffer.Commit(&reader); + auto pull_msg = reader.PullMessage(); + auto poll_msg = pull_msg(); + ASSERT_THAT(poll_msg, IsReady()); + ASSERT_TRUE(poll_msg.value().ok()); + ASSERT_TRUE(poll_msg.value().value().has_value()); + EXPECT_THAT(poll_msg.value().value().value(), IsTestMessage()); +} + +TEST(RequestBufferTest, PullThenPushMessage) { + StrictMock activity; + activity.Activate(); + RequestBuffer buffer; + EXPECT_EQ(buffer.PushClientInitialMetadata(TestMetadata()), 40); + RequestBuffer::Reader reader(&buffer); + auto pull_md = reader.PullClientInitialMetadata(); + EXPECT_THAT(pull_md(), IsReady()); // value tested elsewhere + auto pull_msg = reader.PullMessage(); + auto poll_msg = pull_msg(); + EXPECT_THAT(poll_msg, IsPending()); + auto pusher = buffer.PushMessage(TestMessage()); + EXPECT_WAKEUP(activity, EXPECT_THAT(pusher(), IsReady(49))); + poll_msg = pull_msg(); + ASSERT_THAT(poll_msg, IsReady()); + ASSERT_TRUE(poll_msg.value().ok()); + ASSERT_TRUE(poll_msg.value().value().has_value()); + EXPECT_THAT(poll_msg.value().value().value(), IsTestMessage()); +} + +TEST(RequestBufferTest, PullThenPushMessageSwitchBeforePullMessage) { + StrictMock activity; + activity.Activate(); + RequestBuffer buffer; + EXPECT_EQ(buffer.PushClientInitialMetadata(TestMetadata()), 40); + RequestBuffer::Reader reader(&buffer); + auto pull_md = reader.PullClientInitialMetadata(); + EXPECT_THAT(pull_md(), IsReady()); // value tested elsewhere + buffer.Commit(&reader); + auto pull_msg = reader.PullMessage(); + auto poll_msg = pull_msg(); + EXPECT_THAT(poll_msg, IsPending()); + auto pusher = buffer.PushMessage(TestMessage()); + EXPECT_WAKEUP(activity, EXPECT_THAT(pusher(), IsReady(0))); + poll_msg = pull_msg(); + ASSERT_THAT(poll_msg, IsReady()); + ASSERT_TRUE(poll_msg.value().ok()); + ASSERT_TRUE(poll_msg.value().value().has_value()); + EXPECT_THAT(poll_msg.value().value().value(), IsTestMessage()); +} + +TEST(RequestBufferTest, PullThenPushMessageSwitchBeforePushMessage) { + StrictMock activity; + activity.Activate(); + RequestBuffer buffer; + EXPECT_EQ(buffer.PushClientInitialMetadata(TestMetadata()), 40); + RequestBuffer::Reader reader(&buffer); + auto pull_md = reader.PullClientInitialMetadata(); + EXPECT_THAT(pull_md(), IsReady()); // value tested elsewhere + auto pull_msg = reader.PullMessage(); + auto poll_msg = pull_msg(); + EXPECT_THAT(poll_msg, IsPending()); + buffer.Commit(&reader); + auto pusher = buffer.PushMessage(TestMessage()); + EXPECT_WAKEUP(activity, EXPECT_THAT(pusher(), IsReady(0))); + poll_msg = pull_msg(); + ASSERT_THAT(poll_msg, IsReady()); + ASSERT_TRUE(poll_msg.value().ok()); + ASSERT_TRUE(poll_msg.value().value().has_value()); + EXPECT_THAT(poll_msg.value().value().value(), IsTestMessage()); +} + +TEST(RequestBufferTest, PullThenPushMessageSwitchAfterPushMessage) { + StrictMock activity; + activity.Activate(); + RequestBuffer buffer; + EXPECT_EQ(buffer.PushClientInitialMetadata(TestMetadata()), 40); + RequestBuffer::Reader reader(&buffer); + auto pull_md = reader.PullClientInitialMetadata(); + EXPECT_THAT(pull_md(), IsReady()); // value tested elsewhere + auto pull_msg = reader.PullMessage(); + auto poll_msg = pull_msg(); + EXPECT_THAT(poll_msg, IsPending()); + auto pusher = buffer.PushMessage(TestMessage()); + EXPECT_WAKEUP(activity, EXPECT_THAT(pusher(), IsReady(49))); + buffer.Commit(&reader); + poll_msg = pull_msg(); + ASSERT_THAT(poll_msg, IsReady()); + ASSERT_TRUE(poll_msg.value().ok()); + ASSERT_TRUE(poll_msg.value().value().has_value()); + EXPECT_THAT(poll_msg.value().value().value(), IsTestMessage()); +} + +TEST(RequestBufferTest, PullEndOfStream) { + RequestBuffer buffer; + EXPECT_EQ(buffer.PushClientInitialMetadata(TestMetadata()), 40); + auto pusher = buffer.PushMessage(TestMessage()); + EXPECT_THAT(pusher(), IsReady(49)); + RequestBuffer::Reader reader(&buffer); + auto pull_md = reader.PullClientInitialMetadata(); + EXPECT_THAT(pull_md(), IsReady()); // value tested elsewhere + auto pull_msg = reader.PullMessage(); + auto poll_msg = pull_msg(); + ASSERT_THAT(poll_msg, IsReady()); + ASSERT_TRUE(poll_msg.value().ok()); + ASSERT_TRUE(poll_msg.value().value().has_value()); + EXPECT_THAT(poll_msg.value().value().value(), IsTestMessage()); + EXPECT_EQ(buffer.FinishSends(), Success{}); + auto pull_msg2 = reader.PullMessage(); + poll_msg = pull_msg2(); + ASSERT_THAT(poll_msg, IsReady()); + ASSERT_TRUE(poll_msg.value().ok()); + ASSERT_FALSE(poll_msg.value().value().has_value()); +} + +TEST(RequestBufferTest, PullEndOfStreamSwitchBeforePullMessage) { + RequestBuffer buffer; + EXPECT_EQ(buffer.PushClientInitialMetadata(TestMetadata()), 40); + auto pusher = buffer.PushMessage(TestMessage()); + EXPECT_THAT(pusher(), IsReady(49)); + RequestBuffer::Reader reader(&buffer); + auto pull_md = reader.PullClientInitialMetadata(); + EXPECT_THAT(pull_md(), IsReady()); // value tested elsewhere + buffer.Commit(&reader); + auto pull_msg = reader.PullMessage(); + auto poll_msg = pull_msg(); + ASSERT_THAT(poll_msg, IsReady()); + ASSERT_TRUE(poll_msg.value().ok()); + ASSERT_TRUE(poll_msg.value().value().has_value()); + EXPECT_THAT(poll_msg.value().value().value(), IsTestMessage()); + EXPECT_EQ(buffer.FinishSends(), Success{}); + auto pull_msg2 = reader.PullMessage(); + poll_msg = pull_msg2(); + ASSERT_THAT(poll_msg, IsReady()); + ASSERT_TRUE(poll_msg.value().ok()); + ASSERT_FALSE(poll_msg.value().value().has_value()); +} + +TEST(RequestBufferTest, PullEndOfStreamSwitchBeforePushMessage) { + StrictMock activity; + activity.Activate(); + RequestBuffer buffer; + EXPECT_EQ(buffer.PushClientInitialMetadata(TestMetadata()), 40); + RequestBuffer::Reader reader(&buffer); + buffer.Commit(&reader); + auto pusher = buffer.PushMessage(TestMessage()); + EXPECT_THAT(pusher(), IsPending()); + auto pull_md = reader.PullClientInitialMetadata(); + EXPECT_WAKEUP(activity, + EXPECT_THAT(pull_md(), IsReady())); // value tested elsewhere + EXPECT_THAT(pusher(), IsReady(0)); + auto pull_msg = reader.PullMessage(); + auto poll_msg = pull_msg(); + ASSERT_THAT(poll_msg, IsReady()); + ASSERT_TRUE(poll_msg.value().ok()); + ASSERT_TRUE(poll_msg.value().value().has_value()); + EXPECT_THAT(poll_msg.value().value().value(), IsTestMessage()); + EXPECT_EQ(buffer.FinishSends(), Success{}); + auto pull_msg2 = reader.PullMessage(); + poll_msg = pull_msg2(); + ASSERT_THAT(poll_msg, IsReady()); + ASSERT_TRUE(poll_msg.value().ok()); + ASSERT_FALSE(poll_msg.value().value().has_value()); +} + +TEST(RequestBufferTest, PullEndOfStreamQueuedWithMessage) { + RequestBuffer buffer; + EXPECT_EQ(buffer.PushClientInitialMetadata(TestMetadata()), 40); + auto pusher = buffer.PushMessage(TestMessage()); + EXPECT_THAT(pusher(), IsReady(49)); + EXPECT_EQ(buffer.FinishSends(), Success{}); + RequestBuffer::Reader reader(&buffer); + auto pull_md = reader.PullClientInitialMetadata(); + EXPECT_THAT(pull_md(), IsReady()); // value tested elsewhere + auto pull_msg = reader.PullMessage(); + auto poll_msg = pull_msg(); + ASSERT_THAT(poll_msg, IsReady()); + ASSERT_TRUE(poll_msg.value().ok()); + ASSERT_TRUE(poll_msg.value().value().has_value()); + EXPECT_THAT(poll_msg.value().value().value(), IsTestMessage()); + auto pull_msg2 = reader.PullMessage(); + poll_msg = pull_msg2(); + ASSERT_THAT(poll_msg, IsReady()); + ASSERT_TRUE(poll_msg.value().ok()); + ASSERT_FALSE(poll_msg.value().value().has_value()); +} + +TEST(RequestBufferTest, + PullEndOfStreamQueuedWithMessageSwitchBeforePushMessage) { + StrictMock activity; + activity.Activate(); + RequestBuffer buffer; + EXPECT_EQ(buffer.PushClientInitialMetadata(TestMetadata()), 40); + RequestBuffer::Reader reader(&buffer); + buffer.Commit(&reader); + auto pusher = buffer.PushMessage(TestMessage()); + EXPECT_THAT(pusher(), IsPending()); + auto pull_md = reader.PullClientInitialMetadata(); + EXPECT_WAKEUP(activity, + EXPECT_THAT(pull_md(), IsReady())); // value tested elsewhere + EXPECT_THAT(pusher(), IsReady(0)); + EXPECT_EQ(buffer.FinishSends(), Success{}); + auto pull_msg = reader.PullMessage(); + auto poll_msg = pull_msg(); + ASSERT_THAT(poll_msg, IsReady()); + ASSERT_TRUE(poll_msg.value().ok()); + ASSERT_TRUE(poll_msg.value().value().has_value()); + EXPECT_THAT(poll_msg.value().value().value(), IsTestMessage()); + auto pull_msg2 = reader.PullMessage(); + poll_msg = pull_msg2(); + ASSERT_THAT(poll_msg, IsReady()); + ASSERT_TRUE(poll_msg.value().ok()); + ASSERT_FALSE(poll_msg.value().value().has_value()); +} + +TEST(RequestBufferTest, + PullEndOfStreamQueuedWithMessageSwitchBeforePullMessage) { + RequestBuffer buffer; + EXPECT_EQ(buffer.PushClientInitialMetadata(TestMetadata()), 40); + auto pusher = buffer.PushMessage(TestMessage()); + EXPECT_THAT(pusher(), IsReady(49)); + EXPECT_EQ(buffer.FinishSends(), Success{}); + RequestBuffer::Reader reader(&buffer); + auto pull_md = reader.PullClientInitialMetadata(); + EXPECT_THAT(pull_md(), IsReady()); // value tested elsewhere + buffer.Commit(&reader); + auto pull_msg = reader.PullMessage(); + auto poll_msg = pull_msg(); + ASSERT_THAT(poll_msg, IsReady()); + ASSERT_TRUE(poll_msg.value().ok()); + ASSERT_TRUE(poll_msg.value().value().has_value()); + EXPECT_THAT(poll_msg.value().value().value(), IsTestMessage()); + auto pull_msg2 = reader.PullMessage(); + poll_msg = pull_msg2(); + ASSERT_THAT(poll_msg, IsReady()); + ASSERT_TRUE(poll_msg.value().ok()); + ASSERT_FALSE(poll_msg.value().value().has_value()); +} + +TEST(RequestBufferTest, + PullEndOfStreamQueuedWithMessageSwitchDuringPullMessage) { + RequestBuffer buffer; + EXPECT_EQ(buffer.PushClientInitialMetadata(TestMetadata()), 40); + auto pusher = buffer.PushMessage(TestMessage()); + EXPECT_THAT(pusher(), IsReady(49)); + EXPECT_EQ(buffer.FinishSends(), Success{}); + RequestBuffer::Reader reader(&buffer); + auto pull_md = reader.PullClientInitialMetadata(); + EXPECT_THAT(pull_md(), IsReady()); // value tested elsewhere + auto pull_msg = reader.PullMessage(); + buffer.Commit(&reader); + auto poll_msg = pull_msg(); + ASSERT_THAT(poll_msg, IsReady()); + ASSERT_TRUE(poll_msg.value().ok()); + ASSERT_TRUE(poll_msg.value().value().has_value()); + EXPECT_THAT(poll_msg.value().value().value(), IsTestMessage()); + auto pull_msg2 = reader.PullMessage(); + poll_msg = pull_msg2(); + ASSERT_THAT(poll_msg, IsReady()); + ASSERT_TRUE(poll_msg.value().ok()); + ASSERT_FALSE(poll_msg.value().value().has_value()); +} + +TEST(RequestBufferTest, PushThenPullMessageRepeatedly) { + RequestBuffer buffer; + EXPECT_EQ(buffer.PushClientInitialMetadata(TestMetadata()), 40); + RequestBuffer::Reader reader(&buffer); + auto pull_md = reader.PullClientInitialMetadata(); + EXPECT_THAT(pull_md(), IsReady()); // value tested elsewhere + for (int i = 0; i < 10; i++) { + auto pusher = buffer.PushMessage(TestMessage(i)); + EXPECT_THAT(pusher(), IsReady(40 + 9 * (i + 1))); + auto pull_msg = reader.PullMessage(); + auto poll_msg = pull_msg(); + ASSERT_THAT(poll_msg, IsReady()); + ASSERT_TRUE(poll_msg.value().ok()); + ASSERT_TRUE(poll_msg.value().value().has_value()); + EXPECT_THAT(poll_msg.value().value().value(), IsTestMessage(i)); + } +} + +TEST(RequestBufferTest, PushSomeSwitchThenPushPullMessages) { + StrictMock activity; + activity.Activate(); + RequestBuffer buffer; + EXPECT_EQ(buffer.PushClientInitialMetadata(TestMetadata()), 40); + RequestBuffer::Reader reader(&buffer); + auto pull_md = reader.PullClientInitialMetadata(); + EXPECT_THAT(pull_md(), IsReady()); // value tested elsewhere + for (int i = 0; i < 10; i++) { + auto pusher = buffer.PushMessage(TestMessage(i)); + EXPECT_THAT(pusher(), IsReady(40 + 9 * (i + 1))); + } + buffer.Commit(&reader); + for (int i = 0; i < 10; i++) { + auto pull_msg = reader.PullMessage(); + auto poll_msg = pull_msg(); + ASSERT_THAT(poll_msg, IsReady()); + ASSERT_TRUE(poll_msg.value().ok()); + ASSERT_TRUE(poll_msg.value().value().has_value()); + EXPECT_THAT(poll_msg.value().value().value(), IsTestMessage(i)); + } + for (int i = 0; i < 10; i++) { + auto pusher = buffer.PushMessage(TestMessage(i)); + EXPECT_THAT(pusher(), IsReady(0)); + auto pull_msg = reader.PullMessage(); + auto poll_msg = pull_msg(); + ASSERT_THAT(poll_msg, IsReady()); + ASSERT_TRUE(poll_msg.value().ok()); + ASSERT_TRUE(poll_msg.value().value().has_value()); + EXPECT_THAT(poll_msg.value().value().value(), IsTestMessage(i)); + } +} + +TEST(RequestBufferTest, HedgeReadMetadata) { + RequestBuffer buffer; + EXPECT_EQ(buffer.PushClientInitialMetadata(TestMetadata()), 40); + RequestBuffer::Reader reader1(&buffer); + RequestBuffer::Reader reader2(&buffer); + auto pull_md1 = reader1.PullClientInitialMetadata(); + auto pull_md2 = reader2.PullClientInitialMetadata(); + auto poll_md1 = pull_md1(); + auto poll_md2 = pull_md2(); + ASSERT_THAT(poll_md1, IsReady()); + ASSERT_THAT(poll_md2, IsReady()); + auto value1 = std::move(poll_md1.value()); + auto value2 = std::move(poll_md2.value()); + ASSERT_TRUE(value1.ok()); + ASSERT_TRUE(value2.ok()); + EXPECT_THAT(*value1, IsTestMetadata()); + EXPECT_THAT(*value2, IsTestMetadata()); +} + +TEST(RequestBufferTest, HedgeReadMetadataSwitchBeforeFirstRead) { + RequestBuffer buffer; + EXPECT_EQ(buffer.PushClientInitialMetadata(TestMetadata()), 40); + RequestBuffer::Reader reader1(&buffer); + buffer.Commit(&reader1); + RequestBuffer::Reader reader2(&buffer); + auto pull_md1 = reader1.PullClientInitialMetadata(); + auto pull_md2 = reader2.PullClientInitialMetadata(); + auto poll_md1 = pull_md1(); + auto poll_md2 = pull_md2(); + ASSERT_THAT(poll_md1, IsReady()); + ASSERT_THAT(poll_md2, IsReady()); + auto value1 = std::move(poll_md1.value()); + auto value2 = std::move(poll_md2.value()); + ASSERT_TRUE(value1.ok()); + EXPECT_FALSE(value2.ok()); + EXPECT_THAT(*value1, IsTestMetadata()); +} + +TEST(RequestBufferTest, HedgeReadMetadataLate) { + RequestBuffer buffer; + EXPECT_EQ(buffer.PushClientInitialMetadata(TestMetadata()), 40); + RequestBuffer::Reader reader1(&buffer); + auto pull_md1 = reader1.PullClientInitialMetadata(); + auto poll_md1 = pull_md1(); + ASSERT_THAT(poll_md1, IsReady()); + auto value1 = std::move(poll_md1.value()); + ASSERT_TRUE(value1.ok()); + EXPECT_THAT(*value1, IsTestMetadata()); + RequestBuffer::Reader reader2(&buffer); + auto pull_md2 = reader2.PullClientInitialMetadata(); + auto poll_md2 = pull_md2(); + ASSERT_THAT(poll_md2, IsReady()); + auto value2 = std::move(poll_md2.value()); + ASSERT_TRUE(value2.ok()); + EXPECT_THAT(*value2, IsTestMetadata()); +} + +TEST(RequestBufferTest, HedgeReadMetadataLateSwitchAfterPullInitialMetadata) { + RequestBuffer buffer; + EXPECT_EQ(buffer.PushClientInitialMetadata(TestMetadata()), 40); + RequestBuffer::Reader reader1(&buffer); + auto pull_md1 = reader1.PullClientInitialMetadata(); + auto poll_md1 = pull_md1(); + ASSERT_THAT(poll_md1, IsReady()); + auto value1 = std::move(poll_md1.value()); + ASSERT_TRUE(value1.ok()); + EXPECT_THAT(*value1, IsTestMetadata()); + RequestBuffer::Reader reader2(&buffer); + buffer.Commit(&reader1); + auto pull_md2 = reader2.PullClientInitialMetadata(); + auto poll_md2 = pull_md2(); + ASSERT_THAT(poll_md2, IsReady()); + auto value2 = std::move(poll_md2.value()); + EXPECT_FALSE(value2.ok()); +} + +TEST(RequestBufferTest, StreamingPushBeforeLastMessagePulled) { + StrictMock activity; + activity.Activate(); + RequestBuffer buffer; + EXPECT_EQ(buffer.PushClientInitialMetadata(TestMetadata()), 40); + RequestBuffer::Reader reader(&buffer); + auto pull_md = reader.PullClientInitialMetadata(); + EXPECT_THAT(pull_md(), IsReady()); // value tested elsewhere + buffer.Commit(&reader); + auto pusher1 = buffer.PushMessage(TestMessage(1)); + EXPECT_THAT(pusher1(), IsReady(0)); + auto pusher2 = buffer.PushMessage(TestMessage(2)); + EXPECT_THAT(pusher2(), IsPending()); + auto pull1 = reader.PullMessage(); + EXPECT_WAKEUP(activity, auto poll1 = pull1()); + ASSERT_THAT(poll1, IsReady()); + ASSERT_TRUE(poll1.value().ok()); + ASSERT_TRUE(poll1.value().value().has_value()); + EXPECT_THAT(poll1.value().value().value(), IsTestMessage(1)); + auto pull2 = reader.PullMessage(); + auto poll2 = pull2(); + EXPECT_THAT(poll2, IsPending()); + EXPECT_WAKEUP(activity, EXPECT_THAT(pusher2(), IsReady(0))); + poll2 = pull2(); + ASSERT_THAT(poll2, IsReady()); + ASSERT_TRUE(poll2.value().ok()); + ASSERT_TRUE(poll2.value().value().has_value()); + EXPECT_THAT(poll2.value().value().value(), IsTestMessage(2)); +} + +TEST(RequestBufferTest, SwitchAfterEndOfStream) { + RequestBuffer buffer; + EXPECT_EQ(buffer.PushClientInitialMetadata(TestMetadata()), 40); + RequestBuffer::Reader reader(&buffer); + auto pull_md = reader.PullClientInitialMetadata(); + EXPECT_THAT(pull_md(), IsReady()); // value tested elsewhere + auto pusher = buffer.PushMessage(TestMessage()); + EXPECT_THAT(pusher(), IsReady(49)); + EXPECT_EQ(buffer.FinishSends(), Success{}); + auto pull_msg = reader.PullMessage(); + auto poll_msg = pull_msg(); + ASSERT_THAT(poll_msg, IsReady()); + ASSERT_TRUE(poll_msg.value().ok()); + ASSERT_TRUE(poll_msg.value().value().has_value()); + EXPECT_THAT(poll_msg.value().value().value(), IsTestMessage()); + buffer.Commit(&reader); + auto pull_msg2 = reader.PullMessage(); + poll_msg = pull_msg2(); + ASSERT_THAT(poll_msg, IsReady()); + ASSERT_TRUE(poll_msg.value().ok()); + EXPECT_FALSE(poll_msg.value().value().has_value()); +} + +TEST(RequestBufferTest, NothingAfterEndOfStream) { + RequestBuffer buffer; + EXPECT_EQ(buffer.PushClientInitialMetadata(TestMetadata()), 40); + RequestBuffer::Reader reader(&buffer); + auto pull_md = reader.PullClientInitialMetadata(); + EXPECT_THAT(pull_md(), IsReady()); // value tested elsewhere + auto pusher = buffer.PushMessage(TestMessage()); + EXPECT_THAT(pusher(), IsReady(49)); + EXPECT_EQ(buffer.FinishSends(), Success{}); + auto pull_msg = reader.PullMessage(); + auto poll_msg = pull_msg(); + ASSERT_THAT(poll_msg, IsReady()); + ASSERT_TRUE(poll_msg.value().ok()); + ASSERT_TRUE(poll_msg.value().value().has_value()); + EXPECT_THAT(poll_msg.value().value().value(), IsTestMessage()); + auto pull_msg2 = reader.PullMessage(); + poll_msg = pull_msg2(); + ASSERT_THAT(poll_msg, IsReady()); + ASSERT_TRUE(poll_msg.value().ok()); + EXPECT_FALSE(poll_msg.value().value().has_value()); +} + +TEST(RequestBufferTest, CancelBeforeInitialMetadataPush) { + RequestBuffer buffer; + buffer.Cancel(); + EXPECT_EQ(buffer.PushClientInitialMetadata(TestMetadata()), Failure{}); + RequestBuffer::Reader reader(&buffer); + auto pull_md = reader.PullClientInitialMetadata(); + auto poll_md = pull_md(); + ASSERT_THAT(poll_md, IsReady()); + ASSERT_FALSE(poll_md.value().ok()); +} + +TEST(RequestBufferTest, CancelBeforeInitialMetadataPull) { + RequestBuffer buffer; + EXPECT_EQ(buffer.PushClientInitialMetadata(TestMetadata()), 40); + buffer.Cancel(); + RequestBuffer::Reader reader(&buffer); + auto pull_md = reader.PullClientInitialMetadata(); + auto poll_md = pull_md(); + ASSERT_THAT(poll_md, IsReady()); + ASSERT_FALSE(poll_md.value().ok()); +} + +TEST(RequestBufferTest, CancelBeforeMessagePush) { + RequestBuffer buffer; + EXPECT_EQ(buffer.PushClientInitialMetadata(TestMetadata()), 40); + buffer.Cancel(); + auto pusher = buffer.PushMessage(TestMessage()); + auto poll = pusher(); + ASSERT_THAT(poll, IsReady()); + ASSERT_FALSE(poll.value().ok()); + RequestBuffer::Reader reader(&buffer); + auto pull_md = reader.PullClientInitialMetadata(); + auto poll_md = pull_md(); + ASSERT_THAT(poll_md, IsReady()); + ASSERT_FALSE(poll_md.value().ok()); +} + +TEST(RequestBufferTest, CancelBeforeMessagePushButAfterInitialMetadataPull) { + RequestBuffer buffer; + EXPECT_EQ(buffer.PushClientInitialMetadata(TestMetadata()), 40); + RequestBuffer::Reader reader(&buffer); + auto pull_md = reader.PullClientInitialMetadata(); + auto poll_md = pull_md(); + ASSERT_THAT(poll_md, IsReady()); + ASSERT_TRUE(poll_md.value().ok()); + EXPECT_THAT(*poll_md.value(), IsTestMetadata()); + buffer.Cancel(); + auto pusher = buffer.PushMessage(TestMessage()); + auto poll = pusher(); + ASSERT_THAT(poll, IsReady()); + ASSERT_FALSE(poll.value().ok()); +} + +} // namespace grpc_core + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tools/run_tests/generated/tests.json b/tools/run_tests/generated/tests.json index 16c73265ad2..6253d9f900c 100644 --- a/tools/run_tests/generated/tests.json +++ b/tools/run_tests/generated/tests.json @@ -7765,6 +7765,30 @@ ], "uses_polling": true }, + { + "args": [], + "benchmark": false, + "ci_platforms": [ + "linux", + "mac", + "posix", + "windows" + ], + "cpu_cost": 1.0, + "exclude_configs": [], + "exclude_iomgrs": [], + "flaky": false, + "gtest": true, + "language": "c++", + "name": "request_buffer_test", + "platforms": [ + "linux", + "mac", + "posix", + "windows" + ], + "uses_polling": true + }, { "args": [], "benchmark": false,