[avl] Use RefCountedPtr instead of shared_ptr (#33900)

Reduces node size from 112 bytes to 88 bytes on x64 opt builds.

(also delete the unused specialization of `AVL<T, void>`)

---------

Co-authored-by: ctiller <ctiller@users.noreply.github.com>
pull/33870/head
Craig Tiller 2 years ago committed by GitHub
parent b701a5433e
commit 2b3400052d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 3
      CMakeLists.txt
  2. 7
      build_autogenerated.yaml
  3. 2
      src/core/BUILD
  4. 183
      src/core/lib/avl/avl.h

3
CMakeLists.txt generated

@ -6383,8 +6383,7 @@ target_link_libraries(avl_test
${_gRPC_PROTOBUF_LIBRARIES}
${_gRPC_ZLIB_LIBRARIES}
${_gRPC_ALLTARGETS_LIBRARIES}
absl::strings
absl::variant
gpr
)

@ -4776,12 +4776,13 @@ targets:
language: c++
headers:
- src/core/lib/avl/avl.h
- src/core/lib/gpr/useful.h
- src/core/lib/gprpp/atomic_utils.h
- src/core/lib/gprpp/ref_counted.h
- src/core/lib/gprpp/ref_counted_ptr.h
src:
- test/core/avl/avl_test.cc
deps:
- absl/strings:strings
- absl/types:variant
- gpr
uses_polling: false
- name: aws_request_signer_test
gtest: true

@ -1360,8 +1360,10 @@ grpc_cc_library(
"lib/avl/avl.h",
],
deps = [
"ref_counted",
"useful",
"//:gpr_platform",
"//:ref_counted_ptr",
],
)

