[promises] New `Join`, `TryJoin` implementation (#33995)

Our current implementation of Join, TryJoin leverage some complicated
template stuff to work, which makes them hard to maintain. I've been
thinking about ways to simplify that for some time and had something
like this in mind - using a code generator that's at least a little more
understandable to code generate most of the complexity into a file that
is checkable.

Concurrently - I have a cool optimization in mind - but it requires that
we can move promises after polling, which is a contract change. I'm
going to work through the set of primitives we have in the coming weeks
and change that contract to enable the optimization.

---------

Co-authored-by: ctiller <ctiller@users.noreply.github.com>
pull/34078/head
Craig Tiller 1 year ago committed by GitHub
parent 4d24b93cbb
commit 8bbd11ebed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 5
      CMakeLists.txt
  2. 26
      build_autogenerated.yaml
  3. 13
      src/core/BUILD
  4. 197
      src/core/lib/promise/detail/basic_join.h
  5. 1592
      src/core/lib/promise/detail/join_state.h
  6. 42
      src/core/lib/promise/join.h
  7. 52
      src/core/lib/promise/try_join.h
  8. 5
      test/core/promise/BUILD
  9. 1
      test/core/promise/for_each_test.cc
  10. 1
      test/core/promise/join_test.cc
  11. 1
      test/core/promise/latch_test.cc
  12. 1
      test/core/promise/map_pipe_test.cc
  13. 1
      test/core/promise/pipe_test.cc
  14. 1
      test/core/promise/promise_fuzzer.cc
  15. 1
      test/core/promise/try_join_test.cc
  16. 1
      test/core/transport/promise_endpoint_test.cc
  17. 153
      tools/codegen/core/gen_join.py

5
CMakeLists.txt generated

@ -5237,7 +5237,6 @@ target_link_libraries(activity_test
absl::hash
absl::type_traits
absl::statusor
absl::utility
gpr
)
@ -12178,7 +12177,6 @@ target_link_libraries(for_each_test
absl::hash
absl::type_traits
absl::statusor
absl::utility
gpr
upb
)
@ -15606,7 +15604,6 @@ target_link_libraries(join_test
${_gRPC_ZLIB_LIBRARIES}
${_gRPC_ALLTARGETS_LIBRARIES}
absl::type_traits
absl::utility
gpr
)
@ -15951,7 +15948,6 @@ target_link_libraries(latch_test
${_gRPC_ALLTARGETS_LIBRARIES}
absl::type_traits
absl::statusor
absl::utility
gpr
)
@ -16292,7 +16288,6 @@ target_link_libraries(map_pipe_test
absl::hash
absl::type_traits
absl::statusor
absl::utility
gpr
upb
)

@ -4303,13 +4303,14 @@ targets:
- src/core/lib/gprpp/ref_counted_ptr.h
- src/core/lib/promise/activity.h
- src/core/lib/promise/context.h
- src/core/lib/promise/detail/basic_join.h
- src/core/lib/promise/detail/basic_seq.h
- src/core/lib/promise/detail/join_state.h
- src/core/lib/promise/detail/promise_factory.h
- src/core/lib/promise/detail/promise_like.h
- src/core/lib/promise/detail/seq_state.h
- src/core/lib/promise/detail/status.h
- src/core/lib/promise/join.h
- src/core/lib/promise/map.h
- src/core/lib/promise/poll.h
- src/core/lib/promise/promise.h
- src/core/lib/promise/seq.h
@ -4323,7 +4324,6 @@ targets:
- absl/hash:hash
- absl/meta:type_traits
- absl/status:statusor
- absl/utility:utility
- gpr
uses_polling: false
- name: address_sorting_test
@ -7846,8 +7846,8 @@ targets:
- src/core/lib/iomgr/iomgr_internal.h
- src/core/lib/promise/activity.h
- src/core/lib/promise/context.h
- src/core/lib/promise/detail/basic_join.h
- src/core/lib/promise/detail/basic_seq.h
- src/core/lib/promise/detail/join_state.h
- src/core/lib/promise/detail/promise_factory.h
- src/core/lib/promise/detail/promise_like.h
- src/core/lib/promise/detail/seq_state.h
@ -7911,7 +7911,6 @@ targets:
- absl/hash:hash
- absl/meta:type_traits
- absl/status:statusor
- absl/utility:utility
- gpr
- upb
uses_polling: false
@ -9877,15 +9876,15 @@ targets:
language: c++
headers:
- src/core/lib/gprpp/bitset.h
- src/core/lib/promise/detail/basic_join.h
- src/core/lib/promise/detail/join_state.h
- src/core/lib/promise/detail/promise_like.h
- src/core/lib/promise/join.h
- src/core/lib/promise/map.h
- src/core/lib/promise/poll.h
src:
- test/core/promise/join_test.cc
deps:
- absl/meta:type_traits
- absl/utility:utility
- gpr
uses_polling: false
- name: json_object_loader_test
@ -10066,14 +10065,15 @@ targets:
- src/core/lib/gprpp/ref_counted_ptr.h
- src/core/lib/promise/activity.h
- src/core/lib/promise/context.h
- src/core/lib/promise/detail/basic_join.h
- src/core/lib/promise/detail/basic_seq.h
- src/core/lib/promise/detail/join_state.h
- src/core/lib/promise/detail/promise_factory.h
- src/core/lib/promise/detail/promise_like.h
- src/core/lib/promise/detail/seq_state.h
- src/core/lib/promise/detail/status.h
- src/core/lib/promise/join.h
- src/core/lib/promise/latch.h
- src/core/lib/promise/map.h
- src/core/lib/promise/poll.h
- src/core/lib/promise/seq.h
- src/core/lib/promise/trace.h
@ -10086,7 +10086,6 @@ targets:
deps:
- absl/meta:type_traits
- absl/status:statusor
- absl/utility:utility
- gpr
uses_polling: false
- name: lb_get_cpu_stats_test
@ -10213,8 +10212,8 @@ targets:
- src/core/lib/iomgr/iomgr_internal.h
- src/core/lib/promise/activity.h
- src/core/lib/promise/context.h
- src/core/lib/promise/detail/basic_join.h
- src/core/lib/promise/detail/basic_seq.h
- src/core/lib/promise/detail/join_state.h
- src/core/lib/promise/detail/promise_factory.h
- src/core/lib/promise/detail/promise_like.h
- src/core/lib/promise/detail/seq_state.h
@ -10279,7 +10278,6 @@ targets:
- absl/hash:hash
- absl/meta:type_traits
- absl/status:statusor
- absl/utility:utility
- gpr
- upb
uses_polling: false
@ -11319,7 +11317,7 @@ targets:
build: test
language: c++
headers:
- src/core/lib/promise/detail/basic_join.h
- src/core/lib/promise/detail/join_state.h
- src/core/lib/promise/join.h
- test/core/promise/test_wakeup_schedulers.h
src:
@ -11455,7 +11453,7 @@ targets:
build: test
language: c++
headers:
- src/core/lib/promise/detail/basic_join.h
- src/core/lib/promise/detail/join_state.h
- src/core/lib/promise/join.h
- src/core/lib/transport/promise_endpoint.h
- test/core/promise/test_wakeup_schedulers.h
@ -15376,9 +15374,9 @@ targets:
language: c++
headers:
- src/core/lib/gprpp/bitset.h
- src/core/lib/promise/detail/basic_join.h
- src/core/lib/promise/detail/join_state.h
- src/core/lib/promise/detail/promise_like.h
- src/core/lib/promise/detail/status.h
- src/core/lib/promise/map.h
- src/core/lib/promise/poll.h
- src/core/lib/promise/try_join.h
src:

@ -607,17 +607,17 @@ grpc_cc_library(
)
grpc_cc_library(
name = "basic_join",
external_deps = ["absl/utility"],
name = "join_state",
language = "c++",
public_hdrs = [
"lib/promise/detail/basic_join.h",
"lib/promise/detail/join_state.h",
],
deps = [
"bitset",
"construct_destruct",
"poll",
"promise_like",
"//:gpr",
"//:gpr_platform",
],
)
@ -630,7 +630,8 @@ grpc_cc_library(
"lib/promise/join.h",
],
deps = [
"basic_join",
"join_state",
"map",
"//:gpr_platform",
],
)
@ -647,9 +648,9 @@ grpc_cc_library(
"lib/promise/try_join.h",
],
deps = [
"basic_join",
"join_state",
"map",
"poll",
"promise_status",
"//:gpr_platform",
],
)

