diff --git a/BUILD b/BUILD index 88d27497bd3..95713d03d40 100644 --- a/BUILD +++ b/BUILD @@ -1174,6 +1174,21 @@ grpc_cc_library( ], ) +grpc_cc_library( + name = "for_each", + external_deps = [ + "absl/status", + "absl/types:variant", + ], + language = "c++", + public_hdrs = ["src/core/lib/promise/for_each.h"], + deps = [ + "gpr_platform", + "poll", + "promise_factory", + ], +) + grpc_cc_library( name = "ref_counted", language = "c++", diff --git a/CMakeLists.txt b/CMakeLists.txt index 76e1c358856..90eca224d9d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -870,6 +870,7 @@ if(gRPC_BUILD_TESTS) add_dependencies(buildtests_cxx filter_end2end_test) add_dependencies(buildtests_cxx flaky_network_test) add_dependencies(buildtests_cxx flow_control_test) + add_dependencies(buildtests_cxx for_each_test) add_dependencies(buildtests_cxx generic_end2end_test) if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) add_dependencies(buildtests_cxx global_config_env_test) @@ -11258,6 +11259,111 @@ target_link_libraries(flow_control_test ) +endif() +if(gRPC_BUILD_TESTS) + +add_executable(for_each_test + src/core/ext/upb-generated/google/api/annotations.upb.c + src/core/ext/upb-generated/google/api/expr/v1alpha1/checked.upb.c + src/core/ext/upb-generated/google/api/expr/v1alpha1/syntax.upb.c + src/core/ext/upb-generated/google/api/http.upb.c + src/core/ext/upb-generated/google/protobuf/any.upb.c + src/core/ext/upb-generated/google/protobuf/duration.upb.c + src/core/ext/upb-generated/google/protobuf/empty.upb.c + src/core/ext/upb-generated/google/protobuf/struct.upb.c + src/core/ext/upb-generated/google/protobuf/timestamp.upb.c + src/core/ext/upb-generated/google/protobuf/wrappers.upb.c + src/core/ext/upb-generated/google/rpc/status.upb.c + src/core/lib/gpr/alloc.cc + src/core/lib/gpr/atm.cc + src/core/lib/gpr/cpu_iphone.cc + src/core/lib/gpr/cpu_linux.cc + src/core/lib/gpr/cpu_posix.cc + src/core/lib/gpr/cpu_windows.cc + src/core/lib/gpr/env_linux.cc + src/core/lib/gpr/env_posix.cc + src/core/lib/gpr/env_windows.cc + src/core/lib/gpr/log.cc + src/core/lib/gpr/log_android.cc + src/core/lib/gpr/log_linux.cc + src/core/lib/gpr/log_posix.cc + src/core/lib/gpr/log_windows.cc + src/core/lib/gpr/murmur_hash.cc + src/core/lib/gpr/string.cc + src/core/lib/gpr/string_posix.cc + src/core/lib/gpr/string_util_windows.cc + src/core/lib/gpr/string_windows.cc + src/core/lib/gpr/sync.cc + src/core/lib/gpr/sync_abseil.cc + src/core/lib/gpr/sync_posix.cc + src/core/lib/gpr/sync_windows.cc + src/core/lib/gpr/time.cc + src/core/lib/gpr/time_posix.cc + src/core/lib/gpr/time_precise.cc + src/core/lib/gpr/time_windows.cc + src/core/lib/gpr/tmpfile_msys.cc + src/core/lib/gpr/tmpfile_posix.cc + src/core/lib/gpr/tmpfile_windows.cc + src/core/lib/gpr/wrap_memcpy.cc + src/core/lib/gprpp/arena.cc + src/core/lib/gprpp/examine_stack.cc + src/core/lib/gprpp/fork.cc + src/core/lib/gprpp/global_config_env.cc + src/core/lib/gprpp/host_port.cc + src/core/lib/gprpp/mpscq.cc + src/core/lib/gprpp/stat_posix.cc + src/core/lib/gprpp/stat_windows.cc + src/core/lib/gprpp/status_helper.cc + src/core/lib/gprpp/thd_posix.cc + src/core/lib/gprpp/thd_windows.cc + src/core/lib/gprpp/time_util.cc + src/core/lib/profiling/basic_timers.cc + src/core/lib/profiling/stap_timers.cc + src/core/lib/promise/activity.cc + test/core/promise/for_each_test.cc + third_party/googletest/googletest/src/gtest-all.cc + third_party/googletest/googlemock/src/gmock-all.cc +) + +target_include_directories(for_each_test + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR} + ${CMAKE_CURRENT_SOURCE_DIR}/include + ${_gRPC_ADDRESS_SORTING_INCLUDE_DIR} + ${_gRPC_RE2_INCLUDE_DIR} + ${_gRPC_SSL_INCLUDE_DIR} + ${_gRPC_UPB_GENERATED_DIR} + ${_gRPC_UPB_GRPC_GENERATED_DIR} + ${_gRPC_UPB_INCLUDE_DIR} + ${_gRPC_XXHASH_INCLUDE_DIR} + ${_gRPC_ZLIB_INCLUDE_DIR} + third_party/googletest/googletest/include + third_party/googletest/googletest + third_party/googletest/googlemock/include + third_party/googletest/googlemock + ${_gRPC_PROTO_GENS_DIR} +) + +target_link_libraries(for_each_test + ${_gRPC_PROTOBUF_LIBRARIES} + ${_gRPC_ALLTARGETS_LIBRARIES} + absl::base + absl::core_headers + absl::flat_hash_set + absl::memory + absl::status + absl::statusor + absl::cord + absl::str_format + absl::strings + absl::synchronization + absl::time + absl::optional + absl::variant + upb +) + + endif() if(gRPC_BUILD_TESTS) diff --git a/build_autogenerated.yaml b/build_autogenerated.yaml index 46b742113ce..e09b30ba638 100644 --- a/build_autogenerated.yaml +++ b/build_autogenerated.yaml @@ -5729,6 +5729,145 @@ targets: - test/core/transport/chttp2/flow_control_test.cc deps: - grpc_test_util +- name: for_each_test + gtest: true + build: test + language: c++ + headers: + - src/core/ext/upb-generated/google/api/annotations.upb.h + - src/core/ext/upb-generated/google/api/expr/v1alpha1/checked.upb.h + - src/core/ext/upb-generated/google/api/expr/v1alpha1/syntax.upb.h + - src/core/ext/upb-generated/google/api/http.upb.h + - src/core/ext/upb-generated/google/protobuf/any.upb.h + - src/core/ext/upb-generated/google/protobuf/duration.upb.h + - src/core/ext/upb-generated/google/protobuf/empty.upb.h + - src/core/ext/upb-generated/google/protobuf/struct.upb.h + - src/core/ext/upb-generated/google/protobuf/timestamp.upb.h + - src/core/ext/upb-generated/google/protobuf/wrappers.upb.h + - src/core/ext/upb-generated/google/rpc/status.upb.h + - src/core/lib/gpr/alloc.h + - src/core/lib/gpr/env.h + - src/core/lib/gpr/murmur_hash.h + - src/core/lib/gpr/spinlock.h + - src/core/lib/gpr/string.h + - src/core/lib/gpr/string_windows.h + - src/core/lib/gpr/time_precise.h + - src/core/lib/gpr/tls.h + - src/core/lib/gpr/tmpfile.h + - src/core/lib/gpr/useful.h + - src/core/lib/gprpp/arena.h + - src/core/lib/gprpp/atomic_utils.h + - src/core/lib/gprpp/bitset.h + - src/core/lib/gprpp/construct_destruct.h + - src/core/lib/gprpp/debug_location.h + - src/core/lib/gprpp/examine_stack.h + - src/core/lib/gprpp/fork.h + - src/core/lib/gprpp/global_config.h + - src/core/lib/gprpp/global_config_custom.h + - src/core/lib/gprpp/global_config_env.h + - src/core/lib/gprpp/global_config_generic.h + - src/core/lib/gprpp/host_port.h + - src/core/lib/gprpp/manual_constructor.h + - src/core/lib/gprpp/memory.h + - src/core/lib/gprpp/mpscq.h + - src/core/lib/gprpp/stat.h + - src/core/lib/gprpp/status_helper.h + - src/core/lib/gprpp/sync.h + - src/core/lib/gprpp/thd.h + - src/core/lib/gprpp/time_util.h + - src/core/lib/profiling/timers.h + - src/core/lib/promise/activity.h + - src/core/lib/promise/context.h + - src/core/lib/promise/detail/basic_join.h + - src/core/lib/promise/detail/basic_seq.h + - src/core/lib/promise/detail/promise_factory.h + - src/core/lib/promise/detail/promise_like.h + - src/core/lib/promise/detail/status.h + - src/core/lib/promise/detail/switch.h + - src/core/lib/promise/for_each.h + - src/core/lib/promise/intra_activity_waiter.h + - src/core/lib/promise/join.h + - src/core/lib/promise/map.h + - src/core/lib/promise/observable.h + - src/core/lib/promise/pipe.h + - src/core/lib/promise/poll.h + - src/core/lib/promise/seq.h + - src/core/lib/promise/wait_set.h + src: + - src/core/ext/upb-generated/google/api/annotations.upb.c + - src/core/ext/upb-generated/google/api/expr/v1alpha1/checked.upb.c + - src/core/ext/upb-generated/google/api/expr/v1alpha1/syntax.upb.c + - src/core/ext/upb-generated/google/api/http.upb.c + - src/core/ext/upb-generated/google/protobuf/any.upb.c + - src/core/ext/upb-generated/google/protobuf/duration.upb.c + - src/core/ext/upb-generated/google/protobuf/empty.upb.c + - src/core/ext/upb-generated/google/protobuf/struct.upb.c + - src/core/ext/upb-generated/google/protobuf/timestamp.upb.c + - src/core/ext/upb-generated/google/protobuf/wrappers.upb.c + - src/core/ext/upb-generated/google/rpc/status.upb.c + - src/core/lib/gpr/alloc.cc + - src/core/lib/gpr/atm.cc + - src/core/lib/gpr/cpu_iphone.cc + - src/core/lib/gpr/cpu_linux.cc + - src/core/lib/gpr/cpu_posix.cc + - src/core/lib/gpr/cpu_windows.cc + - src/core/lib/gpr/env_linux.cc + - src/core/lib/gpr/env_posix.cc + - src/core/lib/gpr/env_windows.cc + - src/core/lib/gpr/log.cc + - src/core/lib/gpr/log_android.cc + - src/core/lib/gpr/log_linux.cc + - src/core/lib/gpr/log_posix.cc + - src/core/lib/gpr/log_windows.cc + - src/core/lib/gpr/murmur_hash.cc + - src/core/lib/gpr/string.cc + - src/core/lib/gpr/string_posix.cc + - src/core/lib/gpr/string_util_windows.cc + - src/core/lib/gpr/string_windows.cc + - src/core/lib/gpr/sync.cc + - src/core/lib/gpr/sync_abseil.cc + - src/core/lib/gpr/sync_posix.cc + - src/core/lib/gpr/sync_windows.cc + - src/core/lib/gpr/time.cc + - src/core/lib/gpr/time_posix.cc + - src/core/lib/gpr/time_precise.cc + - src/core/lib/gpr/time_windows.cc + - src/core/lib/gpr/tmpfile_msys.cc + - src/core/lib/gpr/tmpfile_posix.cc + - src/core/lib/gpr/tmpfile_windows.cc + - src/core/lib/gpr/wrap_memcpy.cc + - src/core/lib/gprpp/arena.cc + - src/core/lib/gprpp/examine_stack.cc + - src/core/lib/gprpp/fork.cc + - src/core/lib/gprpp/global_config_env.cc + - src/core/lib/gprpp/host_port.cc + - src/core/lib/gprpp/mpscq.cc + - src/core/lib/gprpp/stat_posix.cc + - src/core/lib/gprpp/stat_windows.cc + - src/core/lib/gprpp/status_helper.cc + - src/core/lib/gprpp/thd_posix.cc + - src/core/lib/gprpp/thd_windows.cc + - src/core/lib/gprpp/time_util.cc + - src/core/lib/profiling/basic_timers.cc + - src/core/lib/profiling/stap_timers.cc + - src/core/lib/promise/activity.cc + - test/core/promise/for_each_test.cc + deps: + - absl/base:base + - absl/base:core_headers + - absl/container:flat_hash_set + - absl/memory:memory + - absl/status:status + - absl/status:statusor + - absl/strings:cord + - absl/strings:str_format + - absl/strings:strings + - absl/synchronization:synchronization + - absl/time:time + - absl/types:optional + - absl/types:variant + - upb + uses_polling: false - name: generic_end2end_test gtest: true build: test diff --git a/src/core/lib/promise/for_each.h b/src/core/lib/promise/for_each.h new file mode 100644 index 00000000000..c9136ee69b3 --- /dev/null +++ b/src/core/lib/promise/for_each.h @@ -0,0 +1,136 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GRPC_CORE_LIB_PROMISE_FOR_EACH_H +#define GRPC_CORE_LIB_PROMISE_FOR_EACH_H + +#include + +#include "absl/status/status.h" +#include "absl/types/variant.h" + +#include "src/core/lib/promise/detail/promise_factory.h" +#include "src/core/lib/promise/poll.h" + +namespace grpc_core { + +namespace for_each_detail { + +// Helper function: at the end of each iteration of a for-each loop, this is +// called. If the iteration failed, return failure. If the iteration succeeded, +// then call the next iteration. +template +Poll FinishIteration(absl::Status* r, Reader* reader, + CallPoll call_poll) { + if (r->ok()) { + auto next = reader->Next(); + return call_poll(next); + } + return std::move(*r); +} + +// Done creates statuses for the end of the iteration. It's templated on the +// type of the result of the ForEach loop, so that we can introduce new types +// easily. +template +struct Done; +template <> +struct Done { + static absl::Status Make() { return absl::OkStatus(); } +}; + +template +class ForEach { + private: + using ReaderNext = decltype(std::declval().Next()); + using ReaderResult = typename PollTraits()())>::Type::value_type; + using ActionFactory = promise_detail::PromiseFactory; + using ActionPromise = typename ActionFactory::Promise; + + public: + using Result = + typename PollTraits()())>::Type; + ForEach(Reader reader, Action action) + : reader_(std::move(reader)), + action_factory_(std::move(action)), + state_(reader_.Next()) {} + + ForEach(const ForEach&) = delete; + ForEach& operator=(const ForEach&) = delete; + // noexcept causes compiler errors on older gcc's + // NOLINTNEXTLINE(performance-noexcept-move-constructor) + ForEach(ForEach&&) = default; + // noexcept causes compiler errors on older gcc's + // NOLINTNEXTLINE(performance-noexcept-move-constructor) + ForEach& operator=(ForEach&&) = default; + + Poll operator()() { + return absl::visit(CallPoll{this}, state_); + } + + private: + Reader reader_; + ActionFactory action_factory_; + absl::variant state_; + + // Call the inner poll function, and if it's finished, start the next + // iteration. If kSetState==true, also set the current state in self->state_. + // We omit that on the first iteration because it's common to poll once and + // not change state, which saves us some work. + template + struct CallPoll { + ForEach* const self; + + Poll operator()(ReaderNext& reader_next) { + auto r = reader_next(); + if (auto* p = absl::get_if(&r)) { + if (p->has_value()) { + auto action = self->action_factory_.Repeated(std::move(**p)); + return CallPoll{self}(action); + } else { + return Done::Make(); + } + } + if (kSetState) { + self->state_.template emplace(std::move(reader_next)); + } + return Pending(); + } + + Poll operator()(ActionPromise& promise) { + auto r = promise(); + if (auto* p = absl::get_if(&r)) { + return FinishIteration(p, &self->reader_, CallPoll{self}); + } + if (kSetState) { + self->state_.template emplace(std::move(promise)); + } + return Pending(); + } + }; +}; + +} // namespace for_each_detail + +/// For each item acquired by calling Reader::Next, run the promise Action. +template +for_each_detail::ForEach ForEach(Reader reader, Action action) { + return for_each_detail::ForEach(std::move(reader), + std::move(action)); +} + +} // namespace grpc_core + +#endif // GRPC_CORE_LIB_PROMISE_FOR_EACH_H diff --git a/test/core/promise/BUILD b/test/core/promise/BUILD index f8735b42297..3382027e4c5 100644 --- a/test/core/promise/BUILD +++ b/test/core/promise/BUILD @@ -217,6 +217,23 @@ grpc_cc_test( ], ) +grpc_cc_test( + name = "for_each_test", + srcs = ["for_each_test.cc"], + external_deps = ["gtest"], + language = "c++", + uses_polling = False, + deps = [ + "//:for_each", + "//:join", + "//:map", + "//:observable", + "//:pipe", + "//:seq", + "//test/core/util:grpc_suppressions", + ], +) + grpc_cc_test( name = "pipe_test", srcs = ["pipe_test.cc"], diff --git a/test/core/promise/for_each_test.cc b/test/core/promise/for_each_test.cc new file mode 100644 index 00000000000..1d112e5e616 --- /dev/null +++ b/test/core/promise/for_each_test.cc @@ -0,0 +1,71 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/core/lib/promise/for_each.h" + +#include +#include + +#include "src/core/lib/promise/join.h" +#include "src/core/lib/promise/map.h" +#include "src/core/lib/promise/observable.h" +#include "src/core/lib/promise/pipe.h" +#include "src/core/lib/promise/seq.h" + +using testing::Mock; +using testing::MockFunction; +using testing::StrictMock; + +namespace grpc_core { + +TEST(ForEachTest, SendThriceWithPipe) { + Pipe pipe; + int num_received = 0; + StrictMock> on_done; + EXPECT_CALL(on_done, Call(absl::OkStatus())); + MakeActivity( + [&pipe, &num_received] { + return Map( + Join( + // Push 3 things into a pipe -- 1, 2, then 3 -- then close. + Seq( + pipe.sender.Push(1), + [&pipe] { return pipe.sender.Push(2); }, + [&pipe] { return pipe.sender.Push(3); }, + [&pipe] { + auto drop = std::move(pipe.sender); + return absl::OkStatus(); + }), + // Use a ForEach loop to read them out and verify all values are + // seen. + ForEach(std::move(pipe.receiver), + [&num_received](int i) { + num_received++; + EXPECT_EQ(num_received, i); + return absl::OkStatus(); + })), + JustElem<1>()); + }, + NoCallbackScheduler(), + [&on_done](absl::Status status) { on_done.Call(std::move(status)); }); + Mock::VerifyAndClearExpectations(&on_done); + EXPECT_EQ(num_received, 3); +} + +} // namespace grpc_core + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tools/run_tests/generated/tests.json b/tools/run_tests/generated/tests.json index a7a67c32ca1..acf05718e72 100644 --- a/tools/run_tests/generated/tests.json +++ b/tools/run_tests/generated/tests.json @@ -4799,6 +4799,30 @@ ], "uses_polling": true }, + { + "args": [], + "benchmark": false, + "ci_platforms": [ + "linux", + "mac", + "posix", + "windows" + ], + "cpu_cost": 1.0, + "exclude_configs": [], + "exclude_iomgrs": [], + "flaky": false, + "gtest": true, + "language": "c++", + "name": "for_each_test", + "platforms": [ + "linux", + "mac", + "posix", + "windows" + ], + "uses_polling": false + }, { "args": [], "benchmark": false,