pull/36509/head
Craig Tiller 10 months ago
parent 51933446c1
commit 7c82e2f2b9
  1. 11
      src/core/lib/promise/detail/promise_like.h
  2. 137
      src/core/lib/surface/call.cc

@ -17,6 +17,7 @@
#include <utility> #include <utility>
#include "absl/functional/any_invocable.h"
#include "absl/meta/type_traits.h" #include "absl/meta/type_traits.h"
#include <grpc/support/port_platform.h> #include <grpc/support/port_platform.h>
@ -63,6 +64,10 @@ auto WrapInPoll(T&& x) -> decltype(PollWrapper<T>::Wrap(std::forward<T>(x))) {
return PollWrapper<T>::Wrap(std::forward<T>(x)); return PollWrapper<T>::Wrap(std::forward<T>(x));
} }
// T -> T, const T& -> T
template <typename T>
using RemoveCVRef = absl::remove_cv_t<absl::remove_reference_t<T>>;
template <typename F, typename SfinaeVoid = void> template <typename F, typename SfinaeVoid = void>
class PromiseLike; class PromiseLike;
@ -73,7 +78,7 @@ template <typename F>
class PromiseLike<F, absl::enable_if_t<!std::is_void< class PromiseLike<F, absl::enable_if_t<!std::is_void<
typename std::result_of<F()>::type>::value>> { typename std::result_of<F()>::type>::value>> {
private: private:
GPR_NO_UNIQUE_ADDRESS F f_; GPR_NO_UNIQUE_ADDRESS RemoveCVRef<F> f_;
public: public:
// NOLINTNEXTLINE - internal detail that drastically simplifies calling code. // NOLINTNEXTLINE - internal detail that drastically simplifies calling code.
@ -82,10 +87,6 @@ class PromiseLike<F, absl::enable_if_t<!std::is_void<
using Result = typename PollTraits<decltype(WrapInPoll(f_()))>::Type; using Result = typename PollTraits<decltype(WrapInPoll(f_()))>::Type;
}; };
// T -> T, const T& -> T
template <typename T>
using RemoveCVRef = absl::remove_cv_t<absl::remove_reference_t<T>>;
} // namespace promise_detail } // namespace promise_detail
} // namespace grpc_core } // namespace grpc_core

