From b2584dc863863cce891a47a5fd62299e3444ab24 Mon Sep 17 00:00:00 2001 From: Craig Tiller Date: Fri, 1 Dec 2023 05:47:58 +0000 Subject: [PATCH] [promises] Add a helper to forward from handler to initiator --- src/core/BUILD | 1 + src/core/lib/promise/detail/status.h | 2 +- src/core/lib/promise/for_each.h | 14 ++++++- src/core/lib/promise/try_seq.h | 38 ++++++++++++++++- src/core/lib/transport/transport.cc | 62 ++++++++++++++++++++++++++++ src/core/lib/transport/transport.h | 24 +++++++++-- 6 files changed, 133 insertions(+), 8 deletions(-) diff --git a/src/core/BUILD b/src/core/BUILD index 351f101b819..9f0c4bcd098 100644 --- a/src/core/BUILD +++ b/src/core/BUILD @@ -1021,6 +1021,7 @@ grpc_cc_library( "poll", "promise_factory", "promise_trace", + "status_flag", "//:gpr", "//:gpr_platform", ], diff --git a/src/core/lib/promise/detail/status.h b/src/core/lib/promise/detail/status.h index b20239de865..1063f329193 100644 --- a/src/core/lib/promise/detail/status.h +++ b/src/core/lib/promise/detail/status.h @@ -45,7 +45,7 @@ inline absl::Status IntoStatus(absl::Status* status) { // can participate in TrySeq as result types that affect control flow. inline bool IsStatusOk(const absl::Status& status) { return status.ok(); } -template +template struct StatusCastImpl; template diff --git a/src/core/lib/promise/for_each.h b/src/core/lib/promise/for_each.h index 1e6a8294312..8510b56b4ae 100644 --- a/src/core/lib/promise/for_each.h +++ b/src/core/lib/promise/for_each.h @@ -24,6 +24,7 @@ #include "absl/status/status.h" #include "absl/strings/str_cat.h" +#include "detail/status.h" #include @@ -31,6 +32,7 @@ #include "src/core/lib/promise/activity.h" #include "src/core/lib/promise/detail/promise_factory.h" #include "src/core/lib/promise/poll.h" +#include "src/core/lib/promise/status_flag.h" #include "src/core/lib/promise/trace.h" namespace grpc_core { @@ -48,6 +50,16 @@ struct Done { static absl::Status Make() { return absl::OkStatus(); } }; +template <> +struct Done { + static StatusFlag Make() { return StatusFlag(true); } +}; + +template <> +struct Done { + static Success Make() { return Success{}; } +}; + template class ForEach { private: @@ -139,7 +151,7 @@ class ForEach { } auto r = in_action_.promise(); if (auto* p = r.value_if_ready()) { - if (p->ok()) { + if (IsStatusOk(*p)) { Destruct(&in_action_); Construct(&reader_next_, reader_.Next()); reading_next_ = true; diff --git a/src/core/lib/promise/try_seq.h b/src/core/lib/promise/try_seq.h index 71946e0e290..266a09839a7 100644 --- a/src/core/lib/promise/try_seq.h +++ b/src/core/lib/promise/try_seq.h @@ -85,13 +85,23 @@ struct TrySeqTraitsWithSfinae> { return run_next(std::move(prior)); } }; +template +struct TakeValueExists { + static constexpr bool value = false; +}; +template +struct TakeValueExists()))>> { + static constexpr bool value = true; +}; // If there exists a function 'IsStatusOk(const T&) -> bool' then we assume that // T is a status type for the purposes of promise sequences, and a non-OK T // should terminate the sequence and return. template struct TrySeqTraitsWithSfinae< T, absl::enable_if_t< - std::is_same())), bool>::value, + std::is_same())), bool>::value && + !TakeValueExists::value, void>> { using UnwrappedType = void; using WrappedType = T; @@ -102,7 +112,31 @@ struct TrySeqTraitsWithSfinae< static bool IsOk(const T& status) { return IsStatusOk(status); } template static R ReturnValue(T&& status) { - return R(std::move(status)); + return StatusCast(std::move(status)); + } + template + static Poll CheckResultAndRunNext(T prior, RunNext run_next) { + if (!IsStatusOk(prior)) return Result(std::move(prior)); + return run_next(std::move(prior)); + } +}; +template +struct TrySeqTraitsWithSfinae< + T, absl::enable_if_t< + std::is_same())), bool>::value && + TakeValueExists::value, + void>> { + using UnwrappedType = decltype(TakeValue(std::declval())); + using WrappedType = T; + template + static auto CallFactory(Next* next, T&& status) { + return next->Make(TakeValue(std::forward(status))); + } + static bool IsOk(const T& status) { return IsStatusOk(status); } + template + static R ReturnValue(T&& status) { + GPR_DEBUG_ASSERT(!IsStatusOk(status)); + return StatusCast(std::move(status)); } template static Poll CheckResultAndRunNext(T prior, RunNext run_next) { diff --git a/src/core/lib/transport/transport.cc b/src/core/lib/transport/transport.cc index bbe25613649..85c43f8576f 100644 --- a/src/core/lib/transport/transport.cc +++ b/src/core/lib/transport/transport.cc @@ -35,6 +35,9 @@ #include "src/core/lib/event_engine/default_event_engine.h" #include "src/core/lib/gprpp/time.h" #include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/promise/for_each.h" +#include "src/core/lib/promise/promise.h" +#include "src/core/lib/promise/try_seq.h" #include "src/core/lib/slice/slice.h" #include "src/core/lib/transport/error_utils.h" @@ -268,4 +271,63 @@ std::string Message::DebugString() const { return out; } +void ForwardCall(CallHandler call_handler, CallInitiator call_initiator, + ClientMetadataHandle client_initial_metadata) { + call_initiator.SpawnGuarded( + "send_initial_metadata", + [client_initial_metadata = std::move(client_initial_metadata), + call_initiator]() mutable { + return call_initiator.PushClientInitialMetadata( + std::move(client_initial_metadata)); + }); + call_handler.SpawnGuarded("read_messages", [call_handler, + call_initiator]() mutable { + return ForEach( + OutgoingMessages(call_handler), + [call_initiator](MessageHandle msg) mutable { + call_initiator.SpawnGuarded( + "send_message", [msg = std::move(msg), call_initiator]() mutable { + return Map(call_initiator.PushMessage(std::move(msg)), + [](bool r) { return StatusFlag(r); }); + }); + return Success{}; + }); + }); + call_initiator.SpawnInfallible("read_the_things", [call_initiator, + call_handler]() mutable { + return Seq( + call_initiator.CancelIfFails(TrySeq( + call_initiator.PullServerInitialMetadata(), + [call_handler](ServerMetadataHandle md) mutable { + call_handler.SpawnGuarded( + "recv_initial_metadata", + [md = std::move(md), call_handler]() mutable { + return call_handler.PushServerInitialMetadata( + std::move(md)); + }); + return Success{}; + }, + ForEach(OutgoingMessages(call_initiator), + [call_handler](MessageHandle msg) mutable { + call_handler.SpawnGuarded( + "recv_message", + [msg = std::move(msg), call_handler]() mutable { + return Map(call_handler.PushMessage(std::move(msg)), + [](bool r) { return StatusFlag(r); }); + }); + return Success{}; + }), + ImmediateOkStatus())), + call_initiator.PullServerTrailingMetadata(), + [call_handler](ServerMetadataHandle md) mutable { + call_handler.SpawnGuarded( + "recv_trailing_metadata", + [md = std::move(md), call_handler]() mutable { + return call_handler.PushServerTrailingMetadata(std::move(md)); + }); + return Empty{}; + }); + }); +} + } // namespace grpc_core diff --git a/src/core/lib/transport/transport.h b/src/core/lib/transport/transport.h index 7b3393bfcc3..e677e5142dc 100644 --- a/src/core/lib/transport/transport.h +++ b/src/core/lib/transport/transport.h @@ -150,6 +150,17 @@ struct StatusCastImpl { } }; +// Anything that can be first cast to absl::Status can then be cast to +// ServerMetadataHandle. +template +struct StatusCastImpl< + ServerMetadataHandle, T, + absl::void_t(std::declval()))>> { + static ServerMetadataHandle Cast(const T& m) { + return ServerMetadataFromStatus(StatusCast(m)); + } +}; + // Move only type that tracks call startup. // Allows observation of when client_initial_metadata has been processed by the // end of the local call stack. @@ -228,6 +239,9 @@ struct CallArgs { PipeSender* server_to_client_messages; }; +using NextPromiseFactory = + std::function(CallArgs)>; + // TODO(ctiller): eventually drop this when we don't need to reference into // legacy promise calls anymore class CallSpineInterface { @@ -266,7 +280,7 @@ class CallSpineInterface { using ResultType = typename P::Result; return Map(std::move(promise), [this](ResultType r) { if (!IsStatusOk(r)) { - Cancel(StatusCast(std::move(r))); + std::ignore = Cancel(StatusCast(std::move(r))); } return Empty{}; }); @@ -293,7 +307,7 @@ class CallSpineInterface { "SpawnGuarded promise must return a status-like object"); party().Spawn(name, std::move(promise_factory), [this](ResultType r) { if (!IsStatusOk(r)) { - Cancel(StatusCast(std::move(r))); + std::ignore = Cancel(StatusCast(std::move(r))); } }); } @@ -462,8 +476,10 @@ auto OutgoingMessages(CallHalf& h) { return Wrapper{h}; } -using NextPromiseFactory = - std::function(CallArgs)>; +// Forward a call from `call_handler` to `call_initiator` (with initial metadata +// `client_initial_metadata`) +void ForwardCall(CallHandler call_handler, CallInitiator call_initiator, + ClientMetadataHandle client_initial_metadata); } // namespace grpc_core