[chaotic-good] Fix race between call finishing and adding to the stream map (#37749)

Previously, if we pulled server trailing metadata *before* the call was added to the client transport then we'd never call `on_done_` on the spine and consequently never remove the call from the map. This change fixes that edge case.

In fixing it, I noticed a state in `CallState` that was both complicating the fix and completely irrelevant because we respecced earlier this year to say that ServerTrailingMetadata processing cannot be asynchronous, so I'm removing that state also.

Closes #37749

COPYBARA_INTEGRATE_REVIEW=https://github.com/grpc/grpc/pull/37749 from ctiller:flake-fightas-11 847814a286
PiperOrigin-RevId: 676246259
pull/37767/head
Craig Tiller 2 months ago committed by Copybara-Service
parent 48aa0c85ff
commit e3aa78868f
  1. 64
      src/core/ext/transport/chaotic_good/client_transport.cc
  2. 7
      src/core/ext/transport/chaotic_good/server_transport.cc
  3. 5
      src/core/lib/transport/call_filters.h
  4. 24
      src/core/lib/transport/call_spine.h
  5. 55
      src/core/lib/transport/call_state.h
  6. 6
      test/core/transport/call_state_test.cc

@ -250,19 +250,20 @@ void ChaoticGoodClientTransport::AbortWithError() {
}
uint32_t ChaoticGoodClientTransport::MakeStream(CallHandler call_handler) {
ReleasableMutexLock lock(&mu_);
MutexLock lock(&mu_);
const uint32_t stream_id = next_stream_id_++;
const bool on_done_added =
call_handler.OnDone([self = RefAsSubclass<ChaoticGoodClientTransport>(),
stream_id](bool cancelled) {
if (cancelled) {
self->outgoing_frames_.MakeSender().UnbufferedImmediateSend(
CancelFrame{stream_id});
}
MutexLock lock(&self->mu_);
self->stream_map_.erase(stream_id);
});
if (!on_done_added) return 0;
stream_map_.emplace(stream_id, call_handler);
lock.Release();
call_handler.OnDone([self = RefAsSubclass<ChaoticGoodClientTransport>(),
stream_id](bool cancelled) {
if (cancelled) {
self->outgoing_frames_.MakeSender().UnbufferedImmediateSend(
CancelFrame{stream_id});
}
MutexLock lock(&self->mu_);
self->stream_map_.erase(stream_id);
});
return stream_id;
}
@ -322,23 +323,30 @@ void ChaoticGoodClientTransport::StartCall(CallHandler call_handler) {
"outbound_loop", [self = RefAsSubclass<ChaoticGoodClientTransport>(),
call_handler]() mutable {
const uint32_t stream_id = self->MakeStream(call_handler);
return Map(
self->CallOutboundLoop(stream_id, call_handler),
[stream_id, sender = self->outgoing_frames_.MakeSender()](
absl::Status result) mutable {
GRPC_TRACE_LOG(chaotic_good, INFO)
<< "CHAOTIC_GOOD: Call " << stream_id << " finished with "
<< result.ToString();
if (!result.ok()) {
GRPC_TRACE_LOG(chaotic_good, INFO)
<< "CHAOTIC_GOOD: Send cancel";
if (!sender.UnbufferedImmediateSend(CancelFrame{stream_id})) {
GRPC_TRACE_LOG(chaotic_good, INFO)
<< "CHAOTIC_GOOD: Send cancel failed";
}
}
return result;
});
return If(
stream_id != 0,
[stream_id, call_handler = std::move(call_handler),
self = std::move(self)]() {
return Map(
self->CallOutboundLoop(stream_id, call_handler),
[stream_id, sender = self->outgoing_frames_.MakeSender()](
absl::Status result) mutable {
GRPC_TRACE_LOG(chaotic_good, INFO)
<< "CHAOTIC_GOOD: Call " << stream_id
<< " finished with " << result.ToString();
if (!result.ok()) {
GRPC_TRACE_LOG(chaotic_good, INFO)
<< "CHAOTIC_GOOD: Send cancel";
if (!sender.UnbufferedImmediateSend(
CancelFrame{stream_id})) {
GRPC_TRACE_LOG(chaotic_good, INFO)
<< "CHAOTIC_GOOD: Send cancel failed";
}
}
return result;
});
},
[]() { return absl::OkStatus(); });
});
}

