diff --git a/src/core/lib/promise/party.cc b/src/core/lib/promise/party.cc index 4a47bb33687..77e06b6ec49 100644 --- a/src/core/lib/promise/party.cc +++ b/src/core/lib/promise/party.cc @@ -37,6 +37,11 @@ grpc_core::DebugOnlyTraceFlag grpc_trace_party_state(false, "party_state"); namespace grpc_core { +namespace { +// TODO(ctiller): Once all activities are parties we can remove this. +thread_local Party* g_current_party_ = nullptr; +} // namespace + /////////////////////////////////////////////////////////////////////////////// // PartySyncUsingAtomics @@ -210,11 +215,32 @@ void Party::ForceImmediateRepoll(WakeupMask mask) { } void Party::RunLocked() { + // If there is a party running, then we don't run it immediately + // but instead add it to the end of the list of parties to run. + // This enables a fairly straightforward batching of work from a + // call to a transport (or back again). + if (g_current_party_ != nullptr) { + Party* after = g_current_party_; + while (after->run_next_ != nullptr) { + after = after->run_next_; + } + after->run_next_ = this; + return; + } auto body = [this]() { - if (RunParty()) { + GPR_DEBUG_ASSERT(g_current_party_ == nullptr); + g_current_party_ = this; + const bool done = RunParty(); + GPR_DEBUG_ASSERT(g_current_party_ == this); + Party* run_next = std::exchange(run_next_, nullptr); + g_current_party_ = nullptr; + if (done) { ScopedActivity activity(this); PartyOver(); } + if (run_next != nullptr) { + run_next->RunLocked(); + } }; #ifdef GRPC_MAXIMIZE_THREADYNESS Thread thd( diff --git a/src/core/lib/promise/party.h b/src/core/lib/promise/party.h index 76c76f22990..7a4bcab14dc 100644 --- a/src/core/lib/promise/party.h +++ b/src/core/lib/promise/party.h @@ -628,6 +628,7 @@ class Party : public Activity, private Wakeable { Arena* const arena_; uint8_t currently_polling_ = kNotPolling; + Party* run_next_ = nullptr; // All current participants, using a tagged format. // If the lower bit is unset, then this is a Participant*. // If the lower bit is set, then this is a ParticipantFactory*. diff --git a/test/core/promise/party_test.cc b/test/core/promise/party_test.cc index 6bdf6696da7..de9c88ce46b 100644 --- a/test/core/promise/party_test.cc +++ b/test/core/promise/party_test.cc @@ -285,8 +285,7 @@ TEST_F(PartyTest, CanSpawnAndRun) { "TestSpawn", [i = 10]() mutable -> Poll { EXPECT_EQ(GetContext()->DebugTag(), "TestParty"); - gpr_log(GPR_DEBUG, "i=%d", i); - GPR_ASSERT(i > 0); + EXPECT_GT(i, 0); GetContext()->ForceImmediateRepoll(); --i; if (i == 0) return 42; @@ -786,6 +785,39 @@ TEST_F(PartyTest, ThreadStressTestWithInnerSpawn) { } } +TEST_F(PartyTest, NestedWakeup) { + auto party1 = MakeRefCounted(); + auto party2 = MakeRefCounted(); + int whats_going_on = 0; + Notification n; + party1->Spawn( + "p1", + [&]() { + EXPECT_EQ(whats_going_on, 0); + whats_going_on = 1; + party2->Spawn( + "p2", + [&]() { + EXPECT_EQ(whats_going_on, 3); + whats_going_on = 4; + return Empty{}; + }, + [&](Empty) { + EXPECT_EQ(whats_going_on, 4); + whats_going_on = 5; + n.Notify(); + }); + EXPECT_EQ(whats_going_on, 1); + whats_going_on = 2; + return Empty{}; + }, + [&](Empty) { + EXPECT_EQ(whats_going_on, 2); + whats_going_on = 3; + }); + n.WaitForNotification(); +} + } // namespace grpc_core int main(int argc, char** argv) {