[promises] Add an observer type (#35552)

We've got a few situations coming up with promises that will want a "broadcast new value to everywhere" situation.

Closes #35552

COPYBARA_INTEGRATE_REVIEW=https://github.com/grpc/grpc/pull/35552 from ctiller:obs 30fd697ae3
PiperOrigin-RevId: 599609399
pull/35601/head
Craig Tiller 1 year ago committed by Copybara-Service
parent ad1dc17030
commit 98472179fb
  1. 46
      CMakeLists.txt
  2. 29
      build_autogenerated.yaml
  3. 14
      src/core/BUILD
  4. 146
      src/core/lib/promise/observable.h
  5. 19
      test/core/promise/BUILD
  6. 227
      test/core/promise/observable_test.cc
  7. 24
      tools/run_tests/generated/tests.json

46
CMakeLists.txt generated

@ -1193,6 +1193,7 @@ if(gRPC_BUILD_TESTS)
add_dependencies(buildtests_cxx nonblocking_test)
add_dependencies(buildtests_cxx notification_test)
add_dependencies(buildtests_cxx num_external_connectivity_watchers_test)
add_dependencies(buildtests_cxx observable_test)
if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX)
add_dependencies(buildtests_cxx oracle_event_engine_posix_test)
endif()
@ -19834,6 +19835,51 @@ target_link_libraries(num_external_connectivity_watchers_test
)
endif()
if(gRPC_BUILD_TESTS)
add_executable(observable_test
src/core/lib/promise/activity.cc
test/core/promise/observable_test.cc
)
if(WIN32 AND MSVC)
if(BUILD_SHARED_LIBS)
target_compile_definitions(observable_test
PRIVATE
"GPR_DLL_IMPORTS"
)
endif()
endif()
target_compile_features(observable_test PUBLIC cxx_std_14)
target_include_directories(observable_test
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/include
${_gRPC_ADDRESS_SORTING_INCLUDE_DIR}
${_gRPC_RE2_INCLUDE_DIR}
${_gRPC_SSL_INCLUDE_DIR}
${_gRPC_UPB_GENERATED_DIR}
${_gRPC_UPB_GRPC_GENERATED_DIR}
${_gRPC_UPB_INCLUDE_DIR}
${_gRPC_XXHASH_INCLUDE_DIR}
${_gRPC_ZLIB_INCLUDE_DIR}
third_party/googletest/googletest/include
third_party/googletest/googletest
third_party/googletest/googlemock/include
third_party/googletest/googlemock
${_gRPC_PROTO_GENS_DIR}
)
target_link_libraries(observable_test
${_gRPC_ALLTARGETS_LIBRARIES}
gtest
absl::hash
absl::type_traits
absl::statusor
gpr
)
endif()
if(gRPC_BUILD_TESTS)
if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX)

@ -12130,6 +12130,35 @@ targets:
deps:
- gtest
- grpc_test_util
- name: observable_test
gtest: true
build: test
language: c++
headers:
- src/core/lib/gprpp/atomic_utils.h
- src/core/lib/gprpp/notification.h
- src/core/lib/gprpp/orphanable.h
- src/core/lib/gprpp/ref_counted.h
- src/core/lib/gprpp/ref_counted_ptr.h
- src/core/lib/promise/activity.h
- src/core/lib/promise/context.h
- src/core/lib/promise/detail/promise_factory.h
- src/core/lib/promise/detail/promise_like.h
- src/core/lib/promise/detail/status.h
- src/core/lib/promise/loop.h
- src/core/lib/promise/map.h
- src/core/lib/promise/observable.h
- src/core/lib/promise/poll.h
src:
- src/core/lib/promise/activity.cc
- test/core/promise/observable_test.cc
deps:
- gtest
- absl/hash:hash
- absl/meta:type_traits
- absl/status:statusor
- gpr
uses_polling: false
- name: oracle_event_engine_posix_test
gtest: true
build: test

