From 8448d499e27a9e036796c07530ff6da4d98c043d Mon Sep 17 00:00:00 2001 From: Craig Tiller Date: Thu, 14 Dec 2023 13:10:59 -0800 Subject: [PATCH] [promises] Add `AllOk` combinator (#35304) `AllOk` runs a set of promises concurrently, and like `TryJoin` waits for them all to succeed or one to fail. Unlike `TryJoin` it returns a single unified status of the composition, so cannot handle member promises that might return `StatusOr` or the like. Closes #35304 COPYBARA_INTEGRATE_REVIEW=https://github.com/grpc/grpc/pull/35304 from ctiller:all-review 30f5f809c6514d38bf1ee3fe9180d52ae03a6833 PiperOrigin-RevId: 591031189 --- CMakeLists.txt | 39 ++++++++++++ build_autogenerated.yaml | 26 ++++++++ src/core/BUILD | 20 ++++++ src/core/lib/promise/all_ok.h | 80 ++++++++++++++++++++++++ src/core/lib/promise/detail/join_state.h | 43 ++++++------- src/core/lib/promise/join.h | 4 ++ src/core/lib/promise/status_flag.h | 3 + src/core/lib/promise/try_join.h | 5 ++ test/core/promise/BUILD | 14 +++++ test/core/promise/all_ok_test.cc | 53 ++++++++++++++++ tools/codegen/core/gen_join.py | 2 +- tools/run_tests/generated/tests.json | 24 +++++++ 12 files changed, 291 insertions(+), 22 deletions(-) create mode 100644 src/core/lib/promise/all_ok.h create mode 100644 test/core/promise/all_ok_test.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index 18c76605560..60817e90943 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -852,6 +852,7 @@ if(gRPC_BUILD_TESTS) if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) add_dependencies(buildtests_cxx alarm_test) endif() + add_dependencies(buildtests_cxx all_ok_test) add_dependencies(buildtests_cxx alloc_test) add_dependencies(buildtests_cxx alpn_test) if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_POSIX) @@ -5895,6 +5896,44 @@ endif() endif() if(gRPC_BUILD_TESTS) +add_executable(all_ok_test + src/core/lib/debug/trace.cc + src/core/lib/promise/trace.cc + test/core/promise/all_ok_test.cc +) +target_compile_features(all_ok_test PUBLIC cxx_std_14) +target_include_directories(all_ok_test + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR} + ${CMAKE_CURRENT_SOURCE_DIR}/include + ${_gRPC_ADDRESS_SORTING_INCLUDE_DIR} + ${_gRPC_RE2_INCLUDE_DIR} + ${_gRPC_SSL_INCLUDE_DIR} + ${_gRPC_UPB_GENERATED_DIR} + ${_gRPC_UPB_GRPC_GENERATED_DIR} + ${_gRPC_UPB_INCLUDE_DIR} + ${_gRPC_XXHASH_INCLUDE_DIR} + ${_gRPC_ZLIB_INCLUDE_DIR} + third_party/googletest/googletest/include + third_party/googletest/googletest + third_party/googletest/googlemock/include + third_party/googletest/googlemock + ${_gRPC_PROTO_GENS_DIR} +) + +target_link_libraries(all_ok_test + ${_gRPC_ALLTARGETS_LIBRARIES} + gtest + absl::type_traits + absl::statusor + absl::utility + gpr +) + + +endif() +if(gRPC_BUILD_TESTS) + add_executable(alloc_test test/core/gpr/alloc_test.cc ) diff --git a/build_autogenerated.yaml b/build_autogenerated.yaml index a62fc967c36..71a87f34adb 100644 --- a/build_autogenerated.yaml +++ b/build_autogenerated.yaml @@ -5410,6 +5410,32 @@ targets: - linux - posix - mac +- name: all_ok_test + gtest: true + build: test + language: c++ + headers: + - src/core/lib/debug/trace.h + - src/core/lib/gprpp/bitset.h + - src/core/lib/promise/all_ok.h + - src/core/lib/promise/detail/join_state.h + - src/core/lib/promise/detail/promise_like.h + - src/core/lib/promise/detail/status.h + - src/core/lib/promise/map.h + - src/core/lib/promise/poll.h + - src/core/lib/promise/status_flag.h + - src/core/lib/promise/trace.h + src: + - src/core/lib/debug/trace.cc + - src/core/lib/promise/trace.cc + - test/core/promise/all_ok_test.cc + deps: + - gtest + - absl/meta:type_traits + - absl/status:statusor + - absl/utility:utility + - gpr + uses_polling: false - name: alloc_test gtest: true build: test diff --git a/src/core/BUILD b/src/core/BUILD index 58017456079..b7c2cdf9c9d 100644 --- a/src/core/BUILD +++ b/src/core/BUILD @@ -710,6 +710,26 @@ grpc_cc_library( ], ) +grpc_cc_library( + name = "all_ok", + external_deps = [ + "absl/meta:type_traits", + "absl/status", + "absl/status:statusor", + ], + language = "c++", + public_hdrs = [ + "lib/promise/all_ok.h", + ], + deps = [ + "join_state", + "map", + "poll", + "status_flag", + "//:gpr_platform", + ], +) + grpc_cc_library( name = "switch", language = "c++", diff --git a/src/core/lib/promise/all_ok.h b/src/core/lib/promise/all_ok.h new file mode 100644 index 00000000000..33941a709f2 --- /dev/null +++ b/src/core/lib/promise/all_ok.h @@ -0,0 +1,80 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GRPC_SRC_CORE_LIB_PROMISE_ALL_OK_H +#define GRPC_SRC_CORE_LIB_PROMISE_ALL_OK_H + +#include + +#include +#include + +#include "absl/meta/type_traits.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" + +#include "src/core/lib/promise/detail/join_state.h" +#include "src/core/lib/promise/map.h" +#include "src/core/lib/promise/poll.h" +#include "src/core/lib/promise/status_flag.h" + +namespace grpc_core { + +namespace promise_detail { + +// Traits object to pass to JoinState +template +struct AllOkTraits { + template + using ResultType = Result; + template + static bool IsOk(const T& x) { + return IsStatusOk(x); + } + static Empty Unwrapped(StatusFlag) { return Empty{}; } + static Empty Unwrapped(absl::Status) { return Empty{}; } + template + static R EarlyReturn(T&& x) { + return StatusCast(std::forward(x)); + } + template + static Result FinalReturn(A&&...) { + return Result{}; + } +}; + +// Implementation of AllOk combinator. +template +class AllOk { + public: + explicit AllOk(Promises... promises) : state_(std::move(promises)...) {} + auto operator()() { return state_.PollOnce(); } + + private: + JoinState, Promises...> state_; +}; + +} // namespace promise_detail + +// Run all promises. +// If any fail, cancel the rest and return the failure. +// If all succeed, return Ok. +template +auto AllOk(Promises... promises) { + return promise_detail::AllOk(std::move(promises)...); +} + +} // namespace grpc_core + +#endif // GRPC_SRC_CORE_LIB_PROMISE_ALL_OK_H diff --git a/src/core/lib/promise/detail/join_state.h b/src/core/lib/promise/detail/join_state.h index 1fbadd4372d..4c36208e026 100644 --- a/src/core/lib/promise/detail/join_state.h +++ b/src/core/lib/promise/detail/join_state.h @@ -20,6 +20,7 @@ #include #include +#include #include #include @@ -137,7 +138,7 @@ struct JoinState { gpr_log(GPR_DEBUG, "join[%p]: joint 2/2 already ready", this); } if (ready.all()) { - return Result{std::make_tuple(std::move(result0), std::move(result1))}; + return Traits::FinalReturn(std::move(result0), std::move(result1)); } return Pending{}; } @@ -286,8 +287,8 @@ struct JoinState { gpr_log(GPR_DEBUG, "join[%p]: joint 3/3 already ready", this); } if (ready.all()) { - return Result{std::make_tuple(std::move(result0), std::move(result1), - std::move(result2))}; + return Traits::FinalReturn(std::move(result0), std::move(result1), + std::move(result2)); } return Pending{}; } @@ -477,8 +478,8 @@ struct JoinState { gpr_log(GPR_DEBUG, "join[%p]: joint 4/4 already ready", this); } if (ready.all()) { - return Result{std::make_tuple(std::move(result0), std::move(result1), - std::move(result2), std::move(result3))}; + return Traits::FinalReturn(std::move(result0), std::move(result1), + std::move(result2), std::move(result3)); } return Pending{}; } @@ -710,9 +711,9 @@ struct JoinState { gpr_log(GPR_DEBUG, "join[%p]: joint 5/5 already ready", this); } if (ready.all()) { - return Result{std::make_tuple(std::move(result0), std::move(result1), - std::move(result2), std::move(result3), - std::move(result4))}; + return Traits::FinalReturn(std::move(result0), std::move(result1), + std::move(result2), std::move(result3), + std::move(result4)); } return Pending{}; } @@ -985,9 +986,9 @@ struct JoinState { gpr_log(GPR_DEBUG, "join[%p]: joint 6/6 already ready", this); } if (ready.all()) { - return Result{std::make_tuple(std::move(result0), std::move(result1), - std::move(result2), std::move(result3), - std::move(result4), std::move(result5))}; + return Traits::FinalReturn(std::move(result0), std::move(result1), + std::move(result2), std::move(result3), + std::move(result4), std::move(result5)); } return Pending{}; } @@ -1301,10 +1302,10 @@ struct JoinState { gpr_log(GPR_DEBUG, "join[%p]: joint 7/7 already ready", this); } if (ready.all()) { - return Result{std::make_tuple(std::move(result0), std::move(result1), - std::move(result2), std::move(result3), - std::move(result4), std::move(result5), - std::move(result6))}; + return Traits::FinalReturn(std::move(result0), std::move(result1), + std::move(result2), std::move(result3), + std::move(result4), std::move(result5), + std::move(result6)); } return Pending{}; } @@ -1660,10 +1661,10 @@ struct JoinState { gpr_log(GPR_DEBUG, "join[%p]: joint 8/8 already ready", this); } if (ready.all()) { - return Result{std::make_tuple(std::move(result0), std::move(result1), - std::move(result2), std::move(result3), - std::move(result4), std::move(result5), - std::move(result6), std::move(result7))}; + return Traits::FinalReturn(std::move(result0), std::move(result1), + std::move(result2), std::move(result3), + std::move(result4), std::move(result5), + std::move(result6), std::move(result7)); } return Pending{}; } @@ -2061,10 +2062,10 @@ struct JoinState { gpr_log(GPR_DEBUG, "join[%p]: joint 9/9 already ready", this); } if (ready.all()) { - return Result{std::make_tuple( + return Traits::FinalReturn( std::move(result0), std::move(result1), std::move(result2), std::move(result3), std::move(result4), std::move(result5), - std::move(result6), std::move(result7), std::move(result8))}; + std::move(result6), std::move(result7), std::move(result8)); } return Pending{}; } diff --git a/src/core/lib/promise/join.h b/src/core/lib/promise/join.h index 014fe4ac296..9918319ce1e 100644 --- a/src/core/lib/promise/join.h +++ b/src/core/lib/promise/join.h @@ -44,6 +44,10 @@ struct JoinTraits { static R EarlyReturn(T) { abort(); } + template + static std::tuple FinalReturn(A... a) { + return std::make_tuple(std::move(a)...); + } }; template diff --git a/src/core/lib/promise/status_flag.h b/src/core/lib/promise/status_flag.h index 1d7b09fabea..0e5af4f0da8 100644 --- a/src/core/lib/promise/status_flag.h +++ b/src/core/lib/promise/status_flag.h @@ -45,6 +45,7 @@ struct StatusCastImpl { // (false). class StatusFlag { public: + StatusFlag() : value_(true) {} explicit StatusFlag(bool value) : value_(value) {} // NOLINTNEXTLINE(google-explicit-constructor) StatusFlag(Failure) : value_(false) {} @@ -53,6 +54,8 @@ class StatusFlag { bool ok() const { return value_; } + bool operator==(StatusFlag other) const { return value_ == other.value_; } + private: bool value_; }; diff --git a/src/core/lib/promise/try_join.h b/src/core/lib/promise/try_join.h index efc0cfd78cb..14503db5796 100644 --- a/src/core/lib/promise/try_join.h +++ b/src/core/lib/promise/try_join.h @@ -65,6 +65,11 @@ struct TryJoinTraits { static R EarlyReturn(absl::Status x) { return x; } + template + static auto FinalReturn(A&&... a) { + return absl::StatusOr>( + std::make_tuple(std::forward(a)...)); + } }; // Implementation of TryJoin combinator. diff --git a/test/core/promise/BUILD b/test/core/promise/BUILD index 2370dfdc4a7..b4c4ce855c6 100644 --- a/test/core/promise/BUILD +++ b/test/core/promise/BUILD @@ -239,6 +239,20 @@ grpc_cc_test( deps = ["//src/core:try_join"], ) +grpc_cc_test( + name = "all_ok_test", + srcs = ["all_ok_test.cc"], + external_deps = [ + "absl/utility", + "gtest", + ], + language = "c++", + tags = ["promise_test"], + uses_event_engine = False, + uses_polling = False, + deps = ["//src/core:all_ok"], +) + grpc_cc_test( name = "seq_test", srcs = ["seq_test.cc"], diff --git a/test/core/promise/all_ok_test.cc b/test/core/promise/all_ok_test.cc new file mode 100644 index 00000000000..3f4e7eed5fb --- /dev/null +++ b/test/core/promise/all_ok_test.cc @@ -0,0 +1,53 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/core/lib/promise/all_ok.h" + +#include +#include +#include + +#include "absl/utility/utility.h" +#include "gtest/gtest.h" + +namespace grpc_core { + +using P = std::function()>; + +P instant_success() { + return [] { return Success{}; }; +} + +P instant_fail() { + return [] { return Failure{}; }; +} + +Poll succeeded() { return Poll(Success{}); } + +Poll failed() { return Poll(Failure{}); } + +TEST(AllOkTest, Join2) { + EXPECT_EQ(AllOk(instant_fail(), instant_fail())(), failed()); + EXPECT_EQ(AllOk(instant_fail(), instant_success())(), failed()); + EXPECT_EQ(AllOk(instant_success(), instant_fail())(), failed()); + EXPECT_EQ(AllOk(instant_success(), instant_success())(), + succeeded()); +} + +} // namespace grpc_core + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tools/codegen/core/gen_join.py b/tools/codegen/core/gen_join.py index cc978eba285..6ddea69532a 100755 --- a/tools/codegen/core/gen_join.py +++ b/tools/codegen/core/gen_join.py @@ -92,7 +92,7 @@ struct JoinState { } % endfor if (ready.all()) { - return Result{std::make_tuple(${",".join(f"std::move(result{i})" for i in range(0,n))})}; + return Traits::FinalReturn(${",".join(f"std::move(result{i})" for i in range(0,n))}); } return Pending{}; } diff --git a/tools/run_tests/generated/tests.json b/tools/run_tests/generated/tests.json index 6b2a8024173..97e35cd758b 100644 --- a/tools/run_tests/generated/tests.json +++ b/tools/run_tests/generated/tests.json @@ -203,6 +203,30 @@ ], "uses_polling": true }, + { + "args": [], + "benchmark": false, + "ci_platforms": [ + "linux", + "mac", + "posix", + "windows" + ], + "cpu_cost": 1.0, + "exclude_configs": [], + "exclude_iomgrs": [], + "flaky": false, + "gtest": true, + "language": "c++", + "name": "all_ok_test", + "platforms": [ + "linux", + "mac", + "posix", + "windows" + ], + "uses_polling": false + }, { "args": [], "benchmark": false,