[promises] Convert http server filter (#29273)

* Begin adding abstractions for capturing batches

* code written

* fixes

* fix

* fmt

* placate clang-tidy

* simplify

* fixes

* Automated change: Fix sanity tests

* annotate unrefs

* fix tsan

* fix

* Update subchannel.cc

* Update BUILD

* Update generate_tests.bzl

* Update BUILD

* Automated change: Fix sanity tests

* Promiseize

* Move server promise based filter to Flusher

* review feedback

* compiles!

* fix

* fix

* start server initial metadata emulation

* comment

* comment

* fix

* refactoring-to-fix-the-bugs

* fixfixfix

* fix

* support cancellation

* fuzzer

* fix

Co-authored-by: ctiller <ctiller@users.noreply.github.com>
pull/29575/head
Craig Tiller 3 years ago committed by GitHub
parent b39e3d4406
commit 717732c044
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      src/core/ext/filters/http/http_filters_plugin.cc
  2. 331
      src/core/ext/filters/http/server/http_server_filter.cc
  3. 27
      src/core/ext/filters/http/server/http_server_filter.h
  4. 259
      src/core/lib/channel/promise_based_filter.cc
  5. 15
      src/core/lib/channel/promise_based_filter.h
  6. 4
      src/core/lib/promise/arena_promise.h
  7. 4
      test/core/filters/filter_fuzzer.cc
  8. 3
      test/cpp/microbenchmarks/bm_call_create.cc

@ -84,6 +84,6 @@ void RegisterHttpFilters(CoreConfiguration::Builder* builder) {
GRPC_ARG_ENABLE_PER_MESSAGE_DECOMPRESSION, &MessageDecompressFilter);
required(GRPC_CLIENT_SUBCHANNEL, &HttpClientFilter::kFilter);
required(GRPC_CLIENT_DIRECT_CHANNEL, &HttpClientFilter::kFilter);
required(GRPC_SERVER_CHANNEL, &grpc_http_server_filter);
required(GRPC_SERVER_CHANNEL, &HttpServerFilter::kFilter);
}
} // namespace grpc_core

