[promises] Reduce bloat for promise_based_filter (#31209)

* [bloat] Rewrite promise_based_filter vtables to maximize code sharing

* moar
pull/31238/head
Craig Tiller 2 years ago committed by GitHub
parent 7698fbba5a
commit a9d3398010
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 184
      src/core/lib/channel/promise_based_filter.h

@ -157,6 +157,8 @@ class BaseCallData : public Activity, private Wakeable {
finalization_.Run(final_info);
}
virtual void StartBatch(grpc_transport_stream_op_batch* batch) = 0;
protected:
class ScopedContext
: public promise_detail::Context<Arena>,
@ -293,7 +295,7 @@ class ClientCallData : public BaseCallData {
// Activity implementation.
void ForceImmediateRepoll() final;
// Handle one grpc_transport_stream_op_batch
void StartBatch(grpc_transport_stream_op_batch* batch);
void StartBatch(grpc_transport_stream_op_batch* batch) override;
private:
// At what stage is our handling of send initial metadata?
@ -393,7 +395,7 @@ class ServerCallData : public BaseCallData {
// Activity implementation.
void ForceImmediateRepoll() final;
// Handle one grpc_transport_stream_op_batch
void StartBatch(grpc_transport_stream_op_batch* batch);
void StartBatch(grpc_transport_stream_op_batch* batch) override;
private:
// At what stage is our handling of recv initial metadata?
@ -471,23 +473,116 @@ class ServerCallData : public BaseCallData {
// Specific call data per channel filter.
// Note that we further specialize for clients and servers since their
// implementations are very different.
template <class ChannelFilter, FilterEndpoint endpoint>
template <FilterEndpoint endpoint>
class CallData;
// Client implementation of call data.
template <class ChannelFilter>
class CallData<ChannelFilter, FilterEndpoint::kClient> : public ClientCallData {
template <>
class CallData<FilterEndpoint::kClient> : public ClientCallData {
public:
using ClientCallData::ClientCallData;
};
// Server implementation of call data.
template <class ChannelFilter>
class CallData<ChannelFilter, FilterEndpoint::kServer> : public ServerCallData {
template <>
class CallData<FilterEndpoint::kServer> : public ServerCallData {
public:
using ServerCallData::ServerCallData;
};
struct BaseCallDataMethods {
static void SetPollsetOrPollsetSet(grpc_call_element* elem,
grpc_polling_entity* pollent) {
static_cast<BaseCallData*>(elem->call_data)->set_pollent(pollent);
}
static void DestructCallData(grpc_call_element* elem,
const grpc_call_final_info* final_info) {
auto* cd = static_cast<BaseCallData*>(elem->call_data);
cd->Finalize(final_info);
cd->~BaseCallData();
}
static void StartTransportStreamOpBatch(
grpc_call_element* elem, grpc_transport_stream_op_batch* batch) {
static_cast<BaseCallData*>(elem->call_data)->StartBatch(batch);
}
};
template <typename CallData, uint8_t kFlags>
struct CallDataFilterWithFlagsMethods {
static absl::Status InitCallElem(grpc_call_element* elem,
const grpc_call_element_args* args) {
new (elem->call_data) CallData(elem, args, kFlags);
return absl::OkStatus();
}
static void DestroyCallElem(grpc_call_element* elem,
const grpc_call_final_info* final_info,
grpc_closure* then_schedule_closure) {
BaseCallDataMethods::DestructCallData(elem, final_info);
if ((kFlags & kFilterIsLast) != 0) {
ExecCtx::Run(DEBUG_LOCATION, then_schedule_closure, absl::OkStatus());
} else {
GPR_ASSERT(then_schedule_closure == nullptr);
}
}
};
struct ChannelFilterMethods {
static ArenaPromise<ServerMetadataHandle> MakeCallPromise(
grpc_channel_element* elem, CallArgs call_args,
NextPromiseFactory next_promise_factory) {
return static_cast<ChannelFilter*>(elem->channel_data)
->MakeCallPromise(std::move(call_args),
std::move(next_promise_factory));
}
static void StartTransportOp(grpc_channel_element* elem,
grpc_transport_op* op) {
if (!static_cast<ChannelFilter*>(elem->channel_data)
->StartTransportOp(op)) {
grpc_channel_next_op(elem, op);
}
}
static void PostInitChannelElem(grpc_channel_stack*,
grpc_channel_element* elem) {
static_cast<ChannelFilter*>(elem->channel_data)->PostInit();
}
static void DestroyChannelElem(grpc_channel_element* elem) {
static_cast<ChannelFilter*>(elem->channel_data)->~ChannelFilter();
}
static void GetChannelInfo(grpc_channel_element* elem,
const grpc_channel_info* info) {
if (!static_cast<ChannelFilter*>(elem->channel_data)
->GetChannelInfo(info)) {
grpc_channel_next_get_info(elem, info);
}
}
};
template <typename F, uint8_t kFlags>
struct ChannelFilterWithFlagsMethods {
static absl::Status InitChannelElem(grpc_channel_element* elem,
grpc_channel_element_args* args) {
GPR_ASSERT(args->is_last == ((kFlags & kFilterIsLast) != 0));
auto status = F::Create(ChannelArgs::FromC(args->channel_args),
ChannelFilter::Args(args->channel_stack, elem));
if (!status.ok()) {
static_assert(
sizeof(promise_filter_detail::InvalidChannelFilter) <= sizeof(F),
"InvalidChannelFilter must fit in F");
new (elem->channel_data) promise_filter_detail::InvalidChannelFilter();
return absl_status_to_grpc_error(status.status());
}
new (elem->channel_data) F(std::move(*status));
return absl::OkStatus();
}
};
} // namespace promise_filter_detail
// F implements ChannelFilter and :
@ -499,83 +594,36 @@ class CallData<ChannelFilter, FilterEndpoint::kServer> : public ServerCallData {
template <typename F, FilterEndpoint kEndpoint, uint8_t kFlags = 0>
absl::enable_if_t<std::is_base_of<ChannelFilter, F>::value, grpc_channel_filter>
MakePromiseBasedFilter(const char* name) {
using CallData = promise_filter_detail::CallData<F, kEndpoint>;
using CallData = promise_filter_detail::CallData<kEndpoint>;
return grpc_channel_filter{
// start_transport_stream_op_batch
[](grpc_call_element* elem, grpc_transport_stream_op_batch* batch) {
static_cast<CallData*>(elem->call_data)->StartBatch(batch);
},
promise_filter_detail::BaseCallDataMethods::StartTransportStreamOpBatch,
// make_call_promise
[](grpc_channel_element* elem, CallArgs call_args,
NextPromiseFactory next_promise_factory) {
return static_cast<ChannelFilter*>(elem->channel_data)
->MakeCallPromise(std::move(call_args),
std::move(next_promise_factory));
},
promise_filter_detail::ChannelFilterMethods::MakeCallPromise,
// start_transport_op
[](grpc_channel_element* elem, grpc_transport_op* op) {
if (!static_cast<ChannelFilter*>(elem->channel_data)
->StartTransportOp(op)) {
grpc_channel_next_op(elem, op);
}
},
promise_filter_detail::ChannelFilterMethods::StartTransportOp,
// sizeof_call_data
sizeof(CallData),
// init_call_elem
[](grpc_call_element* elem, const grpc_call_element_args* args) {
new (elem->call_data) CallData(elem, args, kFlags);
return absl::OkStatus();
},
promise_filter_detail::CallDataFilterWithFlagsMethods<
CallData, kFlags>::InitCallElem,
// set_pollset_or_pollset_set
[](grpc_call_element* elem, grpc_polling_entity* pollent) {
static_cast<CallData*>(elem->call_data)->set_pollent(pollent);
},
promise_filter_detail::BaseCallDataMethods::SetPollsetOrPollsetSet,
// destroy_call_elem
[](grpc_call_element* elem, const grpc_call_final_info* final_info,
grpc_closure* then_schedule_closure) {
auto* cd = static_cast<CallData*>(elem->call_data);
cd->Finalize(final_info);
cd->~CallData();
if ((kFlags & kFilterIsLast) != 0) {
ExecCtx::Run(DEBUG_LOCATION, then_schedule_closure, absl::OkStatus());
} else {
GPR_ASSERT(then_schedule_closure == nullptr);
}
},
promise_filter_detail::CallDataFilterWithFlagsMethods<
CallData, kFlags>::DestroyCallElem,
// sizeof_channel_data
sizeof(F),
// init_channel_elem
[](grpc_channel_element* elem, grpc_channel_element_args* args) {
GPR_ASSERT(args->is_last == ((kFlags & kFilterIsLast) != 0));
auto status = F::Create(ChannelArgs::FromC(args->channel_args),
ChannelFilter::Args(args->channel_stack, elem));
if (!status.ok()) {
static_assert(
sizeof(promise_filter_detail::InvalidChannelFilter) <= sizeof(F),
"InvalidChannelFilter must fit in F");
new (elem->channel_data)
promise_filter_detail::InvalidChannelFilter();
return absl_status_to_grpc_error(status.status());
}
new (elem->channel_data) F(std::move(*status));
return absl::OkStatus();
},
promise_filter_detail::ChannelFilterWithFlagsMethods<
F, kFlags>::InitChannelElem,
// post_init_channel_elem
[](grpc_channel_stack*, grpc_channel_element* elem) {
static_cast<ChannelFilter*>(elem->channel_data)->PostInit();
},
promise_filter_detail::ChannelFilterMethods::PostInitChannelElem,
// destroy_channel_elem
[](grpc_channel_element* elem) {
static_cast<ChannelFilter*>(elem->channel_data)->~ChannelFilter();
},
promise_filter_detail::ChannelFilterMethods::DestroyChannelElem,
// get_channel_info
[](grpc_channel_element* elem, const grpc_channel_info* info) {
if (!static_cast<ChannelFilter*>(elem->channel_data)
->GetChannelInfo(info)) {
grpc_channel_next_get_info(elem, info);
}
},
promise_filter_detail::ChannelFilterMethods::GetChannelInfo,
// name
name,
};

Loading…
Cancel
Save