diff --git a/src/core/ext/transport/chaotic_good/client_transport.cc b/src/core/ext/transport/chaotic_good/client_transport.cc index 3249ff33b3a..2c2c208f31d 100644 --- a/src/core/ext/transport/chaotic_good/client_transport.cc +++ b/src/core/ext/transport/chaotic_good/client_transport.cc @@ -250,19 +250,20 @@ void ChaoticGoodClientTransport::AbortWithError() { } uint32_t ChaoticGoodClientTransport::MakeStream(CallHandler call_handler) { - ReleasableMutexLock lock(&mu_); + MutexLock lock(&mu_); const uint32_t stream_id = next_stream_id_++; + const bool on_done_added = + call_handler.OnDone([self = RefAsSubclass(), + stream_id](bool cancelled) { + if (cancelled) { + self->outgoing_frames_.MakeSender().UnbufferedImmediateSend( + CancelFrame{stream_id}); + } + MutexLock lock(&self->mu_); + self->stream_map_.erase(stream_id); + }); + if (!on_done_added) return 0; stream_map_.emplace(stream_id, call_handler); - lock.Release(); - call_handler.OnDone([self = RefAsSubclass(), - stream_id](bool cancelled) { - if (cancelled) { - self->outgoing_frames_.MakeSender().UnbufferedImmediateSend( - CancelFrame{stream_id}); - } - MutexLock lock(&self->mu_); - self->stream_map_.erase(stream_id); - }); return stream_id; } @@ -322,23 +323,30 @@ void ChaoticGoodClientTransport::StartCall(CallHandler call_handler) { "outbound_loop", [self = RefAsSubclass(), call_handler]() mutable { const uint32_t stream_id = self->MakeStream(call_handler); - return Map( - self->CallOutboundLoop(stream_id, call_handler), - [stream_id, sender = self->outgoing_frames_.MakeSender()]( - absl::Status result) mutable { - GRPC_TRACE_LOG(chaotic_good, INFO) - << "CHAOTIC_GOOD: Call " << stream_id << " finished with " - << result.ToString(); - if (!result.ok()) { - GRPC_TRACE_LOG(chaotic_good, INFO) - << "CHAOTIC_GOOD: Send cancel"; - if (!sender.UnbufferedImmediateSend(CancelFrame{stream_id})) { - GRPC_TRACE_LOG(chaotic_good, INFO) - << "CHAOTIC_GOOD: Send cancel failed"; - } - } - return result; - }); + return If( + stream_id != 0, + [stream_id, call_handler = std::move(call_handler), + self = std::move(self)]() { + return Map( + self->CallOutboundLoop(stream_id, call_handler), + [stream_id, sender = self->outgoing_frames_.MakeSender()]( + absl::Status result) mutable { + GRPC_TRACE_LOG(chaotic_good, INFO) + << "CHAOTIC_GOOD: Call " << stream_id + << " finished with " << result.ToString(); + if (!result.ok()) { + GRPC_TRACE_LOG(chaotic_good, INFO) + << "CHAOTIC_GOOD: Send cancel"; + if (!sender.UnbufferedImmediateSend( + CancelFrame{stream_id})) { + GRPC_TRACE_LOG(chaotic_good, INFO) + << "CHAOTIC_GOOD: Send cancel failed"; + } + } + return result; + }); + }, + []() { return absl::OkStatus(); }); }); } diff --git a/src/core/ext/transport/chaotic_good/server_transport.cc b/src/core/ext/transport/chaotic_good/server_transport.cc index d6c5faf3574..0261223e823 100644 --- a/src/core/ext/transport/chaotic_good/server_transport.cc +++ b/src/core/ext/transport/chaotic_good/server_transport.cc @@ -439,8 +439,7 @@ absl::Status ChaoticGoodServerTransport::NewStream( if (stream_id <= last_seen_new_stream_id_) { return absl::InternalError("Stream id is not increasing"); } - stream_map_.emplace(stream_id, call_initiator); - call_initiator.OnDone( + const bool on_done_added = call_initiator.OnDone( [self = RefAsSubclass(), stream_id](bool) { GRPC_TRACE_LOG(chaotic_good, INFO) << "CHAOTIC_GOOD " << self.get() << " OnDone " << stream_id; @@ -454,6 +453,10 @@ absl::Status ChaoticGoodServerTransport::NewStream( }); } }); + if (!on_done_added) { + return absl::CancelledError(); + } + stream_map_.emplace(stream_id, call_initiator); return absl::OkStatus(); } diff --git a/src/core/lib/transport/call_filters.h b/src/core/lib/transport/call_filters.h index 1cbe428de5f..7f2b0b2f3a6 100644 --- a/src/core/lib/transport/call_filters.h +++ b/src/core/lib/transport/call_filters.h @@ -1485,7 +1485,6 @@ class CallFilters { std::move(value)); } } - call_state_.FinishPullServerTrailingMetadata(); return value; }); } @@ -1497,6 +1496,10 @@ class CallFilters { GRPC_MUST_USE_RESULT auto WasCancelled() { return [this]() { return call_state_.PollWasCancelled(); }; } + // Returns true if server trailing metadata has been pulled + bool WasServerTrailingMetadataPulled() const { + return call_state_.WasServerTrailingMetadataPulled(); + } // Client & server: fill in final_info with the final status of the call. void Finalize(const grpc_call_final_info* final_info); diff --git a/src/core/lib/transport/call_spine.h b/src/core/lib/transport/call_spine.h index 99944b301fc..b1051049a43 100644 --- a/src/core/lib/transport/call_spine.h +++ b/src/core/lib/transport/call_spine.h @@ -54,17 +54,23 @@ class CallSpine final : public Party { CallFilters& call_filters() { return call_filters_; } - // Add a callback to be called when server trailing metadata is received. - void OnDone(absl::AnyInvocable fn) { + // Add a callback to be called when server trailing metadata is received and + // return true. + // If CallOnDone has already been invoked, does nothing and returns false. + GRPC_MUST_USE_RESULT bool OnDone(absl::AnyInvocable fn) { + if (call_filters().WasServerTrailingMetadataPulled()) { + return false; + } if (on_done_ == nullptr) { on_done_ = std::move(fn); - return; + return true; } on_done_ = [first = std::move(fn), next = std::move(on_done_)](bool cancelled) mutable { first(cancelled); next(cancelled); }; + return true; } void CallOnDone(bool cancelled) { if (on_done_ != nullptr) std::exchange(on_done_, nullptr)(cancelled); @@ -232,8 +238,8 @@ class CallInitiator { spine_->PushServerTrailingMetadata(std::move(status)); } - void OnDone(absl::AnyInvocable fn) { - spine_->OnDone(std::move(fn)); + GRPC_MUST_USE_RESULT bool OnDone(absl::AnyInvocable fn) { + return spine_->OnDone(std::move(fn)); } template @@ -281,8 +287,8 @@ class CallHandler { spine_->PushServerTrailingMetadata(std::move(status)); } - void OnDone(absl::AnyInvocable fn) { - spine_->OnDone(std::move(fn)); + GRPC_MUST_USE_RESULT bool OnDone(absl::AnyInvocable fn) { + return spine_->OnDone(std::move(fn)); } template @@ -336,8 +342,8 @@ class UnstartedCallHandler { spine_->PushServerTrailingMetadata(std::move(status)); } - void OnDone(absl::AnyInvocable fn) { - spine_->OnDone(std::move(fn)); + GRPC_MUST_USE_RESULT bool OnDone(absl::AnyInvocable fn) { + return spine_->OnDone(std::move(fn)); } template diff --git a/src/core/lib/transport/call_state.h b/src/core/lib/transport/call_state.h index 3e9acd99c16..67facf8c007 100644 --- a/src/core/lib/transport/call_state.h +++ b/src/core/lib/transport/call_state.h @@ -52,7 +52,7 @@ class CallState { Poll> PollPullServerToClientMessageAvailable(); void FinishPullServerToClientMessage(); Poll PollServerTrailingMetadataAvailable(); - void FinishPullServerTrailingMetadata(); + bool WasServerTrailingMetadataPulled() const; Poll PollWasCancelled(); // Debug std::string DebugString() const; @@ -147,8 +147,6 @@ class CallState { kReading, // Main call loop: processing one message kProcessingServerToClientMessage, - // Processing server trailing metadata - kProcessingServerTrailingMetadata, kTerminated, }; static const char* ServerToClientPullStateString( @@ -172,8 +170,6 @@ class CallState { return "Reading"; case ServerToClientPullState::kProcessingServerToClientMessage: return "ProcessingServerToClientMessage"; - case ServerToClientPullState::kProcessingServerTrailingMetadata: - return "ProcessingServerTrailingMetadata"; case ServerToClientPullState::kTerminated: return "Terminated"; } @@ -294,7 +290,6 @@ GPR_ATTRIBUTE_ALWAYS_INLINE_FUNCTION inline void CallState::Start() { case ServerToClientPullState::kReading: case ServerToClientPullState::kProcessingServerToClientMessage: LOG(FATAL) << "Start called twice"; - case ServerToClientPullState::kProcessingServerTrailingMetadata: case ServerToClientPullState::kTerminated: break; } @@ -644,7 +639,6 @@ CallState::PollPullServerInitialMetadataAvailable() { case ServerToClientPullState::kIdle: case ServerToClientPullState::kReading: case ServerToClientPullState::kProcessingServerToClientMessage: - case ServerToClientPullState::kProcessingServerTrailingMetadata: LOG(FATAL) << "PollPullServerInitialMetadataAvailable called twice"; case ServerToClientPullState::kTerminated: return false; @@ -703,7 +697,6 @@ CallState::FinishPullServerInitialMetadata() { case ServerToClientPullState::kIdle: case ServerToClientPullState::kReading: case ServerToClientPullState::kProcessingServerToClientMessage: - case ServerToClientPullState::kProcessingServerTrailingMetadata: LOG(FATAL) << "Out of order FinishPullServerInitialMetadata"; case ServerToClientPullState::kTerminated: return; @@ -766,9 +759,6 @@ CallState::PollPullServerToClientMessageAvailable() { case ServerToClientPullState::kProcessingServerToClientMessage: LOG(FATAL) << "PollPullServerToClientMessageAvailable called while " "processing a message"; - case ServerToClientPullState::kProcessingServerTrailingMetadata: - LOG(FATAL) << "PollPullServerToClientMessageAvailable called while " - "processing trailing metadata"; case ServerToClientPullState::kTerminated: return Failure{}; } @@ -826,9 +816,6 @@ CallState::FinishPullServerToClientMessage() { server_to_client_pull_state_ = ServerToClientPullState::kIdle; server_to_client_pull_waiter_.Wake(); break; - case ServerToClientPullState::kProcessingServerTrailingMetadata: - LOG(FATAL) << "FinishPullServerToClientMessage called while processing " - "trailing metadata"; case ServerToClientPullState::kTerminated: break; } @@ -875,10 +862,7 @@ CallState::PollServerTrailingMetadataAvailable() { case ServerToClientPushState::kFinished: if (server_trailing_metadata_state_ != ServerTrailingMetadataState::kNotPushed) { - server_to_client_pull_state_ = - ServerToClientPullState::kProcessingServerTrailingMetadata; - server_to_client_pull_waiter_.Wake(); - return Empty{}; + break; // Ready for processing } ABSL_FALLTHROUGH_INTENDED; case ServerToClientPushState::kPushedServerInitialMetadata: @@ -894,26 +878,14 @@ CallState::PollServerTrailingMetadataAvailable() { case ServerToClientPullState::kIdle: if (server_trailing_metadata_state_ != ServerTrailingMetadataState::kNotPushed) { - server_to_client_pull_state_ = - ServerToClientPullState::kProcessingServerTrailingMetadata; - server_to_client_pull_waiter_.Wake(); - return Empty{}; + break; // Ready for processing } return server_trailing_metadata_waiter_.pending(); - case ServerToClientPullState::kProcessingServerTrailingMetadata: - LOG(FATAL) << "PollServerTrailingMetadataAvailable called twice"; case ServerToClientPullState::kTerminated: - return Empty{}; + break; } - Crash("Unreachable"); -} - -GPR_ATTRIBUTE_ALWAYS_INLINE_FUNCTION inline void -CallState::FinishPullServerTrailingMetadata() { - GRPC_TRACE_LOG(call_state, INFO) - << "[call_state] FinishPullServerTrailingMetadata: " - << GRPC_DUMP_ARGS(this, server_trailing_metadata_state_, - server_trailing_metadata_waiter_.DebugString()); + server_to_client_pull_state_ = ServerToClientPullState::kTerminated; + server_to_client_pull_waiter_.Wake(); switch (server_trailing_metadata_state_) { case ServerTrailingMetadataState::kNotPushed: LOG(FATAL) << "FinishPullServerTrailingMetadata called before " @@ -931,6 +903,21 @@ CallState::FinishPullServerTrailingMetadata() { case ServerTrailingMetadataState::kPulledCancel: LOG(FATAL) << "FinishPullServerTrailingMetadata called twice"; } + return Empty{}; +} + +GPR_ATTRIBUTE_ALWAYS_INLINE_FUNCTION inline bool +CallState::WasServerTrailingMetadataPulled() const { + switch (server_trailing_metadata_state_) { + case ServerTrailingMetadataState::kNotPushed: + case ServerTrailingMetadataState::kPushed: + case ServerTrailingMetadataState::kPushedCancel: + return false; + case ServerTrailingMetadataState::kPulled: + case ServerTrailingMetadataState::kPulledCancel: + return true; + } + GPR_UNREACHABLE_CODE(Crash("unreachable")); } GPR_ATTRIBUTE_ALWAYS_INLINE_FUNCTION inline Poll diff --git a/test/core/transport/call_state_test.cc b/test/core/transport/call_state_test.cc index cc35b5b7ec5..c8dcca2825b 100644 --- a/test/core/transport/call_state_test.cc +++ b/test/core/transport/call_state_test.cc @@ -245,7 +245,6 @@ TEST(CallStateTest, ReceiveTrailersOnly) { EXPECT_THAT(state.PollPullServerInitialMetadataAvailable(), IsReady(false)); state.FinishPullServerInitialMetadata(); EXPECT_THAT(state.PollServerTrailingMetadataAvailable(), IsReady()); - state.FinishPullServerTrailingMetadata(); } TEST(CallStateTest, ReceiveTrailersOnlySkipsInitialMetadataOnUnstartedCalls) { @@ -256,7 +255,6 @@ TEST(CallStateTest, ReceiveTrailersOnlySkipsInitialMetadataOnUnstartedCalls) { EXPECT_THAT(state.PollPullServerInitialMetadataAvailable(), IsReady(false)); state.FinishPullServerInitialMetadata(); EXPECT_THAT(state.PollServerTrailingMetadataAvailable(), IsReady()); - state.FinishPullServerTrailingMetadata(); } TEST(CallStateTest, RecallNoCancellation) { @@ -268,8 +266,6 @@ TEST(CallStateTest, RecallNoCancellation) { EXPECT_THAT(state.PollPullServerInitialMetadataAvailable(), IsReady(false)); state.FinishPullServerInitialMetadata(); EXPECT_THAT(state.PollServerTrailingMetadataAvailable(), IsReady()); - EXPECT_THAT(state.PollWasCancelled(), IsPending()); - EXPECT_WAKEUP(activity, state.FinishPullServerTrailingMetadata()); EXPECT_THAT(state.PollWasCancelled(), IsReady(false)); } @@ -282,8 +278,6 @@ TEST(CallStateTest, RecallCancellation) { EXPECT_THAT(state.PollPullServerInitialMetadataAvailable(), IsReady(false)); state.FinishPullServerInitialMetadata(); EXPECT_THAT(state.PollServerTrailingMetadataAvailable(), IsReady()); - EXPECT_THAT(state.PollWasCancelled(), IsPending()); - EXPECT_WAKEUP(activity, state.FinishPullServerTrailingMetadata()); EXPECT_THAT(state.PollWasCancelled(), IsReady(true)); }