From 122a1996aed7dedb5d25847462d5bd696f90f1b1 Mon Sep 17 00:00:00 2001 From: "Mark D. Roth" Date: Mon, 5 Feb 2024 14:41:40 -0800 Subject: [PATCH] [promise] add optional is_acceptable callback to Observable (#35789) Closes #35789 COPYBARA_INTEGRATE_REVIEW=https://github.com/grpc/grpc/pull/35789 from markdroth:observable_is_acceptable 75a888a769ee30fe19ea3cd7ec36f07aea59ef75 PiperOrigin-RevId: 604437699 --- src/core/BUILD | 5 ++- src/core/lib/promise/observable.h | 63 +++++++++++++++++++++++----- test/core/promise/observable_test.cc | 15 +++++++ 3 files changed, 72 insertions(+), 11 deletions(-) diff --git a/src/core/BUILD b/src/core/BUILD index 6ee421b2c23..39562a53211 100644 --- a/src/core/BUILD +++ b/src/core/BUILD @@ -1068,7 +1068,10 @@ grpc_cc_library( hdrs = [ "lib/promise/observable.h", ], - external_deps = ["absl/container:flat_hash_set"], + external_deps = [ + "absl/container:flat_hash_set", + "absl/functional:any_invocable", + ], language = "c++", deps = [ "activity", diff --git a/src/core/lib/promise/observable.h b/src/core/lib/promise/observable.h index 335fc393754..3c270bef34e 100644 --- a/src/core/lib/promise/observable.h +++ b/src/core/lib/promise/observable.h @@ -18,6 +18,7 @@ #include #include "absl/container/flat_hash_set.h" +#include "absl/functional/any_invocable.h" #include "src/core/lib/gprpp/sync.h" #include "src/core/lib/promise/activity.h" @@ -37,7 +38,13 @@ class Observable { void Set(T value) { state_->Set(std::move(value)); } // Returns a promise that resolves to a T when the value becomes != current. - auto Next(T current) { return Observer(state_, std::move(current)); } + auto Next(T current) { return ObserverIfChanged(state_, std::move(current)); } + + // Same as Next(), except it resolves only once is_acceptable returns + // true for the new value. + auto NextWhen(absl::AnyInvocable is_acceptable) { + return ObserverWhen(state_, std::move(is_acceptable)); + } private: // Forward declaration so we can form pointers to Observer in State. @@ -92,13 +99,13 @@ class Observable { T value_ ABSL_GUARDED_BY(mu_); }; - // Observer is a promise that resolves to a T when the value becomes != - // current. + // A promise that resolves to a T when ShouldReturn() returns true. + // Subclasses must implement ShouldReturn(). class Observer { public: - Observer(RefCountedPtr state, T current) - : state_(std::move(state)), current_(std::move(current)) {} - ~Observer() { + explicit Observer(RefCountedPtr state) : state_(std::move(state)) {} + + virtual ~Observer() { // If we saw a pending at all then we *may* be in the set of observers. // If not we're definitely not and we can avoid taking the lock at all. if (!saw_pending_) return; @@ -109,8 +116,7 @@ class Observable { Observer(const Observer&) = delete; Observer& operator=(const Observer&) = delete; - Observer(Observer&& other) noexcept - : state_(std::move(other.state_)), current_(std::move(other.current_)) { + Observer(Observer&& other) noexcept : state_(std::move(other.state_)) { GPR_ASSERT(other.waker_.is_unwakeable()); GPR_ASSERT(!other.saw_pending_); } @@ -118,10 +124,12 @@ class Observable { void Wakeup() { waker_.WakeupAsync(); } + virtual bool ShouldReturn(const T& current) = 0; + Poll operator()() { MutexLock lock(state_->mu()); // Check if the value has changed yet. - if (current_ != state_->current()) { + if (ShouldReturn(state_->current())) { if (saw_pending_ && !waker_.is_unwakeable()) state_->Remove(this); return state_->current(); } @@ -133,11 +141,46 @@ class Observable { private: RefCountedPtr state_; - T current_; Waker waker_; bool saw_pending_ = false; }; + // An observer that resolves to a T when the value becomes != current. + class ObserverIfChanged : public Observer { + public: + ObserverIfChanged(RefCountedPtr state, T current) + : Observer(std::move(state)), current_(std::move(current)) {} + + ObserverIfChanged(ObserverIfChanged&& other) noexcept + : Observer(std::move(other)), current_(std::move(other.current_)) {} + + bool ShouldReturn(const T& current) override { return current_ != current; } + + private: + T current_; + }; + + // A promise that resolves to a T when is_acceptable returns true for + // the current value. + class ObserverWhen : public Observer { + public: + ObserverWhen(RefCountedPtr state, + absl::AnyInvocable is_acceptable) + : Observer(std::move(state)), + is_acceptable_(std::move(is_acceptable)) {} + + ObserverWhen(ObserverWhen&& other) noexcept + : Observer(std::move(other)), + is_acceptable_(std::move(other.is_acceptable_)) {} + + bool ShouldReturn(const T& current) override { + return is_acceptable_(current); + } + + private: + absl::AnyInvocable is_acceptable_; + }; + RefCountedPtr state_; }; diff --git a/test/core/promise/observable_test.cc b/test/core/promise/observable_test.cc index 09abca0b11a..48f52b4978c 100644 --- a/test/core/promise/observable_test.cc +++ b/test/core/promise/observable_test.cc @@ -129,6 +129,21 @@ TEST(ObservableTest, ChangeValueWakesUp) { EXPECT_THAT(next(), IsReady(2)); } +TEST(ObservableTest, NextWhen) { + StrictMock activity; + activity.Activate(); + Observable observable(1); + auto next = observable.NextWhen([](int i) { return i == 3; }); + EXPECT_THAT(next(), IsPending()); + EXPECT_CALL(activity, WakeupRequested()); + observable.Set(2); + EXPECT_THAT(next(), IsPending()); + EXPECT_CALL(activity, WakeupRequested()); + observable.Set(3); + Mock::VerifyAndClearExpectations(&activity); + EXPECT_THAT(next(), IsReady(3)); +} + TEST(ObservableTest, MultipleActivitiesWakeUp) { StrictMock activity1; StrictMock activity2;