diff --git a/BUILD b/BUILD index e272184eb04..80b747f7520 100644 --- a/BUILD +++ b/BUILD @@ -908,6 +908,20 @@ grpc_cc_library( ], ) +grpc_cc_library( + name = "if", + external_deps = [ + "absl/status:statusor", + ], + language = "c++", + public_hdrs = ["src/core/lib/promise/if.h"], + deps = [ + "gpr_platform", + "poll", + "promise_factory", + ], +) + grpc_cc_library( name = "promise_status", external_deps = [ diff --git a/CMakeLists.txt b/CMakeLists.txt index 9c334715aab..07fb2ca6c0e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -886,6 +886,7 @@ if(gRPC_BUILD_TESTS) add_dependencies(buildtests_cxx hpack_encoder_index_test) add_dependencies(buildtests_cxx http2_client) add_dependencies(buildtests_cxx hybrid_end2end_test) + add_dependencies(buildtests_cxx if_test) add_dependencies(buildtests_cxx init_test) add_dependencies(buildtests_cxx initial_settings_frame_bad_client_test) add_dependencies(buildtests_cxx insecure_security_connector_test) @@ -11960,6 +11961,42 @@ target_link_libraries(hybrid_end2end_test ) +endif() +if(gRPC_BUILD_TESTS) + +add_executable(if_test + test/core/promise/if_test.cc + third_party/googletest/googletest/src/gtest-all.cc + third_party/googletest/googlemock/src/gmock-all.cc +) + +target_include_directories(if_test + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR} + ${CMAKE_CURRENT_SOURCE_DIR}/include + ${_gRPC_ADDRESS_SORTING_INCLUDE_DIR} + ${_gRPC_RE2_INCLUDE_DIR} + ${_gRPC_SSL_INCLUDE_DIR} + ${_gRPC_UPB_GENERATED_DIR} + ${_gRPC_UPB_GRPC_GENERATED_DIR} + ${_gRPC_UPB_INCLUDE_DIR} + ${_gRPC_XXHASH_INCLUDE_DIR} + ${_gRPC_ZLIB_INCLUDE_DIR} + third_party/googletest/googletest/include + third_party/googletest/googletest + third_party/googletest/googlemock/include + third_party/googletest/googlemock + ${_gRPC_PROTO_GENS_DIR} +) + +target_link_libraries(if_test + ${_gRPC_PROTOBUF_LIBRARIES} + ${_gRPC_ALLTARGETS_LIBRARIES} + absl::statusor + absl::variant +) + + endif() if(gRPC_BUILD_TESTS) diff --git a/build_autogenerated.yaml b/build_autogenerated.yaml index e36783431a5..06212380bfc 100644 --- a/build_autogenerated.yaml +++ b/build_autogenerated.yaml @@ -5681,6 +5681,21 @@ targets: - test/cpp/end2end/test_service_impl.cc deps: - grpc++_test_util +- name: if_test + gtest: true + build: test + language: c++ + headers: + - src/core/lib/promise/detail/promise_factory.h + - src/core/lib/promise/detail/promise_like.h + - src/core/lib/promise/if.h + - src/core/lib/promise/poll.h + src: + - test/core/promise/if_test.cc + deps: + - absl/status:statusor + - absl/types:variant + uses_polling: false - name: init_test gtest: true build: test diff --git a/src/core/lib/promise/detail/promise_factory.h b/src/core/lib/promise/detail/promise_factory.h index e487f4cefc3..2b802165b35 100644 --- a/src/core/lib/promise/detail/promise_factory.h +++ b/src/core/lib/promise/detail/promise_factory.h @@ -162,16 +162,16 @@ class PromiseFactory { public: using Arg = A; + using Promise = + decltype(PromiseFactoryImpl(std::move(f_), std::declval())); explicit PromiseFactory(F f) : f_(std::move(f)) {} - auto Once(Arg&& a) - -> decltype(PromiseFactoryImpl(std::move(f_), std::forward(a))) { + Promise Once(Arg&& a) { return PromiseFactoryImpl(std::move(f_), std::forward(a)); } - auto Repeated(Arg&& a) const - -> decltype(PromiseFactoryImpl(f_, std::forward(a))) { + Promise Repeated(Arg&& a) const { return PromiseFactoryImpl(f_, std::forward(a)); } }; @@ -183,16 +183,13 @@ class PromiseFactory { public: using Arg = void; + using Promise = decltype(PromiseFactoryImpl(std::move(f_))); explicit PromiseFactory(F f) : f_(std::move(f)) {} - auto Once() -> decltype(PromiseFactoryImpl(std::move(f_))) { - return PromiseFactoryImpl(std::move(f_)); - } + Promise Once() { return PromiseFactoryImpl(std::move(f_)); } - auto Repeated() const -> decltype(PromiseFactoryImpl(f_)) { - return PromiseFactoryImpl(f_); - } + Promise Repeated() const { return PromiseFactoryImpl(f_); } }; } // namespace promise_detail diff --git a/src/core/lib/promise/if.h b/src/core/lib/promise/if.h new file mode 100644 index 00000000000..cef1e792eb3 --- /dev/null +++ b/src/core/lib/promise/if.h @@ -0,0 +1,130 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GRPC_CORE_LIB_PROMISE_IF_H +#define GRPC_CORE_LIB_PROMISE_IF_H + +#include + +#include "absl/status/statusor.h" +#include "absl/types/variant.h" +#include "src/core/lib/promise/detail/promise_factory.h" +#include "src/core/lib/promise/poll.h" + +namespace grpc_core { + +namespace promise_detail { + +template +typename CallPoll::PollResult ChooseIf(CallPoll call_poll, bool result, + T* if_true, F* if_false) { + if (result) { + auto promise = if_true->Once(); + return call_poll(promise); + } else { + auto promise = if_false->Once(); + return call_poll(promise); + } +} + +template +typename CallPoll::PollResult ChooseIf(CallPoll call_poll, + absl::StatusOr result, T* if_true, + F* if_false) { + if (!result.ok()) { + return typename CallPoll::PollResult(result.status()); + } else if (*result) { + auto promise = if_true->Once(); + return call_poll(promise); + } else { + auto promise = if_false->Once(); + return call_poll(promise); + } +} + +template +class If { + private: + using TrueFactory = promise_detail::PromiseFactory; + using FalseFactory = promise_detail::PromiseFactory; + using ConditionPromise = PromiseLike; + using TruePromise = typename TrueFactory::Promise; + using FalsePromise = typename FalseFactory::Promise; + using Result = + typename PollTraits()())>::Type; + + public: + If(C condition, T if_true, F if_false) + : state_(Evaluating{ConditionPromise(std::move(condition)), + TrueFactory(std::move(if_true)), + FalseFactory(std::move(if_false))}) {} + + Poll operator()() { + return absl::visit(CallPoll{this}, state_); + } + + private: + struct Evaluating { + ConditionPromise condition; + TrueFactory if_true; + FalseFactory if_false; + }; + using State = absl::variant; + State state_; + + template + struct CallPoll { + using PollResult = Poll; + + If* const self; + + PollResult operator()(Evaluating& evaluating) const { + static_assert( + !kSetState, + "shouldn't need to set state coming through the initial branch"); + auto r = evaluating.condition(); + if (auto* p = absl::get_if(&r)) { + return ChooseIf(CallPoll{self}, std::move(*p), + &evaluating.if_true, &evaluating.if_false); + } + return Pending(); + } + + template + PollResult operator()(Promise& promise) const { + auto r = promise(); + if (kSetState && absl::holds_alternative(r)) { + self->state_.template emplace(std::move(promise)); + } + return r; + } + }; +}; + +} // namespace promise_detail + +// If promise combinator. +// Takes 3 promise factories, and evaluates the first. +// If it returns failure, returns failure for the entire combinator. +// If it returns true, evaluates the second promise. +// If it returns false, evaluates the third promise. +template +promise_detail::If If(C condition, T if_true, F if_false) { + return promise_detail::If(std::move(condition), std::move(if_true), + std::move(if_false)); +} + +} // namespace grpc_core + +#endif // GRPC_CORE_LIB_PROMISE_IF_H diff --git a/test/core/promise/BUILD b/test/core/promise/BUILD index fced9932e23..f722eee8ad4 100644 --- a/test/core/promise/BUILD +++ b/test/core/promise/BUILD @@ -76,3 +76,15 @@ grpc_cc_test( "//test/core/util:grpc_suppressions", ], ) + +grpc_cc_test( + name = "if_test", + srcs = ["if_test.cc"], + external_deps = ["gtest"], + language = "c++", + uses_polling = False, + deps = [ + "//:if", + "//test/core/util:grpc_suppressions", + ], +) diff --git a/test/core/promise/if_test.cc b/test/core/promise/if_test.cc new file mode 100644 index 00000000000..99ad2d2efc4 --- /dev/null +++ b/test/core/promise/if_test.cc @@ -0,0 +1,57 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/core/lib/promise/if.h" +#include + +namespace grpc_core { + +TEST(IfTest, ChooseTrue) { + EXPECT_EQ(If([]() { return true; }, []() { return 1; }, []() { return 2; })(), + Poll(1)); +} + +TEST(IfTest, ChooseFalse) { + EXPECT_EQ( + If([]() { return false; }, []() { return 1; }, []() { return 2; })(), + Poll(2)); +} + +TEST(IfTest, ChooseSuccesfulTrue) { + EXPECT_EQ(If([]() { return absl::StatusOr(true); }, + []() { return absl::StatusOr(1); }, + []() { return absl::StatusOr(2); })(), + Poll>(absl::StatusOr(1))); +} + +TEST(IfTest, ChooseSuccesfulFalse) { + EXPECT_EQ(If([]() { return absl::StatusOr(false); }, + []() { return absl::StatusOr(1); }, + []() { return absl::StatusOr(2); })(), + Poll>(absl::StatusOr(2))); +} + +TEST(IfTest, ChooseFailure) { + EXPECT_EQ(If([]() { return absl::StatusOr(); }, + []() { return absl::StatusOr(1); }, + []() { return absl::StatusOr(2); })(), + Poll>(absl::StatusOr())); +} + +} // namespace grpc_core + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tools/run_tests/generated/tests.json b/tools/run_tests/generated/tests.json index e5ade4a05a9..3da05659499 100644 --- a/tools/run_tests/generated/tests.json +++ b/tools/run_tests/generated/tests.json @@ -5083,6 +5083,30 @@ ], "uses_polling": true }, + { + "args": [], + "benchmark": false, + "ci_platforms": [ + "linux", + "mac", + "posix", + "windows" + ], + "cpu_cost": 1.0, + "exclude_configs": [], + "exclude_iomgrs": [], + "flaky": false, + "gtest": true, + "language": "c++", + "name": "if_test", + "platforms": [ + "linux", + "mac", + "posix", + "windows" + ], + "uses_polling": false + }, { "args": [], "benchmark": false,