From e481f6acc5c1bf560eea5095dbbf36e7fae77fd5 Mon Sep 17 00:00:00 2001 From: Craig Tiller Date: Fri, 1 Dec 2023 09:05:25 -0800 Subject: [PATCH 1/9] [promises] Generate a better error message for a common mistake (#35191) Closes #35191 COPYBARA_INTEGRATE_REVIEW=https://github.com/grpc/grpc/pull/35191 from ctiller:cg-promise-like-fix 0683ffe16943bee355dd97735b972a7315aabb93 PiperOrigin-RevId: 587026178 --- src/core/lib/promise/detail/promise_like.h | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/core/lib/promise/detail/promise_like.h b/src/core/lib/promise/detail/promise_like.h index 486653856e3..4bec3661642 100644 --- a/src/core/lib/promise/detail/promise_like.h +++ b/src/core/lib/promise/detail/promise_like.h @@ -68,6 +68,10 @@ class PromiseLike { private: GPR_NO_UNIQUE_ADDRESS F f_; + static_assert(!std::is_void::type>::value, + "PromiseLike cannot be used with a function that returns void " + "- return Empty{} instead"); + public: // NOLINTNEXTLINE - internal detail that drastically simplifies calling code. PromiseLike(F&& f) : f_(std::forward(f)) {} From 49f7ee96d195f44e21d5b57106f2564b6e101b15 Mon Sep 17 00:00:00 2001 From: Craig Tiller Date: Fri, 1 Dec 2023 11:06:38 -0800 Subject: [PATCH 2/9] [promises] Convert http-client filter to v3 filter API (#35189) Also start building a temporary wrapping layer so that the new style filters can execute as promise filters directly. Closes #35189 COPYBARA_INTEGRATE_REVIEW=https://github.com/grpc/grpc/pull/35189 from ctiller:cg-http-cli dbd73e8aca651b3c597d78e47b4d52e173e7b02b PiperOrigin-RevId: 587060881 --- .../filters/http/client/http_client_filter.cc | 54 ++-- .../filters/http/client/http_client_filter.h | 14 +- src/core/lib/channel/promise_based_filter.h | 285 ++++++++++++++++++ 3 files changed, 316 insertions(+), 37 deletions(-) diff --git a/src/core/ext/filters/http/client/http_client_filter.cc b/src/core/ext/filters/http/client/http_client_filter.cc index 8b547bac6ee..07fcba62581 100644 --- a/src/core/ext/filters/http/client/http_client_filter.cc +++ b/src/core/ext/filters/http/client/http_client_filter.cc @@ -51,6 +51,9 @@ namespace grpc_core { +const NoInterceptor HttpClientFilter::Call::OnServerToClientMessage; +const NoInterceptor HttpClientFilter::Call::OnClientToServerMessage; + const grpc_channel_filter HttpClientFilter::kFilter = MakePromiseBasedFilter("http-client"); @@ -105,40 +108,27 @@ Slice UserAgentFromArgs(const ChannelArgs& args, } } // namespace -ArenaPromise HttpClientFilter::MakeCallPromise( - CallArgs call_args, NextPromiseFactory next_promise_factory) { - auto& md = call_args.client_initial_metadata; - if (test_only_use_put_requests_) { - md->Set(HttpMethodMetadata(), HttpMethodMetadata::kPut); +void HttpClientFilter::Call::OnClientInitialMetadata(ClientMetadata& md, + HttpClientFilter* filter) { + if (filter->test_only_use_put_requests_) { + md.Set(HttpMethodMetadata(), HttpMethodMetadata::kPut); } else { - md->Set(HttpMethodMetadata(), HttpMethodMetadata::kPost); + md.Set(HttpMethodMetadata(), HttpMethodMetadata::kPost); } - md->Set(HttpSchemeMetadata(), scheme_); - md->Set(TeMetadata(), TeMetadata::kTrailers); - md->Set(ContentTypeMetadata(), ContentTypeMetadata::kApplicationGrpc); - md->Set(UserAgentMetadata(), user_agent_.Ref()); - - auto* initial_metadata_err = - GetContext()->New>(); - - call_args.server_initial_metadata->InterceptAndMap( - [initial_metadata_err]( - ServerMetadataHandle md) -> absl::optional { - auto r = CheckServerMetadata(md.get()); - if (!r.ok()) { - initial_metadata_err->Set(ServerMetadataFromStatus(r)); - return absl::nullopt; - } - return std::move(md); - }); - - return Race(initial_metadata_err->Wait(), - Map(next_promise_factory(std::move(call_args)), - [](ServerMetadataHandle md) -> ServerMetadataHandle { - auto r = CheckServerMetadata(md.get()); - if (!r.ok()) return ServerMetadataFromStatus(r); - return md; - })); + md.Set(HttpSchemeMetadata(), filter->scheme_); + md.Set(TeMetadata(), TeMetadata::kTrailers); + md.Set(ContentTypeMetadata(), ContentTypeMetadata::kApplicationGrpc); + md.Set(UserAgentMetadata(), filter->user_agent_.Ref()); +} + +absl::Status HttpClientFilter::Call::OnServerInitialMetadata( + ServerMetadata& md) { + return CheckServerMetadata(&md); +} + +absl::Status HttpClientFilter::Call::OnServerTrailingMetadata( + ServerMetadata& md) { + return CheckServerMetadata(&md); } HttpClientFilter::HttpClientFilter(HttpSchemeMetadata::ValueType scheme, diff --git a/src/core/ext/filters/http/client/http_client_filter.h b/src/core/ext/filters/http/client/http_client_filter.h index 298daf03c67..3146ea07385 100644 --- a/src/core/ext/filters/http/client/http_client_filter.h +++ b/src/core/ext/filters/http/client/http_client_filter.h @@ -25,23 +25,27 @@ #include "src/core/lib/channel/channel_args.h" #include "src/core/lib/channel/channel_fwd.h" #include "src/core/lib/channel/promise_based_filter.h" -#include "src/core/lib/promise/arena_promise.h" #include "src/core/lib/slice/slice.h" #include "src/core/lib/transport/metadata_batch.h" #include "src/core/lib/transport/transport.h" namespace grpc_core { -class HttpClientFilter : public ChannelFilter { +class HttpClientFilter : public ImplementChannelFilter { public: static const grpc_channel_filter kFilter; static absl::StatusOr Create( const ChannelArgs& args, ChannelFilter::Args filter_args); - // Construct a promise for one call. - ArenaPromise MakeCallPromise( - CallArgs call_args, NextPromiseFactory next_promise_factory) override; + class Call { + public: + void OnClientInitialMetadata(ClientMetadata& md, HttpClientFilter* filter); + absl::Status OnServerInitialMetadata(ServerMetadata& md); + absl::Status OnServerTrailingMetadata(ServerMetadata& md); + static const NoInterceptor OnClientToServerMessage; + static const NoInterceptor OnServerToClientMessage; + }; private: HttpClientFilter(HttpSchemeMetadata::ValueType scheme, Slice user_agent, diff --git a/src/core/lib/channel/promise_based_filter.h b/src/core/lib/channel/promise_based_filter.h index 19efe505db2..25ee1230ef1 100644 --- a/src/core/lib/channel/promise_based_filter.h +++ b/src/core/lib/channel/promise_based_filter.h @@ -43,6 +43,7 @@ #include #include "src/core/lib/channel/call_finalization.h" +#include "src/core/lib/channel/channel_args.h" #include "src/core/lib/channel/channel_fwd.h" #include "src/core/lib/channel/channel_stack.h" #include "src/core/lib/channel/context.h" @@ -60,6 +61,7 @@ #include "src/core/lib/promise/context.h" #include "src/core/lib/promise/pipe.h" #include "src/core/lib/promise/poll.h" +#include "src/core/lib/promise/race.h" #include "src/core/lib/resource_quota/arena.h" #include "src/core/lib/slice/slice_buffer.h" #include "src/core/lib/surface/call.h" @@ -122,6 +124,289 @@ class ChannelFilter { grpc_event_engine::experimental::GetDefaultEventEngine(); }; +struct NoInterceptor {}; + +namespace promise_filter_detail { + +// Determine if a list of interceptors has any that need to asyncronously error +// the promise. If so, we need to allocate a latch for the generated promise for +// the original promise stack polyfill code that's generated. + +inline constexpr bool HasAsyncErrorInterceptor() { return false; } + +inline constexpr bool HasAsyncErrorInterceptor(const NoInterceptor*) { + return false; +} + +template +inline constexpr bool HasAsyncErrorInterceptor(absl::Status (T::*)(A...)) { + return true; +} + +template +inline constexpr bool HasAsyncErrorInterceptor(void (T::*)(A...)) { + return false; +} + +// For the list case we do two interceptors to avoid amiguities with the single +// argument forms above. +template +inline constexpr bool HasAsyncErrorInterceptor(I1 i1, I2 i2, + Interceptors... interceptors) { + return HasAsyncErrorInterceptor(i1) || HasAsyncErrorInterceptor(i2) || + HasAsyncErrorInterceptor(interceptors...); +} + +// Composite for a given channel type to determine if any of its interceptors +// fall into this category: later code should use this. +template +inline constexpr bool CallHasAsyncErrorInterceptor() { + return HasAsyncErrorInterceptor(&Derived::Call::OnClientToServerMessage, + &Derived::Call::OnServerInitialMetadata, + &Derived::Call::OnServerToClientMessage); +} + +// Determine if an interceptor needs to access the channel via one of its +// arguments. If so, we need to allocate a pointer to the channel for the +// generated polyfill promise for the original promise stack. + +inline constexpr bool HasChannelAccess() { return false; } + +inline constexpr bool HasChannelAccess(const NoInterceptor*) { return false; } + +template +inline constexpr bool HasChannelAccess(R (T::*)(A)) { + return false; +} + +template +inline constexpr bool HasChannelAccess(R (T::*)(A, C)) { + return true; +} + +// For the list case we do two interceptors to avoid amiguities with the single +// argument forms above. +template +inline constexpr bool HasChannelAccess(I1 i1, I2 i2, + Interceptors... interceptors) { + return HasChannelAccess(i1) || HasChannelAccess(i2) || + HasChannelAccess(interceptors...); +} + +// Composite for a given channel type to determine if any of its interceptors +// fall into this category: later code should use this. +template +inline constexpr bool CallHasChannelAccess() { + return HasChannelAccess(&Derived::Call::OnClientInitialMetadata, + &Derived::Call::OnClientToServerMessage, + &Derived::Call::OnServerInitialMetadata, + &Derived::Call::OnServerToClientMessage, + &Derived::Call::OnServerTrailingMetadata); +} + +// Given a boolean X export a type: +// either T if X is true +// or an empty type if it is false +template +struct TypeIfNeeded; + +template +struct TypeIfNeeded { + struct Type { + Type() = default; + template + explicit Type(Whatever) : Type() {} + }; +}; + +template +struct TypeIfNeeded { + using Type = T; +}; + +// For the original promise scheme polyfill: +// If a set of interceptors might fail asynchronously, wrap the main +// promise in a race with the cancellation latch. +// If not, just return the main promise. +template +struct RaceAsyncCompletion; + +template <> +struct RaceAsyncCompletion { + template + static Promise Run(Promise x, void*) { + return x; + } +}; + +template <> +struct RaceAsyncCompletion { + template + static Promise Run(Promise x, Latch* latch) { + return Race(latch->Wait(), std::move(x)); + } +}; + +// For the original promise scheme polyfill: data associated with once call. +template +struct FilterCallData { + explicit FilterCallData(Derived* channel) : channel(channel) {} + GPR_NO_UNIQUE_ADDRESS typename Derived::Call call; + GPR_NO_UNIQUE_ADDRESS + typename TypeIfNeeded, + CallHasAsyncErrorInterceptor()>::Type + error_latch; + GPR_NO_UNIQUE_ADDRESS + typename TypeIfNeeded()>::Type + channel; +}; + +template +auto MapResult(const NoInterceptor*, Promise x, void*) { + return x; +} + +template +auto MapResult(absl::Status (Derived::Call::*fn)(ServerMetadata&), Promise x, + FilterCallData* call_data) { + GPR_DEBUG_ASSERT(fn == &Derived::Call::OnServerTrailingMetadata); + return Map(std::move(x), [call_data](ServerMetadataHandle md) { + auto status = call_data->call.OnServerTrailingMetadata(*md); + if (!status.ok()) return ServerMetadataFromStatus(status); + return md; + }); +} + +inline auto RunCall(const NoInterceptor*, CallArgs call_args, + NextPromiseFactory next_promise_factory, void*) { + return next_promise_factory(std::move(call_args)); +} + +template +inline auto RunCall(void (Derived::Call::*fn)(ClientMetadata& md), + CallArgs call_args, NextPromiseFactory next_promise_factory, + FilterCallData* call_data) { + GPR_DEBUG_ASSERT(fn == &Derived::Call::OnClientInitialMetadata); + call_data->call.OnClientInitialMetadata(*call_args.client_initial_metadata); + return next_promise_factory(std::move(call_args)); +} + +template +inline auto RunCall(void (Derived::Call::*fn)(ClientMetadata& md, + Derived* channel), + CallArgs call_args, NextPromiseFactory next_promise_factory, + FilterCallData* call_data) { + GPR_DEBUG_ASSERT(fn == &Derived::Call::OnClientInitialMetadata); + call_data->call.OnClientInitialMetadata(*call_args.client_initial_metadata, + call_data->channel); + return next_promise_factory(std::move(call_args)); +} + +inline void InterceptClientToServerMessage(const NoInterceptor*, void*, + CallArgs&) {} + +inline void InterceptServerInitialMetadata(const NoInterceptor*, void*, + CallArgs&) {} + +template +inline void InterceptServerInitialMetadata( + absl::Status (Derived::Call::*fn)(ServerMetadata&), + FilterCallData* call_data, CallArgs& call_args) { + GPR_DEBUG_ASSERT(fn == &Derived::Call::OnServerInitialMetadata); + call_args.server_initial_metadata->InterceptAndMap( + [call_data]( + ServerMetadataHandle md) -> absl::optional { + auto status = call_data->call.OnServerInitialMetadata(*md); + if (!status.ok() && !call_data->error_latch.is_set()) { + call_data->error_latch.Set(ServerMetadataFromStatus(status)); + return absl::nullopt; + } + return std::move(md); + }); +} + +inline void InterceptServerToClientMessage(const NoInterceptor*, void*, + CallArgs&) {} + +template +absl::enable_if_t>::value, + FilterCallData*> +MakeFilterCall(Derived*) { + static FilterCallData call{nullptr}; + return &call; +} + +template +absl::enable_if_t>::value, + FilterCallData*> +MakeFilterCall(Derived* derived) { + return GetContext()->ManagedNew>(derived); +} + +} // namespace promise_filter_detail + +// Base class for promise-based channel filters. +// Eventually this machinery will move elsewhere (the interception logic will +// move directly into the channel stack, and so filters will just directly +// derive from `ChannelFilter`) +// +// Implements new-style call filters, and polyfills them into the previous +// scheme. +// +// Call filters: +// Derived types should declare a class `Call` with the following members: +// - OnClientInitialMetadata - $VALUE_TYPE = ClientMetadata +// - OnServerInitialMetadata - $VALUE_TYPE = ServerMetadata +// - OnServerToClientMessage - $VALUE_TYPE = Message +// - OnClientToServerMessage - $VALUE_TYPE = Message +// - OnServerTrailingMetadata - $VALUE_TYPE = ServerMetadata +// These members define an interception point for a particular event in +// the call lifecycle. +// The type of these members matters, and is selectable by the class +// author. For $INTERCEPTOR_NAME in the above list: +// - static const NoInterceptor $INTERCEPTOR_NAME: +// defines that this filter does not intercept this event. +// there is zero runtime cost added to handling that event by this filter. +// - void $INTERCEPTOR_NAME($VALUE_TYPE&): +// the filter intercepts this event, and can modify the value. +// it never fails. +// - absl::Status $INTERCEPTOR_NAME($VALUE_TYPE&): +// the filter intercepts this event, and can modify the value. +// it can fail, in which case the call will be aborted. +// - void $INTERCEPTOR_NAME($VALUE_TYPE&, Derived*): +// the filter intercepts this event, and can modify the value. +// it can access the channel via the second argument. +// it never fails. +// - absl::Status $INTERCEPTOR_NAME($VALUE_TYPE&, Derived*): +// the filter intercepts this event, and can modify the value. +// it can access the channel via the second argument. +// it can fail, in which case the call will be aborted. +template +class ImplementChannelFilter : public ChannelFilter { + public: + ArenaPromise MakeCallPromise( + CallArgs call_args, NextPromiseFactory next_promise_factory) final { + auto* call = promise_filter_detail::MakeFilterCall( + static_cast(this)); + promise_filter_detail::InterceptClientToServerMessage( + &Derived::Call::OnClientToServerMessage, call, call_args); + promise_filter_detail::InterceptServerInitialMetadata( + &Derived::Call::OnServerInitialMetadata, call, call_args); + promise_filter_detail::InterceptServerToClientMessage( + &Derived::Call::OnServerToClientMessage, call, call_args); + return promise_filter_detail::MapResult( + &Derived::Call::OnServerTrailingMetadata, + promise_filter_detail::RaceAsyncCompletion< + promise_filter_detail::CallHasAsyncErrorInterceptor()>:: + Run(promise_filter_detail::RunCall( + &Derived::Call::OnClientInitialMetadata, + std::move(call_args), std::move(next_promise_factory), + call), + &call->error_latch), + call); + } +}; + // Designator for whether a filter is client side or server side. // Please don't use this outside calls to MakePromiseBasedFilter - it's // intended to be deleted once the promise conversion is complete. From 84678829af643474e5a8cd468066be3ef1559fae Mon Sep 17 00:00:00 2001 From: Vignesh Babu Date: Fri, 1 Dec 2023 11:41:36 -0800 Subject: [PATCH 3/9] [EventEngine] Add public methods to allow EventEngine Endpoints to support optional Extensions. PiperOrigin-RevId: 587071965 --- CMakeLists.txt | 35 +++++++ build_autogenerated.yaml | 13 +++ include/grpc/event_engine/event_engine.h | 39 ++++++++ src/core/BUILD | 12 +++ src/core/lib/event_engine/query_extensions.h | 70 ++++++++++++++ .../lib/iomgr/event_engine_shims/endpoint.cc | 13 +++ .../lib/iomgr/event_engine_shims/endpoint.h | 5 + test/core/event_engine/BUILD | 13 +++ .../event_engine/query_extensions_test.cc | 95 +++++++++++++++++++ tools/run_tests/generated/tests.json | 24 +++++ 10 files changed, 319 insertions(+) create mode 100644 src/core/lib/event_engine/query_extensions.h create mode 100644 test/core/event_engine/query_extensions_test.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index 4c8979a56ec..9d2a53f4ed6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1229,6 +1229,7 @@ if(gRPC_BUILD_TESTS) add_dependencies(buildtests_cxx proxy_auth_test) add_dependencies(buildtests_cxx qps_json_driver) add_dependencies(buildtests_cxx qps_worker) + add_dependencies(buildtests_cxx query_extensions_test) add_dependencies(buildtests_cxx race_test) add_dependencies(buildtests_cxx random_early_detection_test) add_dependencies(buildtests_cxx raw_end2end_test) @@ -18938,6 +18939,40 @@ target_link_libraries(qps_worker ) +endif() +if(gRPC_BUILD_TESTS) + +add_executable(query_extensions_test + test/core/event_engine/query_extensions_test.cc +) +target_compile_features(query_extensions_test PUBLIC cxx_std_14) +target_include_directories(query_extensions_test + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR} + ${CMAKE_CURRENT_SOURCE_DIR}/include + ${_gRPC_ADDRESS_SORTING_INCLUDE_DIR} + ${_gRPC_RE2_INCLUDE_DIR} + ${_gRPC_SSL_INCLUDE_DIR} + ${_gRPC_UPB_GENERATED_DIR} + ${_gRPC_UPB_GRPC_GENERATED_DIR} + ${_gRPC_UPB_INCLUDE_DIR} + ${_gRPC_XXHASH_INCLUDE_DIR} + ${_gRPC_ZLIB_INCLUDE_DIR} + third_party/googletest/googletest/include + third_party/googletest/googletest + third_party/googletest/googlemock/include + third_party/googletest/googlemock + ${_gRPC_PROTO_GENS_DIR} +) + +target_link_libraries(query_extensions_test + ${_gRPC_ALLTARGETS_LIBRARIES} + gtest + absl::statusor + gpr +) + + endif() if(gRPC_BUILD_TESTS) diff --git a/build_autogenerated.yaml b/build_autogenerated.yaml index bd2a9aa0bbd..7eba4bf9d53 100644 --- a/build_autogenerated.yaml +++ b/build_autogenerated.yaml @@ -13309,6 +13309,19 @@ targets: deps: - grpc++_test_config - grpc++_test_util +- name: query_extensions_test + gtest: true + build: test + language: c++ + headers: + - src/core/lib/event_engine/query_extensions.h + src: + - test/core/event_engine/query_extensions_test.cc + deps: + - gtest + - absl/status:statusor + - gpr + uses_polling: false - name: race_test gtest: true build: test diff --git a/include/grpc/event_engine/event_engine.h b/include/grpc/event_engine/event_engine.h index 4beca657625..20cbc64f52f 100644 --- a/include/grpc/event_engine/event_engine.h +++ b/include/grpc/event_engine/event_engine.h @@ -255,6 +255,45 @@ class EventEngine : public std::enable_shared_from_this { /// values are expected to remain valid for the life of the Endpoint. virtual const ResolvedAddress& GetPeerAddress() const = 0; virtual const ResolvedAddress& GetLocalAddress() const = 0; + + /// A method which allows users to query whether an Endpoint implementation + /// supports a specified extension. The name of the extension is provided + /// as an input. + /// + /// An extension could be any type with a unique string id. Each extension + /// may support additional capabilities and if the Endpoint implementation + /// supports the queried extension, it should return a valid pointer to the + /// extension type. + /// + /// E.g., use case of an EventEngine::Endpoint supporting a custom + /// extension. + /// + /// class CustomEndpointExtension { + /// public: + /// static constexpr std::string name = "my.namespace.extension_name"; + /// void Process() { ... } + /// } + /// + /// + /// class CustomEndpoint : + /// public EventEngine::Endpoint, CustomEndpointExtension { + /// public: + /// void* QueryExtension(absl::string_view id) override { + /// if (id == CustomEndpointExtension::name) { + /// return static_cast(this); + /// } + /// return nullptr; + /// } + /// ... + /// } + /// + /// auto ext_ = + /// static_cast( + /// endpoint->QueryExtension(CustomrEndpointExtension::name)); + /// if (ext_ != nullptr) { ext_->Process(); } + /// + /// + virtual void* QueryExtension(absl::string_view /*id*/) { return nullptr; } }; /// Called when a new connection is established. diff --git a/src/core/BUILD b/src/core/BUILD index 351f101b819..e6a4457ac0f 100644 --- a/src/core/BUILD +++ b/src/core/BUILD @@ -1540,6 +1540,18 @@ grpc_cc_library( ], ) +grpc_cc_library( + name = "event_engine_query_extensions", + hdrs = [ + "lib/event_engine/query_extensions.h", + ], + external_deps = ["absl/strings"], + deps = [ + "//:event_engine_base_hdrs", + "//:gpr_platform", + ], +) + grpc_cc_library( name = "event_engine_work_queue", hdrs = [ diff --git a/src/core/lib/event_engine/query_extensions.h b/src/core/lib/event_engine/query_extensions.h new file mode 100644 index 00000000000..2ef15ccfdab --- /dev/null +++ b/src/core/lib/event_engine/query_extensions.h @@ -0,0 +1,70 @@ +// Copyright 2023 gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#ifndef GRPC_SRC_CORE_LIB_EVENT_ENGINE_QUERY_EXTENSIONS_H +#define GRPC_SRC_CORE_LIB_EVENT_ENGINE_QUERY_EXTENSIONS_H + +#include + +#include "absl/strings/string_view.h" + +#include + +namespace grpc_event_engine { +namespace experimental { + +namespace endpoint_detail { + +template +struct QueryExtensionRecursion; + +template +struct QueryExtensionRecursion { + static void* Query(absl::string_view id, Querying* p) { + if (id == E::EndpointExtensionName()) return static_cast(p); + return QueryExtensionRecursion::Query(id, p); + } +}; + +template +struct QueryExtensionRecursion { + static void* Query(absl::string_view, Querying*) { return nullptr; } +}; + +} // namespace endpoint_detail + +// A helper class to derive from some set of base classes and export +// QueryExtension for them all. +// Endpoint implementations which need to support different extensions just need +// to derive from ExtendedEndpoint class. +template +class ExtendedEndpoint : public EventEngine::Endpoint, public Exports... { + public: + void* QueryExtension(absl::string_view id) override { + return endpoint_detail::QueryExtensionRecursion::Query(id, + this); + } +}; + +/// A helper method which returns a valid pointer if the extension is supported +/// by the endpoint. +template +T* QueryExtension(EventEngine::Endpoint* endpoint) { + return static_cast(endpoint->QueryExtension(T::EndpointExtensionName())); +} + +} // namespace experimental +} // namespace grpc_event_engine + +#endif // GRPC_SRC_CORE_LIB_EVENT_ENGINE_QUERY_EXTENSIONS_H diff --git a/src/core/lib/iomgr/event_engine_shims/endpoint.cc b/src/core/lib/iomgr/event_engine_shims/endpoint.cc index 341fe1e5776..b1e8fdf8904 100644 --- a/src/core/lib/iomgr/event_engine_shims/endpoint.cc +++ b/src/core/lib/iomgr/event_engine_shims/endpoint.cc @@ -69,6 +69,8 @@ class EventEngineEndpointWrapper { explicit EventEngineEndpointWrapper( std::unique_ptr endpoint); + EventEngine::Endpoint* endpoint() { return endpoint_.get(); } + int Fd() { grpc_core::MutexLock lock(&mu_); return fd_; @@ -428,6 +430,17 @@ bool grpc_is_event_engine_endpoint(grpc_endpoint* ep) { return ep->vtable == &grpc_event_engine_endpoint_vtable; } +EventEngine::Endpoint* grpc_get_wrapped_event_engine_endpoint( + grpc_endpoint* ep) { + if (!grpc_is_event_engine_endpoint(ep)) { + return nullptr; + } + auto* eeep = + reinterpret_cast( + ep); + return eeep->wrapper->endpoint(); +} + void grpc_event_engine_endpoint_destroy_and_release_fd( grpc_endpoint* ep, int* fd, grpc_closure* on_release_fd) { auto* eeep = diff --git a/src/core/lib/iomgr/event_engine_shims/endpoint.h b/src/core/lib/iomgr/event_engine_shims/endpoint.h index bc018f1e4d7..efd57c6ea6d 100644 --- a/src/core/lib/iomgr/event_engine_shims/endpoint.h +++ b/src/core/lib/iomgr/event_engine_shims/endpoint.h @@ -31,6 +31,11 @@ grpc_endpoint* grpc_event_engine_endpoint_create( /// Returns true if the passed endpoint is an event engine shim endpoint. bool grpc_is_event_engine_endpoint(grpc_endpoint* ep); +/// Returns the wrapped event engine endpoint if the given grpc_endpoint is an +/// event engine shim endpoint. Otherwise it returns nullptr. +EventEngine::Endpoint* grpc_get_wrapped_event_engine_endpoint( + grpc_endpoint* ep); + /// Destroys the passed in event engine shim endpoint and schedules the /// asynchronous execution of the on_release_fd callback. The int pointer fd is /// set to the underlying endpoint's file descriptor. diff --git a/test/core/event_engine/BUILD b/test/core/event_engine/BUILD index 13543244c14..1cdf7d24cd3 100644 --- a/test/core/event_engine/BUILD +++ b/test/core/event_engine/BUILD @@ -232,3 +232,16 @@ grpc_cc_library( "//src/core:time", ], ) + +grpc_cc_test( + name = "query_extensions_test", + srcs = ["query_extensions_test.cc"], + external_deps = ["gtest"], + language = "C++", + uses_event_engine = False, + uses_polling = False, + deps = [ + "//:gpr_platform", + "//src/core:event_engine_query_extensions", + ], +) diff --git a/test/core/event_engine/query_extensions_test.cc b/test/core/event_engine/query_extensions_test.cc new file mode 100644 index 00000000000..712a496f38c --- /dev/null +++ b/test/core/event_engine/query_extensions_test.cc @@ -0,0 +1,95 @@ +// Copyright 2023 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include + +#include "src/core/lib/event_engine/query_extensions.h" + +#include + +#include "absl/functional/any_invocable.h" +#include "absl/status/status.h" +#include "gtest/gtest.h" + +#include +#include + +#include "src/core/lib/gprpp/crash.h" + +namespace grpc_event_engine { +namespace experimental { +namespace { + +template +class TestExtension { + public: + TestExtension() = default; + ~TestExtension() = default; + + static std::string EndpointExtensionName() { + return "grpc.test.test_extension" + std::to_string(i); + } + + int GetValue() const { return val_; } + + private: + int val_ = i; +}; + +class ExtendedTestEndpoint + : public ExtendedEndpoint, TestExtension<1>, + TestExtension<2>> { + public: + ExtendedTestEndpoint() = default; + ~ExtendedTestEndpoint() override = default; + bool Read(absl::AnyInvocable /*on_read*/, + SliceBuffer* /*buffer*/, const ReadArgs* /*args*/) override { + grpc_core::Crash("Not implemented"); + }; + bool Write(absl::AnyInvocable /*on_writable*/, + SliceBuffer* /*data*/, const WriteArgs* /*args*/) override { + grpc_core::Crash("Not implemented"); + } + /// Returns an address in the format described in DNSResolver. The returned + /// values are expected to remain valid for the life of the Endpoint. + const EventEngine::ResolvedAddress& GetPeerAddress() const override { + grpc_core::Crash("Not implemented"); + } + const EventEngine::ResolvedAddress& GetLocalAddress() const override { + grpc_core::Crash("Not implemented"); + }; +}; + +TEST(QueryExtensionsTest, EndpointSupportsMultipleExtensions) { + ExtendedTestEndpoint endpoint; + TestExtension<0>* extension_0 = QueryExtension>(&endpoint); + TestExtension<1>* extension_1 = QueryExtension>(&endpoint); + TestExtension<2>* extension_2 = QueryExtension>(&endpoint); + + EXPECT_NE(extension_0, nullptr); + EXPECT_NE(extension_1, nullptr); + EXPECT_NE(extension_2, nullptr); + + EXPECT_EQ(extension_0->GetValue(), 0); + EXPECT_EQ(extension_1->GetValue(), 1); + EXPECT_EQ(extension_2->GetValue(), 2); +} +} // namespace + +} // namespace experimental +} // namespace grpc_event_engine + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tools/run_tests/generated/tests.json b/tools/run_tests/generated/tests.json index 35bdddc0eb0..d58a72accec 100644 --- a/tools/run_tests/generated/tests.json +++ b/tools/run_tests/generated/tests.json @@ -7189,6 +7189,30 @@ ], "uses_polling": true }, + { + "args": [], + "benchmark": false, + "ci_platforms": [ + "linux", + "mac", + "posix", + "windows" + ], + "cpu_cost": 1.0, + "exclude_configs": [], + "exclude_iomgrs": [], + "flaky": false, + "gtest": true, + "language": "c++", + "name": "query_extensions_test", + "platforms": [ + "linux", + "mac", + "posix", + "windows" + ], + "uses_polling": false + }, { "args": [], "benchmark": false, From 1d4ecf66298234578e5a9781df42b72866db7924 Mon Sep 17 00:00:00 2001 From: "Mark D. Roth" Date: Fri, 1 Dec 2023 12:13:37 -0800 Subject: [PATCH 4/9] [RefCounted] allow RefCounted<> to work for const types (#35188) Closes #35188 COPYBARA_INTEGRATE_REVIEW=https://github.com/grpc/grpc/pull/35188 from markdroth:ref_counted_const e2dc753b6b234c74e8374e860f2946c840d3b45c PiperOrigin-RevId: 587081377 --- src/core/lib/gprpp/ref_counted.h | 47 ++++++++++++++++++++++------- test/core/gprpp/ref_counted_test.cc | 11 +++++++ 2 files changed, 47 insertions(+), 11 deletions(-) diff --git a/src/core/lib/gprpp/ref_counted.h b/src/core/lib/gprpp/ref_counted.h index cdf692c5ce7..5eaf5cda0f1 100644 --- a/src/core/lib/gprpp/ref_counted.h +++ b/src/core/lib/gprpp/ref_counted.h @@ -219,7 +219,7 @@ class NonPolymorphicRefCount { // Default behavior: Delete the object. struct UnrefDelete { template - void operator()(T* p) { + void operator()(T* p) const { delete p; } }; @@ -231,7 +231,7 @@ struct UnrefDelete { // later by identifying entries for which RefIfNonZero() returns null. struct UnrefNoDelete { template - void operator()(T* /*p*/) {} + void operator()(T* /*p*/) const {} }; // Call the object's dtor but do not delete it. This is useful for cases @@ -239,7 +239,7 @@ struct UnrefNoDelete { // arena). struct UnrefCallDtor { template - void operator()(T* p) { + void operator()(T* p) const { p->~T(); } }; @@ -279,32 +279,44 @@ class RefCounted : public Impl { // Note: Depending on the Impl used, this dtor can be implicitly virtual. ~RefCounted() = default; + // Ref() for mutable types. GRPC_MUST_USE_RESULT RefCountedPtr Ref() { IncrementRefCount(); return RefCountedPtr(static_cast(this)); } - GRPC_MUST_USE_RESULT RefCountedPtr Ref(const DebugLocation& location, const char* reason) { IncrementRefCount(location, reason); return RefCountedPtr(static_cast(this)); } + // Ref() for const types. + GRPC_MUST_USE_RESULT RefCountedPtr Ref() const { + IncrementRefCount(); + return RefCountedPtr(static_cast(this)); + } + GRPC_MUST_USE_RESULT RefCountedPtr Ref( + const DebugLocation& location, const char* reason) const { + IncrementRefCount(location, reason); + return RefCountedPtr(static_cast(this)); + } + // TODO(roth): Once all of our code is converted to C++ and can use // RefCountedPtr<> instead of manual ref-counting, make this method // private, since it will only be used by RefCountedPtr<>, which is a // friend of this class. - void Unref() { + void Unref() const { if (GPR_UNLIKELY(refs_.Unref())) { - unref_behavior_(static_cast(this)); + unref_behavior_(static_cast(this)); } } - void Unref(const DebugLocation& location, const char* reason) { + void Unref(const DebugLocation& location, const char* reason) const { if (GPR_UNLIKELY(refs_.Unref(location, reason))) { - unref_behavior_(static_cast(this)); + unref_behavior_(static_cast(this)); } } + // RefIfNonZero() for mutable types. GRPC_MUST_USE_RESULT RefCountedPtr RefIfNonZero() { return RefCountedPtr(refs_.RefIfNonZero() ? static_cast(this) : nullptr); @@ -316,6 +328,18 @@ class RefCounted : public Impl { : nullptr); } + // RefIfNonZero() for const types. + GRPC_MUST_USE_RESULT RefCountedPtr RefIfNonZero() const { + return RefCountedPtr( + refs_.RefIfNonZero() ? static_cast(this) : nullptr); + } + GRPC_MUST_USE_RESULT RefCountedPtr RefIfNonZero( + const DebugLocation& location, const char* reason) const { + return RefCountedPtr(refs_.RefIfNonZero(location, reason) + ? static_cast(this) + : nullptr); + } + // Not copyable nor movable. RefCounted(const RefCounted&) = delete; RefCounted& operator=(const RefCounted&) = delete; @@ -336,12 +360,13 @@ class RefCounted : public Impl { template friend class RefCountedPtr; - void IncrementRefCount() { refs_.Ref(); } - void IncrementRefCount(const DebugLocation& location, const char* reason) { + void IncrementRefCount() const { refs_.Ref(); } + void IncrementRefCount(const DebugLocation& location, + const char* reason) const { refs_.Ref(location, reason); } - RefCount refs_; + mutable RefCount refs_; GPR_NO_UNIQUE_ADDRESS UnrefBehavior unref_behavior_; }; diff --git a/test/core/gprpp/ref_counted_test.cc b/test/core/gprpp/ref_counted_test.cc index 4d8761ecb1d..7c28cddc6f6 100644 --- a/test/core/gprpp/ref_counted_test.cc +++ b/test/core/gprpp/ref_counted_test.cc @@ -53,6 +53,17 @@ TEST(RefCounted, ExtraRef) { foo->Unref(); } +TEST(RefCounted, Const) { + const Foo* foo = new Foo(); + RefCountedPtr foop = foo->Ref(); + foop.release(); + foop = foo->RefIfNonZero(); + foop.release(); + foo->Unref(); + foo->Unref(); + foo->Unref(); +} + class Value : public RefCounted { public: Value(int value, std::set>* registry) : value_(value) { From 39493f93c06e0adb46dea88a095f4115fb79d0c1 Mon Sep 17 00:00:00 2001 From: Esun Kim Date: Fri, 1 Dec 2023 16:23:11 -0800 Subject: [PATCH 5/9] Making windows/dll test no-op temporarily. This will be reenabled once DLL work is done. PiperOrigin-RevId: 587155376 --- test/distrib/cpp/run_distrib_test_cmake_for_dll.bat | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/test/distrib/cpp/run_distrib_test_cmake_for_dll.bat b/test/distrib/cpp/run_distrib_test_cmake_for_dll.bat index 2adfaf14141..887c20dcd74 100644 --- a/test/distrib/cpp/run_distrib_test_cmake_for_dll.bat +++ b/test/distrib/cpp/run_distrib_test_cmake_for_dll.bat @@ -78,6 +78,11 @@ popd @rem folders, like the following command trying to imitate. git submodule foreach bash -c "cd $toplevel; rm -rf $name" +@rem TODO(dawidcha): Remove this once this DLL test can pass { +echo Skipped! +exit /b 0 +@rem TODO(dawidcha): Remove this once this DLL test can pass } + @rem Install gRPC @rem NOTE(jtattermusch): The -DProtobuf_USE_STATIC_LIBS=ON is necessary on cmake3.16+ @rem since by default "find_package(Protobuf ...)" uses the cmake's builtin From addd18b186855f998c729daf36e8dad49f84dcc7 Mon Sep 17 00:00:00 2001 From: Craig Tiller Date: Fri, 1 Dec 2023 17:52:42 -0800 Subject: [PATCH 6/9] [channel-args] Enforce const-correctness for RefCounted values (#35199) Closes #35199 COPYBARA_INTEGRATE_REVIEW=https://github.com/grpc/grpc/pull/35199 from ctiller:refcount a3f856858a31b87c7501806223ee82a8edcb9900 PiperOrigin-RevId: 587178819 --- src/core/lib/channel/channel_args.h | 63 ++++++++++++++++++++++++-- test/core/channel/channel_args_test.cc | 31 +++++++++++++ 2 files changed, 89 insertions(+), 5 deletions(-) diff --git a/src/core/lib/channel/channel_args.h b/src/core/lib/channel/channel_args.h index 78848750294..2c10d955127 100644 --- a/src/core/lib/channel/channel_args.h +++ b/src/core/lib/channel/channel_args.h @@ -183,13 +183,27 @@ struct ChannelArgTypeTraits +struct ChannelArgPointerShouldBeConst { + static constexpr bool kValue = false; +}; + +template +struct ChannelArgPointerShouldBeConst< + T, absl::void_t> { + static constexpr bool kValue = T::ChannelArgUseConstPtr(); +}; + // GetObject support for shared_ptr and RefCountedPtr template struct GetObjectImpl; // std::shared_ptr implementation template struct GetObjectImpl< - T, absl::enable_if_t::value, void>> { + T, absl::enable_if_t::kValue && + SupportedSharedPtrType::value, + void>> { using Result = T*; using ReffedResult = std::shared_ptr; using StoredType = std::shared_ptr*; @@ -210,7 +224,9 @@ struct GetObjectImpl< // RefCountedPtr template struct GetObjectImpl< - T, absl::enable_if_t::value, void>> { + T, absl::enable_if_t::kValue && + !SupportedSharedPtrType::value, + void>> { using Result = T*; using ReffedResult = RefCountedPtr; using StoredType = Result; @@ -226,6 +242,26 @@ struct GetObjectImpl< }; }; +template +struct GetObjectImpl< + T, absl::enable_if_t::kValue && + !SupportedSharedPtrType::value, + void>> { + using Result = const T*; + using ReffedResult = RefCountedPtr; + using StoredType = Result; + static Result Get(StoredType p) { return p; }; + static ReffedResult GetReffed(StoredType p) { + if (p == nullptr) return nullptr; + return p->Ref(); + }; + static ReffedResult GetReffed(StoredType p, const DebugLocation& location, + const char* reason) { + if (p == nullptr) return nullptr; + return p->Ref(location, reason); + }; +}; + // Provide the canonical name for a type's channel arg key template struct ChannelArgNameTraits { @@ -242,6 +278,7 @@ struct ChannelArgNameTraits { return GRPC_INTERNAL_ARG_EVENT_ENGINE; } }; + class ChannelArgs { public: class Pointer { @@ -381,15 +418,29 @@ class ChannelArgs { GRPC_MUST_USE_RESULT auto Set(absl::string_view name, RefCountedPtr value) const -> absl::enable_if_t< - std::is_same>::VTable())>::value, + !ChannelArgPointerShouldBeConst::kValue && + std::is_same>::VTable())>::value, ChannelArgs> { return Set( name, Pointer(value.release(), ChannelArgTypeTraits>::VTable())); } template + GRPC_MUST_USE_RESULT auto Set(absl::string_view name, + RefCountedPtr value) const + -> absl::enable_if_t< + ChannelArgPointerShouldBeConst::kValue && + std::is_same>::VTable())>::value, + ChannelArgs> { + return Set( + name, Pointer(const_cast(value.release()), + ChannelArgTypeTraits>::VTable())); + } + template GRPC_MUST_USE_RESULT absl::enable_if_t< std::is_same< const grpc_arg_pointer_vtable*, @@ -426,6 +477,8 @@ class ChannelArgs { absl::optional GetInt(absl::string_view name) const; absl::optional GetString(absl::string_view name) const; absl::optional GetOwnedString(absl::string_view name) const; + // WARNING: this is broken if `name` represents something that was stored as a + // RefCounted - we will discard the const-ness. void* GetVoidPointer(absl::string_view name) const; template typename GetObjectImpl::StoredType GetPointer( diff --git a/test/core/channel/channel_args_test.cc b/test/core/channel/channel_args_test.cc index 10a05d35e26..fd035ccc12d 100644 --- a/test/core/channel/channel_args_test.cc +++ b/test/core/channel/channel_args_test.cc @@ -209,6 +209,37 @@ TEST(ChannelArgsTest, GetNonOwningEventEngine) { ASSERT_EQ(p.use_count(), 2); } +struct MutableValue : public RefCounted { + static constexpr absl::string_view ChannelArgName() { + return "grpc.test.mutable_value"; + } + static int ChannelArgsCompare(const MutableValue* a, const MutableValue* b) { + return a->i - b->i; + } + int i = 42; +}; + +struct ConstValue : public RefCounted { + static constexpr absl::string_view ChannelArgName() { + return "grpc.test.const_value"; + } + static constexpr bool ChannelArgUseConstPtr() { return true; }; + static int ChannelArgsCompare(const ConstValue* a, const ConstValue* b) { + return a->i - b->i; + } + int i = 42; +}; + +TEST(ChannelArgsTest, SetObjectRespectsMutabilityConstraints) { + auto m = MakeRefCounted(); + auto c = MakeRefCounted(); + auto args = ChannelArgs().SetObject(m).SetObject(c); + RefCountedPtr m1 = args.GetObjectRef(); + RefCountedPtr c1 = args.GetObjectRef(); + EXPECT_EQ(m1.get(), m.get()); + EXPECT_EQ(c1.get(), c.get()); +} + } // namespace grpc_core TEST(GrpcChannelArgsTest, Create) { From 7047cc17a8337f8c841ae1b5a6afe4bc5b8a72c6 Mon Sep 17 00:00:00 2001 From: Craig Tiller Date: Fri, 1 Dec 2023 17:53:23 -0800 Subject: [PATCH 7/9] [promises] Migrate http server filter to new API (#35197) Closes #35197 COPYBARA_INTEGRATE_REVIEW=https://github.com/grpc/grpc/pull/35197 from ctiller:cg-http-svr cdde418a81dff4ad85c7196dc3e3464e429bf9cc PiperOrigin-RevId: 587178983 --- .../filters/http/server/http_server_filter.cc | 81 +++++++++---------- .../filters/http/server/http_server_filter.h | 14 +++- src/core/lib/channel/promise_based_filter.h | 65 +++++++++++++++ 3 files changed, 115 insertions(+), 45 deletions(-) diff --git a/src/core/ext/filters/http/server/http_server_filter.cc b/src/core/ext/filters/http/server/http_server_filter.cc index 2d4953dd26b..830b931520f 100644 --- a/src/core/ext/filters/http/server/http_server_filter.cc +++ b/src/core/ext/filters/http/server/http_server_filter.cc @@ -49,6 +49,9 @@ namespace grpc_core { +const NoInterceptor HttpServerFilter::Call::OnClientToServerMessage; +const NoInterceptor HttpServerFilter::Call::OnServerToClientMessage; + const grpc_channel_filter HttpServerFilter::kFilter = MakePromiseBasedFilter("http-server"); @@ -71,85 +74,81 @@ ServerMetadataHandle MalformedRequest(absl::string_view explanation) { } } // namespace -ArenaPromise HttpServerFilter::MakeCallPromise( - CallArgs call_args, NextPromiseFactory next_promise_factory) { - const auto& md = call_args.client_initial_metadata; - - auto method = md->get(HttpMethodMetadata()); +ServerMetadataHandle HttpServerFilter::Call::OnClientInitialMetadata( + ClientMetadata& md, HttpServerFilter* filter) { + auto method = md.get(HttpMethodMetadata()); if (method.has_value()) { switch (*method) { case HttpMethodMetadata::kPost: break; case HttpMethodMetadata::kPut: - if (allow_put_requests_) { + if (filter->allow_put_requests_) { break; } ABSL_FALLTHROUGH_INTENDED; case HttpMethodMetadata::kInvalid: case HttpMethodMetadata::kGet: - return Immediate(MalformedRequest("Bad method header")); + return MalformedRequest("Bad method header"); } } else { - return Immediate(MalformedRequest("Missing :method header")); + return MalformedRequest("Missing :method header"); } - auto te = md->Take(TeMetadata()); + auto te = md.Take(TeMetadata()); if (te == TeMetadata::kTrailers) { // Do nothing, ok. } else if (!te.has_value()) { - return Immediate(MalformedRequest("Missing :te header")); + return MalformedRequest("Missing :te header"); } else { - return Immediate(MalformedRequest("Bad :te header")); + return MalformedRequest("Bad :te header"); } - auto scheme = md->Take(HttpSchemeMetadata()); + auto scheme = md.Take(HttpSchemeMetadata()); if (scheme.has_value()) { if (*scheme == HttpSchemeMetadata::kInvalid) { - return Immediate(MalformedRequest("Bad :scheme header")); + return MalformedRequest("Bad :scheme header"); } } else { - return Immediate(MalformedRequest("Missing :scheme header")); + return MalformedRequest("Missing :scheme header"); } - md->Remove(ContentTypeMetadata()); + md.Remove(ContentTypeMetadata()); - Slice* path_slice = md->get_pointer(HttpPathMetadata()); + Slice* path_slice = md.get_pointer(HttpPathMetadata()); if (path_slice == nullptr) { - return Immediate(MalformedRequest("Missing :path header")); + return MalformedRequest("Missing :path header"); } - if (md->get_pointer(HttpAuthorityMetadata()) == nullptr) { - absl::optional host = md->Take(HostMetadata()); + if (md.get_pointer(HttpAuthorityMetadata()) == nullptr) { + absl::optional host = md.Take(HostMetadata()); if (host.has_value()) { - md->Set(HttpAuthorityMetadata(), std::move(*host)); + md.Set(HttpAuthorityMetadata(), std::move(*host)); } } - if (md->get_pointer(HttpAuthorityMetadata()) == nullptr) { - return Immediate(MalformedRequest("Missing :authority header")); + if (md.get_pointer(HttpAuthorityMetadata()) == nullptr) { + return MalformedRequest("Missing :authority header"); } - if (!surface_user_agent_) { - md->Remove(UserAgentMetadata()); + if (!filter->surface_user_agent_) { + md.Remove(UserAgentMetadata()); } - call_args.server_initial_metadata->InterceptAndMap( - [](ServerMetadataHandle md) { - if (grpc_call_trace.enabled()) { - gpr_log(GPR_INFO, "%s[http-server] Write metadata", - Activity::current()->DebugTag().c_str()); - } - FilterOutgoingMetadata(md.get()); - md->Set(HttpStatusMetadata(), 200); - md->Set(ContentTypeMetadata(), ContentTypeMetadata::kApplicationGrpc); - return md; - }); - - return Map(next_promise_factory(std::move(call_args)), - [](ServerMetadataHandle md) -> ServerMetadataHandle { - FilterOutgoingMetadata(md.get()); - return md; - }); + return nullptr; +} + +void HttpServerFilter::Call::OnServerInitialMetadata(ServerMetadata& md) { + if (grpc_call_trace.enabled()) { + gpr_log(GPR_INFO, "%s[http-server] Write metadata", + Activity::current()->DebugTag().c_str()); + } + FilterOutgoingMetadata(&md); + md.Set(HttpStatusMetadata(), 200); + md.Set(ContentTypeMetadata(), ContentTypeMetadata::kApplicationGrpc); +} + +void HttpServerFilter::Call::OnServerTrailingMetadata(ServerMetadata& md) { + FilterOutgoingMetadata(&md); } absl::StatusOr HttpServerFilter::Create( diff --git a/src/core/ext/filters/http/server/http_server_filter.h b/src/core/ext/filters/http/server/http_server_filter.h index bc97bd53b8a..43eeb148957 100644 --- a/src/core/ext/filters/http/server/http_server_filter.h +++ b/src/core/ext/filters/http/server/http_server_filter.h @@ -32,16 +32,22 @@ namespace grpc_core { // Processes metadata on the server side for HTTP2 transports -class HttpServerFilter : public ChannelFilter { +class HttpServerFilter : public ImplementChannelFilter { public: static const grpc_channel_filter kFilter; static absl::StatusOr Create( const ChannelArgs& args, ChannelFilter::Args filter_args); - // Construct a promise for one call. - ArenaPromise MakeCallPromise( - CallArgs call_args, NextPromiseFactory next_promise_factory) override; + class Call { + public: + ServerMetadataHandle OnClientInitialMetadata(ClientMetadata& md, + HttpServerFilter* filter); + void OnServerInitialMetadata(ServerMetadata& md); + void OnServerTrailingMetadata(ServerMetadata& md); + static const NoInterceptor OnClientToServerMessage; + static const NoInterceptor OnServerToClientMessage; + }; private: HttpServerFilter(bool surface_user_agent, bool allow_put_requests) diff --git a/src/core/lib/channel/promise_based_filter.h b/src/core/lib/channel/promise_based_filter.h index 25ee1230ef1..a94b79bbd0c 100644 --- a/src/core/lib/channel/promise_based_filter.h +++ b/src/core/lib/channel/promise_based_filter.h @@ -61,6 +61,7 @@ #include "src/core/lib/promise/context.h" #include "src/core/lib/promise/pipe.h" #include "src/core/lib/promise/poll.h" +#include "src/core/lib/promise/promise.h" #include "src/core/lib/promise/race.h" #include "src/core/lib/resource_quota/arena.h" #include "src/core/lib/slice/slice_buffer.h" @@ -143,6 +144,12 @@ inline constexpr bool HasAsyncErrorInterceptor(absl::Status (T::*)(A...)) { return true; } +template +inline constexpr bool HasAsyncErrorInterceptor( + ServerMetadataHandle (T::*)(A...)) { + return true; +} + template inline constexpr bool HasAsyncErrorInterceptor(void (T::*)(A...)) { return false; @@ -277,6 +284,16 @@ auto MapResult(absl::Status (Derived::Call::*fn)(ServerMetadata&), Promise x, }); } +template +auto MapResult(void (Derived::Call::*fn)(ServerMetadata&), Promise x, + FilterCallData* call_data) { + GPR_DEBUG_ASSERT(fn == &Derived::Call::OnServerTrailingMetadata); + return Map(std::move(x), [call_data](ServerMetadataHandle md) { + call_data->call.OnServerTrailingMetadata(*md); + return md; + }); +} + inline auto RunCall(const NoInterceptor*, CallArgs call_args, NextPromiseFactory next_promise_factory, void*) { return next_promise_factory(std::move(call_args)); @@ -291,6 +308,31 @@ inline auto RunCall(void (Derived::Call::*fn)(ClientMetadata& md), return next_promise_factory(std::move(call_args)); } +template +inline auto RunCall( + ServerMetadataHandle (Derived::Call::*fn)(ClientMetadata& md), + CallArgs call_args, NextPromiseFactory next_promise_factory, + FilterCallData* call_data) -> ArenaPromise { + GPR_DEBUG_ASSERT(fn == &Derived::Call::OnClientInitialMetadata); + auto return_md = call_data->call.OnClientInitialMetadata( + *call_args.client_initial_metadata); + if (return_md == nullptr) return next_promise_factory(std::move(call_args)); + return Immediate(std::move(return_md)); +} + +template +inline auto RunCall(ServerMetadataHandle (Derived::Call::*fn)( + ClientMetadata& md, Derived* channel), + CallArgs call_args, NextPromiseFactory next_promise_factory, + FilterCallData* call_data) + -> ArenaPromise { + GPR_DEBUG_ASSERT(fn == &Derived::Call::OnClientInitialMetadata); + auto return_md = call_data->call.OnClientInitialMetadata( + *call_args.client_initial_metadata, call_data->channel); + if (return_md == nullptr) return next_promise_factory(std::move(call_args)); + return Immediate(std::move(return_md)); +} + template inline auto RunCall(void (Derived::Call::*fn)(ClientMetadata& md, Derived* channel), @@ -308,6 +350,18 @@ inline void InterceptClientToServerMessage(const NoInterceptor*, void*, inline void InterceptServerInitialMetadata(const NoInterceptor*, void*, CallArgs&) {} +template +inline void InterceptServerInitialMetadata( + void (Derived::Call::*fn)(ServerMetadata&), + FilterCallData* call_data, CallArgs& call_args) { + GPR_DEBUG_ASSERT(fn == &Derived::Call::OnServerInitialMetadata); + call_args.server_initial_metadata->InterceptAndMap( + [call_data](ServerMetadataHandle md) { + call_data->call.OnServerInitialMetadata(*md); + return md; + }); +} + template inline void InterceptServerInitialMetadata( absl::Status (Derived::Call::*fn)(ServerMetadata&), @@ -373,6 +427,11 @@ MakeFilterCall(Derived* derived) { // - absl::Status $INTERCEPTOR_NAME($VALUE_TYPE&): // the filter intercepts this event, and can modify the value. // it can fail, in which case the call will be aborted. +// - ServerMetadataHandle $INTERCEPTOR_NAME($VALUE_TYPE&) +// the filter intercepts this event, and can modify the value. +// the filter can return nullptr for success, or a metadata handle for +// failure (in which case the call will be aborted). +// useful for cases where the exact metadata returned needs to be customized. // - void $INTERCEPTOR_NAME($VALUE_TYPE&, Derived*): // the filter intercepts this event, and can modify the value. // it can access the channel via the second argument. @@ -381,6 +440,12 @@ MakeFilterCall(Derived* derived) { // the filter intercepts this event, and can modify the value. // it can access the channel via the second argument. // it can fail, in which case the call will be aborted. +// - ServerMetadataHandle $INTERCEPTOR_NAME($VALUE_TYPE&, Derived*) +// the filter intercepts this event, and can modify the value. +// it can access the channel via the second argument. +// the filter can return nullptr for success, or a metadata handle for +// failure (in which case the call will be aborted). +// useful for cases where the exact metadata returned needs to be customized. template class ImplementChannelFilter : public ChannelFilter { public: From 207b88186878e67901a8d69095de64c221ed138e Mon Sep 17 00:00:00 2001 From: Tanvi Jagtap <139093547+tanvi-jagtap@users.noreply.github.com> Date: Mon, 4 Dec 2023 11:19:11 -0800 Subject: [PATCH 8/9] [grpc] Remove redundant check (#35161) We dont need this check anymore . Deleting the check from the yaml and the sh file. Closes #35161 PiperOrigin-RevId: 587784923 --- tools/run_tests/sanity/check_do_not_submit.sh | 23 ------------------- tools/run_tests/sanity/sanity_tests.yaml | 1 - 2 files changed, 24 deletions(-) delete mode 100755 tools/run_tests/sanity/check_do_not_submit.sh diff --git a/tools/run_tests/sanity/check_do_not_submit.sh b/tools/run_tests/sanity/check_do_not_submit.sh deleted file mode 100755 index 6e0438cd5cb..00000000000 --- a/tools/run_tests/sanity/check_do_not_submit.sh +++ /dev/null @@ -1,23 +0,0 @@ -#! /bin/bash -# Copyright 2021 The gRPC Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# Checks if any file contains "DO NOT SUBMIT" - -cd "$(dirname "$0")/../../.." || exit 1 -git grep -Irn 'DO NOT SUBMIT' -- \ - './*' \ - ':!*check_do_not_submit.sh' \ - ':!third_party/' -test $? -eq 1 || exit 1 diff --git a/tools/run_tests/sanity/sanity_tests.yaml b/tools/run_tests/sanity/sanity_tests.yaml index 439891f75bf..6c4494b00c2 100644 --- a/tools/run_tests/sanity/sanity_tests.yaml +++ b/tools/run_tests/sanity/sanity_tests.yaml @@ -6,7 +6,6 @@ - script: tools/run_tests/sanity/check_buildifier.sh - script: tools/run_tests/sanity/check_cache_mk.sh - script: tools/run_tests/sanity/check_deprecated_grpc++.py -- script: tools/run_tests/sanity/check_do_not_submit.sh - script: tools/run_tests/sanity/check_illegal_terms.sh - script: tools/run_tests/sanity/check_port_platform.py - script: tools/run_tests/sanity/check_include_style.py From 501b895736b196dcbb4efa6df77bb39220c74a07 Mon Sep 17 00:00:00 2001 From: Craig Tiller Date: Mon, 4 Dec 2023 12:04:41 -0800 Subject: [PATCH 9/9] [fuzzing-heck] Fix a bug that comes up with promises + work serializer dispatch (#35196) b/310341170 I'm kind of proud of our testing for finding this Closes #35196 COPYBARA_INTEGRATE_REVIEW=https://github.com/grpc/grpc/pull/35196 from ctiller:fuzz-no 28c95606f97732ea1069f3eb067484f9a3f84ec1 PiperOrigin-RevId: 587798657 --- src/core/ext/filters/client_channel/client_channel.cc | 8 ++++++++ .../negative_deadline/5769288995635200 | 11 +++++++++++ 2 files changed, 19 insertions(+) create mode 100644 test/core/end2end/end2end_test_corpus/negative_deadline/5769288995635200 diff --git a/src/core/ext/filters/client_channel/client_channel.cc b/src/core/ext/filters/client_channel/client_channel.cc index b5064ceb20f..197173aa7e0 100644 --- a/src/core/ext/filters/client_channel/client_channel.cc +++ b/src/core/ext/filters/client_channel/client_channel.cc @@ -315,6 +315,14 @@ class ClientChannel::PromiseBasedCallData : public ClientChannel::CallData { public: explicit PromiseBasedCallData(ClientChannel* chand) : chand_(chand) {} + ~PromiseBasedCallData() override { + if (was_queued_ && client_initial_metadata_ != nullptr) { + MutexLock lock(&chand_->resolution_mu_); + RemoveCallFromResolverQueuedCallsLocked(); + chand_->resolver_queued_calls_.erase(this); + } + } + ArenaPromise> MakeNameResolutionPromise( CallArgs call_args) { pollent_ = NowOrNever(call_args.polling_entity->WaitAndCopy()).value(); diff --git a/test/core/end2end/end2end_test_corpus/negative_deadline/5769288995635200 b/test/core/end2end/end2end_test_corpus/negative_deadline/5769288995635200 new file mode 100644 index 00000000000..bb65fceedf8 --- /dev/null +++ b/test/core/end2end/end2end_test_corpus/negative_deadline/5769288995635200 @@ -0,0 +1,11 @@ +test_id: 3 +event_engine_actions { + run_delay: 9851624184873214 + run_delay: 1 + run_delay: 1 + run_delay: 0 + run_delay: 53876069761024 +} +config_vars { + experiments: 280384054960896 +}