pull/36509/head
Craig Tiller 10 months ago
parent 51933446c1
commit 7c82e2f2b9
  1. 11
      src/core/lib/promise/detail/promise_like.h
  2. 137
      src/core/lib/surface/call.cc

@ -17,6 +17,7 @@
#include <utility>
#include "absl/functional/any_invocable.h"
#include "absl/meta/type_traits.h"
#include <grpc/support/port_platform.h>
@ -63,6 +64,10 @@ auto WrapInPoll(T&& x) -> decltype(PollWrapper<T>::Wrap(std::forward<T>(x))) {
return PollWrapper<T>::Wrap(std::forward<T>(x));
}
// T -> T, const T& -> T
template <typename T>
using RemoveCVRef = absl::remove_cv_t<absl::remove_reference_t<T>>;
template <typename F, typename SfinaeVoid = void>
class PromiseLike;
@ -73,7 +78,7 @@ template <typename F>
class PromiseLike<F, absl::enable_if_t<!std::is_void<
typename std::result_of<F()>::type>::value>> {
private:
GPR_NO_UNIQUE_ADDRESS F f_;
GPR_NO_UNIQUE_ADDRESS RemoveCVRef<F> f_;
public:
// NOLINTNEXTLINE - internal detail that drastically simplifies calling code.
@ -82,10 +87,6 @@ class PromiseLike<F, absl::enable_if_t<!std::is_void<
using Result = typename PollTraits<decltype(WrapInPoll(f_()))>::Type;
};
// T -> T, const T& -> T
template <typename T>
using RemoveCVRef = absl::remove_cv_t<absl::remove_reference_t<T>>;
} // namespace promise_detail
} // namespace grpc_core

