[promises] Add a helper to forward from handler to initiator

pull/35256/head
Craig Tiller 1 year ago
parent af5dd84a0e
commit b2584dc863
  1. 1
      src/core/BUILD
  2. 2
      src/core/lib/promise/detail/status.h
  3. 14
      src/core/lib/promise/for_each.h
  4. 38
      src/core/lib/promise/try_seq.h
  5. 62
      src/core/lib/transport/transport.cc
  6. 24
      src/core/lib/transport/transport.h

@ -1021,6 +1021,7 @@ grpc_cc_library(
"poll",
"promise_factory",
"promise_trace",
"status_flag",
"//:gpr",
"//:gpr_platform",
],

@ -45,7 +45,7 @@ inline absl::Status IntoStatus(absl::Status* status) {
// can participate in TrySeq as result types that affect control flow.
inline bool IsStatusOk(const absl::Status& status) { return status.ok(); }
template <typename To, typename From>
template <typename To, typename From, typename SfinaeVoid = void>
struct StatusCastImpl;
template <typename To>

@ -24,6 +24,7 @@
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "detail/status.h"
#include <grpc/support/log.h>
@ -31,6 +32,7 @@
#include "src/core/lib/promise/activity.h"
#include "src/core/lib/promise/detail/promise_factory.h"
#include "src/core/lib/promise/poll.h"
#include "src/core/lib/promise/status_flag.h"
#include "src/core/lib/promise/trace.h"
namespace grpc_core {
@ -48,6 +50,16 @@ struct Done<absl::Status> {
static absl::Status Make() { return absl::OkStatus(); }
};
template <>
struct Done<StatusFlag> {
static StatusFlag Make() { return StatusFlag(true); }
};
template <>
struct Done<Success> {
static Success Make() { return Success{}; }
};
template <typename Reader, typename Action>
class ForEach {
private:
@ -139,7 +151,7 @@ class ForEach {
}
auto r = in_action_.promise();
if (auto* p = r.value_if_ready()) {
if (p->ok()) {
if (IsStatusOk(*p)) {
Destruct(&in_action_);
Construct(&reader_next_, reader_.Next());
reading_next_ = true;

@ -85,13 +85,23 @@ struct TrySeqTraitsWithSfinae<absl::StatusOr<T>> {
return run_next(std::move(prior));
}
};
template <typename T, typename AnyType = void>
struct TakeValueExists {
static constexpr bool value = false;
};
template <typename T>
struct TakeValueExists<T,
absl::void_t<decltype(TakeValue(std::declval<T>()))>> {
static constexpr bool value = true;
};
// If there exists a function 'IsStatusOk(const T&) -> bool' then we assume that
// T is a status type for the purposes of promise sequences, and a non-OK T
// should terminate the sequence and return.
template <typename T>
struct TrySeqTraitsWithSfinae<
T, absl::enable_if_t<
std::is_same<decltype(IsStatusOk(std::declval<T>())), bool>::value,
std::is_same<decltype(IsStatusOk(std::declval<T>())), bool>::value &&
!TakeValueExists<T>::value,
void>> {
using UnwrappedType = void;
using WrappedType = T;
@ -102,7 +112,31 @@ struct TrySeqTraitsWithSfinae<
static bool IsOk(const T& status) { return IsStatusOk(status); }
template <typename R>
static R ReturnValue(T&& status) {
return R(std::move(status));
return StatusCast<R>(std::move(status));
}
template <typename Result, typename RunNext>
static Poll<Result> CheckResultAndRunNext(T prior, RunNext run_next) {
if (!IsStatusOk(prior)) return Result(std::move(prior));
return run_next(std::move(prior));
}
};
template <typename T>
struct TrySeqTraitsWithSfinae<
T, absl::enable_if_t<
std::is_same<decltype(IsStatusOk(std::declval<T>())), bool>::value &&
TakeValueExists<T>::value,
void>> {
using UnwrappedType = decltype(TakeValue(std::declval<T>()));
using WrappedType = T;
template <typename Next>
static auto CallFactory(Next* next, T&& status) {
return next->Make(TakeValue(std::forward<T>(status)));
}
static bool IsOk(const T& status) { return IsStatusOk(status); }
template <typename R>
static R ReturnValue(T&& status) {
GPR_DEBUG_ASSERT(!IsStatusOk(status));
return StatusCast<R>(std::move(status));
}
template <typename Result, typename RunNext>
static Poll<Result> CheckResultAndRunNext(T prior, RunNext run_next) {

@ -35,6 +35,9 @@
#include "src/core/lib/event_engine/default_event_engine.h"
#include "src/core/lib/gprpp/time.h"
#include "src/core/lib/iomgr/exec_ctx.h"
#include "src/core/lib/promise/for_each.h"
#include "src/core/lib/promise/promise.h"
#include "src/core/lib/promise/try_seq.h"
#include "src/core/lib/slice/slice.h"
#include "src/core/lib/transport/error_utils.h"
@ -268,4 +271,63 @@ std::string Message::DebugString() const {
return out;
}
void ForwardCall(CallHandler call_handler, CallInitiator call_initiator,
ClientMetadataHandle client_initial_metadata) {
call_initiator.SpawnGuarded(
"send_initial_metadata",
[client_initial_metadata = std::move(client_initial_metadata),
call_initiator]() mutable {
return call_initiator.PushClientInitialMetadata(
std::move(client_initial_metadata));
});
call_handler.SpawnGuarded("read_messages", [call_handler,
call_initiator]() mutable {
return ForEach(
OutgoingMessages(call_handler),
[call_initiator](MessageHandle msg) mutable {
call_initiator.SpawnGuarded(
"send_message", [msg = std::move(msg), call_initiator]() mutable {
return Map(call_initiator.PushMessage(std::move(msg)),
[](bool r) { return StatusFlag(r); });
});
return Success{};
});
});
call_initiator.SpawnInfallible("read_the_things", [call_initiator,
call_handler]() mutable {
return Seq(
call_initiator.CancelIfFails(TrySeq(
call_initiator.PullServerInitialMetadata(),
[call_handler](ServerMetadataHandle md) mutable {
call_handler.SpawnGuarded(
"recv_initial_metadata",
[md = std::move(md), call_handler]() mutable {
return call_handler.PushServerInitialMetadata(
std::move(md));
});
return Success{};
},
ForEach(OutgoingMessages(call_initiator),
[call_handler](MessageHandle msg) mutable {
call_handler.SpawnGuarded(
"recv_message",
[msg = std::move(msg), call_handler]() mutable {
return Map(call_handler.PushMessage(std::move(msg)),
[](bool r) { return StatusFlag(r); });
});
return Success{};
}),
ImmediateOkStatus())),
call_initiator.PullServerTrailingMetadata(),
[call_handler](ServerMetadataHandle md) mutable {
call_handler.SpawnGuarded(
"recv_trailing_metadata",
[md = std::move(md), call_handler]() mutable {
return call_handler.PushServerTrailingMetadata(std::move(md));
});
return Empty{};
});
});
}
} // namespace grpc_core

@ -150,6 +150,17 @@ struct StatusCastImpl<ServerMetadataHandle, absl::Status&> {
}
};
// Anything that can be first cast to absl::Status can then be cast to
// ServerMetadataHandle.
template <typename T>
struct StatusCastImpl<
ServerMetadataHandle, T,
absl::void_t<decltype(StatusCast<absl::Status>(std::declval<T>()))>> {
static ServerMetadataHandle Cast(const T& m) {
return ServerMetadataFromStatus(StatusCast<absl::Status>(m));
}
};
// Move only type that tracks call startup.
// Allows observation of when client_initial_metadata has been processed by the
// end of the local call stack.
@ -228,6 +239,9 @@ struct CallArgs {
PipeSender<MessageHandle>* server_to_client_messages;
};
using NextPromiseFactory =
std::function<ArenaPromise<ServerMetadataHandle>(CallArgs)>;
// TODO(ctiller): eventually drop this when we don't need to reference into
// legacy promise calls anymore
class CallSpineInterface {
@ -266,7 +280,7 @@ class CallSpineInterface {
using ResultType = typename P::Result;
return Map(std::move(promise), [this](ResultType r) {
if (!IsStatusOk(r)) {
Cancel(StatusCast<ServerMetadataHandle>(std::move(r)));
std::ignore = Cancel(StatusCast<ServerMetadataHandle>(std::move(r)));
}
return Empty{};
});
@ -293,7 +307,7 @@ class CallSpineInterface {
"SpawnGuarded promise must return a status-like object");
party().Spawn(name, std::move(promise_factory), [this](ResultType r) {
if (!IsStatusOk(r)) {
Cancel(StatusCast<ServerMetadataHandle>(std::move(r)));
std::ignore = Cancel(StatusCast<ServerMetadataHandle>(std::move(r)));
}
});
}
@ -462,8 +476,10 @@ auto OutgoingMessages(CallHalf& h) {
return Wrapper{h};
}
using NextPromiseFactory =
std::function<ArenaPromise<ServerMetadataHandle>(CallArgs)>;
// Forward a call from `call_handler` to `call_initiator` (with initial metadata
// `client_initial_metadata`)
void ForwardCall(CallHandler call_handler, CallInitiator call_initiator,
ClientMetadataHandle client_initial_metadata);
} // namespace grpc_core

Loading…
Cancel
Save