From cc3d0349487ae35126a4ba61aa23437b16e25262 Mon Sep 17 00:00:00 2001 From: Craig Tiller Date: Tue, 16 May 2023 19:55:55 -0700 Subject: [PATCH] [channel_args] Optimize UnionWith (#33154) --------- Co-authored-by: ctiller --- bazel/grpc_deps.bzl | 10 +- fuzztest/core/channel/BUILD | 27 +++++ fuzztest/core/channel/union_with_test.cc | 47 ++++++++ src/core/lib/avl/avl.h | 5 + src/core/lib/channel/channel_args.cc | 102 ++++++++++++++---- src/core/lib/channel/channel_args.h | 35 +++++- .../lib/compression/compression_internal.cc | 7 +- tools/bazel.rc | 2 + tools/fuzztest.bazelrc | 11 +- 9 files changed, 211 insertions(+), 35 deletions(-) create mode 100644 fuzztest/core/channel/BUILD create mode 100644 fuzztest/core/channel/union_with_test.cc diff --git a/bazel/grpc_deps.bzl b/bazel/grpc_deps.bzl index dcabaf9bca1..869f743e8b6 100644 --- a/bazel/grpc_deps.bzl +++ b/bazel/grpc_deps.bzl @@ -278,13 +278,15 @@ def grpc_deps(): ) if "com_google_fuzztest" not in native.existing_rules(): + # when updating this remember to run: + # bazel run @com_google_fuzztest//bazel:setup_configs > tools/fuzztest.bazelrc http_archive( name = "com_google_fuzztest", - sha256 = "f7bb5b3bd162576f3fbbe9bb768b57931fdd98581c1818789aceee5be4eeee64", - strip_prefix = "fuzztest-62cf00c7341eb05d128d0a3cbce79ac31dbda032", + sha256 = "cdf8d8cd3cdc77280a7c59b310edf234e489a96b6e727cb271e7dfbeb9bcca8d", + strip_prefix = "fuzztest-4ecaeb5084a061a862af8f86789ee184cd3d3f18", urls = [ - # 2023-03-03 - "https://github.com/google/fuzztest/archive/62cf00c7341eb05d128d0a3cbce79ac31dbda032.tar.gz", + # 2023-05-16 + "https://github.com/google/fuzztest/archive/4ecaeb5084a061a862af8f86789ee184cd3d3f18.tar.gz", ], ) diff --git a/fuzztest/core/channel/BUILD b/fuzztest/core/channel/BUILD new file mode 100644 index 00000000000..d6192adb0d1 --- /dev/null +++ b/fuzztest/core/channel/BUILD @@ -0,0 +1,27 @@ +# Copyright 2023 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//fuzztest:grpc_fuzz_test.bzl", "grpc_fuzz_test") + +grpc_fuzz_test( + name = "union_with_test", + srcs = ["union_with_test.cc"], + external_deps = [ + "fuzztest", + "fuzztest_main", + ], + deps = [ + "//src/core:channel_args", + ], +) diff --git a/fuzztest/core/channel/union_with_test.cc b/fuzztest/core/channel/union_with_test.cc new file mode 100644 index 00000000000..6e3d0a06d19 --- /dev/null +++ b/fuzztest/core/channel/union_with_test.cc @@ -0,0 +1,47 @@ +// Copyright 2023 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Test to verify Fuzztest integration + +#include "fuzztest/fuzztest.h" + +#include "gtest/gtest.h" + +#include "src/core/lib/channel/channel_args.h" + +namespace grpc_core { + +using IntOrString = absl::variant; +using VectorOfArgs = std::vector>; + +ChannelArgs ChannelArgsFromVector(VectorOfArgs va) { + ChannelArgs result; + for (auto& [key, value] : va) { + if (absl::holds_alternative(value)) { + result = result.Set(key, absl::get(value)); + } else { + result = result.Set(key, absl::get(value)); + } + } + return result; +} + +void UnionWithIsCorrect(VectorOfArgs va, VectorOfArgs vb) { + auto a = ChannelArgsFromVector(std::move(va)); + auto b = ChannelArgsFromVector(std::move(vb)); + EXPECT_EQ(a.UnionWith(b), a.FuzzingReferenceUnionWith(b)); +} +FUZZ_TEST(MyTestSuite, UnionWithIsCorrect); + +} diff --git a/src/core/lib/avl/avl.h b/src/core/lib/avl/avl.h index e44fb5d8a40..9ce529689cb 100644 --- a/src/core/lib/avl/avl.h +++ b/src/core/lib/avl/avl.h @@ -87,6 +87,11 @@ class AVL { return QsortCompare(*this, other) < 0; } + size_t Height() const { + if (root_ == nullptr) return 0; + return root_->height; + } + private: struct Node; diff --git a/src/core/lib/channel/channel_args.cc b/src/core/lib/channel/channel_args.cc index 7f54a320b4b..9fced0de4a1 100644 --- a/src/core/lib/channel/channel_args.cc +++ b/src/core/lib/channel/channel_args.cc @@ -39,6 +39,7 @@ #include #include "src/core/lib/gpr/useful.h" +#include "src/core/lib/gprpp/crash.h" #include "src/core/lib/gprpp/match.h" namespace grpc_core { @@ -126,21 +127,55 @@ ChannelArgs ChannelArgs::FromC(const grpc_channel_args* args) { return result; } +grpc_arg ChannelArgs::Value::MakeCArg(const char* name) const { + char* c_name = const_cast(name); + return Match( + rep_, + [c_name](int i) { return grpc_channel_arg_integer_create(c_name, i); }, + [c_name](const std::shared_ptr& s) { + return grpc_channel_arg_string_create(c_name, + const_cast(s->c_str())); + }, + [c_name](const Pointer& p) { + return grpc_channel_arg_pointer_create(c_name, p.c_pointer(), + p.c_vtable()); + }); +} + +bool ChannelArgs::Value::operator<(const Value& rhs) const { + if (rhs.rep_.index() != rep_.index()) return rep_.index() < rhs.rep_.index(); + switch (rep_.index()) { + case 0: + return absl::get(rep_) < absl::get(rhs.rep_); + case 1: + return *absl::get>(rep_) < + *absl::get>(rhs.rep_); + case 2: + return absl::get(rep_) < absl::get(rhs.rep_); + default: + Crash("unreachable"); + } +} + +bool ChannelArgs::Value::operator==(const Value& rhs) const { + if (rhs.rep_.index() != rep_.index()) return false; + switch (rep_.index()) { + case 0: + return absl::get(rep_) == absl::get(rhs.rep_); + case 1: + return *absl::get>(rep_) == + *absl::get>(rhs.rep_); + case 2: + return absl::get(rep_) == absl::get(rhs.rep_); + default: + Crash("unreachable"); + } +} + ChannelArgs::CPtr ChannelArgs::ToC() const { std::vector c_args; args_.ForEach([&c_args](const std::string& key, const Value& value) { - char* name = const_cast(key.c_str()); - c_args.push_back(Match( - value, - [name](int i) { return grpc_channel_arg_integer_create(name, i); }, - [name](const std::string& s) { - return grpc_channel_arg_string_create(name, - const_cast(s.c_str())); - }, - [name](const Pointer& p) { - return grpc_channel_arg_pointer_create(name, p.c_pointer(), - p.c_vtable()); - })); + c_args.push_back(value.MakeCArg(key.c_str())); }); return CPtr(static_cast( grpc_channel_args_copy_and_add(nullptr, c_args.data(), c_args.size()))); @@ -178,8 +213,9 @@ ChannelArgs ChannelArgs::Remove(absl::string_view key) const { absl::optional ChannelArgs::GetInt(absl::string_view name) const { auto* v = Get(name); if (v == nullptr) return absl::nullopt; - if (!absl::holds_alternative(*v)) return absl::nullopt; - return absl::get(*v); + const auto* i = v->GetIfInt(); + if (i == nullptr) return absl::nullopt; + return *i; } absl::optional ChannelArgs::GetDurationFromIntMillis( @@ -195,8 +231,9 @@ absl::optional ChannelArgs::GetString( absl::string_view name) const { auto* v = Get(name); if (v == nullptr) return absl::nullopt; - if (!absl::holds_alternative(*v)) return absl::nullopt; - return absl::get(*v); + const auto* s = v->GetIfString(); + if (s == nullptr) return absl::nullopt; + return *s; } absl::optional ChannelArgs::GetOwnedString( @@ -209,14 +246,15 @@ absl::optional ChannelArgs::GetOwnedString( void* ChannelArgs::GetVoidPointer(absl::string_view name) const { auto* v = Get(name); if (v == nullptr) return nullptr; - if (!absl::holds_alternative(*v)) return nullptr; - return absl::get(*v).c_pointer(); + const auto* pp = v->GetIfPointer(); + if (pp == nullptr) return nullptr; + return pp->c_pointer(); } absl::optional ChannelArgs::GetBool(absl::string_view name) const { auto* v = Get(name); if (v == nullptr) return absl::nullopt; - auto* i = absl::get_if(v); + auto* i = v->GetIfInt(); if (i == nullptr) { gpr_log(GPR_ERROR, "%s ignored: it must be an integer", std::string(name).c_str()); @@ -238,11 +276,11 @@ std::string ChannelArgs::ToString() const { std::vector arg_strings; args_.ForEach([&arg_strings](const std::string& key, const Value& value) { std::string value_str; - if (auto* i = absl::get_if(&value)) { + if (auto* i = value.GetIfInt()) { value_str = std::to_string(*i); - } else if (auto* s = absl::get_if(&value)) { + } else if (auto* s = value.GetIfString()) { value_str = *s; - } else if (auto* p = absl::get_if(&value)) { + } else if (auto* p = value.GetIfPointer()) { value_str = absl::StrFormat("%p", p->c_pointer()); } arg_strings.push_back(absl::StrCat(key, "=", value_str)); @@ -251,6 +289,26 @@ std::string ChannelArgs::ToString() const { } ChannelArgs ChannelArgs::UnionWith(ChannelArgs other) const { + if (args_.Empty()) return other; + if (other.args_.Empty()) return *this; + if (args_.Height() <= other.args_.Height()) { + args_.ForEach([&other](const std::string& key, const Value& value) { + other.args_ = other.args_.Add(key, value); + }); + return other; + } else { + auto result = *this; + other.args_.ForEach([&result](const std::string& key, const Value& value) { + if (result.args_.Lookup(key) == nullptr) { + result.args_ = result.args_.Add(key, value); + } + }); + return result; + } +} + +ChannelArgs ChannelArgs::FuzzingReferenceUnionWith(ChannelArgs other) const { + // DO NOT OPTIMIZE THIS!! args_.ForEach([&other](const std::string& key, const Value& value) { other.args_ = other.args_.Add(key, value); }); diff --git a/src/core/lib/channel/channel_args.h b/src/core/lib/channel/channel_args.h index 2ff18146de9..8bf1a9c84c7 100644 --- a/src/core/lib/channel/channel_args.h +++ b/src/core/lib/channel/channel_args.h @@ -281,7 +281,35 @@ class ChannelArgs { const grpc_arg_pointer_vtable* vtable_; }; - using Value = absl::variant; + class Value { + public: + explicit Value(int n) : rep_(n) {} + explicit Value(std::string s) + : rep_(std::make_shared(std::move(s))) {} + explicit Value(Pointer p) : rep_(std::move(p)) {} + + const int* GetIfInt() const { return absl::get_if(&rep_); } + const std::string* GetIfString() const { + auto* p = absl::get_if>(&rep_); + if (p == nullptr) return nullptr; + return p->get(); + } + const Pointer* GetIfPointer() const { return absl::get_if(&rep_); } + + grpc_arg MakeCArg(const char* name) const; + + bool operator<(const Value& rhs) const; + bool operator==(const Value& rhs) const; + bool operator!=(const Value& rhs) const { return !this->operator==(rhs); } + bool operator==(absl::string_view rhs) const { + auto* p = absl::get_if>(&rep_); + if (p == nullptr) return false; + return **p == rhs; + } + + private: + absl::variant, Pointer> rep_; + }; struct ChannelArgsDeleter { void operator()(const grpc_channel_args* p) const; @@ -307,6 +335,11 @@ class ChannelArgs { // If a key is present in both, the value from this is used. GRPC_MUST_USE_RESULT ChannelArgs UnionWith(ChannelArgs other) const; + // Only used in union_with_test.cc, reference version of UnionWith for + // differential fuzzing. + GRPC_MUST_USE_RESULT ChannelArgs + FuzzingReferenceUnionWith(ChannelArgs other) const; + const Value* Get(absl::string_view name) const; GRPC_MUST_USE_RESULT ChannelArgs Set(absl::string_view name, Pointer value) const; diff --git a/src/core/lib/compression/compression_internal.cc b/src/core/lib/compression/compression_internal.cc index 36044e0b6c2..4871e46f705 100644 --- a/src/core/lib/compression/compression_internal.cc +++ b/src/core/lib/compression/compression_internal.cc @@ -22,13 +22,10 @@ #include -#include - #include "absl/container/inlined_vector.h" #include "absl/strings/ascii.h" #include "absl/strings/str_format.h" #include "absl/strings/str_split.h" -#include "absl/types/variant.h" #include @@ -230,10 +227,10 @@ absl::optional DefaultCompressionAlgorithmFromChannelArgs(const ChannelArgs& args) { auto* value = args.Get(GRPC_COMPRESSION_CHANNEL_DEFAULT_ALGORITHM); if (value == nullptr) return absl::nullopt; - if (auto* p = absl::get_if(value)) { + if (auto* p = value->GetIfInt()) { return static_cast(*p); } - if (auto* p = absl::get_if(value)) { + if (auto* p = value->GetIfString()) { return ParseCompressionAlgorithm(*p); } return absl::nullopt; diff --git a/tools/bazel.rc b/tools/bazel.rc index da0d5892086..976ac748622 100644 --- a/tools/bazel.rc +++ b/tools/bazel.rc @@ -145,3 +145,5 @@ build:compdb --build_tag_filters=-nocompdb --features=-layering_check try-import %workspace%/tools/fuzztest.bazelrc build:fuzztest --cxxopt=-std=c++17 + +build:fuzztest_test --cxxopt=-std=c++17 diff --git a/tools/fuzztest.bazelrc b/tools/fuzztest.bazelrc index b2616effcea..1721cf56bde 100644 --- a/tools/fuzztest.bazelrc +++ b/tools/fuzztest.bazelrc @@ -13,8 +13,9 @@ # # Do not use directly. -# Link with Address Sanitizer (ASAN). +# Compile and link with Address Sanitizer (ASAN). build:fuzztest-common --linkopt=-fsanitize=address +build:fuzztest-common --copt=-fsanitize=address # Standard define for "ifdef-ing" any fuzz test specific code. build:fuzztest-common --copt=-DFUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION @@ -22,6 +23,10 @@ build:fuzztest-common --copt=-DFUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION # In fuzz tests, we want to catch assertion violations even in optimized builds. build:fuzztest-common --copt=-UNDEBUG +# Enable libc++ assertions. +# See https://libcxx.llvm.org/UsingLibcxx.html#enabling-the-safe-libc-mode +build:fuzztest-common --copt=-D_LIBCPP_ENABLE_ASSERTIONS=1 + ### FuzzTest build configuration. # @@ -37,7 +42,7 @@ build:fuzztest --dynamic_mode=off # the uninstrumented runtime. build:fuzztest --copt=-DADDRESS_SANITIZER -# We apply coverage tracking and ASAN instrumentation to everything but the +# We apply coverage tracking instrumentation to everything but the # FuzzTest framework itself (including GoogleTest and GoogleMock). -build:fuzztest --per_file_copt=+//,-fuzztest/.*,-googletest/.*,-googlemock/.*@-fsanitize=address,-fsanitize-coverage=inline-8bit-counters,-fsanitize-coverage=trace-cmp +build:fuzztest --per_file_copt=+//,-fuzztest/.*,-googletest/.*,-googlemock/.*@-fsanitize-coverage=inline-8bit-counters,-fsanitize-coverage=trace-cmp