@ -3276,62 +3276,83 @@ grpc_call_error ServerCall::StartBatch(const grpc_op* ops, size_t nops,
}
namespace {
template <typename SetupFn>
template <typename SetupResult, grpc_op_type kOp>
class MaybeOpImpl {
public:
using SetupResult = decltype(std::declval<SetupFn>()(grpc_op()));
using PromiseFactory = promise_detail::OncePromiseFactory<void, SetupResult>;
using Promise = typename PromiseFactory::Promise;
static_assert(!std::is_same<Promise, void>::value,
"PromiseFactory must return a promise");
struct Dismissed {};
using State = absl::variant<Dismissed, PromiseFactory, Promise>;
// op_ is garbage but shouldn't be uninitialized
MaybeOpImpl() : state_(Dismissed{}), op_(GRPC_OP_RECV_STATUS_ON_CLIENT) {}
MaybeOpImpl(SetupResult result, grpc_op_type op)
: state_(PromiseFactory(std::move(result))), op_(op) {}
MaybeOpImpl() : state_(State::kDismissed) {}
explicit MaybeOpImpl(SetupResult result) : state_(State::kPromiseFactory) {
Construct(&promise_factory_, std::move(result));
}
~MaybeOpImpl() {
switch (state_) {
case State::kDismissed:
case State::kPromiseFactory:
Destruct(&promise_factory_);
break;
case State::kPromise:
Destruct(&promise_);
break;
}
}
MaybeOpImpl(const MaybeOpImpl&) = delete;
MaybeOpImpl& operator=(const MaybeOpImpl&) = delete;
MaybeOpImpl(MaybeOpImpl&& other) noexcept
: state_(MoveState(other.state_)), op_(other.op_) {}
MaybeOpImpl& operator=(MaybeOpImpl&& other) noexcept {
op_ = other.op_;
if (absl::holds_alternative<Dismissed>(state_)) {
state_.template emplace<Dismissed>();
return *this;
MaybeOpImpl(MaybeOpImpl&& other) noexcept : state_(other.state_) {
switch (state_) {
case State::kDismissed:
case State::kPromiseFactory:
Construct(&promise_factory_, std::move(other.promise_factory_));
break;
case State::kPromise:
Construct(&promise_, std::move(other.promise_));
break;
}
// Can't move after first poll => Promise is not an option
state_.template emplace<PromiseFactory>(
std::move(absl::get<PromiseFactory>(other.state_)));
return *this;
}
MaybeOpImpl& operator=(MaybeOpImpl&& other) noexcept = delete;
Poll<StatusFlag> operator()() {
if (absl::holds_alternative<Dismissed>(state_)) return Success{};
if (absl::holds_alternative<PromiseFactory>(state_)) {
auto& factory = absl::get<PromiseFactory>(state_);
auto promise = factory.Make();
state_.template emplace<Promise>(std::move(promise));
}
if (grpc_call_trace.enabled()) {
gpr_log(GPR_INFO, "%sBeginPoll %s",
Activity::current()->DebugTag().c_str(), OpName(op_).c_str());
}
auto& promise = absl::get<Promise>(state_);
auto r = poll_cast<StatusFlag>(promise());
if (grpc_call_trace.enabled()) {
gpr_log(GPR_INFO, "%sEndPoll %s --> %s",
Activity::current()->DebugTag().c_str(), OpName(op_).c_str(),
switch (state_) {
case State::kDismissed:
return Success{};
case State::kPromiseFactory: {
auto promise = promise_factory_.Make();
Destruct(&promise_factory_);
Construct(&promise_, std::move(promise));
state_ = State::kPromise;
}
ABSL_FALLTHROUGH_INTENDED;
case State::kPromise: {
if (grpc_call_trace.enabled()) {
gpr_log(GPR_INFO, "%sBeginPoll %s",
Activity::current()->DebugTag().c_str(), OpName());
}
auto r = poll_cast<StatusFlag>(promise_());
if (grpc_call_trace.enabled()) {
gpr_log(
GPR_INFO, "%sEndPoll %s --> %s",
Activity::current()->DebugTag().c_str(), OpName(),
r.pending() ? "PENDING" : (r.value().ok() ? "OK" : "FAILURE"));
}
return r;
}
}
return r;
}
private:
static std::string OpName(grpc_op_type op) {
switch (op) {
enum class State {
kDismissed,
kPromiseFactory,
kPromise,
};
static const char* OpName() {
switch (kOp) {
case GRPC_OP_SEND_INITIAL_METADATA:
return "SendInitialMetadata";
case GRPC_OP_SEND_MESSAGE:
@ -3349,17 +3370,15 @@ class MaybeOpImpl {
case GRPC_OP_RECV_STATUS_ON_CLIENT:
return "RecvStatusOnClient";
}
return absl::StrCat("UnknownOp(", op, ")");
}
static State MoveState(State& state) {
if (absl::holds_alternative<Dismissed>(state)) return Dismissed{};
// Can't move after first poll => Promise is not an option
return std::move(absl::get<PromiseFactory>(state));
Crash("Unreachable");
}
// gcc-12 has problems with this being a variant
GPR_NO_UNIQUE_ADDRESS State state_;
GPR_NO_UNIQUE_ADDRESS grpc_op_type op_;
union {
PromiseFactory promise_factory_;
Promise promise_;
};
};
// MaybeOp captures a fairly complicated dance we need to do for the batch
@ -3370,12 +3389,14 @@ class MaybeOpImpl {
// ultimately poll on til completion.
// Once we express our surface API in terms of core internal types this whole
// dance will go away.
template <typename SetupFn>
auto MaybeOp(const grpc_op* ops, uint8_t idx, SetupFn setup) {
if (idx == 255) {
return MaybeOpImpl<SetupFn>();
template <grpc_op_type op_type, typename SetupFn>
auto MaybeOp(const grpc_op* ops, const std::array<uint8_t, 8>& idxs, SetupFn setup) {
using SetupResult = decltype(std::declval<SetupFn>()(grpc_op()));
if (idxs[op_type] == 255) {
return MaybeOpImpl<SetupResult, op_type>();
} else {
return MaybeOpImpl<SetupFn>(setup(ops[idx]), ops[idx].op);
auto r = setup(ops[idxs[op_type]]);
return MaybeOpImpl<SetupResult, op_type>(std::move(r));
}
}
@ -3466,8 +3487,8 @@ void ServerCall::CommitBatch(const grpc_op* ops, size_t nops, void* notify_tag,
got_ops[op.op] = op_idx;
}
if (!is_notify_tag_closure) grpc_cq_begin_op(cq_, notify_tag);
auto send_initial_metadata = MaybeOp(
ops, got_ops[GRPC_OP_SEND_INITIAL_METADATA], [this](const grpc_op& op) {
auto send_initial_metadata = MaybeOp<GRPC_OP_SEND_INITIAL_METADATA>(
ops, got_ops, [this](const grpc_op& op) {
auto metadata = arena()->MakePooled<ServerMetadata>();
PrepareOutgoingInitialMetadata(op, *metadata);
CToMetadata(op.data.send_initial_metadata.metadata,
@ -3481,7 +3502,7 @@ void ServerCall::CommitBatch(const grpc_op* ops, size_t nops, void* notify_tag,
};
});
auto send_message =
MaybeOp(ops, got_ops[GRPC_OP_SEND_MESSAGE], [this](const grpc_op& op) {
MaybeOp<GRPC_OP_SEND_MESSAGE>(ops, got_ops, [this](const grpc_op& op) {
SliceBuffer send;
grpc_slice_buffer_swap(
&op.data.send_message.send_message->data.raw.slice_buffer,
@ -3491,8 +3512,8 @@ void ServerCall::CommitBatch(const grpc_op* ops, size_t nops, void* notify_tag,
return call_handler_.PushMessage(std::move(msg));
};
});
auto send_trailing_metadata = MaybeOp(
ops, got_ops[GRPC_OP_SEND_STATUS_FROM_SERVER], [this](const grpc_op& op) {
auto send_trailing_metadata = MaybeOp<GRPC_OP_SEND_STATUS_FROM_SERVER>(
ops, got_ops, [this](const grpc_op& op) {
auto metadata = arena()->MakePooled<ServerMetadata>();
CToMetadata(op.data.send_status_from_server.trailing_metadata,
op.data.send_status_from_server.trailing_metadata_count,
@ -3519,7 +3540,7 @@ void ServerCall::CommitBatch(const grpc_op* ops, size_t nops, void* notify_tag,
};
});
auto recv_message =
MaybeOp(ops, got_ops[GRPC_OP_RECV_MESSAGE], [this](const grpc_op& op) {
MaybeOp<GRPC_OP_RECV_MESSAGE>(ops, got_ops, [this](const grpc_op& op) {
CHECK_EQ(recv_message_, nullptr);
recv_message_ = op.data.recv_message.recv_message;
return [this]() mutable {
@ -3535,8 +3556,8 @@ void ServerCall::CommitBatch(const grpc_op* ops, size_t nops, void* notify_tag,
std::move(send_trailing_metadata)),
std::move(recv_message));
if (got_ops[GRPC_OP_RECV_CLOSE_ON_SERVER] != 255) {
auto recv_trailing_metadata = MaybeOp(
ops, got_ops[GRPC_OP_RECV_CLOSE_ON_SERVER], [this](const grpc_op& op) {
auto recv_trailing_metadata = MaybeOp<GRPC_OP_RECV_CLOSE_ON_SERVER>(
ops, got_ops, [this](const grpc_op& op) {
return [this, cancelled = op.data.recv_close_on_server.cancelled]() {
return Map(call_handler_.WasCancelled(),
[cancelled, this](bool result) -> Success {

Loading…
Cancel
Save