diff --git a/src/core/lib/transport/call_filters.cc b/src/core/lib/transport/call_filters.cc index 6ebfe125b69..db849b2798b 100644 --- a/src/core/lib/transport/call_filters.cc +++ b/src/core/lib/transport/call_filters.cc @@ -17,6 +17,7 @@ #include "src/core/lib/transport/call_filters.h" #include "src/core/lib/gprpp/crash.h" +#include "src/core/lib/transport/metadata.h" namespace grpc_core { @@ -165,21 +166,10 @@ template class InfallibleOperationExecutor; /////////////////////////////////////////////////////////////////////////////// // CallFilters -CallFilters::CallFilters() : stack_(nullptr), call_data_(nullptr) {} - -CallFilters::CallFilters(RefCountedPtr stack) - : stack_(std::move(stack)), - call_data_(gpr_malloc_aligned(stack_->data_.call_data_size, - stack_->data_.call_data_alignment)) { - for (const auto& constructor : stack_->data_.filter_constructor) { - constructor.call_init(Offset(call_data_, constructor.call_offset), - constructor.channel_data); - } - client_initial_metadata_state_.Start(); - client_to_server_message_state_.Start(); - server_initial_metadata_state_.Start(); - server_to_client_message_state_.Start(); -} +CallFilters::CallFilters(ClientMetadataHandle client_initial_metadata) + : stack_(nullptr), + call_data_(nullptr), + client_initial_metadata_(std::move(client_initial_metadata)) {} CallFilters::~CallFilters() { if (call_data_ != nullptr) { diff --git a/src/core/lib/transport/call_filters.h b/src/core/lib/transport/call_filters.h index 3df2c50e7a3..013db5e0b3b 100644 --- a/src/core/lib/transport/call_filters.h +++ b/src/core/lib/transport/call_filters.h @@ -1093,6 +1093,9 @@ class InfallibleOperationExecutor { // augment it to provide all the functionality that we must. class PipeState { public: + struct StartPushed {}; + PipeState() = default; + explicit PipeState(StartPushed) : state_(ValueState::kQueued) {} // Start the pipe: allows pulls to proceed void Start(); // A push operation is beginning @@ -1245,8 +1248,7 @@ class CallFilters { filters_detail::StackData data_; }; - CallFilters(); - explicit CallFilters(RefCountedPtr stack); + explicit CallFilters(ClientMetadataHandle client_initial_metadata); ~CallFilters(); CallFilters(const CallFilters&) = delete; @@ -1256,7 +1258,9 @@ class CallFilters { void SetStack(RefCountedPtr stack); - GRPC_MUST_USE_RESULT auto PushClientInitialMetadata(ClientMetadataHandle md); + ClientMetadata* unprocessed_client_initial_metadata() { + return client_initial_metadata_.get(); + } GRPC_MUST_USE_RESULT auto PullClientInitialMetadata(); GRPC_MUST_USE_RESULT auto PushServerInitialMetadata(ServerMetadataHandle md); GRPC_MUST_USE_RESULT auto PullServerInitialMetadata(); @@ -1374,6 +1378,60 @@ class CallFilters { }; }; + class PullClientInitialMetadataPromise { + public: + explicit PullClientInitialMetadataPromise(CallFilters* filters) + : filters_(filters) {} + + PullClientInitialMetadataPromise(const PullClientInitialMetadataPromise&) = + delete; + PullClientInitialMetadataPromise& operator=( + const PullClientInitialMetadataPromise&) = delete; + PullClientInitialMetadataPromise( + PullClientInitialMetadataPromise&& other) noexcept + : filters_(std::exchange(other.filters_, nullptr)), + executor_(std::move(other.executor_)) {} + PullClientInitialMetadataPromise& operator=( + PullClientInitialMetadataPromise&&) = delete; + + Poll> operator()() { + if (executor_.IsRunning()) { + return FinishOperationExecutor(executor_.Step(filters_->call_data_)); + } + auto p = state().PollPull(); + auto* r = p.value_if_ready(); + gpr_log(GPR_INFO, "%s", r == nullptr ? "PENDING" : r->ToString().c_str()); + if (r == nullptr) return Pending{}; + if (!r->ok()) { + filters_->CancelDueToFailedPipeOperation(); + return Failure{}; + } + GPR_ASSERT(filters_->client_initial_metadata_ != nullptr); + return FinishOperationExecutor(executor_.Start( + &filters_->stack_->data_.client_initial_metadata, + std::move(filters_->client_initial_metadata_), filters_->call_data_)); + } + + private: + filters_detail::PipeState& state() { + return filters_->client_initial_metadata_state_; + } + + Poll> FinishOperationExecutor( + Poll> p) { + auto* r = p.value_if_ready(); + if (r == nullptr) return Pending{}; + GPR_DEBUG_ASSERT(!executor_.IsRunning()); + state().AckPull(); + if (r->ok != nullptr) return std::move(r->ok); + filters_->PushServerTrailingMetadata(std::move(r->error)); + return Failure{}; + } + + CallFilters* filters_; + filters_detail::OperationExecutor executor_; + }; + class PullServerTrailingMetadataPromise { public: explicit PullServerTrailingMetadataPromise(CallFilters* filters) @@ -1411,31 +1469,27 @@ class CallFilters { RefCountedPtr stack_; - filters_detail::PipeState client_initial_metadata_state_; + filters_detail::PipeState client_initial_metadata_state_{ + filters_detail::PipeState::StartPushed{}}; filters_detail::PipeState server_initial_metadata_state_; filters_detail::PipeState client_to_server_message_state_; filters_detail::PipeState server_to_client_message_state_; IntraActivityWaiter server_trailing_metadata_waiter_; void* call_data_; + ClientMetadataHandle client_initial_metadata_; // The following void*'s are pointers to a `Push` object (from above). // They are used to track the current push operation for each pipe. // It would be lovely for them to be typed pointers, but that would require // a recursive type definition since the location of this field needs to be // a template argument to the `Push` object itself. - void* client_initial_metadata_push_ = nullptr; void* server_initial_metadata_push_ = nullptr; void* client_to_server_message_push_ = nullptr; void* server_to_client_message_push_ = nullptr; ServerMetadataHandle server_trailing_metadata_; - using ClientInitialMetadataPromises = - PipePromise<&CallFilters::client_initial_metadata_state_, - &CallFilters::client_initial_metadata_push_, - ClientMetadataHandle, - &filters_detail::StackData::client_initial_metadata>; using ServerInitialMetadataPromises = PipePromise<&CallFilters::server_initial_metadata_state_, &CallFilters::server_initial_metadata_push_, @@ -1451,14 +1505,8 @@ class CallFilters { &filters_detail::StackData::server_to_client_messages>; }; -inline auto CallFilters::PushClientInitialMetadata(ClientMetadataHandle md) { - GPR_ASSERT(md != nullptr); - return [p = ClientInitialMetadataPromises::Push{ - this, std::move(md)}]() mutable { return p(); }; -} - inline auto CallFilters::PullClientInitialMetadata() { - return ClientInitialMetadataPromises::Pull{this}; + return PullClientInitialMetadataPromise(this); } inline auto CallFilters::PushServerInitialMetadata(ServerMetadataHandle md) { diff --git a/test/core/surface/channel_init_test.cc b/test/core/surface/channel_init_test.cc index 4939745c36c..3282c0e7916 100644 --- a/test/core/surface/channel_init_test.cc +++ b/test/core/surface/channel_init_test.cc @@ -255,7 +255,10 @@ TEST(ChannelInitTest, CanCreateFilterWithCall) { segment->AddToCallFilterStack(stack_builder); segment = absl::CancelledError(); // force the segment to be destroyed auto stack = stack_builder.Build(); - { CallFilters call_filters(stack); } + { + CallFilters call_filters(Arena::MakePooled()); + call_filters.SetStack(std::move(stack)); + } EXPECT_EQ(p, 1); } diff --git a/test/core/transport/call_filters_test.cc b/test/core/transport/call_filters_test.cc index 1db57026314..d2a97643f29 100644 --- a/test/core/transport/call_filters_test.cc +++ b/test/core/transport/call_filters_test.cc @@ -1402,24 +1402,18 @@ TEST(CallFiltersTest, UnaryCall) { CallFilters::StackBuilder builder; builder.Add(&f1); builder.Add(&f2); - CallFilters filters(builder.Build()); auto memory_allocator = MakeMemoryQuota("test-quota")->CreateMemoryAllocator("foo"); auto arena = MakeScopedArena(1024, &memory_allocator); + CallFilters filters(Arena::MakePooled()); + filters.SetStack(builder.Build()); promise_detail::Context ctx(arena.get()); StrictMock activity; activity.Activate(); - // Push client initial metadata - auto push_client_initial_metadata = - filters.PushClientInitialMetadata(Arena::MakePooled()); - EXPECT_THAT(push_client_initial_metadata(), IsPending()); + // Pull client initial metadata auto pull_client_initial_metadata = filters.PullClientInitialMetadata(); - // Pull client initial metadata, expect a wakeup - EXPECT_CALL(activity, WakeupRequested()); EXPECT_THAT(pull_client_initial_metadata(), IsReady()); Mock::VerifyAndClearExpectations(&activity); - // Push should be done - EXPECT_THAT(push_client_initial_metadata(), IsReady(Success{})); // Push client to server message auto push_client_to_server_message = filters.PushClientToServerMessage( Arena::MakePooled(SliceBuffer(), 0));