diff --git a/src/core/lib/channel/promise_based_filter.cc b/src/core/lib/channel/promise_based_filter.cc index 929d52dce27..cafa665b116 100644 --- a/src/core/lib/channel/promise_based_filter.cc +++ b/src/core/lib/channel/promise_based_filter.cc @@ -403,7 +403,8 @@ void BaseCallData::SendMessage::OnComplete(absl::Status status) { } } -void BaseCallData::SendMessage::Done(const ServerMetadata& metadata) { +void BaseCallData::SendMessage::Done(const ServerMetadata& metadata, + Flusher* flusher) { if (grpc_trace_channel.enabled()) { gpr_log(GPR_INFO, "%s SendMessage.Done st=%s md=%s", base_->LogTag().c_str(), StateString(state_), @@ -419,7 +420,16 @@ void BaseCallData::SendMessage::Done(const ServerMetadata& metadata) { state_ = State::kCancelledButNotYetPolled; break; case State::kGotBatchNoPipe: - case State::kGotBatch: + case State::kGotBatch: { + std::string temp; + batch_.CancelWith( + absl::Status( + static_cast(metadata.get(GrpcStatusMetadata()) + .value_or(GRPC_STATUS_UNKNOWN)), + metadata.GetStringValue("grpc-message", &temp).value_or("")), + flusher); + state_ = State::kCancelledButNotYetPolled; + } break; case State::kBatchCompleted: Crash(absl::StrFormat("ILLEGAL STATE: %s", StateString(state_))); break; @@ -964,7 +974,7 @@ class ClientCallData::PollContext { if (auto* r = absl::get_if(&poll)) { auto md = std::move(*r); if (self_->send_message() != nullptr) { - self_->send_message()->Done(*md); + self_->send_message()->Done(*md, flusher_); } if (self_->receive_message() != nullptr) { self_->receive_message()->Done(*md, flusher_); @@ -1384,7 +1394,7 @@ void ClientCallData::Cancel(grpc_error_handle error, Flusher* flusher) { } } if (send_message() != nullptr) { - send_message()->Done(*ServerMetadataFromStatus(error)); + send_message()->Done(*ServerMetadataFromStatus(error), flusher); } if (receive_message() != nullptr) { receive_message()->Done(*ServerMetadataFromStatus(error), flusher); @@ -1640,7 +1650,7 @@ void ClientCallData::RecvTrailingMetadataReady(grpc_error_handle error) { receive_message()->Done(*recv_trailing_metadata_, &flusher); } if (send_message() != nullptr) { - send_message()->Done(*recv_trailing_metadata_); + send_message()->Done(*recv_trailing_metadata_, &flusher); } // Repoll the promise. ScopedContext context(this); @@ -1983,7 +1993,7 @@ void ServerCallData::Completed(grpc_error_handle error, Flusher* flusher) { } ScopedContext ctx(this); if (send_message() != nullptr) { - send_message()->Done(*ServerMetadataFromStatus(error)); + send_message()->Done(*ServerMetadataFromStatus(error), flusher); } if (receive_message() != nullptr) { receive_message()->Done(*ServerMetadataFromStatus(error), flusher); @@ -2205,7 +2215,7 @@ void ServerCallData::WakeInsideCombiner(Flusher* flusher) { auto* md = UnwrapMetadata(std::move(*r)); bool destroy_md = true; if (send_message() != nullptr) { - send_message()->Done(*md); + send_message()->Done(*md, flusher); } if (receive_message() != nullptr) { receive_message()->Done(*md, flusher); diff --git a/src/core/lib/channel/promise_based_filter.h b/src/core/lib/channel/promise_based_filter.h index f6932fd5b7c..57fa2608a87 100644 --- a/src/core/lib/channel/promise_based_filter.h +++ b/src/core/lib/channel/promise_based_filter.h @@ -292,7 +292,7 @@ class BaseCallData : public Activity, private Wakeable { // work. void WakeInsideCombiner(Flusher* flusher); // Call is completed, we have trailing metadata. Close things out. - void Done(const ServerMetadata& metadata); + void Done(const ServerMetadata& metadata, Flusher* flusher); // Return true if we have a batch captured (for debug logs) bool HaveCapturedBatch() const { return batch_.is_captured(); } // Return true if we're not actively sending a message.