[call-v3] Convert message size filter to new API (#35233)

Closes #35233

COPYBARA_INTEGRATE_REVIEW=https://github.com/grpc/grpc/pull/35233 from ctiller:cg-msg-size cce51d8bd5
PiperOrigin-RevId: 588793125
pull/35250/head
Craig Tiller 12 months ago committed by Copybara-Service
parent 6c816a4f99
commit 5f92a67f94
  1. 122
      src/core/ext/filters/message_size/message_size_filter.cc
  2. 56
      src/core/ext/filters/message_size/message_size_filter.h
  3. 148
      src/core/lib/channel/promise_based_filter.h

@ -50,6 +50,13 @@
namespace grpc_core { namespace grpc_core {
const NoInterceptor ClientMessageSizeFilter::Call::OnClientInitialMetadata;
const NoInterceptor ClientMessageSizeFilter::Call::OnServerInitialMetadata;
const NoInterceptor ClientMessageSizeFilter::Call::OnServerTrailingMetadata;
const NoInterceptor ServerMessageSizeFilter::Call::OnClientInitialMetadata;
const NoInterceptor ServerMessageSizeFilter::Call::OnServerInitialMetadata;
const NoInterceptor ServerMessageSizeFilter::Call::OnServerTrailingMetadata;
// //
// MessageSizeParsedConfig // MessageSizeParsedConfig
// //
@ -138,60 +145,6 @@ const grpc_channel_filter ServerMessageSizeFilter::kFilter =
kFilterExaminesOutboundMessages | kFilterExaminesOutboundMessages |
kFilterExaminesInboundMessages>("message_size"); kFilterExaminesInboundMessages>("message_size");
class MessageSizeFilter::CallBuilder {
private:
auto Interceptor(uint32_t max_length, bool is_send) {
return [max_length, is_send,
err = err_](MessageHandle msg) -> absl::optional<MessageHandle> {
if (grpc_call_trace.enabled()) {
gpr_log(GPR_INFO, "%s[message_size] %s len:%" PRIdPTR " max:%d",
Activity::current()->DebugTag().c_str(),
is_send ? "send" : "recv", msg->payload()->Length(),
max_length);
}
if (msg->payload()->Length() > max_length) {
if (err->is_set()) return std::move(msg);
auto r = GetContext<Arena>()->MakePooled<ServerMetadata>(
GetContext<Arena>());
r->Set(GrpcStatusMetadata(), GRPC_STATUS_RESOURCE_EXHAUSTED);
r->Set(GrpcMessageMetadata(),
Slice::FromCopiedString(
absl::StrFormat("%s message larger than max (%u vs. %d)",
is_send ? "Sent" : "Received",
msg->payload()->Length(), max_length)));
err->Set(std::move(r));
return absl::nullopt;
}
return std::move(msg);
};
}
public:
explicit CallBuilder(const MessageSizeParsedConfig& limits)
: limits_(limits) {}
template <typename T>
void AddSend(T* pipe_end) {
if (!limits_.max_send_size().has_value()) return;
pipe_end->InterceptAndMap(Interceptor(*limits_.max_send_size(), true));
}
template <typename T>
void AddRecv(T* pipe_end) {
if (!limits_.max_recv_size().has_value()) return;
pipe_end->InterceptAndMap(Interceptor(*limits_.max_recv_size(), false));
}
ArenaPromise<ServerMetadataHandle> Run(
CallArgs call_args, NextPromiseFactory next_promise_factory) {
return Race(err_->Wait(), next_promise_factory(std::move(call_args)));
}
private:
Latch<ServerMetadataHandle>* const err_ =
GetContext<Arena>()->ManagedNew<Latch<ServerMetadataHandle>>();
MessageSizeParsedConfig limits_;
};
absl::StatusOr<ClientMessageSizeFilter> ClientMessageSizeFilter::Create( absl::StatusOr<ClientMessageSizeFilter> ClientMessageSizeFilter::Create(
const ChannelArgs& args, ChannelFilter::Args) { const ChannelArgs& args, ChannelFilter::Args) {
return ClientMessageSizeFilter(args); return ClientMessageSizeFilter(args);
@ -202,20 +155,40 @@ absl::StatusOr<ServerMessageSizeFilter> ServerMessageSizeFilter::Create(
return ServerMessageSizeFilter(args); return ServerMessageSizeFilter(args);
} }
ArenaPromise<ServerMetadataHandle> ClientMessageSizeFilter::MakeCallPromise( namespace {
CallArgs call_args, NextPromiseFactory next_promise_factory) { ServerMetadataHandle CheckPayload(const Message& msg,
absl::optional<uint32_t> max_length,
bool is_send) {
if (!max_length.has_value()) return nullptr;
if (GRPC_TRACE_FLAG_ENABLED(grpc_call_trace)) {
gpr_log(GPR_INFO, "%s[message_size] %s len:%" PRIdPTR " max:%d",
Activity::current()->DebugTag().c_str(), is_send ? "send" : "recv",
msg.payload()->Length(), *max_length);
}
if (msg.payload()->Length() <= *max_length) return nullptr;
auto r = GetContext<Arena>()->MakePooled<ServerMetadata>(GetContext<Arena>());
r->Set(GrpcStatusMetadata(), GRPC_STATUS_RESOURCE_EXHAUSTED);
r->Set(GrpcMessageMetadata(), Slice::FromCopiedString(absl::StrFormat(
"%s message larger than max (%u vs. %d)",
is_send ? "Sent" : "Received",
msg.payload()->Length(), *max_length)));
return r;
}
} // namespace
ClientMessageSizeFilter::Call::Call(ClientMessageSizeFilter* filter)
: limits_(filter->parsed_config_) {
// Get max sizes from channel data, then merge in per-method config values. // Get max sizes from channel data, then merge in per-method config values.
// Note: Per-method config is only available on the client, so we // Note: Per-method config is only available on the client, so we
// apply the max request size to the send limit and the max response // apply the max request size to the send limit and the max response
// size to the receive limit. // size to the receive limit.
MessageSizeParsedConfig limits = this->limits();
const MessageSizeParsedConfig* config_from_call_context = const MessageSizeParsedConfig* config_from_call_context =
MessageSizeParsedConfig::GetFromCallContext( MessageSizeParsedConfig::GetFromCallContext(
GetContext<grpc_call_context_element>(), GetContext<grpc_call_context_element>(),
service_config_parser_index_); filter->service_config_parser_index_);
if (config_from_call_context != nullptr) { if (config_from_call_context != nullptr) {
absl::optional<uint32_t> max_send_size = limits.max_send_size(); absl::optional<uint32_t> max_send_size = limits_.max_send_size();
absl::optional<uint32_t> max_recv_size = limits.max_recv_size(); absl::optional<uint32_t> max_recv_size = limits_.max_recv_size();
if (config_from_call_context->max_send_size().has_value() && if (config_from_call_context->max_send_size().has_value() &&
(!max_send_size.has_value() || (!max_send_size.has_value() ||
*config_from_call_context->max_send_size() < *max_send_size)) { *config_from_call_context->max_send_size() < *max_send_size)) {
@ -226,21 +199,28 @@ ArenaPromise<ServerMetadataHandle> ClientMessageSizeFilter::MakeCallPromise(
*config_from_call_context->max_recv_size() < *max_recv_size)) { *config_from_call_context->max_recv_size() < *max_recv_size)) {
max_recv_size = *config_from_call_context->max_recv_size(); max_recv_size = *config_from_call_context->max_recv_size();
} }
limits = MessageSizeParsedConfig(max_send_size, max_recv_size); limits_ = MessageSizeParsedConfig(max_send_size, max_recv_size);
}
}
ServerMetadataHandle ServerMessageSizeFilter::Call::OnClientToServerMessage(
const Message& message, ServerMessageSizeFilter* filter) {
return CheckPayload(message, filter->parsed_config_.max_recv_size(), false);
}
ServerMetadataHandle ServerMessageSizeFilter::Call::OnServerToClientMessage(
const Message& message, ServerMessageSizeFilter* filter) {
return CheckPayload(message, filter->parsed_config_.max_send_size(), true);
} }
CallBuilder b(limits); ServerMetadataHandle ClientMessageSizeFilter::Call::OnClientToServerMessage(
b.AddSend(call_args.client_to_server_messages); const Message& message) {
b.AddRecv(call_args.server_to_client_messages); return CheckPayload(message, limits_.max_send_size(), true);
return b.Run(std::move(call_args), std::move(next_promise_factory));
} }
ArenaPromise<ServerMetadataHandle> ServerMessageSizeFilter::MakeCallPromise( ServerMetadataHandle ClientMessageSizeFilter::Call::OnServerToClientMessage(
CallArgs call_args, NextPromiseFactory next_promise_factory) { const Message& message) {
CallBuilder b(limits()); return CheckPayload(message, limits_.max_recv_size(), false);
b.AddSend(call_args.server_to_client_messages);
b.AddRecv(call_args.client_to_server_messages);
return b.Run(std::move(call_args), std::move(next_promise_factory));
} }
namespace { namespace {

@ -86,48 +86,58 @@ class MessageSizeParser : public ServiceConfigParser::Parser {
absl::optional<uint32_t> GetMaxRecvSizeFromChannelArgs(const ChannelArgs& args); absl::optional<uint32_t> GetMaxRecvSizeFromChannelArgs(const ChannelArgs& args);
absl::optional<uint32_t> GetMaxSendSizeFromChannelArgs(const ChannelArgs& args); absl::optional<uint32_t> GetMaxSendSizeFromChannelArgs(const ChannelArgs& args);
class MessageSizeFilter : public ChannelFilter { class ServerMessageSizeFilter final
protected: : public ImplementChannelFilter<ServerMessageSizeFilter> {
explicit MessageSizeFilter(const ChannelArgs& args)
: limits_(MessageSizeParsedConfig::GetFromChannelArgs(args)) {}
class CallBuilder;
const MessageSizeParsedConfig& limits() const { return limits_; }
private:
MessageSizeParsedConfig limits_;
};
class ServerMessageSizeFilter final : public MessageSizeFilter {
public: public:
static const grpc_channel_filter kFilter; static const grpc_channel_filter kFilter;
static absl::StatusOr<ServerMessageSizeFilter> Create( static absl::StatusOr<ServerMessageSizeFilter> Create(
const ChannelArgs& args, ChannelFilter::Args filter_args); const ChannelArgs& args, ChannelFilter::Args filter_args);
// Construct a promise for one call. class Call {
ArenaPromise<ServerMetadataHandle> MakeCallPromise( public:
CallArgs call_args, NextPromiseFactory next_promise_factory) override; static const NoInterceptor OnClientInitialMetadata;
static const NoInterceptor OnServerInitialMetadata;
static const NoInterceptor OnServerTrailingMetadata;
ServerMetadataHandle OnClientToServerMessage(
const Message& message, ServerMessageSizeFilter* filter);
ServerMetadataHandle OnServerToClientMessage(
const Message& message, ServerMessageSizeFilter* filter);
};
private: private:
using MessageSizeFilter::MessageSizeFilter; explicit ServerMessageSizeFilter(const ChannelArgs& args)
: parsed_config_(MessageSizeParsedConfig::GetFromChannelArgs(args)) {}
const MessageSizeParsedConfig parsed_config_;
}; };
class ClientMessageSizeFilter final : public MessageSizeFilter { class ClientMessageSizeFilter final
: public ImplementChannelFilter<ClientMessageSizeFilter> {
public: public:
static const grpc_channel_filter kFilter; static const grpc_channel_filter kFilter;
static absl::StatusOr<ClientMessageSizeFilter> Create( static absl::StatusOr<ClientMessageSizeFilter> Create(
const ChannelArgs& args, ChannelFilter::Args filter_args); const ChannelArgs& args, ChannelFilter::Args filter_args);
// Construct a promise for one call. class Call {
ArenaPromise<ServerMetadataHandle> MakeCallPromise( public:
CallArgs call_args, NextPromiseFactory next_promise_factory) override; explicit Call(ClientMessageSizeFilter* filter);
static const NoInterceptor OnClientInitialMetadata;
static const NoInterceptor OnServerInitialMetadata;
static const NoInterceptor OnServerTrailingMetadata;
ServerMetadataHandle OnClientToServerMessage(const Message& message);
ServerMetadataHandle OnServerToClientMessage(const Message& message);
private:
MessageSizeParsedConfig limits_;
};
private: private:
explicit ClientMessageSizeFilter(const ChannelArgs& args)
: parsed_config_(MessageSizeParsedConfig::GetFromChannelArgs(args)) {}
const size_t service_config_parser_index_{MessageSizeParser::ParserIndex()}; const size_t service_config_parser_index_{MessageSizeParser::ParserIndex()};
using MessageSizeFilter::MessageSizeFilter; const MessageSizeParsedConfig parsed_config_;
}; };
} // namespace grpc_core } // namespace grpc_core

@ -254,11 +254,31 @@ struct RaceAsyncCompletion<true> {
} }
}; };
// Zero-member wrapper to make sure that Call always has a constructor
// that takes a channel pointer (even if it's thrown away)
template <typename Derived, typename SfinaeVoid = void>
class CallWrapper;
template <typename Derived>
class CallWrapper<Derived, absl::void_t<decltype(typename Derived::Call(
std::declval<Derived*>()))>>
: public Derived::Call {
public:
explicit CallWrapper(Derived* channel) : Derived::Call(channel) {}
};
template <typename Derived>
class CallWrapper<Derived, absl::void_t<decltype(typename Derived::Call())>>
: public Derived::Call {
public:
explicit CallWrapper(Derived*) : Derived::Call() {}
};
// For the original promise scheme polyfill: data associated with once call. // For the original promise scheme polyfill: data associated with once call.
template <typename Derived> template <typename Derived>
struct FilterCallData { struct FilterCallData {
explicit FilterCallData(Derived* channel) : channel(channel) {} explicit FilterCallData(Derived* channel) : call(channel), channel(channel) {}
GPR_NO_UNIQUE_ADDRESS typename Derived::Call call; GPR_NO_UNIQUE_ADDRESS CallWrapper<Derived> call;
GPR_NO_UNIQUE_ADDRESS GPR_NO_UNIQUE_ADDRESS
typename TypeIfNeeded<Latch<ServerMetadataHandle>, typename TypeIfNeeded<Latch<ServerMetadataHandle>,
CallHasAsyncErrorInterceptor<Derived>()>::Type CallHasAsyncErrorInterceptor<Derived>()>::Type
@ -347,9 +367,68 @@ inline auto RunCall(void (Derived::Call::*fn)(ClientMetadata& md,
inline void InterceptClientToServerMessage(const NoInterceptor*, void*, inline void InterceptClientToServerMessage(const NoInterceptor*, void*,
const CallArgs&) {} const CallArgs&) {}
template <typename Derived>
inline void InterceptClientToServerMessage(
ServerMetadataHandle (Derived::Call::*fn)(const Message&),
FilterCallData<Derived>* call_data, const CallArgs& call_args) {
GPR_DEBUG_ASSERT(fn == &Derived::Call::OnClientToServerMessage);
call_args.client_to_server_messages->InterceptAndMap(
[call_data](MessageHandle msg) -> absl::optional<MessageHandle> {
auto return_md = call_data->call.OnClientToServerMessage(*msg);
if (return_md == nullptr) return std::move(msg);
if (call_data->error_latch.is_set()) return absl::nullopt;
call_data->error_latch.Set(std::move(return_md));
return absl::nullopt;
});
}
template <typename Derived>
inline void InterceptClientToServerMessage(
ServerMetadataHandle (Derived::Call::*fn)(const Message&, Derived*),
FilterCallData<Derived>* call_data, const CallArgs& call_args) {
GPR_DEBUG_ASSERT(fn == &Derived::Call::OnClientToServerMessage);
call_args.client_to_server_messages->InterceptAndMap(
[call_data](MessageHandle msg) -> absl::optional<MessageHandle> {
auto return_md =
call_data->call.OnClientToServerMessage(*msg, call_data->channel);
if (return_md == nullptr) return std::move(msg);
if (call_data->error_latch.is_set()) return absl::nullopt;
call_data->error_latch.Set(std::move(return_md));
return absl::nullopt;
});
}
inline void InterceptClientToServerMessage(const NoInterceptor*, void*, void*, inline void InterceptClientToServerMessage(const NoInterceptor*, void*, void*,
CallSpineInterface*) {} CallSpineInterface*) {}
template <typename Derived>
inline void InterceptClientToServerMessage(
ServerMetadataHandle (Derived::Call::*fn)(const Message&),
typename Derived::Call* call, Derived*, CallSpineInterface* call_spine) {
GPR_DEBUG_ASSERT(fn == &Derived::Call::OnClientToServerMessage);
call_spine->server_to_client_messages().sender.InterceptAndMap(
[call, call_spine](MessageHandle msg) -> absl::optional<MessageHandle> {
auto return_md = call->OnClientToServerMessage(*msg);
if (return_md == nullptr) return std::move(msg);
return call_spine->Cancel(std::move(return_md));
});
}
template <typename Derived>
inline void InterceptClientToServerMessage(
ServerMetadataHandle (Derived::Call::*fn)(const Message&, Derived*),
typename Derived::Call* call, Derived* channel,
CallSpineInterface* call_spine) {
GPR_DEBUG_ASSERT(fn == &Derived::Call::OnClientToServerMessage);
call_spine->server_to_client_messages().sender.InterceptAndMap(
[call, call_spine,
channel](MessageHandle msg) -> absl::optional<MessageHandle> {
auto return_md = call->OnClientToServerMessage(*msg, channel);
if (return_md == nullptr) return std::move(msg);
return call_spine->Cancel(std::move(return_md));
});
}
inline void InterceptClientInitialMetadata(const NoInterceptor*, void*, void*, inline void InterceptClientInitialMetadata(const NoInterceptor*, void*, void*,
CallSpineInterface*) {} CallSpineInterface*) {}
@ -441,7 +520,6 @@ inline void InterceptServerInitialMetadata(
}); });
} }
template <typename CallArgs>
inline void InterceptServerInitialMetadata(const NoInterceptor*, void*, void*, inline void InterceptServerInitialMetadata(const NoInterceptor*, void*, void*,
CallSpineInterface*) {} CallSpineInterface*) {}
@ -474,9 +552,68 @@ inline void InterceptServerInitialMetadata(
inline void InterceptServerToClientMessage(const NoInterceptor*, void*, inline void InterceptServerToClientMessage(const NoInterceptor*, void*,
const CallArgs&) {} const CallArgs&) {}
template <typename Derived>
inline void InterceptServerToClientMessage(
ServerMetadataHandle (Derived::Call::*fn)(const Message&),
FilterCallData<Derived>* call_data, const CallArgs& call_args) {
GPR_DEBUG_ASSERT(fn == &Derived::Call::OnServerToClientMessage);
call_args.server_to_client_messages->InterceptAndMap(
[call_data](MessageHandle msg) -> absl::optional<MessageHandle> {
auto return_md = call_data->call.OnServerToClientMessage(*msg);
if (return_md == nullptr) return std::move(msg);
if (call_data->error_latch.is_set()) return absl::nullopt;
call_data->error_latch.Set(std::move(return_md));
return absl::nullopt;
});
}
template <typename Derived>
inline void InterceptServerToClientMessage(
ServerMetadataHandle (Derived::Call::*fn)(const Message&, Derived*),
FilterCallData<Derived>* call_data, const CallArgs& call_args) {
GPR_DEBUG_ASSERT(fn == &Derived::Call::OnServerToClientMessage);
call_args.server_to_client_messages->InterceptAndMap(
[call_data](MessageHandle msg) -> absl::optional<MessageHandle> {
auto return_md =
call_data->call.OnServerToClientMessage(*msg, call_data->channel);
if (return_md == nullptr) return std::move(msg);
if (call_data->error_latch.is_set()) return absl::nullopt;
call_data->error_latch.Set(std::move(return_md));
return absl::nullopt;
});
}
inline void InterceptServerToClientMessage(const NoInterceptor*, void*, void*, inline void InterceptServerToClientMessage(const NoInterceptor*, void*, void*,
CallSpineInterface*) {} CallSpineInterface*) {}
template <typename Derived>
inline void InterceptServerToClientMessage(
ServerMetadataHandle (Derived::Call::*fn)(const Message&),
typename Derived::Call* call, Derived*, CallSpineInterface* call_spine) {
GPR_DEBUG_ASSERT(fn == &Derived::Call::OnServerToClientMessage);
call_spine->server_to_client_messages().sender.InterceptAndMap(
[call, call_spine](MessageHandle msg) -> absl::optional<MessageHandle> {
auto return_md = call->OnServerToClientMessage(*msg);
if (return_md == nullptr) return std::move(msg);
return call_spine->Cancel(std::move(return_md));
});
}
template <typename Derived>
inline void InterceptServerToClientMessage(
ServerMetadataHandle (Derived::Call::*fn)(const Message&, Derived*),
typename Derived::Call* call, Derived* channel,
CallSpineInterface* call_spine) {
GPR_DEBUG_ASSERT(fn == &Derived::Call::OnServerToClientMessage);
call_spine->server_to_client_messages().sender.InterceptAndMap(
[call, call_spine,
channel](MessageHandle msg) -> absl::optional<MessageHandle> {
auto return_md = call->OnServerToClientMessage(*msg, channel);
if (return_md == nullptr) return std::move(msg);
return call_spine->Cancel(std::move(return_md));
});
}
inline void InterceptServerTrailingMetadata(const NoInterceptor*, void*, void*, inline void InterceptServerTrailingMetadata(const NoInterceptor*, void*, void*,
CallSpineInterface*) {} CallSpineInterface*) {}
@ -574,7 +711,10 @@ class ImplementChannelFilter : public ChannelFilter {
public: public:
// Natively construct a v3 call. // Natively construct a v3 call.
void InitCall(CallSpineInterface* call_spine) { void InitCall(CallSpineInterface* call_spine) {
auto* call = GetContext<Arena>()->ManagedNew<typename Derived::Call>(); typename Derived::Call* call =
GetContext<Arena>()
->ManagedNew<promise_filter_detail::CallWrapper<Derived>>(
static_cast<Derived*>(this));
promise_filter_detail::InterceptClientInitialMetadata( promise_filter_detail::InterceptClientInitialMetadata(
&Derived::Call::OnClientInitialMetadata, call, &Derived::Call::OnClientInitialMetadata, call,
static_cast<Derived*>(this), call_spine); static_cast<Derived*>(this), call_spine);

Loading…
Cancel
Save