use a separate method for is_acceptable

pull/35789/head
Mark D. Roth 1 year ago
parent 082a306685
commit 4c872578ef
  1. 88
      src/core/lib/promise/observable.h
  2. 4
      test/core/promise/observable_test.cc

@ -38,16 +38,18 @@ class Observable {
void Set(T value) { state_->Set(std::move(value)); }
// Returns a promise that resolves to a T when the value becomes != current.
// If a new value is available but is_acceptable is non-null and returns
// false for that value, does not return that value (i.e., remains pending).
auto Next(T current,
absl::AnyInvocable<bool(const T&)> is_acceptable = nullptr) {
return Observer(state_, std::move(current), std::move(is_acceptable));
auto Next(T current) { return Observer(state_, std::move(current)); }
// Same as Next(), except it resolves only once is_acceptable returns
// true for the new value.
auto NextWhen(absl::AnyInvocable<bool(const T&)> is_acceptable) {
return ObserverWhen(state_, std::move(is_acceptable));
}
private:
// Forward declaration so we can form pointers to Observer in State.
// Forward declaration so we can form pointers to observer types in State.
class Observer;
class ObserverWhen;
// State keeps track of all observable state.
// It's a refcounted object so that promises reading the state are not tied
@ -75,6 +77,9 @@ class Observable {
void Remove(Observer* observer) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
observers_.erase(observer);
}
void Remove(ObserverWhen* observer) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
observer_whens_.erase(observer);
}
// Add an observer to the set (it needs updates).
GRPC_MUST_USE_RESULT Waker Add(Observer* observer)
@ -82,6 +87,11 @@ class Observable {
observers_.insert(observer);
return GetContext<Activity>()->MakeNonOwningWaker();
}
GRPC_MUST_USE_RESULT Waker Add(ObserverWhen* observer)
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
observer_whens_.insert(observer);
return GetContext<Activity>()->MakeNonOwningWaker();
}
private:
// Wake all observers.
@ -89,11 +99,15 @@ class Observable {
for (auto* observer : observers_) {
observer->Wakeup();
}
for (auto* observer : observer_whens_) {
observer->Wakeup();
}
}
Mutex mu_;
// All observers that may need an update.
absl::flat_hash_set<Observer*> observers_ ABSL_GUARDED_BY(mu_);
absl::flat_hash_set<ObserverWhen*> observer_whens_ ABSL_GUARDED_BY(mu_);
// The current value.
T value_ ABSL_GUARDED_BY(mu_);
};
@ -102,11 +116,8 @@ class Observable {
// current.
class Observer {
public:
Observer(RefCountedPtr<State> state, T current,
absl::AnyInvocable<bool(const T&)> is_acceptable)
: state_(std::move(state)),
current_(std::move(current)),
is_acceptable_(std::move(is_acceptable)) {}
Observer(RefCountedPtr<State> state, T current)
: state_(std::move(state)), current_(std::move(current)) {}
~Observer() {
// If we saw a pending at all then we *may* be in the set of observers.
@ -120,9 +131,7 @@ class Observable {
Observer(const Observer&) = delete;
Observer& operator=(const Observer&) = delete;
Observer(Observer&& other) noexcept
: state_(std::move(other.state_)),
current_(std::move(other.current_)),
is_acceptable_(std::move(other.is_acceptable_)) {
: state_(std::move(other.state_)), current_(std::move(other.current_)) {
GPR_ASSERT(other.waker_.is_unwakeable());
GPR_ASSERT(!other.saw_pending_);
}
@ -133,8 +142,7 @@ class Observable {
Poll<T> operator()() {
MutexLock lock(state_->mu());
// Check if the value has changed yet.
if (current_ != state_->current() &&
(is_acceptable_ == nullptr || is_acceptable_(state_->current()))) {
if (current_ != state_->current()) {
if (saw_pending_ && !waker_.is_unwakeable()) state_->Remove(this);
return state_->current();
}
@ -147,6 +155,54 @@ class Observable {
private:
RefCountedPtr<State> state_;
T current_;
Waker waker_;
bool saw_pending_ = false;
};
// ObserverWhen is a promise that resolves to a T when the value becomes
// acceptable.
class ObserverWhen {
public:
ObserverWhen(RefCountedPtr<State> state,
absl::AnyInvocable<bool(const T&)> is_acceptable)
: state_(std::move(state)), is_acceptable_(std::move(is_acceptable)) {}
~ObserverWhen() {
// If we saw a pending at all then we *may* be in the set of observers.
// If not we're definitely not and we can avoid taking the lock at all.
if (!saw_pending_) return;
MutexLock lock(state_->mu());
auto w = std::move(waker_);
state_->Remove(this);
}
ObserverWhen(const ObserverWhen&) = delete;
ObserverWhen& operator=(const ObserverWhen&) = delete;
ObserverWhen(ObserverWhen&& other) noexcept
: state_(std::move(other.state_)),
is_acceptable_(std::move(other.is_acceptable_)) {
GPR_ASSERT(other.waker_.is_unwakeable());
GPR_ASSERT(!other.saw_pending_);
}
ObserverWhen& operator=(ObserverWhen&& other) noexcept = delete;
void Wakeup() { waker_.WakeupAsync(); }
Poll<T> operator()() {
MutexLock lock(state_->mu());
// Check if the value is acceptable.
if (is_acceptable_(state_->current())) {
if (saw_pending_ && !waker_.is_unwakeable()) state_->Remove(this);
return state_->current();
}
// Record that we saw at least one pending and then register for wakeup.
saw_pending_ = true;
if (waker_.is_unwakeable()) waker_ = state_->Add(this);
return Pending{};
}
private:
RefCountedPtr<State> state_;
absl::AnyInvocable<bool(const T&)> is_acceptable_;
Waker waker_;
bool saw_pending_ = false;

@ -129,11 +129,11 @@ TEST(ObservableTest, ChangeValueWakesUp) {
EXPECT_THAT(next(), IsReady(2));
}
TEST(ObservableTest, IsAcceptable) {
TEST(ObservableTest, NextWhen) {
StrictMock<MockActivity> activity;
activity.Activate();
Observable<int> observable(1);
auto next = observable.Next(0, [](int i) { return i == 3; });
auto next = observable.NextWhen([](int i) { return i == 3; });
EXPECT_THAT(next(), IsPending());
EXPECT_CALL(activity, WakeupRequested());
observable.Set(2);

Loading…
Cancel
Save