[channel-args] Reland UnionWith optimizations (#33163)

Same as yesterday, with a fix in
695e2e24ba.
pull/33167/head
Craig Tiller 2 years ago committed by GitHub
parent f60d0c7247
commit 0526a51734
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 10
      bazel/grpc_deps.bzl
  2. 26
      fuzztest/core/channel/BUILD
  3. 47
      fuzztest/core/channel/union_with_test.cc
  4. 5
      src/core/lib/avl/avl.h
  5. 102
      src/core/lib/channel/channel_args.cc
  6. 35
      src/core/lib/channel/channel_args.h
  7. 7
      src/core/lib/compression/compression_internal.cc
  8. 2
      tools/bazel.rc
  9. 1
      tools/distrib/fix_build_deps.py
  10. 11
      tools/fuzztest.bazelrc

@ -278,13 +278,15 @@ def grpc_deps():
) )
if "com_google_fuzztest" not in native.existing_rules(): if "com_google_fuzztest" not in native.existing_rules():
# when updating this remember to run:
# bazel run @com_google_fuzztest//bazel:setup_configs > tools/fuzztest.bazelrc
http_archive( http_archive(
name = "com_google_fuzztest", name = "com_google_fuzztest",
sha256 = "f7bb5b3bd162576f3fbbe9bb768b57931fdd98581c1818789aceee5be4eeee64", sha256 = "cdf8d8cd3cdc77280a7c59b310edf234e489a96b6e727cb271e7dfbeb9bcca8d",
strip_prefix = "fuzztest-62cf00c7341eb05d128d0a3cbce79ac31dbda032", strip_prefix = "fuzztest-4ecaeb5084a061a862af8f86789ee184cd3d3f18",
urls = [ urls = [
# 2023-03-03 # 2023-05-16
"https://github.com/google/fuzztest/archive/62cf00c7341eb05d128d0a3cbce79ac31dbda032.tar.gz", "https://github.com/google/fuzztest/archive/4ecaeb5084a061a862af8f86789ee184cd3d3f18.tar.gz",
], ],
) )

@ -0,0 +1,26 @@
# Copyright 2023 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load("//fuzztest:grpc_fuzz_test.bzl", "grpc_fuzz_test")
grpc_fuzz_test(
name = "union_with_test",
srcs = ["union_with_test.cc"],
external_deps = [
"fuzztest",
"fuzztest_main",
"gtest",
],
deps = ["//src/core:channel_args"],
)

@ -0,0 +1,47 @@
// Copyright 2023 gRPC authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Test to verify Fuzztest integration
#include "fuzztest/fuzztest.h"
#include "gtest/gtest.h"
#include "src/core/lib/channel/channel_args.h"
namespace grpc_core {
using IntOrString = absl::variant<int, std::string>;
using VectorOfArgs = std::vector<std::pair<std::string, IntOrString>>;
ChannelArgs ChannelArgsFromVector(VectorOfArgs va) {
ChannelArgs result;
for (auto& [key, value] : va) {
if (absl::holds_alternative<int>(value)) {
result = result.Set(key, absl::get<int>(value));
} else {
result = result.Set(key, absl::get<std::string>(value));
}
}
return result;
}
void UnionWithIsCorrect(VectorOfArgs va, VectorOfArgs vb) {
auto a = ChannelArgsFromVector(std::move(va));
auto b = ChannelArgsFromVector(std::move(vb));
EXPECT_EQ(a.UnionWith(b), a.FuzzingReferenceUnionWith(b));
}
FUZZ_TEST(MyTestSuite, UnionWithIsCorrect);
}

@ -87,6 +87,11 @@ class AVL {
return QsortCompare(*this, other) < 0; return QsortCompare(*this, other) < 0;
} }
size_t Height() const {
if (root_ == nullptr) return 0;
return root_->height;
}
private: private:
struct Node; struct Node;

