|
|
|
@ -157,6 +157,8 @@ class BaseCallData : public Activity, private Wakeable { |
|
|
|
|
finalization_.Run(final_info); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
virtual void StartBatch(grpc_transport_stream_op_batch* batch) = 0; |
|
|
|
|
|
|
|
|
|
protected: |
|
|
|
|
class ScopedContext |
|
|
|
|
: public promise_detail::Context<Arena>, |
|
|
|
@ -293,7 +295,7 @@ class ClientCallData : public BaseCallData { |
|
|
|
|
// Activity implementation.
|
|
|
|
|
void ForceImmediateRepoll() final; |
|
|
|
|
// Handle one grpc_transport_stream_op_batch
|
|
|
|
|
void StartBatch(grpc_transport_stream_op_batch* batch); |
|
|
|
|
void StartBatch(grpc_transport_stream_op_batch* batch) override; |
|
|
|
|
|
|
|
|
|
private: |
|
|
|
|
// At what stage is our handling of send initial metadata?
|
|
|
|
@ -393,7 +395,7 @@ class ServerCallData : public BaseCallData { |
|
|
|
|
// Activity implementation.
|
|
|
|
|
void ForceImmediateRepoll() final; |
|
|
|
|
// Handle one grpc_transport_stream_op_batch
|
|
|
|
|
void StartBatch(grpc_transport_stream_op_batch* batch); |
|
|
|
|
void StartBatch(grpc_transport_stream_op_batch* batch) override; |
|
|
|
|
|
|
|
|
|
private: |
|
|
|
|
// At what stage is our handling of recv initial metadata?
|
|
|
|
@ -471,23 +473,116 @@ class ServerCallData : public BaseCallData { |
|
|
|
|
// Specific call data per channel filter.
|
|
|
|
|
// Note that we further specialize for clients and servers since their
|
|
|
|
|
// implementations are very different.
|
|
|
|
|
template <class ChannelFilter, FilterEndpoint endpoint> |
|
|
|
|
template <FilterEndpoint endpoint> |
|
|
|
|
class CallData; |
|
|
|
|
|
|
|
|
|
// Client implementation of call data.
|
|
|
|
|
template <class ChannelFilter> |
|
|
|
|
class CallData<ChannelFilter, FilterEndpoint::kClient> : public ClientCallData { |
|
|
|
|
template <> |
|
|
|
|
class CallData<FilterEndpoint::kClient> : public ClientCallData { |
|
|
|
|
public: |
|
|
|
|
using ClientCallData::ClientCallData; |
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
// Server implementation of call data.
|
|
|
|
|
template <class ChannelFilter> |
|
|
|
|
class CallData<ChannelFilter, FilterEndpoint::kServer> : public ServerCallData { |
|
|
|
|
template <> |
|
|
|
|
class CallData<FilterEndpoint::kServer> : public ServerCallData { |
|
|
|
|
public: |
|
|
|
|
using ServerCallData::ServerCallData; |
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
struct BaseCallDataMethods { |
|
|
|
|
static void SetPollsetOrPollsetSet(grpc_call_element* elem, |
|
|
|
|
grpc_polling_entity* pollent) { |
|
|
|
|
static_cast<BaseCallData*>(elem->call_data)->set_pollent(pollent); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
static void DestructCallData(grpc_call_element* elem, |
|
|
|
|
const grpc_call_final_info* final_info) { |
|
|
|
|
auto* cd = static_cast<BaseCallData*>(elem->call_data); |
|
|
|
|
cd->Finalize(final_info); |
|
|
|
|
cd->~BaseCallData(); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
static void StartTransportStreamOpBatch( |
|
|
|
|
grpc_call_element* elem, grpc_transport_stream_op_batch* batch) { |
|
|
|
|
static_cast<BaseCallData*>(elem->call_data)->StartBatch(batch); |
|
|
|
|
} |
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
template <typename CallData, uint8_t kFlags> |
|
|
|
|
struct CallDataFilterWithFlagsMethods { |
|
|
|
|
static absl::Status InitCallElem(grpc_call_element* elem, |
|
|
|
|
const grpc_call_element_args* args) { |
|
|
|
|
new (elem->call_data) CallData(elem, args, kFlags); |
|
|
|
|
return absl::OkStatus(); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
static void DestroyCallElem(grpc_call_element* elem, |
|
|
|
|
const grpc_call_final_info* final_info, |
|
|
|
|
grpc_closure* then_schedule_closure) { |
|
|
|
|
BaseCallDataMethods::DestructCallData(elem, final_info); |
|
|
|
|
if ((kFlags & kFilterIsLast) != 0) { |
|
|
|
|
ExecCtx::Run(DEBUG_LOCATION, then_schedule_closure, absl::OkStatus()); |
|
|
|
|
} else { |
|
|
|
|
GPR_ASSERT(then_schedule_closure == nullptr); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
struct ChannelFilterMethods { |
|
|
|
|
static ArenaPromise<ServerMetadataHandle> MakeCallPromise( |
|
|
|
|
grpc_channel_element* elem, CallArgs call_args, |
|
|
|
|
NextPromiseFactory next_promise_factory) { |
|
|
|
|
return static_cast<ChannelFilter*>(elem->channel_data) |
|
|
|
|
->MakeCallPromise(std::move(call_args), |
|
|
|
|
std::move(next_promise_factory)); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
static void StartTransportOp(grpc_channel_element* elem, |
|
|
|
|
grpc_transport_op* op) { |
|
|
|
|
if (!static_cast<ChannelFilter*>(elem->channel_data) |
|
|
|
|
->StartTransportOp(op)) { |
|
|
|
|
grpc_channel_next_op(elem, op); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
static void PostInitChannelElem(grpc_channel_stack*, |
|
|
|
|
grpc_channel_element* elem) { |
|
|
|
|
static_cast<ChannelFilter*>(elem->channel_data)->PostInit(); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
static void DestroyChannelElem(grpc_channel_element* elem) { |
|
|
|
|
static_cast<ChannelFilter*>(elem->channel_data)->~ChannelFilter(); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
static void GetChannelInfo(grpc_channel_element* elem, |
|
|
|
|
const grpc_channel_info* info) { |
|
|
|
|
if (!static_cast<ChannelFilter*>(elem->channel_data) |
|
|
|
|
->GetChannelInfo(info)) { |
|
|
|
|
grpc_channel_next_get_info(elem, info); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
template <typename F, uint8_t kFlags> |
|
|
|
|
struct ChannelFilterWithFlagsMethods { |
|
|
|
|
static absl::Status InitChannelElem(grpc_channel_element* elem, |
|
|
|
|
grpc_channel_element_args* args) { |
|
|
|
|
GPR_ASSERT(args->is_last == ((kFlags & kFilterIsLast) != 0)); |
|
|
|
|
auto status = F::Create(ChannelArgs::FromC(args->channel_args), |
|
|
|
|
ChannelFilter::Args(args->channel_stack, elem)); |
|
|
|
|
if (!status.ok()) { |
|
|
|
|
static_assert( |
|
|
|
|
sizeof(promise_filter_detail::InvalidChannelFilter) <= sizeof(F), |
|
|
|
|
"InvalidChannelFilter must fit in F"); |
|
|
|
|
new (elem->channel_data) promise_filter_detail::InvalidChannelFilter(); |
|
|
|
|
return absl_status_to_grpc_error(status.status()); |
|
|
|
|
} |
|
|
|
|
new (elem->channel_data) F(std::move(*status)); |
|
|
|
|
return absl::OkStatus(); |
|
|
|
|
} |
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
} // namespace promise_filter_detail
|
|
|
|
|
|
|
|
|
|
// F implements ChannelFilter and :
|
|
|
|
@ -499,83 +594,36 @@ class CallData<ChannelFilter, FilterEndpoint::kServer> : public ServerCallData { |
|
|
|
|
template <typename F, FilterEndpoint kEndpoint, uint8_t kFlags = 0> |
|
|
|
|
absl::enable_if_t<std::is_base_of<ChannelFilter, F>::value, grpc_channel_filter> |
|
|
|
|
MakePromiseBasedFilter(const char* name) { |
|
|
|
|
using CallData = promise_filter_detail::CallData<F, kEndpoint>; |
|
|
|
|
using CallData = promise_filter_detail::CallData<kEndpoint>; |
|
|
|
|
|
|
|
|
|
return grpc_channel_filter{ |
|
|
|
|
// start_transport_stream_op_batch
|
|
|
|
|
[](grpc_call_element* elem, grpc_transport_stream_op_batch* batch) { |
|
|
|
|
static_cast<CallData*>(elem->call_data)->StartBatch(batch); |
|
|
|
|
}, |
|
|
|
|
promise_filter_detail::BaseCallDataMethods::StartTransportStreamOpBatch, |
|
|
|
|
// make_call_promise
|
|
|
|
|
[](grpc_channel_element* elem, CallArgs call_args, |
|
|
|
|
NextPromiseFactory next_promise_factory) { |
|
|
|
|
return static_cast<ChannelFilter*>(elem->channel_data) |
|
|
|
|
->MakeCallPromise(std::move(call_args), |
|
|
|
|
std::move(next_promise_factory)); |
|
|
|
|
}, |
|
|
|
|
promise_filter_detail::ChannelFilterMethods::MakeCallPromise, |
|
|
|
|
// start_transport_op
|
|
|
|
|
[](grpc_channel_element* elem, grpc_transport_op* op) { |
|
|
|
|
if (!static_cast<ChannelFilter*>(elem->channel_data) |
|
|
|
|
->StartTransportOp(op)) { |
|
|
|
|
grpc_channel_next_op(elem, op); |
|
|
|
|
} |
|
|
|
|
}, |
|
|
|
|
promise_filter_detail::ChannelFilterMethods::StartTransportOp, |
|
|
|
|
// sizeof_call_data
|
|
|
|
|
sizeof(CallData), |
|
|
|
|
// init_call_elem
|
|
|
|
|
[](grpc_call_element* elem, const grpc_call_element_args* args) { |
|
|
|
|
new (elem->call_data) CallData(elem, args, kFlags); |
|
|
|
|
return absl::OkStatus(); |
|
|
|
|
}, |
|
|
|
|
promise_filter_detail::CallDataFilterWithFlagsMethods< |
|
|
|
|
CallData, kFlags>::InitCallElem, |
|
|
|
|
// set_pollset_or_pollset_set
|
|
|
|
|
[](grpc_call_element* elem, grpc_polling_entity* pollent) { |
|
|
|
|
static_cast<CallData*>(elem->call_data)->set_pollent(pollent); |
|
|
|
|
}, |
|
|
|
|
promise_filter_detail::BaseCallDataMethods::SetPollsetOrPollsetSet, |
|
|
|
|
// destroy_call_elem
|
|
|
|
|
[](grpc_call_element* elem, const grpc_call_final_info* final_info, |
|
|
|
|
grpc_closure* then_schedule_closure) { |
|
|
|
|
auto* cd = static_cast<CallData*>(elem->call_data); |
|
|
|
|
cd->Finalize(final_info); |
|
|
|
|
cd->~CallData(); |
|
|
|
|
if ((kFlags & kFilterIsLast) != 0) { |
|
|
|
|
ExecCtx::Run(DEBUG_LOCATION, then_schedule_closure, absl::OkStatus()); |
|
|
|
|
} else { |
|
|
|
|
GPR_ASSERT(then_schedule_closure == nullptr); |
|
|
|
|
} |
|
|
|
|
}, |
|
|
|
|
promise_filter_detail::CallDataFilterWithFlagsMethods< |
|
|
|
|
CallData, kFlags>::DestroyCallElem, |
|
|
|
|
// sizeof_channel_data
|
|
|
|
|
sizeof(F), |
|
|
|
|
// init_channel_elem
|
|
|
|
|
[](grpc_channel_element* elem, grpc_channel_element_args* args) { |
|
|
|
|
GPR_ASSERT(args->is_last == ((kFlags & kFilterIsLast) != 0)); |
|
|
|
|
auto status = F::Create(ChannelArgs::FromC(args->channel_args), |
|
|
|
|
ChannelFilter::Args(args->channel_stack, elem)); |
|
|
|
|
if (!status.ok()) { |
|
|
|
|
static_assert( |
|
|
|
|
sizeof(promise_filter_detail::InvalidChannelFilter) <= sizeof(F), |
|
|
|
|
"InvalidChannelFilter must fit in F"); |
|
|
|
|
new (elem->channel_data) |
|
|
|
|
promise_filter_detail::InvalidChannelFilter(); |
|
|
|
|
return absl_status_to_grpc_error(status.status()); |
|
|
|
|
} |
|
|
|
|
new (elem->channel_data) F(std::move(*status)); |
|
|
|
|
return absl::OkStatus(); |
|
|
|
|
}, |
|
|
|
|
promise_filter_detail::ChannelFilterWithFlagsMethods< |
|
|
|
|
F, kFlags>::InitChannelElem, |
|
|
|
|
// post_init_channel_elem
|
|
|
|
|
[](grpc_channel_stack*, grpc_channel_element* elem) { |
|
|
|
|
static_cast<ChannelFilter*>(elem->channel_data)->PostInit(); |
|
|
|
|
}, |
|
|
|
|
promise_filter_detail::ChannelFilterMethods::PostInitChannelElem, |
|
|
|
|
// destroy_channel_elem
|
|
|
|
|
[](grpc_channel_element* elem) { |
|
|
|
|
static_cast<ChannelFilter*>(elem->channel_data)->~ChannelFilter(); |
|
|
|
|
}, |
|
|
|
|
promise_filter_detail::ChannelFilterMethods::DestroyChannelElem, |
|
|
|
|
// get_channel_info
|
|
|
|
|
[](grpc_channel_element* elem, const grpc_channel_info* info) { |
|
|
|
|
if (!static_cast<ChannelFilter*>(elem->channel_data) |
|
|
|
|
->GetChannelInfo(info)) { |
|
|
|
|
grpc_channel_next_get_info(elem, info); |
|
|
|
|
} |
|
|
|
|
}, |
|
|
|
|
promise_filter_detail::ChannelFilterMethods::GetChannelInfo, |
|
|
|
|
// name
|
|
|
|
|
name, |
|
|
|
|
}; |
|
|
|
|