[chaotic-good] Fix race between call finishing and adding to the stream map (#37749)

Previously, if we pulled server trailing metadata *before* the call was added to the client transport then we'd never call `on_done_` on the spine and consequently never remove the call from the map. This change fixes that edge case.

In fixing it, I noticed a state in `CallState` that was both complicating the fix and completely irrelevant because we respecced earlier this year to say that ServerTrailingMetadata processing cannot be asynchronous, so I'm removing that state also.

Closes #37749

COPYBARA_INTEGRATE_REVIEW=https://github.com/grpc/grpc/pull/37749 from ctiller:flake-fightas-11 847814a286
PiperOrigin-RevId: 676246259
pull/37767/head
Craig Tiller 2 months ago committed by Copybara-Service
parent 48aa0c85ff
commit e3aa78868f
  1. 20
      src/core/ext/transport/chaotic_good/client_transport.cc
  2. 7
      src/core/ext/transport/chaotic_good/server_transport.cc
  3. 5
      src/core/lib/transport/call_filters.h
  4. 24
      src/core/lib/transport/call_spine.h
  5. 55
      src/core/lib/transport/call_state.h
  6. 6
      test/core/transport/call_state_test.cc

@ -250,10 +250,9 @@ void ChaoticGoodClientTransport::AbortWithError() {
} }
uint32_t ChaoticGoodClientTransport::MakeStream(CallHandler call_handler) { uint32_t ChaoticGoodClientTransport::MakeStream(CallHandler call_handler) {
ReleasableMutexLock lock(&mu_); MutexLock lock(&mu_);
const uint32_t stream_id = next_stream_id_++; const uint32_t stream_id = next_stream_id_++;
stream_map_.emplace(stream_id, call_handler); const bool on_done_added =
lock.Release();
call_handler.OnDone([self = RefAsSubclass<ChaoticGoodClientTransport>(), call_handler.OnDone([self = RefAsSubclass<ChaoticGoodClientTransport>(),
stream_id](bool cancelled) { stream_id](bool cancelled) {
if (cancelled) { if (cancelled) {
@ -263,6 +262,8 @@ uint32_t ChaoticGoodClientTransport::MakeStream(CallHandler call_handler) {
MutexLock lock(&self->mu_); MutexLock lock(&self->mu_);
self->stream_map_.erase(stream_id); self->stream_map_.erase(stream_id);
}); });
if (!on_done_added) return 0;
stream_map_.emplace(stream_id, call_handler);
return stream_id; return stream_id;
} }
@ -322,23 +323,30 @@ void ChaoticGoodClientTransport::StartCall(CallHandler call_handler) {
"outbound_loop", [self = RefAsSubclass<ChaoticGoodClientTransport>(), "outbound_loop", [self = RefAsSubclass<ChaoticGoodClientTransport>(),
call_handler]() mutable { call_handler]() mutable {
const uint32_t stream_id = self->MakeStream(call_handler); const uint32_t stream_id = self->MakeStream(call_handler);
return If(
stream_id != 0,
[stream_id, call_handler = std::move(call_handler),
self = std::move(self)]() {
return Map( return Map(
self->CallOutboundLoop(stream_id, call_handler), self->CallOutboundLoop(stream_id, call_handler),
[stream_id, sender = self->outgoing_frames_.MakeSender()]( [stream_id, sender = self->outgoing_frames_.MakeSender()](
absl::Status result) mutable { absl::Status result) mutable {
GRPC_TRACE_LOG(chaotic_good, INFO) GRPC_TRACE_LOG(chaotic_good, INFO)
<< "CHAOTIC_GOOD: Call " << stream_id << " finished with " << "CHAOTIC_GOOD: Call " << stream_id
<< result.ToString(); << " finished with " << result.ToString();
if (!result.ok()) { if (!result.ok()) {
GRPC_TRACE_LOG(chaotic_good, INFO) GRPC_TRACE_LOG(chaotic_good, INFO)
<< "CHAOTIC_GOOD: Send cancel"; << "CHAOTIC_GOOD: Send cancel";
if (!sender.UnbufferedImmediateSend(CancelFrame{stream_id})) { if (!sender.UnbufferedImmediateSend(
CancelFrame{stream_id})) {
GRPC_TRACE_LOG(chaotic_good, INFO) GRPC_TRACE_LOG(chaotic_good, INFO)
<< "CHAOTIC_GOOD: Send cancel failed"; << "CHAOTIC_GOOD: Send cancel failed";
} }
} }
return result; return result;
}); });
},
[]() { return absl::OkStatus(); });
}); });
} }