@ -1041,6 +1041,20 @@ grpc_cc_library(
],
)
grpc_cc_library(
name = "observable",
hdrs = [
"lib/promise/observable.h",
],
external_deps = ["absl/container:flat_hash_set"],
language = "c++",
deps = [
"activity",
"poll",
"//:gpr",
],
)
grpc_cc_library(
name = "for_each",
external_deps = [

@ -0,0 +1,146 @@
// Copyright 2024 gRPC authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef GRPC_SRC_CORE_LIB_PROMISE_OBSERVABLE_H
#define GRPC_SRC_CORE_LIB_PROMISE_OBSERVABLE_H
#include <grpc/support/port_platform.h>
#include "absl/container/flat_hash_set.h"
#include "src/core/lib/gprpp/sync.h"
#include "src/core/lib/promise/activity.h"
#include "src/core/lib/promise/poll.h"
namespace grpc_core {
// Observable allows broadcasting a value to multiple interested observers.
template <typename T>
class Observable {
public:
// We need to assign a value initially.
explicit Observable(T initial)
: state_(MakeRefCounted<State>(std::move(initial))) {}
// Update the value to something new. Awakes any waiters.
void Set(T value) { state_->Set(std::move(value)); }
// Returns a promise that resolves to a T when the value becomes != current.
auto Next(T current) { return Observer(state_, std::move(current)); }
private:
// Forward declaration so we can form pointers to Observer in State.
class Observer;
// State keeps track of all observable state.
// It's a refcounted object so that promises reading the state are not tied
// to the lifetime of the Observable.
class State : public RefCounted<State> {
public:
explicit State(T value) : value_(std::move(value)) {}
// Update the value and wake all observers.
void Set(T value) {
MutexLock lock(&mu_);
std::swap(value_, value);
WakeAll();
}
// Export our mutex so that Observer can use it.
Mutex* mu() ABSL_LOCK_RETURNED(mu_) { return &mu_; }
// Fetch a ref to the current value.
const T& current() const ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
return value_;
}
// Remove an observer from the set (it no longer needs updates).
void Remove(Observer* observer) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
observers_.erase(observer);
}
// Add an observer to the set (it needs updates).
GRPC_MUST_USE_RESULT Waker Add(Observer* observer)
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
observers_.insert(observer);
return Activity::current()->MakeNonOwningWaker();
}
private:
// Wake all observers.
void WakeAll() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
for (auto* observer : observers_) {
observer->Wakeup();
}
}
Mutex mu_;
// All observers that may need an update.
absl::flat_hash_set<Observer*> observers_ ABSL_GUARDED_BY(mu_);
// The current value.
T value_ ABSL_GUARDED_BY(mu_);
};
// Observer is a promise that resolves to a T when the value becomes !=
// current.
class Observer {
public:
Observer(RefCountedPtr<State> state, T current)
: state_(std::move(state)), current_(std::move(current)) {}
~Observer() {
// If we saw a pending at all then we *may* be in the set of observers.
// If not we're definitely not and we can avoid taking the lock at all.
if (!saw_pending_) return;
MutexLock lock(state_->mu());
auto w = std::move(waker_);
state_->Remove(this);
}
Observer(const Observer&) = delete;
Observer& operator=(const Observer&) = delete;
Observer(Observer&& other) noexcept
: state_(std::move(other.state_)), current_(std::move(other.current_)) {
GPR_ASSERT(other.waker_.is_unwakeable());
GPR_ASSERT(!other.saw_pending_);
}
Observer& operator=(Observer&& other) noexcept = delete;
void Wakeup() { waker_.WakeupAsync(); }
Poll<T> operator()() {
MutexLock lock(state_->mu());
// Check if the value has changed yet.
if (current_ != state_->current()) {
if (saw_pending_ && !waker_.is_unwakeable()) state_->Remove(this);
return state_->current();
}
// Record that we saw at least one pending and then register for wakeup.
saw_pending_ = true;
if (waker_.is_unwakeable()) waker_ = state_->Add(this);
return Pending{};
}
private:
RefCountedPtr<State> state_;
T current_;
Waker waker_;
bool saw_pending_ = false;
};
RefCountedPtr<State> state_;
};
} // namespace grpc_core
#endif // GRPC_SRC_CORE_LIB_PROMISE_OBSERVABLE_H

