From 8004254a535f3df3fbd2c728a922819a2b8c82cf Mon Sep 17 00:00:00 2001 From: Craig Tiller Date: Fri, 28 Jul 2023 15:01:24 -0700 Subject: [PATCH] [channel_args] Size optimizations (#33901) - Make `Value` a simple wrapper around `Pointer` and use some blessed vtables to distinguish strings vs ints vs actual pointers - this saves 8 bytes per value stored - introduce `RcString` as a lightweight container around an immutable string - this saves some bytes vs the shared_ptr approach we previously had, and importantly opens up the technique (via `RcStringValue`) to channel node keys also, which should increase sharing and consequently also decrease total memory usage --------- Co-authored-by: ctiller --- build_autogenerated.yaml | 2 - src/core/BUILD | 2 - src/core/lib/channel/channel_args.cc | 152 ++++++++++-------- src/core/lib/channel/channel_args.h | 130 +++++++++++++-- .../lib/compression/compression_internal.cc | 11 +- src/core/lib/gprpp/ref_counted.h | 4 +- 6 files changed, 205 insertions(+), 96 deletions(-) diff --git a/build_autogenerated.yaml b/build_autogenerated.yaml index 49236f99ad5..4a8a34eec8e 100644 --- a/build_autogenerated.yaml +++ b/build_autogenerated.yaml @@ -7076,9 +7076,7 @@ targets: - src/core/lib/event_engine/channel_args_endpoint_config.h - src/core/lib/gprpp/atomic_utils.h - src/core/lib/gprpp/dual_ref_counted.h - - src/core/lib/gprpp/match.h - src/core/lib/gprpp/orphanable.h - - src/core/lib/gprpp/overload.h - src/core/lib/gprpp/ref_counted.h - src/core/lib/gprpp/ref_counted_ptr.h - src/core/lib/gprpp/time.h diff --git a/src/core/BUILD b/src/core/BUILD index 14f615158b1..f1a323e6687 100644 --- a/src/core/BUILD +++ b/src/core/BUILD @@ -2624,7 +2624,6 @@ grpc_cc_library( "absl/strings", "absl/strings:str_format", "absl/types:optional", - "absl/types:variant", ], language = "c++", visibility = [ @@ -2634,7 +2633,6 @@ grpc_cc_library( "avl", "channel_stack_type", "dual_ref_counted", - "match", "ref_counted", "time", "useful", diff --git a/src/core/lib/channel/channel_args.cc b/src/core/lib/channel/channel_args.cc index 6a74d9b1044..4496abad4b9 100644 --- a/src/core/lib/channel/channel_args.cc +++ b/src/core/lib/channel/channel_args.cc @@ -27,6 +27,9 @@ #include #include #include +#include +#include +#include #include #include "absl/strings/match.h" @@ -40,11 +43,47 @@ #include #include "src/core/lib/gpr/useful.h" -#include "src/core/lib/gprpp/crash.h" -#include "src/core/lib/gprpp/match.h" namespace grpc_core { +RefCountedPtr RcString::Make(absl::string_view src) { + void* p = gpr_malloc(sizeof(Header) + src.length() + 1); + return RefCountedPtr(new (p) RcString(src)); +} + +RcString::RcString(absl::string_view src) : header_{{}, src.length()} { + memcpy(payload_, src.data(), header_.length); + // Null terminate because we frequently need to convert to char* still to go + // back and forth to the old c-style api. + payload_[header_.length] = 0; +} + +void RcString::Destroy() { gpr_free(this); } + +const grpc_arg_pointer_vtable ChannelArgs::Value::int_vtable_{ + // copy + [](void* p) { return p; }, + // destroy + [](void*) {}, + // cmp + [](void* p1, void* p2) -> int { + return QsortCompare(reinterpret_cast(p1), + reinterpret_cast(p2)); + }, +}; + +const grpc_arg_pointer_vtable ChannelArgs::Value::string_vtable_{ + // copy + [](void* p) -> void* { return static_cast(p)->Ref().release(); }, + // destroy + [](void* p) { static_cast(p)->Unref(); }, + // cmp + [](void* p1, void* p2) -> int { + return QsortCompare(static_cast(p1)->as_string_view(), + static_cast(p2)->as_string_view()); + }, +}; + ChannelArgs::Pointer::Pointer(void* p, const grpc_arg_pointer_vtable* vtable) : p_(p), vtable_(vtable == nullptr ? EmptyVTable() : vtable) {} @@ -100,7 +139,7 @@ bool ChannelArgs::WantMinimalStack() const { return GetBool(GRPC_ARG_MINIMAL_STACK).value_or(false); } -ChannelArgs::ChannelArgs(AVL args) +ChannelArgs::ChannelArgs(AVL args) : args_(std::move(args)) {} ChannelArgs ChannelArgs::Set(grpc_arg arg) const { @@ -130,52 +169,22 @@ ChannelArgs ChannelArgs::FromC(const grpc_channel_args* args) { grpc_arg ChannelArgs::Value::MakeCArg(const char* name) const { char* c_name = const_cast(name); - return Match( - rep_, - [c_name](int i) { return grpc_channel_arg_integer_create(c_name, i); }, - [c_name](const std::shared_ptr& s) { - return grpc_channel_arg_string_create(c_name, - const_cast(s->c_str())); - }, - [c_name](const Pointer& p) { - return grpc_channel_arg_pointer_create(c_name, p.c_pointer(), - p.c_vtable()); - }); -} - -bool ChannelArgs::Value::operator<(const Value& rhs) const { - if (rhs.rep_.index() != rep_.index()) return rep_.index() < rhs.rep_.index(); - switch (rep_.index()) { - case 0: - return absl::get(rep_) < absl::get(rhs.rep_); - case 1: - return *absl::get>(rep_) < - *absl::get>(rhs.rep_); - case 2: - return absl::get(rep_) < absl::get(rhs.rep_); - default: - Crash("unreachable"); + if (rep_.c_vtable() == &int_vtable_) { + return grpc_channel_arg_integer_create( + c_name, reinterpret_cast(rep_.c_pointer())); } -} - -bool ChannelArgs::Value::operator==(const Value& rhs) const { - if (rhs.rep_.index() != rep_.index()) return false; - switch (rep_.index()) { - case 0: - return absl::get(rep_) == absl::get(rhs.rep_); - case 1: - return *absl::get>(rep_) == - *absl::get>(rhs.rep_); - case 2: - return absl::get(rep_) == absl::get(rhs.rep_); - default: - Crash("unreachable"); + if (rep_.c_vtable() == &string_vtable_) { + return grpc_channel_arg_string_create( + c_name, + const_cast(static_cast(rep_.c_pointer())->c_str())); } + return grpc_channel_arg_pointer_create(c_name, rep_.c_pointer(), + rep_.c_vtable()); } ChannelArgs::CPtr ChannelArgs::ToC() const { std::vector c_args; - args_.ForEach([&c_args](const std::string& key, const Value& value) { + args_.ForEach([&c_args](const RcStringValue& key, const Value& value) { c_args.push_back(value.MakeCArg(key.c_str())); }); return CPtr(static_cast( @@ -194,7 +203,7 @@ ChannelArgs ChannelArgs::Set(absl::string_view name, Value value) const { if (const auto* p = args_.Lookup(name)) { if (*p == value) return *this; // already have this value for this key } - return ChannelArgs(args_.Add(std::string(name), std::move(value))); + return ChannelArgs(args_.Add(RcStringValue(name), std::move(value))); } ChannelArgs ChannelArgs::Set(absl::string_view name, @@ -218,8 +227,8 @@ ChannelArgs ChannelArgs::Remove(absl::string_view name) const { ChannelArgs ChannelArgs::RemoveAllKeysWithPrefix( absl::string_view prefix) const { auto args = args_; - args_.ForEach([&args, prefix](const std::string& key, const Value&) { - if (absl::StartsWith(key, prefix)) args = args.Remove(key); + args_.ForEach([&](const RcStringValue& key, const Value&) { + if (absl::StartsWith(key.as_string_view(), prefix)) args = args.Remove(key); }); return ChannelArgs(std::move(args)); } @@ -227,9 +236,7 @@ ChannelArgs ChannelArgs::RemoveAllKeysWithPrefix( absl::optional ChannelArgs::GetInt(absl::string_view name) const { auto* v = Get(name); if (v == nullptr) return absl::nullopt; - const auto* i = v->GetIfInt(); - if (i == nullptr) return absl::nullopt; - return *i; + return v->GetIfInt(); } absl::optional ChannelArgs::GetDurationFromIntMillis( @@ -245,9 +252,9 @@ absl::optional ChannelArgs::GetString( absl::string_view name) const { auto* v = Get(name); if (v == nullptr) return absl::nullopt; - const auto* s = v->GetIfString(); + const auto s = v->GetIfString(); if (s == nullptr) return absl::nullopt; - return *s; + return s->as_string_view(); } absl::optional ChannelArgs::GetOwnedString( @@ -268,8 +275,8 @@ void* ChannelArgs::GetVoidPointer(absl::string_view name) const { absl::optional ChannelArgs::GetBool(absl::string_view name) const { auto* v = Get(name); if (v == nullptr) return absl::nullopt; - auto* i = v->GetIfInt(); - if (i == nullptr) { + auto i = v->GetIfInt(); + if (!i.has_value()) { gpr_log(GPR_ERROR, "%s ignored: it must be an integer", std::string(name).c_str()); return absl::nullopt; @@ -286,18 +293,22 @@ absl::optional ChannelArgs::GetBool(absl::string_view name) const { } } +std::string ChannelArgs::Value::ToString() const { + if (rep_.c_vtable() == &int_vtable_) { + return std::to_string(reinterpret_cast(rep_.c_pointer())); + } + if (rep_.c_vtable() == &string_vtable_) { + return std::string( + static_cast(rep_.c_pointer())->as_string_view()); + } + return absl::StrFormat("%p", rep_.c_pointer()); +} + std::string ChannelArgs::ToString() const { std::vector arg_strings; - args_.ForEach([&arg_strings](const std::string& key, const Value& value) { - std::string value_str; - if (auto* i = value.GetIfInt()) { - value_str = std::to_string(*i); - } else if (auto* s = value.GetIfString()) { - value_str = *s; - } else if (auto* p = value.GetIfPointer()) { - value_str = absl::StrFormat("%p", p->c_pointer()); - } - arg_strings.push_back(absl::StrCat(key, "=", value_str)); + args_.ForEach([&arg_strings](const RcStringValue& key, const Value& value) { + arg_strings.push_back( + absl::StrCat(key.as_string_view(), "=", value.ToString())); }); return absl::StrCat("{", absl::StrJoin(arg_strings, ", "), "}"); } @@ -306,24 +317,25 @@ ChannelArgs ChannelArgs::UnionWith(ChannelArgs other) const { if (args_.Empty()) return other; if (other.args_.Empty()) return *this; if (args_.Height() <= other.args_.Height()) { - args_.ForEach([&other](const std::string& key, const Value& value) { + args_.ForEach([&other](const RcStringValue& key, const Value& value) { other.args_ = other.args_.Add(key, value); }); return other; } else { auto result = *this; - other.args_.ForEach([&result](const std::string& key, const Value& value) { - if (result.args_.Lookup(key) == nullptr) { - result.args_ = result.args_.Add(key, value); - } - }); + other.args_.ForEach( + [&result](const RcStringValue& key, const Value& value) { + if (result.args_.Lookup(key) == nullptr) { + result.args_ = result.args_.Add(key, value); + } + }); return result; } } ChannelArgs ChannelArgs::FuzzingReferenceUnionWith(ChannelArgs other) const { // DO NOT OPTIMIZE THIS!! - args_.ForEach([&other](const std::string& key, const Value& value) { + args_.ForEach([&other](const RcStringValue& key, const Value& value) { other.args_ = other.args_.Add(key, value); }); return other; diff --git a/src/core/lib/channel/channel_args.h b/src/core/lib/channel/channel_args.h index 57d6f35e335..676ca6066f8 100644 --- a/src/core/lib/channel/channel_args.h +++ b/src/core/lib/channel/channel_args.h @@ -22,6 +22,7 @@ #include #include +#include #include // IWYU pragma: keep #include @@ -33,7 +34,6 @@ #include "absl/meta/type_traits.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" -#include "absl/types/variant.h" #include #include @@ -220,6 +220,91 @@ struct GetObjectImpl::value, void>> { }; }; +// Immutable reference counted string +class RcString { + public: + static RefCountedPtr Make(absl::string_view src); + + RefCountedPtr Ref() { + IncrementRefCount(); + return RefCountedPtr(this); + } + void IncrementRefCount() { header_.rc.Ref(); } + void Unref() { + if (header_.rc.Unref()) Destroy(); + } + + absl::string_view as_string_view() const { + return absl::string_view(payload_, header_.length); + } + + char* c_str() { return payload_; } + + private: + explicit RcString(absl::string_view src); + void Destroy(); + + struct Header { + RefCount rc; + size_t length; + }; + Header header_; + char payload_[]; +}; + +// Wrapper around RefCountedPtr to give value semantics, especially to +// overloaded operators. +class RcStringValue { + public: + RcStringValue() : str_{} {} + explicit RcStringValue(absl::string_view str) : str_(RcString::Make(str)) {} + + absl::string_view as_string_view() const { + return str_ == nullptr ? absl::string_view() : str_->as_string_view(); + } + + const char* c_str() const { return str_ == nullptr ? "" : str_->c_str(); } + + private: + RefCountedPtr str_; +}; + +inline bool operator==(const RcStringValue& lhs, absl::string_view rhs) { + return lhs.as_string_view() == rhs; +} + +inline bool operator==(absl::string_view lhs, const RcStringValue& rhs) { + return lhs == rhs.as_string_view(); +} + +inline bool operator==(const RcStringValue& lhs, const RcStringValue& rhs) { + return lhs.as_string_view() == rhs.as_string_view(); +} + +inline bool operator<(const RcStringValue& lhs, absl::string_view rhs) { + return lhs.as_string_view() < rhs; +} + +inline bool operator<(absl::string_view lhs, const RcStringValue& rhs) { + return lhs < rhs.as_string_view(); +} + +inline bool operator<(const RcStringValue& lhs, const RcStringValue& rhs) { + return lhs.as_string_view() < rhs.as_string_view(); +} + +inline bool operator>(const RcStringValue& lhs, absl::string_view rhs) { + return lhs.as_string_view() > rhs; +} + +inline bool operator>(absl::string_view lhs, const RcStringValue& rhs) { + return lhs > rhs.as_string_view(); +} + +inline bool operator>(const RcStringValue& lhs, const RcStringValue& rhs) { + return lhs.as_string_view() > rhs.as_string_view(); +} + // Provide the canonical name for a type's channel arg key template struct ChannelArgNameTraits { @@ -283,32 +368,43 @@ class ChannelArgs { class Value { public: - explicit Value(int n) : rep_(n) {} + explicit Value(int n) : rep_(reinterpret_cast(n), &int_vtable_) {} explicit Value(std::string s) - : rep_(std::make_shared(std::move(s))) {} + : rep_(RcString::Make(s).release(), &string_vtable_) {} explicit Value(Pointer p) : rep_(std::move(p)) {} - const int* GetIfInt() const { return absl::get_if(&rep_); } - const std::string* GetIfString() const { - auto* p = absl::get_if>(&rep_); - if (p == nullptr) return nullptr; - return p->get(); + absl::optional GetIfInt() const { + if (rep_.c_vtable() != &int_vtable_) return absl::nullopt; + return reinterpret_cast(rep_.c_pointer()); + } + RefCountedPtr GetIfString() const { + if (rep_.c_vtable() != &string_vtable_) return nullptr; + return static_cast(rep_.c_pointer())->Ref(); + } + const Pointer* GetIfPointer() const { + if (rep_.c_vtable() == &int_vtable_) return nullptr; + if (rep_.c_vtable() == &string_vtable_) return nullptr; + return &rep_; } - const Pointer* GetIfPointer() const { return absl::get_if(&rep_); } + + std::string ToString() const; grpc_arg MakeCArg(const char* name) const; - bool operator<(const Value& rhs) const; - bool operator==(const Value& rhs) const; + bool operator<(const Value& rhs) const { return rep_ < rhs.rep_; } + bool operator==(const Value& rhs) const { return rep_ == rhs.rep_; } bool operator!=(const Value& rhs) const { return !this->operator==(rhs); } bool operator==(absl::string_view rhs) const { - auto* p = absl::get_if>(&rep_); - if (p == nullptr) return false; - return **p == rhs; + auto str = GetIfString(); + if (str == nullptr) return false; + return str->as_string_view() == rhs; } private: - absl::variant, Pointer> rep_; + static const grpc_arg_pointer_vtable int_vtable_; + static const grpc_arg_pointer_vtable string_vtable_; + + Pointer rep_; }; struct ChannelArgsDeleter { @@ -462,12 +558,12 @@ class ChannelArgs { std::string ToString() const; private: - explicit ChannelArgs(AVL args); + explicit ChannelArgs(AVL args); GRPC_MUST_USE_RESULT ChannelArgs Set(absl::string_view name, Value value) const; - AVL args_; + AVL args_; }; std::ostream& operator<<(std::ostream& out, const ChannelArgs& args); diff --git a/src/core/lib/compression/compression_internal.cc b/src/core/lib/compression/compression_internal.cc index 4871e46f705..793e8969d87 100644 --- a/src/core/lib/compression/compression_internal.cc +++ b/src/core/lib/compression/compression_internal.cc @@ -32,6 +32,7 @@ #include "src/core/lib/channel/channel_args.h" #include "src/core/lib/debug/trace.h" #include "src/core/lib/gprpp/crash.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/surface/api_trace.h" namespace grpc_core { @@ -227,11 +228,13 @@ absl::optional DefaultCompressionAlgorithmFromChannelArgs(const ChannelArgs& args) { auto* value = args.Get(GRPC_COMPRESSION_CHANNEL_DEFAULT_ALGORITHM); if (value == nullptr) return absl::nullopt; - if (auto* p = value->GetIfInt()) { - return static_cast(*p); + auto ival = value->GetIfInt(); + if (ival.has_value()) { + return static_cast(*ival); } - if (auto* p = value->GetIfString()) { - return ParseCompressionAlgorithm(*p); + auto sval = value->GetIfString(); + if (sval != nullptr) { + return ParseCompressionAlgorithm(sval->as_string_view()); } return absl::nullopt; } diff --git a/src/core/lib/gprpp/ref_counted.h b/src/core/lib/gprpp/ref_counted.h index 66ef0424899..cdf692c5ce7 100644 --- a/src/core/lib/gprpp/ref_counted.h +++ b/src/core/lib/gprpp/ref_counted.h @@ -45,12 +45,14 @@ class RefCount { public: using Value = intptr_t; + RefCount() : RefCount(1) {} + // `init` is the initial refcount stored in this object. // // `trace` is a string to be logged with trace events; if null, no // trace logging will be done. Tracing is a no-op in non-debug builds. explicit RefCount( - Value init = 1, + Value init, const char* #ifndef NDEBUG // Leave unnamed if NDEBUG to avoid unused parameter warning