@ -439,8 +439,7 @@ absl::Status ChaoticGoodServerTransport::NewStream(
if (stream_id <= last_seen_new_stream_id_) { if (stream_id <= last_seen_new_stream_id_) {
return absl::InternalError("Stream id is not increasing"); return absl::InternalError("Stream id is not increasing");
} }
stream_map_.emplace(stream_id, call_initiator); const bool on_done_added = call_initiator.OnDone(
call_initiator.OnDone(
[self = RefAsSubclass<ChaoticGoodServerTransport>(), stream_id](bool) { [self = RefAsSubclass<ChaoticGoodServerTransport>(), stream_id](bool) {
GRPC_TRACE_LOG(chaotic_good, INFO) GRPC_TRACE_LOG(chaotic_good, INFO)
<< "CHAOTIC_GOOD " << self.get() << " OnDone " << stream_id; << "CHAOTIC_GOOD " << self.get() << " OnDone " << stream_id;
@ -454,6 +453,10 @@ absl::Status ChaoticGoodServerTransport::NewStream(
}); });
} }
}); });
if (!on_done_added) {
return absl::CancelledError();
}
stream_map_.emplace(stream_id, call_initiator);
return absl::OkStatus(); return absl::OkStatus();
} }

@ -1485,7 +1485,6 @@ class CallFilters {
std::move(value)); std::move(value));
} }
} }
call_state_.FinishPullServerTrailingMetadata();
return value; return value;
}); });
} }
@ -1497,6 +1496,10 @@ class CallFilters {
GRPC_MUST_USE_RESULT auto WasCancelled() { GRPC_MUST_USE_RESULT auto WasCancelled() {
return [this]() { return call_state_.PollWasCancelled(); }; return [this]() { return call_state_.PollWasCancelled(); };
} }
// Returns true if server trailing metadata has been pulled
bool WasServerTrailingMetadataPulled() const {
return call_state_.WasServerTrailingMetadataPulled();
}
// Client & server: fill in final_info with the final status of the call. // Client & server: fill in final_info with the final status of the call.
void Finalize(const grpc_call_final_info* final_info); void Finalize(const grpc_call_final_info* final_info);

@ -54,17 +54,23 @@ class CallSpine final : public Party {
CallFilters& call_filters() { return call_filters_; } CallFilters& call_filters() { return call_filters_; }
// Add a callback to be called when server trailing metadata is received. // Add a callback to be called when server trailing metadata is received and
void OnDone(absl::AnyInvocable<void(bool)> fn) { // return true.
// If CallOnDone has already been invoked, does nothing and returns false.
GRPC_MUST_USE_RESULT bool OnDone(absl::AnyInvocable<void(bool)> fn) {
if (call_filters().WasServerTrailingMetadataPulled()) {
return false;
}
if (on_done_ == nullptr) { if (on_done_ == nullptr) {
on_done_ = std::move(fn); on_done_ = std::move(fn);
return; return true;
} }
on_done_ = [first = std::move(fn), on_done_ = [first = std::move(fn),
next = std::move(on_done_)](bool cancelled) mutable { next = std::move(on_done_)](bool cancelled) mutable {
first(cancelled); first(cancelled);
next(cancelled); next(cancelled);
}; };
return true;
} }
void CallOnDone(bool cancelled) { void CallOnDone(bool cancelled) {
if (on_done_ != nullptr) std::exchange(on_done_, nullptr)(cancelled); if (on_done_ != nullptr) std::exchange(on_done_, nullptr)(cancelled);
@ -232,8 +238,8 @@ class CallInitiator {
spine_->PushServerTrailingMetadata(std::move(status)); spine_->PushServerTrailingMetadata(std::move(status));
} }
void OnDone(absl::AnyInvocable<void(bool)> fn) { GRPC_MUST_USE_RESULT bool OnDone(absl::AnyInvocable<void(bool)> fn) {
spine_->OnDone(std::move(fn)); return spine_->OnDone(std::move(fn));
} }
template <typename PromiseFactory> template <typename PromiseFactory>
@ -281,8 +287,8 @@ class CallHandler {
spine_->PushServerTrailingMetadata(std::move(status)); spine_->PushServerTrailingMetadata(std::move(status));
} }
void OnDone(absl::AnyInvocable<void(bool)> fn) { GRPC_MUST_USE_RESULT bool OnDone(absl::AnyInvocable<void(bool)> fn) {
spine_->OnDone(std::move(fn)); return spine_->OnDone(std::move(fn));
} }
template <typename Promise> template <typename Promise>
@ -336,8 +342,8 @@ class UnstartedCallHandler {
spine_->PushServerTrailingMetadata(std::move(status)); spine_->PushServerTrailingMetadata(std::move(status));
} }
void OnDone(absl::AnyInvocable<void(bool)> fn) { GRPC_MUST_USE_RESULT bool OnDone(absl::AnyInvocable<void(bool)> fn) {
spine_->OnDone(std::move(fn)); return spine_->OnDone(std::move(fn));
} }
template <typename Promise> template <typename Promise>

@ -52,7 +52,7 @@ class CallState {
Poll<ValueOrFailure<bool>> PollPullServerToClientMessageAvailable(); Poll<ValueOrFailure<bool>> PollPullServerToClientMessageAvailable();
void FinishPullServerToClientMessage(); void FinishPullServerToClientMessage();
Poll<Empty> PollServerTrailingMetadataAvailable(); Poll<Empty> PollServerTrailingMetadataAvailable();
void FinishPullServerTrailingMetadata(); bool WasServerTrailingMetadataPulled() const;
Poll<bool> PollWasCancelled(); Poll<bool> PollWasCancelled();
// Debug // Debug
std::string DebugString() const; std::string DebugString() const;
@ -147,8 +147,6 @@ class CallState {
kReading, kReading,
// Main call loop: processing one message // Main call loop: processing one message
kProcessingServerToClientMessage, kProcessingServerToClientMessage,
// Processing server trailing metadata
kProcessingServerTrailingMetadata,
kTerminated, kTerminated,
}; };
static const char* ServerToClientPullStateString( static const char* ServerToClientPullStateString(
@ -172,8 +170,6 @@ class CallState {
return "Reading"; return "Reading";
case ServerToClientPullState::kProcessingServerToClientMessage: case ServerToClientPullState::kProcessingServerToClientMessage:
return "ProcessingServerToClientMessage"; return "ProcessingServerToClientMessage";
case ServerToClientPullState::kProcessingServerTrailingMetadata:
return "ProcessingServerTrailingMetadata";
case ServerToClientPullState::kTerminated: case ServerToClientPullState::kTerminated:
return "Terminated"; return "Terminated";
} }
@ -294,7 +290,6 @@ GPR_ATTRIBUTE_ALWAYS_INLINE_FUNCTION inline void CallState::Start() {
case ServerToClientPullState::kReading: case ServerToClientPullState::kReading:
case ServerToClientPullState::kProcessingServerToClientMessage: case ServerToClientPullState::kProcessingServerToClientMessage:
LOG(FATAL) << "Start called twice"; LOG(FATAL) << "Start called twice";
case ServerToClientPullState::kProcessingServerTrailingMetadata:
case ServerToClientPullState::kTerminated: case ServerToClientPullState::kTerminated:
break; break;
} }
@ -644,7 +639,6 @@ CallState::PollPullServerInitialMetadataAvailable() {
case ServerToClientPullState::kIdle: case ServerToClientPullState::kIdle:
case ServerToClientPullState::kReading: case ServerToClientPullState::kReading:
case ServerToClientPullState::kProcessingServerToClientMessage: case ServerToClientPullState::kProcessingServerToClientMessage:
case ServerToClientPullState::kProcessingServerTrailingMetadata:
LOG(FATAL) << "PollPullServerInitialMetadataAvailable called twice"; LOG(FATAL) << "PollPullServerInitialMetadataAvailable called twice";
case ServerToClientPullState::kTerminated: case ServerToClientPullState::kTerminated:
return false; return false;
@ -703,7 +697,6 @@ CallState::FinishPullServerInitialMetadata() {
case ServerToClientPullState::kIdle: case ServerToClientPullState::kIdle:
case ServerToClientPullState::kReading: case ServerToClientPullState::kReading:
case ServerToClientPullState::kProcessingServerToClientMessage: case ServerToClientPullState::kProcessingServerToClientMessage:
case ServerToClientPullState::kProcessingServerTrailingMetadata:
LOG(FATAL) << "Out of order FinishPullServerInitialMetadata"; LOG(FATAL) << "Out of order FinishPullServerInitialMetadata";
case ServerToClientPullState::kTerminated: case ServerToClientPullState::kTerminated:
return; return;
@ -766,9 +759,6 @@ CallState::PollPullServerToClientMessageAvailable() {
case ServerToClientPullState::kProcessingServerToClientMessage: case ServerToClientPullState::kProcessingServerToClientMessage:
LOG(FATAL) << "PollPullServerToClientMessageAvailable called while " LOG(FATAL) << "PollPullServerToClientMessageAvailable called while "
"processing a message"; "processing a message";
case ServerToClientPullState::kProcessingServerTrailingMetadata:
LOG(FATAL) << "PollPullServerToClientMessageAvailable called while "
"processing trailing metadata";
case ServerToClientPullState::kTerminated: case ServerToClientPullState::kTerminated:
return Failure{}; return Failure{};
} }
@ -826,9 +816,6 @@ CallState::FinishPullServerToClientMessage() {
server_to_client_pull_state_ = ServerToClientPullState::kIdle; server_to_client_pull_state_ = ServerToClientPullState::kIdle;
server_to_client_pull_waiter_.Wake(); server_to_client_pull_waiter_.Wake();
break; break;
case ServerToClientPullState::kProcessingServerTrailingMetadata:
LOG(FATAL) << "FinishPullServerToClientMessage called while processing "
"trailing metadata";
case ServerToClientPullState::kTerminated: case ServerToClientPullState::kTerminated:
break; break;
} }
@ -875,10 +862,7 @@ CallState::PollServerTrailingMetadataAvailable() {
case ServerToClientPushState::kFinished: case ServerToClientPushState::kFinished:
if (server_trailing_metadata_state_ != if (server_trailing_metadata_state_ !=
ServerTrailingMetadataState::kNotPushed) { ServerTrailingMetadataState::kNotPushed) {
server_to_client_pull_state_ = break; // Ready for processing
ServerToClientPullState::kProcessingServerTrailingMetadata;
server_to_client_pull_waiter_.Wake();
return Empty{};
} }
ABSL_FALLTHROUGH_INTENDED; ABSL_FALLTHROUGH_INTENDED;
case ServerToClientPushState::kPushedServerInitialMetadata: case ServerToClientPushState::kPushedServerInitialMetadata:
@ -894,26 +878,14 @@ CallState::PollServerTrailingMetadataAvailable() {
case ServerToClientPullState::kIdle: case ServerToClientPullState::kIdle:
if (server_trailing_metadata_state_ != if (server_trailing_metadata_state_ !=
ServerTrailingMetadataState::kNotPushed) { ServerTrailingMetadataState::kNotPushed) {
server_to_client_pull_state_ = break; // Ready for processing
ServerToClientPullState::kProcessingServerTrailingMetadata;
server_to_client_pull_waiter_.Wake();
return Empty{};
} }
return server_trailing_metadata_waiter_.pending(); return server_trailing_metadata_waiter_.pending();
case ServerToClientPullState::kProcessingServerTrailingMetadata:
LOG(FATAL) << "PollServerTrailingMetadataAvailable called twice";
case ServerToClientPullState::kTerminated: case ServerToClientPullState::kTerminated:
return Empty{}; break;
} }
Crash("Unreachable"); server_to_client_pull_state_ = ServerToClientPullState::kTerminated;
} server_to_client_pull_waiter_.Wake();
GPR_ATTRIBUTE_ALWAYS_INLINE_FUNCTION inline void
CallState::FinishPullServerTrailingMetadata() {
GRPC_TRACE_LOG(call_state, INFO)
<< "[call_state] FinishPullServerTrailingMetadata: "
<< GRPC_DUMP_ARGS(this, server_trailing_metadata_state_,
server_trailing_metadata_waiter_.DebugString());
switch (server_trailing_metadata_state_) { switch (server_trailing_metadata_state_) {
case ServerTrailingMetadataState::kNotPushed: case ServerTrailingMetadataState::kNotPushed:
LOG(FATAL) << "FinishPullServerTrailingMetadata called before " LOG(FATAL) << "FinishPullServerTrailingMetadata called before "
@ -931,6 +903,21 @@ CallState::FinishPullServerTrailingMetadata() {
case ServerTrailingMetadataState::kPulledCancel: case ServerTrailingMetadataState::kPulledCancel:
LOG(FATAL) << "FinishPullServerTrailingMetadata called twice"; LOG(FATAL) << "FinishPullServerTrailingMetadata called twice";
} }
return Empty{};
}
GPR_ATTRIBUTE_ALWAYS_INLINE_FUNCTION inline bool
CallState::WasServerTrailingMetadataPulled() const {
switch (server_trailing_metadata_state_) {
case ServerTrailingMetadataState::kNotPushed:
case ServerTrailingMetadataState::kPushed:
case ServerTrailingMetadataState::kPushedCancel:
return false;
case ServerTrailingMetadataState::kPulled:
case ServerTrailingMetadataState::kPulledCancel:
return true;
}
GPR_UNREACHABLE_CODE(Crash("unreachable"));
} }
GPR_ATTRIBUTE_ALWAYS_INLINE_FUNCTION inline Poll<bool> GPR_ATTRIBUTE_ALWAYS_INLINE_FUNCTION inline Poll<bool>