@ -1,197 +0,0 @@
// Copyright 2021 gRPC authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef GRPC_SRC_CORE_LIB_PROMISE_DETAIL_BASIC_JOIN_H
#define GRPC_SRC_CORE_LIB_PROMISE_DETAIL_BASIC_JOIN_H
#include <grpc/support/port_platform.h>
#include <assert.h>
#include <stddef.h>
#include <array>
#include <tuple>
#include <type_traits>
#include <utility>
#include "absl/utility/utility.h"
#include "src/core/lib/gprpp/bitset.h"
#include "src/core/lib/gprpp/construct_destruct.h"
#include "src/core/lib/promise/detail/promise_like.h"
#include "src/core/lib/promise/poll.h"
namespace grpc_core {
namespace promise_detail {
// This union can either be a functor, or the result of the functor (after
// mapping via a trait). Allows us to remember the result of one joined functor
// until the rest are ready.
template <typename Traits, typename F>
union Fused {
explicit Fused(F&& f) : f(std::forward<F>(f)) {}
explicit Fused(PromiseLike<F>&& f) : f(std::forward<PromiseLike<F>>(f)) {}
~Fused() {}
// Wrap the functor in a PromiseLike to handle immediately returning functors
// and the like.
using Promise = PromiseLike<F>;
GPR_NO_UNIQUE_ADDRESS Promise f;
// Compute the result type: We take the result of the promise, and pass it via
// our traits, so that, for example, TryJoin and take a StatusOr<T> and just
// store a T.
using Result = typename Traits::template ResultType<typename Promise::Result>;
GPR_NO_UNIQUE_ADDRESS Result result;
};
// A join gets composed of joints... these are just wrappers around a Fused for
// their data, with some machinery as methods to get the system working.
template <typename Traits, size_t kRemaining, typename... Fs>
struct Joint : public Joint<Traits, kRemaining - 1, Fs...> {
// The index into Fs for this Joint
static constexpr size_t kIdx = sizeof...(Fs) - kRemaining;
// The next join (the one we derive from)
using NextJoint = Joint<Traits, kRemaining - 1, Fs...>;
// From Fs, extract the functor for this joint.
using F = typename std::tuple_element<kIdx, std::tuple<Fs...>>::type;
// Generate the Fused type for this functor.
using Fsd = Fused<Traits, F>;
GPR_NO_UNIQUE_ADDRESS Fsd fused;
// Figure out what kind of bitmask will be used by the outer join.
using Bits = BitSet<sizeof...(Fs)>;
// Initialize from a tuple of pointers to Fs
explicit Joint(std::tuple<Fs*...> fs)
: NextJoint(fs), fused(std::move(*std::get<kIdx>(fs))) {}
// Copy: assume that the Fuse is still in the promise state (since it's not
// legal to copy after the first poll!)
Joint(const Joint& j) : NextJoint(j), fused(j.fused.f) {}
// Move: assume that the Fuse is still in the promise state (since it's not
// legal to move after the first poll!)
Joint(Joint&& j) noexcept
: NextJoint(std::forward<NextJoint>(j)), fused(std::move(j.fused.f)) {}
// Destruct: check bits to see if we're in promise or result state, and call
// the appropriate destructor. Recursively, call up through the join.
void DestructAll(const Bits& bits) {
if (!bits.is_set(kIdx)) {
Destruct(&fused.f);
} else {
Destruct(&fused.result);
}
NextJoint::DestructAll(bits);
}
// Poll all joints up, and then call finally.
template <typename F>
auto Run(Bits* bits, F finally) -> decltype(finally()) {
// If we're still in the promise state...
if (!bits->is_set(kIdx)) {
// Poll the promise
auto r = fused.f();
if (auto* p = r.value_if_ready()) {
// If it's done, then ask the trait to unwrap it and store that result
// in the Fused, and continue the iteration. Note that OnResult could
// instead choose to return a value instead of recursing through the
// iteration, in that case we continue returning the same result up.
// Here is where TryJoin can escape out.
return Traits::OnResult(
std::move(*p), [this, bits, &finally](typename Fsd::Result result) {
bits->set(kIdx);
Destruct(&fused.f);
Construct(&fused.result, std::move(result));
return NextJoint::Run(bits, std::move(finally));
});
}
}
// That joint is still pending... we'll still poll the result of the joints.
return NextJoint::Run(bits, std::move(finally));
}
};
// Terminating joint... for each of the recursions, do the thing we're supposed
// to do at the end.
template <typename Traits, typename... Fs>
struct Joint<Traits, 0, Fs...> {
explicit Joint(std::tuple<Fs*...>) {}
Joint(const Joint&) {}
Joint(Joint&&) noexcept {}
template <typename T>
void DestructAll(const T&) {}
template <typename F>
auto Run(BitSet<sizeof...(Fs)>*, F finally) -> decltype(finally()) {
return finally();
}
};
template <typename Traits, typename... Fs>
class BasicJoin {
private:
// How many things are we joining?
static constexpr size_t N = sizeof...(Fs);
// Bitset: if a bit is 0, that joint is still in promise state. If it's 1,
// then the joint has a result.
GPR_NO_UNIQUE_ADDRESS BitSet<N> state_;
// The actual joints, wrapped in an anonymous union to give us control of
// construction/destruction.
union {
GPR_NO_UNIQUE_ADDRESS Joint<Traits, sizeof...(Fs), Fs...> joints_;
};
// Access joint index I
template <size_t I>
Joint<Traits, sizeof...(Fs) - I, Fs...>* GetJoint() {
return static_cast<Joint<Traits, sizeof...(Fs) - I, Fs...>*>(&joints_);
}
// The tuple of results of all our promises
using Tuple = std::tuple<typename Fused<Traits, Fs>::Result...>;
// Collect up all the results and construct a tuple.
template <size_t... I>
Tuple Finish(absl::index_sequence<I...>) {
return Tuple(std::move(GetJoint<I>()->fused.result)...);
}
public:
explicit BasicJoin(Fs&&... fs) : joints_(std::tuple<Fs*...>(&fs...)) {}
BasicJoin& operator=(const BasicJoin&) = delete;
// Copy a join - only available before polling.
BasicJoin(const BasicJoin& other) {
assert(other.state_.none());
Construct(&joints_, other.joints_);
}
// Move a join - only available before polling.
BasicJoin(BasicJoin&& other) noexcept {
assert(other.state_.none());
Construct(&joints_, std::move(other.joints_));
}
~BasicJoin() { joints_.DestructAll(state_); }
using Result = decltype(Traits::Wrap(std::declval<Tuple>()));
// Poll the join
Poll<Result> operator()() {
// Poll the joints...
return joints_.Run(&state_, [this]() -> Poll<Result> {
// If all of them are completed, collect the results, and then ask our
// traits to wrap them - allowing for example TryJoin to turn tuple<A,B,C>
// into StatusOr<tuple<A,B,C>>.
if (state_.all()) {
return Traits::Wrap(Finish(absl::make_index_sequence<N>()));
} else {
return Pending();
}
});
}
};
} // namespace promise_detail
} // namespace grpc_core
#endif // GRPC_SRC_CORE_LIB_PROMISE_DETAIL_BASIC_JOIN_H

