diff --git a/build_autogenerated.yaml b/build_autogenerated.yaml index 8b8cad64122..39a48a8a44d 100644 --- a/build_autogenerated.yaml +++ b/build_autogenerated.yaml @@ -13235,6 +13235,7 @@ targets: - src/core/lib/promise/join.h - src/core/lib/promise/map.h - src/core/lib/promise/poll.h + - test/core/promise/poll_matcher.h src: - src/core/lib/debug/trace.cc - src/core/lib/debug/trace_flags.cc @@ -14286,6 +14287,7 @@ targets: - src/core/lib/promise/poll.h - src/core/lib/promise/promise.h - src/core/lib/promise/wait_set.h + - test/core/promise/poll_matcher.h src: - src/core/lib/promise/activity.cc - test/core/promise/mpsc_test.cc diff --git a/src/core/BUILD b/src/core/BUILD index f6b6b13ea10..1b40042d771 100644 --- a/src/core/BUILD +++ b/src/core/BUILD @@ -500,7 +500,10 @@ grpc_cc_library( grpc_cc_library( name = "poll", - external_deps = ["absl/log:check"], + external_deps = [ + "absl/log:check", + "absl/strings:str_format", + ], language = "c++", public_hdrs = [ "lib/promise/poll.h", diff --git a/src/core/lib/promise/mpsc.h b/src/core/lib/promise/mpsc.h index ec8132e843a..26244b43734 100644 --- a/src/core/lib/promise/mpsc.h +++ b/src/core/lib/promise/mpsc.h @@ -18,6 +18,8 @@ #include #include +#include +#include #include #include @@ -48,6 +50,9 @@ class Center : public RefCounted> { // Construct the center with a maximum queue size. explicit Center(size_t max_queued) : max_queued_(max_queued) {} + static constexpr const uint64_t kClosedBatch = + std::numeric_limits::max(); + // Poll for new items. // - Returns true if new items were obtained, in which case they are contained // in dest in the order they were added. Wakes up all pending senders since @@ -67,45 +72,39 @@ class Center : public RefCounted> { } dest.swap(queue_); queue_.clear(); + if (batch_ != kClosedBatch) ++batch_; auto wakeups = send_wakers_.TakeWakeupSet(); lock.Release(); wakeups.Wakeup(); return true; } - // Poll to send one item. - // Returns pending if no send slot was available. - // Returns true if the item was sent. - // Returns false if the receiver has been closed. - Poll PollSend(T& t) { - ReleasableMutexLock lock(&mu_); - if (receiver_closed_) return Poll(false); - if (queue_.size() < max_queued_) { - queue_.push_back(std::move(t)); - auto receive_waker = std::move(receive_waker_); - lock.Release(); - receive_waker.Wakeup(); - return Poll(true); - } - send_wakers_.AddPending(GetContext()->MakeNonOwningWaker()); - return Pending{}; - } - - bool ImmediateSend(T t) { + // Returns the batch number that the item was sent in, or kClosedBatch if the + // pipe is closed. + uint64_t Send(T t) { ReleasableMutexLock lock(&mu_); - if (receiver_closed_) return false; + if (batch_ == kClosedBatch) return kClosedBatch; queue_.push_back(std::move(t)); auto receive_waker = std::move(receive_waker_); + const uint64_t batch = queue_.size() <= max_queued_ ? batch_ : batch_ + 1; lock.Release(); receive_waker.Wakeup(); - return true; + return batch; + } + + // Poll until a particular batch number is received. + Poll PollReceiveBatch(uint64_t batch) { + ReleasableMutexLock lock(&mu_); + if (batch_ >= batch) return Empty{}; + send_wakers_.AddPending(GetContext()->MakeNonOwningWaker()); + return Pending{}; } // Mark that the receiver is closed. void ReceiverClosed() { ReleasableMutexLock lock(&mu_); - if (receiver_closed_) return; - receiver_closed_ = true; + if (batch_ == kClosedBatch) return; + batch_ = kClosedBatch; auto wakeups = send_wakers_.TakeWakeupSet(); lock.Release(); wakeups.Wakeup(); @@ -115,7 +114,9 @@ class Center : public RefCounted> { Mutex mu_; const size_t max_queued_; std::vector queue_ ABSL_GUARDED_BY(mu_); - bool receiver_closed_ ABSL_GUARDED_BY(mu_) = false; + // Every time we give queue_ to the receiver, we increment batch_. + // When the receiver is closed we set batch_ to kClosedBatch. + uint64_t batch_ ABSL_GUARDED_BY(mu_) = 1; Waker receive_waker_ ABSL_GUARDED_BY(mu_); WaitSet send_wakers_ ABSL_GUARDED_BY(mu_); }; @@ -138,14 +139,23 @@ class MpscSender { // Resolves to true if sent, false if the receiver was closed (and the value // will never be successfully sent). auto Send(T t) { - return [center = center_, t = std::move(t)]() mutable -> Poll { + return [center = center_, t = std::move(t), + batch = uint64_t(0)]() mutable -> Poll { if (center == nullptr) return false; - return center->PollSend(t); + if (batch == 0) { + batch = center->Send(std::move(t)); + CHECK_NE(batch, 0u); + if (batch == mpscpipe_detail::Center::kClosedBatch) return false; + } + auto p = center->PollReceiveBatch(batch); + if (p.pending()) return Pending{}; + return true; }; } bool UnbufferedImmediateSend(T t) { - return center_->ImmediateSend(std::move(t)); + return center_->Send(std::move(t)) != + mpscpipe_detail::Center::kClosedBatch; } private: diff --git a/src/core/lib/promise/poll.h b/src/core/lib/promise/poll.h index 57a4a65e83a..f23340d9ebe 100644 --- a/src/core/lib/promise/poll.h +++ b/src/core/lib/promise/poll.h @@ -19,6 +19,7 @@ #include #include "absl/log/check.h" +#include "absl/strings/str_format.h" #include #include @@ -252,6 +253,15 @@ std::string PollToString( return t_to_string(poll.value()); } +template +void AbslStringify(Sink& sink, const Poll& poll) { + if (poll.pending()) { + absl::Format(&sink, "<>"); + return; + } + absl::Format(&sink, "%v", poll.value()); +} + } // namespace grpc_core #endif // GRPC_SRC_CORE_LIB_PROMISE_POLL_H diff --git a/test/core/promise/BUILD b/test/core/promise/BUILD index 973dfaa88d3..aca4b5b3631 100644 --- a/test/core/promise/BUILD +++ b/test/core/promise/BUILD @@ -241,6 +241,7 @@ grpc_cc_test( uses_event_engine = False, uses_polling = False, deps = [ + "poll_matcher", "//src/core:join", "//src/core:poll", ], @@ -487,6 +488,7 @@ grpc_cc_test( uses_event_engine = False, uses_polling = False, deps = [ + "poll_matcher", "//:gpr", "//:promise", "//src/core:activity", diff --git a/test/core/promise/join_test.cc b/test/core/promise/join_test.cc index 288d3a64539..7e60d65732f 100644 --- a/test/core/promise/join_test.cc +++ b/test/core/promise/join_test.cc @@ -17,25 +17,26 @@ #include #include +#include "gmock/gmock.h" #include "gtest/gtest.h" #include "src/core/lib/promise/poll.h" +#include "test/core/promise/poll_matcher.h" namespace grpc_core { TEST(JoinTest, Join1) { - EXPECT_EQ(Join([] { return 3; })(), - (Poll>(std::make_tuple(3)))); + EXPECT_THAT(Join([] { return 3; })(), IsReady(std::make_tuple(3))); } TEST(JoinTest, Join2) { - EXPECT_EQ(Join([] { return 3; }, [] { return 4; })(), - (Poll>(std::make_tuple(3, 4)))); + EXPECT_THAT(Join([] { return 3; }, [] { return 4; })(), + IsReady(std::make_tuple(3, 4))); } TEST(JoinTest, Join3) { - EXPECT_EQ(Join([] { return 3; }, [] { return 4; }, [] { return 5; })(), - (Poll>(std::make_tuple(3, 4, 5)))); + EXPECT_THAT(Join([] { return 3; }, [] { return 4; }, [] { return 5; })(), + IsReady(std::make_tuple(3, 4, 5))); } } // namespace grpc_core diff --git a/test/core/promise/mpsc_test.cc b/test/core/promise/mpsc_test.cc index 3d6e669a173..0a486edff72 100644 --- a/test/core/promise/mpsc_test.cc +++ b/test/core/promise/mpsc_test.cc @@ -25,6 +25,7 @@ #include "src/core/lib/promise/activity.h" #include "src/core/lib/promise/promise.h" +#include "test/core/promise/poll_matcher.h" using testing::Mock; using testing::StrictMock; @@ -63,8 +64,17 @@ struct Payload { return (x == nullptr && other.x == nullptr) || (x != nullptr && other.x != nullptr && *x == *other.x); } + bool operator!=(const Payload& other) const { return !(*this == other); } + explicit Payload(std::unique_ptr x) : x(std::move(x)) {} + Payload(const Payload& other) + : x(other.x ? std::make_unique(*other.x) : nullptr) {} + + friend std::ostream& operator<<(std::ostream& os, const Payload& payload) { + if (payload.x == nullptr) return os << "Payload{nullptr}"; + return os << "Payload{" << *payload.x << "}"; + } }; -Payload MakePayload(int value) { return {std::make_unique(value)}; } +Payload MakePayload(int value) { return Payload{std::make_unique(value)}; } TEST(MpscTest, NoOp) { MpscReceiver receiver(1); } @@ -76,14 +86,14 @@ TEST(MpscTest, MakeSender) { TEST(MpscTest, SendOneThingInstantly) { MpscReceiver receiver(1); MpscSender sender = receiver.MakeSender(); - EXPECT_EQ(NowOrNever(sender.Send(MakePayload(1))), true); + EXPECT_THAT(sender.Send(MakePayload(1))(), IsReady(true)); } TEST(MpscTest, SendOneThingInstantlyAndReceiveInstantly) { MpscReceiver receiver(1); MpscSender sender = receiver.MakeSender(); - EXPECT_EQ(NowOrNever(sender.Send(MakePayload(1))), true); - EXPECT_EQ(NowOrNever(receiver.Next()), MakePayload(1)); + EXPECT_THAT(sender.Send(MakePayload(1))(), IsReady(true)); + EXPECT_THAT(receiver.Next()(), IsReady(MakePayload(1))); } TEST(MpscTest, SendingLotsOfThingsGivesPushback) { @@ -92,8 +102,8 @@ TEST(MpscTest, SendingLotsOfThingsGivesPushback) { MpscSender sender = receiver.MakeSender(); activity1.Activate(); - EXPECT_EQ(NowOrNever(sender.Send(MakePayload(1))), true); - EXPECT_EQ(NowOrNever(sender.Send(MakePayload(2))), absl::nullopt); + EXPECT_THAT(sender.Send(MakePayload(1))(), IsReady(true)); + EXPECT_THAT(sender.Send(MakePayload(2))(), IsPending()); activity1.Deactivate(); EXPECT_CALL(activity1, WakeupRequested()); @@ -106,28 +116,23 @@ TEST(MpscTest, ReceivingAfterBlockageWakesUp) { MpscSender sender = receiver.MakeSender(); activity1.Activate(); - EXPECT_EQ(NowOrNever(sender.Send(MakePayload(1))), true); + EXPECT_THAT(sender.Send(MakePayload(1))(), IsReady(true)); auto send2 = sender.Send(MakePayload(2)); - EXPECT_EQ(send2(), Poll(Pending{})); + EXPECT_THAT(send2(), IsPending()); activity1.Deactivate(); activity2.Activate(); EXPECT_CALL(activity1, WakeupRequested()); - EXPECT_EQ(NowOrNever(receiver.Next()), MakePayload(1)); + EXPECT_THAT(receiver.Next()(), IsReady(MakePayload(1))); Mock::VerifyAndClearExpectations(&activity1); auto receive2 = receiver.Next(); - EXPECT_EQ(receive2(), Poll(Pending{})); + EXPECT_THAT(receive2(), IsReady(MakePayload(2))); activity2.Deactivate(); activity1.Activate(); - EXPECT_CALL(activity2, WakeupRequested()); - EXPECT_EQ(send2(), Poll(true)); + EXPECT_THAT(send2(), Poll(true)); Mock::VerifyAndClearExpectations(&activity2); activity1.Deactivate(); - - activity2.Activate(); - EXPECT_EQ(receive2(), Poll(MakePayload(2))); - activity2.Deactivate(); } TEST(MpscTest, BigBufferAllowsBurst) { @@ -135,10 +140,10 @@ TEST(MpscTest, BigBufferAllowsBurst) { MpscSender sender = receiver.MakeSender(); for (int i = 0; i < 25; i++) { - EXPECT_EQ(NowOrNever(sender.Send(MakePayload(i))), true); + EXPECT_THAT(sender.Send(MakePayload(i))(), IsReady(true)); } for (int i = 0; i < 25; i++) { - EXPECT_EQ(NowOrNever(receiver.Next()), MakePayload(i)); + EXPECT_THAT(receiver.Next()(), IsReady(MakePayload(i))); } } @@ -146,7 +151,7 @@ TEST(MpscTest, ClosureIsVisibleToSenders) { auto receiver = std::make_unique>(1); MpscSender sender = receiver->MakeSender(); receiver.reset(); - EXPECT_EQ(NowOrNever(sender.Send(MakePayload(1))), false); + EXPECT_THAT(sender.Send(MakePayload(1))(), IsReady(false)); } TEST(MpscTest, ImmediateSendWorks) { @@ -163,15 +168,15 @@ TEST(MpscTest, ImmediateSendWorks) { EXPECT_EQ(sender.UnbufferedImmediateSend(MakePayload(7)), true); activity.Activate(); - EXPECT_EQ(NowOrNever(receiver.Next()), MakePayload(1)); - EXPECT_EQ(NowOrNever(receiver.Next()), MakePayload(2)); - EXPECT_EQ(NowOrNever(receiver.Next()), MakePayload(3)); - EXPECT_EQ(NowOrNever(receiver.Next()), MakePayload(4)); - EXPECT_EQ(NowOrNever(receiver.Next()), MakePayload(5)); - EXPECT_EQ(NowOrNever(receiver.Next()), MakePayload(6)); - EXPECT_EQ(NowOrNever(receiver.Next()), MakePayload(7)); + EXPECT_THAT(receiver.Next()(), IsReady(MakePayload(1))); + EXPECT_THAT(receiver.Next()(), IsReady(MakePayload(2))); + EXPECT_THAT(receiver.Next()(), IsReady(MakePayload(3))); + EXPECT_THAT(receiver.Next()(), IsReady(MakePayload(4))); + EXPECT_THAT(receiver.Next()(), IsReady(MakePayload(5))); + EXPECT_THAT(receiver.Next()(), IsReady(MakePayload(6))); + EXPECT_THAT(receiver.Next()(), IsReady(MakePayload(7))); auto receive2 = receiver.Next(); - EXPECT_EQ(receive2(), Poll(Pending{})); + EXPECT_THAT(receive2(), IsPending()); activity.Deactivate(); }