[party] Decouple `Party` from `Arena` (#36229)

Following up to #33961 `Party` no longer needs to refer to `Arena`, and decoupling gives us a few more degrees of freedom in the design of a final `CallSpine`.

Also flesh out `LogStateChange` usage so that all state transitions are traced when that tracer is enabled.

Closes #36229

COPYBARA_INTEGRATE_REVIEW=https://github.com/grpc/grpc/pull/36229 from ctiller:arenaless-party 51ae8eb898
PiperOrigin-RevId: 621525912
pull/36239/head
Craig Tiller 8 months ago committed by Copybara-Service
parent aa67587bac
commit e0cade4daa
  1. 2
      src/core/lib/promise/party.cc
  2. 26
      src/core/lib/promise/party.h
  3. 4
      src/core/lib/surface/call.cc
  4. 23
      src/core/lib/transport/batch_builder.cc
  5. 4
      src/core/lib/transport/batch_builder.h
  6. 9
      src/core/lib/transport/call_spine.h
  7. 12
      test/core/promise/party_test.cc

@ -182,7 +182,6 @@ Party::~Party() {}
void Party::CancelRemainingParticipants() {
ScopedActivity activity(this);
promise_detail::Context<Arena> arena_ctx(arena_);
for (size_t i = 0; i < party_detail::kMaxParticipants; i++) {
if (auto* p =
participants_[i].exchange(nullptr, std::memory_order_acquire)) {
@ -265,7 +264,6 @@ void Party::RunLocked() {
bool Party::RunParty() {
ScopedActivity activity(this);
promise_detail::Context<Arena> arena_ctx(arena_);
return sync_.RunParty([this](int i) { return RunOneParticipant(i); });
}

@ -41,7 +41,6 @@
#include "src/core/lib/promise/detail/promise_factory.h"
#include "src/core/lib/promise/poll.h"
#include "src/core/lib/promise/trace.h"
#include "src/core/lib/resource_quota/arena.h"
// Two implementations of party synchronization are provided: one using a single
// atomic, the other using a mutex and a set of state variables.
@ -130,7 +129,10 @@ class PartySyncUsingAtomics {
if (poll_one_participant(i)) {
const uint64_t allocated_bit = (1u << i << kAllocatedShift);
prev_state &= ~allocated_bit;
state_.fetch_and(~allocated_bit, std::memory_order_release);
uint64_t finished_prev_state =
state_.fetch_and(~allocated_bit, std::memory_order_release);
LogStateChange("Run:ParticipantComplete", finished_prev_state,
finished_prev_state & ~allocated_bit);
}
}
// Try to CAS the state we expected to have (with no wakeups or adds)
@ -208,7 +210,8 @@ class PartySyncUsingAtomics {
// Now we need to wake up the party.
state = state_.fetch_or(wakeup_mask | kLocked, std::memory_order_release);
LogStateChange("AddParticipantsAndRef:Wakeup", state, state | kLocked);
LogStateChange("AddParticipantsAndRef:Wakeup", state,
state | wakeup_mask | kLocked);
// If the party was already locked, we're done.
return ((state & kLocked) == 0);
@ -229,7 +232,7 @@ class PartySyncUsingAtomics {
void LogStateChange(const char* op, uint64_t prev_state, uint64_t new_state,
DebugLocation loc = {}) {
if (grpc_trace_party_state.enabled()) {
gpr_log(loc.file(), loc.line(), GPR_LOG_SEVERITY_DEBUG,
gpr_log(loc.file(), loc.line(), GPR_LOG_SEVERITY_INFO,
"Party %p %30s: %016" PRIx64 " -> %016" PRIx64, this, op,
prev_state, new_state);
}
@ -404,8 +407,6 @@ class Party : public Activity, private Wakeable {
return RefCountedPtr<Party>(this);
}
Arena* arena() const { return arena_; }
// Return a promise that resolves to Empty{} when the current party poll is
// complete.
// This is useful for implementing batching and the like: we can hold some
@ -438,8 +439,7 @@ class Party : public Activity, private Wakeable {
};
protected:
explicit Party(Arena* arena, size_t initial_refs)
: sync_(initial_refs), arena_(arena) {}
explicit Party(size_t initial_refs) : sync_(initial_refs) {}
~Party() override;
// Main run loop. Must be locked.
@ -491,13 +491,13 @@ class Party : public Activity, private Wakeable {
auto p = promise_();
if (auto* r = p.value_if_ready()) {
on_complete_(std::move(*r));
GetContext<Arena>()->DeletePooled(this);
delete this;
return true;
}
return false;
}
void Destroy() override { GetContext<Arena>()->DeletePooled(this); }
void Destroy() override { delete this; }
private:
union {
@ -626,7 +626,6 @@ class Party : public Activity, private Wakeable {
#error No synchronization method defined
#endif
Arena* const arena_;
uint8_t currently_polling_ = kNotPolling;
// All current participants, using a tagged format.
// If the lower bit is unset, then this is a Participant*.
@ -646,9 +645,8 @@ void Party::BulkSpawner::Spawn(absl::string_view name, Factory promise_factory,
gpr_log(GPR_DEBUG, "%s[bulk_spawn] On %p queue %s",
party_->DebugTag().c_str(), this, std::string(name).c_str());
}
participants_[num_participants_++] =
party_->arena_->NewPooled<ParticipantImpl<Factory, OnComplete>>(
name, std::move(promise_factory), std::move(on_complete));
participants_[num_participants_++] = new ParticipantImpl<Factory, OnComplete>(
name, std::move(promise_factory), std::move(on_complete));
}
template <typename Factory, typename OnComplete>

@ -1980,7 +1980,7 @@ class BasicPromiseBasedCall : public Call,
const grpc_call_create_args& args)
: Call(arena, args.server_transport_data == nullptr, args.send_deadline,
args.channel->Ref()),
Party(arena, initial_internal_refs),
Party(initial_internal_refs),
external_refs_(initial_external_refs),
cq_(args.cq) {
if (args.cq != nullptr) {
@ -2869,6 +2869,7 @@ class ClientPromiseBasedCall final : public PromiseBasedCall {
}
Party& party() override { return *call_; }
Arena* arena() override { return call_->arena(); }
void IncrementRefCount() override { refs_.Ref(); }
void Unref() override {
@ -3724,6 +3725,7 @@ class ServerCallSpine final : public CallSpineInterface,
}
Latch<ServerMetadataHandle>& cancel_latch() override { return cancel_latch_; }
Party& party() override { return *this; }
Arena* arena() override { return BasicPromiseBasedCall::arena(); }
void IncrementRefCount() override { InternalRef("CallSpine"); }
void Unref() override { InternalUnref("CallSpine"); }

@ -66,25 +66,16 @@ BatchBuilder::Batch::Batch(grpc_transport_stream_op_batch_payload* payload,
}
BatchBuilder::Batch::~Batch() {
auto* arena = party->arena();
if (grpc_call_trace.enabled()) {
gpr_log(GPR_DEBUG, "%s[connected] [batch %p] Destroy",
GetContext<Activity>()->DebugTag().c_str(), this);
}
if (pending_receive_message != nullptr) {
arena->DeletePooled(pending_receive_message);
}
if (pending_receive_initial_metadata != nullptr) {
arena->DeletePooled(pending_receive_initial_metadata);
}
if (pending_receive_trailing_metadata != nullptr) {
arena->DeletePooled(pending_receive_trailing_metadata);
}
if (pending_sends != nullptr) {
arena->DeletePooled(pending_sends);
}
delete pending_receive_message;
delete pending_receive_initial_metadata;
delete pending_receive_trailing_metadata;
delete pending_sends;
if (batch.cancel_stream) {
arena->DeletePooled(batch.payload);
delete batch.payload;
}
#ifndef NDEBUG
grpc_stream_unref(stream_refcount, "pending-batch");
@ -171,8 +162,8 @@ BatchBuilder::Batch* BatchBuilder::MakeCancel(
void BatchBuilder::Cancel(Target target, absl::Status status) {
auto* batch = MakeCancel(target.stream_refcount, std::move(status));
batch->batch.on_complete = NewClosure(
[batch](absl::Status) { batch->party->arena()->DeletePooled(batch); });
batch->batch.on_complete =
NewClosure([batch](absl::Status) { delete batch; });
batch->PerformWith(target);
}

@ -208,7 +208,7 @@ class BatchBuilder {
void IncrementRefCount() { ++refs; }
void Unref() {
if (--refs == 0) party->arena()->DeletePooled(this);
if (--refs == 0) delete this;
}
RefCountedPtr<Batch> Ref() {
IncrementRefCount();
@ -224,7 +224,7 @@ class BatchBuilder {
template <typename T>
T* GetInitializedCompletion(T*(Batch::*field)) {
if (this->*field != nullptr) return this->*field;
this->*field = party->arena()->NewPooled<T>(Ref());
this->*field = new T(Ref());
if (grpc_call_trace.enabled()) {
gpr_log(GPR_DEBUG, "%sAdd batch closure for %s @ %s",
DebugPrefix().c_str(),

@ -64,6 +64,7 @@ class CallSpineInterface {
if (on_done_ != nullptr) std::exchange(on_done_, nullptr)();
}
virtual Party& party() = 0;
virtual Arena* arena() = 0;
virtual void IncrementRefCount() = 0;
virtual void Unref() = 0;
@ -170,6 +171,7 @@ class CallSpine final : public CallSpineInterface, public Party {
}
Latch<ServerMetadataHandle>& cancel_latch() override { return cancel_latch_; }
Party& party() override { return *this; }
Arena* arena() override { return arena_; }
void IncrementRefCount() override { Party::IncrementRefCount(); }
void Unref() override { Party::Unref(); }
@ -177,7 +179,7 @@ class CallSpine final : public CallSpineInterface, public Party {
friend class Arena;
CallSpine(grpc_event_engine::experimental::EventEngine* event_engine,
Arena* arena)
: Party(arena, 1), event_engine_(event_engine) {}
: Party(1), arena_(arena), event_engine_(event_engine) {}
class ScopedContext : public ScopedActivity,
public promise_detail::Context<Arena> {
@ -206,6 +208,7 @@ class CallSpine final : public CallSpineInterface, public Party {
return event_engine_;
}
Arena* arena_;
// Initial metadata from client to server
Pipe<ClientMetadataHandle> client_initial_metadata_{arena()};
// Initial metadata from server to client
@ -313,7 +316,7 @@ class CallInitiator {
return spine_->party().SpawnWaitable(name, std::move(promise_factory));
}
Arena* arena() { return spine_->party().arena(); }
Arena* arena() { return spine_->arena(); }
private:
RefCountedPtr<CallSpineInterface> spine_;
@ -398,7 +401,7 @@ class CallHandler {
return spine_->party().SpawnWaitable(name, std::move(promise_factory));
}
Arena* arena() { return spine_->party().arena(); }
Arena* arena() { return spine_->arena(); }
private:
RefCountedPtr<CallSpineInterface> spine_;

@ -231,17 +231,9 @@ TYPED_TEST(PartySyncTest, UnrefWhileRunning) {
///////////////////////////////////////////////////////////////////////////////
// PartyTest
class AllocatorOwner {
protected:
~AllocatorOwner() { arena_->Destroy(); }
MemoryAllocator memory_allocator_ = MemoryAllocator(
ResourceQuota::Default()->memory_quota()->CreateMemoryAllocator("test"));
Arena* arena_ = Arena::Create(1024, &memory_allocator_);
};
class TestParty final : public AllocatorOwner, public Party {
class TestParty final : public Party {
public:
TestParty() : Party(AllocatorOwner::arena_, 1) {}
TestParty() : Party(1) {}
~TestParty() override {}
std::string DebugTag() const override { return "TestParty"; }

Loading…
Cancel
Save