File diff suppressed because it is too large Load Diff

@ -17,9 +17,15 @@
#include <grpc/support/port_platform.h>
#include <stdlib.h>
#include <tuple>
#include <utility>
#include "absl/meta/type_traits.h"
#include "src/core/lib/promise/detail/basic_join.h"
#include "src/core/lib/promise/detail/join_state.h"
#include "src/core/lib/promise/map.h"
namespace grpc_core {
namespace promise_detail {
@ -27,19 +33,36 @@ namespace promise_detail {
struct JoinTraits {
template <typename T>
using ResultType = absl::remove_reference_t<T>;
template <typename T, typename F>
static auto OnResult(T result, F kontinue)
-> decltype(kontinue(std::move(result))) {
return kontinue(std::move(result));
template <typename T>
static bool IsOk(const T&) {
return true;
}
template <typename T>
static T Wrap(T x) {
static T Unwrapped(T x) {
return x;
}
template <typename R, typename T>
static R EarlyReturn(T) {
abort();
}
};
template <typename... Promises>
using Join = BasicJoin<JoinTraits, Promises...>;
class Join {
public:
explicit Join(Promises... promises) : state_(std::move(promises)...) {}
auto operator()() { return state_.PollOnce(); }
private:
JoinState<JoinTraits, Promises...> state_;
};
struct WrapInTuple {
template <typename T>
std::tuple<T> operator()(T x) {
return std::make_tuple(std::move(x));
}
};
} // namespace promise_detail
@ -50,6 +73,11 @@ promise_detail::Join<Promise...> Join(Promise... promises) {
return promise_detail::Join<Promise...>(std::move(promises)...);
}
template <typename F>
auto Join(F promise) {
return Map(std::move(promise), promise_detail::WrapInTuple{});
}
} // namespace grpc_core
#endif // GRPC_SRC_CORE_LIB_PROMISE_JOIN_H

@ -17,14 +17,15 @@
#include <grpc/support/port_platform.h>
#include <type_traits>
#include <tuple>
#include <utility>
#include "absl/meta/type_traits.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "src/core/lib/promise/detail/basic_join.h"
#include "src/core/lib/promise/detail/status.h"
#include "src/core/lib/promise/detail/join_state.h"
#include "src/core/lib/promise/map.h"
#include "src/core/lib/promise/poll.h"
namespace grpc_core {
@ -45,27 +46,39 @@ inline Empty IntoResult(absl::Status*) { return Empty{}; }
// Traits object to pass to BasicJoin
struct TryJoinTraits {
template <typename T>
using ResultType =
decltype(IntoResult(std::declval<absl::remove_reference_t<T>*>()));
template <typename T, typename F>
static auto OnResult(T result, F kontinue)
-> decltype(kontinue(IntoResult(&result))) {
using Result =
typename PollTraits<decltype(kontinue(IntoResult(&result)))>::Type;
if (!result.ok()) {
return Result(IntoStatus(&result));
}
return kontinue(IntoResult(&result));
using ResultType = absl::StatusOr<absl::remove_reference_t<T>>;
template <typename T>
static bool IsOk(const absl::StatusOr<T>& x) {
return x.ok();
}
template <typename T>
static absl::StatusOr<T> Wrap(T x) {
return absl::StatusOr<T>(std::move(x));
static T Unwrapped(absl::StatusOr<T> x) {
return std::move(*x);
}
template <typename R, typename T>
static R EarlyReturn(absl::StatusOr<T> x) {
return x.status();
}
};
// Implementation of TryJoin combinator.
template <typename... Promises>
using TryJoin = BasicJoin<TryJoinTraits, Promises...>;
class TryJoin {
public:
explicit TryJoin(Promises... promises) : state_(std::move(promises)...) {}
auto operator()() { return state_.PollOnce(); }
private:
JoinState<TryJoinTraits, Promises...> state_;
};
struct WrapInStatusOrTuple {
template <typename T>
absl::StatusOr<std::tuple<T>> operator()(absl::StatusOr<T> x) {
if (!x.ok()) return x.status();
return std::make_tuple(std::move(*x));
}
};
} // namespace promise_detail
@ -77,6 +90,11 @@ promise_detail::TryJoin<Promises...> TryJoin(Promises... promises) {
return promise_detail::TryJoin<Promises...>(std::move(promises)...);
}
template <typename F>
auto TryJoin(F promise) {
return Map(promise, promise_detail::WrapInStatusOrTuple{});
}
} // namespace grpc_core
#endif // GRPC_SRC_CORE_LIB_PROMISE_TRY_JOIN_H

@ -315,7 +315,6 @@ grpc_cc_test(
deps = [
"test_wakeup_schedulers",
"//src/core:activity",
"//src/core:basic_join",
"//src/core:join",
"//src/core:latch",
"//src/core:seq",
@ -335,7 +334,6 @@ grpc_cc_test(
"//:ref_counted_ptr",
"//src/core:activity",
"//src/core:arena",
"//src/core:basic_join",
"//src/core:event_engine_memory_allocator",
"//src/core:for_each",
"//src/core:join",
@ -361,7 +359,6 @@ grpc_cc_test(
"//:ref_counted_ptr",
"//src/core:activity",
"//src/core:arena",
"//src/core:basic_join",
"//src/core:event_engine_memory_allocator",
"//src/core:for_each",
"//src/core:join",
@ -393,7 +390,6 @@ grpc_cc_test(
"//:grpc",
"//:ref_counted_ptr",
"//src/core:activity",
"//src/core:basic_join",
"//src/core:event_engine_memory_allocator",
"//src/core:join",
"//src/core:map",
@ -440,7 +436,6 @@ grpc_proto_fuzzer(
"//:gpr",
"//:promise",
"//src/core:activity",
"//src/core:basic_join",
"//src/core:join",
"//src/core:map",
"//src/core:poll",

@ -26,7 +26,6 @@
#include "src/core/lib/gprpp/ref_counted_ptr.h"
#include "src/core/lib/promise/activity.h"
#include "src/core/lib/promise/detail/basic_join.h"
#include "src/core/lib/promise/join.h"
#include "src/core/lib/promise/map.h"
#include "src/core/lib/promise/pipe.h"

@ -15,7 +15,6 @@
#include "src/core/lib/promise/join.h"
#include <tuple>
#include <utility>
#include "gtest/gtest.h"

@ -22,7 +22,6 @@
#include "gtest/gtest.h"
#include "src/core/lib/promise/activity.h"
#include "src/core/lib/promise/detail/basic_join.h"
#include "src/core/lib/promise/join.h"
#include "src/core/lib/promise/seq.h"
#include "test/core/promise/test_wakeup_schedulers.h"

@ -27,7 +27,6 @@
#include "src/core/lib/gprpp/ref_counted_ptr.h"
#include "src/core/lib/promise/activity.h"
#include "src/core/lib/promise/detail/basic_join.h"
#include "src/core/lib/promise/for_each.h"
#include "src/core/lib/promise/join.h"
#include "src/core/lib/promise/map.h"

@ -30,7 +30,6 @@
#include "src/core/lib/gprpp/crash.h"
#include "src/core/lib/gprpp/ref_counted_ptr.h"
#include "src/core/lib/promise/activity.h"
#include "src/core/lib/promise/detail/basic_join.h"
#include "src/core/lib/promise/join.h"
#include "src/core/lib/promise/map.h"
#include "src/core/lib/promise/seq.h"

@ -28,7 +28,6 @@
#include <grpc/support/log.h>
#include "src/core/lib/promise/activity.h"
#include "src/core/lib/promise/detail/basic_join.h"
#include "src/core/lib/promise/join.h"
#include "src/core/lib/promise/map.h"
#include "src/core/lib/promise/poll.h"

@ -16,7 +16,6 @@
#include <functional>
#include <tuple>
#include <utility>
#include "absl/utility/utility.h"
#include "gtest/gtest.h"

@ -30,7 +30,6 @@
#include <grpc/event_engine/slice_buffer.h>
#include "src/core/lib/promise/activity.h"
#include "src/core/lib/promise/detail/basic_join.h"
#include "src/core/lib/promise/join.h"
#include "src/core/lib/promise/seq.h"
#include "src/core/lib/slice/slice.h"

@ -0,0 +1,153 @@
#!/usr/bin/env python3
# Copyright 2023 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from mako.template import Template
join_state = Template(
"""
template <class Traits, ${",".join(f"typename P{i}" for i in range(0,n))}>
struct JoinState<Traits, ${",".join(f"P{i}" for i in range(0,n))}> {
template <typename T>
using UnwrappedType = decltype(Traits::Unwrapped(std::declval<T>()));
% for i in range(0,n):
using Promise${i} = PromiseLike<P${i}>;
using Result${i} = UnwrappedType<typename Promise${i}::Result>;
union {
GPR_NO_UNIQUE_ADDRESS Promise${i} promise${i};
GPR_NO_UNIQUE_ADDRESS Result${i} result${i};
};
% endfor
GPR_NO_UNIQUE_ADDRESS BitSet<${n}> ready;
JoinState(${",".join(f"P{i}&& p{i}" for i in range(0,n))}) {
% for i in range(0,n):
Construct(&promise${i}, std::forward<P${i}>(p${i}));
% endfor
}
JoinState(const JoinState& other) {
GPR_ASSERT(other.ready.none());
% for i in range(0,n):
Construct(&promise${i}, other.promise${i});
% endfor
}
JoinState& operator=(const JoinState& other) = delete;
JoinState& operator=(JoinState&& other) = delete;
JoinState(JoinState&& other) noexcept : ready(other.ready) {
% for i in range(0,n):
if (ready.is_set(${i})) {
Construct(&result${i}, std::move(other.result${i}));
} else {
Construct(&promise${i}, std::move(other.promise${i}));
}
% endfor
}
~JoinState() {
% for i in range(0,n):
if (ready.is_set(${i})) {
Destruct(&result${i});
} else {
Destruct(&promise${i});
}
% endfor
}
using Result = typename Traits::template ResultType<std::tuple<
${",".join(f"Result{i}" for i in range(0,n))}>>;
Poll<Result> PollOnce() {
% for i in range(0,n):
if (!ready.is_set(${i})) {
auto poll = promise${i}();
if (auto* p = poll.value_if_ready()) {
if (Traits::IsOk(*p)) {
ready.set(${i});
Destruct(&promise${i});
Construct(&result${i}, Traits::Unwrapped(std::move(*p)));
} else {
return Traits::template EarlyReturn<Result>(std::move(*p));
}
}
}
% endfor
if (ready.all()) {
return Result{std::make_tuple(${",".join(f"std::move(result{i})" for i in range(0,n))})};
}
return Pending{};
}
};
"""
)
front_matter = """
#ifndef GRPC_SRC_CORE_LIB_PROMISE_DETAIL_JOIN_STATE_H
#define GRPC_SRC_CORE_LIB_PROMISE_DETAIL_JOIN_STATE_H
// This file is generated by tools/codegen/core/gen_seq.py
#include <grpc/support/port_platform.h>
#include "src/core/lib/gprpp/construct_destruct.h"
#include "src/core/lib/promise/detail/promise_like.h"
#include "src/core/lib/promise/poll.h"
#include "src/core/lib/gprpp/bitset.h"
#include <grpc/support/log.h>
#include <tuple>
#include <type_traits>
#include <utility>
namespace grpc_core {
namespace promise_detail {
template <class Traits, typename... Ps>
struct JoinState;
"""
end_matter = """
} // namespace promise_detail
} // namespace grpc_core
#endif // GRPC_SRC_CORE_LIB_PROMISE_DETAIL_JOIN_STATE_H
"""
# utility: print a big comment block into a set of files
def put_banner(files, banner):
for f in files:
for line in banner:
print("// %s" % line, file=f)
print("", file=f)
with open(sys.argv[0]) as my_source:
copyright = []
for line in my_source:
if line[0] != "#":
break
for line in my_source:
if line[0] == "#":
copyright.append(line)
break
for line in my_source:
if line[0] != "#":
break
copyright.append(line)
copyright = [line[2:].rstrip() for line in copyright]
with open("src/core/lib/promise/detail/join_state.h", "w") as f:
put_banner([f], copyright)
print(front_matter, file=f)
for n in range(2, 10):
print(join_state.render(n=n), file=f)
print(end_matter, file=f)
Loading…
Cancel
Save