@ -3276,62 +3276,83 @@ grpc_call_error ServerCall::StartBatch(const grpc_op* ops, size_t nops,
} }
namespace { namespace {
template <typename SetupFn> template <typename SetupResult, grpc_op_type kOp>
class MaybeOpImpl { class MaybeOpImpl {
public: public:
using SetupResult = decltype(std::declval<SetupFn>()(grpc_op()));
using PromiseFactory = promise_detail::OncePromiseFactory<void, SetupResult>; using PromiseFactory = promise_detail::OncePromiseFactory<void, SetupResult>;
using Promise = typename PromiseFactory::Promise; using Promise = typename PromiseFactory::Promise;
static_assert(!std::is_same<Promise, void>::value, static_assert(!std::is_same<Promise, void>::value,
"PromiseFactory must return a promise"); "PromiseFactory must return a promise");
struct Dismissed {};
using State = absl::variant<Dismissed, PromiseFactory, Promise>;
// op_ is garbage but shouldn't be uninitialized MaybeOpImpl() : state_(State::kDismissed) {}
MaybeOpImpl() : state_(Dismissed{}), op_(GRPC_OP_RECV_STATUS_ON_CLIENT) {} explicit MaybeOpImpl(SetupResult result) : state_(State::kPromiseFactory) {
MaybeOpImpl(SetupResult result, grpc_op_type op) Construct(&promise_factory_, std::move(result));
: state_(PromiseFactory(std::move(result))), op_(op) {} }
~MaybeOpImpl() {
switch (state_) {
case State::kDismissed:
case State::kPromiseFactory:
Destruct(&promise_factory_);
break;
case State::kPromise:
Destruct(&promise_);
break;
}
}
MaybeOpImpl(const MaybeOpImpl&) = delete; MaybeOpImpl(const MaybeOpImpl&) = delete;
MaybeOpImpl& operator=(const MaybeOpImpl&) = delete; MaybeOpImpl& operator=(const MaybeOpImpl&) = delete;
MaybeOpImpl(MaybeOpImpl&& other) noexcept MaybeOpImpl(MaybeOpImpl&& other) noexcept : state_(other.state_) {
: state_(MoveState(other.state_)), op_(other.op_) {} switch (state_) {
MaybeOpImpl& operator=(MaybeOpImpl&& other) noexcept { case State::kDismissed:
op_ = other.op_; case State::kPromiseFactory:
if (absl::holds_alternative<Dismissed>(state_)) { Construct(&promise_factory_, std::move(other.promise_factory_));
state_.template emplace<Dismissed>(); break;
return *this; case State::kPromise:
Construct(&promise_, std::move(other.promise_));
break;
} }
// Can't move after first poll => Promise is not an option
state_.template emplace<PromiseFactory>(
std::move(absl::get<PromiseFactory>(other.state_)));
return *this;
} }
MaybeOpImpl& operator=(MaybeOpImpl&& other) noexcept = delete;
Poll<StatusFlag> operator()() { Poll<StatusFlag> operator()() {
if (absl::holds_alternative<Dismissed>(state_)) return Success{}; switch (state_) {
if (absl::holds_alternative<PromiseFactory>(state_)) { case State::kDismissed:
auto& factory = absl::get<PromiseFactory>(state_); return Success{};
auto promise = factory.Make(); case State::kPromiseFactory: {
state_.template emplace<Promise>(std::move(promise)); auto promise = promise_factory_.Make();
} Destruct(&promise_factory_);
if (grpc_call_trace.enabled()) { Construct(&promise_, std::move(promise));
gpr_log(GPR_INFO, "%sBeginPoll %s", state_ = State::kPromise;
Activity::current()->DebugTag().c_str(), OpName(op_).c_str()); }
} ABSL_FALLTHROUGH_INTENDED;
auto& promise = absl::get<Promise>(state_); case State::kPromise: {
auto r = poll_cast<StatusFlag>(promise()); if (grpc_call_trace.enabled()) {
if (grpc_call_trace.enabled()) { gpr_log(GPR_INFO, "%sBeginPoll %s",
gpr_log(GPR_INFO, "%sEndPoll %s --> %s", Activity::current()->DebugTag().c_str(), OpName());
Activity::current()->DebugTag().c_str(), OpName(op_).c_str(), }
auto r = poll_cast<StatusFlag>(promise_());
if (grpc_call_trace.enabled()) {
gpr_log(
GPR_INFO, "%sEndPoll %s --> %s",
Activity::current()->DebugTag().c_str(), OpName(),
r.pending() ? "PENDING" : (r.value().ok() ? "OK" : "FAILURE")); r.pending() ? "PENDING" : (r.value().ok() ? "OK" : "FAILURE"));
}
return r;
}
} }
return r;
} }
private: private:
static std::string OpName(grpc_op_type op) { enum class State {
switch (op) { kDismissed,
kPromiseFactory,
kPromise,
};
static const char* OpName() {
switch (kOp) {
case GRPC_OP_SEND_INITIAL_METADATA: case GRPC_OP_SEND_INITIAL_METADATA:
return "SendInitialMetadata"; return "SendInitialMetadata";
case GRPC_OP_SEND_MESSAGE: case GRPC_OP_SEND_MESSAGE:
@ -3349,17 +3370,15 @@ class MaybeOpImpl {
case GRPC_OP_RECV_STATUS_ON_CLIENT: case GRPC_OP_RECV_STATUS_ON_CLIENT:
return "RecvStatusOnClient"; return "RecvStatusOnClient";
} }
return absl::StrCat("UnknownOp(", op, ")"); Crash("Unreachable");
}
static State MoveState(State& state) {
if (absl::holds_alternative<Dismissed>(state)) return Dismissed{};
// Can't move after first poll => Promise is not an option
return std::move(absl::get<PromiseFactory>(state));
} }
// gcc-12 has problems with this being a variant
GPR_NO_UNIQUE_ADDRESS State state_; GPR_NO_UNIQUE_ADDRESS State state_;
GPR_NO_UNIQUE_ADDRESS grpc_op_type op_; union {
PromiseFactory promise_factory_;
Promise promise_;
};
}; };
// MaybeOp captures a fairly complicated dance we need to do for the batch // MaybeOp captures a fairly complicated dance we need to do for the batch
@ -3370,12 +3389,14 @@ class MaybeOpImpl {
// ultimately poll on til completion. // ultimately poll on til completion.
// Once we express our surface API in terms of core internal types this whole // Once we express our surface API in terms of core internal types this whole
// dance will go away. // dance will go away.
template <typename SetupFn> template <grpc_op_type op_type, typename SetupFn>
auto MaybeOp(const grpc_op* ops, uint8_t idx, SetupFn setup) { auto MaybeOp(const grpc_op* ops, const std::array<uint8_t, 8>& idxs, SetupFn setup) {
if (idx == 255) { using SetupResult = decltype(std::declval<SetupFn>()(grpc_op()));
return MaybeOpImpl<SetupFn>(); if (idxs[op_type] == 255) {
return MaybeOpImpl<SetupResult, op_type>();
} else { } else {
return MaybeOpImpl<SetupFn>(setup(ops[idx]), ops[idx].op); auto r = setup(ops[idxs[op_type]]);
return MaybeOpImpl<SetupResult, op_type>(std::move(r));
} }
} }
@ -3466,8 +3487,8 @@ void ServerCall::CommitBatch(const grpc_op* ops, size_t nops, void* notify_tag,
got_ops[op.op] = op_idx; got_ops[op.op] = op_idx;
} }
if (!is_notify_tag_closure) grpc_cq_begin_op(cq_, notify_tag); if (!is_notify_tag_closure) grpc_cq_begin_op(cq_, notify_tag);
auto send_initial_metadata = MaybeOp( auto send_initial_metadata = MaybeOp<GRPC_OP_SEND_INITIAL_METADATA>(
ops, got_ops[GRPC_OP_SEND_INITIAL_METADATA], [this](const grpc_op& op) { ops, got_ops, [this](const grpc_op& op) {
auto metadata = arena()->MakePooled<ServerMetadata>(); auto metadata = arena()->MakePooled<ServerMetadata>();
PrepareOutgoingInitialMetadata(op, *metadata); PrepareOutgoingInitialMetadata(op, *metadata);
CToMetadata(op.data.send_initial_metadata.metadata, CToMetadata(op.data.send_initial_metadata.metadata,
@ -3481,7 +3502,7 @@ void ServerCall::CommitBatch(const grpc_op* ops, size_t nops, void* notify_tag,
}; };
}); });
auto send_message = auto send_message =
MaybeOp(ops, got_ops[GRPC_OP_SEND_MESSAGE], [this](const grpc_op& op) { MaybeOp<GRPC_OP_SEND_MESSAGE>(ops, got_ops, [this](const grpc_op& op) {
SliceBuffer send; SliceBuffer send;
grpc_slice_buffer_swap( grpc_slice_buffer_swap(
&op.data.send_message.send_message->data.raw.slice_buffer, &op.data.send_message.send_message->data.raw.slice_buffer,
@ -3491,8 +3512,8 @@ void ServerCall::CommitBatch(const grpc_op* ops, size_t nops, void* notify_tag,
return call_handler_.PushMessage(std::move(msg)); return call_handler_.PushMessage(std::move(msg));
}; };
}); });
auto send_trailing_metadata = MaybeOp( auto send_trailing_metadata = MaybeOp<GRPC_OP_SEND_STATUS_FROM_SERVER>(
ops, got_ops[GRPC_OP_SEND_STATUS_FROM_SERVER], [this](const grpc_op& op) { ops, got_ops, [this](const grpc_op& op) {
auto metadata = arena()->MakePooled<ServerMetadata>(); auto metadata = arena()->MakePooled<ServerMetadata>();
CToMetadata(op.data.send_status_from_server.trailing_metadata, CToMetadata(op.data.send_status_from_server.trailing_metadata,
op.data.send_status_from_server.trailing_metadata_count, op.data.send_status_from_server.trailing_metadata_count,
@ -3519,7 +3540,7 @@ void ServerCall::CommitBatch(const grpc_op* ops, size_t nops, void* notify_tag,
}; };
}); });
auto recv_message = auto recv_message =
MaybeOp(ops, got_ops[GRPC_OP_RECV_MESSAGE], [this](const grpc_op& op) { MaybeOp<GRPC_OP_RECV_MESSAGE>(ops, got_ops, [this](const grpc_op& op) {
CHECK_EQ(recv_message_, nullptr); CHECK_EQ(recv_message_, nullptr);
recv_message_ = op.data.recv_message.recv_message; recv_message_ = op.data.recv_message.recv_message;
return [this]() mutable { return [this]() mutable {
@ -3535,8 +3556,8 @@ void ServerCall::CommitBatch(const grpc_op* ops, size_t nops, void* notify_tag,
std::move(send_trailing_metadata)), std::move(send_trailing_metadata)),
std::move(recv_message)); std::move(recv_message));
if (got_ops[GRPC_OP_RECV_CLOSE_ON_SERVER] != 255) { if (got_ops[GRPC_OP_RECV_CLOSE_ON_SERVER] != 255) {
auto recv_trailing_metadata = MaybeOp( auto recv_trailing_metadata = MaybeOp<GRPC_OP_RECV_CLOSE_ON_SERVER>(
ops, got_ops[GRPC_OP_RECV_CLOSE_ON_SERVER], [this](const grpc_op& op) { ops, got_ops, [this](const grpc_op& op) {
return [this, cancelled = op.data.recv_close_on_server.cancelled]() { return [this, cancelled = op.data.recv_close_on_server.cancelled]() {
return Map(call_handler_.WasCancelled(), return Map(call_handler_.WasCancelled(),
[cancelled, this](bool result) -> Success { [cancelled, this](bool result) -> Success {

Loading…
Cancel
Save