From 200e24900e4e4337e785aa211b03750f886fb3ef Mon Sep 17 00:00:00 2001 From: Protobuf Team Bot Date: Tue, 16 May 2023 15:42:38 -0700 Subject: [PATCH] Make the fallback tree use a variant key. This allows moving most of the tree logic into the .cc instead of having it duplicated on each template instantiation. This reduces code size of the cold paths, making the hot paths more inlineable. Make the iterator base completely untyped now that the tree fallback is untyped. More code duplication reduction, and it will allow further improvements on MapField in followup changes. Move clearing logic to the .cc and optimize it. Having a single copy of it allows adding more logic to it without bloating the template instantiations: - The map destructor will no longer reset the table before deleting it. That was wasted work. - Use prefetching to load the nodes ahead of time. Even for trivially destructible nodes we need to read the `next` pointer. - Start the clearing on index_of_first_non_null_ instead of 0. - Check for arena==nullptr only once for the whole call instead of once per element. PiperOrigin-RevId: 532595044 --- src/google/protobuf/map.cc | 204 +++++++++++ src/google/protobuf/map.h | 571 +++++++++++++++---------------- src/google/protobuf/map_field.cc | 19 + src/google/protobuf/map_field.h | 44 +-- src/google/protobuf/map_test.inc | 10 +- 5 files changed, 527 insertions(+), 321 deletions(-) diff --git a/src/google/protobuf/map.cc b/src/google/protobuf/map.cc index 3490fcae20..e8e139c20f 100644 --- a/src/google/protobuf/map.cc +++ b/src/google/protobuf/map.cc @@ -30,12 +30,216 @@ #include "google/protobuf/map.h" +#include +#include +#include +#include +#include + +#include "absl/hash/hash.h" +#include "absl/strings/string_view.h" +#include "google/protobuf/message_lite.h" + + +// Must be included last. +#include "google/protobuf/port_def.inc" + namespace google { namespace protobuf { namespace internal { const TableEntryPtr kGlobalEmptyTable[kGlobalEmptyTableSize] = {}; +NodeBase* UntypedMapBase::DestroyTree(Tree* tree) { + NodeBase* head = tree->empty() ? nullptr : tree->begin()->second; + if (alloc_.arena() == nullptr) { + delete tree; + } + return head; +} + +void UntypedMapBase::EraseFromTree(size_type b, + typename Tree::iterator tree_it) { + ABSL_DCHECK(TableEntryIsTree(b)); + Tree* tree = TableEntryToTree(table_[b]); + if (tree_it != tree->begin()) { + NodeBase* prev = std::prev(tree_it)->second; + prev->next = prev->next->next; + } + tree->erase(tree_it); + if (tree->empty()) { + DestroyTree(tree); + table_[b] = TableEntryPtr{}; + } +} + +size_t UntypedMapBase::VariantBucketNumber(VariantKey key) const { + return BucketNumberFromHash(key.Hash()); +} + +void UntypedMapBase::InsertUniqueInTree(size_type b, GetKey get_key, + NodeBase* node) { + if (TableEntryIsNonEmptyList(b)) { + // To save in binary size, we delegate to an out-of-line function to do + // the conversion. + table_[b] = ConvertToTree(TableEntryToNode(table_[b]), get_key); + } + ABSL_DCHECK(TableEntryIsTree(b)) + << (void*)table_[b] << " " << (uintptr_t)table_[b]; + + Tree* tree = TableEntryToTree(table_[b]); + auto it = tree->try_emplace(get_key(node), node).first; + // Maintain the linked list of the nodes in the tree. + // For simplicity, they are in the same order as the tree iteration. + if (it != tree->begin()) { + NodeBase* prev = std::prev(it)->second; + prev->next = node; + } + auto next = std::next(it); + node->next = next != tree->end() ? next->second : nullptr; +} + +void UntypedMapBase::TransferTree(Tree* tree, GetKey get_key) { + NodeBase* node = DestroyTree(tree); + do { + NodeBase* next = node->next; + + size_type b = VariantBucketNumber(get_key(node)); + // This is similar to InsertUnique, but with erasure. + if (TableEntryIsEmpty(b)) { + InsertUniqueInList(b, node); + index_of_first_non_null_ = (std::min)(index_of_first_non_null_, b); + } else if (TableEntryIsNonEmptyList(b) && !TableEntryIsTooLong(b)) { + InsertUniqueInList(b, node); + } else { + InsertUniqueInTree(b, get_key, node); + } + + node = next; + } while (node != nullptr); +} + +TableEntryPtr UntypedMapBase::ConvertToTree(NodeBase* node, GetKey get_key) { + auto* tree = Arena::Create(alloc_.arena(), typename Tree::key_compare(), + typename Tree::allocator_type(alloc_)); + for (; node != nullptr; node = node->next) { + tree->try_emplace(get_key(node), node); + } + ABSL_DCHECK_EQ(MapTreeLengthThreshold(), tree->size()); + + // Relink the nodes. + NodeBase* next = nullptr; + auto it = tree->end(); + do { + node = (--it)->second; + node->next = next; + next = node; + } while (it != tree->begin()); + + return TreeToTableEntry(tree); +} + +void UntypedMapBase::ClearTable(const ClearInput input) { + ABSL_DCHECK_NE(num_buckets_, kGlobalEmptyTableSize); + + if (alloc_.arena() == nullptr) { + const auto loop = [=](auto destroy_node) { + const TableEntryPtr* table = table_; + for (size_type b = index_of_first_non_null_, end = num_buckets_; b < end; + ++b) { + NodeBase* node = + PROTOBUF_PREDICT_FALSE(internal::TableEntryIsTree(table[b])) + ? DestroyTree(TableEntryToTree(table[b])) + : TableEntryToNode(table[b]); + + while (node != nullptr) { + NodeBase* next = node->next; + destroy_node(node); + SizedDelete(node, SizeFromInfo(input.size_info)); + node = next; + } + } + }; + switch (input.destroy_bits) { + case 0: + loop([](NodeBase*) {}); + break; + case kKeyIsString: + loop([](NodeBase* node) { + static_cast(node->GetVoidKey())->~basic_string(); + }); + break; + case kValueIsString: + loop([size_info = input.size_info](NodeBase* node) { + static_cast(node->GetVoidValue(size_info)) + ->~basic_string(); + }); + break; + case kKeyIsString | kValueIsString: + loop([size_info = input.size_info](NodeBase* node) { + static_cast(node->GetVoidKey())->~basic_string(); + static_cast(node->GetVoidValue(size_info)) + ->~basic_string(); + }); + break; + case kValueIsProto: + loop([size_info = input.size_info](NodeBase* node) { + static_cast(node->GetVoidValue(size_info)) + ->~MessageLite(); + }); + break; + case kKeyIsString | kValueIsProto: + loop([size_info = input.size_info](NodeBase* node) { + static_cast(node->GetVoidKey())->~basic_string(); + static_cast(node->GetVoidValue(size_info)) + ->~MessageLite(); + }); + break; + case kUseDestructFunc: + loop(input.destroy_node); + break; + } + } + + if (input.reset_table) { + std::fill(table_, table_ + num_buckets_, TableEntryPtr{}); + num_elements_ = 0; + index_of_first_non_null_ = num_buckets_; + } else { + DeleteTable(table_, num_buckets_); + } +} + +auto UntypedMapBase::FindFromTree(size_type b, VariantKey key, + Tree::iterator* it) const -> NodeAndBucket { + Tree* tree = TableEntryToTree(table_[b]); + auto tree_it = tree->find(key); + if (it != nullptr) *it = tree_it; + if (tree_it != tree->end()) { + return {tree_it->second, b}; + } + return {nullptr, b}; +} + +size_t UntypedMapBase::SpaceUsedInTable(size_t sizeof_node) const { + size_t size = 0; + // The size of the table. + size += sizeof(void*) * num_buckets_; + // All the nodes. + size += sizeof_node * num_elements_; + // For each tree, count the overhead of those nodes. + // Two buckets at a time because we only care about trees. + for (size_t b = 0; b < num_buckets_; ++b) { + if (TableEntryIsTree(b)) { + size += sizeof(Tree); + size += sizeof(Tree::value_type) * TableEntryToTree(table_[b])->size(); + } + } + return size; +} + } // namespace internal } // namespace protobuf } // namespace google + +#include "google/protobuf/port_undef.inc" diff --git a/src/google/protobuf/map.h b/src/google/protobuf/map.h index 7b8d08ec45..00dbefbb0e 100644 --- a/src/google/protobuf/map.h +++ b/src/google/protobuf/map.h @@ -38,6 +38,7 @@ #define GOOGLE_PROTOBUF_MAP_H__ #include +#include #include #include #include @@ -221,13 +222,20 @@ using KeyForBase = typename KeyForBaseImpl::type; // only accept `key_type`. template struct TransparentSupport { - using hash = std::hash; - using less = std::less; + // We hash all the scalars as uint64_t so that we can implement the same hash + // function for VariantKey. This way we can have MapKey provide the same hash + // as the underlying value would have. + using hash = std::hash< + std::conditional_t::value, uint64_t, key_type>>; static bool Equals(const key_type& a, const key_type& b) { return a == b; } template using key_arg = key_type; + + using ViewType = std::conditional_t::value, key_type, + const key_type&>; + static ViewType ToView(const key_type& v) { return v; } }; // We add transparent support for std::string keys. We use @@ -274,15 +282,6 @@ struct TransparentSupport { ImplicitConvert(std::forward(str))); } }; - struct less { - using is_transparent = void; - - template - bool operator()(T&& t, U&& u) const { - return ImplicitConvert(std::forward(t)) < - ImplicitConvert(std::forward(u)); - } - }; template static bool Equals(T&& t, U&& u) { @@ -292,6 +291,12 @@ struct TransparentSupport { template using key_arg = K; + + using ViewType = absl::string_view; + template + static ViewType ToView(const T& v) { + return ImplicitConvert(v); + } }; enum class MapNodeSizeInfoT : uint32_t; @@ -329,8 +334,9 @@ inline NodeBase* EraseFromLinkedList(NodeBase* item, NodeBase* head) { } } +constexpr size_t MapTreeLengthThreshold() { return 8; } inline bool TableEntryIsTooLong(NodeBase* node) { - const size_t kMaxLength = 8; + const size_t kMaxLength = MapTreeLengthThreshold(); size_t count = 0; do { ++count; @@ -341,18 +347,64 @@ inline bool TableEntryIsTooLong(NodeBase* node) { return count >= kMaxLength; } -template -using KeyForTree = std::conditional_t::value, uint64_t, - std::reference_wrapper>; +// Similar to the public MapKey, but specialized for the internal +// implementation. +struct VariantKey { + // We make this value 16 bytes to make it cheaper to pass in the ABI. + // Can't overload string_view this way, so we unpack the fields. + // data==nullptr means this is a number and `integral` is the value. + // data!=nullptr means this is a string and `integral` is the size. + const char* data; + uint64_t integral; + + explicit VariantKey(uint64_t v) : data(nullptr), integral(v) {} + explicit VariantKey(absl::string_view v) + : data(v.data()), integral(v.size()) { + // We use `data` to discriminate between the types, so make sure it is never + // null here. + if (data == nullptr) data = ""; + } + + size_t Hash() const { + return data == nullptr ? std::hash{}(integral) + : absl::Hash{}( + absl::string_view(data, integral)); + } + + friend bool operator<(const VariantKey& left, const VariantKey& right) { + ABSL_DCHECK_EQ(left.data == nullptr, right.data == nullptr); + if (left.integral != right.integral) { + // If they are numbers with different value, or strings with different + // size, check the number only. + return left.integral < right.integral; + } + if (left.data == nullptr) { + // If they are numbers they have the same value, so return. + return false; + } + // They are strings of the same size, so check the bytes. + return memcmp(left.data, right.data, left.integral) < 0; + } +}; +// This is to be specialized by MapKey. template -using LessForTree = typename TransparentSupport< - std::conditional_t::value, uint64_t, T>>::less; +struct RealKeyToVariantKey { + VariantKey operator()(T value) const { return VariantKey(value); } +}; -template +template <> +struct RealKeyToVariantKey { + template + VariantKey operator()(const T& value) const { + return VariantKey(TransparentSupport::ImplicitConvert(value)); + } +}; + +// We use a single kind of tree for all maps. This reduces code duplication. using TreeForMap = - absl::btree_map, NodeBase*, LessForTree, - MapAllocator, NodeBase*>>>; + absl::btree_map, + MapAllocator>>; // Type safe tagged pointer. // We convert to/from nodes and trees using the operations below. @@ -383,13 +435,11 @@ inline TableEntryPtr NodeToTableEntry(NodeBase* node) { ABSL_DCHECK((reinterpret_cast(node) & 1) == 0); return static_cast(reinterpret_cast(node)); } -template -Tree* TableEntryToTree(TableEntryPtr entry) { +inline TreeForMap* TableEntryToTree(TableEntryPtr entry) { ABSL_DCHECK(TableEntryIsTree(entry)); - return reinterpret_cast(static_cast(entry) - 1); + return reinterpret_cast(static_cast(entry) - 1); } -template -TableEntryPtr TreeToTableEntry(Tree* node) { +inline TableEntryPtr TreeToTableEntry(TreeForMap* node) { ABSL_DCHECK((reinterpret_cast(node) & 1) == 0); return static_cast(reinterpret_cast(node) | 1); } @@ -409,33 +459,6 @@ constexpr size_t kGlobalEmptyTableSize = 1; PROTOBUF_EXPORT extern const TableEntryPtr kGlobalEmptyTable[kGlobalEmptyTableSize]; -// Space used for the table, trees, and nodes. -// Does not include the indirect space used. Eg the data of a std::string. -template -PROTOBUF_NOINLINE size_t SpaceUsedInTable(TableEntryPtr* table, - size_t num_buckets, - size_t num_elements, - size_t sizeof_node) { - size_t size = 0; - // The size of the table. - size += sizeof(void*) * num_buckets; - // All the nodes. - size += sizeof_node * num_elements; - // For each tree, count the overhead of the those nodes. - // Two buckets at a time because we only care about trees. - for (size_t b = 0; b < num_buckets; ++b) { - if (internal::TableEntryIsTree(table[b])) { - using Tree = TreeForMap; - Tree* tree = TableEntryToTree(table[b]); - // Estimated cost of the red-black tree nodes, 3 pointers plus a - // bool (plus alignment, so 4 pointers). - size += tree->size() * - (sizeof(typename Tree::value_type) + sizeof(void*) * 4); - } - } - return size; -} - template ::value || @@ -449,8 +472,64 @@ size_t SpaceUsedInValues(const Map* map) { return size; } +// Multiply two numbers where overflow is expected. +template +N MultiplyWithOverflow(N a, N b) { +#if defined(PROTOBUF_HAS_BUILTIN_MUL_OVERFLOW) + N res; + (void)__builtin_mul_overflow(a, b, &res); + return res; +#else + return a * b; +#endif +} + inline size_t SpaceUsedInValues(const void*) { return 0; } +class UntypedMapBase; + +class UntypedMapIterator { + public: + // Invariants: + // node_ is always correct. This is handy because the most common + // operations are operator* and operator-> and they only use node_. + // When node_ is set to a non-null value, all the other non-const fields + // are updated to be correct also, but those fields can become stale + // if the underlying map is modified. When those fields are needed they + // are rechecked, and updated if necessary. + UntypedMapIterator() : node_(nullptr), m_(nullptr), bucket_index_(0) {} + + explicit UntypedMapIterator(const UntypedMapBase* m); + + UntypedMapIterator(NodeBase* n, const UntypedMapBase* m, size_t index) + : node_(n), m_(m), bucket_index_(index) {} + + // Advance through buckets, looking for the first that isn't empty. + // If nothing non-empty is found then leave node_ == nullptr. + void SearchFrom(size_t start_bucket); + + // The definition of operator== is handled by the derived type. If we were + // to do it in this class it would allow comparing iterators of different + // map types. + bool Equals(const UntypedMapIterator& other) const { + return node_ == other.node_; + } + + // The definition of operator++ is handled in the derived type. We would not + // be able to return the right type from here. + void PlusPlus() { + if (node_->next == nullptr) { + SearchFrom(bucket_index_ + 1); + } else { + node_ = node_->next; + } + } + + NodeBase* node_; + const UntypedMapBase* m_; + size_t bucket_index_; +}; + // Base class for all Map instantiations. // This class holds all the data and provides the basic functionality shared // among all instantiations. @@ -458,6 +537,7 @@ inline size_t SpaceUsedInValues(const void*) { return 0; } // parser) by having non-template code that can handle all instantiations. class PROTOBUF_EXPORT UntypedMapBase { using Allocator = internal::MapAllocator; + using Tree = internal::TreeForMap; public: using size_type = size_t; @@ -497,6 +577,7 @@ class PROTOBUF_EXPORT UntypedMapBase { protected: friend class TcParser; friend struct MapTestPeer; + friend class UntypedMapIterator; struct NodeAndBucket { NodeBase* node; @@ -583,6 +664,27 @@ class PROTOBUF_EXPORT UntypedMapBase { AllocFor(alloc_).deallocate(table, n); } + NodeBase* DestroyTree(Tree* tree); + using GetKey = VariantKey (*)(NodeBase*); + void InsertUniqueInTree(size_type b, GetKey get_key, NodeBase* node); + void TransferTree(Tree* tree, GetKey get_key); + TableEntryPtr ConvertToTree(NodeBase* node, GetKey get_key); + void EraseFromTree(size_type b, typename Tree::iterator tree_it); + + size_type VariantBucketNumber(VariantKey key) const; + + size_type BucketNumberFromHash(uint64_t h) const { + // We xor the hash value against the random seed so that we effectively + // have a random hash function. + h ^= seed_; + + // We use the multiplication method to determine the bucket number from + // the hash value. The constant kPhi (suggested by Knuth) is roughly + // (sqrt(5) - 1) / 2 * 2^64. + constexpr uint64_t kPhi = uint64_t{0x9e3779b97f4a7c15}; + return (MultiplyWithOverflow(kPhi, h) >> 32) & (num_buckets_ - 1); + } + TableEntryPtr* CreateEmptyTable(size_type n) { ABSL_DCHECK_GE(n, size_type{kMinTableSize}); ABSL_DCHECK_EQ(n & (n - 1), 0u); @@ -618,6 +720,63 @@ class PROTOBUF_EXPORT UntypedMapBase { return s; } + enum { + kKeyIsString = 1 << 0, + kValueIsString = 1 << 1, + kValueIsProto = 1 << 2, + kUseDestructFunc = 1 << 3, + }; + template + static constexpr uint8_t MakeDestroyBits() { + uint8_t result = 0; + if (!std::is_trivially_destructible::value) { + if (std::is_same::value) { + result |= kKeyIsString; + } else { + return kUseDestructFunc; + } + } + if (!std::is_trivially_destructible::value) { + if (std::is_same::value) { + result |= kValueIsString; + } else if (std::is_base_of::value) { + result |= kValueIsProto; + } else { + return kUseDestructFunc; + } + } + return result; + } + + struct ClearInput { + MapNodeSizeInfoT size_info; + uint8_t destroy_bits; + bool reset_table; + void (*destroy_node)(NodeBase*); + }; + + template + static void DestroyNode(NodeBase* node) { + static_cast(node)->~Node(); + } + + template + static constexpr ClearInput MakeClearInput(bool reset) { + constexpr auto bits = + MakeDestroyBits(); + return ClearInput{Node::size_info(), bits, reset, + bits & kUseDestructFunc ? DestroyNode : nullptr}; + } + + void ClearTable(ClearInput input); + + NodeAndBucket FindFromTree(size_type b, VariantKey key, + Tree::iterator* it) const; + + // Space used for the table, trees, and nodes. + // Does not include the indirect space used. Eg the data of a std::string. + size_t SpaceUsedInTable(size_t sizeof_node) const; + size_type num_elements_; size_type num_buckets_; size_type seed_; @@ -626,6 +785,40 @@ class PROTOBUF_EXPORT UntypedMapBase { Allocator alloc_; }; +inline UntypedMapIterator::UntypedMapIterator(const UntypedMapBase* m) : m_(m) { + if (m_->index_of_first_non_null_ == m_->num_buckets_) { + bucket_index_ = 0; + node_ = nullptr; + } else { + bucket_index_ = m_->index_of_first_non_null_; + TableEntryPtr entry = m_->table_[bucket_index_]; + node_ = PROTOBUF_PREDICT_TRUE(TableEntryIsList(entry)) + ? TableEntryToNode(entry) + : TableEntryToTree(entry)->begin()->second; + PROTOBUF_ASSUME(node_ != nullptr); + } +} + +inline void UntypedMapIterator::SearchFrom(size_t start_bucket) { + ABSL_DCHECK(m_->index_of_first_non_null_ == m_->num_buckets_ || + !m_->TableEntryIsEmpty(m_->index_of_first_non_null_)); + for (size_t i = start_bucket; i < m_->num_buckets_; ++i) { + TableEntryPtr entry = m_->table_[i]; + if (entry == TableEntryPtr{}) continue; + bucket_index_ = i; + if (PROTOBUF_PREDICT_TRUE(TableEntryIsList(entry))) { + node_ = TableEntryToNode(entry); + } else { + TreeForMap* tree = TableEntryToTree(entry); + ABSL_DCHECK(!tree->empty()); + node_ = tree->begin()->second; + } + return; + } + node_ = nullptr; + bucket_index_ = 0; +} + // Base class used by TcParser to extract the map object from a map field. // We keep it here to avoid a dependency into map_field.h from the main TcParser // code, since that would bring in Message too. @@ -656,31 +849,6 @@ struct KeyNode : NodeBase { decltype(auto) key() const { return ReadKey(GetVoidKey()); } }; -// Multiply two numbers where overflow is expected. -template -N MultiplyWithOverflow(N a, N b) { -#if defined(PROTOBUF_HAS_BUILTIN_MUL_OVERFLOW) - N res; - (void)__builtin_mul_overflow(a, b, &res); - return res; -#else - return a * b; -#endif -} - -// This struct contains the field of the iterators, but no other API. -// This allows MapIterator to allocate space for an iterator generically. -class MapIteratorPayload { - protected: - MapIteratorPayload() = default; - MapIteratorPayload(NodeBase* node, const UntypedMapBase* m, - size_t bucket_index) - : node_(node), m_(m), bucket_index_(bucket_index) {} - NodeBase* node_ = nullptr; - const UntypedMapBase* m_ = nullptr; - size_t bucket_index_ = 0; -}; - // KeyMapBase is a chaining hash map with the additional feature that some // buckets can be converted to use an ordered container. This ensures O(lg n) // bounds on find, insert, and erase, while avoiding the overheads of ordered @@ -701,18 +869,18 @@ class MapIteratorPayload { // 5. Mutations to a map do not invalidate the map's iterators, pointers to // elements, or references to elements. // 6. Except for erase(iterator), any non-const method can reorder iterators. -// 7. Uses KeyForTree when using the Tree representation, which -// is either `uint64_t` if `Key` is an integer, or -// `reference_wrapper` otherwise. This avoids unnecessary copies -// of string keys, for example. +// 7. Uses VariantKey when using the Tree representation, which holds all +// possible key types as a variant value. template class KeyMapBase : public UntypedMapBase { static_assert(!std::is_signed::value || !std::is_integral::value, ""); + using TS = TransparentSupport; + public: - using hasher = typename TransparentSupport::hash; + using hasher = typename TS::hash; using UntypedMapBase::UntypedMapBase; @@ -724,73 +892,9 @@ class KeyMapBase : public UntypedMapBase { // The value is a void* pointing to Node. We use void* instead of Node* to // avoid code bloat. That way there is only one instantiation of the tree // class per key type. - using Tree = internal::TreeForMap; + using Tree = internal::TreeForMap; using TreeIterator = typename Tree::iterator; - class KeyIteratorBase : protected MapIteratorPayload { - public: - // Invariants: - // node_ is always correct. This is handy because the most common - // operations are operator* and operator-> and they only use node_. - // When node_ is set to a non-null value, all the other non-const fields - // are updated to be correct also, but those fields can become stale - // if the underlying map is modified. When those fields are needed they - // are rechecked, and updated if necessary. - KeyIteratorBase() = default; - - explicit KeyIteratorBase(const KeyMapBase* m) { - m_ = m; - SearchFrom(m->index_of_first_non_null_); - } - - KeyIteratorBase(KeyNode* n, const KeyMapBase* m, size_type index) - : MapIteratorPayload(n, m, index) {} - - KeyIteratorBase(TreeIterator tree_it, const KeyMapBase* m, size_type index) - : MapIteratorPayload(NodeFromTreeIterator(tree_it), m, index) {} - - // Advance through buckets, looking for the first that isn't empty. - // If nothing non-empty is found then leave node_ == nullptr. - void SearchFrom(size_type start_bucket) { - ABSL_DCHECK(map().index_of_first_non_null_ == map().num_buckets_ || - !map().TableEntryIsEmpty(map().index_of_first_non_null_)); - for (size_type i = start_bucket; i < map().num_buckets_; ++i) { - TableEntryPtr entry = map().table_[i]; - if (entry == TableEntryPtr{}) continue; - bucket_index_ = i; - if (PROTOBUF_PREDICT_TRUE(internal::TableEntryIsList(entry))) { - node_ = static_cast(TableEntryToNode(entry)); - } else { - Tree* tree = TableEntryToTree(entry); - ABSL_DCHECK(!tree->empty()); - node_ = static_cast(tree->begin()->second); - } - return; - } - node_ = nullptr; - bucket_index_ = 0; - } - - // The definition of operator== is handled by the derived type. If we were - // to do it in this class it would allow comparing iterators of different - // map types. - bool Equals(const KeyIteratorBase& other) const { - return node_ == other.node_; - } - - // The definition of operator++ is handled in the derived type. We would not - // be able to return the right type from here. - void PlusPlus() { - if (node_->next == nullptr) { - SearchFrom(bucket_index_ + 1); - } else { - node_ = static_cast(node_->next); - } - } - - auto& map() const { return static_cast(*m_); } - }; - public: hasher hash_function() const { return {}; } @@ -807,17 +911,7 @@ class KeyMapBase : public UntypedMapBase { head = EraseFromLinkedList(node, head); table_[b] = NodeToTableEntry(head); } else { - ABSL_DCHECK(this->TableEntryIsTree(b)); - Tree* tree = internal::TableEntryToTree(this->table_[b]); - if (tree_it != tree->begin()) { - auto* prev = std::prev(tree_it)->second; - prev->next = prev->next->next; - } - tree->erase(tree_it); - if (tree->empty()) { - this->DestroyTree(tree); - this->table_[b] = TableEntryPtr{}; - } + EraseFromTree(b, tree_it); } --num_elements_; if (PROTOBUF_PREDICT_FALSE(b == index_of_first_non_null_)) { @@ -828,29 +922,20 @@ class KeyMapBase : public UntypedMapBase { } } - // TODO(sbenza): We can reduce duplication by coercing `K` to a common type. - // Eg, for string keys we can coerce to string_view. Otherwise, we instantiate - // this with all the different `char[N]` of the caller. - template - NodeAndBucket FindHelper(const K& k, TreeIterator* it = nullptr) const { + NodeAndBucket FindHelper(typename TS::ViewType k, + TreeIterator* it = nullptr) const { size_type b = BucketNumber(k); if (TableEntryIsNonEmptyList(b)) { auto* node = internal::TableEntryToNode(table_[b]); do { - if (internal::TransparentSupport::Equals( - static_cast(node)->key(), k)) { + if (TS::Equals(static_cast(node)->key(), k)) { return {node, b}; } else { node = node->next; } } while (node != nullptr); } else if (TableEntryIsTree(b)) { - Tree* tree = internal::TableEntryToTree(table_[b]); - auto tree_it = tree->find(k); - if (it != nullptr) *it = tree_it; - if (tree_it != tree->end()) { - return {tree_it->second, b}; - } + return FindFromTree(b, internal::RealKeyToVariantKey{}(k), it); } return {nullptr, b}; } @@ -892,29 +977,13 @@ class KeyMapBase : public UntypedMapBase { } else if (TableEntryIsNonEmptyList(b) && !TableEntryIsTooLong(b)) { InsertUniqueInList(b, node); } else { - if (TableEntryIsNonEmptyList(b)) { - TreeConvert(b); - } - ABSL_DCHECK(TableEntryIsTree(b)) - << (void*)table_[b] << " " << (uintptr_t)table_[b]; - InsertUniqueInTree(b, node); - index_of_first_non_null_ = (std::min)(index_of_first_non_null_, b); + InsertUniqueInTree(b, NodeToVariantKey, node); } } - // Helper for InsertUnique. Handles the case where bucket b points to a - // Tree. - void InsertUniqueInTree(size_type b, KeyNode* node) { - auto* tree = TableEntryToTree(table_[b]); - auto it = tree->insert({node->key(), node}).first; - // Maintain the linked list of the nodes in the tree. - // For simplicity, they are in the same order as the tree iteration. - if (it != tree->begin()) { - auto* prev = std::prev(it)->second; - prev->next = node; - } - auto next = std::next(it); - node->next = next != tree->end() ? next->second : nullptr; + static VariantKey NodeToVariantKey(NodeBase* node) { + return internal::RealKeyToVariantKey{}( + static_cast(node)->key()); } // Returns whether it did resize. Currently this is only used when @@ -979,7 +1048,7 @@ class KeyMapBase : public UntypedMapBase { if (internal::TableEntryIsNonEmptyList(old_table[i])) { TransferList(static_cast(TableEntryToNode(old_table[i]))); } else if (internal::TableEntryIsTree(old_table[i])) { - TransferTree(TableEntryToTree(old_table[i])); + this->TransferTree(TableEntryToTree(old_table[i]), NodeToVariantKey); } } DeleteTable(old_table, old_table_size); @@ -994,63 +1063,10 @@ class KeyMapBase : public UntypedMapBase { } while (node != nullptr); } - // Transfer all nodes in the tree `tree` into `this` and destroy the tree. - void TransferTree(Tree* tree) { - auto* node = tree->begin()->second; - DestroyTree(tree); - TransferList(static_cast(node)); - } - - void TreeConvert(size_type b) { - ABSL_DCHECK(!TableEntryIsTree(b)); - Tree* tree = - Arena::Create(alloc_.arena(), typename Tree::key_compare(), - typename Tree::allocator_type(alloc_)); - size_type count = CopyListToTree(b, tree); - ABSL_DCHECK_EQ(count, tree->size()); - table_[b] = TreeToTableEntry(tree); - // Relink the nodes. - NodeBase* next = nullptr; - auto it = tree->end(); - do { - auto* node = (--it)->second; - node->next = next; - next = node; - } while (it != tree->begin()); - } - - // Copy a linked list in the given bucket to a tree. - // Returns the number of things it copied. - size_type CopyListToTree(size_type b, Tree* tree) { - size_type count = 0; - auto* node = TableEntryToNode(table_[b]); - while (node != nullptr) { - tree->insert({static_cast(node)->key(), node}); - ++count; - auto* next = node->next; - node->next = nullptr; - node = next; - } - return count; - } - - template - size_type BucketNumber(const K& k) const { - // We xor the hash value against the random seed so that we effectively - // have a random hash function. - uint64_t h = hash_function()(k) ^ seed_; - - // We use the multiplication method to determine the bucket number from - // the hash value. The constant kPhi (suggested by Knuth) is roughly - // (sqrt(5) - 1) / 2 * 2^64. - constexpr uint64_t kPhi = uint64_t{0x9e3779b97f4a7c15}; - return (MultiplyWithOverflow(kPhi, h) >> 32) & (num_buckets_ - 1); - } - - void DestroyTree(Tree* tree) { - if (alloc_.arena() == nullptr) { - delete tree; - } + size_type BucketNumber(typename TS::ViewType k) const { + ABSL_DCHECK_EQ(BucketNumberFromHash(hash_function()(k)), + VariantBucketNumber(RealKeyToVariantKey{}(k))); + return BucketNumberFromHash(hash_function()(k)); } // Assumes node_ and m_ are correct and non-null, but other fields may be @@ -1109,6 +1125,8 @@ template class Map : private internal::KeyMapBase> { using Base = typename Map::KeyMapBase; + using TS = internal::TransparentSupport; + public: using key_type = Key; using mapped_type = T; @@ -1121,7 +1139,7 @@ class Map : private internal::KeyMapBase> { using const_reference = const value_type&; using size_type = size_t; - using hasher = typename internal::TransparentSupport::hash; + using hasher = typename TS::hash; constexpr Map() : Base(nullptr) { StaticValidityCheck(); } explicit Map(Arena* arena) : Base(arena) { StaticValidityCheck(); } @@ -1157,10 +1175,8 @@ class Map : private internal::KeyMapBase> { // won't trigger for leaked maps that never get destructed. StaticValidityCheck(); - if (this->alloc_.arena() == nullptr && - this->num_buckets_ != internal::kGlobalEmptyTableSize) { - clear(); - this->DeleteTable(this->table_, this->num_buckets_); + if (this->num_buckets_ != internal::kGlobalEmptyTableSize) { + this->ClearTable(this->template MakeClearInput(false)); } } @@ -1212,13 +1228,12 @@ class Map : private internal::KeyMapBase> { typename std::enable_if::value, int>::type; template - using key_arg = typename internal::TransparentSupport< - key_type>::template key_arg; + using key_arg = typename TS::template key_arg; public: // Iterators - class const_iterator : private Base::KeyIteratorBase { - using BaseIt = typename Base::KeyIteratorBase; + class const_iterator : private internal::UntypedMapIterator { + using BaseIt = internal::UntypedMapIterator; public: using iterator_category = std::forward_iterator_tag; @@ -1257,8 +1272,8 @@ class Map : private internal::KeyMapBase> { friend class Map; }; - class iterator : private Base::KeyIteratorBase { - using BaseIt = typename Base::KeyIteratorBase; + class iterator : private internal::UntypedMapIterator { + using BaseIt = internal::UntypedMapIterator; public: using iterator_category = std::forward_iterator_tag; @@ -1357,7 +1372,7 @@ class Map : private internal::KeyMapBase> { } template iterator find(const key_arg& key) ABSL_ATTRIBUTE_LIFETIME_BOUND { - auto res = this->FindHelper(key); + auto res = this->FindHelper(TS::ToView(key)); return iterator(static_cast(res.node), this, res.bucket); } @@ -1463,27 +1478,8 @@ class Map : private internal::KeyMapBase> { } void clear() { - for (size_type b = 0; b < this->num_buckets_; b++) { - internal::NodeBase* node; - if (this->TableEntryIsNonEmptyList(b)) { - node = internal::TableEntryToNode(this->table_[b]); - this->table_[b] = TableEntryPtr{}; - } else if (this->TableEntryIsTree(b)) { - Tree* tree = internal::TableEntryToTree(this->table_[b]); - this->table_[b] = TableEntryPtr{}; - node = NodeFromTreeIterator(tree->begin()); - this->DestroyTree(tree); - } else { - continue; - } - do { - auto* next = node->next; - DestroyNode(static_cast(node)); - node = next; - } while (node != nullptr); - } - this->num_elements_ = 0; - this->index_of_first_non_null_ = this->num_buckets_; + if (this->num_buckets_ == internal::kGlobalEmptyTableSize) return; + this->ClearTable(this->template MakeClearInput(true)); } // Assign @@ -1523,6 +1519,8 @@ class Map : private internal::KeyMapBase> { // Linked-list nodes, as one would expect for a chaining hash table. struct Node : Base::KeyNode { + using key_type = Key; + using mapped_type = T; static constexpr internal::MapNodeSizeInfoT size_info() { return internal::MakeNodeInfo(sizeof(Node), PROTOBUF_FIELD_OFFSET(Node, kv.second)); @@ -1530,7 +1528,7 @@ class Map : private internal::KeyMapBase> { value_type kv; }; - using Tree = internal::TreeForMap; + using Tree = internal::TreeForMap; using TreeIterator = typename Tree::iterator; using TableEntryPtr = internal::TableEntryPtr; @@ -1550,8 +1548,7 @@ class Map : private internal::KeyMapBase> { } size_t SpaceUsedInternal() const { - return internal::SpaceUsedInTable(this->table_, this->num_buckets_, - this->num_elements_, sizeof(Node)); + return this->SpaceUsedInTable(sizeof(Node)); } // We try to construct `init_type` from `Args` with a fall back to @@ -1570,14 +1567,14 @@ class Map : private internal::KeyMapBase> { template std::pair TryEmplaceInternal(K&& k, Args&&... args) { - auto p = this->FindHelper(k); + auto p = this->FindHelper(TS::ToView(k)); // Case 1: key was already present. if (p.node != nullptr) return std::make_pair( iterator(static_cast(p.node), this, p.bucket), false); // Case 2: insert. if (this->ResizeIfLoadIsOutOfRange(this->num_elements_ + 1)) { - p = this->FindHelper(k); + p = this->FindHelper(TS::ToView(k)); } const size_type b = p.bucket; // bucket number // If K is not key_type, make the conversion to key_type explicit. diff --git a/src/google/protobuf/map_field.cc b/src/google/protobuf/map_field.cc index d9a66a7339..24ce61d8e7 100644 --- a/src/google/protobuf/map_field.cc +++ b/src/google/protobuf/map_field.cc @@ -44,6 +44,25 @@ namespace protobuf { namespace internal { using ::google::protobuf::internal::DownCast; +VariantKey RealKeyToVariantKey::operator()(const MapKey& value) const { + switch (value.type()) { + case FieldDescriptor::CPPTYPE_STRING: + return VariantKey(value.GetStringValue()); + case FieldDescriptor::CPPTYPE_INT64: + return VariantKey(value.GetInt64Value()); + case FieldDescriptor::CPPTYPE_INT32: + return VariantKey(value.GetInt32Value()); + case FieldDescriptor::CPPTYPE_UINT64: + return VariantKey(value.GetUInt64Value()); + case FieldDescriptor::CPPTYPE_UINT32: + return VariantKey(value.GetUInt32Value()); + case FieldDescriptor::CPPTYPE_BOOL: + return VariantKey(static_cast(value.GetBoolValue())); + default: + ABSL_ASSUME(false); + } +} + MapFieldBase::~MapFieldBase() { ABSL_DCHECK_EQ(arena(), nullptr); delete maybe_payload(); diff --git a/src/google/protobuf/map_field.h b/src/google/protobuf/map_field.h index c3d69b11dd..c1dbba69fa 100644 --- a/src/google/protobuf/map_field.h +++ b/src/google/protobuf/map_field.h @@ -276,8 +276,15 @@ class PROTOBUF_EXPORT MapKey { }; namespace internal { + template <> struct is_internal_map_key_type : std::true_type {}; + +template <> +struct RealKeyToVariantKey { + VariantKey operator()(const MapKey& value) const; +}; + } // namespace internal } // namespace protobuf @@ -286,37 +293,8 @@ namespace std { template <> struct hash { size_t operator()(const google::protobuf::MapKey& map_key) const { - switch (map_key.type()) { - case google::protobuf::FieldDescriptor::CPPTYPE_DOUBLE: - case google::protobuf::FieldDescriptor::CPPTYPE_FLOAT: - case google::protobuf::FieldDescriptor::CPPTYPE_ENUM: - case google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE: - ABSL_LOG(FATAL) << "Unsupported"; - break; - case google::protobuf::FieldDescriptor::CPPTYPE_STRING: - return hash()(map_key.GetStringValue()); - case google::protobuf::FieldDescriptor::CPPTYPE_INT64: { - auto value = map_key.GetInt64Value(); - return hash()(value); - } - case google::protobuf::FieldDescriptor::CPPTYPE_INT32: { - auto value = map_key.GetInt32Value(); - return hash()(map_key.GetInt32Value()); - } - case google::protobuf::FieldDescriptor::CPPTYPE_UINT64: { - auto value = map_key.GetUInt64Value(); - return hash()(map_key.GetUInt64Value()); - } - case google::protobuf::FieldDescriptor::CPPTYPE_UINT32: { - auto value = map_key.GetUInt32Value(); - return hash()(map_key.GetUInt32Value()); - } - case google::protobuf::FieldDescriptor::CPPTYPE_BOOL: { - return hash()(map_key.GetBoolValue()); - } - } - ABSL_LOG(FATAL) << "Can't get here."; - return 0; + return ::google::protobuf::internal::RealKeyToVariantKey<::google::protobuf::MapKey>{}(map_key) + .Hash(); } bool operator()(const google::protobuf::MapKey& map_key1, const google::protobuf::MapKey& map_key2) const { @@ -944,8 +922,8 @@ class PROTOBUF_EXPORT MapIterator { // This field provides the storage for Map<...>::const_iterator. We use // reinterpret_cast to get the right type. The real iterator is trivially // destructible/copyable, so no need to manage that. - alignas(internal::MapIteratorPayload) char map_iter_buffer_[sizeof( - internal::MapIteratorPayload)]{}; + alignas(internal::UntypedMapIterator) char map_iter_buffer_[sizeof( + internal::UntypedMapIterator)]{}; // Point to a MapField to call helper methods implemented in MapField. // MapIterator does not own this object. internal::MapFieldBase* map_; diff --git a/src/google/protobuf/map_test.inc b/src/google/protobuf/map_test.inc index e8f171ab42..f75939ca85 100644 --- a/src/google/protobuf/map_test.inc +++ b/src/google/protobuf/map_test.inc @@ -332,7 +332,7 @@ namespace std { template <> // NOLINT struct hash { size_t operator()(const google::protobuf::internal::MoveTestKey& key) const { - return hash{}(key.data); + return hash{}(key.data); } }; } // namespace std @@ -340,6 +340,14 @@ struct hash { namespace google { namespace protobuf { namespace internal { + +template <> +struct RealKeyToVariantKey { + VariantKey operator()(const MoveTestKey& value) const { + return VariantKey(value.data); + } +}; + namespace { TEST_F(MapImplTest, OperatorBracketRValue) {