From 2fde70a6be978a4de9b7b119da8ca6120c5e271b Mon Sep 17 00:00:00 2001 From: Craig Tiller Date: Fri, 8 Dec 2023 16:54:07 -0800 Subject: [PATCH] [call-v3] Generic forwarder from a CallHandler to a CallInterceptor (#35256) Closes #35256 COPYBARA_INTEGRATE_REVIEW=https://github.com/grpc/grpc/pull/35256 from ctiller:cg-fwd-call cdaae8bccd99d1b39c310efc42b3fe98a6723802 PiperOrigin-RevId: 589278551 --- build_autogenerated.yaml | 2 + src/core/BUILD | 2 + src/core/lib/promise/detail/status.h | 2 +- src/core/lib/promise/for_each.h | 14 +++++- src/core/lib/promise/status_flag.h | 7 +++ src/core/lib/promise/try_seq.h | 38 +++++++++++++++- src/core/lib/transport/transport.cc | 66 ++++++++++++++++++++++++++++ src/core/lib/transport/transport.h | 30 ++++++++++++- 8 files changed, 155 insertions(+), 6 deletions(-) diff --git a/build_autogenerated.yaml b/build_autogenerated.yaml index f5ea18449e2..09460f09a88 100644 --- a/build_autogenerated.yaml +++ b/build_autogenerated.yaml @@ -9393,6 +9393,7 @@ targets: - src/core/lib/promise/poll.h - src/core/lib/promise/race.h - src/core/lib/promise/seq.h + - src/core/lib/promise/status_flag.h - src/core/lib/promise/trace.h - src/core/lib/promise/try_seq.h - src/core/lib/resource_quota/arena.h @@ -11561,6 +11562,7 @@ targets: - src/core/lib/promise/poll.h - src/core/lib/promise/race.h - src/core/lib/promise/seq.h + - src/core/lib/promise/status_flag.h - src/core/lib/promise/trace.h - src/core/lib/promise/try_seq.h - src/core/lib/resource_quota/arena.h diff --git a/src/core/BUILD b/src/core/BUILD index 449649facdc..36bad285ae5 100644 --- a/src/core/BUILD +++ b/src/core/BUILD @@ -1020,7 +1020,9 @@ grpc_cc_library( "construct_destruct", "poll", "promise_factory", + "promise_status", "promise_trace", + "status_flag", "//:gpr", "//:gpr_platform", ], diff --git a/src/core/lib/promise/detail/status.h b/src/core/lib/promise/detail/status.h index b20239de865..1063f329193 100644 --- a/src/core/lib/promise/detail/status.h +++ b/src/core/lib/promise/detail/status.h @@ -45,7 +45,7 @@ inline absl::Status IntoStatus(absl::Status* status) { // can participate in TrySeq as result types that affect control flow. inline bool IsStatusOk(const absl::Status& status) { return status.ok(); } -template +template struct StatusCastImpl; template diff --git a/src/core/lib/promise/for_each.h b/src/core/lib/promise/for_each.h index 1e6a8294312..2b8e9b10cc9 100644 --- a/src/core/lib/promise/for_each.h +++ b/src/core/lib/promise/for_each.h @@ -30,7 +30,9 @@ #include "src/core/lib/gprpp/construct_destruct.h" #include "src/core/lib/promise/activity.h" #include "src/core/lib/promise/detail/promise_factory.h" +#include "src/core/lib/promise/detail/status.h" #include "src/core/lib/promise/poll.h" +#include "src/core/lib/promise/status_flag.h" #include "src/core/lib/promise/trace.h" namespace grpc_core { @@ -48,6 +50,16 @@ struct Done { static absl::Status Make() { return absl::OkStatus(); } }; +template <> +struct Done { + static StatusFlag Make() { return StatusFlag(true); } +}; + +template <> +struct Done { + static Success Make() { return Success{}; } +}; + template class ForEach { private: @@ -139,7 +151,7 @@ class ForEach { } auto r = in_action_.promise(); if (auto* p = r.value_if_ready()) { - if (p->ok()) { + if (IsStatusOk(*p)) { Destruct(&in_action_); Construct(&reader_next_, reader_.Next()); reading_next_ = true; diff --git a/src/core/lib/promise/status_flag.h b/src/core/lib/promise/status_flag.h index 54019d38740..1d7b09fabea 100644 --- a/src/core/lib/promise/status_flag.h +++ b/src/core/lib/promise/status_flag.h @@ -66,6 +66,13 @@ struct StatusCastImpl { } }; +template <> +struct StatusCastImpl { + static absl::Status Cast(StatusFlag flag) { + return flag.ok() ? absl::OkStatus() : absl::CancelledError(); + } +}; + template <> struct StatusCastImpl { static absl::Status Cast(StatusFlag flag) { diff --git a/src/core/lib/promise/try_seq.h b/src/core/lib/promise/try_seq.h index 71946e0e290..266a09839a7 100644 --- a/src/core/lib/promise/try_seq.h +++ b/src/core/lib/promise/try_seq.h @@ -85,13 +85,23 @@ struct TrySeqTraitsWithSfinae> { return run_next(std::move(prior)); } }; +template +struct TakeValueExists { + static constexpr bool value = false; +}; +template +struct TakeValueExists()))>> { + static constexpr bool value = true; +}; // If there exists a function 'IsStatusOk(const T&) -> bool' then we assume that // T is a status type for the purposes of promise sequences, and a non-OK T // should terminate the sequence and return. template struct TrySeqTraitsWithSfinae< T, absl::enable_if_t< - std::is_same())), bool>::value, + std::is_same())), bool>::value && + !TakeValueExists::value, void>> { using UnwrappedType = void; using WrappedType = T; @@ -102,7 +112,31 @@ struct TrySeqTraitsWithSfinae< static bool IsOk(const T& status) { return IsStatusOk(status); } template static R ReturnValue(T&& status) { - return R(std::move(status)); + return StatusCast(std::move(status)); + } + template + static Poll CheckResultAndRunNext(T prior, RunNext run_next) { + if (!IsStatusOk(prior)) return Result(std::move(prior)); + return run_next(std::move(prior)); + } +}; +template +struct TrySeqTraitsWithSfinae< + T, absl::enable_if_t< + std::is_same())), bool>::value && + TakeValueExists::value, + void>> { + using UnwrappedType = decltype(TakeValue(std::declval())); + using WrappedType = T; + template + static auto CallFactory(Next* next, T&& status) { + return next->Make(TakeValue(std::forward(status))); + } + static bool IsOk(const T& status) { return IsStatusOk(status); } + template + static R ReturnValue(T&& status) { + GPR_DEBUG_ASSERT(!IsStatusOk(status)); + return StatusCast(std::move(status)); } template static Poll CheckResultAndRunNext(T prior, RunNext run_next) { diff --git a/src/core/lib/transport/transport.cc b/src/core/lib/transport/transport.cc index bbe25613649..ab405804065 100644 --- a/src/core/lib/transport/transport.cc +++ b/src/core/lib/transport/transport.cc @@ -35,6 +35,9 @@ #include "src/core/lib/event_engine/default_event_engine.h" #include "src/core/lib/gprpp/time.h" #include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/promise/for_each.h" +#include "src/core/lib/promise/promise.h" +#include "src/core/lib/promise/try_seq.h" #include "src/core/lib/slice/slice.h" #include "src/core/lib/transport/error_utils.h" @@ -268,4 +271,67 @@ std::string Message::DebugString() const { return out; } +void ForwardCall(CallHandler call_handler, CallInitiator call_initiator, + ClientMetadataHandle client_initial_metadata) { + // Send initial metadata. + call_initiator.SpawnGuarded( + "send_initial_metadata", + [client_initial_metadata = std::move(client_initial_metadata), + call_initiator]() mutable { + return call_initiator.PushClientInitialMetadata( + std::move(client_initial_metadata)); + }); + // Read messages from handler into initiator. + call_handler.SpawnGuarded( + "read_messages", [call_handler, call_initiator]() mutable { + return ForEach(OutgoingMessages(call_handler), + [call_initiator](MessageHandle msg) mutable { + // Need to spawn a job into the initiator's activity to + // push the message in. + return call_initiator.SpawnWaitable( + "send_message", + [msg = std::move(msg), call_initiator]() mutable { + return call_initiator.CancelIfFails(Map( + call_initiator.PushMessage(std::move(msg)), + [](bool r) { return StatusFlag(r); })); + }); + }); + }); + call_initiator.SpawnInfallible("read_the_things", [call_initiator, + call_handler]() mutable { + return Seq( + call_initiator.CancelIfFails(TrySeq( + call_initiator.PullServerInitialMetadata(), + [call_handler](ServerMetadataHandle md) mutable { + call_handler.SpawnGuarded( + "recv_initial_metadata", + [md = std::move(md), call_handler]() mutable { + return call_handler.PushServerInitialMetadata( + std::move(md)); + }); + return Success{}; + }, + ForEach(OutgoingMessages(call_initiator), + [call_handler](MessageHandle msg) mutable { + return call_handler.SpawnWaitable( + "recv_message", + [msg = std::move(msg), call_handler]() mutable { + return call_handler.CancelIfFails( + Map(call_handler.PushMessage(std::move(msg)), + [](bool r) { return StatusFlag(r); })); + }); + }), + ImmediateOkStatus())), + call_initiator.PullServerTrailingMetadata(), + [call_handler](ServerMetadataHandle md) mutable { + call_handler.SpawnGuarded( + "recv_trailing_metadata", + [md = std::move(md), call_handler]() mutable { + return call_handler.PushServerTrailingMetadata(std::move(md)); + }); + return Empty{}; + }); + }); +} + } // namespace grpc_core diff --git a/src/core/lib/transport/transport.h b/src/core/lib/transport/transport.h index f34692c8ddc..d35bdf5179a 100644 --- a/src/core/lib/transport/transport.h +++ b/src/core/lib/transport/transport.h @@ -150,6 +150,17 @@ struct StatusCastImpl { } }; +// Anything that can be first cast to absl::Status can then be cast to +// ServerMetadataHandle. +template +struct StatusCastImpl< + ServerMetadataHandle, T, + absl::void_t(std::declval()))>> { + static ServerMetadataHandle Cast(const T& m) { + return ServerMetadataFromStatus(StatusCast(m)); + } +}; + // Move only type that tracks call startup. // Allows observation of when client_initial_metadata has been processed by the // end of the local call stack. @@ -283,9 +294,9 @@ class CallSpineInterface { using ResultType = typename P::Result; return Map(std::move(promise), [this](ResultType r) { if (!IsStatusOk(r)) { - std::ignore = Cancel(StatusCast(std::move(r))); + std::ignore = Cancel(StatusCast(r)); } - return Empty{}; + return r; }); } @@ -410,6 +421,11 @@ class CallInitiator { spine_->SpawnInfallible(name, std::move(promise_factory)); } + template + auto SpawnWaitable(absl::string_view name, PromiseFactory promise_factory) { + return spine_->party().SpawnWaitable(name, std::move(promise_factory)); + } + private: const RefCountedPtr spine_; }; @@ -466,6 +482,11 @@ class CallHandler { spine_->SpawnInfallible(name, std::move(promise_factory)); } + template + auto SpawnWaitable(absl::string_view name, PromiseFactory promise_factory) { + return spine_->party().SpawnWaitable(name, std::move(promise_factory)); + } + private: const RefCountedPtr spine_; }; @@ -479,6 +500,11 @@ auto OutgoingMessages(CallHalf& h) { return Wrapper{h}; } +// Forward a call from `call_handler` to `call_initiator` (with initial metadata +// `client_initial_metadata`) +void ForwardCall(CallHandler call_handler, CallInitiator call_initiator, + ClientMetadataHandle client_initial_metadata); + } // namespace grpc_core // forward declarations