diff --git a/upb/mem/BUILD b/upb/mem/BUILD index 71e007e544..f6bc19e7f1 100644 --- a/upb/mem/BUILD +++ b/upb/mem/BUILD @@ -46,10 +46,12 @@ cc_test( deps = [ "@com_google_googletest//:gtest_main", "//upb:mem", + "//upb:mem_internal", "//upb:port", "@com_google_absl//absl/random", "@com_google_absl//absl/random:distributions", "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", ], ) diff --git a/upb/mem/arena.c b/upb/mem/arena.c index 7e34f86ee4..b9d3d9d8f9 100644 --- a/upb/mem/arena.c +++ b/upb/mem/arena.c @@ -346,3 +346,23 @@ bool upb_Arena_Fuse(upb_Arena* a1, upb_Arena* a2) { } } } + +void upb_Arena_IncRefFor(upb_Arena* arena, const void* owner) { + _upb_ArenaRoot r; +retry: + r = _upb_Arena_FindRoot(arena); + if (upb_Atomic_CompareExchangeWeak( + &r.root->parent_or_count, &r.tagged_count, + _upb_Arena_TaggedFromRefcount( + _upb_Arena_RefCountFromTagged(r.tagged_count) + 1), + memory_order_release, memory_order_acquire)) { + // We incremented it successfully, so we are done. + return; + } + // We failed update due to parent switching on the arena. + goto retry; +} + +void upb_Arena_DecRefFor(upb_Arena* arena, const void* owner) { + upb_Arena_Free(arena); +} diff --git a/upb/mem/arena.h b/upb/mem/arena.h index e7fd852ff9..e4afceaa43 100644 --- a/upb/mem/arena.h +++ b/upb/mem/arena.h @@ -47,6 +47,9 @@ UPB_API upb_Arena* upb_Arena_Init(void* mem, size_t n, upb_alloc* alloc); UPB_API void upb_Arena_Free(upb_Arena* a); UPB_API bool upb_Arena_Fuse(upb_Arena* a, upb_Arena* b); +void upb_Arena_IncRefFor(upb_Arena* arena, const void* owner); +void upb_Arena_DecRefFor(upb_Arena* arena, const void* owner); + void* _upb_Arena_SlowMalloc(upb_Arena* a, size_t size); size_t upb_Arena_SpaceAllocated(upb_Arena* arena); uint32_t upb_Arena_DebugRefCount(upb_Arena* arena); diff --git a/upb/mem/arena_test.cc b/upb/mem/arena_test.cc index 7020da1344..0b13e1a365 100644 --- a/upb/mem/arena_test.cc +++ b/upb/mem/arena_test.cc @@ -19,6 +19,8 @@ #include "absl/synchronization/notification.h" // Must be last. +#include "absl/time/clock.h" +#include "absl/time/time.h" #include "upb/port/def.inc" namespace { @@ -44,7 +46,7 @@ TEST(ArenaTest, FuseWithInitialBlock) { char buf2[1024]; upb_Arena* arenas[] = {upb_Arena_Init(buf1, 1024, &upb_alloc_global), upb_Arena_Init(buf2, 1024, &upb_alloc_global), - upb_Arena_Init(NULL, 0, &upb_alloc_global)}; + upb_Arena_Init(nullptr, 0, &upb_alloc_global)}; int size = sizeof(arenas) / sizeof(arenas[0]); for (int i = 0; i < size; ++i) { for (int j = 0; j < size; ++j) { @@ -74,6 +76,15 @@ class Environment { if (old != nullptr) upb_Arena_Free(old); } + void RandomIncRefCount(absl::BitGen& gen) { + auto* a = SwapRandomly(gen, nullptr); + if (a != nullptr) { + upb_Arena_IncRefFor(a, nullptr); + upb_Arena_DecRefFor(a, nullptr); + upb_Arena_Free(a); + } + } + void RandomFuse(absl::BitGen& gen) { std::array old; for (auto& o : old) { @@ -168,6 +179,40 @@ TEST(ArenaTest, FuzzFuseFuseRace) { for (auto& t : threads) t.join(); } +TEST(ArenaTest, ArenaIncRef) { + upb_Arena* arena1 = upb_Arena_New(); + EXPECT_EQ(upb_Arena_DebugRefCount(arena1), 1); + upb_Arena_IncRefFor(arena1, nullptr); + EXPECT_EQ(upb_Arena_DebugRefCount(arena1), 2); + upb_Arena_DecRefFor(arena1, nullptr); + EXPECT_EQ(upb_Arena_DebugRefCount(arena1), 1); + upb_Arena_Free(arena1); +} + +TEST(ArenaTest, FuzzFuseIncRefCountRace) { + Environment env; + + absl::Notification done; + std::vector threads; + for (int i = 0; i < 10; ++i) { + threads.emplace_back([&]() { + absl::BitGen gen; + while (!done.HasBeenNotified()) { + env.RandomNewFree(gen); + } + }); + } + + absl::BitGen gen; + auto end = absl::Now() + absl::Seconds(2); + while (absl::Now() < end) { + env.RandomFuse(gen); + env.RandomIncRefCount(gen); + } + done.Notify(); + for (auto& t : threads) t.join(); +} + #endif } // namespace