@ -28,307 +28,126 @@
#include "src/core/lib/channel/channel_args.h"
#include "src/core/lib/gprpp/manual_constructor.h"
#include "src/core/lib/profiling/timers.h"
#include "src/core/lib/promise/call_push_pull.h"
#include "src/core/lib/promise/seq.h"
#include "src/core/lib/slice/b64.h"
#include "src/core/lib/slice/percent_encoding.h"
#include "src/core/lib/slice/slice_internal.h"
#include "src/core/lib/slice/slice_string_helpers.h"
static void hs_recv_initial_metadata_ready(void* user_data,
grpc_error_handle err);
static void hs_recv_trailing_metadata_ready(void* user_data,
grpc_error_handle err);
namespace grpc_core {
namespace {
struct call_data {
call_data(grpc_call_element* elem, const grpc_call_element_args& args)
: call_combiner(args.call_combiner) {
GRPC_CLOSURE_INIT(&recv_initial_metadata_ready,
hs_recv_initial_metadata_ready, elem,
grpc_schedule_on_exec_ctx);
GRPC_CLOSURE_INIT(&recv_trailing_metadata_ready,
hs_recv_trailing_metadata_ready, elem,
grpc_schedule_on_exec_ctx);
}
~call_data() { GRPC_ERROR_UNREF(recv_initial_metadata_ready_error); }
grpc_core::CallCombiner* call_combiner;
// State for intercepting recv_initial_metadata.
grpc_closure recv_initial_metadata_ready;
grpc_error_handle recv_initial_metadata_ready_error = GRPC_ERROR_NONE;
grpc_closure* original_recv_initial_metadata_ready;
grpc_metadata_batch* recv_initial_metadata = nullptr;
uint32_t* recv_initial_metadata_flags;
bool seen_recv_initial_metadata_ready = false;
// State for intercepting recv_trailing_metadata
grpc_closure recv_trailing_metadata_ready;
grpc_closure* original_recv_trailing_metadata_ready;
grpc_error_handle recv_trailing_metadata_ready_error;
bool seen_recv_trailing_metadata_ready = false;
};
struct channel_data {
bool surface_user_agent;
bool allow_put_requests;
};
} // namespace
static grpc_error_handle hs_filter_outgoing_metadata(grpc_metadata_batch* b) {
if (grpc_core::Slice* grpc_message =
b->get_pointer(grpc_core::GrpcMessageMetadata())) {
*grpc_message = grpc_core::PercentEncodeSlice(
std::move(*grpc_message), grpc_core::PercentEncodingType::Compatible);
}
return GRPC_ERROR_NONE;
}
const grpc_channel_filter HttpServerFilter::kFilter =
MakePromiseBasedFilter<HttpServerFilter, FilterEndpoint::kServer,
kFilterExaminesServerInitialMetadata>("http-server");
static void hs_add_error(const char* error_name, grpc_error_handle* cumulative,
grpc_error_handle new_err) {
if (new_err == GRPC_ERROR_NONE) return;
if (*cumulative == GRPC_ERROR_NONE) {
*cumulative = GRPC_ERROR_CREATE_FROM_COPIED_STRING(error_name);
namespace {
void FilterOutgoingMetadata(ServerMetadata* md) {
if (Slice* grpc_message = md->get_pointer(GrpcMessageMetadata())) {
*grpc_message = PercentEncodeSlice(std::move(*grpc_message),
PercentEncodingType::Compatible);
}
*cumulative = grpc_error_add_child(*cumulative, new_err);
}
} // namespace
static grpc_error_handle hs_filter_incoming_metadata(grpc_call_element* elem,
grpc_metadata_batch* b) {
grpc_error_handle error = GRPC_ERROR_NONE;
static const char* error_name = "Failed processing incoming headers";
ArenaPromise<ServerMetadataHandle> HttpServerFilter::MakeCallPromise(
CallArgs call_args, NextPromiseFactory next_promise_factory) {
const auto& md = call_args.client_initial_metadata;
auto method = b->get(grpc_core::HttpMethodMetadata());
auto method = md->get(HttpMethodMetadata());
if (method.has_value()) {
switch (*method) {
case grpc_core::HttpMethodMetadata::kPost:
case HttpMethodMetadata::kPost:
break;
case grpc_core::HttpMethodMetadata::kPut:
if (static_cast<channel_data*>(elem->channel_data)
->allow_put_requests) {
case HttpMethodMetadata::kPut:
if (allow_put_requests_) {
break;
}
ABSL_FALLTHROUGH_INTENDED;
case grpc_core::HttpMethodMetadata::kInvalid:
case grpc_core::HttpMethodMetadata::kGet:
hs_add_error(error_name, &error,
GRPC_ERROR_CREATE_FROM_STATIC_STRING("Bad method header"));
break;
case HttpMethodMetadata::kInvalid:
case HttpMethodMetadata::kGet:
return Immediate(
ServerMetadataHandle(absl::UnknownError("Bad method header")));
}
} else {
hs_add_error(error_name, &error,
grpc_error_set_str(
GRPC_ERROR_CREATE_FROM_STATIC_STRING("Missing header"),
GRPC_ERROR_STR_KEY, ":method"));
return Immediate(
ServerMetadataHandle(absl::UnknownError("Missing :method header")));
}
auto te = b->Take(grpc_core::TeMetadata());
if (te == grpc_core::TeMetadata::kTrailers) {
auto te = md->Take(TeMetadata());
if (te == TeMetadata::kTrailers) {
// Do nothing, ok.
} else if (!te.has_value()) {
hs_add_error(error_name, &error,
grpc_error_set_str(
GRPC_ERROR_CREATE_FROM_STATIC_STRING("Missing header"),
GRPC_ERROR_STR_KEY, "te"));
return Immediate(
ServerMetadataHandle(absl::UnknownError("Missing :te header")));
} else {
hs_add_error(error_name, &error,
GRPC_ERROR_CREATE_FROM_STATIC_STRING("Bad te header"));
return Immediate(
ServerMetadataHandle(absl::UnknownError("Bad :te header")));
}
auto scheme = b->Take(grpc_core::HttpSchemeMetadata());
auto scheme = md->Take(HttpSchemeMetadata());
if (scheme.has_value()) {
if (*scheme == grpc_core::HttpSchemeMetadata::kInvalid) {
hs_add_error(error_name, &error,
GRPC_ERROR_CREATE_FROM_STATIC_STRING("Bad :scheme header"));
if (*scheme == HttpSchemeMetadata::kInvalid) {
return Immediate(
ServerMetadataHandle(absl::UnknownError("Bad :scheme header")));
}
} else {
hs_add_error(error_name, &error,
grpc_error_set_str(
GRPC_ERROR_CREATE_FROM_STATIC_STRING("Missing header"),
GRPC_ERROR_STR_KEY, ":scheme"));
return Immediate(
ServerMetadataHandle(absl::UnknownError("Missing :scheme header")));
}
b->Remove(grpc_core::ContentTypeMetadata());
md->Remove(ContentTypeMetadata());
grpc_core::Slice* path_slice = b->get_pointer(grpc_core::HttpPathMetadata());
Slice* path_slice = md->get_pointer(HttpPathMetadata());
if (path_slice == nullptr) {
hs_add_error(error_name, &error,
grpc_error_set_str(
GRPC_ERROR_CREATE_FROM_STATIC_STRING("Missing header"),
GRPC_ERROR_STR_KEY, ":path"));
return Immediate(
ServerMetadataHandle(absl::UnknownError("Missing :path header")));
}
if (b->get_pointer(grpc_core::HttpAuthorityMetadata()) == nullptr) {
absl::optional<grpc_core::Slice> host = b->Take(grpc_core::HostMetadata());
if (md->get_pointer(HttpAuthorityMetadata()) == nullptr) {
absl::optional<Slice> host = md->Take(HostMetadata());
if (host.has_value()) {
b->Set(grpc_core::HttpAuthorityMetadata(), std::move(*host));
md->Set(HttpAuthorityMetadata(), std::move(*host));
}
}
if (b->get_pointer(grpc_core::HttpAuthorityMetadata()) == nullptr) {
hs_add_error(error_name, &error,
grpc_error_set_str(
GRPC_ERROR_CREATE_FROM_STATIC_STRING("Missing header"),
GRPC_ERROR_STR_KEY, ":authority"));
if (md->get_pointer(HttpAuthorityMetadata()) == nullptr) {
return Immediate(
ServerMetadataHandle(absl::UnknownError("Missing :authority header")));
}
channel_data* chand = static_cast<channel_data*>(elem->channel_data);
if (!chand->surface_user_agent) {
b->Remove(grpc_core::UserAgentMetadata());
if (!surface_user_agent_) {
md->Remove(UserAgentMetadata());
}
return error;
}
auto* read_latch = GetContext<Arena>()->New<Latch<ServerMetadata*>>();
auto* write_latch =
absl::exchange(call_args.server_initial_metadata, read_latch);
static void hs_recv_initial_metadata_ready(void* user_data,
grpc_error_handle err) {
grpc_call_element* elem = static_cast<grpc_call_element*>(user_data);
call_data* calld = static_cast<call_data*>(elem->call_data);
calld->seen_recv_initial_metadata_ready = true;
if (err == GRPC_ERROR_NONE) {
err = hs_filter_incoming_metadata(elem, calld->recv_initial_metadata);
calld->recv_initial_metadata_ready_error = GRPC_ERROR_REF(err);
} else {
(void)GRPC_ERROR_REF(err);
}
if (calld->seen_recv_trailing_metadata_ready) {
GRPC_CALL_COMBINER_START(calld->call_combiner,
&calld->recv_trailing_metadata_ready,
calld->recv_trailing_metadata_ready_error,
"resuming hs_recv_trailing_metadata_ready from "
"hs_recv_initial_metadata_ready");
}
grpc_core::Closure::Run(DEBUG_LOCATION,
calld->original_recv_initial_metadata_ready, err);
return CallPushPull(Seq(next_promise_factory(std::move(call_args)),
[](ServerMetadataHandle md) -> ServerMetadataHandle {
FilterOutgoingMetadata(md.get());
return md;
}),
Seq(read_latch->Wait(),
[write_latch](ServerMetadata** md) {
FilterOutgoingMetadata(*md);
(*md)->Set(HttpStatusMetadata(), 200);
(*md)->Set(ContentTypeMetadata(),
ContentTypeMetadata::kApplicationGrpc);
write_latch->Set(*md);
return absl::OkStatus();
}),
[]() { return absl::OkStatus(); });
}
static void hs_recv_trailing_metadata_ready(void* user_data,
grpc_error_handle err) {
grpc_call_element* elem = static_cast<grpc_call_element*>(user_data);
call_data* calld = static_cast<call_data*>(elem->call_data);
if (!calld->seen_recv_initial_metadata_ready) {
calld->recv_trailing_metadata_ready_error = GRPC_ERROR_REF(err);
calld->seen_recv_trailing_metadata_ready = true;
GRPC_CALL_COMBINER_STOP(calld->call_combiner,
"deferring hs_recv_trailing_metadata_ready until "
"ater hs_recv_initial_metadata_ready");
return;
}
err = grpc_error_add_child(
GRPC_ERROR_REF(err),
GRPC_ERROR_REF(calld->recv_initial_metadata_ready_error));
grpc_core::Closure::Run(DEBUG_LOCATION,
calld->original_recv_trailing_metadata_ready, err);
absl::StatusOr<HttpServerFilter> HttpServerFilter::Create(ChannelArgs args,
ChannelFilter::Args) {
return HttpServerFilter(
args.GetBool(GRPC_ARG_SURFACE_USER_AGENT).value_or(true),
args.GetBool(
GRPC_ARG_DO_NOT_USE_UNLESS_YOU_HAVE_PERMISSION_FROM_GRPC_TEAM_ALLOW_BROKEN_PUT_REQUESTS)
.value_or(false));
}
static grpc_error_handle hs_mutate_op(grpc_call_element* elem,
grpc_transport_stream_op_batch* op) {
/* grab pointers to our data from the call element */
call_data* calld = static_cast<call_data*>(elem->call_data);
if (op->send_initial_metadata) {
grpc_error_handle error = GRPC_ERROR_NONE;
static const char* error_name = "Failed sending initial metadata";
op->payload->send_initial_metadata.send_initial_metadata->Set(
grpc_core::HttpStatusMetadata(), 200);
op->payload->send_initial_metadata.send_initial_metadata->Set(
grpc_core::ContentTypeMetadata(),
grpc_core::ContentTypeMetadata::kApplicationGrpc);
hs_add_error(error_name, &error,
hs_filter_outgoing_metadata(
op->payload->send_initial_metadata.send_initial_metadata));
if (error != GRPC_ERROR_NONE) return error;
}
if (op->recv_initial_metadata) {
/* substitute our callback for the higher callback */
GPR_ASSERT(op->payload->recv_initial_metadata.recv_flags != nullptr);
calld->recv_initial_metadata =
op->payload->recv_initial_metadata.recv_initial_metadata;
calld->recv_initial_metadata_flags =
op->payload->recv_initial_metadata.recv_flags;
calld->original_recv_initial_metadata_ready =
op->payload->recv_initial_metadata.recv_initial_metadata_ready;
op->payload->recv_initial_metadata.recv_initial_metadata_ready =
&calld->recv_initial_metadata_ready;
}
if (op->recv_trailing_metadata) {
calld->original_recv_trailing_metadata_ready =
op->payload->recv_trailing_metadata.recv_trailing_metadata_ready;
op->payload->recv_trailing_metadata.recv_trailing_metadata_ready =
&calld->recv_trailing_metadata_ready;
}
if (op->send_trailing_metadata) {
grpc_error_handle error = hs_filter_outgoing_metadata(
op->payload->send_trailing_metadata.send_trailing_metadata);
if (error != GRPC_ERROR_NONE) return error;
}
return GRPC_ERROR_NONE;
}
static void hs_start_transport_stream_op_batch(
grpc_call_element* elem, grpc_transport_stream_op_batch* op) {
GPR_TIMER_SCOPE("hs_start_transport_stream_op_batch", 0);
call_data* calld = static_cast<call_data*>(elem->call_data);
grpc_error_handle error = hs_mutate_op(elem, op);
if (error != GRPC_ERROR_NONE) {
grpc_transport_stream_op_batch_finish_with_failure(op, error,
calld->call_combiner);
} else {
grpc_call_next_op(elem, op);
}
}
/* Constructor for call_data */
static grpc_error_handle hs_init_call_elem(grpc_call_element* elem,
const grpc_call_element_args* args) {
new (elem->call_data) call_data(elem, *args);
return GRPC_ERROR_NONE;
}
/* Destructor for call_data */
static void hs_destroy_call_elem(grpc_call_element* elem,
const grpc_call_final_info* /*final_info*/,
grpc_closure* /*ignored*/) {
call_data* calld = static_cast<call_data*>(elem->call_data);
calld->~call_data();
}
/* Constructor for channel_data */
static grpc_error_handle hs_init_channel_elem(grpc_channel_element* elem,
grpc_channel_element_args* args) {
channel_data* chand = static_cast<channel_data*>(elem->channel_data);
GPR_ASSERT(!args->is_last);
chand->surface_user_agent = grpc_channel_arg_get_bool(
grpc_channel_args_find(args->channel_args,
const_cast<char*>(GRPC_ARG_SURFACE_USER_AGENT)),
true);
chand->allow_put_requests = grpc_channel_args_find_bool(
args->channel_args,
GRPC_ARG_DO_NOT_USE_UNLESS_YOU_HAVE_PERMISSION_FROM_GRPC_TEAM_ALLOW_BROKEN_PUT_REQUESTS,
false);
return GRPC_ERROR_NONE;
}
/* Destructor for channel data */
static void hs_destroy_channel_elem(grpc_channel_element* /*elem*/) {}
const grpc_channel_filter grpc_http_server_filter = {
hs_start_transport_stream_op_batch,
nullptr,
grpc_channel_next_op,
sizeof(call_data),
hs_init_call_elem,
grpc_call_stack_ignore_set_pollset_or_pollset_set,
hs_destroy_call_elem,
sizeof(channel_data),
hs_init_channel_elem,
grpc_channel_stack_no_post_init,
hs_destroy_channel_elem,
grpc_channel_next_get_info,
"http-server"};
} // namespace grpc_core

@ -22,9 +22,32 @@
#include <grpc/support/port_platform.h>
#include "src/core/lib/channel/channel_stack.h"
#include "src/core/lib/channel/promise_based_filter.h"
/* Processes metadata on the server side for HTTP2 transports */
extern const grpc_channel_filter grpc_http_server_filter;
namespace grpc_core {
// Processes metadata on the server side for HTTP2 transports
class HttpServerFilter : public ChannelFilter {
public:
static const grpc_channel_filter kFilter;
static absl::StatusOr<HttpServerFilter> Create(
ChannelArgs args, ChannelFilter::Args filter_args);
// Construct a promise for one call.
ArenaPromise<ServerMetadataHandle> MakeCallPromise(
CallArgs call_args, NextPromiseFactory next_promise_factory) override;
private:
HttpServerFilter(bool surface_user_agent, bool allow_put_requests)
: surface_user_agent_(surface_user_agent),
allow_put_requests_(allow_put_requests) {}
bool surface_user_agent_;
bool allow_put_requests_;
};
} // namespace grpc_core
// A Temporary channel arg that allows servers to accept PUT requests. DO NOT
// USE WITHOUT PERMISSION.

@ -100,6 +100,7 @@ BaseCallData::CapturedBatch::~CapturedBatch() {
BaseCallData::CapturedBatch::CapturedBatch(const CapturedBatch& rhs)
: batch_(rhs.batch_) {
if (batch_ == nullptr) return;
uintptr_t& refcnt = *RefCountField(batch_);
if (refcnt == 0) return; // refcnt==0 ==> cancelled
++refcnt;
@ -836,27 +837,101 @@ void ClientCallData::OnWakeup() {
///////////////////////////////////////////////////////////////////////////////
// ServerCallData
struct ServerCallData::SendInitialMetadata {
enum State {
kInitial,
kGotLatch,
kQueuedWaitingForLatch,
kQueuedAndGotLatch,
kQueuedAndSetLatch,
kForwarded,
kCancelled,
};
State state = kInitial;
CapturedBatch batch;
Latch<ServerMetadata*>* server_initial_metadata_publisher = nullptr;
};
class ServerCallData::PollContext {
public:
explicit PollContext(ServerCallData* self, Flusher* flusher)
: self_(self), flusher_(flusher) {
GPR_ASSERT(self_->poll_ctx_ == nullptr);
self_->poll_ctx_ = this;
scoped_activity_.Init(self_);
have_scoped_activity_ = true;
}
PollContext(const PollContext&) = delete;
PollContext& operator=(const PollContext&) = delete;
~PollContext() {
self_->poll_ctx_ = nullptr;
if (have_scoped_activity_) scoped_activity_.Destroy();
if (repoll_) {
struct NextPoll : public grpc_closure {
grpc_call_stack* call_stack;
ServerCallData* call_data;
};
auto run = [](void* p, grpc_error_handle) {
auto* next_poll = static_cast<NextPoll*>(p);
{
Flusher flusher(next_poll->call_data);
next_poll->call_data->WakeInsideCombiner(&flusher);
}
GRPC_CALL_STACK_UNREF(next_poll->call_stack, "re-poll");
delete next_poll;
};
auto* p = absl::make_unique<NextPoll>().release();
p->call_stack = self_->call_stack();
p->call_data = self_;
GRPC_CALL_STACK_REF(self_->call_stack(), "re-poll");
GRPC_CLOSURE_INIT(p, run, p, nullptr);
flusher_->AddClosure(p, GRPC_ERROR_NONE, "re-poll");
}
}
void Repoll() { repoll_ = true; }
void ClearRepoll() { repoll_ = false; }
private:
ManualConstructor<ScopedActivity> scoped_activity_;
ServerCallData* const self_;
Flusher* const flusher_;
bool repoll_ = false;
bool have_scoped_activity_;
};
ServerCallData::ServerCallData(grpc_call_element* elem,
const grpc_call_element_args* args,
uint8_t flags)
: BaseCallData(elem, args, flags) {
if (server_initial_metadata_latch() != nullptr) {
send_initial_metadata_ = arena()->New<SendInitialMetadata>();
}
GRPC_CLOSURE_INIT(&recv_initial_metadata_ready_,
RecvInitialMetadataReadyCallback, this,
grpc_schedule_on_exec_ctx);
}
ServerCallData::~ServerCallData() {
GPR_ASSERT(!is_polling_);
GPR_ASSERT(poll_ctx_ == nullptr);
GRPC_ERROR_UNREF(cancelled_error_);
}
// Activity implementation.
void ServerCallData::ForceImmediateRepoll() { abort(); } // Not implemented.
void ServerCallData::ForceImmediateRepoll() {
GPR_ASSERT(poll_ctx_ != nullptr);
poll_ctx_->Repoll();
}
// Handle one grpc_transport_stream_op_batch
void ServerCallData::StartBatch(grpc_transport_stream_op_batch* batch) {
void ServerCallData::StartBatch(grpc_transport_stream_op_batch* b) {
// Fake out the activity based context.
ScopedContext context(this);
CapturedBatch batch(b);
Flusher flusher(this);
bool wake = false;
// If this is a cancel stream, cancel anything we have pending and
// propagate the cancellation.
@ -865,8 +940,9 @@ void ServerCallData::StartBatch(grpc_transport_stream_op_batch* batch) {
!batch->send_trailing_metadata && !batch->send_message &&
!batch->recv_initial_metadata && !batch->recv_message &&
!batch->recv_trailing_metadata);
Cancel(batch->payload->cancel_stream.cancel_error);
grpc_call_next_op(elem(), batch);
Cancel(GRPC_ERROR_REF(batch->payload->cancel_stream.cancel_error),
&flusher);
batch.ResumeWith(&flusher);
return;
}
@ -888,61 +964,85 @@ void ServerCallData::StartBatch(grpc_transport_stream_op_batch* batch) {
recv_initial_state_ = RecvInitialState::kForwarded;
}
// send_initial_metadata
if (send_initial_metadata_ != nullptr && batch->send_initial_metadata) {
switch (send_initial_metadata_->state) {
case SendInitialMetadata::kInitial:
send_initial_metadata_->state =
SendInitialMetadata::kQueuedWaitingForLatch;
break;
case SendInitialMetadata::kGotLatch:
send_initial_metadata_->state = SendInitialMetadata::kQueuedAndGotLatch;
break;
case SendInitialMetadata::kCancelled:
batch.CancelWith(GRPC_ERROR_REF(cancelled_error_), &flusher);
break;
case SendInitialMetadata::kQueuedAndGotLatch:
case SendInitialMetadata::kQueuedWaitingForLatch:
case SendInitialMetadata::kQueuedAndSetLatch:
case SendInitialMetadata::kForwarded:
abort(); // not reachable
}
send_initial_metadata_->batch = batch;
wake = true;
}
// send_trailing_metadata
if (batch->send_trailing_metadata) {
if (batch.is_captured() && batch->send_trailing_metadata) {
switch (send_trailing_state_) {
case SendTrailingState::kInitial:
send_trailing_metadata_batch_ = batch;
send_trailing_state_ = SendTrailingState::kQueued;
WakeInsideCombiner([this](grpc_error_handle error) {
GPR_ASSERT(send_trailing_state_ == SendTrailingState::kQueued);
Cancel(error);
});
wake = true;
break;
case SendTrailingState::kQueued:
case SendTrailingState::kForwarded:
abort(); // unreachable
break;
case SendTrailingState::kCancelled:
grpc_transport_stream_op_batch_finish_with_failure(
batch, GRPC_ERROR_REF(cancelled_error_), call_combiner());
batch.CancelWith(GRPC_ERROR_REF(cancelled_error_), &flusher);
break;
}
return;
}
grpc_call_next_op(elem(), batch);
if (wake) WakeInsideCombiner(&flusher);
if (batch.is_captured()) batch.ResumeWith(&flusher);
}
// Handle cancellation.
void ServerCallData::Cancel(grpc_error_handle error) {
void ServerCallData::Cancel(grpc_error_handle error, Flusher* flusher) {
// Track the latest reason for cancellation.
GRPC_ERROR_UNREF(cancelled_error_);
cancelled_error_ = GRPC_ERROR_REF(error);
cancelled_error_ = error;
// Stop running the promise.
promise_ = ArenaPromise<ServerMetadataHandle>();
if (send_trailing_state_ == SendTrailingState::kQueued) {
send_trailing_state_ = SendTrailingState::kCancelled;
struct FailBatch : public grpc_closure {
grpc_transport_stream_op_batch* batch;
CallCombiner* call_combiner;
};
auto fail = [](void* p, grpc_error_handle error) {
auto* f = static_cast<FailBatch*>(p);
grpc_transport_stream_op_batch_finish_with_failure(
f->batch, GRPC_ERROR_REF(error), f->call_combiner);
delete f;
};
auto* b = new FailBatch();
GRPC_CLOSURE_INIT(b, fail, b, nullptr);
b->batch = absl::exchange(send_trailing_metadata_batch_, nullptr);
b->call_combiner = call_combiner();
GRPC_CALL_COMBINER_START(call_combiner(), b,
GRPC_ERROR_REF(cancelled_error_),
"cancel pending batch");
send_trailing_metadata_batch_.CancelWith(GRPC_ERROR_REF(error), flusher);
} else {
send_trailing_state_ = SendTrailingState::kCancelled;
}
if (send_initial_metadata_ != nullptr) {
switch (send_initial_metadata_->state) {
case SendInitialMetadata::kInitial:
case SendInitialMetadata::kGotLatch:
case SendInitialMetadata::kForwarded:
case SendInitialMetadata::kCancelled:
break;
case SendInitialMetadata::kQueuedWaitingForLatch:
case SendInitialMetadata::kQueuedAndGotLatch:
case SendInitialMetadata::kQueuedAndSetLatch:
send_initial_metadata_->batch.CancelWith(GRPC_ERROR_REF(error),
flusher);
break;
}
send_initial_metadata_->state = SendInitialMetadata::kCancelled;
}
if (auto* closure =
absl::exchange(original_recv_initial_metadata_ready_, nullptr)) {
flusher->AddClosure(closure, GRPC_ERROR_REF(error),
"original_recv_initial_metadata");
}
}
// Construct a promise that will "call" the next filter.
@ -955,6 +1055,31 @@ ArenaPromise<ServerMetadataHandle> ServerCallData::MakeNextPromise(
GPR_ASSERT(UnwrapMetadata(std::move(call_args.client_initial_metadata)) ==
recv_initial_metadata_);
forward_recv_initial_metadata_callback_ = true;
if (send_initial_metadata_ != nullptr) {
GPR_ASSERT(send_initial_metadata_->server_initial_metadata_publisher ==
nullptr);
GPR_ASSERT(call_args.server_initial_metadata != nullptr);
send_initial_metadata_->server_initial_metadata_publisher =
call_args.server_initial_metadata;
switch (send_initial_metadata_->state) {
case SendInitialMetadata::kInitial:
send_initial_metadata_->state = SendInitialMetadata::kGotLatch;
break;
case SendInitialMetadata::kGotLatch:
case SendInitialMetadata::kQueuedAndGotLatch:
case SendInitialMetadata::kQueuedAndSetLatch:
case SendInitialMetadata::kForwarded:
abort(); // not reachable
break;
case SendInitialMetadata::kQueuedWaitingForLatch:
send_initial_metadata_->state = SendInitialMetadata::kQueuedAndGotLatch;
break;
case SendInitialMetadata::kCancelled:
break;
}
} else {
GPR_ASSERT(call_args.server_initial_metadata == nullptr);
}
return ArenaPromise<ServerMetadataHandle>(
[this]() { return PollTrailingMetadata(); });
}
@ -986,12 +1111,14 @@ void ServerCallData::RecvInitialMetadataReadyCallback(void* arg,
}
void ServerCallData::RecvInitialMetadataReady(grpc_error_handle error) {
Flusher flusher(this);
GPR_ASSERT(recv_initial_state_ == RecvInitialState::kForwarded);
// If there was an error we just propagate that through
if (error != GRPC_ERROR_NONE) {
recv_initial_state_ = RecvInitialState::kResponded;
Closure::Run(DEBUG_LOCATION, original_recv_initial_metadata_ready_,
GRPC_ERROR_REF(error));
flusher.AddClosure(
absl::exchange(original_recv_initial_metadata_ready_, nullptr),
GRPC_ERROR_REF(error), "propagate error");
return;
}
// Record that we've got the callback.
@ -1008,30 +1135,46 @@ void ServerCallData::RecvInitialMetadataReady(grpc_error_handle error) {
return MakeNextPromise(std::move(call_args));
});
// Poll once.
bool own_error = false;
WakeInsideCombiner([&error, &own_error](grpc_error_handle new_error) {
GPR_ASSERT(error == GRPC_ERROR_NONE);
error = GRPC_ERROR_REF(new_error);
own_error = true;
});
Closure::Run(DEBUG_LOCATION, original_recv_initial_metadata_ready_,
GRPC_ERROR_REF(error));
if (own_error) GRPC_ERROR_UNREF(error);
WakeInsideCombiner(&flusher);
if (auto* closure =
absl::exchange(original_recv_initial_metadata_ready_, nullptr)) {
flusher.AddClosure(closure, GRPC_ERROR_NONE,
"original_recv_initial_metadata");
}
}
// Wakeup and poll the promise if appropriate.
void ServerCallData::WakeInsideCombiner(
absl::FunctionRef<void(grpc_error_handle)> cancel) {
GPR_ASSERT(!is_polling_);
bool forward_send_trailing_metadata = false;
is_polling_ = true;
if (recv_initial_state_ == RecvInitialState::kComplete) {
void ServerCallData::WakeInsideCombiner(Flusher* flusher) {
PollContext poll_ctx(this, flusher);
if (send_initial_metadata_ != nullptr &&
send_initial_metadata_->state ==
SendInitialMetadata::kQueuedAndGotLatch) {
send_initial_metadata_->state = SendInitialMetadata::kQueuedAndSetLatch;
send_initial_metadata_->server_initial_metadata_publisher->Set(
send_initial_metadata_->batch->payload->send_initial_metadata
.send_initial_metadata);
}
poll_ctx.ClearRepoll();
if (promise_.has_value()) {
Poll<ServerMetadataHandle> poll;
{
ScopedActivity activity(this);
poll = promise_();
poll = promise_();
if (send_initial_metadata_ != nullptr &&
send_initial_metadata_->state ==
SendInitialMetadata::kQueuedAndSetLatch) {
Poll<ServerMetadata**> p = server_initial_metadata_latch()->Wait()();
if (ServerMetadata*** ppp = absl::get_if<ServerMetadata**>(&p)) {
ServerMetadata* md = **ppp;
if (send_initial_metadata_->batch->payload->send_initial_metadata
.send_initial_metadata != md) {
*send_initial_metadata_->batch->payload->send_initial_metadata
.send_initial_metadata = std::move(*md);
}
send_initial_metadata_->state = SendInitialMetadata::kForwarded;
send_initial_metadata_->batch.ResumeWith(flusher);
}
}
if (auto* r = absl::get_if<ServerMetadataHandle>(&poll)) {
promise_ = ArenaPromise<ServerMetadataHandle>();
auto* md = UnwrapMetadata(std::move(*r));
bool destroy_md = true;
switch (send_trailing_state_) {
@ -1043,7 +1186,7 @@ void ServerCallData::WakeInsideCombiner(
} else {
destroy_md = false;
}
forward_send_trailing_metadata = true;
send_trailing_metadata_batch_.ResumeWith(flusher);
send_trailing_state_ = SendTrailingState::kForwarded;
} break;
case SendTrailingState::kForwarded:
@ -1060,8 +1203,7 @@ void ServerCallData::WakeInsideCombiner(
error = grpc_error_set_str(error, GRPC_ERROR_STR_GRPC_MESSAGE,
message->as_string_view());
}
cancel(error);
GRPC_ERROR_UNREF(error);
Cancel(error, flusher);
} break;
case SendTrailingState::kCancelled:
// Nothing to do.
@ -1072,11 +1214,6 @@ void ServerCallData::WakeInsideCombiner(
}
}
}
is_polling_ = false;
if (forward_send_trailing_metadata) {
grpc_call_next_op(elem(),
absl::exchange(send_trailing_metadata_batch_, nullptr));
}
}
void ServerCallData::OnWakeup() { abort(); } // not implemented

@ -354,8 +354,11 @@ class ServerCallData : public BaseCallData {
kCancelled
};
class PollContext;
struct SendInitialMetadata;
// Handle cancellation.
void Cancel(grpc_error_handle error);
void Cancel(grpc_error_handle error, Flusher* flusher);
// Construct a promise that will "call" the next filter.
// Effectively:
// - put the modified initial metadata into the batch being sent up.
@ -369,13 +372,15 @@ class ServerCallData : public BaseCallData {
grpc_error_handle error);
void RecvInitialMetadataReady(grpc_error_handle error);
// Wakeup and poll the promise if appropriate.
void WakeInsideCombiner(absl::FunctionRef<void(grpc_error_handle)> cancel);
void WakeInsideCombiner(Flusher* flusher);
void OnWakeup() override;
// Contained promise
ArenaPromise<ServerMetadataHandle> promise_;
// Pointer to where initial metadata will be stored.
grpc_metadata_batch* recv_initial_metadata_ = nullptr;
// State for sending initial metadata.
SendInitialMetadata* send_initial_metadata_ = nullptr;
// Closure to call when we're done with the trailing metadata.
grpc_closure* original_recv_initial_metadata_ready_ = nullptr;
// Our closure pointing to RecvInitialMetadataReadyCallback.
@ -383,13 +388,13 @@ class ServerCallData : public BaseCallData {
// Error received during cancellation.
grpc_error_handle cancelled_error_ = GRPC_ERROR_NONE;
// Trailing metadata batch
grpc_transport_stream_op_batch* send_trailing_metadata_batch_ = nullptr;
CapturedBatch send_trailing_metadata_batch_;
// State of the send_initial_metadata op.
RecvInitialState recv_initial_state_ = RecvInitialState::kInitial;
// State of the recv_trailing_metadata op.
SendTrailingState send_trailing_state_ = SendTrailingState::kInitial;
// Whether we're currently polling the promise.
bool is_polling_ = false;
// Current poll context (or nullptr if not polling).
PollContext* poll_ctx_ = nullptr;
// Whether to forward the recv_initial_metadata op at the end of promise
// wakeup.
bool forward_recv_initial_metadata_callback_ = false;

@ -185,6 +185,10 @@ class ArenaPromise {
// Expose the promise interface: a call operator that returns Poll<T>.
Poll<T> operator()() { return impl_->PollOnce(); }
bool has_value() const {
return impl_ != arena_promise_detail::NullImpl<T>::Get();
}
private:
// Underlying impl object.
arena_promise_detail::ImplInterface<T>* impl_ =

@ -19,6 +19,7 @@
#include "src/core/ext/filters/channel_idle/channel_idle_filter.h"
#include "src/core/ext/filters/http/client/http_client_filter.h"
#include "src/core/ext/filters/http/client_authority_filter.h"
#include "src/core/ext/filters/http/server/http_server_filter.h"
#include "src/core/lib/gpr/env.h"
#include "src/core/lib/iomgr/executor.h"
#include "src/core/lib/iomgr/timer_manager.h"
@ -259,7 +260,8 @@ ChannelArgs LoadChannelArgs(const FuzzerChannelArgs& fuzz_args,
const Filter* const kFilters[] = {
MAKE_FILTER(ClientAuthorityFilter), MAKE_FILTER(HttpClientFilter),
MAKE_FILTER(ClientAuthFilter), MAKE_FILTER(GrpcServerAuthzFilter),
MAKE_FILTER(ClientAuthFilter), MAKE_FILTER(GrpcServerAuthzFilter),
MAKE_FILTER(HttpServerFilter),
// We exclude this one internally, so we can't have it here - will need to
// pick it up through some future registration mechanism.
// MAKE_FILTER(ServerLoadReportingFilter),

@ -623,7 +623,8 @@ typedef Fixture<&grpc_core::HttpClientFilter::kFilter,
HttpClientFilter;
BENCHMARK_TEMPLATE(BM_IsolatedFilter, HttpClientFilter, NoOp);
BENCHMARK_TEMPLATE(BM_IsolatedFilter, HttpClientFilter, SendEmptyMetadata);
typedef Fixture<&grpc_http_server_filter, CHECKS_NOT_LAST> HttpServerFilter;
typedef Fixture<&grpc_core::HttpServerFilter::kFilter, CHECKS_NOT_LAST>
HttpServerFilter;
BENCHMARK_TEMPLATE(BM_IsolatedFilter, HttpServerFilter, NoOp);
BENCHMARK_TEMPLATE(BM_IsolatedFilter, HttpServerFilter, SendEmptyMetadata);
typedef Fixture<&grpc_message_size_filter, CHECKS_NOT_LAST> MessageSizeFilter;

Loading…
Cancel
Save