Add a Sanitize method to metadata map to allow filtering select metadata types from the map.

Includes a unit test

PiperOrigin-RevId: 658632905
pull/37228/head
Vignesh Babu 4 months ago committed by Copybara-Service
parent 449d1b248f
commit 848272c570
  1. 28
      src/core/lib/gprpp/table.h
  2. 41
      src/core/lib/transport/metadata_batch.h
  3. 77
      test/core/transport/metadata_map_test.cc

@ -334,6 +334,15 @@ class Table {
absl::index_sequence<table_detail::IndexOf<Vs, Ts...>()...>());
}
// Iterate through each set field in the table if it exists in Vs, in the
// order of Vs. For each existing field, call the filter function. If the
// function returns true, keep the field. Otherwise, remove the field.
template <typename F, typename... Vs>
void FilterIn(F f) {
FilterInImpl(std::move(f),
absl::index_sequence<table_detail::IndexOf<Vs, Ts...>()...>());
}
// Count the number of set fields in the table
size_t count() const { return present_bits_.count(); }
@ -415,6 +424,18 @@ class Table {
}
}
// Call (*f)(value) if that value is in the table.
// If the value is present in the table and (*f)(value) returns false, remove
// the value from the table.
template <size_t I, typename F>
void FilterIf(F* f) {
if (auto* p = get<I>()) {
if (!(*f)(*p)) {
clear<I>();
}
}
}
// For each field (element I=0, 1, ...) if that field is present, call its
// destructor.
template <size_t... I>
@ -444,6 +465,13 @@ class Table {
table_detail::do_these_things<int>({(CallIf<I>(&f), 1)...});
}
// For each field (element I=0, 1, ...) if that field is present, call f. If
// f returns false, remove the field from the table.
template <typename F, size_t... I>
void FilterInImpl(F f, absl::index_sequence<I...>) {
table_detail::do_these_things<int>({(FilterIf<I>(&f), 1)...});
}
template <size_t... I>
void ClearAllImpl(absl::index_sequence<I...>) {
table_detail::do_these_things<int>({(clear<I>(), 1)...});

@ -1052,6 +1052,26 @@ struct LogWrapper {
}
};
// Callable for the table FilterIn -- for each value, call the
// appropriate filter method to determine of the value should be kept or
// removed.
template <typename Filterer>
struct FilterWrapper {
Filterer filter_fn;
template <typename Which,
absl::enable_if_t<IsEncodableTrait<Which>::value, bool> = true>
bool operator()(const Value<Which>& /*which*/) {
return filter_fn(Which());
}
template <typename Which,
absl::enable_if_t<!IsEncodableTrait<Which>::value, bool> = true>
bool operator()(const Value<Which>& /*which*/) {
return true;
}
};
// Encoder to compute TransportSize
class TransportSizeEncoder {
public:
@ -1094,6 +1114,16 @@ class UnknownMap {
BackingType::const_iterator begin() const { return unknown_.cbegin(); }
BackingType::const_iterator end() const { return unknown_.cend(); }
template <typename Filterer>
void Filter(Filterer* filter_fn) {
unknown_.erase(
std::remove_if(unknown_.begin(), unknown_.end(),
[&](auto& pair) {
return !(*filter_fn)(pair.first.as_string_view());
}),
unknown_.end());
}
bool empty() const { return unknown_.empty(); }
size_t size() const { return unknown_.size(); }
void Clear() { unknown_.clear(); }
@ -1314,6 +1344,17 @@ class MetadataMap {
}
}
// Filter the metadata map.
// Iterates over all encodable and unknown headers and calls the filter_fn
// for each of them. If the function returns true, the header is kept.
template <typename Filterer>
void Filter(Filterer filter_fn) {
table_.template FilterIn<metadata_detail::FilterWrapper<Filterer>,
Value<Traits>...>(
metadata_detail::FilterWrapper<Filterer>{filter_fn});
unknown_.Filter<Filterer>(&filter_fn);
}
std::string DebugString() const {
metadata_detail::DebugStringBuilder builder;
Log([&builder](absl::string_view key, absl::string_view value) {

@ -260,6 +260,83 @@ TEST(DebugStringBuilderTest, TestAllRedacted) {
}
}
std::vector<std::string> GetEncodableHeaders() {
return {
// clang-format off
std::string(ContentTypeMetadata::key()),
std::string(EndpointLoadMetricsBinMetadata::key()),
std::string(GrpcAcceptEncodingMetadata::key()),
std::string(GrpcEncodingMetadata::key()),
std::string(GrpcInternalEncodingRequest::key()),
std::string(GrpcLbClientStatsMetadata::key()),
std::string(GrpcMessageMetadata::key()),
std::string(GrpcPreviousRpcAttemptsMetadata::key()),
std::string(GrpcRetryPushbackMsMetadata::key()),
std::string(GrpcServerStatsBinMetadata::key()),
std::string(GrpcStatusMetadata::key()),
std::string(GrpcTagsBinMetadata::key()),
std::string(GrpcTimeoutMetadata::key()),
std::string(GrpcTraceBinMetadata::key()),
std::string(HostMetadata::key()),
std::string(HttpAuthorityMetadata::key()),
std::string(HttpMethodMetadata::key()),
std::string(HttpPathMetadata::key()),
std::string(HttpSchemeMetadata::key()),
std::string(HttpStatusMetadata::key()),
std::string(LbCostBinMetadata::key()),
std::string(LbTokenMetadata::key()),
std::string(TeMetadata::key()),
// clang-format on
};
}
template <typename NonEncodableHeader, typename Value>
void AddNonEncodableHeader(grpc_metadata_batch& md, Value value) {
md.Set(NonEncodableHeader(), value);
}
template <bool filter_unknown>
class HeaderFilter {
public:
template <typename Key>
bool operator()(Key) {
return filter_unknown;
}
bool operator()(absl::string_view /*key*/) { return !filter_unknown; }
};
TEST(MetadataMapTest, FilterTest) {
grpc_metadata_batch map;
std::vector<std::string> allow_list_keys = GetEncodableHeaders();
std::vector<std::string> unknown_keys = {"unknown_key_1", "unknown_key_2"};
allow_list_keys.insert(allow_list_keys.end(), unknown_keys.begin(),
unknown_keys.end());
// Add some encodable and unknown headers
for (const std::string& curr_key : allow_list_keys) {
map.Append(curr_key, Slice::FromStaticString("value1"),
[](absl::string_view /*error*/, const Slice& /*value*/) {});
}
// Add 5 non-encodable headers
constexpr int kNumNonEncodableHeaders = 5;
AddNonEncodableHeader<GrpcCallWasCancelled, bool>(map, true);
AddNonEncodableHeader<GrpcRegisteredMethod, void*>(map, nullptr);
AddNonEncodableHeader<GrpcStatusContext, std::string>(map, "value1");
AddNonEncodableHeader<GrpcStatusFromWire>(map, "value1");
AddNonEncodableHeader<GrpcStreamNetworkState,
GrpcStreamNetworkState::ValueType>(
map, GrpcStreamNetworkState::kNotSentOnWire);
EXPECT_EQ(map.count(), allow_list_keys.size() + kNumNonEncodableHeaders);
// Remove all unknown headers
map.Filter(HeaderFilter<true>());
EXPECT_EQ(map.count(), allow_list_keys.size() + kNumNonEncodableHeaders -
unknown_keys.size());
// Remove all encodable headers
map.Filter(HeaderFilter<false>());
EXPECT_EQ(map.count(), kNumNonEncodableHeaders);
}
} // namespace testing
} // namespace grpc_core

Loading…
Cancel
Save