Add shared_ptr support to ChannelArgs (#31056)

Split off from #30556, this adds the ability to put `std::shared_ptr`s
into ChannelArgs. The EventEngine specialization and preconditioning
will be done separately, when some prework is done there.
pull/31049/head^2
AJ Heller 2 years ago committed by GitHub
parent bc4f98bb36
commit 22df3d9089
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 128
      src/core/lib/channel/channel_args.h
  2. 42
      test/core/channel/channel_args_test.cc

@ -54,7 +54,7 @@ namespace grpc_core {
// ChannelArgs to automatically derive a vtable from a T*. // ChannelArgs to automatically derive a vtable from a T*.
// To participate as a pointer, instances should expose the function: // To participate as a pointer, instances should expose the function:
// // Gets the vtable for this type // // Gets the vtable for this type
// static const grpc_channel_arg_vtable* VTable(); // static const grpc_arg_pointer_vtable* VTable();
// // Performs any mutations required for channel args to own a pointer // // Performs any mutations required for channel args to own a pointer
// // Only needed if ChannelArgs::Set is to be called with a raw pointer. // // Only needed if ChannelArgs::Set is to be called with a raw pointer.
// static void* TakeUnownedPointer(T* p); // static void* TakeUnownedPointer(T* p);
@ -117,6 +117,32 @@ struct ChannelArgTypeTraits<
}; };
}; };
// Specialization for shared_ptr
// Incurs an allocation because shared_ptr.release is not a thing.
template <typename T>
struct is_shared_ptr : std::false_type {};
template <typename T>
struct is_shared_ptr<std::shared_ptr<T>> : std::true_type {};
template <typename T>
struct ChannelArgTypeTraits<T,
absl::enable_if_t<is_shared_ptr<T>::value, void>> {
static void* TakeUnownedPointer(T* p) { return p; }
static const grpc_arg_pointer_vtable* VTable() {
static const grpc_arg_pointer_vtable tbl = {
// copy
[](void* p) -> void* { return new T(*static_cast<T*>(p)); },
// destroy
[](void* p) { delete static_cast<T*>(p); },
// compare
[](void* p1, void* p2) {
return QsortCompare(static_cast<const T*>(p1)->get(),
static_cast<const T*>(p2)->get());
},
};
return &tbl;
};
};
// If a type declares some member 'struct RawPointerChannelArgTag {}' then // If a type declares some member 'struct RawPointerChannelArgTag {}' then
// we automatically generate a vtable for it that does not do any ownership // we automatically generate a vtable for it that does not do any ownership
// management and compares the type by pointer identity. // management and compares the type by pointer identity.
@ -139,6 +165,55 @@ struct ChannelArgTypeTraits<T,
}; };
}; };
// GetObject support for shared_ptr and RefCountedPtr
template <typename T>
struct WrapInSharedPtr
: std::integral_constant<
bool, std::is_base_of<std::enable_shared_from_this<T>, T>::value> {};
template <typename T, typename Ignored = void /* for SFINAE */>
struct GetObjectImpl;
// std::shared_ptr implementation
template <typename T>
struct GetObjectImpl<T, absl::enable_if_t<WrapInSharedPtr<T>::value, void>> {
using Result = T*;
using ReffedResult = std::shared_ptr<T>;
using StoredType = std::shared_ptr<T>*;
static Result Get(StoredType p) { return p->get(); };
static ReffedResult GetReffed(StoredType p) { return ReffedResult(*p); };
static ReffedResult GetReffed(StoredType p,
const DebugLocation& /* location */,
const char* /* reason */) {
return GetReffed(*p);
};
};
// RefCountedPtr
template <typename T>
struct GetObjectImpl<T, absl::enable_if_t<!WrapInSharedPtr<T>::value, void>> {
using Result = T*;
using ReffedResult = RefCountedPtr<T>;
using StoredType = Result;
static Result Get(StoredType p) { return p; };
static ReffedResult GetReffed(StoredType p) {
if (p == nullptr) return nullptr;
return p->Ref();
};
static ReffedResult GetReffed(StoredType p, const DebugLocation& location,
const char* reason) {
if (p == nullptr) return nullptr;
return p->Ref(location, reason);
};
};
// Provide the canonical name for a type's channel arg key
template <typename T>
struct ChannelArgNameTraits {
static absl::string_view ChannelArgName() { return T::ChannelArgName(); }
};
template <typename T>
struct ChannelArgNameTraits<std::shared_ptr<T>> {
static absl::string_view ChannelArgName() { return T::ChannelArgName(); }
};
class ChannelArgs { class ChannelArgs {
public: public:
class Pointer { class Pointer {
@ -243,6 +318,20 @@ class ChannelArgs {
ChannelArgTypeTraits<absl::remove_cvref_t<T>>::VTable())); ChannelArgTypeTraits<absl::remove_cvref_t<T>>::VTable()));
} }
template <typename T> template <typename T>
GRPC_MUST_USE_RESULT absl::enable_if_t<
std::is_same<
const grpc_arg_pointer_vtable*,
decltype(ChannelArgTypeTraits<std::shared_ptr<T>>::VTable())>::value,
ChannelArgs>
Set(absl::string_view name, std::shared_ptr<T> value) const {
auto* store_value = new std::shared_ptr<T>(value);
return Set(
name,
Pointer(ChannelArgTypeTraits<std::shared_ptr<T>>::TakeUnownedPointer(
store_value),
ChannelArgTypeTraits<std::shared_ptr<T>>::VTable()));
}
template <typename T>
GRPC_MUST_USE_RESULT ChannelArgs SetIfUnset(absl::string_view name, GRPC_MUST_USE_RESULT ChannelArgs SetIfUnset(absl::string_view name,
T value) const { T value) const {
if (Contains(name)) return *this; if (Contains(name)) return *this;
@ -251,13 +340,20 @@ class ChannelArgs {
GRPC_MUST_USE_RESULT ChannelArgs Remove(absl::string_view name) const; GRPC_MUST_USE_RESULT ChannelArgs Remove(absl::string_view name) const;
bool Contains(absl::string_view name) const; bool Contains(absl::string_view name) const;
template <typename T>
bool ContainsObject() const {
return Get(ChannelArgNameTraits<T>::ChannelArgName()) != nullptr;
}
absl::optional<int> GetInt(absl::string_view name) const; absl::optional<int> GetInt(absl::string_view name) const;
absl::optional<absl::string_view> GetString(absl::string_view name) const; absl::optional<absl::string_view> GetString(absl::string_view name) const;
absl::optional<std::string> GetOwnedString(absl::string_view name) const; absl::optional<std::string> GetOwnedString(absl::string_view name) const;
void* GetVoidPointer(absl::string_view name) const; void* GetVoidPointer(absl::string_view name) const;
template <typename T> template <typename T>
T* GetPointer(absl::string_view name) const { typename GetObjectImpl<T>::StoredType GetPointer(
return static_cast<T*>(GetVoidPointer(name)); absl::string_view name) const {
return static_cast<typename GetObjectImpl<T>::StoredType>(
GetVoidPointer(name));
} }
absl::optional<Duration> GetDurationFromIntMillis( absl::optional<Duration> GetDurationFromIntMillis(
absl::string_view name) const; absl::string_view name) const;
@ -277,21 +373,25 @@ class ChannelArgs {
return Set(T::ChannelArgName(), std::move(p)); return Set(T::ChannelArgName(), std::move(p));
} }
template <typename T> template <typename T>
T* GetObject() const { GRPC_MUST_USE_RESULT ChannelArgs SetObject(std::shared_ptr<T> p) const {
return GetPointer<T>(T::ChannelArgName()); return Set(ChannelArgNameTraits<T>::ChannelArgName(), std::move(p));
} }
template <typename T> template <typename T>
RefCountedPtr<T> GetObjectRef() const { typename GetObjectImpl<T>::Result GetObject() const {
auto* p = GetObject<T>(); return GetObjectImpl<T>::Get(
if (p == nullptr) return nullptr; GetPointer<T>(ChannelArgNameTraits<T>::ChannelArgName()));
return p->Ref(DEBUG_LOCATION, "ChannelArgs GetObjectRef()");
} }
template <typename T> template <typename T>
RefCountedPtr<T> GetObjectRef(const DebugLocation& location, typename GetObjectImpl<T>::ReffedResult GetObjectRef() const {
const char* reason) const { return GetObjectImpl<T>::GetReffed(
auto* p = GetObject<T>(); GetPointer<T>(ChannelArgNameTraits<T>::ChannelArgName()));
if (p == nullptr) return nullptr; }
return p->Ref(location, reason); template <typename T>
typename GetObjectImpl<T>::ReffedResult GetObjectRef(
const DebugLocation& location, const char* reason) const {
return GetObjectImpl<T>::GetReffed(
GetPointer<T>(ChannelArgNameTraits<T>::ChannelArgName()), location,
reason);
} }
bool operator!=(const ChannelArgs& other) const; bool operator!=(const ChannelArgs& other) const;

@ -124,6 +124,48 @@ TEST(ChannelArgsTest, ToAndFromC) {
gpr_free(ptr); gpr_free(ptr);
} }
// shared_ptrs in ChannelArgs must support enable_shared_from_this
class ShareableObject : public std::enable_shared_from_this<ShareableObject> {
public:
explicit ShareableObject(int n) : n(n) {}
int n;
static int ChannelArgsCompare(const ShareableObject* a,
const ShareableObject* b) {
return a->n - b->n;
}
static absl::string_view ChannelArgName() { return "grpc.test"; }
};
TEST(ChannelArgsTest, StoreAndRetrieveSharedPtr) {
std::shared_ptr<ShareableObject> copied_obj;
{
ChannelArgs channel_args;
auto shared_obj = std::make_shared<ShareableObject>(42);
EXPECT_TRUE(shared_obj.unique());
channel_args = channel_args.SetObject(shared_obj);
EXPECT_FALSE(shared_obj.unique());
copied_obj = channel_args.GetObjectRef<ShareableObject>();
EXPECT_EQ(copied_obj->n, 42);
// Refs: p, copied_obj, and ChannelArgs
EXPECT_EQ(3, copied_obj.use_count());
}
// The p and ChannelArgs are deleted.
EXPECT_TRUE(copied_obj.unique());
EXPECT_EQ(copied_obj->n, 42);
}
TEST(ChannelArgsTest, RetrieveRawPointerFromStoredSharedPtr) {
ChannelArgs channel_args;
auto shared_obj = std::make_shared<ShareableObject>(42);
EXPECT_TRUE(shared_obj.unique());
channel_args = channel_args.SetObject(shared_obj);
EXPECT_FALSE(shared_obj.unique());
ShareableObject* raw_obj = channel_args.GetObject<ShareableObject>();
EXPECT_EQ(raw_obj->n, 42);
// Refs: p and ChannelArgs
EXPECT_EQ(2, shared_obj.use_count());
}
} // namespace grpc_core } // namespace grpc_core
TEST(GrpcChannelArgsTest, Create) { TEST(GrpcChannelArgsTest, Create) {

Loading…
Cancel
Save