@ -245,7 +245,6 @@ TEST(CallStateTest, ReceiveTrailersOnly) {
EXPECT_THAT(state.PollPullServerInitialMetadataAvailable(), IsReady(false)); EXPECT_THAT(state.PollPullServerInitialMetadataAvailable(), IsReady(false));
state.FinishPullServerInitialMetadata(); state.FinishPullServerInitialMetadata();
EXPECT_THAT(state.PollServerTrailingMetadataAvailable(), IsReady()); EXPECT_THAT(state.PollServerTrailingMetadataAvailable(), IsReady());
state.FinishPullServerTrailingMetadata();
} }
TEST(CallStateTest, ReceiveTrailersOnlySkipsInitialMetadataOnUnstartedCalls) { TEST(CallStateTest, ReceiveTrailersOnlySkipsInitialMetadataOnUnstartedCalls) {
@ -256,7 +255,6 @@ TEST(CallStateTest, ReceiveTrailersOnlySkipsInitialMetadataOnUnstartedCalls) {
EXPECT_THAT(state.PollPullServerInitialMetadataAvailable(), IsReady(false)); EXPECT_THAT(state.PollPullServerInitialMetadataAvailable(), IsReady(false));
state.FinishPullServerInitialMetadata(); state.FinishPullServerInitialMetadata();
EXPECT_THAT(state.PollServerTrailingMetadataAvailable(), IsReady()); EXPECT_THAT(state.PollServerTrailingMetadataAvailable(), IsReady());
state.FinishPullServerTrailingMetadata();
} }
TEST(CallStateTest, RecallNoCancellation) { TEST(CallStateTest, RecallNoCancellation) {
@ -268,8 +266,6 @@ TEST(CallStateTest, RecallNoCancellation) {
EXPECT_THAT(state.PollPullServerInitialMetadataAvailable(), IsReady(false)); EXPECT_THAT(state.PollPullServerInitialMetadataAvailable(), IsReady(false));
state.FinishPullServerInitialMetadata(); state.FinishPullServerInitialMetadata();
EXPECT_THAT(state.PollServerTrailingMetadataAvailable(), IsReady()); EXPECT_THAT(state.PollServerTrailingMetadataAvailable(), IsReady());
EXPECT_THAT(state.PollWasCancelled(), IsPending());
EXPECT_WAKEUP(activity, state.FinishPullServerTrailingMetadata());
EXPECT_THAT(state.PollWasCancelled(), IsReady(false)); EXPECT_THAT(state.PollWasCancelled(), IsReady(false));
} }
@ -282,8 +278,6 @@ TEST(CallStateTest, RecallCancellation) {
EXPECT_THAT(state.PollPullServerInitialMetadataAvailable(), IsReady(false)); EXPECT_THAT(state.PollPullServerInitialMetadataAvailable(), IsReady(false));
state.FinishPullServerInitialMetadata(); state.FinishPullServerInitialMetadata();
EXPECT_THAT(state.PollServerTrailingMetadataAvailable(), IsReady()); EXPECT_THAT(state.PollServerTrailingMetadataAvailable(), IsReady());
EXPECT_THAT(state.PollWasCancelled(), IsPending());
EXPECT_WAKEUP(activity, state.FinishPullServerTrailingMetadata());
EXPECT_THAT(state.PollWasCancelled(), IsReady(true)); EXPECT_THAT(state.PollWasCancelled(), IsReady(true));
} }

Loading…
Cancel
Save