diff --git a/src/google/protobuf/map.h b/src/google/protobuf/map.h index da654c9073..e84ca694dd 100644 --- a/src/google/protobuf/map.h +++ b/src/google/protobuf/map.h @@ -70,6 +70,7 @@ template class TypeDefinedMapFieldBase; @@ -563,6 +564,7 @@ class PROTOBUF_EXPORT UntypedMapBase { protected: friend class TcParser; friend struct MapTestPeer; + friend struct MapBenchmarkPeer; friend class UntypedMapIterator; struct NodeAndBucket { @@ -900,6 +902,7 @@ class KeyMapBase : public UntypedMapBase { protected: friend class TcParser; friend struct MapTestPeer; + friend struct MapBenchmarkPeer; PROTOBUF_NOINLINE void erase_no_destroy(map_index_t b, KeyNode* node) { TreeIterator tree_it; @@ -1665,6 +1668,7 @@ class Map : private internal::KeyMapBase> { friend class internal::MapFieldLite; friend class internal::TcParser; friend struct internal::MapTestPeer; + friend struct internal::MapBenchmarkPeer; }; namespace internal { diff --git a/src/google/protobuf/map_probe_benchmark.cc b/src/google/protobuf/map_probe_benchmark.cc new file mode 100644 index 0000000000..550de51a4e --- /dev/null +++ b/src/google/protobuf/map_probe_benchmark.cc @@ -0,0 +1,308 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2023 Google Inc. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file or at +// https://developers.google.com/open-source/licenses/bsd + +#include +#include +#include +#include +#include +#include + +#include "absl/random/random.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "google/protobuf/map.h" + +namespace google::protobuf::internal { +struct MapBenchmarkPeer { + template + static double LoadFactor(const T& map) { + return static_cast(map.size()) / + static_cast(map.num_buckets_); + } + + template + static double GetMeanProbeLength(const T& map) { + double total_probe_cost = 0; + for (map_index_t b = 0; b < map.num_buckets_; ++b) { + if (map.TableEntryIsList(b)) { + auto* node = internal::TableEntryToNode(map.table_[b]); + size_t cost = 0; + while (node != nullptr) { + total_probe_cost += static_cast(cost); + cost++; + node = node->next; + } + } else if (map.TableEntryIsTree(b)) { + // Overhead factor to account for more costly binary search. + constexpr double kTreeOverhead = 2.0; + size_t tree_size = TableEntryToTree(map.table_[b])->size(); + total_probe_cost += kTreeOverhead * static_cast(tree_size) * + std::log2(tree_size); + } + } + return total_probe_cost / map.size(); + } + + template + static double GetPercentTree(const T& map) { + size_t total_tree_size = 0; + for (map_index_t b = 0; b < map.num_buckets_; ++b) { + if (map.TableEntryIsTree(b)) { + total_tree_size += TableEntryToTree(map.table_[b])->size(); + } + } + return static_cast(total_tree_size) / + static_cast(map.size()); + } +}; +} // namespace protobuf +} // namespace google::internal + +namespace { + +using Peer = google::protobuf::internal::MapBenchmarkPeer; + +absl::BitGen& GlobalBitGen() { + static auto* value = new absl::BitGen; + return *value; +} + +template +using Table = google::protobuf::Map; + +struct LoadSizes { + size_t min_load; + size_t max_load; +}; + +LoadSizes GetMinMaxLoadSizes() { + static const auto sizes = [] { + Table t; + + // First, fill enough to have a good distribution. + constexpr size_t kMinSize = 10000; + while (t.size() < kMinSize) t[static_cast(t.size())]; + + const auto reach_min_load_factor = [&] { + const double lf = Peer::LoadFactor(t); + while (lf <= Peer::LoadFactor(t)) t[static_cast(t.size())]; + }; + + // Then, insert until we reach min load factor. + reach_min_load_factor(); + const size_t min_load_size = t.size(); + + // Keep going until we hit min load factor again, then go back one. + t[static_cast(t.size())]; + reach_min_load_factor(); + + return LoadSizes{min_load_size, t.size() - 1}; + }(); + return sizes; +} + +struct Ratios { + double min_load; + double avg_load; + double max_load; + double percent_tree; +}; + +template +Ratios CollectMeanProbeLengths() { + const auto min_max_sizes = GetMinMaxLoadSizes(); + + ElemFn elem; + using Key = decltype(elem()); + Table t; + + Ratios result; + while (t.size() < min_max_sizes.min_load) t[elem()]; + result.min_load = Peer::GetMeanProbeLength(t); + + while (t.size() < (min_max_sizes.min_load + min_max_sizes.max_load) / 2) + t[elem()]; + result.avg_load = Peer::GetMeanProbeLength(t); + + while (t.size() < min_max_sizes.max_load) t[elem()]; + result.max_load = Peer::GetMeanProbeLength(t); + result.percent_tree = Peer::GetPercentTree(t); + + return result; +} + +constexpr char kStringFormat[] = "/path/to/file/name-%07d-of-9999999.txt"; + +template +struct String { + std::string value; + static std::string Make(uint32_t v) { + return {small ? absl::StrCat(v) : absl::StrFormat(kStringFormat, v)}; + } +}; + +template +struct Sequential { + T operator()() const { return current++; } + mutable T current{}; +}; + +template +struct Sequential> { + std::string operator()() const { return String::Make(current++); } + mutable uint32_t current = 0; +}; + +template +struct AlmostSequential { + mutable Sequential current; + + auto operator()() const -> decltype(current()) { + while (absl::Uniform(GlobalBitGen(), 0.0, 1.0) <= percent_skip / 100.) + current(); + return current(); + } +}; + +struct Uniform { + template + T operator()(T) const { + return absl::Uniform(absl::IntervalClosed, GlobalBitGen(), T{0}, ~T{0}); + } +}; + +struct Gaussian { + template + T operator()(T) const { + double d; + do { + d = absl::Gaussian(GlobalBitGen(), 1e6, 1e4); + } while (d <= 0 || d > std::numeric_limits::max() / 2); + return static_cast(d); + } +}; + +struct Zipf { + template + T operator()(T) const { + return absl::Zipf(GlobalBitGen(), std::numeric_limits::max(), 1.6); + } +}; + +template +struct Random { + T operator()() const { return Dist{}(T{}); } +}; + +template +struct Random, Dist> { + std::string operator()() const { + return String::Make(Random{}()); + } +}; + +template +std::string Name(); + +std::string Name(uint64_t*) { return "u64"; } + +template +std::string Name(String*) { + return small ? "StrS" : "StrL"; +} + +template +std::string Name(Sequential*) { + return "Sequential"; +} + +template +std::string Name(AlmostSequential*) { + return absl::StrCat("AlmostSeq_", P); +} + +template +std::string Name(Random*) { + return "UnifRand"; +} + +template +std::string Name(Random*) { + return "GausRand"; +} + +template +std::string Name(Random*) { + return "ZipfRand"; +} + +template +std::string Name() { + return Name(static_cast(nullptr)); +} + +struct Result { + std::string name; + std::string dist_name; + Ratios ratios; +}; + +template +void RunForTypeAndDistribution(std::vector& results) { + results.push_back({Name(), Name(), CollectMeanProbeLengths()}); +} + +template +void RunForType(std::vector& results) { + RunForTypeAndDistribution>(results); + RunForTypeAndDistribution>(results); + RunForTypeAndDistribution>(results); + RunForTypeAndDistribution>(results); + RunForTypeAndDistribution>(results); + RunForTypeAndDistribution>(results); +} + +} // namespace + +int main(int argc, char** argv) { + std::vector results; + RunForType(results); + RunForType>(results); + RunForType>(results); + + absl::PrintF("{\n"); + absl::PrintF(" \"benchmarks\": [\n"); + absl::string_view comma; + for (const auto& result : results) { + auto print = [&](absl::string_view stat, double Ratios::*val) { + std::string name = + absl::StrCat(result.name, "/", result.dist_name, "/", stat); + absl::PrintF(" %s{\n", comma); + absl::PrintF(" \"cpu_time\": 0,\n"); + absl::PrintF(" \"real_time\": 0,\n"); + absl::PrintF(" \"allocs_per_iter\": %f,\n", result.ratios.*val); + + absl::PrintF(" \"iterations\": 1,\n"); + absl::PrintF(" \"name\": \"%s\",\n", name); + absl::PrintF(" \"time_unit\": \"ns\"\n"); + absl::PrintF(" }\n"); + comma = ","; + }; + print("min", &Ratios::min_load); + print("avg", &Ratios::avg_load); + print("max", &Ratios::max_load); + print("tree_percent", &Ratios::percent_tree); + } + absl::PrintF(" ],\n"); + absl::PrintF(" \"context\": {\n"); + absl::PrintF(" }\n"); + absl::PrintF("}\n"); + + return 0; +}