enforce order only for some methods

pull/38784/head
Vignesh Babu 1 week ago
parent 765301d514
commit 8619e193ef
  1. 176
      src/core/call/filter_fusion.h
  2. 167
      test/core/call/filter_fusion_test.cc

@ -633,17 +633,16 @@ void ExecuteCombinedWithChannelAccess(Call* call, Derived* channel,
typename FilterMethods::Methods(), typename FilterMethods::Idxs()); typename FilterMethods::Methods(), typename FilterMethods::Idxs());
} }
#define GRPC_FUSE_METHOD(name, type, forward, prefix) \ #define GRPC_FUSE_METHOD(name, type, forward) \
template <MethodVariant variant, typename Derived, typename... Filters> \ template <MethodVariant variant, typename Derived, typename... Filters> \
class FuseImpl##prefix##name; \ class FuseImpl##name; \
template <typename Derived, typename... Filters> \ template <typename Derived, typename... Filters> \
class FuseImpl##prefix##name<MethodVariant::kNoInterceptor, Derived, \ class FuseImpl##name<MethodVariant::kNoInterceptor, Derived, Filters...> { \
Filters...> { \
public: \ public: \
static inline const NoInterceptor name; \ static inline const NoInterceptor name; \
}; \ }; \
template <typename Derived, typename... Filters> \ template <typename Derived, typename... Filters> \
class FuseImpl##prefix##name<MethodVariant::kSimple, Derived, Filters...> { \ class FuseImpl##name<MethodVariant::kSimple, Derived, Filters...> { \
public: \ public: \
auto name(type x) { \ auto name(type x) { \
return ExecuteCombined<typename ForwardOrReverse< \ return ExecuteCombined<typename ForwardOrReverse< \
@ -652,8 +651,7 @@ void ExecuteCombinedWithChannelAccess(Call* call, Derived* channel,
} \ } \
}; \ }; \
template <typename Derived, typename... Filters> \ template <typename Derived, typename... Filters> \
class FuseImpl##prefix##name<MethodVariant::kChannelAccess, Derived, \ class FuseImpl##name<MethodVariant::kChannelAccess, Derived, Filters...> { \
Filters...> { \
public: \ public: \
auto name(type x, Derived* channel) { \ auto name(type x, Derived* channel) { \
return ExecuteCombinedWithChannelAccess< \ return ExecuteCombinedWithChannelAccess< \
@ -664,112 +662,82 @@ void ExecuteCombinedWithChannelAccess(Call* call, Derived* channel,
} \ } \
}; \ }; \
template <typename Derived, typename... Filters> \ template <typename Derived, typename... Filters> \
using Fuse##prefix##name = FuseImpl##prefix##name< \ using Fuse##name = \
MethodVariantForFilters<&Filters::Call::name...>(), Derived, Filters...> FuseImpl##name<MethodVariantForFilters<&Filters::Call::name...>(), \
Derived, Filters...>
GRPC_FUSE_METHOD(OnClientInitialMetadata, ClientMetadataHandle, true,
ClientFilter); GRPC_FUSE_METHOD(OnClientInitialMetadata, ClientMetadataHandle, true);
GRPC_FUSE_METHOD(OnClientInitialMetadata, ClientMetadataHandle, false, GRPC_FUSE_METHOD(OnServerInitialMetadata, ServerMetadataHandle, false);
ServerFilter); GRPC_FUSE_METHOD(OnClientToServerMessage, MessageHandle, true);
GRPC_FUSE_METHOD(OnServerInitialMetadata, ServerMetadataHandle, true, GRPC_FUSE_METHOD(OnServerToClientMessage, MessageHandle, false);
ClientFilter); GRPC_FUSE_METHOD(OnServerTrailingMetadata, ServerMetadataHandle, false);
GRPC_FUSE_METHOD(OnServerInitialMetadata, ServerMetadataHandle, false, GRPC_FUSE_METHOD(OnClientToServerHalfClose, ServerMetadataHandle, true);
ServerFilter); GRPC_FUSE_METHOD(OnFinalize, grpc_call_final_info*, true);
GRPC_FUSE_METHOD(OnClientToServerMessage, MessageHandle, true, ClientFilter);
GRPC_FUSE_METHOD(OnClientToServerMessage, MessageHandle, false, ServerFilter);
GRPC_FUSE_METHOD(OnServerToClientMessage, MessageHandle, true, ClientFilter);
GRPC_FUSE_METHOD(OnServerToClientMessage, MessageHandle, false, ServerFilter);
GRPC_FUSE_METHOD(OnServerTrailingMetadata, ServerMetadataHandle, true,
ClientFilter);
GRPC_FUSE_METHOD(OnServerTrailingMetadata, ServerMetadataHandle, false,
ServerFilter);
GRPC_FUSE_METHOD(OnClientToServerHalfClose, ServerMetadataHandle, true,
ClientFilter);
GRPC_FUSE_METHOD(OnClientToServerHalfClose, ServerMetadataHandle, false,
ServerFilter);
GRPC_FUSE_METHOD(OnFinalize, grpc_call_final_info*, true, ClientFilter);
GRPC_FUSE_METHOD(OnFinalize, grpc_call_final_info*, false, ServerFilter);
#undef GRPC_FUSE_METHOD #undef GRPC_FUSE_METHOD
#define GRPC_FUSED_FILTER(prefix, forward) \ template <typename... Filters>
template <typename... Filters> \ class FusedFilter : public Filters... {
class Fused##prefix : public Filters... { \ public:
public: \ class Call : public FuseOnClientInitialMetadata<FusedFilter, Filters...>,
class Call : public Fuse##prefix##OnClientInitialMetadata<Fused##prefix, \ public FuseOnServerInitialMetadata<FusedFilter, Filters...>,
Filters...>, \ public FuseOnClientToServerMessage<FusedFilter, Filters...>,
public Fuse##prefix##OnServerInitialMetadata<Fused##prefix, \ public FuseOnServerToClientMessage<FusedFilter, Filters...>,
Filters...>, \ public FuseOnServerTrailingMetadata<FusedFilter, Filters...>,
public Fuse##prefix##OnClientToServerMessage<Fused##prefix, \ public FuseOnClientToServerHalfClose<FusedFilter, Filters...>,
Filters...>, \ public FuseOnFinalize<FusedFilter, Filters...> {
public Fuse##prefix##OnServerToClientMessage<Fused##prefix, \ public:
Filters...>, \ template <size_t I>
public Fuse##prefix##OnServerTrailingMetadata<Fused##prefix, \ auto* fused_child() {
Filters...>, \ return &std::get<I>(filter_calls_);
public Fuse##prefix##OnClientToServerHalfClose<Fused##prefix, \ }
Filters...>, \
public Fuse##prefix##OnFinalize<Fused##prefix, Filters...> { \ using FuseOnClientInitialMetadata<FusedFilter,
public: \ Filters...>::OnClientInitialMetadata;
template <size_t I> \ using FuseOnServerInitialMetadata<FusedFilter,
auto* fused_child() { \ Filters...>::OnServerInitialMetadata;
return &std::get<I>(filter_calls_); \ using FuseOnClientToServerMessage<FusedFilter,
} \ Filters...>::OnClientToServerMessage;
\ using FuseOnServerToClientMessage<FusedFilter,
using Fuse##prefix##OnClientInitialMetadata< \ Filters...>::OnServerToClientMessage;
Fused##prefix, Filters...>::OnClientInitialMetadata; \ using FuseOnServerTrailingMetadata<FusedFilter,
using Fuse##prefix##OnServerInitialMetadata< \ Filters...>::OnServerTrailingMetadata;
Fused##prefix, Filters...>::OnServerInitialMetadata; \ using FuseOnClientToServerHalfClose<FusedFilter,
using Fuse##prefix##OnClientToServerMessage< \ Filters...>::OnClientToServerHalfClose;
Fused##prefix, Filters...>::OnClientToServerMessage; \ using FuseOnFinalize<FusedFilter, Filters...>::OnFinalize;
using Fuse##prefix##OnServerToClientMessage< \
Fused##prefix, Filters...>::OnServerToClientMessage; \ private:
using Fuse##prefix##OnServerTrailingMetadata< \ std::tuple<typename Filters::Call...> filter_calls_;
Fused##prefix, Filters...>::OnServerTrailingMetadata; \
using Fuse##prefix##OnClientToServerHalfClose< \
Fused##prefix, Filters...>::OnClientToServerHalfClose; \
using Fuse##prefix##OnFinalize<Fused##prefix, Filters...>::OnFinalize; \
\
private: \
std::tuple<typename Filters::Call...> filter_calls_; \
}; \
using FilterTypeList = \
typename ForwardOrReverseTypes<forward, \
Filters...>::OrderMethod::Types; \
\
bool StartTransportOp(grpc_transport_op* op) { \
return StartTransportOpInternal(op, FilterTypeList()); \
} \
\
bool GetChannelInfo(const grpc_channel_info* info) { \
return GetChannelInfoInternal(info, FilterTypeList()); \
} \
\
private: \
template <typename... FilterTypes> \
bool StartTransportOpInternal(grpc_transport_op* op, \
Typelist<FilterTypes...>) { \
return (std::get<FilterTypes>(filters_).StartTransportOp(op) || ...); \
} \
\
template <typename... FilterTypes> \
bool GetChannelInfoInternal(const grpc_channel_info* info, \
Typelist<FilterTypes...>) { \
return (std::get<FilterTypes>(filters_).GetChannelInfo(info) || ...); \
} \
\
std::tuple<Filters...> filters_; \
}; };
GRPC_FUSED_FILTER(ClientFilter, true); bool StartTransportOp(grpc_transport_op* op) {
GRPC_FUSED_FILTER(ServerFilter, false); return StartTransportOpInternal(op, Typelist<Filters...>());
}
} // namespace filters_detail bool GetChannelInfo(const grpc_channel_info* info) {
return GetChannelInfoInternal(info, Typelist<Filters...>());
}
template <typename... Filters> private:
using FusedClientFilter = filters_detail::FusedClientFilter<Filters...>; template <typename... FilterTypes>
bool StartTransportOpInternal(grpc_transport_op* op,
Typelist<FilterTypes...>) {
return (std::get<FilterTypes>(filters_).StartTransportOp(op) || ...);
}
template <typename... FilterTypes>
bool GetChannelInfoInternal(const grpc_channel_info* info,
Typelist<FilterTypes...>) {
return (std::get<FilterTypes>(filters_).GetChannelInfo(info) || ...);
}
std::tuple<Filters...> filters_;
};
} // namespace filters_detail
template <typename... Filters> template <typename... Filters>
using FusedServerFilter = filters_detail::FusedServerFilter<Filters...>; using FusedFilter = filters_detail::FusedFilter<Filters...>;
} // namespace grpc_core } // namespace grpc_core

