- FusedSet With(P x) {
- return FusedSet
(std::move(x), std::move(*this));
- }
-
- private:
- union {
- T wrapper_;
- };
-};
-
-template <>
-class FusedSet<> {
- public:
- static constexpr size_t Size() { return 0; }
- static constexpr uint8_t NecessaryBits() { return 0; }
-
- template
- Poll Run(uint8_t) {
- return Pending{};
- }
- template
- void Destroy(uint8_t) {}
-
- template
- FusedSet With(P x) {
- return FusedSet
(std::move(x));
- }
-};
-
-template
-class TryConcurrently {
- public:
- TryConcurrently(Main main, PreMain pre_main, PostMain post_main)
- : done_bits_(0),
- pre_main_(std::move(pre_main)),
- post_main_(std::move(post_main)) {
- Construct(&main_, std::move(main));
- }
-
- TryConcurrently(const TryConcurrently&) = delete;
- TryConcurrently& operator=(const TryConcurrently&) = delete;
- TryConcurrently(TryConcurrently&& other) noexcept
- : done_bits_(0),
- pre_main_(std::move(other.pre_main_)),
- post_main_(std::move(other.post_main_)) {
- GPR_DEBUG_ASSERT(other.done_bits_ == 0);
- other.done_bits_ = HelperBits();
- Construct(&main_, std::move(other.main_));
- }
- TryConcurrently& operator=(TryConcurrently&& other) noexcept {
- GPR_DEBUG_ASSERT(other.done_bits_ == 0);
- done_bits_ = 0;
- other.done_bits_ = HelperBits();
- pre_main_ = std::move(other.pre_main_);
- post_main_ = std::move(other.post_main_);
- Construct(&main_, std::move(other.main_));
- return *this;
- }
-
- ~TryConcurrently() {
- if (done_bits_ & 1) {
- Destruct(&result_);
- } else {
- Destruct(&main_);
- }
- pre_main_.template Destroy<1>(done_bits_);
- post_main_.template Destroy<1 + PreMain::Size()>(done_bits_);
- }
-
- using Result =
- typename PollTraits>()())>::Type;
-
- Poll operator()() {
- auto r = pre_main_.template Run(done_bits_);
- if (auto* status = absl::get_if(&r)) {
- GPR_DEBUG_ASSERT(!IsStatusOk(*status));
- return std::move(*status);
- }
- if ((done_bits_ & 1) == 0) {
- auto p = main_();
- if (auto* status = absl::get_if(&p)) {
- done_bits_ |= 1;
- Destruct(&main_);
- Construct(&result_, std::move(*status));
- }
- }
- r = post_main_.template Run(done_bits_);
- if (auto* status = absl::get_if(&r)) {
- GPR_DEBUG_ASSERT(!IsStatusOk(*status));
- return std::move(*status);
- }
- if ((done_bits_ & NecessaryBits()) == NecessaryBits()) {
- return std::move(result_);
- }
- return Pending{};
- }
-
- template
- auto NecessaryPush(P p);
- template
- auto NecessaryPull(P p);
- template
- auto Push(P p);
- template
- auto Pull(P p);
-
- private:
- // Bitmask for done_bits_ specifying which promises must be completed prior to
- // returning ok.
- constexpr uint8_t NecessaryBits() {
- return 1 | (PreMain::NecessaryBits() << 1) |
- (PostMain::NecessaryBits() << (1 + PreMain::Size()));
- }
- // Bitmask for done_bits_ specifying what all of the promises being complete
- // would look like.
- constexpr uint8_t AllBits() {
- return (1 << (1 + PreMain::Size() + PostMain::Size())) - 1;
- }
- // Bitmask of done_bits_ specifying which bits correspond to helper promises -
- // that is all promises that are not the main one.
- constexpr uint8_t HelperBits() { return AllBits() ^ 1; }
-
- // done_bits signifies which operations have completed.
- // Bit 0 is set if main_ has completed.
- // The next higher bits correspond one per pre-main promise.
- // The next higher bits correspond one per post-main promise.
- // So, going from most significant bit to least significant:
- // +--------------+-------------+--------+
- // |post_main bits|pre_main bits|main bit|
- // +--------------+-------------+--------+
- uint8_t done_bits_;
- PreMain pre_main_;
- union {
- PromiseLike main_;
- Result result_;
- };
- PostMain post_main_;
-};
-
-template
-auto MakeTryConcurrently(Main&& main, PreMain&& pre_main,
- PostMain&& post_main) {
- return TryConcurrently(
- std::forward(main), std::forward(pre_main),
- std::forward(post_main));
-}
-
-template
-template
-auto TryConcurrently::NecessaryPush(P p) {
- GPR_DEBUG_ASSERT(done_bits_ == 0);
- done_bits_ = HelperBits();
- return MakeTryConcurrently(std::move(main_),
- pre_main_.With(Necessary{std::move(p)}),
- std::move(post_main_));
-}
-
-template
-template
-auto TryConcurrently::NecessaryPull(P p) {
- GPR_DEBUG_ASSERT(done_bits_ == 0);
- done_bits_ = HelperBits();
- return MakeTryConcurrently(std::move(main_), std::move(pre_main_),
- post_main_.With(Necessary{std::move(p)}));
-}
-
-template
-template
-auto TryConcurrently::Push(P p) {
- GPR_DEBUG_ASSERT(done_bits_ == 0);
- done_bits_ = HelperBits();
- return MakeTryConcurrently(std::move(main_),
- pre_main_.With(Helper{std::move(p)}),
- std::move(post_main_));
-}
-
-template
-template
-auto TryConcurrently::Pull(P p) {
- GPR_DEBUG_ASSERT(done_bits_ == 0);
- done_bits_ = HelperBits();
- return MakeTryConcurrently(std::move(main_), std::move(pre_main_),
- post_main_.With(Helper{std::move(p)}));
-}
-
-} // namespace promise_detail
-
-// TryConcurrently runs a set of promises concurrently.
-// There is a structure to the promises:
-// - A 'main' promise dominates the others - it must complete before the
-// overall promise successfully completes. Its result is chosen in the event
-// of successful completion.
-// - A set of (optional) push and pull promises to aid main. Push promises are
-// polled before main, pull promises are polled after. In this way we can
-// avoid overall wakeup churn - sending a message will tend to push things
-// down the promise tree as its polled, so that send should be in a push
-// promise - then as the main promise is polled and it calls into things
-// lower in the stack they'll already see things there (this reasoning holds
-// for receiving things and the pull promises too!).
-// - Each push and pull promise is either necessary or optional.
-// Necessary promises must complete successfully before the overall promise
-// completes. Optional promises will just be cancelled once the main promise
-// completes and any necessary helpers.
-// - If any of the promises fail, the overall promise fails immediately.
-// API:
-// This function, TryConcurrently, is used to create a TryConcurrently promise.
-// It takes a single argument, being the main promise. That promise also has
-// a set of methods for attaching push and pull promises. The act of attachment
-// returns a new TryConcurrently promise with previous contained promises moved
-// out.
-// The methods exposed:
-// - Push, NecessaryPush: attach a push promise (with the first variant being
-// optional, the second necessary).
-// - Pull, NecessaryPull: attach a pull promise, with variants as above.
-// Example:
-// TryConcurrently(call_next_filter(std::move(call_args)))
-// .Push(send_messages_promise)
-// .Pull(recv_messages_promise)
-template
-auto TryConcurrently(Main main) {
- return promise_detail::MakeTryConcurrently(std::move(main),
- promise_detail::FusedSet<>(),
- promise_detail::FusedSet<>());
-}
-
-} // namespace grpc_core
-
-#endif // GRPC_CORE_LIB_PROMISE_TRY_CONCURRENTLY_H
diff --git a/test/core/promise/BUILD b/test/core/promise/BUILD
index 15f9b286ab3..831b932f650 100644
--- a/test/core/promise/BUILD
+++ b/test/core/promise/BUILD
@@ -406,8 +406,8 @@ grpc_cc_test(
)
grpc_cc_test(
- name = "try_concurrently_test",
- srcs = ["try_concurrently_test.cc"],
+ name = "call_push_pull_test",
+ srcs = ["call_push_pull_test.cc"],
external_deps = [
"gtest",
"absl/status",
@@ -417,6 +417,6 @@ grpc_cc_test(
uses_event_engine = False,
uses_polling = False,
deps = [
- "//src/core:try_concurrently",
+ "//src/core:call_push_pull",
],
)
diff --git a/test/core/promise/arena_promise_test.cc b/test/core/promise/arena_promise_test.cc
index f8d49b495da..b7c37900d80 100644
--- a/test/core/promise/arena_promise_test.cc
+++ b/test/core/promise/arena_promise_test.cc
@@ -14,7 +14,6 @@
#include "src/core/lib/promise/arena_promise.h"
-#include
#include
#include "absl/types/variant.h"
@@ -72,23 +71,6 @@ TEST(ArenaPromiseTest, MoveAssignmentWorks) {
p = ArenaPromise();
}
-TEST(ArenaPromiseTest, AllocatedUniquePtrWorks) {
- ExecCtx exec_ctx;
- auto arena = MakeScopedArena(1024, g_memory_allocator);
- TestContext context(arena.get());
- std::array garbage = {0, 1, 2, 3, 4};
- auto freer = [garbage](int* p) { free(p + garbage[0]); };
- using Ptr = std::unique_ptr;
- Ptr x(new int(42), freer);
- static_assert(sizeof(x) > sizeof(arena_promise_detail::ArgType),
- "This test assumes the unique ptr will go down the allocated "
- "path for ArenaPromise");
- ArenaPromise initial_promise(
- [x = std::move(x)]() mutable { return Poll(std::move(x)); });
- ArenaPromise p(std::move(initial_promise));
- EXPECT_EQ(*absl::get(p()), 42);
-}
-
} // namespace grpc_core
int main(int argc, char** argv) {
diff --git a/test/core/promise/call_push_pull_test.cc b/test/core/promise/call_push_pull_test.cc
new file mode 100644
index 00000000000..89d2ab8107f
--- /dev/null
+++ b/test/core/promise/call_push_pull_test.cc
@@ -0,0 +1,77 @@
+// Copyright 2022 gRPC authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/core/lib/promise/call_push_pull.h"
+
+#include
+
+#include "absl/status/status.h"
+#include "gtest/gtest.h"
+
+namespace grpc_core {
+
+TEST(CallPushPullTest, Empty) {
+ auto p = CallPushPull([] { return absl::OkStatus(); },
+ [] { return absl::OkStatus(); },
+ [] { return absl::OkStatus(); });
+ EXPECT_EQ(p(), Poll(absl::OkStatus()));
+}
+
+TEST(CallPushPullTest, Paused) {
+ auto p = CallPushPull([]() -> Poll { return Pending{}; },
+ []() -> Poll { return Pending{}; },
+ []() -> Poll { return Pending{}; });
+ EXPECT_EQ(p(), Poll(Pending{}));
+}
+
+TEST(CallPushPullTest, OneReady) {
+ auto a = CallPushPull([]() -> Poll { return absl::OkStatus(); },
+ []() -> Poll { return Pending{}; },
+ []() -> Poll { return Pending{}; });
+ EXPECT_EQ(a(), Poll(Pending{}));
+ auto b = CallPushPull([]() -> Poll { return Pending{}; },
+ []() -> Poll { return absl::OkStatus(); },
+ []() -> Poll { return Pending{}; });
+ EXPECT_EQ(b(), Poll(Pending{}));
+ auto c =
+ CallPushPull([]() -> Poll { return Pending{}; },
+ []() -> Poll { return Pending{}; },
+ []() -> Poll { return absl::OkStatus(); });
+ EXPECT_EQ(c(), Poll(Pending{}));
+}
+
+TEST(CallPushPullTest, OneFailed) {
+ auto a = CallPushPull(
+ []() -> Poll { return absl::UnknownError("bah"); },
+ []() -> Poll { return absl::OkStatus(); },
+ []() -> Poll { return absl::OkStatus(); });
+ EXPECT_EQ(a(), Poll(absl::UnknownError("bah")));
+ auto b = CallPushPull(
+ []() -> Poll { return Pending{}; },
+ []() -> Poll { return absl::UnknownError("humbug"); },
+ []() -> Poll { return Pending{}; });
+ EXPECT_EQ(b(), Poll(absl::UnknownError("humbug")));
+ auto c = CallPushPull(
+ []() -> Poll { return Pending{}; },
+ []() -> Poll { return Pending{}; },
+ []() -> Poll { return absl::UnknownError("wha"); });
+ EXPECT_EQ(c(), Poll(absl::UnknownError("wha")));
+}
+
+} // namespace grpc_core
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/test/core/promise/try_concurrently_test.cc b/test/core/promise/try_concurrently_test.cc
deleted file mode 100644
index 7013e840017..00000000000
--- a/test/core/promise/try_concurrently_test.cc
+++ /dev/null
@@ -1,160 +0,0 @@
-// Copyright 2022 gRPC authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "src/core/lib/promise/try_concurrently.h"
-
-#include
-#include
-#include
-#include
-#include
-
-#include "absl/status/status.h"
-#include "gtest/gtest.h"
-
-namespace grpc_core {
-
-class PromiseFactory {
- public:
- // Create a promise that resolves to Ok but has a memory allocation (to verify
- // destruction)
- auto OkPromise(std::string tag) {
- return [this, tag = std::move(tag),
- p = std::make_unique(absl::OkStatus())]() mutable {
- order_.push_back(tag);
- return std::move(*p);
- };
- }
-
- // Create a promise that never resolves and carries a memory allocation
- auto NeverPromise(std::string tag) {
- return [this, tag = std::move(tag),
- p = std::make_unique()]() -> Poll {
- order_.push_back(tag);
- return *p;
- };
- }
-
- // Create a promise that fails and carries a memory allocation
- auto FailPromise(std::string tag) {
- return [this, p = std::make_unique(absl::UnknownError(tag)),
- tag = std::move(tag)]() mutable {
- order_.push_back(tag);
- return std::move(*p);
- };
- }
-
- // Finish one round and return a vector of strings representing which promises
- // were polled and in which order.
- std::vector Finish() { return std::exchange(order_, {}); }
-
- private:
- std::vector order_;
-};
-
-std::ostream& operator<<(std::ostream& out, const Poll& p) {
- return out << PollToString(
- p, [](const absl::Status& s) { return s.ToString(); });
-}
-
-TEST(TryConcurrentlyTest, Immediate) {
- PromiseFactory pf;
- auto a = TryConcurrently(pf.OkPromise("1"));
- EXPECT_EQ(a(), Poll(absl::OkStatus()));
- EXPECT_EQ(pf.Finish(), std::vector({"1"}));
- auto b = TryConcurrently(pf.OkPromise("1")).NecessaryPush(pf.OkPromise("2"));
- EXPECT_EQ(b(), Poll(absl::OkStatus()));
- EXPECT_EQ(pf.Finish(), std::vector({"2", "1"}));
- auto c = TryConcurrently(pf.OkPromise("1")).NecessaryPull(pf.OkPromise("2"));
- EXPECT_EQ(c(), Poll(absl::OkStatus()));
- EXPECT_EQ(pf.Finish(), std::vector({"1", "2"}));
- auto d = TryConcurrently(pf.OkPromise("1"))
- .NecessaryPull(pf.OkPromise("2"))
- .NecessaryPush(pf.OkPromise("3"));
- EXPECT_EQ(d(), Poll(absl::OkStatus()));
- EXPECT_EQ(pf.Finish(), std::vector