enforce order only for some methods

pull/38784/head
Vignesh Babu 1 week ago
parent 765301d514
commit 8619e193ef
  1. 176
      src/core/call/filter_fusion.h
  2. 167
      test/core/call/filter_fusion_test.cc

@ -633,17 +633,16 @@ void ExecuteCombinedWithChannelAccess(Call* call, Derived* channel,
typename FilterMethods::Methods(), typename FilterMethods::Idxs());
}
#define GRPC_FUSE_METHOD(name, type, forward, prefix) \
#define GRPC_FUSE_METHOD(name, type, forward) \
template <MethodVariant variant, typename Derived, typename... Filters> \
class FuseImpl##prefix##name; \
class FuseImpl##name; \
template <typename Derived, typename... Filters> \
class FuseImpl##prefix##name<MethodVariant::kNoInterceptor, Derived, \
Filters...> { \
class FuseImpl##name<MethodVariant::kNoInterceptor, Derived, Filters...> { \
public: \
static inline const NoInterceptor name; \
}; \
template <typename Derived, typename... Filters> \
class FuseImpl##prefix##name<MethodVariant::kSimple, Derived, Filters...> { \
class FuseImpl##name<MethodVariant::kSimple, Derived, Filters...> { \
public: \
auto name(type x) { \
return ExecuteCombined<typename ForwardOrReverse< \
@ -652,8 +651,7 @@ void ExecuteCombinedWithChannelAccess(Call* call, Derived* channel,
} \
}; \
template <typename Derived, typename... Filters> \
class FuseImpl##prefix##name<MethodVariant::kChannelAccess, Derived, \
Filters...> { \
class FuseImpl##name<MethodVariant::kChannelAccess, Derived, Filters...> { \
public: \
auto name(type x, Derived* channel) { \
return ExecuteCombinedWithChannelAccess< \
@ -664,112 +662,82 @@ void ExecuteCombinedWithChannelAccess(Call* call, Derived* channel,
} \
}; \
template <typename Derived, typename... Filters> \
using Fuse##prefix##name = FuseImpl##prefix##name< \
MethodVariantForFilters<&Filters::Call::name...>(), Derived, Filters...>
GRPC_FUSE_METHOD(OnClientInitialMetadata, ClientMetadataHandle, true,
ClientFilter);
GRPC_FUSE_METHOD(OnClientInitialMetadata, ClientMetadataHandle, false,
ServerFilter);
GRPC_FUSE_METHOD(OnServerInitialMetadata, ServerMetadataHandle, true,
ClientFilter);
GRPC_FUSE_METHOD(OnServerInitialMetadata, ServerMetadataHandle, false,
ServerFilter);
GRPC_FUSE_METHOD(OnClientToServerMessage, MessageHandle, true, ClientFilter);
GRPC_FUSE_METHOD(OnClientToServerMessage, MessageHandle, false, ServerFilter);
GRPC_FUSE_METHOD(OnServerToClientMessage, MessageHandle, true, ClientFilter);
GRPC_FUSE_METHOD(OnServerToClientMessage, MessageHandle, false, ServerFilter);
GRPC_FUSE_METHOD(OnServerTrailingMetadata, ServerMetadataHandle, true,
ClientFilter);
GRPC_FUSE_METHOD(OnServerTrailingMetadata, ServerMetadataHandle, false,
ServerFilter);
GRPC_FUSE_METHOD(OnClientToServerHalfClose, ServerMetadataHandle, true,
ClientFilter);
GRPC_FUSE_METHOD(OnClientToServerHalfClose, ServerMetadataHandle, false,
ServerFilter);
GRPC_FUSE_METHOD(OnFinalize, grpc_call_final_info*, true, ClientFilter);
GRPC_FUSE_METHOD(OnFinalize, grpc_call_final_info*, false, ServerFilter);
using Fuse##name = \
FuseImpl##name<MethodVariantForFilters<&Filters::Call::name...>(), \
Derived, Filters...>
GRPC_FUSE_METHOD(OnClientInitialMetadata, ClientMetadataHandle, true);
GRPC_FUSE_METHOD(OnServerInitialMetadata, ServerMetadataHandle, false);
GRPC_FUSE_METHOD(OnClientToServerMessage, MessageHandle, true);
GRPC_FUSE_METHOD(OnServerToClientMessage, MessageHandle, false);
GRPC_FUSE_METHOD(OnServerTrailingMetadata, ServerMetadataHandle, false);
GRPC_FUSE_METHOD(OnClientToServerHalfClose, ServerMetadataHandle, true);
GRPC_FUSE_METHOD(OnFinalize, grpc_call_final_info*, true);
#undef GRPC_FUSE_METHOD
#define GRPC_FUSED_FILTER(prefix, forward) \
template <typename... Filters> \
class Fused##prefix : public Filters... { \
public: \
class Call : public Fuse##prefix##OnClientInitialMetadata<Fused##prefix, \
Filters...>, \
public Fuse##prefix##OnServerInitialMetadata<Fused##prefix, \
Filters...>, \
public Fuse##prefix##OnClientToServerMessage<Fused##prefix, \
Filters...>, \
public Fuse##prefix##OnServerToClientMessage<Fused##prefix, \
Filters...>, \
public Fuse##prefix##OnServerTrailingMetadata<Fused##prefix, \
Filters...>, \
public Fuse##prefix##OnClientToServerHalfClose<Fused##prefix, \
Filters...>, \
public Fuse##prefix##OnFinalize<Fused##prefix, Filters...> { \
public: \
template <size_t I> \
auto* fused_child() { \
return &std::get<I>(filter_calls_); \
} \
\
using Fuse##prefix##OnClientInitialMetadata< \
Fused##prefix, Filters...>::OnClientInitialMetadata; \
using Fuse##prefix##OnServerInitialMetadata< \
Fused##prefix, Filters...>::OnServerInitialMetadata; \
using Fuse##prefix##OnClientToServerMessage< \
Fused##prefix, Filters...>::OnClientToServerMessage; \
using Fuse##prefix##OnServerToClientMessage< \
Fused##prefix, Filters...>::OnServerToClientMessage; \
using Fuse##prefix##OnServerTrailingMetadata< \
Fused##prefix, Filters...>::OnServerTrailingMetadata; \
using Fuse##prefix##OnClientToServerHalfClose< \
Fused##prefix, Filters...>::OnClientToServerHalfClose; \
using Fuse##prefix##OnFinalize<Fused##prefix, Filters...>::OnFinalize; \
\
private: \
std::tuple<typename Filters::Call...> filter_calls_; \
}; \
using FilterTypeList = \
typename ForwardOrReverseTypes<forward, \
Filters...>::OrderMethod::Types; \
\
bool StartTransportOp(grpc_transport_op* op) { \
return StartTransportOpInternal(op, FilterTypeList()); \
} \
\
bool GetChannelInfo(const grpc_channel_info* info) { \
return GetChannelInfoInternal(info, FilterTypeList()); \
} \
\
private: \
template <typename... FilterTypes> \
bool StartTransportOpInternal(grpc_transport_op* op, \
Typelist<FilterTypes...>) { \
return (std::get<FilterTypes>(filters_).StartTransportOp(op) || ...); \
} \
\
template <typename... FilterTypes> \
bool GetChannelInfoInternal(const grpc_channel_info* info, \
Typelist<FilterTypes...>) { \
return (std::get<FilterTypes>(filters_).GetChannelInfo(info) || ...); \
} \
\
std::tuple<Filters...> filters_; \
template <typename... Filters>
class FusedFilter : public Filters... {
public:
class Call : public FuseOnClientInitialMetadata<FusedFilter, Filters...>,
public FuseOnServerInitialMetadata<FusedFilter, Filters...>,
public FuseOnClientToServerMessage<FusedFilter, Filters...>,
public FuseOnServerToClientMessage<FusedFilter, Filters...>,
public FuseOnServerTrailingMetadata<FusedFilter, Filters...>,
public FuseOnClientToServerHalfClose<FusedFilter, Filters...>,
public FuseOnFinalize<FusedFilter, Filters...> {
public:
template <size_t I>
auto* fused_child() {
return &std::get<I>(filter_calls_);
}
using FuseOnClientInitialMetadata<FusedFilter,
Filters...>::OnClientInitialMetadata;
using FuseOnServerInitialMetadata<FusedFilter,
Filters...>::OnServerInitialMetadata;
using FuseOnClientToServerMessage<FusedFilter,
Filters...>::OnClientToServerMessage;
using FuseOnServerToClientMessage<FusedFilter,
Filters...>::OnServerToClientMessage;
using FuseOnServerTrailingMetadata<FusedFilter,
Filters...>::OnServerTrailingMetadata;
using FuseOnClientToServerHalfClose<FusedFilter,
Filters...>::OnClientToServerHalfClose;
using FuseOnFinalize<FusedFilter, Filters...>::OnFinalize;
private:
std::tuple<typename Filters::Call...> filter_calls_;
};
GRPC_FUSED_FILTER(ClientFilter, true);
GRPC_FUSED_FILTER(ServerFilter, false);
bool StartTransportOp(grpc_transport_op* op) {
return StartTransportOpInternal(op, Typelist<Filters...>());
}
} // namespace filters_detail
bool GetChannelInfo(const grpc_channel_info* info) {
return GetChannelInfoInternal(info, Typelist<Filters...>());
}
template <typename... Filters>
using FusedClientFilter = filters_detail::FusedClientFilter<Filters...>;
private:
template <typename... FilterTypes>
bool StartTransportOpInternal(grpc_transport_op* op,
Typelist<FilterTypes...>) {
return (std::get<FilterTypes>(filters_).StartTransportOp(op) || ...);
}
template <typename... FilterTypes>
bool GetChannelInfoInternal(const grpc_channel_info* info,
Typelist<FilterTypes...>) {
return (std::get<FilterTypes>(filters_).GetChannelInfo(info) || ...);
}
std::tuple<Filters...> filters_;
};
} // namespace filters_detail
template <typename... Filters>
using FusedServerFilter = filters_detail::FusedServerFilter<Filters...>;
using FusedFilter = filters_detail::FusedFilter<Filters...>;
} // namespace grpc_core