@ -484,6 +484,25 @@ grpc_cc_test(
],
)
grpc_cc_test(
name = "observable_test",
srcs = ["observable_test.cc"],
external_deps = [
"absl/strings",
"gtest",
],
language = "c++",
tags = ["promise_test"],
uses_event_engine = False,
uses_polling = False,
deps = [
"//src/core:loop",
"//src/core:map",
"//src/core:notification",
"//src/core:observable",
],
)
grpc_proto_fuzzer(
name = "promise_fuzzer",
srcs = ["promise_fuzzer.cc"],

@ -0,0 +1,227 @@
// Copyright 2024 gRPC authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/core/lib/promise/observable.h"
#include <cstdint>
#include <limits>
#include <thread>
#include <vector>
#include "absl/strings/str_join.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "src/core/lib/gprpp/notification.h"
#include "src/core/lib/promise/loop.h"
#include "src/core/lib/promise/map.h"
using testing::Mock;
using testing::StrictMock;
namespace grpc_core {
namespace {
class MockActivity : public Activity, public Wakeable {
public:
MOCK_METHOD(void, WakeupRequested, ());
void ForceImmediateRepoll(WakeupMask) override { WakeupRequested(); }
void Orphan() override {}
Waker MakeOwningWaker() override { return Waker(this, 0); }
Waker MakeNonOwningWaker() override { return Waker(this, 0); }
void Wakeup(WakeupMask) override { WakeupRequested(); }
void WakeupAsync(WakeupMask) override { WakeupRequested(); }
void Drop(WakeupMask) override {}
std::string DebugTag() const override { return "MockActivity"; }
std::string ActivityDebugTag(WakeupMask) const override { return DebugTag(); }
void Activate() {
if (scoped_activity_ != nullptr) return;
scoped_activity_ = std::make_unique<ScopedActivity>(this);
}
void Deactivate() { scoped_activity_.reset(); }
private:
std::unique_ptr<ScopedActivity> scoped_activity_;
};
MATCHER(IsPending, "") {
if (arg.ready()) {
*result_listener << "is ready";
return false;
}
return true;
}
MATCHER(IsReady, "") {
if (arg.pending()) {
*result_listener << "is pending";
return false;
}
return true;
}
MATCHER_P(IsReady, value, "") {
if (arg.pending()) {
*result_listener << "is pending";
return false;
}
if (arg.value() != value) {
*result_listener << "is " << ::testing::PrintToString(arg.value());
return false;
}
return true;
}
TEST(ObservableTest, ImmediateNext) {
Observable<int> observable(1);
auto next = observable.Next(0);
EXPECT_THAT(next(), IsReady(1));
}
TEST(ObservableTest, SetBecomesImmediateNext1) {
Observable<int> observable(0);
auto next = observable.Next(0);
observable.Set(1);
EXPECT_THAT(next(), IsReady(1));
}
TEST(ObservableTest, SetBecomesImmediateNext2) {
Observable<int> observable(0);
observable.Set(1);
auto next = observable.Next(0);
EXPECT_THAT(next(), IsReady(1));
}
TEST(ObservableTest, SameValueGetsPending) {
StrictMock<MockActivity> activity;
activity.Activate();
Observable<int> observable(1);
auto next = observable.Next(1);
EXPECT_THAT(next(), IsPending());
EXPECT_THAT(next(), IsPending());
EXPECT_THAT(next(), IsPending());
EXPECT_THAT(next(), IsPending());
}
TEST(ObservableTest, ChangeValueWakesUp) {
StrictMock<MockActivity> activity;
activity.Activate();
Observable<int> observable(1);
auto next = observable.Next(1);
EXPECT_THAT(next(), IsPending());
EXPECT_CALL(activity, WakeupRequested());
observable.Set(2);
Mock::VerifyAndClearExpectations(&activity);
EXPECT_THAT(next(), IsReady(2));
}
TEST(ObservableTest, MultipleActivitiesWakeUp) {
StrictMock<MockActivity> activity1;
StrictMock<MockActivity> activity2;
Observable<int> observable(1);
auto next1 = observable.Next(1);
auto next2 = observable.Next(1);
{
activity1.Activate();
EXPECT_THAT(next1(), IsPending());
}
{
activity2.Activate();
EXPECT_THAT(next2(), IsPending());
}
EXPECT_CALL(activity1, WakeupRequested());
EXPECT_CALL(activity2, WakeupRequested());
observable.Set(2);
Mock::VerifyAndClearExpectations(&activity1);
Mock::VerifyAndClearExpectations(&activity2);
EXPECT_THAT(next1(), IsReady(2));
EXPECT_THAT(next2(), IsReady(2));
}
class ThreadWakeupScheduler {
public:
template <typename ActivityType>
class BoundScheduler {
public:
explicit BoundScheduler(ThreadWakeupScheduler) {}
void ScheduleWakeup() {
std::thread t(
[this] { static_cast<ActivityType*>(this)->RunScheduledWakeup(); });
t.detach();
}
};
};
TEST(ObservableTest, Stress) {
static constexpr uint64_t kEnd = std::numeric_limits<uint64_t>::max();
std::vector<uint64_t> values1;
std::vector<uint64_t> values2;
uint64_t current1 = 0;
uint64_t current2 = 0;
Notification done1;
Notification done2;
Observable<uint64_t> observable(0);
auto activity1 = MakeActivity(
Loop([&observable, &current1, &values1] {
return Map(
observable.Next(current1),
[&values1, &current1](uint64_t value) -> LoopCtl<absl::Status> {
values1.push_back(value);
current1 = value;
if (value == kEnd) return absl::OkStatus();
return Continue{};
});
}),
ThreadWakeupScheduler(), [&done1](absl::Status status) {
EXPECT_TRUE(status.ok()) << status.ToString();
done1.Notify();
});
auto activity2 = MakeActivity(
Loop([&observable, &current2, &values2] {
return Map(
observable.Next(current2),
[&values2, &current2](uint64_t value) -> LoopCtl<absl::Status> {
values2.push_back(value);
current2 = value;
if (value == kEnd) return absl::OkStatus();
return Continue{};
});
}),
ThreadWakeupScheduler(), [&done2](absl::Status status) {
EXPECT_TRUE(status.ok()) << status.ToString();
done2.Notify();
});
for (uint64_t i = 0; i < 1000000; i++) {
observable.Set(i);
}
observable.Set(kEnd);
done1.WaitForNotification();
done2.WaitForNotification();
ASSERT_GE(values1.size(), 1);
ASSERT_GE(values2.size(), 1);
EXPECT_EQ(values1.back(), kEnd);
EXPECT_EQ(values2.back(), kEnd);
}
} // namespace
} // namespace grpc_core
int main(int argc, char** argv) {
gpr_log_verbosity_init();
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

@ -6269,6 +6269,30 @@
],
"uses_polling": true
},
{
"args": [],
"benchmark": false,
"ci_platforms": [
"linux",
"mac",
"posix",
"windows"
],
"cpu_cost": 1.0,
"exclude_configs": [],
"exclude_iomgrs": [],
"flaky": false,
"gtest": true,
"language": "c++",
"name": "observable_test",
"platforms": [
"linux",
"mac",
"posix",
"windows"
],
"uses_polling": false
},
{
"args": [],
"benchmark": false,

Loading…
Cancel
Save