@ -209,56 +209,28 @@ class Test5 {
} }
}; };
using TestFusedClientFilter = using TestFusedFilter = FusedFilter<Test1, Test2, Test3, Test4, Test5>;
FusedClientFilter<Test1, Test2, Test3, Test4, Test5>;
using TestFusedServerFilter =
FusedServerFilter<Test1, Test2, Test3, Test4, Test5>;
// ClientFilter
static_assert(!std::is_same_v<
decltype(&TestFusedClientFilter::Call::OnClientInitialMetadata),
const NoInterceptor*>);
static_assert(!std::is_same_v<
decltype(&TestFusedClientFilter::Call::OnServerInitialMetadata),
const NoInterceptor*>);
static_assert(!std::is_same_v<
decltype(&TestFusedClientFilter::Call::OnClientToServerMessage),
const NoInterceptor*>);
static_assert(!std::is_same_v<
decltype(&TestFusedClientFilter::Call::OnServerToClientMessage),
const NoInterceptor*>);
static_assert(!std::is_same_v<
decltype(&TestFusedClientFilter::Call::OnClientToServerHalfClose),
const NoInterceptor*>);
static_assert(!std::is_same_v<
decltype(&TestFusedClientFilter::Call::OnServerTrailingMetadata),
const NoInterceptor*>);
static_assert( static_assert(
!std::is_same_v<decltype(&TestFusedClientFilter::Call::OnFinalize), !std::is_same_v<decltype(&TestFusedFilter::Call::OnClientInitialMetadata),
const NoInterceptor*>);
static_assert(
!std::is_same_v<decltype(&TestFusedFilter::Call::OnServerInitialMetadata),
const NoInterceptor*>);
static_assert(
!std::is_same_v<decltype(&TestFusedFilter::Call::OnClientToServerMessage),
const NoInterceptor*>); const NoInterceptor*>);
// ServerFilter
static_assert(!std::is_same_v<
decltype(&TestFusedServerFilter::Call::OnClientInitialMetadata),
const NoInterceptor*>);
static_assert(!std::is_same_v<
decltype(&TestFusedServerFilter::Call::OnServerInitialMetadata),
const NoInterceptor*>);
static_assert(!std::is_same_v<
decltype(&TestFusedServerFilter::Call::OnClientToServerMessage),
const NoInterceptor*>);
static_assert(!std::is_same_v<
decltype(&TestFusedServerFilter::Call::OnServerToClientMessage),
const NoInterceptor*>);
static_assert(!std::is_same_v<
decltype(&TestFusedServerFilter::Call::OnClientToServerHalfClose),
const NoInterceptor*>);
static_assert(!std::is_same_v<
decltype(&TestFusedServerFilter::Call::OnServerTrailingMetadata),
const NoInterceptor*>);
static_assert( static_assert(
!std::is_same_v<decltype(&TestFusedServerFilter::Call::OnFinalize), !std::is_same_v<decltype(&TestFusedFilter::Call::OnServerToClientMessage),
const NoInterceptor*>); const NoInterceptor*>);
static_assert(
!std::is_same_v<decltype(&TestFusedFilter::Call::OnClientToServerHalfClose),
const NoInterceptor*>);
static_assert(
!std::is_same_v<decltype(&TestFusedFilter::Call::OnServerTrailingMetadata),
const NoInterceptor*>);
static_assert(!std::is_same_v<decltype(&TestFusedFilter::Call::OnFinalize),
const NoInterceptor*>);
template <typename T> template <typename T>
typename ServerMetadataOrHandle<T>::ValueType RunSuccessfulPromise( typename ServerMetadataOrHandle<T>::ValueType RunSuccessfulPromise(
@ -275,8 +247,8 @@ typename ServerMetadataOrHandle<T>::ValueType RunSuccessfulPromise(
TEST(FusedFilterTest, ClientFilterTest) { TEST(FusedFilterTest, ClientFilterTest) {
history.clear(); history.clear();
TestFusedClientFilter filter; TestFusedFilter filter;
TestFusedClientFilter::Call call; TestFusedFilter::Call call;
history.clear(); history.clear();
auto message = Arena::MakePooled<Message>(); auto message = Arena::MakePooled<Message>();
auto server_metadata_handle = Arena::MakePooled<ServerMetadata>(); auto server_metadata_handle = Arena::MakePooled<ServerMetadata>();
@ -298,26 +270,30 @@ TEST(FusedFilterTest, ClientFilterTest) {
RunSuccessfulPromise<ServerMetadata>(call.OnClientToServerHalfClose( RunSuccessfulPromise<ServerMetadata>(call.OnClientToServerHalfClose(
std::move(server_trailing_metadata_handle_half_close))); std::move(server_trailing_metadata_handle_half_close)));
call.OnFinalize(&info, &filter); call.OnFinalize(&info, &filter);
EXPECT_THAT(history, EXPECT_THAT(
ElementsAre("Test2::Call::OnClientToServerMessage", history,
"Test3::Call::OnClientToServerMessage", ElementsAre("Test2::Call::OnClientToServerMessage",
"Test1::Call::OnServerToClientMessage", "Test3::Call::OnClientToServerMessage",
"Test2::Call::OnServerToClientMessage", // ServerToClientMessage execution order must be reversed.
"Test3::Call::OnServerToClientMessage", "Test3::Call::OnServerToClientMessage",
"Test4::Call::OnServerInitialMetadata", "Test2::Call::OnServerToClientMessage",
"Test5::Call::OnServerInitialMetadata", "Test1::Call::OnServerToClientMessage",
"Test1::Call::OnClientInitialMetadata", // ServerInitialMetadata execution order must be reversed.
"Test2::Call::OnClientInitialMetadata", "Test5::Call::OnServerInitialMetadata",
"Test3::Call::OnClientInitialMetadata", "Test4::Call::OnServerInitialMetadata",
"Test4::Call::OnClientInitialMetadata", "Test1::Call::OnClientInitialMetadata",
"Test5::Call::OnClientInitialMetadata", "Test2::Call::OnClientInitialMetadata",
"Test1::Call::OnServerTrailingMetadata", "Test3::Call::OnClientInitialMetadata",
"Test2::Call::OnServerTrailingMetadata", "Test4::Call::OnClientInitialMetadata",
"Test3::Call::OnServerTrailingMetadata", "Test5::Call::OnClientInitialMetadata",
"Test1::Call::OnClientToServerHalfClose", // ServerTrailingMetadata execution order must be reversed.
"Test2::Call::OnClientToServerHalfClose", "Test3::Call::OnServerTrailingMetadata",
"Test1::Call::OnFinalize", "Test3::Call::OnFinalize", "Test2::Call::OnServerTrailingMetadata",
"Test4::Call::OnFinalize")); "Test1::Call::OnServerTrailingMetadata",
"Test1::Call::OnClientToServerHalfClose",
"Test2::Call::OnClientToServerHalfClose",
"Test1::Call::OnFinalize", "Test3::Call::OnFinalize",
"Test4::Call::OnFinalize"));
history.clear(); history.clear();
grpc_transport_op op; grpc_transport_op op;
grpc_channel_info channel_info; grpc_channel_info channel_info;
@ -330,61 +306,6 @@ TEST(FusedFilterTest, ClientFilterTest) {
"Test3::GetChannelInfo", "Test4::GetChannelInfo")); "Test3::GetChannelInfo", "Test4::GetChannelInfo"));
} }
TEST(FusedFilterTest, ServerFilterTest) {
history.clear();
TestFusedServerFilter filter;
TestFusedServerFilter::Call call;
history.clear();
auto message = Arena::MakePooled<Message>();
auto server_metadata_handle = Arena::MakePooled<ServerMetadata>();
auto server_trailing_metadata_handle = Arena::MakePooled<ServerMetadata>();
auto server_trailing_metadata_handle_half_close =
Arena::MakePooled<ServerMetadata>();
auto client_metadata_handle = Arena::MakePooled<ClientMetadata>();
struct grpc_call_final_info info;
message = RunSuccessfulPromise<Message>(
call.OnClientToServerMessage(std::move(message), &filter));
RunSuccessfulPromise<Message>(
call.OnServerToClientMessage(std::move(message), &filter));
RunSuccessfulPromise<ServerMetadata>(
call.OnServerInitialMetadata(std::move(server_metadata_handle), &filter));
RunSuccessfulPromise<ClientMetadata>(
call.OnClientInitialMetadata(std::move(client_metadata_handle), &filter));
RunSuccessfulPromise<ServerMetadata>(call.OnServerTrailingMetadata(
std::move(server_trailing_metadata_handle), &filter));
RunSuccessfulPromise<ServerMetadata>(call.OnClientToServerHalfClose(
std::move(server_trailing_metadata_handle_half_close)));
call.OnFinalize(&info, &filter);
EXPECT_THAT(history,
ElementsAre("Test3::Call::OnClientToServerMessage",
"Test2::Call::OnClientToServerMessage",
"Test3::Call::OnServerToClientMessage",
"Test2::Call::OnServerToClientMessage",
"Test1::Call::OnServerToClientMessage",
"Test5::Call::OnServerInitialMetadata",
"Test4::Call::OnServerInitialMetadata",
"Test5::Call::OnClientInitialMetadata",
"Test4::Call::OnClientInitialMetadata",
"Test3::Call::OnClientInitialMetadata",
"Test2::Call::OnClientInitialMetadata",
"Test1::Call::OnClientInitialMetadata",
"Test3::Call::OnServerTrailingMetadata",
"Test2::Call::OnServerTrailingMetadata",
"Test1::Call::OnServerTrailingMetadata",
"Test2::Call::OnClientToServerHalfClose",
"Test1::Call::OnClientToServerHalfClose",
"Test4::Call::OnFinalize", "Test3::Call::OnFinalize",
"Test1::Call::OnFinalize"));
history.clear();
grpc_transport_op op;
grpc_channel_info channel_info;
EXPECT_TRUE(filter.StartTransportOp(&op));
EXPECT_TRUE(filter.GetChannelInfo(&channel_info));
EXPECT_THAT(history,
ElementsAre("Test5::StartTransportOp", "Test4::StartTransportOp",
"Test5::GetChannelInfo", "Test4::GetChannelInfo"));
}
} // namespace } // namespace
} // namespace grpc_core } // namespace grpc_core

Loading…
Cancel
Save