mirror of https://github.com/grpc/grpc.git
[promises] Add an observer type (#35552)
We've got a few situations coming up with promises that will want a "broadcast new value to everywhere" situation.
Closes #35552
COPYBARA_INTEGRATE_REVIEW=https://github.com/grpc/grpc/pull/35552 from ctiller:obs 30fd697ae3
PiperOrigin-RevId: 599609399
pull/35601/head
parent
ad1dc17030
commit
98472179fb
7 changed files with 505 additions and 0 deletions
@ -0,0 +1,146 @@ |
||||
// Copyright 2024 gRPC authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef GRPC_SRC_CORE_LIB_PROMISE_OBSERVABLE_H |
||||
#define GRPC_SRC_CORE_LIB_PROMISE_OBSERVABLE_H |
||||
|
||||
#include <grpc/support/port_platform.h> |
||||
|
||||
#include "absl/container/flat_hash_set.h" |
||||
|
||||
#include "src/core/lib/gprpp/sync.h" |
||||
#include "src/core/lib/promise/activity.h" |
||||
#include "src/core/lib/promise/poll.h" |
||||
|
||||
namespace grpc_core { |
||||
|
||||
// Observable allows broadcasting a value to multiple interested observers.
|
||||
template <typename T> |
||||
class Observable { |
||||
public: |
||||
// We need to assign a value initially.
|
||||
explicit Observable(T initial) |
||||
: state_(MakeRefCounted<State>(std::move(initial))) {} |
||||
|
||||
// Update the value to something new. Awakes any waiters.
|
||||
void Set(T value) { state_->Set(std::move(value)); } |
||||
|
||||
// Returns a promise that resolves to a T when the value becomes != current.
|
||||
auto Next(T current) { return Observer(state_, std::move(current)); } |
||||
|
||||
private: |
||||
// Forward declaration so we can form pointers to Observer in State.
|
||||
class Observer; |
||||
|
||||
// State keeps track of all observable state.
|
||||
// It's a refcounted object so that promises reading the state are not tied
|
||||
// to the lifetime of the Observable.
|
||||
class State : public RefCounted<State> { |
||||
public: |
||||
explicit State(T value) : value_(std::move(value)) {} |
||||
|
||||
// Update the value and wake all observers.
|
||||
void Set(T value) { |
||||
MutexLock lock(&mu_); |
||||
std::swap(value_, value); |
||||
WakeAll(); |
||||
} |
||||
|
||||
// Export our mutex so that Observer can use it.
|
||||
Mutex* mu() ABSL_LOCK_RETURNED(mu_) { return &mu_; } |
||||
|
||||
// Fetch a ref to the current value.
|
||||
const T& current() const ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
||||
return value_; |
||||
} |
||||
|
||||
// Remove an observer from the set (it no longer needs updates).
|
||||
void Remove(Observer* observer) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
||||
observers_.erase(observer); |
||||
} |
||||
|
||||
// Add an observer to the set (it needs updates).
|
||||
GRPC_MUST_USE_RESULT Waker Add(Observer* observer) |
||||
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
||||
observers_.insert(observer); |
||||
return Activity::current()->MakeNonOwningWaker(); |
||||
} |
||||
|
||||
private: |
||||
// Wake all observers.
|
||||
void WakeAll() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
||||
for (auto* observer : observers_) { |
||||
observer->Wakeup(); |
||||
} |
||||
} |
||||
|
||||
Mutex mu_; |
||||
// All observers that may need an update.
|
||||
absl::flat_hash_set<Observer*> observers_ ABSL_GUARDED_BY(mu_); |
||||
// The current value.
|
||||
T value_ ABSL_GUARDED_BY(mu_); |
||||
}; |
||||
|
||||
// Observer is a promise that resolves to a T when the value becomes !=
|
||||
// current.
|
||||
class Observer { |
||||
public: |
||||
Observer(RefCountedPtr<State> state, T current) |
||||
: state_(std::move(state)), current_(std::move(current)) {} |
||||
~Observer() { |
||||
// If we saw a pending at all then we *may* be in the set of observers.
|
||||
// If not we're definitely not and we can avoid taking the lock at all.
|
||||
if (!saw_pending_) return; |
||||
MutexLock lock(state_->mu()); |
||||
auto w = std::move(waker_); |
||||
state_->Remove(this); |
||||
} |
||||
|
||||
Observer(const Observer&) = delete; |
||||
Observer& operator=(const Observer&) = delete; |
||||
Observer(Observer&& other) noexcept |
||||
: state_(std::move(other.state_)), current_(std::move(other.current_)) { |
||||
GPR_ASSERT(other.waker_.is_unwakeable()); |
||||
GPR_ASSERT(!other.saw_pending_); |
||||
} |
||||
Observer& operator=(Observer&& other) noexcept = delete; |
||||
|
||||
void Wakeup() { waker_.WakeupAsync(); } |
||||
|
||||
Poll<T> operator()() { |
||||
MutexLock lock(state_->mu()); |
||||
// Check if the value has changed yet.
|
||||
if (current_ != state_->current()) { |
||||
if (saw_pending_ && !waker_.is_unwakeable()) state_->Remove(this); |
||||
return state_->current(); |
||||
} |
||||
// Record that we saw at least one pending and then register for wakeup.
|
||||
saw_pending_ = true; |
||||
if (waker_.is_unwakeable()) waker_ = state_->Add(this); |
||||
return Pending{}; |
||||
} |
||||
|
||||
private: |
||||
RefCountedPtr<State> state_; |
||||
T current_; |
||||
Waker waker_; |
||||
bool saw_pending_ = false; |
||||
}; |
||||
|
||||
RefCountedPtr<State> state_; |
||||
}; |
||||
|
||||
} // namespace grpc_core
|
||||
|
||||
#endif // GRPC_SRC_CORE_LIB_PROMISE_OBSERVABLE_H
|
@ -0,0 +1,227 @@ |
||||
// Copyright 2024 gRPC authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "src/core/lib/promise/observable.h" |
||||
|
||||
#include <cstdint> |
||||
#include <limits> |
||||
#include <thread> |
||||
#include <vector> |
||||
|
||||
#include "absl/strings/str_join.h" |
||||
#include "gmock/gmock.h" |
||||
#include "gtest/gtest.h" |
||||
|
||||
#include "src/core/lib/gprpp/notification.h" |
||||
#include "src/core/lib/promise/loop.h" |
||||
#include "src/core/lib/promise/map.h" |
||||
|
||||
using testing::Mock; |
||||
using testing::StrictMock; |
||||
|
||||
namespace grpc_core { |
||||
namespace { |
||||
|
||||
class MockActivity : public Activity, public Wakeable { |
||||
public: |
||||
MOCK_METHOD(void, WakeupRequested, ()); |
||||
|
||||
void ForceImmediateRepoll(WakeupMask) override { WakeupRequested(); } |
||||
void Orphan() override {} |
||||
Waker MakeOwningWaker() override { return Waker(this, 0); } |
||||
Waker MakeNonOwningWaker() override { return Waker(this, 0); } |
||||
void Wakeup(WakeupMask) override { WakeupRequested(); } |
||||
void WakeupAsync(WakeupMask) override { WakeupRequested(); } |
||||
void Drop(WakeupMask) override {} |
||||
std::string DebugTag() const override { return "MockActivity"; } |
||||
std::string ActivityDebugTag(WakeupMask) const override { return DebugTag(); } |
||||
|
||||
void Activate() { |
||||
if (scoped_activity_ != nullptr) return; |
||||
scoped_activity_ = std::make_unique<ScopedActivity>(this); |
||||
} |
||||
|
||||
void Deactivate() { scoped_activity_.reset(); } |
||||
|
||||
private: |
||||
std::unique_ptr<ScopedActivity> scoped_activity_; |
||||
}; |
||||
|
||||
MATCHER(IsPending, "") { |
||||
if (arg.ready()) { |
||||
*result_listener << "is ready"; |
||||
return false; |
||||
} |
||||
return true; |
||||
} |
||||
|
||||
MATCHER(IsReady, "") { |
||||
if (arg.pending()) { |
||||
*result_listener << "is pending"; |
||||
return false; |
||||
} |
||||
return true; |
||||
} |
||||
|
||||
MATCHER_P(IsReady, value, "") { |
||||
if (arg.pending()) { |
||||
*result_listener << "is pending"; |
||||
return false; |
||||
} |
||||
if (arg.value() != value) { |
||||
*result_listener << "is " << ::testing::PrintToString(arg.value()); |
||||
return false; |
||||
} |
||||
return true; |
||||
} |
||||
|
||||
TEST(ObservableTest, ImmediateNext) { |
||||
Observable<int> observable(1); |
||||
auto next = observable.Next(0); |
||||
EXPECT_THAT(next(), IsReady(1)); |
||||
} |
||||
|
||||
TEST(ObservableTest, SetBecomesImmediateNext1) { |
||||
Observable<int> observable(0); |
||||
auto next = observable.Next(0); |
||||
observable.Set(1); |
||||
EXPECT_THAT(next(), IsReady(1)); |
||||
} |
||||
|
||||
TEST(ObservableTest, SetBecomesImmediateNext2) { |
||||
Observable<int> observable(0); |
||||
observable.Set(1); |
||||
auto next = observable.Next(0); |
||||
EXPECT_THAT(next(), IsReady(1)); |
||||
} |
||||
|
||||
TEST(ObservableTest, SameValueGetsPending) { |
||||
StrictMock<MockActivity> activity; |
||||
activity.Activate(); |
||||
Observable<int> observable(1); |
||||
auto next = observable.Next(1); |
||||
EXPECT_THAT(next(), IsPending()); |
||||
EXPECT_THAT(next(), IsPending()); |
||||
EXPECT_THAT(next(), IsPending()); |
||||
EXPECT_THAT(next(), IsPending()); |
||||
} |
||||
|
||||
TEST(ObservableTest, ChangeValueWakesUp) { |
||||
StrictMock<MockActivity> activity; |
||||
activity.Activate(); |
||||
Observable<int> observable(1); |
||||
auto next = observable.Next(1); |
||||
EXPECT_THAT(next(), IsPending()); |
||||
EXPECT_CALL(activity, WakeupRequested()); |
||||
observable.Set(2); |
||||
Mock::VerifyAndClearExpectations(&activity); |
||||
EXPECT_THAT(next(), IsReady(2)); |
||||
} |
||||
|
||||
TEST(ObservableTest, MultipleActivitiesWakeUp) { |
||||
StrictMock<MockActivity> activity1; |
||||
StrictMock<MockActivity> activity2; |
||||
Observable<int> observable(1); |
||||
auto next1 = observable.Next(1); |
||||
auto next2 = observable.Next(1); |
||||
{ |
||||
activity1.Activate(); |
||||
EXPECT_THAT(next1(), IsPending()); |
||||
} |
||||
{ |
||||
activity2.Activate(); |
||||
EXPECT_THAT(next2(), IsPending()); |
||||
} |
||||
EXPECT_CALL(activity1, WakeupRequested()); |
||||
EXPECT_CALL(activity2, WakeupRequested()); |
||||
observable.Set(2); |
||||
Mock::VerifyAndClearExpectations(&activity1); |
||||
Mock::VerifyAndClearExpectations(&activity2); |
||||
EXPECT_THAT(next1(), IsReady(2)); |
||||
EXPECT_THAT(next2(), IsReady(2)); |
||||
} |
||||
|
||||
class ThreadWakeupScheduler { |
||||
public: |
||||
template <typename ActivityType> |
||||
class BoundScheduler { |
||||
public: |
||||
explicit BoundScheduler(ThreadWakeupScheduler) {} |
||||
void ScheduleWakeup() { |
||||
std::thread t( |
||||
[this] { static_cast<ActivityType*>(this)->RunScheduledWakeup(); }); |
||||
t.detach(); |
||||
} |
||||
}; |
||||
}; |
||||
|
||||
TEST(ObservableTest, Stress) { |
||||
static constexpr uint64_t kEnd = std::numeric_limits<uint64_t>::max(); |
||||
std::vector<uint64_t> values1; |
||||
std::vector<uint64_t> values2; |
||||
uint64_t current1 = 0; |
||||
uint64_t current2 = 0; |
||||
Notification done1; |
||||
Notification done2; |
||||
Observable<uint64_t> observable(0); |
||||
auto activity1 = MakeActivity( |
||||
Loop([&observable, ¤t1, &values1] { |
||||
return Map( |
||||
observable.Next(current1), |
||||
[&values1, ¤t1](uint64_t value) -> LoopCtl<absl::Status> { |
||||
values1.push_back(value); |
||||
current1 = value; |
||||
if (value == kEnd) return absl::OkStatus(); |
||||
return Continue{}; |
||||
}); |
||||
}), |
||||
ThreadWakeupScheduler(), [&done1](absl::Status status) { |
||||
EXPECT_TRUE(status.ok()) << status.ToString(); |
||||
done1.Notify(); |
||||
}); |
||||
auto activity2 = MakeActivity( |
||||
Loop([&observable, ¤t2, &values2] { |
||||
return Map( |
||||
observable.Next(current2), |
||||
[&values2, ¤t2](uint64_t value) -> LoopCtl<absl::Status> { |
||||
values2.push_back(value); |
||||
current2 = value; |
||||
if (value == kEnd) return absl::OkStatus(); |
||||
return Continue{}; |
||||
}); |
||||
}), |
||||
ThreadWakeupScheduler(), [&done2](absl::Status status) { |
||||
EXPECT_TRUE(status.ok()) << status.ToString(); |
||||
done2.Notify(); |
||||
}); |
||||
for (uint64_t i = 0; i < 1000000; i++) { |
||||
observable.Set(i); |
||||
} |
||||
observable.Set(kEnd); |
||||
done1.WaitForNotification(); |
||||
done2.WaitForNotification(); |
||||
ASSERT_GE(values1.size(), 1); |
||||
ASSERT_GE(values2.size(), 1); |
||||
EXPECT_EQ(values1.back(), kEnd); |
||||
EXPECT_EQ(values2.back(), kEnd); |
||||
} |
||||
|
||||
} // namespace
|
||||
} // namespace grpc_core
|
||||
|
||||
int main(int argc, char** argv) { |
||||
gpr_log_verbosity_init(); |
||||
::testing::InitGoogleTest(&argc, argv); |
||||
return RUN_ALL_TESTS(); |
||||
} |
Loading…
Reference in new issue