diff --git a/BUILD b/BUILD index 4d182f30292..1e1eb64b4c8 100644 --- a/BUILD +++ b/BUILD @@ -642,6 +642,20 @@ grpc_cc_library( ], ) +grpc_cc_library( + name = "dual_ref_counted", + language = "c++", + public_hdrs = ["src/core/lib/gprpp/dual_ref_counted.h"], + deps = [ + "atomic", + "debug_location", + "gpr_base", + "grpc_trace", + "orphanable", + "ref_counted_ptr", + ], +) + grpc_cc_library( name = "ref_counted_ptr", language = "c++", diff --git a/CMakeLists.txt b/CMakeLists.txt index 2495f0e8d22..56ac713074b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -808,6 +808,7 @@ if(gRPC_BUILD_TESTS) add_dependencies(buildtests_cxx context_list_test) add_dependencies(buildtests_cxx delegating_channel_test) add_dependencies(buildtests_cxx destroy_grpclb_channel_with_active_connect_stress_test) + add_dependencies(buildtests_cxx dual_ref_counted_test) add_dependencies(buildtests_cxx duplicate_header_bad_client_test) add_dependencies(buildtests_cxx end2end_test) add_dependencies(buildtests_cxx error_details_test) @@ -10553,6 +10554,45 @@ target_link_libraries(destroy_grpclb_channel_with_active_connect_stress_test ) +endif() +if(gRPC_BUILD_TESTS) + +add_executable(dual_ref_counted_test + test/core/gprpp/dual_ref_counted_test.cc + third_party/googletest/googletest/src/gtest-all.cc + third_party/googletest/googlemock/src/gmock-all.cc +) + +target_include_directories(dual_ref_counted_test + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR} + ${CMAKE_CURRENT_SOURCE_DIR}/include + ${_gRPC_ADDRESS_SORTING_INCLUDE_DIR} + ${_gRPC_RE2_INCLUDE_DIR} + ${_gRPC_SSL_INCLUDE_DIR} + ${_gRPC_UPB_GENERATED_DIR} + ${_gRPC_UPB_GRPC_GENERATED_DIR} + ${_gRPC_UPB_INCLUDE_DIR} + ${_gRPC_ZLIB_INCLUDE_DIR} + third_party/googletest/googletest/include + third_party/googletest/googletest + third_party/googletest/googlemock/include + third_party/googletest/googlemock + ${_gRPC_PROTO_GENS_DIR} +) + +target_link_libraries(dual_ref_counted_test + ${_gRPC_PROTOBUF_LIBRARIES} + ${_gRPC_ALLTARGETS_LIBRARIES} + grpc_test_util + grpc + gpr + address_sorting + upb + ${_gRPC_GFLAGS_LIBRARIES} +) + + endif() if(gRPC_BUILD_TESTS) diff --git a/build_autogenerated.yaml b/build_autogenerated.yaml index 7da93bb56e4..ada639cd9ee 100644 --- a/build_autogenerated.yaml +++ b/build_autogenerated.yaml @@ -5628,6 +5628,20 @@ targets: - gpr - address_sorting - upb +- name: dual_ref_counted_test + gtest: true + build: test + language: c++ + headers: + - src/core/lib/gprpp/dual_ref_counted.h + src: + - test/core/gprpp/dual_ref_counted_test.cc + deps: + - grpc_test_util + - grpc + - gpr + - address_sorting + - upb - name: duplicate_header_bad_client_test gtest: true build: test @@ -6750,7 +6764,8 @@ targets: gtest: true build: test language: c++ - headers: [] + headers: + - src/core/lib/gprpp/dual_ref_counted.h src: - test/core/gprpp/ref_counted_ptr_test.cc deps: diff --git a/src/core/lib/gprpp/dual_ref_counted.h b/src/core/lib/gprpp/dual_ref_counted.h new file mode 100644 index 00000000000..66935bbdea1 --- /dev/null +++ b/src/core/lib/gprpp/dual_ref_counted.h @@ -0,0 +1,336 @@ +// +// Copyright 2020 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#ifndef GRPC_CORE_LIB_GPRPP_DUAL_REF_COUNTED_H +#define GRPC_CORE_LIB_GPRPP_DUAL_REF_COUNTED_H + +#include + +#include +#include +#include + +#include +#include +#include + +#include "src/core/lib/debug/trace.h" +#include "src/core/lib/gprpp/atomic.h" +#include "src/core/lib/gprpp/debug_location.h" +#include "src/core/lib/gprpp/orphanable.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" + +namespace grpc_core { + +// DualRefCounted is an interface for reference-counted objects with two +// classes of refs: strong refs (usually just called "refs") and weak refs. +// This supports cases where an object needs to start shutting down when +// all external callers are done with it (represented by strong refs) but +// cannot be destroyed until all internal callbacks are complete +// (represented by weak refs). +// +// Each class of refs can be incremented and decremented independently. +// Objects start with 1 strong ref and 0 weak refs at instantiation. +// When the strong refcount reaches 0, the object's Orphan() method is called. +// When the weak refcount reaches 0, the object is destroyed. +// +// This will be used by CRTP (curiously-recurring template pattern), e.g.: +// class MyClass : public RefCounted { ... }; +template +class DualRefCounted : public Orphanable { + public: + virtual ~DualRefCounted() = default; + + RefCountedPtr Ref() GRPC_MUST_USE_RESULT { + IncrementRefCount(); + return RefCountedPtr(static_cast(this)); + } + + RefCountedPtr Ref(const DebugLocation& location, + const char* reason) GRPC_MUST_USE_RESULT { + IncrementRefCount(location, reason); + return RefCountedPtr(static_cast(this)); + } + + void Unref() { + // Convert strong ref to weak ref. + const uint64_t prev_ref_pair = + refs_.FetchAdd(MakeRefPair(-1, 1), MemoryOrder::ACQ_REL); + const uint32_t strong_refs = GetStrongRefs(prev_ref_pair); +#ifndef NDEBUG + const uint32_t weak_refs = GetWeakRefs(prev_ref_pair); + if (trace_flag_ != nullptr && trace_flag_->enabled()) { + gpr_log(GPR_INFO, "%s:%p unref %d -> %d, weak_ref %d -> %d", + trace_flag_->name(), this, strong_refs, strong_refs - 1, + weak_refs, weak_refs + 1); + } + GPR_ASSERT(strong_refs > 0); +#endif + if (GPR_UNLIKELY(strong_refs == 1)) { + Orphan(); + } + // Now drop the weak ref. + WeakUnref(); + } + void Unref(const DebugLocation& location, const char* reason) { + const uint64_t prev_ref_pair = + refs_.FetchAdd(MakeRefPair(-1, 1), MemoryOrder::ACQ_REL); + const uint32_t strong_refs = GetStrongRefs(prev_ref_pair); +#ifndef NDEBUG + const uint32_t weak_refs = GetWeakRefs(prev_ref_pair); + if (trace_flag_ != nullptr && trace_flag_->enabled()) { + gpr_log(GPR_INFO, "%s:%p %s:%d unref %d -> %d, weak_ref %d -> %d) %s", + trace_flag_->name(), this, location.file(), location.line(), + strong_refs, strong_refs - 1, weak_refs, weak_refs + 1, reason); + } + GPR_ASSERT(strong_refs > 0); +#else + // Avoid unused-parameter warnings for debug-only parameters + (void)location; + (void)reason; +#endif + if (GPR_UNLIKELY(strong_refs == 1)) { + Orphan(); + } + // Now drop the weak ref. + WeakUnref(location, reason); + } + + RefCountedPtr RefIfNonZero() GRPC_MUST_USE_RESULT { + uint64_t prev_ref_pair = refs_.Load(MemoryOrder::ACQUIRE); + do { + const uint32_t strong_refs = GetStrongRefs(prev_ref_pair); +#ifndef NDEBUG + const uint32_t weak_refs = GetWeakRefs(prev_ref_pair); + if (trace_flag_ != nullptr && trace_flag_->enabled()) { + gpr_log(GPR_INFO, "%s:%p ref_if_non_zero %d -> %d (weak_refs=%d)", + trace_flag_->name(), this, strong_refs, strong_refs + 1, + weak_refs); + } +#endif + if (strong_refs == 0) return nullptr; + } while (!refs_.CompareExchangeWeak( + &prev_ref_pair, prev_ref_pair + MakeRefPair(1, 0), MemoryOrder::ACQ_REL, + MemoryOrder::ACQUIRE)); + return RefCountedPtr(static_cast(this)); + } + + RefCountedPtr RefIfNonZero(const DebugLocation& location, + const char* reason) GRPC_MUST_USE_RESULT { + uint64_t prev_ref_pair = refs_.Load(MemoryOrder::ACQUIRE); + do { + const uint32_t strong_refs = GetStrongRefs(prev_ref_pair); +#ifndef NDEBUG + const uint32_t weak_refs = GetWeakRefs(prev_ref_pair); + if (trace_flag_ != nullptr && trace_flag_->enabled()) { + gpr_log(GPR_INFO, + "%s:%p %s:%d ref_if_non_zero %d -> %d (weak_refs=%d) %s", + trace_flag_->name(), this, location.file(), location.line(), + strong_refs, strong_refs + 1, weak_refs, reason); + } +#else + // Avoid unused-parameter warnings for debug-only parameters + (void)location; + (void)reason; +#endif + if (strong_refs == 0) return nullptr; + } while (!refs_.CompareExchangeWeak( + &prev_ref_pair, prev_ref_pair + MakeRefPair(1, 0), MemoryOrder::ACQ_REL, + MemoryOrder::ACQUIRE)); + return RefCountedPtr(static_cast(this)); + } + + WeakRefCountedPtr WeakRef() GRPC_MUST_USE_RESULT { + IncrementWeakRefCount(); + return WeakRefCountedPtr(static_cast(this)); + } + + WeakRefCountedPtr WeakRef(const DebugLocation& location, + const char* reason) GRPC_MUST_USE_RESULT { + IncrementWeakRefCount(location, reason); + return WeakRefCountedPtr(static_cast(this)); + } + + void WeakUnref() { +#ifndef NDEBUG + // Grab a copy of the trace flag before the atomic change, since we + // can't safely access it afterwards if we're going to be freed. + auto* trace_flag = trace_flag_; +#endif + const uint64_t prev_ref_pair = + refs_.FetchSub(MakeRefPair(0, 1), MemoryOrder::ACQ_REL); + const uint32_t weak_refs = GetWeakRefs(prev_ref_pair); +#ifndef NDEBUG + const uint32_t strong_refs = GetStrongRefs(prev_ref_pair); + if (trace_flag != nullptr && trace_flag->enabled()) { + gpr_log(GPR_INFO, "%s:%p weak_unref %d -> %d (refs=%d)", + trace_flag->name(), this, weak_refs, weak_refs - 1, strong_refs); + } + GPR_ASSERT(weak_refs > 0); +#endif + if (GPR_UNLIKELY(prev_ref_pair == MakeRefPair(0, 1))) { + delete static_cast(this); + } + } + void WeakUnref(const DebugLocation& location, const char* reason) { +#ifndef NDEBUG + // Grab a copy of the trace flag before the atomic change, since we + // can't safely access it afterwards if we're going to be freed. + auto* trace_flag = trace_flag_; +#endif + const uint64_t prev_ref_pair = + refs_.FetchSub(MakeRefPair(0, 1), MemoryOrder::ACQ_REL); + const uint32_t weak_refs = GetWeakRefs(prev_ref_pair); +#ifndef NDEBUG + const uint32_t strong_refs = GetStrongRefs(prev_ref_pair); + if (trace_flag != nullptr && trace_flag->enabled()) { + gpr_log(GPR_INFO, "%s:%p %s:%d weak_unref %d -> %d (refs=%d) %s", + trace_flag->name(), this, location.file(), location.line(), + weak_refs, weak_refs - 1, strong_refs, reason); + } + GPR_ASSERT(weak_refs > 0); +#else + // Avoid unused-parameter warnings for debug-only parameters + (void)location; + (void)reason; +#endif + if (GPR_UNLIKELY(prev_ref_pair == MakeRefPair(0, 1))) { + delete static_cast(this); + } + } + + // Not copyable nor movable. + DualRefCounted(const DualRefCounted&) = delete; + DualRefCounted& operator=(const DualRefCounted&) = delete; + + protected: + // TraceFlagT is defined to accept both DebugOnlyTraceFlag and TraceFlag. + // Note: RefCount tracing is only enabled on debug builds, even when a + // TraceFlag is used. + template + explicit DualRefCounted( + TraceFlagT* +#ifndef NDEBUG + // Leave unnamed if NDEBUG to avoid unused parameter warning + trace_flag +#endif + = nullptr, + int32_t initial_refcount = 1) + : +#ifndef NDEBUG + trace_flag_(trace_flag), +#endif + refs_(MakeRefPair(initial_refcount, 0)) { + } + + private: + // Allow RefCountedPtr<> to access IncrementRefCount(). + template + friend class RefCountedPtr; + // Allow WeakRefCountedPtr<> to access IncrementWeakRefCount(). + template + friend class WeakRefCountedPtr; + + // First 32 bits are strong refs, next 32 bits are weak refs. + static uint64_t MakeRefPair(uint32_t strong, uint32_t weak) { + return (static_cast(strong) << 32) + static_cast(weak); + } + static uint32_t GetStrongRefs(uint64_t ref_pair) { + return static_cast(ref_pair >> 32); + } + static uint32_t GetWeakRefs(uint64_t ref_pair) { + return static_cast(ref_pair & 0xffffffffu); + } + + void IncrementRefCount() { +#ifndef NDEBUG + const uint64_t prev_ref_pair = + refs_.FetchAdd(MakeRefPair(1, 0), MemoryOrder::RELAXED); + const uint32_t strong_refs = GetStrongRefs(prev_ref_pair); + const uint32_t weak_refs = GetWeakRefs(prev_ref_pair); + GPR_ASSERT(strong_refs != 0); + if (trace_flag_ != nullptr && trace_flag_->enabled()) { + gpr_log(GPR_INFO, "%s:%p ref %d -> %d; (weak_refs=%d)", + trace_flag_->name(), this, strong_refs, strong_refs + 1, + weak_refs); + } +#else + refs_.FetchAdd(MakeRefPair(1, 0), MemoryOrder::RELAXED); +#endif + } + void IncrementRefCount(const DebugLocation& location, const char* reason) { +#ifndef NDEBUG + const uint64_t prev_ref_pair = + refs_.FetchAdd(MakeRefPair(1, 0), MemoryOrder::RELAXED); + const uint32_t strong_refs = GetStrongRefs(prev_ref_pair); + const uint32_t weak_refs = GetWeakRefs(prev_ref_pair); + GPR_ASSERT(strong_refs != 0); + if (trace_flag_ != nullptr && trace_flag_->enabled()) { + gpr_log(GPR_INFO, "%s:%p %s:%d ref %d -> %d (weak_refs=%d) %s", + trace_flag_->name(), this, location.file(), location.line(), + strong_refs, strong_refs + 1, weak_refs, reason); + } +#else + // Use conditionally-important parameters + (void)location; + (void)reason; + refs_.FetchAdd(MakeRefPair(1, 0), MemoryOrder::RELAXED); +#endif + } + + void IncrementWeakRefCount() { +#ifndef NDEBUG + const uint64_t prev_ref_pair = + refs_.FetchAdd(MakeRefPair(0, 1), MemoryOrder::RELAXED); + const uint32_t strong_refs = GetStrongRefs(prev_ref_pair); + const uint32_t weak_refs = GetWeakRefs(prev_ref_pair); + if (trace_flag_ != nullptr && trace_flag_->enabled()) { + gpr_log(GPR_INFO, "%s:%p weak_ref %d -> %d; (refs=%d)", + trace_flag_->name(), this, weak_refs, weak_refs + 1, strong_refs); + } +#else + refs_.FetchAdd(MakeRefPair(0, 1), MemoryOrder::RELAXED); +#endif + } + void IncrementWeakRefCount(const DebugLocation& location, + const char* reason) { +#ifndef NDEBUG + const uint64_t prev_ref_pair = + refs_.FetchAdd(MakeRefPair(0, 1), MemoryOrder::RELAXED); + const uint32_t strong_refs = GetStrongRefs(prev_ref_pair); + const uint32_t weak_refs = GetWeakRefs(prev_ref_pair); + if (trace_flag_ != nullptr && trace_flag_->enabled()) { + gpr_log(GPR_INFO, "%s:%p %s:%d weak_ref %d -> %d (refs=%d) %s", + trace_flag_->name(), this, location.file(), location.line(), + weak_refs, weak_refs + 1, strong_refs, reason); + } +#else + // Use conditionally-important parameters + (void)location; + (void)reason; + refs_.FetchAdd(MakeRefPair(0, 1), MemoryOrder::RELAXED); +#endif + } + +#ifndef NDEBUG + TraceFlag* trace_flag_; +#endif + Atomic refs_; +}; + +} // namespace grpc_core + +#endif /* GRPC_CORE_LIB_GPRPP_DUAL_REF_COUNTED_H */ diff --git a/src/core/lib/gprpp/ref_counted_ptr.h b/src/core/lib/gprpp/ref_counted_ptr.h index 179491b22c2..c28e7625323 100644 --- a/src/core/lib/gprpp/ref_counted_ptr.h +++ b/src/core/lib/gprpp/ref_counted_ptr.h @@ -177,6 +177,154 @@ class RefCountedPtr { T* value_ = nullptr; }; +// A smart pointer class for objects that provide IncrementWeakRefCount() and +// WeakUnref() methods, such as those provided by the DualRefCounted base class. +template +class WeakRefCountedPtr { + public: + WeakRefCountedPtr() {} + WeakRefCountedPtr(std::nullptr_t) {} + + // If value is non-null, we take ownership of a ref to it. + template + explicit WeakRefCountedPtr(Y* value) { + value_ = value; + } + + // Move ctors. + WeakRefCountedPtr(WeakRefCountedPtr&& other) { + value_ = other.value_; + other.value_ = nullptr; + } + template + WeakRefCountedPtr(WeakRefCountedPtr&& other) { + value_ = static_cast(other.value_); + other.value_ = nullptr; + } + + // Move assignment. + WeakRefCountedPtr& operator=(WeakRefCountedPtr&& other) { + reset(other.value_); + other.value_ = nullptr; + return *this; + } + template + WeakRefCountedPtr& operator=(WeakRefCountedPtr&& other) { + reset(other.value_); + other.value_ = nullptr; + return *this; + } + + // Copy ctors. + WeakRefCountedPtr(const WeakRefCountedPtr& other) { + if (other.value_ != nullptr) other.value_->IncrementWeakRefCount(); + value_ = other.value_; + } + template + WeakRefCountedPtr(const WeakRefCountedPtr& other) { + static_assert(std::has_virtual_destructor::value, + "T does not have a virtual dtor"); + if (other.value_ != nullptr) other.value_->IncrementWeakRefCount(); + value_ = static_cast(other.value_); + } + + // Copy assignment. + WeakRefCountedPtr& operator=(const WeakRefCountedPtr& other) { + // Note: Order of reffing and unreffing is important here in case value_ + // and other.value_ are the same object. + if (other.value_ != nullptr) other.value_->IncrementWeakRefCount(); + reset(other.value_); + return *this; + } + template + WeakRefCountedPtr& operator=(const WeakRefCountedPtr& other) { + static_assert(std::has_virtual_destructor::value, + "T does not have a virtual dtor"); + // Note: Order of reffing and unreffing is important here in case value_ + // and other.value_ are the same object. + if (other.value_ != nullptr) other.value_->IncrementWeakRefCount(); + reset(other.value_); + return *this; + } + + ~WeakRefCountedPtr() { + if (value_ != nullptr) value_->WeakUnref(); + } + + void swap(WeakRefCountedPtr& other) { std::swap(value_, other.value_); } + + // If value is non-null, we take ownership of a ref to it. + void reset(T* value = nullptr) { + if (value_ != nullptr) value_->WeakUnref(); + value_ = value; + } + void reset(const DebugLocation& location, const char* reason, + T* value = nullptr) { + if (value_ != nullptr) value_->WeakUnref(location, reason); + value_ = value; + } + template + void reset(Y* value = nullptr) { + static_assert(std::has_virtual_destructor::value, + "T does not have a virtual dtor"); + if (value_ != nullptr) value_->WeakUnref(); + value_ = static_cast(value); + } + template + void reset(const DebugLocation& location, const char* reason, + Y* value = nullptr) { + static_assert(std::has_virtual_destructor::value, + "T does not have a virtual dtor"); + if (value_ != nullptr) value_->WeakUnref(location, reason); + value_ = static_cast(value); + } + + // TODO(roth): This method exists solely as a transition mechanism to allow + // us to pass a ref to idiomatic C code that does not use WeakRefCountedPtr<>. + // Once all of our code has been converted to idiomatic C++, this + // method should go away. + T* release() { + T* value = value_; + value_ = nullptr; + return value; + } + + T* get() const { return value_; } + + T& operator*() const { return *value_; } + T* operator->() const { return value_; } + + template + bool operator==(const WeakRefCountedPtr& other) const { + return value_ == other.value_; + } + + template + bool operator==(const Y* other) const { + return value_ == other; + } + + bool operator==(std::nullptr_t) const { return value_ == nullptr; } + + template + bool operator!=(const WeakRefCountedPtr& other) const { + return value_ != other.value_; + } + + template + bool operator!=(const Y* other) const { + return value_ != other; + } + + bool operator!=(std::nullptr_t) const { return value_ != nullptr; } + + private: + template + friend class WeakRefCountedPtr; + + T* value_ = nullptr; +}; + template inline RefCountedPtr MakeRefCounted(Args&&... args) { return RefCountedPtr(new T(std::forward(args)...)); @@ -187,6 +335,11 @@ bool operator<(const RefCountedPtr& p1, const RefCountedPtr& p2) { return p1.get() < p2.get(); } +template +bool operator<(const WeakRefCountedPtr& p1, const WeakRefCountedPtr& p2) { + return p1.get() < p2.get(); +} + } // namespace grpc_core #endif /* GRPC_CORE_LIB_GPRPP_REF_COUNTED_PTR_H */ diff --git a/test/core/gprpp/BUILD b/test/core/gprpp/BUILD index 4c9e82fa868..cdd72e9e9ba 100644 --- a/test/core/gprpp/BUILD +++ b/test/core/gprpp/BUILD @@ -121,6 +121,19 @@ grpc_cc_test( ], ) +grpc_cc_test( + name = "dual_ref_counted_test", + srcs = ["dual_ref_counted_test.cc"], + external_deps = [ + "gtest", + ], + language = "C++", + deps = [ + "//:dual_ref_counted", + "//test/core/util:grpc_test_util", + ], +) + grpc_cc_test( name = "ref_counted_ptr_test", srcs = ["ref_counted_ptr_test.cc"], @@ -129,6 +142,7 @@ grpc_cc_test( ], language = "C++", deps = [ + "//:dual_ref_counted", "//:ref_counted", "//:ref_counted_ptr", "//test/core/util:grpc_test_util", diff --git a/test/core/gprpp/dual_ref_counted_test.cc b/test/core/gprpp/dual_ref_counted_test.cc new file mode 100644 index 00000000000..ac3e5d96061 --- /dev/null +++ b/test/core/gprpp/dual_ref_counted_test.cc @@ -0,0 +1,112 @@ +// +// Copyright 2020 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include "src/core/lib/gprpp/dual_ref_counted.h" + +#include + +#include +#include + +#include "test/core/util/test_config.h" + +namespace grpc_core { +namespace testing { +namespace { + +class Foo : public DualRefCounted { + public: + Foo() = default; + ~Foo() { GPR_ASSERT(shutting_down_); } + + void Orphan() override { shutting_down_ = true; } + + private: + bool shutting_down_ = false; +}; + +TEST(DualRefCounted, Basic) { + Foo* foo = new Foo(); + foo->Unref(); +} + +TEST(DualRefCounted, ExtraRef) { + Foo* foo = new Foo(); + foo->Ref().release(); + foo->Unref(); + foo->Unref(); +} + +TEST(DualRefCounted, ExtraWeakRef) { + Foo* foo = new Foo(); + foo->WeakRef().release(); + foo->Unref(); + foo->WeakUnref(); +} + +TEST(DualRefCounted, RefIfNonZero) { + Foo* foo = new Foo(); + foo->WeakRef().release(); + { + RefCountedPtr foop = foo->RefIfNonZero(); + EXPECT_NE(foop.get(), nullptr); + } + foo->Unref(); + { + RefCountedPtr foop = foo->RefIfNonZero(); + EXPECT_EQ(foop.get(), nullptr); + } + foo->WeakUnref(); +} + +// Note: We use DebugOnlyTraceFlag instead of TraceFlag to ensure that +// things build properly in both debug and non-debug cases. +DebugOnlyTraceFlag foo_tracer(true, "foo"); + +class FooWithTracing : public DualRefCounted { + public: + FooWithTracing() : DualRefCounted(&foo_tracer) {} + ~FooWithTracing() { GPR_ASSERT(shutting_down_); } + + void Orphan() override { shutting_down_ = true; } + + private: + bool shutting_down_ = false; +}; + +TEST(DualRefCountedWithTracing, Basic) { + FooWithTracing* foo = new FooWithTracing(); + foo->Ref(DEBUG_LOCATION, "extra_ref").release(); + foo->Unref(DEBUG_LOCATION, "extra_ref"); + foo->WeakRef(DEBUG_LOCATION, "extra_ref").release(); + foo->WeakUnref(DEBUG_LOCATION, "extra_ref"); + // Can use the no-argument methods, too. + foo->Ref().release(); + foo->Unref(); + foo->WeakRef().release(); + foo->WeakUnref(); + foo->Unref(DEBUG_LOCATION, "original_ref"); +} + +} // namespace +} // namespace testing +} // namespace grpc_core + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/core/gprpp/ref_counted_ptr_test.cc b/test/core/gprpp/ref_counted_ptr_test.cc index e392c2058e4..3907b21dc1a 100644 --- a/test/core/gprpp/ref_counted_ptr_test.cc +++ b/test/core/gprpp/ref_counted_ptr_test.cc @@ -22,6 +22,7 @@ #include +#include "src/core/lib/gprpp/dual_ref_counted.h" #include "src/core/lib/gprpp/memory.h" #include "src/core/lib/gprpp/ref_counted.h" #include "test/core/util/test_config.h" @@ -30,6 +31,10 @@ namespace grpc_core { namespace testing { namespace { +// +// RefCountedPtr<> tests +// + class Foo : public RefCounted { public: Foo() : value_(0) {} @@ -53,27 +58,27 @@ TEST(RefCountedPtr, ExplicitConstructor) { RefCountedPtr foo(new Foo()); } TEST(RefCountedPtr, MoveConstructor) { RefCountedPtr foo(new Foo()); RefCountedPtr foo2(std::move(foo)); - EXPECT_EQ(nullptr, foo.get()); + EXPECT_EQ(nullptr, foo.get()); // NOLINT EXPECT_NE(nullptr, foo2.get()); } TEST(RefCountedPtr, MoveAssignment) { RefCountedPtr foo(new Foo()); RefCountedPtr foo2 = std::move(foo); - EXPECT_EQ(nullptr, foo.get()); + EXPECT_EQ(nullptr, foo.get()); // NOLINT EXPECT_NE(nullptr, foo2.get()); } TEST(RefCountedPtr, CopyConstructor) { RefCountedPtr foo(new Foo()); - const RefCountedPtr& foo2(foo); + RefCountedPtr foo2(foo); EXPECT_NE(nullptr, foo.get()); EXPECT_EQ(foo.get(), foo2.get()); } TEST(RefCountedPtr, CopyAssignment) { RefCountedPtr foo(new Foo()); - const RefCountedPtr& foo2 = foo; + RefCountedPtr foo2 = foo; EXPECT_NE(nullptr, foo.get()); EXPECT_EQ(foo.get(), foo2.get()); } @@ -250,6 +255,263 @@ TEST(RefCountedPtr, CanPassSubclassToFunctionExpectingSubclass) { FunctionTakingSubclass(p); } +// +// WeakRefCountedPtr<> tests +// + +class Bar : public DualRefCounted { + public: + Bar() : value_(0) {} + + explicit Bar(int value) : value_(value) {} + + ~Bar() { GPR_ASSERT(shutting_down_); } + + void Orphan() override { shutting_down_ = true; } + + int value() const { return value_; } + + private: + int value_; + bool shutting_down_ = false; +}; + +TEST(WeakRefCountedPtr, DefaultConstructor) { WeakRefCountedPtr bar; } + +TEST(WeakRefCountedPtr, ExplicitConstructorEmpty) { + WeakRefCountedPtr bar(nullptr); +} + +TEST(WeakRefCountedPtr, ExplicitConstructor) { + RefCountedPtr bar_strong(new Bar()); + bar_strong->WeakRef().release(); + WeakRefCountedPtr bar(bar_strong.get()); +} + +TEST(WeakRefCountedPtr, MoveConstructor) { + RefCountedPtr bar_strong(new Bar()); + WeakRefCountedPtr bar = bar_strong->WeakRef(); + WeakRefCountedPtr bar2(std::move(bar)); + EXPECT_EQ(nullptr, bar.get()); // NOLINT + EXPECT_NE(nullptr, bar2.get()); +} + +TEST(WeakRefCountedPtr, MoveAssignment) { + RefCountedPtr bar_strong(new Bar()); + WeakRefCountedPtr bar = bar_strong->WeakRef(); + WeakRefCountedPtr bar2 = std::move(bar); + EXPECT_EQ(nullptr, bar.get()); // NOLINT + EXPECT_NE(nullptr, bar2.get()); +} + +TEST(WeakRefCountedPtr, CopyConstructor) { + RefCountedPtr bar_strong(new Bar()); + WeakRefCountedPtr bar = bar_strong->WeakRef(); + WeakRefCountedPtr bar2(bar); + EXPECT_NE(nullptr, bar.get()); + EXPECT_EQ(bar.get(), bar2.get()); +} + +TEST(WeakRefCountedPtr, CopyAssignment) { + RefCountedPtr bar_strong(new Bar()); + WeakRefCountedPtr bar = bar_strong->WeakRef(); + WeakRefCountedPtr bar2 = bar; + EXPECT_NE(nullptr, bar.get()); + EXPECT_EQ(bar.get(), bar2.get()); +} + +TEST(WeakRefCountedPtr, CopyAssignmentWhenEmpty) { + WeakRefCountedPtr bar; + WeakRefCountedPtr bar2; + bar2 = bar; + EXPECT_EQ(nullptr, bar.get()); + EXPECT_EQ(nullptr, bar2.get()); +} + +TEST(WeakRefCountedPtr, CopyAssignmentToSelf) { + RefCountedPtr bar_strong(new Bar()); + WeakRefCountedPtr bar = bar_strong->WeakRef(); + bar = *&bar; // The "*&" avoids warnings from LLVM -Wself-assign. +} + +TEST(WeakRefCountedPtr, EnclosedScope) { + RefCountedPtr bar_strong(new Bar()); + WeakRefCountedPtr bar = bar_strong->WeakRef(); + { + WeakRefCountedPtr bar2(std::move(bar)); + EXPECT_EQ(nullptr, bar.get()); + EXPECT_NE(nullptr, bar2.get()); + } + EXPECT_EQ(nullptr, bar.get()); +} + +TEST(WeakRefCountedPtr, ResetFromNullToNonNull) { + RefCountedPtr bar_strong(new Bar()); + WeakRefCountedPtr bar; + EXPECT_EQ(nullptr, bar.get()); + bar_strong->WeakRef().release(); + bar.reset(bar_strong.get()); + EXPECT_NE(nullptr, bar.get()); +} + +TEST(WeakRefCountedPtr, ResetFromNonNullToNonNull) { + RefCountedPtr bar_strong(new Bar()); + RefCountedPtr bar2_strong(new Bar()); + WeakRefCountedPtr bar = bar_strong->WeakRef(); + EXPECT_NE(nullptr, bar.get()); + bar2_strong->WeakRef().release(); + bar.reset(bar2_strong.get()); + EXPECT_NE(nullptr, bar.get()); + EXPECT_NE(bar_strong.get(), bar.get()); +} + +TEST(WeakRefCountedPtr, ResetFromNonNullToNull) { + RefCountedPtr bar_strong(new Bar()); + WeakRefCountedPtr bar = bar_strong->WeakRef(); + EXPECT_NE(nullptr, bar.get()); + bar.reset(); + EXPECT_EQ(nullptr, bar.get()); +} + +TEST(WeakRefCountedPtr, ResetFromNullToNull) { + WeakRefCountedPtr bar; + EXPECT_EQ(nullptr, bar.get()); + bar.reset(); + EXPECT_EQ(nullptr, bar.get()); +} + +TEST(WeakRefCountedPtr, DerefernceOperators) { + RefCountedPtr bar_strong(new Bar()); + WeakRefCountedPtr bar = bar_strong->WeakRef(); + bar->value(); + Bar& bar_ref = *bar; + bar_ref.value(); +} + +TEST(WeakRefCountedPtr, EqualityOperators) { + RefCountedPtr bar_strong(new Bar()); + WeakRefCountedPtr bar = bar_strong->WeakRef(); + WeakRefCountedPtr bar2 = bar; + WeakRefCountedPtr empty; + // Test equality between RefCountedPtrs. + EXPECT_EQ(bar, bar2); + EXPECT_NE(bar, empty); + // Test equality with bare pointers. + EXPECT_EQ(bar, bar.get()); + EXPECT_EQ(empty, nullptr); + EXPECT_NE(bar, nullptr); +} + +TEST(WeakRefCountedPtr, Swap) { + RefCountedPtr bar_strong(new Bar()); + RefCountedPtr bar2_strong(new Bar()); + WeakRefCountedPtr bar = bar_strong->WeakRef(); + WeakRefCountedPtr bar2 = bar2_strong->WeakRef(); + bar.swap(bar2); + EXPECT_EQ(bar_strong.get(), bar2.get()); + EXPECT_EQ(bar2_strong.get(), bar.get()); + WeakRefCountedPtr bar3; + bar3.swap(bar2); + EXPECT_EQ(nullptr, bar2.get()); + EXPECT_EQ(bar_strong.get(), bar3.get()); +} + +TraceFlag bar_tracer(true, "bar"); + +class BarWithTracing : public DualRefCounted { + public: + BarWithTracing() : DualRefCounted(&bar_tracer) {} + + ~BarWithTracing() { GPR_ASSERT(shutting_down_); } + + void Orphan() override { shutting_down_ = true; } + + private: + bool shutting_down_ = false; +}; + +TEST(WeakRefCountedPtr, RefCountedWithTracing) { + RefCountedPtr bar_strong(new BarWithTracing()); + WeakRefCountedPtr bar = bar_strong->WeakRef(); + WeakRefCountedPtr bar2 = bar->WeakRef(DEBUG_LOCATION, "bar"); + bar2.release(); + bar->WeakUnref(DEBUG_LOCATION, "bar"); +} + +class WeakBaseClass : public DualRefCounted { + public: + WeakBaseClass() {} + + ~WeakBaseClass() { GPR_ASSERT(shutting_down_); } + + void Orphan() override { shutting_down_ = true; } + + private: + bool shutting_down_ = false; +}; + +class WeakSubclass : public WeakBaseClass { + public: + WeakSubclass() {} +}; + +TEST(WeakRefCountedPtr, ConstructFromWeakSubclass) { + RefCountedPtr strong(new WeakSubclass()); + WeakRefCountedPtr p(strong->WeakRef().release()); +} + +TEST(WeakRefCountedPtr, CopyAssignFromWeakSubclass) { + RefCountedPtr strong(new WeakSubclass()); + WeakRefCountedPtr b; + EXPECT_EQ(nullptr, b.get()); + WeakRefCountedPtr s = strong->WeakRef(); + b = s; + EXPECT_NE(nullptr, b.get()); +} + +TEST(WeakRefCountedPtr, MoveAssignFromWeakSubclass) { + RefCountedPtr strong(new WeakSubclass()); + WeakRefCountedPtr b; + EXPECT_EQ(nullptr, b.get()); + WeakRefCountedPtr s = strong->WeakRef(); + b = std::move(s); + EXPECT_NE(nullptr, b.get()); +} + +TEST(WeakRefCountedPtr, ResetFromWeakSubclass) { + RefCountedPtr strong(new WeakSubclass()); + WeakRefCountedPtr b; + EXPECT_EQ(nullptr, b.get()); + b.reset(strong->WeakRef().release()); + EXPECT_NE(nullptr, b.get()); +} + +TEST(WeakRefCountedPtr, EqualityWithWeakSubclass) { + RefCountedPtr strong(new WeakSubclass()); + WeakRefCountedPtr b = strong->WeakRef(); + EXPECT_EQ(b, strong.get()); +} + +void FunctionTakingWeakBaseClass(WeakRefCountedPtr p) { + p.reset(); // To appease clang-tidy. +} + +TEST(WeakRefCountedPtr, CanPassWeakSubclassToFunctionExpectingWeakBaseClass) { + RefCountedPtr strong(new WeakSubclass()); + WeakRefCountedPtr p = strong->WeakRef(); + FunctionTakingWeakBaseClass(p); +} + +void FunctionTakingWeakSubclass(WeakRefCountedPtr p) { + p.reset(); // To appease clang-tidy. +} + +TEST(WeakRefCountedPtr, CanPassWeakSubclassToFunctionExpectingWeakSubclass) { + RefCountedPtr strong(new WeakSubclass()); + WeakRefCountedPtr p = strong->WeakRef(); + FunctionTakingWeakSubclass(p); +} + } // namespace } // namespace testing } // namespace grpc_core diff --git a/tools/run_tests/generated/tests.json b/tools/run_tests/generated/tests.json index bd2040f432c..0d8d3d7c586 100644 --- a/tools/run_tests/generated/tests.json +++ b/tools/run_tests/generated/tests.json @@ -4265,6 +4265,30 @@ ], "uses_polling": true }, + { + "args": [], + "benchmark": false, + "ci_platforms": [ + "linux", + "mac", + "posix", + "windows" + ], + "cpu_cost": 1.0, + "exclude_configs": [], + "exclude_iomgrs": [], + "flaky": false, + "gtest": true, + "language": "c++", + "name": "dual_ref_counted_test", + "platforms": [ + "linux", + "mac", + "posix", + "windows" + ], + "uses_polling": true + }, { "args": [], "benchmark": false,