@ -39,6 +39,7 @@
#include <grpc/support/string_util.h> #include <grpc/support/string_util.h>
#include "src/core/lib/gpr/useful.h" #include "src/core/lib/gpr/useful.h"
#include "src/core/lib/gprpp/crash.h"
#include "src/core/lib/gprpp/match.h" #include "src/core/lib/gprpp/match.h"
namespace grpc_core { namespace grpc_core {
@ -126,21 +127,55 @@ ChannelArgs ChannelArgs::FromC(const grpc_channel_args* args) {
return result; return result;
} }
grpc_arg ChannelArgs::Value::MakeCArg(const char* name) const {
char* c_name = const_cast<char*>(name);
return Match(
rep_,
[c_name](int i) { return grpc_channel_arg_integer_create(c_name, i); },
[c_name](const std::shared_ptr<const std::string>& s) {
return grpc_channel_arg_string_create(c_name,
const_cast<char*>(s->c_str()));
},
[c_name](const Pointer& p) {
return grpc_channel_arg_pointer_create(c_name, p.c_pointer(),
p.c_vtable());
});
}
bool ChannelArgs::Value::operator<(const Value& rhs) const {
if (rhs.rep_.index() != rep_.index()) return rep_.index() < rhs.rep_.index();
switch (rep_.index()) {
case 0:
return absl::get<int>(rep_) < absl::get<int>(rhs.rep_);
case 1:
return *absl::get<std::shared_ptr<const std::string>>(rep_) <
*absl::get<std::shared_ptr<const std::string>>(rhs.rep_);
case 2:
return absl::get<Pointer>(rep_) < absl::get<Pointer>(rhs.rep_);
default:
Crash("unreachable");
}
}
bool ChannelArgs::Value::operator==(const Value& rhs) const {
if (rhs.rep_.index() != rep_.index()) return false;
switch (rep_.index()) {
case 0:
return absl::get<int>(rep_) == absl::get<int>(rhs.rep_);
case 1:
return *absl::get<std::shared_ptr<const std::string>>(rep_) ==
*absl::get<std::shared_ptr<const std::string>>(rhs.rep_);
case 2:
return absl::get<Pointer>(rep_) == absl::get<Pointer>(rhs.rep_);
default:
Crash("unreachable");
}
}
ChannelArgs::CPtr ChannelArgs::ToC() const { ChannelArgs::CPtr ChannelArgs::ToC() const {
std::vector<grpc_arg> c_args; std::vector<grpc_arg> c_args;
args_.ForEach([&c_args](const std::string& key, const Value& value) { args_.ForEach([&c_args](const std::string& key, const Value& value) {
char* name = const_cast<char*>(key.c_str()); c_args.push_back(value.MakeCArg(key.c_str()));
c_args.push_back(Match(
value,
[name](int i) { return grpc_channel_arg_integer_create(name, i); },
[name](const std::string& s) {
return grpc_channel_arg_string_create(name,
const_cast<char*>(s.c_str()));
},
[name](const Pointer& p) {
return grpc_channel_arg_pointer_create(name, p.c_pointer(),
p.c_vtable());
}));
}); });
return CPtr(static_cast<const grpc_channel_args*>( return CPtr(static_cast<const grpc_channel_args*>(
grpc_channel_args_copy_and_add(nullptr, c_args.data(), c_args.size()))); grpc_channel_args_copy_and_add(nullptr, c_args.data(), c_args.size())));
@ -178,8 +213,9 @@ ChannelArgs ChannelArgs::Remove(absl::string_view key) const {
absl::optional<int> ChannelArgs::GetInt(absl::string_view name) const { absl::optional<int> ChannelArgs::GetInt(absl::string_view name) const {
auto* v = Get(name); auto* v = Get(name);
if (v == nullptr) return absl::nullopt; if (v == nullptr) return absl::nullopt;
if (!absl::holds_alternative<int>(*v)) return absl::nullopt; const auto* i = v->GetIfInt();
return absl::get<int>(*v); if (i == nullptr) return absl::nullopt;
return *i;
} }
absl::optional<Duration> ChannelArgs::GetDurationFromIntMillis( absl::optional<Duration> ChannelArgs::GetDurationFromIntMillis(
@ -195,8 +231,9 @@ absl::optional<absl::string_view> ChannelArgs::GetString(
absl::string_view name) const { absl::string_view name) const {
auto* v = Get(name); auto* v = Get(name);
if (v == nullptr) return absl::nullopt; if (v == nullptr) return absl::nullopt;
if (!absl::holds_alternative<std::string>(*v)) return absl::nullopt; const auto* s = v->GetIfString();
return absl::get<std::string>(*v); if (s == nullptr) return absl::nullopt;
return *s;
} }
absl::optional<std::string> ChannelArgs::GetOwnedString( absl::optional<std::string> ChannelArgs::GetOwnedString(
@ -209,14 +246,15 @@ absl::optional<std::string> ChannelArgs::GetOwnedString(
void* ChannelArgs::GetVoidPointer(absl::string_view name) const { void* ChannelArgs::GetVoidPointer(absl::string_view name) const {
auto* v = Get(name); auto* v = Get(name);
if (v == nullptr) return nullptr; if (v == nullptr) return nullptr;
if (!absl::holds_alternative<Pointer>(*v)) return nullptr; const auto* pp = v->GetIfPointer();
return absl::get<Pointer>(*v).c_pointer(); if (pp == nullptr) return nullptr;
return pp->c_pointer();
} }
absl::optional<bool> ChannelArgs::GetBool(absl::string_view name) const { absl::optional<bool> ChannelArgs::GetBool(absl::string_view name) const {
auto* v = Get(name); auto* v = Get(name);
if (v == nullptr) return absl::nullopt; if (v == nullptr) return absl::nullopt;
auto* i = absl::get_if<int>(v); auto* i = v->GetIfInt();
if (i == nullptr) { if (i == nullptr) {
gpr_log(GPR_ERROR, "%s ignored: it must be an integer", gpr_log(GPR_ERROR, "%s ignored: it must be an integer",
std::string(name).c_str()); std::string(name).c_str());
@ -238,11 +276,11 @@ std::string ChannelArgs::ToString() const {
std::vector<std::string> arg_strings; std::vector<std::string> arg_strings;
args_.ForEach([&arg_strings](const std::string& key, const Value& value) { args_.ForEach([&arg_strings](const std::string& key, const Value& value) {
std::string value_str; std::string value_str;
if (auto* i = absl::get_if<int>(&value)) { if (auto* i = value.GetIfInt()) {
value_str = std::to_string(*i); value_str = std::to_string(*i);
} else if (auto* s = absl::get_if<std::string>(&value)) { } else if (auto* s = value.GetIfString()) {
value_str = *s; value_str = *s;
} else if (auto* p = absl::get_if<Pointer>(&value)) { } else if (auto* p = value.GetIfPointer()) {
value_str = absl::StrFormat("%p", p->c_pointer()); value_str = absl::StrFormat("%p", p->c_pointer());
} }
arg_strings.push_back(absl::StrCat(key, "=", value_str)); arg_strings.push_back(absl::StrCat(key, "=", value_str));
@ -251,6 +289,26 @@ std::string ChannelArgs::ToString() const {
} }
ChannelArgs ChannelArgs::UnionWith(ChannelArgs other) const { ChannelArgs ChannelArgs::UnionWith(ChannelArgs other) const {
if (args_.Empty()) return other;
if (other.args_.Empty()) return *this;
if (args_.Height() <= other.args_.Height()) {
args_.ForEach([&other](const std::string& key, const Value& value) {
other.args_ = other.args_.Add(key, value);
});
return other;
} else {
auto result = *this;
other.args_.ForEach([&result](const std::string& key, const Value& value) {
if (result.args_.Lookup(key) == nullptr) {
result.args_ = result.args_.Add(key, value);
}
});
return result;
}
}
ChannelArgs ChannelArgs::FuzzingReferenceUnionWith(ChannelArgs other) const {
// DO NOT OPTIMIZE THIS!!
args_.ForEach([&other](const std::string& key, const Value& value) { args_.ForEach([&other](const std::string& key, const Value& value) {
other.args_ = other.args_.Add(key, value); other.args_ = other.args_.Add(key, value);
}); });

@ -281,7 +281,35 @@ class ChannelArgs {
const grpc_arg_pointer_vtable* vtable_; const grpc_arg_pointer_vtable* vtable_;
}; };
using Value = absl::variant<int, std::string, Pointer>; class Value {
public:
explicit Value(int n) : rep_(n) {}
explicit Value(std::string s)
: rep_(std::make_shared<const std::string>(std::move(s))) {}
explicit Value(Pointer p) : rep_(std::move(p)) {}
const int* GetIfInt() const { return absl::get_if<int>(&rep_); }
const std::string* GetIfString() const {
auto* p = absl::get_if<std::shared_ptr<const std::string>>(&rep_);
if (p == nullptr) return nullptr;
return p->get();
}
const Pointer* GetIfPointer() const { return absl::get_if<Pointer>(&rep_); }
grpc_arg MakeCArg(const char* name) const;
bool operator<(const Value& rhs) const;
bool operator==(const Value& rhs) const;
bool operator!=(const Value& rhs) const { return !this->operator==(rhs); }
bool operator==(absl::string_view rhs) const {
auto* p = absl::get_if<std::shared_ptr<const std::string>>(&rep_);
if (p == nullptr) return false;
return **p == rhs;
}
private:
absl::variant<int, std::shared_ptr<const std::string>, Pointer> rep_;
};
struct ChannelArgsDeleter { struct ChannelArgsDeleter {
void operator()(const grpc_channel_args* p) const; void operator()(const grpc_channel_args* p) const;
@ -307,6 +335,11 @@ class ChannelArgs {
// If a key is present in both, the value from this is used. // If a key is present in both, the value from this is used.
GRPC_MUST_USE_RESULT ChannelArgs UnionWith(ChannelArgs other) const; GRPC_MUST_USE_RESULT ChannelArgs UnionWith(ChannelArgs other) const;
// Only used in union_with_test.cc, reference version of UnionWith for
// differential fuzzing.
GRPC_MUST_USE_RESULT ChannelArgs
FuzzingReferenceUnionWith(ChannelArgs other) const;
const Value* Get(absl::string_view name) const; const Value* Get(absl::string_view name) const;
GRPC_MUST_USE_RESULT ChannelArgs Set(absl::string_view name, GRPC_MUST_USE_RESULT ChannelArgs Set(absl::string_view name,
Pointer value) const; Pointer value) const;

@ -22,13 +22,10 @@
#include <stdlib.h> #include <stdlib.h>
#include <string>
#include "absl/container/inlined_vector.h" #include "absl/container/inlined_vector.h"
#include "absl/strings/ascii.h" #include "absl/strings/ascii.h"
#include "absl/strings/str_format.h" #include "absl/strings/str_format.h"
#include "absl/strings/str_split.h" #include "absl/strings/str_split.h"
#include "absl/types/variant.h"
#include <grpc/support/log.h> #include <grpc/support/log.h>
@ -230,10 +227,10 @@ absl::optional<grpc_compression_algorithm>
DefaultCompressionAlgorithmFromChannelArgs(const ChannelArgs& args) { DefaultCompressionAlgorithmFromChannelArgs(const ChannelArgs& args) {
auto* value = args.Get(GRPC_COMPRESSION_CHANNEL_DEFAULT_ALGORITHM); auto* value = args.Get(GRPC_COMPRESSION_CHANNEL_DEFAULT_ALGORITHM);
if (value == nullptr) return absl::nullopt; if (value == nullptr) return absl::nullopt;
if (auto* p = absl::get_if<int>(value)) { if (auto* p = value->GetIfInt()) {
return static_cast<grpc_compression_algorithm>(*p); return static_cast<grpc_compression_algorithm>(*p);
} }
if (auto* p = absl::get_if<std::string>(value)) { if (auto* p = value->GetIfString()) {
return ParseCompressionAlgorithm(*p); return ParseCompressionAlgorithm(*p);
} }
return absl::nullopt; return absl::nullopt;

@ -145,3 +145,5 @@ build:compdb --build_tag_filters=-nocompdb --features=-layering_check
try-import %workspace%/tools/fuzztest.bazelrc try-import %workspace%/tools/fuzztest.bazelrc
build:fuzztest --cxxopt=-std=c++17 build:fuzztest --cxxopt=-std=c++17
build:fuzztest_test --cxxopt=-std=c++17

@ -461,6 +461,7 @@ for dirname in [
"test/core/resource_quota", "test/core/resource_quota",
"test/core/transport/chaotic_good", "test/core/transport/chaotic_good",
"fuzztest", "fuzztest",
"fuzztest/core/channel",
]: ]:
parsing_path = dirname parsing_path = dirname
exec( exec(

@ -13,8 +13,9 @@
# #
# Do not use directly. # Do not use directly.
# Link with Address Sanitizer (ASAN). # Compile and link with Address Sanitizer (ASAN).
build:fuzztest-common --linkopt=-fsanitize=address build:fuzztest-common --linkopt=-fsanitize=address
build:fuzztest-common --copt=-fsanitize=address
# Standard define for "ifdef-ing" any fuzz test specific code. # Standard define for "ifdef-ing" any fuzz test specific code.
build:fuzztest-common --copt=-DFUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION build:fuzztest-common --copt=-DFUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION
@ -22,6 +23,10 @@ build:fuzztest-common --copt=-DFUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION
# In fuzz tests, we want to catch assertion violations even in optimized builds. # In fuzz tests, we want to catch assertion violations even in optimized builds.
build:fuzztest-common --copt=-UNDEBUG build:fuzztest-common --copt=-UNDEBUG
# Enable libc++ assertions.
# See https://libcxx.llvm.org/UsingLibcxx.html#enabling-the-safe-libc-mode
build:fuzztest-common --copt=-D_LIBCPP_ENABLE_ASSERTIONS=1
### FuzzTest build configuration. ### FuzzTest build configuration.
# #
@ -37,7 +42,7 @@ build:fuzztest --dynamic_mode=off
# the uninstrumented runtime. # the uninstrumented runtime.
build:fuzztest --copt=-DADDRESS_SANITIZER build:fuzztest --copt=-DADDRESS_SANITIZER
# We apply coverage tracking and ASAN instrumentation to everything but the # We apply coverage tracking instrumentation to everything but the
# FuzzTest framework itself (including GoogleTest and GoogleMock). # FuzzTest framework itself (including GoogleTest and GoogleMock).
build:fuzztest --per_file_copt=+//,-fuzztest/.*,-googletest/.*,-googlemock/.*@-fsanitize=address,-fsanitize-coverage=inline-8bit-counters,-fsanitize-coverage=trace-cmp build:fuzztest --per_file_copt=+//,-fuzztest/.*,-googletest/.*,-googlemock/.*@-fsanitize-coverage=inline-8bit-counters,-fsanitize-coverage=trace-cmp

Loading…
Cancel
Save