diff --git a/CMakeLists.txt b/CMakeLists.txt index 4c8979a56ec..9d2a53f4ed6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1229,6 +1229,7 @@ if(gRPC_BUILD_TESTS) add_dependencies(buildtests_cxx proxy_auth_test) add_dependencies(buildtests_cxx qps_json_driver) add_dependencies(buildtests_cxx qps_worker) + add_dependencies(buildtests_cxx query_extensions_test) add_dependencies(buildtests_cxx race_test) add_dependencies(buildtests_cxx random_early_detection_test) add_dependencies(buildtests_cxx raw_end2end_test) @@ -18938,6 +18939,40 @@ target_link_libraries(qps_worker ) +endif() +if(gRPC_BUILD_TESTS) + +add_executable(query_extensions_test + test/core/event_engine/query_extensions_test.cc +) +target_compile_features(query_extensions_test PUBLIC cxx_std_14) +target_include_directories(query_extensions_test + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR} + ${CMAKE_CURRENT_SOURCE_DIR}/include + ${_gRPC_ADDRESS_SORTING_INCLUDE_DIR} + ${_gRPC_RE2_INCLUDE_DIR} + ${_gRPC_SSL_INCLUDE_DIR} + ${_gRPC_UPB_GENERATED_DIR} + ${_gRPC_UPB_GRPC_GENERATED_DIR} + ${_gRPC_UPB_INCLUDE_DIR} + ${_gRPC_XXHASH_INCLUDE_DIR} + ${_gRPC_ZLIB_INCLUDE_DIR} + third_party/googletest/googletest/include + third_party/googletest/googletest + third_party/googletest/googlemock/include + third_party/googletest/googlemock + ${_gRPC_PROTO_GENS_DIR} +) + +target_link_libraries(query_extensions_test + ${_gRPC_ALLTARGETS_LIBRARIES} + gtest + absl::statusor + gpr +) + + endif() if(gRPC_BUILD_TESTS) diff --git a/build_autogenerated.yaml b/build_autogenerated.yaml index bd2a9aa0bbd..7eba4bf9d53 100644 --- a/build_autogenerated.yaml +++ b/build_autogenerated.yaml @@ -13309,6 +13309,19 @@ targets: deps: - grpc++_test_config - grpc++_test_util +- name: query_extensions_test + gtest: true + build: test + language: c++ + headers: + - src/core/lib/event_engine/query_extensions.h + src: + - test/core/event_engine/query_extensions_test.cc + deps: + - gtest + - absl/status:statusor + - gpr + uses_polling: false - name: race_test gtest: true build: test diff --git a/include/grpc/event_engine/event_engine.h b/include/grpc/event_engine/event_engine.h index 4beca657625..20cbc64f52f 100644 --- a/include/grpc/event_engine/event_engine.h +++ b/include/grpc/event_engine/event_engine.h @@ -255,6 +255,45 @@ class EventEngine : public std::enable_shared_from_this { /// values are expected to remain valid for the life of the Endpoint. virtual const ResolvedAddress& GetPeerAddress() const = 0; virtual const ResolvedAddress& GetLocalAddress() const = 0; + + /// A method which allows users to query whether an Endpoint implementation + /// supports a specified extension. The name of the extension is provided + /// as an input. + /// + /// An extension could be any type with a unique string id. Each extension + /// may support additional capabilities and if the Endpoint implementation + /// supports the queried extension, it should return a valid pointer to the + /// extension type. + /// + /// E.g., use case of an EventEngine::Endpoint supporting a custom + /// extension. + /// + /// class CustomEndpointExtension { + /// public: + /// static constexpr std::string name = "my.namespace.extension_name"; + /// void Process() { ... } + /// } + /// + /// + /// class CustomEndpoint : + /// public EventEngine::Endpoint, CustomEndpointExtension { + /// public: + /// void* QueryExtension(absl::string_view id) override { + /// if (id == CustomEndpointExtension::name) { + /// return static_cast(this); + /// } + /// return nullptr; + /// } + /// ... + /// } + /// + /// auto ext_ = + /// static_cast( + /// endpoint->QueryExtension(CustomrEndpointExtension::name)); + /// if (ext_ != nullptr) { ext_->Process(); } + /// + /// + virtual void* QueryExtension(absl::string_view /*id*/) { return nullptr; } }; /// Called when a new connection is established. diff --git a/src/core/BUILD b/src/core/BUILD index 351f101b819..e6a4457ac0f 100644 --- a/src/core/BUILD +++ b/src/core/BUILD @@ -1540,6 +1540,18 @@ grpc_cc_library( ], ) +grpc_cc_library( + name = "event_engine_query_extensions", + hdrs = [ + "lib/event_engine/query_extensions.h", + ], + external_deps = ["absl/strings"], + deps = [ + "//:event_engine_base_hdrs", + "//:gpr_platform", + ], +) + grpc_cc_library( name = "event_engine_work_queue", hdrs = [ diff --git a/src/core/lib/event_engine/query_extensions.h b/src/core/lib/event_engine/query_extensions.h new file mode 100644 index 00000000000..2ef15ccfdab --- /dev/null +++ b/src/core/lib/event_engine/query_extensions.h @@ -0,0 +1,70 @@ +// Copyright 2023 gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#ifndef GRPC_SRC_CORE_LIB_EVENT_ENGINE_QUERY_EXTENSIONS_H +#define GRPC_SRC_CORE_LIB_EVENT_ENGINE_QUERY_EXTENSIONS_H + +#include + +#include "absl/strings/string_view.h" + +#include + +namespace grpc_event_engine { +namespace experimental { + +namespace endpoint_detail { + +template +struct QueryExtensionRecursion; + +template +struct QueryExtensionRecursion { + static void* Query(absl::string_view id, Querying* p) { + if (id == E::EndpointExtensionName()) return static_cast(p); + return QueryExtensionRecursion::Query(id, p); + } +}; + +template +struct QueryExtensionRecursion { + static void* Query(absl::string_view, Querying*) { return nullptr; } +}; + +} // namespace endpoint_detail + +// A helper class to derive from some set of base classes and export +// QueryExtension for them all. +// Endpoint implementations which need to support different extensions just need +// to derive from ExtendedEndpoint class. +template +class ExtendedEndpoint : public EventEngine::Endpoint, public Exports... { + public: + void* QueryExtension(absl::string_view id) override { + return endpoint_detail::QueryExtensionRecursion::Query(id, + this); + } +}; + +/// A helper method which returns a valid pointer if the extension is supported +/// by the endpoint. +template +T* QueryExtension(EventEngine::Endpoint* endpoint) { + return static_cast(endpoint->QueryExtension(T::EndpointExtensionName())); +} + +} // namespace experimental +} // namespace grpc_event_engine + +#endif // GRPC_SRC_CORE_LIB_EVENT_ENGINE_QUERY_EXTENSIONS_H diff --git a/src/core/lib/iomgr/event_engine_shims/endpoint.cc b/src/core/lib/iomgr/event_engine_shims/endpoint.cc index 341fe1e5776..b1e8fdf8904 100644 --- a/src/core/lib/iomgr/event_engine_shims/endpoint.cc +++ b/src/core/lib/iomgr/event_engine_shims/endpoint.cc @@ -69,6 +69,8 @@ class EventEngineEndpointWrapper { explicit EventEngineEndpointWrapper( std::unique_ptr endpoint); + EventEngine::Endpoint* endpoint() { return endpoint_.get(); } + int Fd() { grpc_core::MutexLock lock(&mu_); return fd_; @@ -428,6 +430,17 @@ bool grpc_is_event_engine_endpoint(grpc_endpoint* ep) { return ep->vtable == &grpc_event_engine_endpoint_vtable; } +EventEngine::Endpoint* grpc_get_wrapped_event_engine_endpoint( + grpc_endpoint* ep) { + if (!grpc_is_event_engine_endpoint(ep)) { + return nullptr; + } + auto* eeep = + reinterpret_cast( + ep); + return eeep->wrapper->endpoint(); +} + void grpc_event_engine_endpoint_destroy_and_release_fd( grpc_endpoint* ep, int* fd, grpc_closure* on_release_fd) { auto* eeep = diff --git a/src/core/lib/iomgr/event_engine_shims/endpoint.h b/src/core/lib/iomgr/event_engine_shims/endpoint.h index bc018f1e4d7..efd57c6ea6d 100644 --- a/src/core/lib/iomgr/event_engine_shims/endpoint.h +++ b/src/core/lib/iomgr/event_engine_shims/endpoint.h @@ -31,6 +31,11 @@ grpc_endpoint* grpc_event_engine_endpoint_create( /// Returns true if the passed endpoint is an event engine shim endpoint. bool grpc_is_event_engine_endpoint(grpc_endpoint* ep); +/// Returns the wrapped event engine endpoint if the given grpc_endpoint is an +/// event engine shim endpoint. Otherwise it returns nullptr. +EventEngine::Endpoint* grpc_get_wrapped_event_engine_endpoint( + grpc_endpoint* ep); + /// Destroys the passed in event engine shim endpoint and schedules the /// asynchronous execution of the on_release_fd callback. The int pointer fd is /// set to the underlying endpoint's file descriptor. diff --git a/test/core/event_engine/BUILD b/test/core/event_engine/BUILD index 13543244c14..1cdf7d24cd3 100644 --- a/test/core/event_engine/BUILD +++ b/test/core/event_engine/BUILD @@ -232,3 +232,16 @@ grpc_cc_library( "//src/core:time", ], ) + +grpc_cc_test( + name = "query_extensions_test", + srcs = ["query_extensions_test.cc"], + external_deps = ["gtest"], + language = "C++", + uses_event_engine = False, + uses_polling = False, + deps = [ + "//:gpr_platform", + "//src/core:event_engine_query_extensions", + ], +) diff --git a/test/core/event_engine/query_extensions_test.cc b/test/core/event_engine/query_extensions_test.cc new file mode 100644 index 00000000000..712a496f38c --- /dev/null +++ b/test/core/event_engine/query_extensions_test.cc @@ -0,0 +1,95 @@ +// Copyright 2023 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include + +#include "src/core/lib/event_engine/query_extensions.h" + +#include + +#include "absl/functional/any_invocable.h" +#include "absl/status/status.h" +#include "gtest/gtest.h" + +#include +#include + +#include "src/core/lib/gprpp/crash.h" + +namespace grpc_event_engine { +namespace experimental { +namespace { + +template +class TestExtension { + public: + TestExtension() = default; + ~TestExtension() = default; + + static std::string EndpointExtensionName() { + return "grpc.test.test_extension" + std::to_string(i); + } + + int GetValue() const { return val_; } + + private: + int val_ = i; +}; + +class ExtendedTestEndpoint + : public ExtendedEndpoint, TestExtension<1>, + TestExtension<2>> { + public: + ExtendedTestEndpoint() = default; + ~ExtendedTestEndpoint() override = default; + bool Read(absl::AnyInvocable /*on_read*/, + SliceBuffer* /*buffer*/, const ReadArgs* /*args*/) override { + grpc_core::Crash("Not implemented"); + }; + bool Write(absl::AnyInvocable /*on_writable*/, + SliceBuffer* /*data*/, const WriteArgs* /*args*/) override { + grpc_core::Crash("Not implemented"); + } + /// Returns an address in the format described in DNSResolver. The returned + /// values are expected to remain valid for the life of the Endpoint. + const EventEngine::ResolvedAddress& GetPeerAddress() const override { + grpc_core::Crash("Not implemented"); + } + const EventEngine::ResolvedAddress& GetLocalAddress() const override { + grpc_core::Crash("Not implemented"); + }; +}; + +TEST(QueryExtensionsTest, EndpointSupportsMultipleExtensions) { + ExtendedTestEndpoint endpoint; + TestExtension<0>* extension_0 = QueryExtension>(&endpoint); + TestExtension<1>* extension_1 = QueryExtension>(&endpoint); + TestExtension<2>* extension_2 = QueryExtension>(&endpoint); + + EXPECT_NE(extension_0, nullptr); + EXPECT_NE(extension_1, nullptr); + EXPECT_NE(extension_2, nullptr); + + EXPECT_EQ(extension_0->GetValue(), 0); + EXPECT_EQ(extension_1->GetValue(), 1); + EXPECT_EQ(extension_2->GetValue(), 2); +} +} // namespace + +} // namespace experimental +} // namespace grpc_event_engine + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tools/run_tests/generated/tests.json b/tools/run_tests/generated/tests.json index 35bdddc0eb0..d58a72accec 100644 --- a/tools/run_tests/generated/tests.json +++ b/tools/run_tests/generated/tests.json @@ -7189,6 +7189,30 @@ ], "uses_polling": true }, + { + "args": [], + "benchmark": false, + "ci_platforms": [ + "linux", + "mac", + "posix", + "windows" + ], + "cpu_cost": 1.0, + "exclude_configs": [], + "exclude_iomgrs": [], + "flaky": false, + "gtest": true, + "language": "c++", + "name": "query_extensions_test", + "platforms": [ + "linux", + "mac", + "posix", + "windows" + ], + "uses_polling": false + }, { "args": [], "benchmark": false,