@ -20,10 +20,11 @@
#include <stdlib.h>
#include <algorithm> // IWYU pragma: keep
#include <memory>
#include <utility>
#include "src/core/lib/gpr/useful.h"
#include "src/core/lib/gprpp/ref_counted.h"
#include "src/core/lib/gprpp/ref_counted_ptr.h"
namespace grpc_core {
@ -42,12 +43,12 @@ class AVL {
template <typename SomethingLikeK>
const V* Lookup(const SomethingLikeK& key) const {
NodePtr n = Get(root_, key);
return n ? &n->kv.second : nullptr;
return n != nullptr ? &n->kv.second : nullptr;
}
const std::pair<K, V>* LookupBelow(const K& key) const {
NodePtr n = GetBelow(root_, *key);
return n ? &n->kv : nullptr;
return n != nullptr ? &n->kv : nullptr;
}
bool Empty() const { return root_ == nullptr; }
@ -95,8 +96,8 @@ class AVL {
private:
struct Node;
typedef std::shared_ptr<Node> NodePtr;
struct Node : public std::enable_shared_from_this<Node> {
typedef RefCountedPtr<Node> NodePtr;
struct Node : public RefCounted<Node, NonPolymorphicRefCount> {
Node(K k, V v, NodePtr l, NodePtr r, long h)
: kv(std::move(k), std::move(v)),
left(std::move(l)),
@ -167,12 +168,12 @@ class AVL {
ForEachImpl(n->right.get(), std::forward<F>(f));
}
static long Height(const NodePtr& n) { return n ? n->height : 0; }
static long Height(const NodePtr& n) { return n != nullptr ? n->height : 0; }
static NodePtr MakeNode(K key, V value, const NodePtr& left,
const NodePtr& right) {
return std::make_shared<Node>(std::move(key), std::move(value), left, right,
1 + std::max(Height(left), Height(right)));
return MakeRefCounted<Node>(std::move(key), std::move(value), left, right,
1 + std::max(Height(left), Height(right)));
}
template <typename SomethingLikeK>
@ -259,7 +260,7 @@ class AVL {
}
static NodePtr AddKey(const NodePtr& node, K key, V value) {
if (!node) {
if (node == nullptr) {
return MakeNode(std::move(key), std::move(value), nullptr, nullptr);
}
if (node->kv.first < key) {
@ -318,170 +319,6 @@ class AVL {
}
};
template <class K>
class AVL<K, void> {
public:
AVL() {}
AVL Add(K key) const { return AVL(AddKey(root_, std::move(key))); }
AVL Remove(const K& key) const { return AVL(RemoveKey(root_, key)); }
bool Lookup(const K& key) const { return Get(root_, key) != nullptr; }
bool Empty() const { return root_ == nullptr; }
template <class F>
void ForEach(F&& f) const {
ForEachImpl(root_.get(), std::forward<F>(f));
}
bool SameIdentity(AVL avl) const { return root_ == avl.root_; }
private:
struct Node;
typedef std::shared_ptr<Node> NodePtr;
struct Node : public std::enable_shared_from_this<Node> {
Node(K k, NodePtr l, NodePtr r, long h)
: key(std::move(k)),
left(std::move(l)),
right(std::move(r)),
height(h) {}
const K key;
const NodePtr left;
const NodePtr right;
const long height;
};
NodePtr root_;
explicit AVL(NodePtr root) : root_(std::move(root)) {}
template <class F>
static void ForEachImpl(const Node* n, F&& f) {
if (n == nullptr) return;
ForEachImpl(n->left.get(), std::forward<F>(f));
f(const_cast<const K&>(n->key));
ForEachImpl(n->right.get(), std::forward<F>(f));
}
static long Height(const NodePtr& n) { return n ? n->height : 0; }
static NodePtr MakeNode(K key, const NodePtr& left, const NodePtr& right) {
return std::make_shared<Node>(std::move(key), left, right,
1 + std::max(Height(left), Height(right)));
}
static NodePtr Get(const NodePtr& node, const K& key) {
if (node == nullptr) {
return nullptr;
}
if (node->key > key) {
return Get(node->left, key);
} else if (node->key < key) {
return Get(node->right, key);
} else {
return node;
}
}
static NodePtr RotateLeft(K key, const NodePtr& left, const NodePtr& right) {
return MakeNode(right->key, MakeNode(std::move(key), left, right->left),
right->right);
}
static NodePtr RotateRight(K key, const NodePtr& left, const NodePtr& right) {
return MakeNode(left->key, left->left,
MakeNode(std::move(key), left->right, right));
}
static NodePtr RotateLeftRight(K key, const NodePtr& left,
const NodePtr& right) {
// rotate_right(..., rotate_left(left), right)
return MakeNode(left->right->key,
MakeNode(left->key, left->left, left->right->left),
MakeNode(std::move(key), left->right->right, right));
}
static NodePtr RotateRightLeft(K key, const NodePtr& left,
const NodePtr& right) {
// rotate_left(..., left, rotate_right(right))
return MakeNode(right->left->key,
MakeNode(std::move(key), left, right->left->left),
MakeNode(right->key, right->left->right, right->right));
}
static NodePtr Rebalance(K key, const NodePtr& left, const NodePtr& right) {
switch (Height(left) - Height(right)) {
case 2:
if (Height(left->left) - Height(left->right) == -1) {
return RotateLeftRight(std::move(key), left, right);
} else {
return RotateRight(std::move(key), left, right);
}
case -2:
if (Height(right->left) - Height(right->right) == 1) {
return RotateRightLeft(std::move(key), left, right);
} else {
return RotateLeft(std::move(key), left, right);
}
default:
return MakeNode(key, left, right);
}
}
static NodePtr AddKey(const NodePtr& node, K key) {
if (!node) {
return MakeNode(std::move(key), nullptr, nullptr);
}
if (node->key < key) {
return Rebalance(node->key, node->left,
AddKey(node->right, std::move(key)));
}
if (key < node->key) {
return Rebalance(node->key, AddKey(node->left, std::move(key)),
node->right);
}
return MakeNode(std::move(key), node->left, node->right);
}
static NodePtr InOrderHead(NodePtr node) {
while (node->left != nullptr) {
node = node->left;
}
return node;
}
static NodePtr InOrderTail(NodePtr node) {
while (node->right != nullptr) {
node = node->right;
}
return node;
}
static NodePtr RemoveKey(const NodePtr& node, const K& key) {
if (node == nullptr) {
return nullptr;
}
if (key < node->key) {
return Rebalance(node->key, RemoveKey(node->left, key), node->right);
} else if (node->key < key) {
return Rebalance(node->key, node->left, RemoveKey(node->right, key));
} else {
if (node->left == nullptr) {
return node->right;
} else if (node->right == nullptr) {
return node->left;
} else if (node->left->height < node->right->height) {
NodePtr h = InOrderHead(node->right);
return Rebalance(h->key, node->left, RemoveKey(node->right, h->key));
} else {
NodePtr h = InOrderTail(node->left);
return Rebalance(h->key, RemoveKey(node->left, h->key), node->right);
}
}
abort();
}
};
} // namespace grpc_core
#endif // GRPC_SRC_CORE_LIB_AVL_AVL_H

Loading…
Cancel
Save