diff --git a/src/core/lib/channel/promise_based_filter.h b/src/core/lib/channel/promise_based_filter.h index 7a87a091303..b5d898b8799 100644 --- a/src/core/lib/channel/promise_based_filter.h +++ b/src/core/lib/channel/promise_based_filter.h @@ -157,6 +157,8 @@ class BaseCallData : public Activity, private Wakeable { finalization_.Run(final_info); } + virtual void StartBatch(grpc_transport_stream_op_batch* batch) = 0; + protected: class ScopedContext : public promise_detail::Context, @@ -293,7 +295,7 @@ class ClientCallData : public BaseCallData { // Activity implementation. void ForceImmediateRepoll() final; // Handle one grpc_transport_stream_op_batch - void StartBatch(grpc_transport_stream_op_batch* batch); + void StartBatch(grpc_transport_stream_op_batch* batch) override; private: // At what stage is our handling of send initial metadata? @@ -393,7 +395,7 @@ class ServerCallData : public BaseCallData { // Activity implementation. void ForceImmediateRepoll() final; // Handle one grpc_transport_stream_op_batch - void StartBatch(grpc_transport_stream_op_batch* batch); + void StartBatch(grpc_transport_stream_op_batch* batch) override; private: // At what stage is our handling of recv initial metadata? @@ -471,23 +473,116 @@ class ServerCallData : public BaseCallData { // Specific call data per channel filter. // Note that we further specialize for clients and servers since their // implementations are very different. -template +template class CallData; // Client implementation of call data. -template -class CallData : public ClientCallData { +template <> +class CallData : public ClientCallData { public: using ClientCallData::ClientCallData; }; // Server implementation of call data. -template -class CallData : public ServerCallData { +template <> +class CallData : public ServerCallData { public: using ServerCallData::ServerCallData; }; +struct BaseCallDataMethods { + static void SetPollsetOrPollsetSet(grpc_call_element* elem, + grpc_polling_entity* pollent) { + static_cast(elem->call_data)->set_pollent(pollent); + } + + static void DestructCallData(grpc_call_element* elem, + const grpc_call_final_info* final_info) { + auto* cd = static_cast(elem->call_data); + cd->Finalize(final_info); + cd->~BaseCallData(); + } + + static void StartTransportStreamOpBatch( + grpc_call_element* elem, grpc_transport_stream_op_batch* batch) { + static_cast(elem->call_data)->StartBatch(batch); + } +}; + +template +struct CallDataFilterWithFlagsMethods { + static absl::Status InitCallElem(grpc_call_element* elem, + const grpc_call_element_args* args) { + new (elem->call_data) CallData(elem, args, kFlags); + return absl::OkStatus(); + } + + static void DestroyCallElem(grpc_call_element* elem, + const grpc_call_final_info* final_info, + grpc_closure* then_schedule_closure) { + BaseCallDataMethods::DestructCallData(elem, final_info); + if ((kFlags & kFilterIsLast) != 0) { + ExecCtx::Run(DEBUG_LOCATION, then_schedule_closure, absl::OkStatus()); + } else { + GPR_ASSERT(then_schedule_closure == nullptr); + } + } +}; + +struct ChannelFilterMethods { + static ArenaPromise MakeCallPromise( + grpc_channel_element* elem, CallArgs call_args, + NextPromiseFactory next_promise_factory) { + return static_cast(elem->channel_data) + ->MakeCallPromise(std::move(call_args), + std::move(next_promise_factory)); + } + + static void StartTransportOp(grpc_channel_element* elem, + grpc_transport_op* op) { + if (!static_cast(elem->channel_data) + ->StartTransportOp(op)) { + grpc_channel_next_op(elem, op); + } + } + + static void PostInitChannelElem(grpc_channel_stack*, + grpc_channel_element* elem) { + static_cast(elem->channel_data)->PostInit(); + } + + static void DestroyChannelElem(grpc_channel_element* elem) { + static_cast(elem->channel_data)->~ChannelFilter(); + } + + static void GetChannelInfo(grpc_channel_element* elem, + const grpc_channel_info* info) { + if (!static_cast(elem->channel_data) + ->GetChannelInfo(info)) { + grpc_channel_next_get_info(elem, info); + } + } +}; + +template +struct ChannelFilterWithFlagsMethods { + static absl::Status InitChannelElem(grpc_channel_element* elem, + grpc_channel_element_args* args) { + GPR_ASSERT(args->is_last == ((kFlags & kFilterIsLast) != 0)); + auto status = F::Create(ChannelArgs::FromC(args->channel_args), + ChannelFilter::Args(args->channel_stack, elem)); + if (!status.ok()) { + static_assert( + sizeof(promise_filter_detail::InvalidChannelFilter) <= sizeof(F), + "InvalidChannelFilter must fit in F"); + new (elem->channel_data) promise_filter_detail::InvalidChannelFilter(); + return absl_status_to_grpc_error(status.status()); + } + new (elem->channel_data) F(std::move(*status)); + return absl::OkStatus(); + } +}; + } // namespace promise_filter_detail // F implements ChannelFilter and : @@ -499,83 +594,36 @@ class CallData : public ServerCallData { template absl::enable_if_t::value, grpc_channel_filter> MakePromiseBasedFilter(const char* name) { - using CallData = promise_filter_detail::CallData; + using CallData = promise_filter_detail::CallData; return grpc_channel_filter{ // start_transport_stream_op_batch - [](grpc_call_element* elem, grpc_transport_stream_op_batch* batch) { - static_cast(elem->call_data)->StartBatch(batch); - }, + promise_filter_detail::BaseCallDataMethods::StartTransportStreamOpBatch, // make_call_promise - [](grpc_channel_element* elem, CallArgs call_args, - NextPromiseFactory next_promise_factory) { - return static_cast(elem->channel_data) - ->MakeCallPromise(std::move(call_args), - std::move(next_promise_factory)); - }, + promise_filter_detail::ChannelFilterMethods::MakeCallPromise, // start_transport_op - [](grpc_channel_element* elem, grpc_transport_op* op) { - if (!static_cast(elem->channel_data) - ->StartTransportOp(op)) { - grpc_channel_next_op(elem, op); - } - }, + promise_filter_detail::ChannelFilterMethods::StartTransportOp, // sizeof_call_data sizeof(CallData), // init_call_elem - [](grpc_call_element* elem, const grpc_call_element_args* args) { - new (elem->call_data) CallData(elem, args, kFlags); - return absl::OkStatus(); - }, + promise_filter_detail::CallDataFilterWithFlagsMethods< + CallData, kFlags>::InitCallElem, // set_pollset_or_pollset_set - [](grpc_call_element* elem, grpc_polling_entity* pollent) { - static_cast(elem->call_data)->set_pollent(pollent); - }, + promise_filter_detail::BaseCallDataMethods::SetPollsetOrPollsetSet, // destroy_call_elem - [](grpc_call_element* elem, const grpc_call_final_info* final_info, - grpc_closure* then_schedule_closure) { - auto* cd = static_cast(elem->call_data); - cd->Finalize(final_info); - cd->~CallData(); - if ((kFlags & kFilterIsLast) != 0) { - ExecCtx::Run(DEBUG_LOCATION, then_schedule_closure, absl::OkStatus()); - } else { - GPR_ASSERT(then_schedule_closure == nullptr); - } - }, + promise_filter_detail::CallDataFilterWithFlagsMethods< + CallData, kFlags>::DestroyCallElem, // sizeof_channel_data sizeof(F), // init_channel_elem - [](grpc_channel_element* elem, grpc_channel_element_args* args) { - GPR_ASSERT(args->is_last == ((kFlags & kFilterIsLast) != 0)); - auto status = F::Create(ChannelArgs::FromC(args->channel_args), - ChannelFilter::Args(args->channel_stack, elem)); - if (!status.ok()) { - static_assert( - sizeof(promise_filter_detail::InvalidChannelFilter) <= sizeof(F), - "InvalidChannelFilter must fit in F"); - new (elem->channel_data) - promise_filter_detail::InvalidChannelFilter(); - return absl_status_to_grpc_error(status.status()); - } - new (elem->channel_data) F(std::move(*status)); - return absl::OkStatus(); - }, + promise_filter_detail::ChannelFilterWithFlagsMethods< + F, kFlags>::InitChannelElem, // post_init_channel_elem - [](grpc_channel_stack*, grpc_channel_element* elem) { - static_cast(elem->channel_data)->PostInit(); - }, + promise_filter_detail::ChannelFilterMethods::PostInitChannelElem, // destroy_channel_elem - [](grpc_channel_element* elem) { - static_cast(elem->channel_data)->~ChannelFilter(); - }, + promise_filter_detail::ChannelFilterMethods::DestroyChannelElem, // get_channel_info - [](grpc_channel_element* elem, const grpc_channel_info* info) { - if (!static_cast(elem->channel_data) - ->GetChannelInfo(info)) { - grpc_channel_next_get_info(elem, info); - } - }, + promise_filter_detail::ChannelFilterMethods::GetChannelInfo, // name name, };