Remove unnecessary usage of AVL (#27666)

* Remove unnecessary usage of avl

* review feedback

* add missing ref
pull/27784/head
Craig Tiller 3 years ago committed by GitHub
parent 25d8458721
commit 4dc6a11227
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 65
      src/core/ext/filters/client_channel/retry_throttle.cc
  2. 66
      src/core/tsi/ssl/session_cache/ssl_session_cache.cc
  3. 7
      src/core/tsi/ssl/session_cache/ssl_session_cache.h

@ -23,6 +23,7 @@
#include <limits.h> #include <limits.h>
#include <string.h> #include <string.h>
#include <map>
#include <string> #include <string>
#include <grpc/support/alloc.h> #include <grpc/support/alloc.h>
@ -30,7 +31,7 @@
#include <grpc/support/string_util.h> #include <grpc/support/string_util.h>
#include <grpc/support/sync.h> #include <grpc/support/sync.h>
#include "src/core/lib/avl/avl.h" #include "src/core/lib/gprpp/manual_constructor.h"
namespace grpc_core { namespace grpc_core {
namespace internal { namespace internal {
@ -114,55 +115,24 @@ void ServerRetryThrottleData::RecordSuccess() {
static_cast<gpr_atm>(throttle_data->max_milli_tokens_)); static_cast<gpr_atm>(throttle_data->max_milli_tokens_));
} }
//
// avl vtable for string -> server_retry_throttle_data map
//
namespace {
void* copy_server_name(void* key, void* /*unused*/) {
return gpr_strdup(static_cast<const char*>(key));
}
long compare_server_name(void* key1, void* key2, void* /*unused*/) {
return strcmp(static_cast<const char*>(key1), static_cast<const char*>(key2));
}
void destroy_server_retry_throttle_data(void* value, void* /*unused*/) {
ServerRetryThrottleData* throttle_data =
static_cast<ServerRetryThrottleData*>(value);
throttle_data->Unref();
}
void* copy_server_retry_throttle_data(void* value, void* /*unused*/) {
ServerRetryThrottleData* throttle_data =
static_cast<ServerRetryThrottleData*>(value);
return throttle_data->Ref().release();
}
void destroy_server_name(void* key, void* /*unused*/) { gpr_free(key); }
const grpc_avl_vtable avl_vtable = {
destroy_server_name, copy_server_name, compare_server_name,
destroy_server_retry_throttle_data, copy_server_retry_throttle_data};
} // namespace
// //
// ServerRetryThrottleMap // ServerRetryThrottleMap
// //
using StringToDataMap =
std::map<std::string, RefCountedPtr<ServerRetryThrottleData>>;
static gpr_mu g_mu; static gpr_mu g_mu;
static grpc_avl g_avl; static StringToDataMap* g_map;
void ServerRetryThrottleMap::Init() { void ServerRetryThrottleMap::Init() {
gpr_mu_init(&g_mu); gpr_mu_init(&g_mu);
g_avl = grpc_avl_create(&avl_vtable); g_map = new StringToDataMap();
} }
void ServerRetryThrottleMap::Shutdown() { void ServerRetryThrottleMap::Shutdown() {
gpr_mu_destroy(&g_mu); gpr_mu_destroy(&g_mu);
grpc_avl_unref(g_avl, nullptr); delete g_map;
g_map = nullptr;
} }
RefCountedPtr<ServerRetryThrottleData> ServerRetryThrottleMap::GetDataForServer( RefCountedPtr<ServerRetryThrottleData> ServerRetryThrottleMap::GetDataForServer(
@ -170,23 +140,22 @@ RefCountedPtr<ServerRetryThrottleData> ServerRetryThrottleMap::GetDataForServer(
intptr_t milli_token_ratio) { intptr_t milli_token_ratio) {
RefCountedPtr<ServerRetryThrottleData> result; RefCountedPtr<ServerRetryThrottleData> result;
gpr_mu_lock(&g_mu); gpr_mu_lock(&g_mu);
auto it = g_map->find(server_name);
ServerRetryThrottleData* throttle_data = ServerRetryThrottleData* throttle_data =
static_cast<ServerRetryThrottleData*>( it == g_map->end() ? nullptr : it->second.get();
grpc_avl_get(g_avl, const_cast<char*>(server_name.c_str()), nullptr));
if (throttle_data == nullptr || if (throttle_data == nullptr ||
throttle_data->max_milli_tokens() != max_milli_tokens || throttle_data->max_milli_tokens() != max_milli_tokens ||
throttle_data->milli_token_ratio() != milli_token_ratio) { throttle_data->milli_token_ratio() != milli_token_ratio) {
// Entry not found, or found with old parameters. Create a new one. // Entry not found, or found with old parameters. Create a new one.
result = MakeRefCounted<ServerRetryThrottleData>( it = g_map
max_milli_tokens, milli_token_ratio, throttle_data); ->emplace(server_name,
g_avl = grpc_avl_add(g_avl, gpr_strdup(server_name.c_str()), MakeRefCounted<ServerRetryThrottleData>(
result->Ref().release(), nullptr); max_milli_tokens, milli_token_ratio, throttle_data))
} else { .first;
// Entry found. Return a new ref to it. throttle_data = it->second.get();
result = throttle_data->Ref();
} }
gpr_mu_unlock(&g_mu); gpr_mu_unlock(&g_mu);
return result; return throttle_data->Ref();
} }
} // namespace internal } // namespace internal

@ -29,41 +29,18 @@
namespace tsi { namespace tsi {
static void cache_key_avl_destroy(void* /*key*/, void* /*unused*/) {}
static void* cache_key_avl_copy(void* key, void* /*unused*/) { return key; }
static long cache_key_avl_compare(void* key1, void* key2, void* /*unused*/) {
return grpc_slice_cmp(*static_cast<grpc_slice*>(key1),
*static_cast<grpc_slice*>(key2));
}
static void cache_value_avl_destroy(void* /*value*/, void* /*unused*/) {}
static void* cache_value_avl_copy(void* value, void* /*unused*/) {
return value;
}
// AVL only stores pointers, ownership belonges to the linked list.
static const grpc_avl_vtable cache_avl_vtable = {
cache_key_avl_destroy, cache_key_avl_copy, cache_key_avl_compare,
cache_value_avl_destroy, cache_value_avl_copy,
};
/// Node for single cached session. /// Node for single cached session.
class SslSessionLRUCache::Node { class SslSessionLRUCache::Node {
public: public:
Node(const grpc_slice& key, SslSessionPtr session) : key_(key) { Node(const std::string& key, SslSessionPtr session) : key_(key) {
SetSession(std::move(session)); SetSession(std::move(session));
} }
~Node() { grpc_slice_unref_internal(key_); }
// Not copyable nor movable. // Not copyable nor movable.
Node(const Node&) = delete; Node(const Node&) = delete;
Node& operator=(const Node&) = delete; Node& operator=(const Node&) = delete;
void* AvlKey() { return &key_; } const std::string& key() const { return key_; }
/// Returns a copy of the node's cache session. /// Returns a copy of the node's cache session.
SslSessionPtr CopySession() const { return session_->CopySession(); } SslSessionPtr CopySession() const { return session_->CopySession(); }
@ -76,7 +53,7 @@ class SslSessionLRUCache::Node {
private: private:
friend class SslSessionLRUCache; friend class SslSessionLRUCache;
grpc_slice key_; std::string key_;
std::unique_ptr<SslCachedSession> session_; std::unique_ptr<SslCachedSession> session_;
Node* next_ = nullptr; Node* next_ = nullptr;
@ -85,7 +62,6 @@ class SslSessionLRUCache::Node {
SslSessionLRUCache::SslSessionLRUCache(size_t capacity) : capacity_(capacity) { SslSessionLRUCache::SslSessionLRUCache(size_t capacity) : capacity_(capacity) {
GPR_ASSERT(capacity > 0); GPR_ASSERT(capacity > 0);
entry_by_key_ = grpc_avl_create(&cache_avl_vtable);
} }
SslSessionLRUCache::~SslSessionLRUCache() { SslSessionLRUCache::~SslSessionLRUCache() {
@ -95,7 +71,6 @@ SslSessionLRUCache::~SslSessionLRUCache() {
delete node; delete node;
node = next; node = next;
} }
grpc_avl_unref(entry_by_key_, nullptr);
} }
size_t SslSessionLRUCache::Size() { size_t SslSessionLRUCache::Size() {
@ -104,13 +79,12 @@ size_t SslSessionLRUCache::Size() {
} }
SslSessionLRUCache::Node* SslSessionLRUCache::FindLocked( SslSessionLRUCache::Node* SslSessionLRUCache::FindLocked(
const grpc_slice& key) { const std::string& key) {
void* value = auto it = entry_by_key_.find(key);
grpc_avl_get(entry_by_key_, const_cast<grpc_slice*>(&key), nullptr); if (it == entry_by_key_.end()) {
if (value == nullptr) {
return nullptr; return nullptr;
} }
Node* node = static_cast<Node*>(value); Node* node = it->second;
// Move to the beginning. // Move to the beginning.
Remove(node); Remove(node);
PushFront(node); PushFront(node);
@ -120,22 +94,21 @@ SslSessionLRUCache::Node* SslSessionLRUCache::FindLocked(
void SslSessionLRUCache::Put(const char* key, SslSessionPtr session) { void SslSessionLRUCache::Put(const char* key, SslSessionPtr session) {
grpc_core::MutexLock lock(&lock_); grpc_core::MutexLock lock(&lock_);
Node* node = FindLocked(grpc_slice_from_static_string(key)); Node* node = FindLocked(key);
if (node != nullptr) { if (node != nullptr) {
node->SetSession(std::move(session)); node->SetSession(std::move(session));
return; return;
} }
grpc_slice key_slice = grpc_slice_from_copied_string(key); node = new Node(key, std::move(session));
node = new Node(key_slice, std::move(session));
PushFront(node); PushFront(node);
entry_by_key_ = grpc_avl_add(entry_by_key_, node->AvlKey(), node, nullptr); entry_by_key_.emplace(key, node);
AssertInvariants(); AssertInvariants();
if (use_order_list_size_ > capacity_) { if (use_order_list_size_ > capacity_) {
GPR_ASSERT(use_order_list_tail_); GPR_ASSERT(use_order_list_tail_);
node = use_order_list_tail_; node = use_order_list_tail_;
Remove(node); Remove(node);
// Order matters, key is destroyed after deleting node. // Order matters, key is destroyed after deleting node.
entry_by_key_ = grpc_avl_remove(entry_by_key_, node->AvlKey(), nullptr); entry_by_key_.erase(node->key());
delete node; delete node;
AssertInvariants(); AssertInvariants();
} }
@ -144,8 +117,7 @@ void SslSessionLRUCache::Put(const char* key, SslSessionPtr session) {
SslSessionPtr SslSessionLRUCache::Get(const char* key) { SslSessionPtr SslSessionLRUCache::Get(const char* key) {
grpc_core::MutexLock lock(&lock_); grpc_core::MutexLock lock(&lock_);
// Key is only used for lookups. // Key is only used for lookups.
grpc_slice key_slice = grpc_slice_from_static_string(key); Node* node = FindLocked(key);
Node* node = FindLocked(key_slice);
if (node == nullptr) { if (node == nullptr) {
return nullptr; return nullptr;
} }
@ -183,13 +155,6 @@ void SslSessionLRUCache::PushFront(SslSessionLRUCache::Node* node) {
} }
#ifndef NDEBUG #ifndef NDEBUG
static size_t calculate_tree_size(grpc_avl_node* node) {
if (node == nullptr) {
return 0;
}
return 1 + calculate_tree_size(node->left) + calculate_tree_size(node->right);
}
void SslSessionLRUCache::AssertInvariants() { void SslSessionLRUCache::AssertInvariants() {
size_t size = 0; size_t size = 0;
Node* prev = nullptr; Node* prev = nullptr;
@ -197,14 +162,15 @@ void SslSessionLRUCache::AssertInvariants() {
while (current != nullptr) { while (current != nullptr) {
size++; size++;
GPR_ASSERT(current->prev_ == prev); GPR_ASSERT(current->prev_ == prev);
void* node = grpc_avl_get(entry_by_key_, current->AvlKey(), nullptr); auto it = entry_by_key_.find(current->key());
GPR_ASSERT(node == current); GPR_ASSERT(it != entry_by_key_.end());
GPR_ASSERT(it->second == current);
prev = current; prev = current;
current = current->next_; current = current->next_;
} }
GPR_ASSERT(prev == use_order_list_tail_); GPR_ASSERT(prev == use_order_list_tail_);
GPR_ASSERT(size == use_order_list_size_); GPR_ASSERT(size == use_order_list_size_);
GPR_ASSERT(calculate_tree_size(entry_by_key_.root) == use_order_list_size_); GPR_ASSERT(entry_by_key_.size() == use_order_list_size_);
} }
#else #else
void SslSessionLRUCache::AssertInvariants() {} void SslSessionLRUCache::AssertInvariants() {}

@ -28,7 +28,8 @@ extern "C" {
#include <openssl/ssl.h> #include <openssl/ssl.h>
} }
#include "src/core/lib/avl/avl.h" #include <map>
#include "src/core/lib/gprpp/memory.h" #include "src/core/lib/gprpp/memory.h"
#include "src/core/lib/gprpp/ref_counted.h" #include "src/core/lib/gprpp/ref_counted.h"
#include "src/core/lib/gprpp/sync.h" #include "src/core/lib/gprpp/sync.h"
@ -72,7 +73,7 @@ class SslSessionLRUCache : public grpc_core::RefCounted<SslSessionLRUCache> {
private: private:
class Node; class Node;
Node* FindLocked(const grpc_slice& key); Node* FindLocked(const std::string& key);
void Remove(Node* node); void Remove(Node* node);
void PushFront(Node* node); void PushFront(Node* node);
void AssertInvariants(); void AssertInvariants();
@ -83,7 +84,7 @@ class SslSessionLRUCache : public grpc_core::RefCounted<SslSessionLRUCache> {
Node* use_order_list_head_ = nullptr; Node* use_order_list_head_ = nullptr;
Node* use_order_list_tail_ = nullptr; Node* use_order_list_tail_ = nullptr;
size_t use_order_list_size_ = 0; size_t use_order_list_size_ = 0;
grpc_avl entry_by_key_; std::map<std::string, Node*> entry_by_key_;
}; };
} // namespace tsi } // namespace tsi

Loading…
Cancel
Save