@ -209,56 +209,28 @@ class Test5 {
}
};
using TestFusedClientFilter =
FusedClientFilter<Test1, Test2, Test3, Test4, Test5>;
using TestFusedServerFilter =
FusedServerFilter<Test1, Test2, Test3, Test4, Test5>;
// ClientFilter
static_assert(!std::is_same_v<
decltype(&TestFusedClientFilter::Call::OnClientInitialMetadata),
const NoInterceptor*>);
static_assert(!std::is_same_v<
decltype(&TestFusedClientFilter::Call::OnServerInitialMetadata),
const NoInterceptor*>);
static_assert(!std::is_same_v<
decltype(&TestFusedClientFilter::Call::OnClientToServerMessage),
const NoInterceptor*>);
static_assert(!std::is_same_v<
decltype(&TestFusedClientFilter::Call::OnServerToClientMessage),
const NoInterceptor*>);
static_assert(!std::is_same_v<
decltype(&TestFusedClientFilter::Call::OnClientToServerHalfClose),
const NoInterceptor*>);
static_assert(!std::is_same_v<
decltype(&TestFusedClientFilter::Call::OnServerTrailingMetadata),
const NoInterceptor*>);
using TestFusedFilter = FusedFilter<Test1, Test2, Test3, Test4, Test5>;
static_assert(
!std::is_same_v<decltype(&TestFusedClientFilter::Call::OnFinalize),
!std::is_same_v<decltype(&TestFusedFilter::Call::OnClientInitialMetadata),
const NoInterceptor*>);
static_assert(
!std::is_same_v<decltype(&TestFusedFilter::Call::OnServerInitialMetadata),
const NoInterceptor*>);
static_assert(
!std::is_same_v<decltype(&TestFusedFilter::Call::OnClientToServerMessage),
const NoInterceptor*>);
// ServerFilter
static_assert(!std::is_same_v<
decltype(&TestFusedServerFilter::Call::OnClientInitialMetadata),
const NoInterceptor*>);
static_assert(!std::is_same_v<
decltype(&TestFusedServerFilter::Call::OnServerInitialMetadata),
const NoInterceptor*>);
static_assert(!std::is_same_v<
decltype(&TestFusedServerFilter::Call::OnClientToServerMessage),
const NoInterceptor*>);
static_assert(!std::is_same_v<
decltype(&TestFusedServerFilter::Call::OnServerToClientMessage),
const NoInterceptor*>);
static_assert(!std::is_same_v<
decltype(&TestFusedServerFilter::Call::OnClientToServerHalfClose),
const NoInterceptor*>);
static_assert(!std::is_same_v<
decltype(&TestFusedServerFilter::Call::OnServerTrailingMetadata),
const NoInterceptor*>);
static_assert(
!std::is_same_v<decltype(&TestFusedServerFilter::Call::OnFinalize),
!std::is_same_v<decltype(&TestFusedFilter::Call::OnServerToClientMessage),
const NoInterceptor*>);
static_assert(
!std::is_same_v<decltype(&TestFusedFilter::Call::OnClientToServerHalfClose),
const NoInterceptor*>);
static_assert(
!std::is_same_v<decltype(&TestFusedFilter::Call::OnServerTrailingMetadata),
const NoInterceptor*>);
static_assert(!std::is_same_v<decltype(&TestFusedFilter::Call::OnFinalize),
const NoInterceptor*>);
template <typename T>
typename ServerMetadataOrHandle<T>::ValueType RunSuccessfulPromise(
@ -275,8 +247,8 @@ typename ServerMetadataOrHandle<T>::ValueType RunSuccessfulPromise(
TEST(FusedFilterTest, ClientFilterTest) {
history.clear();
TestFusedClientFilter filter;
TestFusedClientFilter::Call call;
TestFusedFilter filter;
TestFusedFilter::Call call;
history.clear();
auto message = Arena::MakePooled<Message>();
auto server_metadata_handle = Arena::MakePooled<ServerMetadata>();
@ -298,26 +270,30 @@ TEST(FusedFilterTest, ClientFilterTest) {
RunSuccessfulPromise<ServerMetadata>(call.OnClientToServerHalfClose(
std::move(server_trailing_metadata_handle_half_close)));
call.OnFinalize(&info, &filter);
EXPECT_THAT(history,
ElementsAre("Test2::Call::OnClientToServerMessage",
"Test3::Call::OnClientToServerMessage",
"Test1::Call::OnServerToClientMessage",
"Test2::Call::OnServerToClientMessage",
"Test3::Call::OnServerToClientMessage",
"Test4::Call::OnServerInitialMetadata",
"Test5::Call::OnServerInitialMetadata",
"Test1::Call::OnClientInitialMetadata",
"Test2::Call::OnClientInitialMetadata",
"Test3::Call::OnClientInitialMetadata",
"Test4::Call::OnClientInitialMetadata",
"Test5::Call::OnClientInitialMetadata",
"Test1::Call::OnServerTrailingMetadata",
"Test2::Call::OnServerTrailingMetadata",
"Test3::Call::OnServerTrailingMetadata",
"Test1::Call::OnClientToServerHalfClose",
"Test2::Call::OnClientToServerHalfClose",
"Test1::Call::OnFinalize", "Test3::Call::OnFinalize",
"Test4::Call::OnFinalize"));
EXPECT_THAT(
history,
ElementsAre("Test2::Call::OnClientToServerMessage",
"Test3::Call::OnClientToServerMessage",
// ServerToClientMessage execution order must be reversed.
"Test3::Call::OnServerToClientMessage",
"Test2::Call::OnServerToClientMessage",
"Test1::Call::OnServerToClientMessage",
// ServerInitialMetadata execution order must be reversed.
"Test5::Call::OnServerInitialMetadata",
"Test4::Call::OnServerInitialMetadata",
"Test1::Call::OnClientInitialMetadata",
"Test2::Call::OnClientInitialMetadata",
"Test3::Call::OnClientInitialMetadata",
"Test4::Call::OnClientInitialMetadata",
"Test5::Call::OnClientInitialMetadata",
// ServerTrailingMetadata execution order must be reversed.
"Test3::Call::OnServerTrailingMetadata",
"Test2::Call::OnServerTrailingMetadata",
"Test1::Call::OnServerTrailingMetadata",
"Test1::Call::OnClientToServerHalfClose",
"Test2::Call::OnClientToServerHalfClose",
"Test1::Call::OnFinalize", "Test3::Call::OnFinalize",
"Test4::Call::OnFinalize"));
history.clear();
grpc_transport_op op;
grpc_channel_info channel_info;
@ -330,61 +306,6 @@ TEST(FusedFilterTest, ClientFilterTest) {
"Test3::GetChannelInfo", "Test4::GetChannelInfo"));
}
TEST(FusedFilterTest, ServerFilterTest) {
history.clear();
TestFusedServerFilter filter;
TestFusedServerFilter::Call call;
history.clear();
auto message = Arena::MakePooled<Message>();
auto server_metadata_handle = Arena::MakePooled<ServerMetadata>();
auto server_trailing_metadata_handle = Arena::MakePooled<ServerMetadata>();
auto server_trailing_metadata_handle_half_close =
Arena::MakePooled<ServerMetadata>();
auto client_metadata_handle = Arena::MakePooled<ClientMetadata>();
struct grpc_call_final_info info;
message = RunSuccessfulPromise<Message>(
call.OnClientToServerMessage(std::move(message), &filter));
RunSuccessfulPromise<Message>(
call.OnServerToClientMessage(std::move(message), &filter));
RunSuccessfulPromise<ServerMetadata>(
call.OnServerInitialMetadata(std::move(server_metadata_handle), &filter));
RunSuccessfulPromise<ClientMetadata>(
call.OnClientInitialMetadata(std::move(client_metadata_handle), &filter));
RunSuccessfulPromise<ServerMetadata>(call.OnServerTrailingMetadata(
std::move(server_trailing_metadata_handle), &filter));
RunSuccessfulPromise<ServerMetadata>(call.OnClientToServerHalfClose(
std::move(server_trailing_metadata_handle_half_close)));
call.OnFinalize(&info, &filter);
EXPECT_THAT(history,
ElementsAre("Test3::Call::OnClientToServerMessage",
"Test2::Call::OnClientToServerMessage",
"Test3::Call::OnServerToClientMessage",
"Test2::Call::OnServerToClientMessage",
"Test1::Call::OnServerToClientMessage",
"Test5::Call::OnServerInitialMetadata",
"Test4::Call::OnServerInitialMetadata",
"Test5::Call::OnClientInitialMetadata",
"Test4::Call::OnClientInitialMetadata",
"Test3::Call::OnClientInitialMetadata",
"Test2::Call::OnClientInitialMetadata",
"Test1::Call::OnClientInitialMetadata",
"Test3::Call::OnServerTrailingMetadata",
"Test2::Call::OnServerTrailingMetadata",
"Test1::Call::OnServerTrailingMetadata",
"Test2::Call::OnClientToServerHalfClose",
"Test1::Call::OnClientToServerHalfClose",
"Test4::Call::OnFinalize", "Test3::Call::OnFinalize",
"Test1::Call::OnFinalize"));
history.clear();
grpc_transport_op op;
grpc_channel_info channel_info;
EXPECT_TRUE(filter.StartTransportOp(&op));
EXPECT_TRUE(filter.GetChannelInfo(&channel_info));
EXPECT_THAT(history,
ElementsAre("Test5::StartTransportOp", "Test4::StartTransportOp",
"Test5::GetChannelInfo", "Test4::GetChannelInfo"));
}
} // namespace
} // namespace grpc_core

Loading…
Cancel
Save