@ -439,8 +439,7 @@ absl::Status ChaoticGoodServerTransport::NewStream(
if (stream_id <= last_seen_new_stream_id_) {
return absl::InternalError("Stream id is not increasing");
}
stream_map_.emplace(stream_id, call_initiator);
call_initiator.OnDone(
const bool on_done_added = call_initiator.OnDone(
[self = RefAsSubclass<ChaoticGoodServerTransport>(), stream_id](bool) {
GRPC_TRACE_LOG(chaotic_good, INFO)
<< "CHAOTIC_GOOD " << self.get() << " OnDone " << stream_id;
@ -454,6 +453,10 @@ absl::Status ChaoticGoodServerTransport::NewStream(
});
}
});
if (!on_done_added) {
return absl::CancelledError();
}
stream_map_.emplace(stream_id, call_initiator);
return absl::OkStatus();
}

@ -1485,7 +1485,6 @@ class CallFilters {
std::move(value));
}
}
call_state_.FinishPullServerTrailingMetadata();
return value;
});
}
@ -1497,6 +1496,10 @@ class CallFilters {
GRPC_MUST_USE_RESULT auto WasCancelled() {
return [this]() { return call_state_.PollWasCancelled(); };
}
// Returns true if server trailing metadata has been pulled
bool WasServerTrailingMetadataPulled() const {
return call_state_.WasServerTrailingMetadataPulled();
}
// Client & server: fill in final_info with the final status of the call.
void Finalize(const grpc_call_final_info* final_info);

@ -54,17 +54,23 @@ class CallSpine final : public Party {
CallFilters& call_filters() { return call_filters_; }
// Add a callback to be called when server trailing metadata is received.
void OnDone(absl::AnyInvocable<void(bool)> fn) {
// Add a callback to be called when server trailing metadata is received and
// return true.
// If CallOnDone has already been invoked, does nothing and returns false.
GRPC_MUST_USE_RESULT bool OnDone(absl::AnyInvocable<void(bool)> fn) {
if (call_filters().WasServerTrailingMetadataPulled()) {
return false;
}
if (on_done_ == nullptr) {
on_done_ = std::move(fn);
return;
return true;
}
on_done_ = [first = std::move(fn),
next = std::move(on_done_)](bool cancelled) mutable {
first(cancelled);
next(cancelled);
};
return true;
}
void CallOnDone(bool cancelled) {
if (on_done_ != nullptr) std::exchange(on_done_, nullptr)(cancelled);
@ -232,8 +238,8 @@ class CallInitiator {
spine_->PushServerTrailingMetadata(std::move(status));
}
void OnDone(absl::AnyInvocable<void(bool)> fn) {
spine_->OnDone(std::move(fn));
GRPC_MUST_USE_RESULT bool OnDone(absl::AnyInvocable<void(bool)> fn) {
return spine_->OnDone(std::move(fn));
}
template <typename PromiseFactory>
@ -281,8 +287,8 @@ class CallHandler {
spine_->PushServerTrailingMetadata(std::move(status));
}
void OnDone(absl::AnyInvocable<void(bool)> fn) {
spine_->OnDone(std::move(fn));
GRPC_MUST_USE_RESULT bool OnDone(absl::AnyInvocable<void(bool)> fn) {
return spine_->OnDone(std::move(fn));
}
template <typename Promise>
@ -336,8 +342,8 @@ class UnstartedCallHandler {
spine_->PushServerTrailingMetadata(std::move(status));
}
void OnDone(absl::AnyInvocable<void(bool)> fn) {
spine_->OnDone(std::move(fn));
GRPC_MUST_USE_RESULT bool OnDone(absl::AnyInvocable<void(bool)> fn) {
return spine_->OnDone(std::move(fn));
}
template <typename Promise>

@ -52,7 +52,7 @@ class CallState {
Poll<ValueOrFailure<bool>> PollPullServerToClientMessageAvailable();
void FinishPullServerToClientMessage();
Poll<Empty> PollServerTrailingMetadataAvailable();
void FinishPullServerTrailingMetadata();
bool WasServerTrailingMetadataPulled() const;
Poll<bool> PollWasCancelled();
// Debug
std::string DebugString() const;
@ -147,8 +147,6 @@ class CallState {
kReading,
// Main call loop: processing one message
kProcessingServerToClientMessage,
// Processing server trailing metadata
kProcessingServerTrailingMetadata,
kTerminated,
};
static const char* ServerToClientPullStateString(
@ -172,8 +170,6 @@ class CallState {
return "Reading";
case ServerToClientPullState::kProcessingServerToClientMessage:
return "ProcessingServerToClientMessage";
case ServerToClientPullState::kProcessingServerTrailingMetadata:
return "ProcessingServerTrailingMetadata";
case ServerToClientPullState::kTerminated:
return "Terminated";
}
@ -294,7 +290,6 @@ GPR_ATTRIBUTE_ALWAYS_INLINE_FUNCTION inline void CallState::Start() {
case ServerToClientPullState::kReading:
case ServerToClientPullState::kProcessingServerToClientMessage:
LOG(FATAL) << "Start called twice";
case ServerToClientPullState::kProcessingServerTrailingMetadata:
case ServerToClientPullState::kTerminated:
break;
}
@ -644,7 +639,6 @@ CallState::PollPullServerInitialMetadataAvailable() {
case ServerToClientPullState::kIdle:
case ServerToClientPullState::kReading:
case ServerToClientPullState::kProcessingServerToClientMessage:
case ServerToClientPullState::kProcessingServerTrailingMetadata:
LOG(FATAL) << "PollPullServerInitialMetadataAvailable called twice";
case ServerToClientPullState::kTerminated:
return false;
@ -703,7 +697,6 @@ CallState::FinishPullServerInitialMetadata() {
case ServerToClientPullState::kIdle:
case ServerToClientPullState::kReading:
case ServerToClientPullState::kProcessingServerToClientMessage:
case ServerToClientPullState::kProcessingServerTrailingMetadata:
LOG(FATAL) << "Out of order FinishPullServerInitialMetadata";
case ServerToClientPullState::kTerminated:
return;
@ -766,9 +759,6 @@ CallState::PollPullServerToClientMessageAvailable() {
case ServerToClientPullState::kProcessingServerToClientMessage:
LOG(FATAL) << "PollPullServerToClientMessageAvailable called while "
"processing a message";
case ServerToClientPullState::kProcessingServerTrailingMetadata:
LOG(FATAL) << "PollPullServerToClientMessageAvailable called while "
"processing trailing metadata";
case ServerToClientPullState::kTerminated:
return Failure{};
}
@ -826,9 +816,6 @@ CallState::FinishPullServerToClientMessage() {
server_to_client_pull_state_ = ServerToClientPullState::kIdle;
server_to_client_pull_waiter_.Wake();
break;
case ServerToClientPullState::kProcessingServerTrailingMetadata:
LOG(FATAL) << "FinishPullServerToClientMessage called while processing "
"trailing metadata";
case ServerToClientPullState::kTerminated:
break;
}
@ -875,10 +862,7 @@ CallState::PollServerTrailingMetadataAvailable() {
case ServerToClientPushState::kFinished:
if (server_trailing_metadata_state_ !=
ServerTrailingMetadataState::kNotPushed) {
server_to_client_pull_state_ =
ServerToClientPullState::kProcessingServerTrailingMetadata;
server_to_client_pull_waiter_.Wake();
return Empty{};
break; // Ready for processing
}
ABSL_FALLTHROUGH_INTENDED;
case ServerToClientPushState::kPushedServerInitialMetadata:
@ -894,26 +878,14 @@ CallState::PollServerTrailingMetadataAvailable() {
case ServerToClientPullState::kIdle:
if (server_trailing_metadata_state_ !=
ServerTrailingMetadataState::kNotPushed) {
server_to_client_pull_state_ =
ServerToClientPullState::kProcessingServerTrailingMetadata;
server_to_client_pull_waiter_.Wake();
return Empty{};
break; // Ready for processing
}
return server_trailing_metadata_waiter_.pending();
case ServerToClientPullState::kProcessingServerTrailingMetadata:
LOG(FATAL) << "PollServerTrailingMetadataAvailable called twice";
case ServerToClientPullState::kTerminated:
return Empty{};
break;
}
Crash("Unreachable");
}
GPR_ATTRIBUTE_ALWAYS_INLINE_FUNCTION inline void
CallState::FinishPullServerTrailingMetadata() {
GRPC_TRACE_LOG(call_state, INFO)
<< "[call_state] FinishPullServerTrailingMetadata: "
<< GRPC_DUMP_ARGS(this, server_trailing_metadata_state_,
server_trailing_metadata_waiter_.DebugString());
server_to_client_pull_state_ = ServerToClientPullState::kTerminated;
server_to_client_pull_waiter_.Wake();
switch (server_trailing_metadata_state_) {
case ServerTrailingMetadataState::kNotPushed:
LOG(FATAL) << "FinishPullServerTrailingMetadata called before "
@ -931,6 +903,21 @@ CallState::FinishPullServerTrailingMetadata() {
case ServerTrailingMetadataState::kPulledCancel:
LOG(FATAL) << "FinishPullServerTrailingMetadata called twice";
}
return Empty{};
}
GPR_ATTRIBUTE_ALWAYS_INLINE_FUNCTION inline bool
CallState::WasServerTrailingMetadataPulled() const {
switch (server_trailing_metadata_state_) {
case ServerTrailingMetadataState::kNotPushed:
case ServerTrailingMetadataState::kPushed:
case ServerTrailingMetadataState::kPushedCancel:
return false;
case ServerTrailingMetadataState::kPulled:
case ServerTrailingMetadataState::kPulledCancel:
return true;
}
GPR_UNREACHABLE_CODE(Crash("unreachable"));
}
GPR_ATTRIBUTE_ALWAYS_INLINE_FUNCTION inline Poll<bool>

@ -245,7 +245,6 @@ TEST(CallStateTest, ReceiveTrailersOnly) {
EXPECT_THAT(state.PollPullServerInitialMetadataAvailable(), IsReady(false));
state.FinishPullServerInitialMetadata();
EXPECT_THAT(state.PollServerTrailingMetadataAvailable(), IsReady());
state.FinishPullServerTrailingMetadata();
}
TEST(CallStateTest, ReceiveTrailersOnlySkipsInitialMetadataOnUnstartedCalls) {
@ -256,7 +255,6 @@ TEST(CallStateTest, ReceiveTrailersOnlySkipsInitialMetadataOnUnstartedCalls) {
EXPECT_THAT(state.PollPullServerInitialMetadataAvailable(), IsReady(false));
state.FinishPullServerInitialMetadata();
EXPECT_THAT(state.PollServerTrailingMetadataAvailable(), IsReady());
state.FinishPullServerTrailingMetadata();
}
TEST(CallStateTest, RecallNoCancellation) {
@ -268,8 +266,6 @@ TEST(CallStateTest, RecallNoCancellation) {
EXPECT_THAT(state.PollPullServerInitialMetadataAvailable(), IsReady(false));
state.FinishPullServerInitialMetadata();
EXPECT_THAT(state.PollServerTrailingMetadataAvailable(), IsReady());
EXPECT_THAT(state.PollWasCancelled(), IsPending());
EXPECT_WAKEUP(activity, state.FinishPullServerTrailingMetadata());
EXPECT_THAT(state.PollWasCancelled(), IsReady(false));
}
@ -282,8 +278,6 @@ TEST(CallStateTest, RecallCancellation) {
EXPECT_THAT(state.PollPullServerInitialMetadataAvailable(), IsReady(false));
state.FinishPullServerInitialMetadata();
EXPECT_THAT(state.PollServerTrailingMetadataAvailable(), IsReady());
EXPECT_THAT(state.PollWasCancelled(), IsPending());
EXPECT_WAKEUP(activity, state.FinishPullServerTrailingMetadata());
EXPECT_THAT(state.PollWasCancelled(), IsReady(true));
}

Loading…
Cancel
Save