diff --git a/BUILD b/BUILD index 12f7524abc4..d79912a34e1 100644 --- a/BUILD +++ b/BUILD @@ -3164,7 +3164,7 @@ grpc_cc_library( ], language = "c++", visibility = ["@grpc:http"], - deps = [ + deps = ["//src/core:call_promise", "channel_stack_builder", "config", "gpr", diff --git a/src/core/BUILD b/src/core/BUILD index 5597c92cbbd..f300ccb76ab 100644 --- a/src/core/BUILD +++ b/src/core/BUILD @@ -400,6 +400,34 @@ grpc_cc_library( ], ) +grpc_cc_library( + name = "call_promise", + hdrs = ["lib/channel/call_promise.h"], + external_deps = [ + "absl/status", + "absl/strings", + "absl/types:variant", + ], + language = "c++", + public_hdrs = [ + "lib/promise/map_pipe.h", + ], + deps = [ + "activity", + "construct_destruct", + "for_each", + "map", + "pipe", + "poll", + "promise_factory", + "promise_like", + "promise_status", + "promise_trace", + "//:gpr", + "//:grpc_trace", + ], +) + grpc_cc_library( name = "map_pipe", external_deps = ["absl/status"], diff --git a/src/core/ext/filters/http/client/http_client_filter.cc b/src/core/ext/filters/http/client/http_client_filter.cc index 88ec6ce9a54..00c6dcf9e5d 100644 --- a/src/core/ext/filters/http/client/http_client_filter.cc +++ b/src/core/ext/filters/http/client/http_client_filter.cc @@ -36,6 +36,7 @@ #include #include +#include "src/core/lib/channel/call_promise.h" #include "src/core/lib/channel/channel_args.h" #include "src/core/lib/channel/channel_stack.h" #include "src/core/lib/promise/context.h" @@ -117,24 +118,15 @@ ArenaPromise HttpClientFilter::MakeCallPromise( md->Set(ContentTypeMetadata(), ContentTypeMetadata::kApplicationGrpc); md->Set(UserAgentMetadata(), user_agent_.Ref()); - auto* read_latch = GetContext()->New>(); - auto* write_latch = - std::exchange(call_args.server_initial_metadata, read_latch); - - return TryConcurrently( - Seq(next_promise_factory(std::move(call_args)), - [](ServerMetadataHandle md) -> ServerMetadataHandle { - auto r = CheckServerMetadata(md.get()); - if (!r.ok()) return ServerMetadataFromStatus(r); - return md; - })) - .NecessaryPull(Seq(read_latch->Wait(), - [write_latch](ServerMetadata** md) -> absl::Status { - auto r = *md == nullptr ? absl::OkStatus() - : CheckServerMetadata(*md); - write_latch->Set(*md); - return r; - })); + return CallPromiseBuilder() + .OnServerInitialMetadata( + [](ServerMetadataHandle md) { return CheckServerMetadata(md.get()); }) + .OnServerTrailingMetadata([](ServerMetadataHandle md) { + auto r = CheckServerMetadata(md.get()); + if (!r.ok()) return ServerMetadataFromStatus(r); + return md; + }) + .BuildServer(std::move(call_args), std::move(next_promise_factory)); } HttpClientFilter::HttpClientFilter(HttpSchemeMetadata::ValueType scheme, diff --git a/src/core/lib/channel/call_promise.h b/src/core/lib/channel/call_promise.h new file mode 100644 index 00000000000..664b4bbf12b --- /dev/null +++ b/src/core/lib/channel/call_promise.h @@ -0,0 +1,277 @@ +// Copyright 2022 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef CALL_PROMISE_H +#define CALL_PROMISE_H + +#include + +#include "src/core/lib/promise/map.h" +#include "src/core/lib/promise/seq.h" + +namespace grpc_core { + +namespace call_promise_detail { + +template +class MainLoop; + +template +class MainLoop { + public: + static auto MakePromise(CallArgs call_args, Empty, Empty, Empty, Empty, + NextPromiseFactory f) { + return f(std::move(call_args)); + } +}; + +template +class OnServerInitialMetadataHandler; + +inline auto WrapServerMetadataInHandle(ServerMetadata** p) { + return ServerMetadataHandle(*p, Arena::PooledDeleter(nullptr)); +} + +template <> +class OnServerInitialMetadataHandler { + public: + template + static auto Wrap(OnServerInitialMetadataType f, CallArgs& args) { + return Map(args.server_initial_metadata->Wait(), + [f = std::move(f)](ServerMetadata** p) { + return f(WrapServerMetadataInHandle(p)); + }); + } +}; + +template <> +class OnServerInitialMetadataHandler { + public: + template + static auto Wrap(OnServerInitialMetadataType f, CallArgs& args) { + auto* read_latch = GetContext()->New>(); + auto* write_latch = std::exchange(args.server_initial_metadata, read_latch); + return Seq( + Map(read_latch->Wait(), WrapServerMetadataInHandle), + Map(std::move(f), [read_latch, write_latch](ServerMetadataHandle h) { + if (h.get() != read_latch->Get()) { + *read_latch->Get() = std::move(*h.get()); + } + write_latch->Set(read_latch->Get()); + return absl::OkStatus(); + })); + } +}; + +template +class MainLoop { + public: + static auto MakePromise( + CallArgs call_args, + OnServerInitialMetadataType on_server_initial_metadata, Empty, Empty, + NextPromiseFactory f) { + auto wrapped_on_server_initial_metadata = + OnServerInitialMetadataHandler()))>:: + Wrap(std::move(on_server_initial_metadata), call_args); + return TryConcurrently(f(std::move(call_args))) + .NecessaryPull(wrapped_on_server_initial_metadata); + } +}; + +template +class AddBracketingMetadata { + public: + template + static auto MakePromise(OnClientInitialMetadataType start, Middle middle, + OnServerTrailingMetadataType end) { + return Seq(std::move(start), std::move(middle), std::move(end)); + } +}; + +template +class AddBracketingMetadata { + public: + template + static auto MakePromise(Empty, Middle middle, + OnServerTrailingMetadataType end) { + return Seq(std::move(middle), std::move(end)); + } +}; + +template +class AddBracketingMetadata { + public: + template + static auto MakePromise(OnClientInitialMetadataType start, Middle middle, + Empty) { + return Seq(std::move(start), std::move(middle)); + } +}; + +template <> +class AddBracketingMetadata { + public: + template + static auto MakePromise(Empty, Middle middle, Empty) { + return middle; + } +}; + +template +class CallPromiseBuilder { + public: + CallPromiseBuilder() = default; + + CallPromiseBuilder(OnClientInitialMetadataType on_client_initial_metadata, + OnServerInitialMetadataType on_server_initial_metadata, + OnServerTrailingMetadataType on_server_trailing_metadata, + MapOutgoingMessageType map_outgoing_message, + MapIncomingMessageType map_incoming_message) + : on_client_initial_metadata_(std::move(on_client_initial_metadata)), + on_server_initial_metadata_(std::move(on_server_initial_metadata)), + on_server_trailing_metadata_(std::move(on_server_trailing_metadata)), + map_outgoing_message_(std::move(map_outgoing_message)), + map_incoming_message_(std::move(map_incoming_message)) {} + + template + CallPromiseBuilder + OnClientInitialMetadata(F f) { + static_assert(std::is_same::value, + "OnClientInitialMetadata already set"); + return CallPromiseBuilder{ + std::forward(f), on_server_initial_metadata_, + on_server_trailing_metadata_, map_outgoing_message_, + map_incoming_message_}; + } + + template + CallPromiseBuilder + OnServerInitialMetadata(F f) { + static_assert(std::is_same::value, + "OnServerInitialMetadata already set"); + return CallPromiseBuilder{ + on_client_initial_metadata_, std::forward(f), + on_server_trailing_metadata_, map_outgoing_message_, + map_incoming_message_}; + } + + template + CallPromiseBuilder + OnServerTrailingMetadata(F f) { + static_assert(std::is_same::value, + "OnServerTrailingMetadata already set"); + return CallPromiseBuilder{ + on_client_initial_metadata_, on_server_initial_metadata_, + std::forward(f), map_outgoing_message_, map_incoming_message_}; + } + + template + CallPromiseBuilder + MapOutgoingMessage(F f) { + static_assert(std::is_same::value, + "MapOutgoingMessage already set"); + return CallPromiseBuilder< + OnClientInitialMetadataType, OnServerInitialMetadataType, + OnServerTrailingMetadataType, F, MapIncomingMessageType>{ + on_client_initial_metadata_, on_server_initial_metadata_, + on_server_trailing_metadata_, std::forward(f), + map_incoming_message_}; + } + + template + CallPromiseBuilder + MapIncomingMessage(F f) { + static_assert(std::is_same::value, + "MapIncomingMessage already set"); + return CallPromiseBuilder< + OnClientInitialMetadataType, OnServerInitialMetadataType, + OnServerTrailingMetadataType, MapOutgoingMessageType, F>{ + on_client_initial_metadata_, on_server_initial_metadata_, + on_server_trailing_metadata_, map_outgoing_message_, + std::forward(f)}; + } + + auto BuildClient(CallArgs call_args, + NextPromiseFactory next_promise_factory) { + return AddBracketingMetadata:: + MakePromise( + std::move(on_client_initial_metadata_), + MainLoop:: + MakePromise(std::move(call_args), + std::move(on_server_initial_metadata_), + std::move(map_outgoing_message_), + std::move(map_incoming_message_), + std::move(next_promise_factory)), + std::move(on_server_trailing_metadata_)); + } + + auto BuildServer(CallArgs call_args, + NextPromiseFactory next_promise_factory) { + return AddBracketingMetadata:: + MakePromise( + std::move(on_client_initial_metadata_), + MainLoop:: + MakePromise(std::move(call_args), + std::move(on_server_initial_metadata_), + std::move(map_outgoing_message_), + std::move(map_incoming_message_), + std::move(next_promise_factory)), + std::move(on_server_trailing_metadata_)); + } + + private: + GPR_NO_UNIQUE_ADDRESS OnClientInitialMetadataType on_client_initial_metadata_; + GPR_NO_UNIQUE_ADDRESS OnServerInitialMetadataType on_server_initial_metadata_; + GPR_NO_UNIQUE_ADDRESS OnServerTrailingMetadataType + on_server_trailing_metadata_; + GPR_NO_UNIQUE_ADDRESS MapOutgoingMessageType map_outgoing_message_; + GPR_NO_UNIQUE_ADDRESS MapIncomingMessageType map_incoming_message_; +}; + +} // namespace call_promise_detail + +using CallPromiseBuilder = + call_promise_detail::CallPromiseBuilder; + +} // namespace grpc_core + +#endif diff --git a/src/core/lib/channel/connected_channel.cc b/src/core/lib/channel/connected_channel.cc index ca4877aec13..e8ea129eba3 100644 --- a/src/core/lib/channel/connected_channel.cc +++ b/src/core/lib/channel/connected_channel.cc @@ -929,7 +929,7 @@ class ServerStream final : public ConnectedChannelStream { } Poll Poll() { - absl::MutexLock lock(mu()); + MutexLock lock(mu()); if (grpc_call_trace.enabled()) { gpr_log(GPR_INFO, "%s[connected] PollConnectedChannel: %s", diff --git a/src/core/lib/promise/latch.h b/src/core/lib/promise/latch.h index 04bb1a2b400..d27384eec1c 100644 --- a/src/core/lib/promise/latch.h +++ b/src/core/lib/promise/latch.h @@ -89,6 +89,11 @@ class Latch { waiter_.Wake(); } + const T& Get() const { + GPR_DEBUG_ASSERT(has_value_); + return value_; + } + private: std::string DebugTag() { return absl::StrCat(Activity::current()->DebugTag(), " LATCH[0x", diff --git a/tools/distrib/gen_compilation_database.py b/tools/distrib/gen_compilation_database.py index f2c93ade71c..c218fed973c 100755 --- a/tools/distrib/gen_compilation_database.py +++ b/tools/distrib/gen_compilation_database.py @@ -39,12 +39,12 @@ def generateCompilationDatabase(args): "--remote_download_outputs=all", ] - subprocess.check_call(["bazel", "build"] + bazel_options + [ + subprocess.check_call(["tools/bazel", "build"] + bazel_options + [ "--aspects=@bazel_compdb//:aspects.bzl%compilation_database_aspect", "--output_groups=compdb_files,header_files" ] + args.bazel_targets) - execroot = subprocess.check_output(["bazel", "info", "execution_root"] + + execroot = subprocess.check_output(["tools/bazel", "info", "execution_root"] + bazel_options).decode().strip() compdb = []