diff --git a/BUILD b/BUILD index b33418d118b..e29ba34178b 100644 --- a/BUILD +++ b/BUILD @@ -2285,6 +2285,7 @@ grpc_cc_library( external_deps = [ "absl/base:core_headers", "absl/container:inlined_vector", + "absl/functional:any_invocable", "absl/log:check", "absl/log:log", "absl/status", @@ -2309,6 +2310,7 @@ grpc_cc_library( "grpc_trace", "handshaker", "iomgr", + "orphanable", "promise", "ref_counted_ptr", "resource_quota_api", @@ -3192,9 +3194,11 @@ grpc_cc_library( external_deps = [ "absl/base:core_headers", "absl/container:inlined_vector", + "absl/functional:any_invocable", "absl/log:check", "absl/log:log", "absl/status", + "absl/status:statusor", "absl/strings:str_format", ], language = "c++", @@ -3211,6 +3215,7 @@ grpc_cc_library( "grpc_public_hdrs", "grpc_trace", "iomgr", + "orphanable", "ref_counted_ptr", "//src/core:channel_args", "//src/core:closure", diff --git a/examples/cpp/retry/BUILD b/examples/cpp/retry/BUILD new file mode 100644 index 00000000000..1802c75721e --- /dev/null +++ b/examples/cpp/retry/BUILD @@ -0,0 +1,38 @@ +# Copyright 2024 the gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +licenses(["notice"]) + +cc_binary( + name = "client", + srcs = ["client.cc"], + defines = ["BAZEL_BUILD"], + deps = [ + "//:grpc++", + "//examples/protos:helloworld_cc_grpc", + "@com_google_absl//absl/strings:string_view", + ], +) + +cc_binary( + name = "server", + srcs = ["server.cc"], + defines = ["BAZEL_BUILD"], + deps = [ + "//:grpc++", + "//:grpc++_reflection", + "//examples/protos:helloworld_cc_grpc", + "@com_google_absl//absl/strings:str_format", + ], +) diff --git a/examples/cpp/retry/CMakeLists.txt b/examples/cpp/retry/CMakeLists.txt new file mode 100644 index 00000000000..f01e360bb7a --- /dev/null +++ b/examples/cpp/retry/CMakeLists.txt @@ -0,0 +1,73 @@ +# Copyright 2024 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# cmake build file for C++ retry example. +# Assumes protobuf and gRPC have been installed using cmake. +# See cmake_externalproject/CMakeLists.txt for all-in-one cmake build +# that automatically builds all the dependencies before building retry. + +cmake_minimum_required(VERSION 3.8) + +project(Retry C CXX) + +include(../cmake/common.cmake) + +# Proto file +get_filename_component(hw_proto "../../protos/helloworld.proto" ABSOLUTE) +get_filename_component(hw_proto_path "${hw_proto}" PATH) + +# Generated sources +set(hw_proto_srcs "${CMAKE_CURRENT_BINARY_DIR}/helloworld.pb.cc") +set(hw_proto_hdrs "${CMAKE_CURRENT_BINARY_DIR}/helloworld.pb.h") +set(hw_grpc_srcs "${CMAKE_CURRENT_BINARY_DIR}/helloworld.grpc.pb.cc") +set(hw_grpc_hdrs "${CMAKE_CURRENT_BINARY_DIR}/helloworld.grpc.pb.h") +add_custom_command( + OUTPUT "${hw_proto_srcs}" "${hw_proto_hdrs}" "${hw_grpc_srcs}" "${hw_grpc_hdrs}" + COMMAND ${_PROTOBUF_PROTOC} + ARGS --grpc_out "${CMAKE_CURRENT_BINARY_DIR}" + --cpp_out "${CMAKE_CURRENT_BINARY_DIR}" + -I "${hw_proto_path}" + --plugin=protoc-gen-grpc="${_GRPC_CPP_PLUGIN_EXECUTABLE}" + "${hw_proto}" + DEPENDS "${hw_proto}") + +# Include generated *.pb.h files +include_directories("${CMAKE_CURRENT_BINARY_DIR}") + +# hw_grpc_proto +add_library(hw_grpc_proto + ${hw_grpc_srcs} + ${hw_grpc_hdrs} + ${hw_proto_srcs} + ${hw_proto_hdrs}) +target_link_libraries(hw_grpc_proto + absl::check + ${_REFLECTION} + ${_GRPC_GRPCPP} + ${_PROTOBUF_LIBPROTOBUF}) + +# Targets (client|server) +foreach(_target + client server) + add_executable(${_target} "${_target}.cc") + target_link_libraries(${_target} + hw_grpc_proto + absl::check + absl::flags + absl::flags_parse + absl::log + ${_REFLECTION} + ${_GRPC_GRPCPP} + ${_PROTOBUF_LIBPROTOBUF}) +endforeach() diff --git a/examples/cpp/retry/README.md b/examples/cpp/retry/README.md new file mode 100644 index 00000000000..54ea22c5135 --- /dev/null +++ b/examples/cpp/retry/README.md @@ -0,0 +1,69 @@ +# Retry + +This example shows how to enable and configure retry on gRPC clients. + +## Documentation + +[gRFC for client-side retry support](https://github.com/grpc/proposal/blob/master/A6-client-retries.md) + +## Try it + +This example includes a service implementation that fails requests three times with status +code `Unavailable`, then passes the fourth. The client is configured to make four retry attempts +when receiving an `Unavailable` status code. + +First start the server: + +```bash +$ ./server +``` + +Then run the client: + +```bash +$ ./client +``` + +Expected server output: + +``` +Server listening on 0.0.0.0:50052 +return UNAVAILABLE +return UNAVAILABLE +return UNAVAILABLE +return OK +``` + +Expected client output: + +``` +Greeter received: Hello world +``` + +## Usage + +### Define your retry policy + +Retry is enabled via the service config, which can be provided by the name resolver or +a [GRPC_ARG_SERVICE_CONFIG](https://github.com/grpc/grpc/blob/master/include/grpc/impl/channel_arg_names.h#L207-L209) channel argument. In the below config, we set retry policy for the "helloworld.Greeter" service. + +`maxAttempts`: how many times to attempt the RPC before failing. + +`initialBackoff`, `maxBackoff`, `backoffMultiplier`: configures delay between attempts. + +`retryableStatusCodes`: Retry only when receiving these status codes. + +```c++ +constexpr absl::string_view kRetryPolicy = + "{\"methodConfig\" : [{" + " \"name\" : [{\"service\": \"helloworld.Greeter\"}]," + " \"waitForReady\": true," + " \"retryPolicy\": {" + " \"maxAttempts\": 4," + " \"initialBackoff\": \"1s\"," + " \"maxBackoff\": \"120s\"," + " \"backoffMultiplier\": 1.0," + " \"retryableStatusCodes\": [\"UNAVAILABLE\"]" + " }" + "}]}"; +``` diff --git a/examples/cpp/retry/client.cc b/examples/cpp/retry/client.cc new file mode 100644 index 00000000000..3d490f25d5c --- /dev/null +++ b/examples/cpp/retry/client.cc @@ -0,0 +1,98 @@ +/* + * Copyright 2024 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include "absl/strings/string_view.h" + +#include +#include + +#ifdef BAZEL_BUILD +#include "examples/protos/helloworld.grpc.pb.h" +#else +#include "helloworld.grpc.pb.h" +#endif + +using grpc::Channel; +using grpc::ClientContext; +using grpc::Status; +using helloworld::Greeter; +using helloworld::HelloReply; +using helloworld::HelloRequest; + +constexpr absl::string_view kTargetAddress = "localhost:50052"; + +// clang-format off +constexpr absl::string_view kRetryPolicy = + "{\"methodConfig\" : [{" + " \"name\" : [{\"service\": \"helloworld.Greeter\"}]," + " \"waitForReady\": true," + " \"retryPolicy\": {" + " \"maxAttempts\": 4," + " \"initialBackoff\": \"1s\"," + " \"maxBackoff\": \"120s\"," + " \"backoffMultiplier\": 1.0," + " \"retryableStatusCodes\": [\"UNAVAILABLE\"]" + " }" + "}]}"; +// clang-format on + +class GreeterClient { + public: + GreeterClient(std::shared_ptr channel) + : stub_(Greeter::NewStub(channel)) {} + + // Assembles the client's payload, sends it and presents the response back + // from the server. + std::string SayHello(const std::string& user) { + // Data we are sending to the server. + HelloRequest request; + request.set_name(user); + // Container for the data we expect from the server. + HelloReply reply; + // Context for the client. It could be used to convey extra information to + // the server and/or tweak certain RPC behaviors. + ClientContext context; + // The actual RPC. + Status status = stub_->SayHello(&context, request, &reply); + // Act upon its status. + if (status.ok()) { + return reply.message(); + } else { + std::cout << status.error_code() << ": " << status.error_message() + << std::endl; + return "RPC failed"; + } + } + + private: + std::unique_ptr stub_; +}; + +int main() { + auto channel_args = grpc::ChannelArguments(); + channel_args.SetServiceConfigJSON(std::string(kRetryPolicy)); + GreeterClient greeter(grpc::CreateCustomChannel( + std::string(kTargetAddress), grpc::InsecureChannelCredentials(), + channel_args)); + std::string user("world"); + std::string reply = greeter.SayHello(user); + std::cout << "Greeter received: " << reply << std::endl; + return 0; +} diff --git a/examples/cpp/retry/server.cc b/examples/cpp/retry/server.cc new file mode 100644 index 00000000000..f6a8c602473 --- /dev/null +++ b/examples/cpp/retry/server.cc @@ -0,0 +1,86 @@ +/* + * Copyright 2024 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include +#include +#include + +#ifdef BAZEL_BUILD +#include "examples/protos/helloworld.grpc.pb.h" +#else +#include "helloworld.grpc.pb.h" +#endif + +using grpc::Server; +using grpc::ServerBuilder; +using grpc::ServerContext; +using grpc::Status; +using grpc::StatusCode; +using helloworld::Greeter; +using helloworld::HelloReply; +using helloworld::HelloRequest; + +// Logic and data behind the server's behavior. +class GreeterServiceImpl final : public Greeter::Service { + public: + Status SayHello(ServerContext* context, const HelloRequest* request, + HelloReply* reply) override { + if (++request_counter_ % request_modulo_ != 0) { + // Return an OK status for every request_modulo_ number of requests, + // return UNAVAILABLE otherwise. + std::cout << "return UNAVAILABLE" << std::endl; + return Status(StatusCode::UNAVAILABLE, ""); + } + std::string prefix("Hello "); + reply->set_message(prefix + request->name()); + std::cout << "return OK" << std::endl; + return Status::OK; + } + + private: + static constexpr int request_modulo_ = 4; + int request_counter_ = 0; +}; + +void RunServer(uint16_t port) { + std::string server_address = absl::StrFormat("0.0.0.0:%d", port); + GreeterServiceImpl service; + + grpc::EnableDefaultHealthCheckService(true); + grpc::reflection::InitProtoReflectionServerBuilderPlugin(); + ServerBuilder builder; + // Listen on the given address without any authentication mechanism. + builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); + // Register "service" as the instance through which we'll communicate with + // clients. In this case it corresponds to an *synchronous* service. + builder.RegisterService(&service); + // Finally assemble the server. + std::unique_ptr server(builder.BuildAndStart()); + std::cout << "Server listening on " << server_address << std::endl; + + // Wait for the server to shutdown. Note that some other thread must be + // responsible for shutting down the server for this call to ever return. + server->Wait(); +} + +int main(int argc, char** argv) { + RunServer(/*port=*/50052); + return 0; +} diff --git a/src/core/BUILD b/src/core/BUILD index 93bab4ad4e6..efebfbca3cb 100644 --- a/src/core/BUILD +++ b/src/core/BUILD @@ -1301,6 +1301,7 @@ grpc_cc_library( ], external_deps = [ "absl/base:core_headers", + "absl/functional:any_invocable", "absl/log:check", "absl/status", "absl/status:statusor", @@ -1343,6 +1344,7 @@ grpc_cc_library( "handshaker/endpoint_info/endpoint_info_handshaker.h", ], external_deps = [ + "absl/functional:any_invocable", "absl/status", ], language = "c++", diff --git a/src/core/ext/transport/chaotic_good/client/chaotic_good_connector.cc b/src/core/ext/transport/chaotic_good/client/chaotic_good_connector.cc index 0f71d509ba4..cecb5fa53af 100644 --- a/src/core/ext/transport/chaotic_good/client/chaotic_good_connector.cc +++ b/src/core/ext/transport/chaotic_good/client/chaotic_good_connector.cc @@ -256,20 +256,23 @@ void ChaoticGoodConnector::Connect(const Args& args, Result* result, error); return; } - auto* p = self.release(); auto* chaotic_good_ext = grpc_event_engine::experimental::QueryExtension< grpc_event_engine::experimental::ChaoticGoodExtension>( - endpoint.value().get()); + endpoint->get()); if (chaotic_good_ext != nullptr) { chaotic_good_ext->EnableStatsCollection(/*is_control_channel=*/true); chaotic_good_ext->UseMemoryQuota( ResourceQuota::Default()->memory_quota()); } + auto* p = self.get(); p->handshake_mgr_->DoHandshake( - grpc_event_engine_endpoint_create(std::move(endpoint.value())), + OrphanablePtr( + grpc_event_engine_endpoint_create(std::move(*endpoint))), p->args_.channel_args, p->args_.deadline, nullptr /* acceptor */, - OnHandshakeDone, p); + [self = std::move(self)](absl::StatusOr result) { + self->OnHandshakeDone(std::move(result)); + }); }; event_engine_->Connect( std::move(on_connect), *resolved_addr_, @@ -280,45 +283,37 @@ void ChaoticGoodConnector::Connect(const Args& args, Result* result, std::chrono::seconds(kTimeoutSecs)); } -void ChaoticGoodConnector::OnHandshakeDone(void* arg, grpc_error_handle error) { - auto* args = static_cast(arg); - RefCountedPtr self( - static_cast(args->user_data)); - grpc_slice_buffer_destroy(args->read_buffer); - gpr_free(args->read_buffer); +void ChaoticGoodConnector::OnHandshakeDone( + absl::StatusOr result) { // Start receiving setting frames; { - MutexLock lock(&self->mu_); - if (!error.ok() || self->is_shutdown_) { - if (error.ok()) { + MutexLock lock(&mu_); + if (!result.ok() || is_shutdown_) { + absl::Status error = result.status(); + if (result.ok()) { error = GRPC_ERROR_CREATE("connector shutdown"); - // We were shut down after handshaking completed successfully, so - // destroy the endpoint here. - if (args->endpoint != nullptr) { - grpc_endpoint_destroy(args->endpoint); - } } - self->result_->Reset(); - ExecCtx::Run(DEBUG_LOCATION, std::exchange(self->notify_, nullptr), - error); + result_->Reset(); + ExecCtx::Run(DEBUG_LOCATION, std::exchange(notify_, nullptr), error); return; } } - if (args->endpoint != nullptr) { + if ((*result)->endpoint != nullptr) { CHECK(grpc_event_engine::experimental::grpc_is_event_engine_endpoint( - args->endpoint)); - self->control_endpoint_ = PromiseEndpoint( - grpc_event_engine::experimental:: - grpc_take_wrapped_event_engine_endpoint(args->endpoint), - SliceBuffer()); + (*result)->endpoint.get())); + control_endpoint_ = + PromiseEndpoint(grpc_event_engine::experimental:: + grpc_take_wrapped_event_engine_endpoint( + (*result)->endpoint.release()), + SliceBuffer()); auto activity = MakeActivity( - [self] { + [self = RefAsSubclass()] { return TrySeq(ControlEndpointWriteSettingsFrame(self), ControlEndpointReadSettingsFrame(self), []() { return absl::OkStatus(); }); }, - EventEngineWakeupScheduler(self->event_engine_), - [self](absl::Status status) { + EventEngineWakeupScheduler(event_engine_), + [self = RefAsSubclass()](absl::Status status) { if (GRPC_TRACE_FLAG_ENABLED(chaotic_good)) { gpr_log(GPR_INFO, "ChaoticGoodConnector::OnHandshakeDone: %s", status.ToString().c_str()); @@ -338,17 +333,19 @@ void ChaoticGoodConnector::OnHandshakeDone(void* arg, grpc_error_handle error) { status); } }, - self->arena_, self->event_engine_.get()); - MutexLock lock(&self->mu_); - if (!self->is_shutdown_) { - self->connect_activity_ = std::move(activity); + arena_, event_engine_.get()); + MutexLock lock(&mu_); + if (!is_shutdown_) { + connect_activity_ = std::move(activity); } } else { // Handshaking succeeded but there is no endpoint. - MutexLock lock(&self->mu_); - self->result_->Reset(); + MutexLock lock(&mu_); + result_->Reset(); auto error = GRPC_ERROR_CREATE("handshake complete with empty endpoint."); - ExecCtx::Run(DEBUG_LOCATION, std::exchange(self->notify_, nullptr), error); + ExecCtx::Run( + DEBUG_LOCATION, std::exchange(notify_, nullptr), + absl::InternalError("handshake complete with empty endpoint.")); } } diff --git a/src/core/ext/transport/chaotic_good/client/chaotic_good_connector.h b/src/core/ext/transport/chaotic_good/client/chaotic_good_connector.h index b8db7a52502..caf2c564a14 100644 --- a/src/core/ext/transport/chaotic_good/client/chaotic_good_connector.h +++ b/src/core/ext/transport/chaotic_good/client/chaotic_good_connector.h @@ -77,7 +77,7 @@ class ChaoticGoodConnector : public SubchannelConnector { RefCountedPtr self); static auto WaitForDataEndpointSetup( RefCountedPtr self); - static void OnHandshakeDone(void* arg, grpc_error_handle error); + void OnHandshakeDone(absl::StatusOr result); RefCountedPtr arena_ = SimpleArenaAllocator()->MakeArena(); Mutex mu_; diff --git a/src/core/ext/transport/chaotic_good/server/chaotic_good_server.cc b/src/core/ext/transport/chaotic_good/server/chaotic_good_server.cc index 6964d4d422c..b38e1e4a9df 100644 --- a/src/core/ext/transport/chaotic_good/server/chaotic_good_server.cc +++ b/src/core/ext/transport/chaotic_good/server/chaotic_good_server.cc @@ -211,9 +211,12 @@ ChaoticGoodServerListener::ActiveConnection::HandshakingState::HandshakingState( void ChaoticGoodServerListener::ActiveConnection::HandshakingState::Start( std::unique_ptr endpoint) { handshake_mgr_->DoHandshake( - grpc_event_engine_endpoint_create(std::move(endpoint)), - connection_->args(), GetConnectionDeadline(), nullptr, OnHandshakeDone, - Ref().release()); + OrphanablePtr( + grpc_event_engine_endpoint_create(std::move(endpoint))), + connection_->args(), GetConnectionDeadline(), nullptr, + [self = Ref()](absl::StatusOr result) { + self->OnHandshakeDone(std::move(result)); + }); } auto ChaoticGoodServerListener::ActiveConnection::HandshakingState:: @@ -384,33 +387,28 @@ auto ChaoticGoodServerListener::ActiveConnection::HandshakingState:: } void ChaoticGoodServerListener::ActiveConnection::HandshakingState:: - OnHandshakeDone(void* arg, grpc_error_handle error) { - auto* args = static_cast(arg); - CHECK_NE(args, nullptr); - RefCountedPtr self( - static_cast(args->user_data)); - grpc_slice_buffer_destroy(args->read_buffer); - gpr_free(args->read_buffer); - if (!error.ok()) { - self->connection_->Done( - absl::StrCat("Handshake failed: ", StatusToString(error))); + OnHandshakeDone(absl::StatusOr result) { + if (!result.ok()) { + connection_->Done( + absl::StrCat("Handshake failed: ", result.status().ToString())); return; } - if (args->endpoint == nullptr) { - self->connection_->Done("Server handshake done but has empty endpoint."); + CHECK_NE(*result, nullptr); + if ((*result)->endpoint == nullptr) { + connection_->Done("Server handshake done but has empty endpoint."); return; } CHECK(grpc_event_engine::experimental::grpc_is_event_engine_endpoint( - args->endpoint)); + (*result)->endpoint.get())); auto ee_endpoint = grpc_event_engine::experimental::grpc_take_wrapped_event_engine_endpoint( - args->endpoint); + (*result)->endpoint.release()); auto* chaotic_good_ext = grpc_event_engine::experimental::QueryExtension< grpc_event_engine::experimental::ChaoticGoodExtension>(ee_endpoint.get()); - self->connection_->endpoint_ = + connection_->endpoint_ = PromiseEndpoint(std::move(ee_endpoint), SliceBuffer()); auto activity = MakeActivity( - [self, chaotic_good_ext]() { + [self = Ref(), chaotic_good_ext]() { return TrySeq( Race(EndpointReadSettingsFrame(self), TrySeq(Sleep(Timestamp::Now() + kConnectionDeadline), @@ -430,8 +428,8 @@ void ChaoticGoodServerListener::ActiveConnection::HandshakingState:: return EndpointWriteSettingsFrame(self, is_control_endpoint); }); }, - EventEngineWakeupScheduler(self->connection_->listener_->event_engine_), - [self](absl::Status status) { + EventEngineWakeupScheduler(connection_->listener_->event_engine_), + [self = Ref()](absl::Status status) { if (!status.ok()) { self->connection_->Done( absl::StrCat("Server setting frame handling failed: ", @@ -440,11 +438,10 @@ void ChaoticGoodServerListener::ActiveConnection::HandshakingState:: self->connection_->Done(); } }, - self->connection_->arena_.get(), - self->connection_->listener_->event_engine_.get()); - MutexLock lock(&self->connection_->mu_); - if (self->connection_->orphaned_) return; - self->connection_->receive_settings_activity_ = std::move(activity); + connection_->arena_.get(), connection_->listener_->event_engine_.get()); + MutexLock lock(&connection_->mu_); + if (connection_->orphaned_) return; + connection_->receive_settings_activity_ = std::move(activity); } Timestamp ChaoticGoodServerListener::ActiveConnection::HandshakingState:: diff --git a/src/core/ext/transport/chaotic_good/server/chaotic_good_server.h b/src/core/ext/transport/chaotic_good/server/chaotic_good_server.h index 8511790c03b..8da7bc7513e 100644 --- a/src/core/ext/transport/chaotic_good/server/chaotic_good_server.h +++ b/src/core/ext/transport/chaotic_good/server/chaotic_good_server.h @@ -104,7 +104,7 @@ class ChaoticGoodServerListener final : public Server::ListenerInterface { static auto DataEndpointWriteSettingsFrame( RefCountedPtr self); - static void OnHandshakeDone(void* arg, grpc_error_handle error); + void OnHandshakeDone(absl::StatusOr result); Timestamp GetConnectionDeadline(); const RefCountedPtr connection_; const RefCountedPtr handshake_mgr_; diff --git a/src/core/ext/transport/chttp2/client/chttp2_connector.cc b/src/core/ext/transport/chttp2/client/chttp2_connector.cc index 8711eece3bf..6fb92f0d6f8 100644 --- a/src/core/ext/transport/chttp2/client/chttp2_connector.cc +++ b/src/core/ext/transport/chttp2/client/chttp2_connector.cc @@ -120,10 +120,12 @@ void Chttp2Connector::Connect(const Args& args, Result* result, CoreConfiguration::Get().handshaker_registry().AddHandshakers( HANDSHAKER_CLIENT, channel_args, args_.interested_parties, handshake_mgr_.get()); - Ref().release(); // Ref held by OnHandshakeDone(). - handshake_mgr_->DoHandshake(nullptr /* endpoint */, channel_args, - args.deadline, nullptr /* acceptor */, - OnHandshakeDone, this); + handshake_mgr_->DoHandshake( + /*endpoint=*/nullptr, channel_args, args.deadline, /*acceptor=*/nullptr, + [self = RefAsSubclass()]( + absl::StatusOr result) { + self->OnHandshakeDone(std::move(result)); + }); } void Chttp2Connector::Shutdown(grpc_error_handle error) { @@ -135,54 +137,42 @@ void Chttp2Connector::Shutdown(grpc_error_handle error) { } } -void Chttp2Connector::OnHandshakeDone(void* arg, grpc_error_handle error) { - auto* args = static_cast(arg); - Chttp2Connector* self = static_cast(args->user_data); - { - MutexLock lock(&self->mu_); - if (!error.ok() || self->shutdown_) { - if (error.ok()) { - error = GRPC_ERROR_CREATE("connector shutdown"); - // We were shut down after handshaking completed successfully, so - // destroy the endpoint here. - if (args->endpoint != nullptr) { - grpc_endpoint_destroy(args->endpoint); - grpc_slice_buffer_destroy(args->read_buffer); - gpr_free(args->read_buffer); - } - } - self->result_->Reset(); - NullThenSchedClosure(DEBUG_LOCATION, &self->notify_, error); - } else if (args->endpoint != nullptr) { - self->result_->transport = - grpc_create_chttp2_transport(args->args, args->endpoint, true); - CHECK_NE(self->result_->transport, nullptr); - self->result_->socket_node = - grpc_chttp2_transport_get_socket_node(self->result_->transport); - self->result_->channel_args = args->args; - self->Ref().release(); // Ref held by OnReceiveSettings() - GRPC_CLOSURE_INIT(&self->on_receive_settings_, OnReceiveSettings, self, - grpc_schedule_on_exec_ctx); - grpc_chttp2_transport_start_reading( - self->result_->transport, args->read_buffer, - &self->on_receive_settings_, self->args_.interested_parties, nullptr); - self->timer_handle_ = self->event_engine_->RunAfter( - self->args_.deadline - Timestamp::Now(), - [self = self->RefAsSubclass()] { - ApplicationCallbackExecCtx callback_exec_ctx; - ExecCtx exec_ctx; - self->OnTimeout(); - }); - } else { - // If the handshaking succeeded but there is no endpoint, then the - // handshaker may have handed off the connection to some external - // code. Just verify that exit_early flag is set. - DCHECK(args->exit_early); - NullThenSchedClosure(DEBUG_LOCATION, &self->notify_, error); +void Chttp2Connector::OnHandshakeDone(absl::StatusOr result) { + MutexLock lock(&mu_); + if (!result.ok() || shutdown_) { + if (result.ok()) { + result = GRPC_ERROR_CREATE("connector shutdown"); } - self->handshake_mgr_.reset(); + result_->Reset(); + NullThenSchedClosure(DEBUG_LOCATION, ¬ify_, result.status()); + } else if ((*result)->endpoint != nullptr) { + result_->transport = grpc_create_chttp2_transport( + (*result)->args, std::move((*result)->endpoint), true); + CHECK_NE(result_->transport, nullptr); + result_->socket_node = + grpc_chttp2_transport_get_socket_node(result_->transport); + result_->channel_args = std::move((*result)->args); + Ref().release(); // Ref held by OnReceiveSettings() + GRPC_CLOSURE_INIT(&on_receive_settings_, OnReceiveSettings, this, + grpc_schedule_on_exec_ctx); + grpc_chttp2_transport_start_reading( + result_->transport, (*result)->read_buffer.c_slice_buffer(), + &on_receive_settings_, args_.interested_parties, nullptr); + timer_handle_ = + event_engine_->RunAfter(args_.deadline - Timestamp::Now(), + [self = RefAsSubclass()] { + ApplicationCallbackExecCtx callback_exec_ctx; + ExecCtx exec_ctx; + self->OnTimeout(); + }); + } else { + // If the handshaking succeeded but there is no endpoint, then the + // handshaker may have handed off the connection to some external + // code. Just verify that exit_early flag is set. + DCHECK((*result)->exit_early); + NullThenSchedClosure(DEBUG_LOCATION, ¬ify_, result.status()); } - self->Unref(); + handshake_mgr_.reset(); } void Chttp2Connector::OnReceiveSettings(void* arg, grpc_error_handle error) { @@ -380,12 +370,12 @@ grpc_channel* grpc_channel_create_from_fd(const char* target, int fd, int flags = fcntl(fd, F_GETFL, 0); CHECK_EQ(fcntl(fd, F_SETFL, flags | O_NONBLOCK), 0); - grpc_endpoint* client = grpc_tcp_create_from_fd( + grpc_core::OrphanablePtr client(grpc_tcp_create_from_fd( grpc_fd_create(fd, "client", true), grpc_event_engine::experimental::ChannelArgsEndpointConfig(final_args), - "fd-client"); + "fd-client")); grpc_core::Transport* transport = - grpc_create_chttp2_transport(final_args, client, true); + grpc_create_chttp2_transport(final_args, std::move(client), true); CHECK(transport); auto channel = grpc_core::ChannelCreate( target, final_args, GRPC_CLIENT_DIRECT_CHANNEL, transport); diff --git a/src/core/ext/transport/chttp2/client/chttp2_connector.h b/src/core/ext/transport/chttp2/client/chttp2_connector.h index 679c7db6ce9..0cb08474ca6 100644 --- a/src/core/ext/transport/chttp2/client/chttp2_connector.h +++ b/src/core/ext/transport/chttp2/client/chttp2_connector.h @@ -41,7 +41,7 @@ class Chttp2Connector : public SubchannelConnector { void Shutdown(grpc_error_handle error) override; private: - static void OnHandshakeDone(void* arg, grpc_error_handle error); + void OnHandshakeDone(absl::StatusOr result); static void OnReceiveSettings(void* arg, grpc_error_handle error); void OnTimeout() ABSL_LOCKS_EXCLUDED(mu_); diff --git a/src/core/ext/transport/chttp2/server/chttp2_server.cc b/src/core/ext/transport/chttp2/server/chttp2_server.cc index b20d2a55548..37196c21b60 100644 --- a/src/core/ext/transport/chttp2/server/chttp2_server.cc +++ b/src/core/ext/transport/chttp2/server/chttp2_server.cc @@ -107,6 +107,13 @@ const char kUnixUriPrefix[] = "unix:"; const char kUnixAbstractUriPrefix[] = "unix-abstract:"; const char kVSockUriPrefix[] = "vsock:"; +struct AcceptorDeleter { + void operator()(grpc_tcp_server_acceptor* acceptor) const { + gpr_free(acceptor); + } +}; +using AcceptorPtr = std::unique_ptr; + class Chttp2ServerListener : public Server::ListenerInterface { public: static grpc_error_handle Create(Server* server, grpc_resolved_address* addr, @@ -167,15 +174,15 @@ class Chttp2ServerListener : public Server::ListenerInterface { class HandshakingState : public InternallyRefCounted { public: HandshakingState(RefCountedPtr connection_ref, - grpc_pollset* accepting_pollset, - grpc_tcp_server_acceptor* acceptor, + grpc_pollset* accepting_pollset, AcceptorPtr acceptor, const ChannelArgs& args); ~HandshakingState() override; void Orphan() override; - void Start(grpc_endpoint* endpoint, const ChannelArgs& args); + void Start(OrphanablePtr endpoint, + const ChannelArgs& args); // Needed to be able to grab an external ref in // ActiveConnection::Start() @@ -184,10 +191,10 @@ class Chttp2ServerListener : public Server::ListenerInterface { private: void OnTimeout() ABSL_LOCKS_EXCLUDED(&connection_->mu_); static void OnReceiveSettings(void* arg, grpc_error_handle /* error */); - static void OnHandshakeDone(void* arg, grpc_error_handle error); + void OnHandshakeDone(absl::StatusOr result); RefCountedPtr const connection_; grpc_pollset* const accepting_pollset_; - grpc_tcp_server_acceptor* acceptor_; + AcceptorPtr acceptor_; RefCountedPtr handshake_mgr_ ABSL_GUARDED_BY(&connection_->mu_); // State for enforcing handshake timeout on receiving HTTP/2 settings. @@ -198,8 +205,7 @@ class Chttp2ServerListener : public Server::ListenerInterface { grpc_pollset_set* const interested_parties_; }; - ActiveConnection(grpc_pollset* accepting_pollset, - grpc_tcp_server_acceptor* acceptor, + ActiveConnection(grpc_pollset* accepting_pollset, AcceptorPtr acceptor, EventEngine* event_engine, const ChannelArgs& args, MemoryOwner memory_owner); ~ActiveConnection() override; @@ -209,7 +215,7 @@ class Chttp2ServerListener : public Server::ListenerInterface { void SendGoAway(); void Start(RefCountedPtr listener, - grpc_endpoint* endpoint, const ChannelArgs& args); + OrphanablePtr endpoint, const ChannelArgs& args); // Needed to be able to grab an external ref in // Chttp2ServerListener::OnAccept() @@ -367,11 +373,11 @@ Timestamp GetConnectionDeadline(const ChannelArgs& args) { Chttp2ServerListener::ActiveConnection::HandshakingState::HandshakingState( RefCountedPtr connection_ref, - grpc_pollset* accepting_pollset, grpc_tcp_server_acceptor* acceptor, + grpc_pollset* accepting_pollset, AcceptorPtr acceptor, const ChannelArgs& args) : connection_(std::move(connection_ref)), accepting_pollset_(accepting_pollset), - acceptor_(acceptor), + acceptor_(std::move(acceptor)), handshake_mgr_(MakeRefCounted()), deadline_(GetConnectionDeadline(args)), interested_parties_(grpc_pollset_set_create()) { @@ -387,7 +393,6 @@ Chttp2ServerListener::ActiveConnection::HandshakingState::~HandshakingState() { grpc_pollset_set_del_pollset(interested_parties_, accepting_pollset_); } grpc_pollset_set_destroy(interested_parties_); - gpr_free(acceptor_); } void Chttp2ServerListener::ActiveConnection::HandshakingState::Orphan() { @@ -401,16 +406,18 @@ void Chttp2ServerListener::ActiveConnection::HandshakingState::Orphan() { } void Chttp2ServerListener::ActiveConnection::HandshakingState::Start( - grpc_endpoint* endpoint, const ChannelArgs& channel_args) { - Ref().release(); // Held by OnHandshakeDone + OrphanablePtr endpoint, const ChannelArgs& channel_args) { RefCountedPtr handshake_mgr; { MutexLock lock(&connection_->mu_); if (handshake_mgr_ == nullptr) return; handshake_mgr = handshake_mgr_; } - handshake_mgr->DoHandshake(endpoint, channel_args, deadline_, acceptor_, - OnHandshakeDone, this); + handshake_mgr->DoHandshake( + std::move(endpoint), channel_args, deadline_, acceptor_.get(), + [self = Ref()](absl::StatusOr result) { + self->OnHandshakeDone(std::move(result)); + }); } void Chttp2ServerListener::ActiveConnection::HandshakingState::OnTimeout() { @@ -444,61 +451,50 @@ void Chttp2ServerListener::ActiveConnection::HandshakingState:: } void Chttp2ServerListener::ActiveConnection::HandshakingState::OnHandshakeDone( - void* arg, grpc_error_handle error) { - auto* args = static_cast(arg); - HandshakingState* self = static_cast(args->user_data); + absl::StatusOr result) { OrphanablePtr handshaking_state_ref; RefCountedPtr handshake_mgr; bool cleanup_connection = false; bool release_connection = false; { - MutexLock connection_lock(&self->connection_->mu_); - if (!error.ok() || self->connection_->shutdown_) { - std::string error_str = StatusToString(error); + MutexLock connection_lock(&connection_->mu_); + if (!result.ok() || connection_->shutdown_) { cleanup_connection = true; release_connection = true; - if (error.ok() && args->endpoint != nullptr) { - // We were shut down or stopped serving after handshaking completed - // successfully, so destroy the endpoint here. - grpc_endpoint_destroy(args->endpoint); - grpc_slice_buffer_destroy(args->read_buffer); - gpr_free(args->read_buffer); - } } else { // If the handshaking succeeded but there is no endpoint, then the // handshaker may have handed off the connection to some external // code, so we can just clean up here without creating a transport. - if (args->endpoint != nullptr) { + if ((*result)->endpoint != nullptr) { RefCountedPtr transport = - grpc_create_chttp2_transport(args->args, args->endpoint, false) + grpc_create_chttp2_transport((*result)->args, + std::move((*result)->endpoint), false) ->Ref(); grpc_error_handle channel_init_err = - self->connection_->listener_->server_->SetupTransport( - transport.get(), self->accepting_pollset_, args->args, + connection_->listener_->server_->SetupTransport( + transport.get(), accepting_pollset_, (*result)->args, grpc_chttp2_transport_get_socket_node(transport.get())); if (channel_init_err.ok()) { // Use notify_on_receive_settings callback to enforce the // handshake deadline. - self->connection_->transport_ = + connection_->transport_ = DownCast(transport.get())->Ref(); - self->Ref().release(); // Held by OnReceiveSettings(). - GRPC_CLOSURE_INIT(&self->on_receive_settings_, OnReceiveSettings, - self, grpc_schedule_on_exec_ctx); + Ref().release(); // Held by OnReceiveSettings(). + GRPC_CLOSURE_INIT(&on_receive_settings_, OnReceiveSettings, this, + grpc_schedule_on_exec_ctx); // If the listener has been configured with a config fetcher, we // need to watch on the transport being closed so that we can an // updated list of active connections. grpc_closure* on_close = nullptr; - if (self->connection_->listener_->config_fetcher_watcher_ != - nullptr) { + if (connection_->listener_->config_fetcher_watcher_ != nullptr) { // Refs helds by OnClose() - self->connection_->Ref().release(); - on_close = &self->connection_->on_close_; + connection_->Ref().release(); + on_close = &connection_->on_close_; } else { // Remove the connection from the connections_ map since OnClose() // will not be invoked when a config fetcher is set. auto connection_quota = - self->connection_->listener_->connection_quota_->Ref() - .release(); + connection_->listener_->connection_quota_->Ref().release(); auto on_close_transport = [](void* arg, grpc_error_handle /*handle*/) { ConnectionQuota* connection_quota = @@ -511,11 +507,10 @@ void Chttp2ServerListener::ActiveConnection::HandshakingState::OnHandshakeDone( cleanup_connection = true; } grpc_chttp2_transport_start_reading( - transport.get(), args->read_buffer, &self->on_receive_settings_, - nullptr, on_close); - self->timer_handle_ = self->connection_->event_engine_->RunAfter( - self->deadline_ - Timestamp::Now(), - [self = self->Ref()]() mutable { + transport.get(), (*result)->read_buffer.c_slice_buffer(), + &on_receive_settings_, nullptr, on_close); + timer_handle_ = connection_->event_engine_->RunAfter( + deadline_ - Timestamp::Now(), [self = Ref()]() mutable { ApplicationCallbackExecCtx callback_exec_ctx; ExecCtx exec_ctx; self->OnTimeout(); @@ -527,8 +522,6 @@ void Chttp2ServerListener::ActiveConnection::HandshakingState::OnHandshakeDone( LOG(ERROR) << "Failed to create channel: " << StatusToString(channel_init_err); transport->Orphan(); - grpc_slice_buffer_destroy(args->read_buffer); - gpr_free(args->read_buffer); cleanup_connection = true; release_connection = true; } @@ -541,25 +534,21 @@ void Chttp2ServerListener::ActiveConnection::HandshakingState::OnHandshakeDone( // shutdown the handshake when the listener needs to stop serving. // Avoid calling the destructor of HandshakeManager and HandshakingState // from within the critical region. - handshake_mgr = std::move(self->handshake_mgr_); - handshaking_state_ref = std::move(self->connection_->handshaking_state_); + handshake_mgr = std::move(handshake_mgr_); + handshaking_state_ref = std::move(connection_->handshaking_state_); } - gpr_free(self->acceptor_); - self->acceptor_ = nullptr; OrphanablePtr connection; if (cleanup_connection) { - MutexLock listener_lock(&self->connection_->listener_->mu_); + MutexLock listener_lock(&connection_->listener_->mu_); if (release_connection) { - self->connection_->listener_->connection_quota_->ReleaseConnections(1); + connection_->listener_->connection_quota_->ReleaseConnections(1); } - auto it = self->connection_->listener_->connections_.find( - self->connection_.get()); - if (it != self->connection_->listener_->connections_.end()) { + auto it = connection_->listener_->connections_.find(connection_.get()); + if (it != connection_->listener_->connections_.end()) { connection = std::move(it->second); - self->connection_->listener_->connections_.erase(it); + connection_->listener_->connections_.erase(it); } } - self->Unref(); } // @@ -567,11 +556,11 @@ void Chttp2ServerListener::ActiveConnection::HandshakingState::OnHandshakeDone( // Chttp2ServerListener::ActiveConnection::ActiveConnection( - grpc_pollset* accepting_pollset, grpc_tcp_server_acceptor* acceptor, + grpc_pollset* accepting_pollset, AcceptorPtr acceptor, EventEngine* event_engine, const ChannelArgs& args, MemoryOwner memory_owner) : handshaking_state_(memory_owner.MakeOrphanable( - Ref(), accepting_pollset, acceptor, args)), + Ref(), accepting_pollset, std::move(acceptor), args)), event_engine_(event_engine) { GRPC_CLOSURE_INIT(&on_close_, ActiveConnection::OnClose, this, grpc_schedule_on_exec_ctx); @@ -625,29 +614,24 @@ void Chttp2ServerListener::ActiveConnection::SendGoAway() { } void Chttp2ServerListener::ActiveConnection::Start( - RefCountedPtr listener, grpc_endpoint* endpoint, - const ChannelArgs& args) { - RefCountedPtr handshaking_state_ref; + RefCountedPtr listener, + OrphanablePtr endpoint, const ChannelArgs& args) { listener_ = std::move(listener); if (listener_->tcp_server_ != nullptr) { grpc_tcp_server_ref(listener_->tcp_server_); } + RefCountedPtr handshaking_state_ref; { - ReleasableMutexLock lock(&mu_); - if (shutdown_) { - lock.Release(); - // If the Connection is already shutdown at this point, it implies the - // owning Chttp2ServerListener and all associated ActiveConnections have - // been orphaned. The generated endpoints need to be shutdown here to - // ensure the tcp connections are closed appropriately. - grpc_endpoint_destroy(endpoint); - return; - } + MutexLock lock(&mu_); + // If the Connection is already shutdown at this point, it implies the + // owning Chttp2ServerListener and all associated ActiveConnections have + // been orphaned. + if (shutdown_) return; // Hold a ref to HandshakingState to allow starting the handshake outside // the critical region. handshaking_state_ref = handshaking_state_->Ref(); } - handshaking_state_ref->Start(endpoint, args); + handshaking_state_ref->Start(std::move(endpoint), args); } void Chttp2ServerListener::ActiveConnection::OnClose( @@ -841,48 +825,41 @@ void Chttp2ServerListener::AcceptConnectedEndpoint( void Chttp2ServerListener::OnAccept(void* arg, grpc_endpoint* tcp, grpc_pollset* accepting_pollset, - grpc_tcp_server_acceptor* acceptor) { + grpc_tcp_server_acceptor* server_acceptor) { Chttp2ServerListener* self = static_cast(arg); ChannelArgs args = self->args_; + OrphanablePtr endpoint(tcp); + AcceptorPtr acceptor(server_acceptor); RefCountedPtr connection_manager; { MutexLock lock(&self->mu_); connection_manager = self->connection_manager_; } - auto endpoint_cleanup = [&]() { - grpc_endpoint_destroy(tcp); - gpr_free(acceptor); - }; if (!self->connection_quota_->AllowIncomingConnection( - self->memory_quota_, grpc_endpoint_get_peer(tcp))) { - endpoint_cleanup(); + self->memory_quota_, grpc_endpoint_get_peer(endpoint.get()))) { return; } if (self->config_fetcher_ != nullptr) { if (connection_manager == nullptr) { - endpoint_cleanup(); return; } absl::StatusOr args_result = connection_manager->UpdateChannelArgsForConnection(args, tcp); if (!args_result.ok()) { - endpoint_cleanup(); return; } grpc_error_handle error; args = self->args_modifier_(*args_result, &error); if (!error.ok()) { - endpoint_cleanup(); return; } } auto memory_owner = self->memory_quota_->CreateMemoryOwner(); EventEngine* const event_engine = self->args_.GetObject(); auto connection = memory_owner.MakeOrphanable( - accepting_pollset, acceptor, event_engine, args, std::move(memory_owner)); - // We no longer own acceptor - acceptor = nullptr; + accepting_pollset, std::move(acceptor), event_engine, args, + std::move(memory_owner)); // Hold a ref to connection to allow starting handshake outside the // critical region RefCountedPtr connection_ref = connection->Ref(); @@ -902,10 +879,8 @@ void Chttp2ServerListener::OnAccept(void* arg, grpc_endpoint* tcp, self->connections_.emplace(connection.get(), std::move(connection)); } } - if (connection != nullptr) { - endpoint_cleanup(); - } else { - connection_ref->Start(std::move(listener_ref), tcp, args); + if (connection == nullptr) { + connection_ref->Start(std::move(listener_ref), std::move(endpoint), args); } } @@ -1161,15 +1136,17 @@ void grpc_server_add_channel_from_fd(grpc_server* server, int fd, std::string name = absl::StrCat("fd:", fd); auto memory_quota = server_args.GetObject()->memory_quota(); - grpc_endpoint* server_endpoint = grpc_tcp_create_from_fd( - grpc_fd_create(fd, name.c_str(), true), - grpc_event_engine::experimental::ChannelArgsEndpointConfig(server_args), - name); + grpc_core::OrphanablePtr server_endpoint( + grpc_tcp_create_from_fd( + grpc_fd_create(fd, name.c_str(), true), + grpc_event_engine::experimental::ChannelArgsEndpointConfig( + server_args), + name)); for (grpc_pollset* pollset : core_server->pollsets()) { - grpc_endpoint_add_to_pollset(server_endpoint, pollset); + grpc_endpoint_add_to_pollset(server_endpoint.get(), pollset); } grpc_core::Transport* transport = grpc_create_chttp2_transport( - server_args, server_endpoint, false // is_client + server_args, std::move(server_endpoint), false // is_client ); grpc_error_handle error = core_server->SetupTransport(transport, nullptr, server_args, nullptr); diff --git a/src/core/ext/transport/chttp2/transport/chttp2_transport.cc b/src/core/ext/transport/chttp2/transport/chttp2_transport.cc index fe00b2e4bdc..349e6ceb575 100644 --- a/src/core/ext/transport/chttp2/transport/chttp2_transport.cc +++ b/src/core/ext/transport/chttp2/transport/chttp2_transport.cc @@ -84,6 +84,7 @@ #include "src/core/lib/gprpp/status_helper.h" #include "src/core/lib/gprpp/time.h" #include "src/core/lib/iomgr/combiner.h" +#include "src/core/lib/iomgr/endpoint.h" #include "src/core/lib/iomgr/error.h" #include "src/core/lib/iomgr/ev_posix.h" #include "src/core/lib/iomgr/event_engine_shims/endpoint.h" @@ -378,8 +379,6 @@ grpc_chttp2_transport::~grpc_chttp2_transport() { channelz_socket.reset(); } - if (ep != nullptr) grpc_endpoint_destroy(ep); - grpc_slice_buffer_destroy(&qbuf); grpc_error_handle error = GRPC_ERROR_CREATE("Transport destroyed"); @@ -495,7 +494,7 @@ static void read_channel_args(grpc_chttp2_transport* t, .value_or(GRPC_ENABLE_CHANNELZ_DEFAULT)) { t->channelz_socket = grpc_core::MakeRefCounted( - std::string(grpc_endpoint_get_local_address(t->ep)), + std::string(grpc_endpoint_get_local_address(t->ep.get())), std::string(t->peer_string.as_string_view()), absl::StrCat(t->GetTransportName(), " ", t->peer_string.as_string_view()), @@ -589,11 +588,11 @@ using grpc_event_engine::experimental::QueryExtension; using grpc_event_engine::experimental::TcpTraceExtension; grpc_chttp2_transport::grpc_chttp2_transport( - const grpc_core::ChannelArgs& channel_args, grpc_endpoint* ep, - bool is_client) - : ep(ep), + const grpc_core::ChannelArgs& channel_args, + grpc_core::OrphanablePtr endpoint, bool is_client) + : ep(std::move(endpoint)), peer_string( - grpc_core::Slice::FromCopiedString(grpc_endpoint_get_peer(ep))), + grpc_core::Slice::FromCopiedString(grpc_endpoint_get_peer(ep.get()))), memory_owner(channel_args.GetObject() ->memory_quota() ->CreateMemoryOwner()), @@ -617,10 +616,11 @@ grpc_chttp2_transport::grpc_chttp2_transport( context_list = new grpc_core::ContextList(); if (channel_args.GetBool(GRPC_ARG_TCP_TRACING_ENABLED).value_or(false) && - grpc_event_engine::experimental::grpc_is_event_engine_endpoint(ep)) { + grpc_event_engine::experimental::grpc_is_event_engine_endpoint( + ep.get())) { auto epte = QueryExtension( grpc_event_engine::experimental::grpc_get_wrapped_event_engine_endpoint( - ep)); + ep.get())); if (epte != nullptr) { epte->InitializeAndReturnTcpTracer(); } @@ -763,17 +763,16 @@ static void close_transport_locked(grpc_chttp2_transport* t, CHECK(t->write_state == GRPC_CHTTP2_WRITE_STATE_IDLE); if (t->interested_parties_until_recv_settings != nullptr) { grpc_endpoint_delete_from_pollset_set( - t->ep, t->interested_parties_until_recv_settings); + t->ep.get(), t->interested_parties_until_recv_settings); t->interested_parties_until_recv_settings = nullptr; } grpc_core::MutexLock lock(&t->ep_destroy_mu); - grpc_endpoint_destroy(t->ep); - t->ep = nullptr; + t->ep.reset(); } if (t->notify_on_receive_settings != nullptr) { if (t->interested_parties_until_recv_settings != nullptr) { grpc_endpoint_delete_from_pollset_set( - t->ep, t->interested_parties_until_recv_settings); + t->ep.get(), t->interested_parties_until_recv_settings); t->interested_parties_until_recv_settings = nullptr; } grpc_core::ExecCtx::Run(DEBUG_LOCATION, t->notify_on_receive_settings, @@ -1061,7 +1060,7 @@ static void write_action(grpc_chttp2_transport* t) { << (t->is_client ? "CLIENT" : "SERVER") << "[" << t << "]: Write " << t->outbuf.Length() << " bytes"; t->write_size_policy.BeginWrite(t->outbuf.Length()); - grpc_endpoint_write(t->ep, t->outbuf.c_slice_buffer(), + grpc_endpoint_write(t->ep.get(), t->outbuf.c_slice_buffer(), grpc_core::InitTransportClosure( t->Ref(), &t->write_action_end_locked), cl, max_frame_size); @@ -1939,13 +1938,13 @@ static void perform_transport_op_locked(void* stream_op, if (op->bind_pollset) { if (t->ep != nullptr) { - grpc_endpoint_add_to_pollset(t->ep, op->bind_pollset); + grpc_endpoint_add_to_pollset(t->ep.get(), op->bind_pollset); } } if (op->bind_pollset_set) { if (t->ep != nullptr) { - grpc_endpoint_add_to_pollset_set(t->ep, op->bind_pollset_set); + grpc_endpoint_add_to_pollset_set(t->ep.get(), op->bind_pollset_set); } } @@ -2763,7 +2762,7 @@ static void continue_read_action_locked( grpc_core::RefCountedPtr t) { const bool urgent = !t->goaway_error.ok(); auto* tp = t.get(); - grpc_endpoint_read(tp->ep, &tp->read_buffer, + grpc_endpoint_read(tp->ep.get(), &tp->read_buffer, grpc_core::InitTransportClosure( std::move(t), &tp->read_action_locked), urgent, grpc_chttp2_min_read_progress_size(tp)); @@ -3026,7 +3025,7 @@ void grpc_chttp2_transport::SetPollset(grpc_stream* /*gs*/, // actually uses pollsets. if (strcmp(grpc_get_poll_strategy_name(), "poll") != 0) return; grpc_core::MutexLock lock(&ep_destroy_mu); - if (ep != nullptr) grpc_endpoint_add_to_pollset(ep, pollset); + if (ep != nullptr) grpc_endpoint_add_to_pollset(ep.get(), pollset); } void grpc_chttp2_transport::SetPollsetSet(grpc_stream* /*gs*/, @@ -3036,7 +3035,7 @@ void grpc_chttp2_transport::SetPollsetSet(grpc_stream* /*gs*/, // actually uses pollsets. if (strcmp(grpc_get_poll_strategy_name(), "poll") != 0) return; grpc_core::MutexLock lock(&ep_destroy_mu); - if (ep != nullptr) grpc_endpoint_add_to_pollset_set(ep, pollset_set); + if (ep != nullptr) grpc_endpoint_add_to_pollset_set(ep.get(), pollset_set); } // @@ -3215,9 +3214,9 @@ grpc_chttp2_transport_get_socket_node(grpc_core::Transport* transport) { } grpc_core::Transport* grpc_create_chttp2_transport( - const grpc_core::ChannelArgs& channel_args, grpc_endpoint* ep, - bool is_client) { - return new grpc_chttp2_transport(channel_args, ep, is_client); + const grpc_core::ChannelArgs& channel_args, + grpc_core::OrphanablePtr ep, bool is_client) { + return new grpc_chttp2_transport(channel_args, std::move(ep), is_client); } void grpc_chttp2_transport_start_reading( @@ -3228,7 +3227,6 @@ void grpc_chttp2_transport_start_reading( auto t = reinterpret_cast(transport)->Ref(); if (read_buffer != nullptr) { grpc_slice_buffer_move_into(read_buffer, &t->read_buffer); - gpr_free(read_buffer); } auto* tp = t.get(); tp->combiner->Run( @@ -3240,7 +3238,7 @@ void grpc_chttp2_transport_start_reading( if (t->ep != nullptr && interested_parties_until_recv_settings != nullptr) { grpc_endpoint_delete_from_pollset_set( - t->ep, interested_parties_until_recv_settings); + t->ep.get(), interested_parties_until_recv_settings); } grpc_core::ExecCtx::Run(DEBUG_LOCATION, notify_on_receive_settings, t->closed_with_error); diff --git a/src/core/ext/transport/chttp2/transport/chttp2_transport.h b/src/core/ext/transport/chttp2/transport/chttp2_transport.h index 1bcb8a9ae10..d2b93298038 100644 --- a/src/core/ext/transport/chttp2/transport/chttp2_transport.h +++ b/src/core/ext/transport/chttp2/transport/chttp2_transport.h @@ -44,8 +44,8 @@ /// from the caller; if the caller still needs the resource_user after creating /// a transport, the caller must take another ref. grpc_core::Transport* grpc_create_chttp2_transport( - const grpc_core::ChannelArgs& channel_args, grpc_endpoint* ep, - bool is_client); + const grpc_core::ChannelArgs& channel_args, + grpc_core::OrphanablePtr ep, bool is_client); grpc_core::RefCountedPtr grpc_chttp2_transport_get_socket_node(grpc_core::Transport* transport); diff --git a/src/core/ext/transport/chttp2/transport/frame_settings.cc b/src/core/ext/transport/chttp2/transport/frame_settings.cc index bea244c8187..07b07d41b13 100644 --- a/src/core/ext/transport/chttp2/transport/frame_settings.cc +++ b/src/core/ext/transport/chttp2/transport/frame_settings.cc @@ -110,7 +110,7 @@ grpc_error_handle grpc_chttp2_settings_parser_parse(void* p, if (t->notify_on_receive_settings != nullptr) { if (t->interested_parties_until_recv_settings != nullptr) { grpc_endpoint_delete_from_pollset_set( - t->ep, t->interested_parties_until_recv_settings); + t->ep.get(), t->interested_parties_until_recv_settings); t->interested_parties_until_recv_settings = nullptr; } grpc_core::ExecCtx::Run(DEBUG_LOCATION, diff --git a/src/core/ext/transport/chttp2/transport/internal.h b/src/core/ext/transport/chttp2/transport/internal.h index 6e077a835ed..1f41d7ea10f 100644 --- a/src/core/ext/transport/chttp2/transport/internal.h +++ b/src/core/ext/transport/chttp2/transport/internal.h @@ -226,7 +226,8 @@ typedef enum { struct grpc_chttp2_transport final : public grpc_core::FilterStackTransport, public grpc_core::KeepsGrpcInitialized { grpc_chttp2_transport(const grpc_core::ChannelArgs& channel_args, - grpc_endpoint* ep, bool is_client); + grpc_core::OrphanablePtr endpoint, + bool is_client); ~grpc_chttp2_transport() override; void Orphan() override; @@ -257,7 +258,7 @@ struct grpc_chttp2_transport final : public grpc_core::FilterStackTransport, grpc_pollset_set* pollset_set) override; void PerformOp(grpc_transport_op* op) override; - grpc_endpoint* ep; + grpc_core::OrphanablePtr ep; grpc_core::Mutex ep_destroy_mu; // Guards endpoint destruction only. grpc_core::Slice peer_string; diff --git a/src/core/ext/transport/chttp2/transport/parsing.cc b/src/core/ext/transport/chttp2/transport/parsing.cc index e0a8297b31a..851670b5555 100644 --- a/src/core/ext/transport/chttp2/transport/parsing.cc +++ b/src/core/ext/transport/chttp2/transport/parsing.cc @@ -717,7 +717,7 @@ static grpc_error_handle init_header_frame_parser(grpc_chttp2_transport* t, gpr_log(GPR_INFO, "[t:%p fd:%d peer:%s] Accepting new stream; " "num_incoming_streams_before_settings_ack=%u", - t, grpc_endpoint_get_fd(t->ep), + t, grpc_endpoint_get_fd(t->ep.get()), std::string(t->peer_string.as_string_view()).c_str(), t->num_incoming_streams_before_settings_ack); } diff --git a/src/core/ext/transport/chttp2/transport/writing.cc b/src/core/ext/transport/chttp2/transport/writing.cc index 7112e1ffbf3..6ffdcea46f0 100644 --- a/src/core/ext/transport/chttp2/transport/writing.cc +++ b/src/core/ext/transport/chttp2/transport/writing.cc @@ -676,7 +676,7 @@ grpc_chttp2_begin_write_result grpc_chttp2_begin_write( num_stream_bytes = t->outbuf.c_slice_buffer()->length - orig_len; s->byte_counter += static_cast(num_stream_bytes); ++s->write_counter; - if (s->traced && grpc_endpoint_can_track_err(t->ep)) { + if (s->traced && grpc_endpoint_can_track_err(t->ep.get())) { grpc_core::CopyContextFn copy_context_fn = grpc_core::GrpcHttp2GetCopyContextFn(); if (copy_context_fn != nullptr && diff --git a/src/core/handshaker/endpoint_info/endpoint_info_handshaker.cc b/src/core/handshaker/endpoint_info/endpoint_info_handshaker.cc index 7db1842d5d8..153eddbeecc 100644 --- a/src/core/handshaker/endpoint_info/endpoint_info_handshaker.cc +++ b/src/core/handshaker/endpoint_info/endpoint_info_handshaker.cc @@ -17,7 +17,9 @@ #include "src/core/handshaker/endpoint_info/endpoint_info_handshaker.h" #include +#include +#include "absl/functional/any_invocable.h" #include "absl/status/status.h" #include @@ -38,17 +40,17 @@ namespace { class EndpointInfoHandshaker : public Handshaker { public: - const char* name() const override { return "endpoint_info"; } + absl::string_view name() const override { return "endpoint_info"; } - void DoHandshake(grpc_tcp_server_acceptor* /*acceptor*/, - grpc_closure* on_handshake_done, - HandshakerArgs* args) override { + void DoHandshake( + HandshakerArgs* args, + absl::AnyInvocable on_handshake_done) override { args->args = args->args .Set(GRPC_ARG_ENDPOINT_LOCAL_ADDRESS, - grpc_endpoint_get_local_address(args->endpoint)) + grpc_endpoint_get_local_address(args->endpoint.get())) .Set(GRPC_ARG_ENDPOINT_PEER_ADDRESS, - grpc_endpoint_get_peer(args->endpoint)); - ExecCtx::Run(DEBUG_LOCATION, on_handshake_done, absl::OkStatus()); + grpc_endpoint_get_peer(args->endpoint.get())); + InvokeOnHandshakeDone(args, std::move(on_handshake_done), absl::OkStatus()); } void Shutdown(grpc_error_handle /*why*/) override {} diff --git a/src/core/handshaker/handshaker.cc b/src/core/handshaker/handshaker.cc index 4279755cb6d..13ad9375053 100644 --- a/src/core/handshaker/handshaker.cc +++ b/src/core/handshaker/handshaker.cc @@ -23,8 +23,10 @@ #include #include +#include "absl/functional/any_invocable.h" #include "absl/log/check.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include @@ -38,23 +40,37 @@ #include "src/core/lib/debug/trace.h" #include "src/core/lib/gprpp/debug_location.h" #include "src/core/lib/gprpp/status_helper.h" +#include "src/core/lib/gprpp/time.h" +#include "src/core/lib/iomgr/endpoint.h" +#include "src/core/lib/iomgr/error.h" #include "src/core/lib/iomgr/event_engine_shims/endpoint.h" #include "src/core/lib/iomgr/exec_ctx.h" +using ::grpc_event_engine::experimental::EventEngine; + namespace grpc_core { -namespace { +void Handshaker::InvokeOnHandshakeDone( + HandshakerArgs* args, + absl::AnyInvocable on_handshake_done, + absl::Status status) { + args->event_engine->Run([on_handshake_done = std::move(on_handshake_done), + status = std::move(status)]() mutable { + ApplicationCallbackExecCtx callback_exec_ctx; + ExecCtx exec_ctx; + on_handshake_done(std::move(status)); + // Destroy callback while ExecCtx is still in scope. + on_handshake_done = nullptr; + }); +} -using ::grpc_event_engine::experimental::EventEngine; +namespace { std::string HandshakerArgsString(HandshakerArgs* args) { - size_t read_buffer_length = - args->read_buffer != nullptr ? args->read_buffer->length : 0; - return absl::StrFormat( - "{endpoint=%p, args=%s, read_buffer=%p (length=%" PRIuPTR - "), exit_early=%d}", - args->endpoint, args->args.ToString(), args->read_buffer, - read_buffer_length, args->exit_early); + return absl::StrFormat("{endpoint=%p, args=%s, read_buffer.Length()=%" PRIuPTR + ", exit_early=%d}", + args->endpoint.get(), args->args.ToString(), + args->read_buffer.Length(), args->exit_early); } } // namespace @@ -69,155 +85,129 @@ void HandshakeManager::Add(RefCountedPtr handshaker) { gpr_log( GPR_INFO, "handshake_manager %p: adding handshaker %s [%p] at index %" PRIuPTR, - this, handshaker->name(), handshaker.get(), handshakers_.size()); + this, std::string(handshaker->name()).c_str(), handshaker.get(), + handshakers_.size()); } handshakers_.push_back(std::move(handshaker)); } -HandshakeManager::~HandshakeManager() { handshakers_.clear(); } +void HandshakeManager::DoHandshake( + OrphanablePtr endpoint, const ChannelArgs& channel_args, + Timestamp deadline, grpc_tcp_server_acceptor* acceptor, + absl::AnyInvocable)> + on_handshake_done) { + MutexLock lock(&mu_); + CHECK_EQ(index_, 0u); + on_handshake_done_ = std::move(on_handshake_done); + // Construct handshaker args. These will be passed through all + // handshakers and eventually be freed by the on_handshake_done callback. + args_.endpoint = std::move(endpoint); + args_.deadline = deadline; + args_.args = channel_args; + args_.event_engine = args_.args.GetObject(); + args_.acceptor = acceptor; + if (acceptor != nullptr && acceptor->external_connection && + acceptor->pending_data != nullptr) { + grpc_slice_buffer_swap(args_.read_buffer.c_slice_buffer(), + &(acceptor->pending_data->data.raw.slice_buffer)); + // TODO(vigneshbabu): For connections accepted through event engine + // listeners, the ownership of the byte buffer received is transferred to + // this callback and it is thus this callback's duty to delete it. + // Make this hack default once event engine is rolled out. + if (grpc_event_engine::experimental::grpc_is_event_engine_endpoint( + args_.endpoint.get())) { + grpc_byte_buffer_destroy(acceptor->pending_data); + } + } + // Start deadline timer, which owns a ref. + const Duration time_to_deadline = deadline - Timestamp::Now(); + deadline_timer_handle_ = + args_.event_engine->RunAfter(time_to_deadline, [self = Ref()]() mutable { + ApplicationCallbackExecCtx callback_exec_ctx; + ExecCtx exec_ctx; + self->Shutdown(GRPC_ERROR_CREATE("Handshake timed out")); + // HandshakeManager deletion might require an active ExecCtx. + self.reset(); + }); + // Start first handshaker. + CallNextHandshakerLocked(absl::OkStatus()); +} -void HandshakeManager::Shutdown(grpc_error_handle why) { - { - MutexLock lock(&mu_); +void HandshakeManager::Shutdown(absl::Status error) { + MutexLock lock(&mu_); + if (!is_shutdown_) { + if (GRPC_TRACE_FLAG_ENABLED(handshaker)) { + gpr_log(GPR_INFO, "handshake_manager %p: Shutdown() called: %s", this, + error.ToString().c_str()); + } + is_shutdown_ = true; // Shutdown the handshaker that's currently in progress, if any. - if (!is_shutdown_ && index_ > 0) { - is_shutdown_ = true; - handshakers_[index_ - 1]->Shutdown(why); + if (index_ > 0) { + if (GRPC_TRACE_FLAG_ENABLED(handshaker)) { + gpr_log( + GPR_INFO, + "handshake_manager %p: shutting down handshaker at index %" PRIuPTR, + this, index_ - 1); + } + handshakers_[index_ - 1]->Shutdown(std::move(error)); } } } -// Helper function to call either the next handshaker or the -// on_handshake_done callback. -// Returns true if we've scheduled the on_handshake_done callback. -bool HandshakeManager::CallNextHandshakerLocked(grpc_error_handle error) { +void HandshakeManager::CallNextHandshakerLocked(absl::Status error) { if (GRPC_TRACE_FLAG_ENABLED(handshaker)) { gpr_log(GPR_INFO, "handshake_manager %p: error=%s shutdown=%d index=%" PRIuPTR ", args=%s", - this, StatusToString(error).c_str(), is_shutdown_, index_, + this, error.ToString().c_str(), is_shutdown_, index_, HandshakerArgsString(&args_).c_str()); } CHECK(index_ <= handshakers_.size()); // If we got an error or we've been shut down or we're exiting early or // we've finished the last handshaker, invoke the on_handshake_done - // callback. Otherwise, call the next handshaker. + // callback. if (!error.ok() || is_shutdown_ || args_.exit_early || index_ == handshakers_.size()) { if (error.ok() && is_shutdown_) { error = GRPC_ERROR_CREATE("handshaker shutdown"); - // It is possible that the endpoint has already been destroyed by - // a shutdown call while this callback was sitting on the ExecCtx - // with no error. - if (args_.endpoint != nullptr) { - grpc_endpoint_destroy(args_.endpoint); - args_.endpoint = nullptr; - } - if (args_.read_buffer != nullptr) { - grpc_slice_buffer_destroy(args_.read_buffer); - gpr_free(args_.read_buffer); - args_.read_buffer = nullptr; - } - args_.args = ChannelArgs(); + args_.endpoint.reset(); } if (GRPC_TRACE_FLAG_ENABLED(handshaker)) { gpr_log(GPR_INFO, "handshake_manager %p: handshaking complete -- scheduling " "on_handshake_done with error=%s", - this, StatusToString(error).c_str()); + this, error.ToString().c_str()); } // Cancel deadline timer, since we're invoking the on_handshake_done // callback now. - event_engine_->Cancel(deadline_timer_handle_); - ExecCtx::Run(DEBUG_LOCATION, &on_handshake_done_, error); + args_.event_engine->Cancel(deadline_timer_handle_); is_shutdown_ = true; - } else { - auto handshaker = handshakers_[index_]; - if (GRPC_TRACE_FLAG_ENABLED(handshaker)) { - gpr_log( - GPR_INFO, - "handshake_manager %p: calling handshaker %s [%p] at index %" PRIuPTR, - this, handshaker->name(), handshaker.get(), index_); - } - handshaker->DoHandshake(acceptor_, &call_next_handshaker_, &args_); - } - ++index_; - return is_shutdown_; -} - -void HandshakeManager::CallNextHandshakerFn(void* arg, - grpc_error_handle error) { - auto* mgr = static_cast(arg); - bool done; - { - MutexLock lock(&mgr->mu_); - done = mgr->CallNextHandshakerLocked(error); - } - // If we're invoked the final callback, we won't be coming back - // to this function, so we can release our reference to the - // handshake manager. - if (done) { - mgr->Unref(); - } -} - -void HandshakeManager::DoHandshake(grpc_endpoint* endpoint, - const ChannelArgs& channel_args, - Timestamp deadline, - grpc_tcp_server_acceptor* acceptor, - grpc_iomgr_cb_func on_handshake_done, - void* user_data) { - bool done; - { - MutexLock lock(&mu_); - CHECK_EQ(index_, 0u); - // Construct handshaker args. These will be passed through all - // handshakers and eventually be freed by the on_handshake_done callback. - args_.endpoint = endpoint; - args_.deadline = deadline; - args_.args = channel_args; - args_.user_data = user_data; - args_.read_buffer = - static_cast(gpr_malloc(sizeof(*args_.read_buffer))); - grpc_slice_buffer_init(args_.read_buffer); - if (acceptor != nullptr && acceptor->external_connection && - acceptor->pending_data != nullptr) { - grpc_slice_buffer_swap(args_.read_buffer, - &(acceptor->pending_data->data.raw.slice_buffer)); - // TODO(vigneshbabu): For connections accepted through event engine - // listeners, the ownership of the byte buffer received is transferred to - // this callback and it is thus this callback's duty to delete it. - // Make this hack default once event engine is rolled out. - if (grpc_event_engine::experimental::grpc_is_event_engine_endpoint( - endpoint)) { - grpc_byte_buffer_destroy(acceptor->pending_data); - } - } - // Initialize state needed for calling handshakers. - acceptor_ = acceptor; - GRPC_CLOSURE_INIT(&call_next_handshaker_, - &HandshakeManager::CallNextHandshakerFn, this, - grpc_schedule_on_exec_ctx); - GRPC_CLOSURE_INIT(&on_handshake_done_, on_handshake_done, &args_, - grpc_schedule_on_exec_ctx); - // Start deadline timer, which owns a ref. - const Duration time_to_deadline = deadline - Timestamp::Now(); - event_engine_ = args_.args.GetObjectRef(); - deadline_timer_handle_ = - event_engine_->RunAfter(time_to_deadline, [self = Ref()]() mutable { - ApplicationCallbackExecCtx callback_exec_ctx; - ExecCtx exec_ctx; - self->Shutdown(GRPC_ERROR_CREATE("Handshake timed out")); - // HandshakeManager deletion might require an active ExecCtx. - self.reset(); - }); - // Start first handshaker, which also owns a ref. - Ref().release(); - done = CallNextHandshakerLocked(absl::OkStatus()); + absl::StatusOr result(&args_); + if (!error.ok()) result = std::move(error); + args_.event_engine->Run([on_handshake_done = std::move(on_handshake_done_), + result = std::move(result)]() mutable { + ApplicationCallbackExecCtx callback_exec_ctx; + ExecCtx exec_ctx; + on_handshake_done(std::move(result)); + // Destroy callback while ExecCtx is still in scope. + on_handshake_done = nullptr; + }); + return; } - if (done) { - Unref(); + // Call the next handshaker. + auto handshaker = handshakers_[index_]; + if (GRPC_TRACE_FLAG_ENABLED(handshaker)) { + gpr_log( + GPR_INFO, + "handshake_manager %p: calling handshaker %s [%p] at index %" PRIuPTR, + this, std::string(handshaker->name()).c_str(), handshaker.get(), + index_); } + ++index_; + handshaker->DoHandshake(&args_, [self = Ref()](absl::Status error) mutable { + MutexLock lock(&self->mu_); + self->CallNextHandshakerLocked(std::move(error)); + }); } } // namespace grpc_core diff --git a/src/core/handshaker/handshaker.h b/src/core/handshaker/handshaker.h index f5df3824081..04beed4a966 100644 --- a/src/core/handshaker/handshaker.h +++ b/src/core/handshaker/handshaker.h @@ -31,6 +31,7 @@ #include #include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gprpp/orphanable.h" #include "src/core/lib/gprpp/ref_counted.h" #include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/gprpp/sync.h" @@ -39,6 +40,7 @@ #include "src/core/lib/iomgr/endpoint.h" #include "src/core/lib/iomgr/error.h" #include "src/core/lib/iomgr/tcp_server.h" +#include "src/core/lib/slice/slice_buffer.h" namespace grpc_core { @@ -49,34 +51,35 @@ namespace grpc_core { /// /// In general, handshakers should be used via a handshake manager. -/// Arguments passed through handshakers and to the on_handshake_done callback. +/// Arguments passed through handshakers and back to the caller. /// /// For handshakers, all members are input/output parameters; for /// example, a handshaker may read from or write to \a endpoint and /// then later replace it with a wrapped endpoint. Similarly, a /// handshaker may modify \a args. /// -/// A handshaker takes ownership of the members while a handshake is in -/// progress. Upon failure or shutdown of an in-progress handshaker, -/// the handshaker is responsible for destroying the members and setting -/// them to NULL before invoking the on_handshake_done callback. -/// -/// For the on_handshake_done callback, all members are input arguments, -/// which the callback takes ownership of. +/// A handshaker takes ownership of the members when this struct is +/// passed to DoHandshake(). It passes ownership back to the caller +/// when it invokes on_handshake_done. struct HandshakerArgs { - grpc_endpoint* endpoint = nullptr; + OrphanablePtr endpoint; ChannelArgs args; - grpc_slice_buffer* read_buffer = nullptr; + // Any bytes read from the endpoint that are not consumed by the + // handshaker must be passed back via this buffer. + SliceBuffer read_buffer; // A handshaker may set this to true before invoking on_handshake_done // to indicate that subsequent handshakers should be skipped. bool exit_early = false; - // User data passed through the handshake manager. Not used by - // individual handshakers. - void* user_data = nullptr; + // EventEngine to use for async work. + // (This is just a convenience to avoid digging it out of args.) + grpc_event_engine::experimental::EventEngine* event_engine = nullptr; // Deadline associated with the handshake. // TODO(anramach): Move this out of handshake args after EventEngine // is the default. Timestamp deadline; + // TODO(roth): Make this go away somehow as part of the EventEngine + // migration? + grpc_tcp_server_acceptor* acceptor = nullptr; }; /// @@ -86,11 +89,23 @@ struct HandshakerArgs { class Handshaker : public RefCounted { public: ~Handshaker() override = default; - virtual void Shutdown(grpc_error_handle why) = 0; - virtual void DoHandshake(grpc_tcp_server_acceptor* acceptor, - grpc_closure* on_handshake_done, - HandshakerArgs* args) = 0; - virtual const char* name() const = 0; + virtual absl::string_view name() const = 0; + virtual void DoHandshake( + HandshakerArgs* args, + absl::AnyInvocable on_handshake_done) = 0; + virtual void Shutdown(absl::Status error) = 0; + + protected: + // Helper function to safely invoke on_handshake_done asynchronously. + // + // Note that on_handshake_done may complete in another thread as soon + // as this method returns, so the handshaker object may be destroyed + // by the callback unless the caller of this method is holding its own + // ref to the handshaker. + static void InvokeOnHandshakeDone( + HandshakerArgs* args, + absl::AnyInvocable on_handshake_done, + absl::Status status); }; // @@ -100,16 +115,11 @@ class Handshaker : public RefCounted { class HandshakeManager : public RefCounted { public: HandshakeManager(); - ~HandshakeManager() override; /// Adds a handshaker to the handshake manager. /// Takes ownership of \a handshaker. void Add(RefCountedPtr handshaker) ABSL_LOCKS_EXCLUDED(mu_); - /// Shuts down the handshake manager (e.g., to clean up when the operation is - /// aborted in the middle). - void Shutdown(grpc_error_handle why) ABSL_LOCKS_EXCLUDED(mu_); - /// Invokes handshakers in the order they were added. /// Takes ownership of \a endpoint, and then passes that ownership to /// the \a on_handshake_done callback. @@ -122,41 +132,39 @@ class HandshakeManager : public RefCounted { /// absl::OkStatus(), then handshaking failed and the handshaker has done /// the necessary clean-up. Otherwise, the callback takes ownership of /// the arguments. - void DoHandshake(grpc_endpoint* endpoint, const ChannelArgs& channel_args, - Timestamp deadline, grpc_tcp_server_acceptor* acceptor, - grpc_iomgr_cb_func on_handshake_done, void* user_data) - ABSL_LOCKS_EXCLUDED(mu_); + void DoHandshake(OrphanablePtr endpoint, + const ChannelArgs& channel_args, Timestamp deadline, + grpc_tcp_server_acceptor* acceptor, + absl::AnyInvocable)> + on_handshake_done) ABSL_LOCKS_EXCLUDED(mu_); - private: - bool CallNextHandshakerLocked(grpc_error_handle error) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + /// Shuts down the handshake manager (e.g., to clean up when the operation is + /// aborted in the middle). + void Shutdown(absl::Status error) ABSL_LOCKS_EXCLUDED(mu_); + private: // A function used as the handshaker-done callback when chaining // handshakers together. - static void CallNextHandshakerFn(void* arg, grpc_error_handle error) - ABSL_LOCKS_EXCLUDED(mu_); + void CallNextHandshakerLocked(absl::Status error) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); - static const size_t HANDSHAKERS_INIT_SIZE = 2; + static const size_t kHandshakerListInlineSize = 2; Mutex mu_; bool is_shutdown_ ABSL_GUARDED_BY(mu_) = false; - // An array of handshakers added via grpc_handshake_manager_add(). - absl::InlinedVector, HANDSHAKERS_INIT_SIZE> - handshakers_ ABSL_GUARDED_BY(mu_); // The index of the handshaker to invoke next and closure to invoke it. size_t index_ ABSL_GUARDED_BY(mu_) = 0; - grpc_closure call_next_handshaker_ ABSL_GUARDED_BY(mu_); - // The acceptor to call the handshakers with. - grpc_tcp_server_acceptor* acceptor_ ABSL_GUARDED_BY(mu_); - // The final callback and user_data to invoke after the last handshaker. - grpc_closure on_handshake_done_ ABSL_GUARDED_BY(mu_); + // An array of handshakers added via Add(). + absl::InlinedVector, kHandshakerListInlineSize> + handshakers_ ABSL_GUARDED_BY(mu_); // Handshaker args. HandshakerArgs args_ ABSL_GUARDED_BY(mu_); + // The final callback to invoke after the last handshaker. + absl::AnyInvocable)> on_handshake_done_ + ABSL_GUARDED_BY(mu_); // Deadline timer across all handshakers. grpc_event_engine::experimental::EventEngine::TaskHandle deadline_timer_handle_ ABSL_GUARDED_BY(mu_); - std::shared_ptr event_engine_ - ABSL_GUARDED_BY(mu_); }; } // namespace grpc_core diff --git a/src/core/handshaker/http_connect/http_connect_handshaker.cc b/src/core/handshaker/http_connect/http_connect_handshaker.cc index de3de3b843f..13fecfc8df5 100644 --- a/src/core/handshaker/http_connect/http_connect_handshaker.cc +++ b/src/core/handshaker/http_connect/http_connect_handshaker.cc @@ -23,6 +23,7 @@ #include #include +#include #include "absl/base/thread_annotations.h" #include "absl/status/status.h" @@ -50,6 +51,8 @@ #include "src/core/lib/iomgr/exec_ctx.h" #include "src/core/lib/iomgr/iomgr_fwd.h" #include "src/core/lib/iomgr/tcp_server.h" +#include "src/core/lib/slice/slice.h" +#include "src/core/lib/slice/slice_buffer.h" #include "src/core/util/http_client/format_request.h" #include "src/core/util/http_client/parser.h" #include "src/core/util/string.h" @@ -61,165 +64,148 @@ namespace { class HttpConnectHandshaker : public Handshaker { public: HttpConnectHandshaker(); - void Shutdown(grpc_error_handle why) override; - void DoHandshake(grpc_tcp_server_acceptor* acceptor, - grpc_closure* on_handshake_done, - HandshakerArgs* args) override; - const char* name() const override { return "http_connect"; } + absl::string_view name() const override { return "http_connect"; } + void DoHandshake( + HandshakerArgs* args, + absl::AnyInvocable on_handshake_done) override; + void Shutdown(absl::Status error) override; private: ~HttpConnectHandshaker() override; - void CleanupArgsForFailureLocked() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); - void HandshakeFailedLocked(grpc_error_handle error) + void HandshakeFailedLocked(absl::Status error) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); - static void OnWriteDone(void* arg, grpc_error_handle error); - static void OnReadDone(void* arg, grpc_error_handle error); + void FinishLocked(absl::Status error) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + void OnWriteDone(absl::Status error); + void OnReadDone(absl::Status error); + bool OnReadDoneLocked(absl::Status error) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); static void OnWriteDoneScheduler(void* arg, grpc_error_handle error); static void OnReadDoneScheduler(void* arg, grpc_error_handle error); Mutex mu_; - bool is_shutdown_ ABSL_GUARDED_BY(mu_) = false; - // Read buffer to destroy after a shutdown. - grpc_slice_buffer* read_buffer_to_destroy_ ABSL_GUARDED_BY(mu_) = nullptr; - // State saved while performing the handshake. HandshakerArgs* args_ = nullptr; - grpc_closure* on_handshake_done_ = nullptr; + absl::AnyInvocable on_handshake_done_ + ABSL_GUARDED_BY(mu_); // Objects for processing the HTTP CONNECT request and response. - grpc_slice_buffer write_buffer_ ABSL_GUARDED_BY(mu_); - grpc_closure request_done_closure_ ABSL_GUARDED_BY(mu_); - grpc_closure response_read_closure_ ABSL_GUARDED_BY(mu_); + SliceBuffer write_buffer_ ABSL_GUARDED_BY(mu_); + grpc_closure on_write_done_scheduler_ ABSL_GUARDED_BY(mu_); + grpc_closure on_read_done_scheduler_ ABSL_GUARDED_BY(mu_); grpc_http_parser http_parser_ ABSL_GUARDED_BY(mu_); grpc_http_response http_response_ ABSL_GUARDED_BY(mu_); }; HttpConnectHandshaker::~HttpConnectHandshaker() { - if (read_buffer_to_destroy_ != nullptr) { - grpc_slice_buffer_destroy(read_buffer_to_destroy_); - gpr_free(read_buffer_to_destroy_); - } - grpc_slice_buffer_destroy(&write_buffer_); grpc_http_parser_destroy(&http_parser_); grpc_http_response_destroy(&http_response_); } -// Set args fields to nullptr, saving the endpoint and read buffer for -// later destruction. -void HttpConnectHandshaker::CleanupArgsForFailureLocked() { - read_buffer_to_destroy_ = args_->read_buffer; - args_->read_buffer = nullptr; - args_->args = ChannelArgs(); -} - // If the handshake failed or we're shutting down, clean up and invoke the // callback with the error. -void HttpConnectHandshaker::HandshakeFailedLocked(grpc_error_handle error) { +void HttpConnectHandshaker::HandshakeFailedLocked(absl::Status error) { if (error.ok()) { // If we were shut down after an endpoint operation succeeded but // before the endpoint callback was invoked, we need to generate our // own error. error = GRPC_ERROR_CREATE("Handshaker shutdown"); } - if (!is_shutdown_) { - // Not shutting down, so the handshake failed. Clean up before - // invoking the callback. - grpc_endpoint_destroy(args_->endpoint); - args_->endpoint = nullptr; - CleanupArgsForFailureLocked(); - // Set shutdown to true so that subsequent calls to - // http_connect_handshaker_shutdown() do nothing. - is_shutdown_ = true; - } // Invoke callback. - ExecCtx::Run(DEBUG_LOCATION, on_handshake_done_, error); + FinishLocked(std::move(error)); +} + +void HttpConnectHandshaker::FinishLocked(absl::Status error) { + InvokeOnHandshakeDone(args_, std::move(on_handshake_done_), std::move(error)); } // This callback can be invoked inline while already holding onto the mutex. To // avoid deadlocks, schedule OnWriteDone on ExecCtx. +// TODO(roth): This hop will no longer be needed when we migrate to the +// EventEngine endpoint API. void HttpConnectHandshaker::OnWriteDoneScheduler(void* arg, grpc_error_handle error) { auto* handshaker = static_cast(arg); - ExecCtx::Run(DEBUG_LOCATION, - GRPC_CLOSURE_INIT(&handshaker->request_done_closure_, - &HttpConnectHandshaker::OnWriteDone, - handshaker, grpc_schedule_on_exec_ctx), - error); + handshaker->args_->event_engine->Run( + [handshaker, error = std::move(error)]() mutable { + ApplicationCallbackExecCtx callback_exec_ctx; + ExecCtx exec_ctx; + handshaker->OnWriteDone(std::move(error)); + }); } // Callback invoked when finished writing HTTP CONNECT request. -void HttpConnectHandshaker::OnWriteDone(void* arg, grpc_error_handle error) { - auto* handshaker = static_cast(arg); - ReleasableMutexLock lock(&handshaker->mu_); - if (!error.ok() || handshaker->is_shutdown_) { +void HttpConnectHandshaker::OnWriteDone(absl::Status error) { + ReleasableMutexLock lock(&mu_); + if (!error.ok() || args_->endpoint == nullptr) { // If the write failed or we're shutting down, clean up and invoke the // callback with the error. - handshaker->HandshakeFailedLocked(error); + HandshakeFailedLocked(error); lock.Release(); - handshaker->Unref(); + Unref(); } else { // Otherwise, read the response. // The read callback inherits our ref to the handshaker. grpc_endpoint_read( - handshaker->args_->endpoint, handshaker->args_->read_buffer, - GRPC_CLOSURE_INIT(&handshaker->response_read_closure_, - &HttpConnectHandshaker::OnReadDoneScheduler, - handshaker, grpc_schedule_on_exec_ctx), + args_->endpoint.get(), args_->read_buffer.c_slice_buffer(), + GRPC_CLOSURE_INIT(&on_read_done_scheduler_, + &HttpConnectHandshaker::OnReadDoneScheduler, this, + grpc_schedule_on_exec_ctx), /*urgent=*/true, /*min_progress_size=*/1); } } // This callback can be invoked inline while already holding onto the mutex. To // avoid deadlocks, schedule OnReadDone on ExecCtx. +// TODO(roth): This hop will no longer be needed when we migrate to the +// EventEngine endpoint API. void HttpConnectHandshaker::OnReadDoneScheduler(void* arg, grpc_error_handle error) { auto* handshaker = static_cast(arg); - ExecCtx::Run(DEBUG_LOCATION, - GRPC_CLOSURE_INIT(&handshaker->response_read_closure_, - &HttpConnectHandshaker::OnReadDone, handshaker, - grpc_schedule_on_exec_ctx), - error); + handshaker->args_->event_engine->Run( + [handshaker, error = std::move(error)]() mutable { + ApplicationCallbackExecCtx callback_exec_ctx; + ExecCtx exec_ctx; + handshaker->OnReadDone(std::move(error)); + }); } // Callback invoked for reading HTTP CONNECT response. -void HttpConnectHandshaker::OnReadDone(void* arg, grpc_error_handle error) { - auto* handshaker = static_cast(arg); - ReleasableMutexLock lock(&handshaker->mu_); - if (!error.ok() || handshaker->is_shutdown_) { +void HttpConnectHandshaker::OnReadDone(absl::Status error) { + bool done; + { + MutexLock lock(&mu_); + done = OnReadDoneLocked(std::move(error)); + } + if (done) Unref(); +} + +bool HttpConnectHandshaker::OnReadDoneLocked(absl::Status error) { + if (!error.ok() || args_->endpoint == nullptr) { // If the read failed or we're shutting down, clean up and invoke the // callback with the error. - handshaker->HandshakeFailedLocked(error); - goto done; + HandshakeFailedLocked(std::move(error)); + return true; } // Add buffer to parser. - for (size_t i = 0; i < handshaker->args_->read_buffer->count; ++i) { - if (GRPC_SLICE_LENGTH(handshaker->args_->read_buffer->slices[i]) > 0) { + while (args_->read_buffer.Count() > 0) { + Slice slice = args_->read_buffer.TakeFirst(); + if (slice.length() > 0) { size_t body_start_offset = 0; - error = grpc_http_parser_parse(&handshaker->http_parser_, - handshaker->args_->read_buffer->slices[i], + error = grpc_http_parser_parse(&http_parser_, slice.c_slice(), &body_start_offset); if (!error.ok()) { - handshaker->HandshakeFailedLocked(error); - goto done; + HandshakeFailedLocked(std::move(error)); + return true; } - if (handshaker->http_parser_.state == GRPC_HTTP_BODY) { + if (http_parser_.state == GRPC_HTTP_BODY) { // Remove the data we've already read from the read buffer, // leaving only the leftover bytes (if any). - grpc_slice_buffer tmp_buffer; - grpc_slice_buffer_init(&tmp_buffer); - if (body_start_offset < - GRPC_SLICE_LENGTH(handshaker->args_->read_buffer->slices[i])) { - grpc_slice_buffer_add( - &tmp_buffer, - grpc_slice_split_tail(&handshaker->args_->read_buffer->slices[i], - body_start_offset)); + SliceBuffer tmp_buffer; + if (body_start_offset < slice.length()) { + tmp_buffer.Append(slice.Split(body_start_offset)); } - grpc_slice_buffer_addn(&tmp_buffer, - &handshaker->args_->read_buffer->slices[i + 1], - handshaker->args_->read_buffer->count - i - 1); - grpc_slice_buffer_swap(handshaker->args_->read_buffer, &tmp_buffer); - grpc_slice_buffer_destroy(&tmp_buffer); + tmp_buffer.TakeAndAppend(args_->read_buffer); + tmp_buffer.Swap(&args_->read_buffer); break; } } @@ -235,65 +221,46 @@ void HttpConnectHandshaker::OnReadDone(void* arg, grpc_error_handle error) { // need to fix the HTTP parser to understand when the body is // complete (e.g., handling chunked transfer encoding or looking // at the Content-Length: header). - if (handshaker->http_parser_.state != GRPC_HTTP_BODY) { - grpc_slice_buffer_reset_and_unref(handshaker->args_->read_buffer); + if (http_parser_.state != GRPC_HTTP_BODY) { + args_->read_buffer.Clear(); grpc_endpoint_read( - handshaker->args_->endpoint, handshaker->args_->read_buffer, - GRPC_CLOSURE_INIT(&handshaker->response_read_closure_, - &HttpConnectHandshaker::OnReadDoneScheduler, - handshaker, grpc_schedule_on_exec_ctx), + args_->endpoint.get(), args_->read_buffer.c_slice_buffer(), + GRPC_CLOSURE_INIT(&on_read_done_scheduler_, + &HttpConnectHandshaker::OnReadDoneScheduler, this, + grpc_schedule_on_exec_ctx), /*urgent=*/true, /*min_progress_size=*/1); - return; + return false; } // Make sure we got a 2xx response. - if (handshaker->http_response_.status < 200 || - handshaker->http_response_.status >= 300) { + if (http_response_.status < 200 || http_response_.status >= 300) { error = GRPC_ERROR_CREATE(absl::StrCat("HTTP proxy returned response code ", - handshaker->http_response_.status)); - handshaker->HandshakeFailedLocked(error); - goto done; + http_response_.status)); + HandshakeFailedLocked(std::move(error)); + return true; } // Success. Invoke handshake-done callback. - ExecCtx::Run(DEBUG_LOCATION, handshaker->on_handshake_done_, error); -done: - // Set shutdown to true so that subsequent calls to - // http_connect_handshaker_shutdown() do nothing. - handshaker->is_shutdown_ = true; - lock.Release(); - handshaker->Unref(); + FinishLocked(absl::OkStatus()); + return true; } // // Public handshaker methods // -void HttpConnectHandshaker::Shutdown(grpc_error_handle /*why*/) { - { - MutexLock lock(&mu_); - if (!is_shutdown_) { - is_shutdown_ = true; - grpc_endpoint_destroy(args_->endpoint); - args_->endpoint = nullptr; - CleanupArgsForFailureLocked(); - } - } +void HttpConnectHandshaker::Shutdown(absl::Status /*error*/) { + MutexLock lock(&mu_); + if (on_handshake_done_ != nullptr) args_->endpoint.reset(); } -void HttpConnectHandshaker::DoHandshake(grpc_tcp_server_acceptor* /*acceptor*/, - grpc_closure* on_handshake_done, - HandshakerArgs* args) { +void HttpConnectHandshaker::DoHandshake( + HandshakerArgs* args, + absl::AnyInvocable on_handshake_done) { // Check for HTTP CONNECT channel arg. // If not found, invoke on_handshake_done without doing anything. absl::optional server_name = args->args.GetString(GRPC_ARG_HTTP_CONNECT_SERVER); if (!server_name.has_value()) { - // Set shutdown to true so that subsequent calls to - // http_connect_handshaker_shutdown() do nothing. - { - MutexLock lock(&mu_); - is_shutdown_ = true; - } - ExecCtx::Run(DEBUG_LOCATION, on_handshake_done, absl::OkStatus()); + InvokeOnHandshakeDone(args, std::move(on_handshake_done), absl::OkStatus()); return; } // Get headers from channel args. @@ -311,7 +278,6 @@ void HttpConnectHandshaker::DoHandshake(grpc_tcp_server_acceptor* /*acceptor*/, gpr_malloc(sizeof(grpc_http_header) * num_header_strings)); for (size_t i = 0; i < num_header_strings; ++i) { char* sep = strchr(header_strings[i], ':'); - if (sep == nullptr) { gpr_log(GPR_ERROR, "skipping unparseable HTTP CONNECT header: %s", header_strings[i]); @@ -326,9 +292,9 @@ void HttpConnectHandshaker::DoHandshake(grpc_tcp_server_acceptor* /*acceptor*/, // Save state in the handshaker object. MutexLock lock(&mu_); args_ = args; - on_handshake_done_ = on_handshake_done; + on_handshake_done_ = std::move(on_handshake_done); // Log connection via proxy. - std::string proxy_name(grpc_endpoint_get_peer(args->endpoint)); + std::string proxy_name(grpc_endpoint_get_peer(args->endpoint.get())); std::string server_name_string(*server_name); gpr_log(GPR_INFO, "Connecting to server %s via HTTP proxy %s", server_name_string.c_str(), proxy_name.c_str()); @@ -342,7 +308,7 @@ void HttpConnectHandshaker::DoHandshake(grpc_tcp_server_acceptor* /*acceptor*/, request.body = nullptr; grpc_slice request_slice = grpc_httpcli_format_connect_request( &request, server_name_string.c_str(), server_name_string.c_str()); - grpc_slice_buffer_add(&write_buffer_, request_slice); + write_buffer_.Append(Slice(request_slice)); // Clean up. gpr_free(headers); for (size_t i = 0; i < num_header_strings; ++i) { @@ -352,15 +318,14 @@ void HttpConnectHandshaker::DoHandshake(grpc_tcp_server_acceptor* /*acceptor*/, // Take a new ref to be held by the write callback. Ref().release(); grpc_endpoint_write( - args->endpoint, &write_buffer_, - GRPC_CLOSURE_INIT(&request_done_closure_, + args->endpoint.get(), write_buffer_.c_slice_buffer(), + GRPC_CLOSURE_INIT(&on_write_done_scheduler_, &HttpConnectHandshaker::OnWriteDoneScheduler, this, grpc_schedule_on_exec_ctx), nullptr, /*max_frame_size=*/INT_MAX); } HttpConnectHandshaker::HttpConnectHandshaker() { - grpc_slice_buffer_init(&write_buffer_); grpc_http_parser_init(&http_parser_, GRPC_HTTP_RESPONSE, &http_response_); } diff --git a/src/core/handshaker/security/secure_endpoint.cc b/src/core/handshaker/security/secure_endpoint.cc index 18fde152438..972b54d619f 100644 --- a/src/core/handshaker/security/secure_endpoint.cc +++ b/src/core/handshaker/security/secure_endpoint.cc @@ -23,6 +23,7 @@ #include #include #include +#include #include "absl/base/thread_annotations.h" #include "absl/log/check.h" @@ -43,9 +44,11 @@ #include "src/core/lib/debug/trace.h" #include "src/core/lib/gprpp/debug_location.h" +#include "src/core/lib/gprpp/orphanable.h" #include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/gprpp/sync.h" #include "src/core/lib/iomgr/closure.h" +#include "src/core/lib/iomgr/endpoint.h" #include "src/core/lib/iomgr/error.h" #include "src/core/lib/iomgr/exec_ctx.h" #include "src/core/lib/iomgr/iomgr_fwd.h" @@ -64,17 +67,18 @@ static void on_read(void* user_data, grpc_error_handle error); static void on_write(void* user_data, grpc_error_handle error); namespace { -struct secure_endpoint { - secure_endpoint(const grpc_endpoint_vtable* vtable, +struct secure_endpoint : public grpc_endpoint { + secure_endpoint(const grpc_endpoint_vtable* vtbl, tsi_frame_protector* protector, tsi_zero_copy_grpc_protector* zero_copy_protector, - grpc_endpoint* transport, grpc_slice* leftover_slices, + grpc_core::OrphanablePtr endpoint, + grpc_slice* leftover_slices, const grpc_channel_args* channel_args, size_t leftover_nslices) - : wrapped_ep(transport), + : wrapped_ep(std::move(endpoint)), protector(protector), zero_copy_protector(zero_copy_protector) { - base.vtable = vtable; + this->vtable = vtbl; gpr_mu_init(&protector_mu); GRPC_CLOSURE_INIT(&on_read, ::on_read, this, grpc_schedule_on_exec_ctx); GRPC_CLOSURE_INIT(&on_write, ::on_write, this, grpc_schedule_on_exec_ctx); @@ -117,8 +121,7 @@ struct secure_endpoint { gpr_mu_destroy(&protector_mu); } - grpc_endpoint base; - grpc_endpoint* wrapped_ep; + grpc_core::OrphanablePtr wrapped_ep; struct tsi_frame_protector* protector; struct tsi_zero_copy_grpc_protector* zero_copy_protector; gpr_mu protector_mu; @@ -365,8 +368,8 @@ static void endpoint_read(grpc_endpoint* secure_ep, grpc_slice_buffer* slices, return; } - grpc_endpoint_read(ep->wrapped_ep, &ep->source_buffer, &ep->on_read, urgent, - /*min_progress_size=*/ep->min_progress_size); + grpc_endpoint_read(ep->wrapped_ep.get(), &ep->source_buffer, &ep->on_read, + urgent, /*min_progress_size=*/ep->min_progress_size); } static void flush_write_staging_buffer(secure_endpoint* ep, uint8_t** cur, @@ -500,52 +503,52 @@ static void endpoint_write(grpc_endpoint* secure_ep, grpc_slice_buffer* slices, // output_buffer at any time until the write completes. SECURE_ENDPOINT_REF(ep, "write"); ep->write_cb = cb; - grpc_endpoint_write(ep->wrapped_ep, &ep->output_buffer, &ep->on_write, arg, - max_frame_size); + grpc_endpoint_write(ep->wrapped_ep.get(), &ep->output_buffer, &ep->on_write, + arg, max_frame_size); } static void endpoint_destroy(grpc_endpoint* secure_ep) { secure_endpoint* ep = reinterpret_cast(secure_ep); - grpc_endpoint_destroy(ep->wrapped_ep); + ep->wrapped_ep.reset(); SECURE_ENDPOINT_UNREF(ep, "destroy"); } static void endpoint_add_to_pollset(grpc_endpoint* secure_ep, grpc_pollset* pollset) { secure_endpoint* ep = reinterpret_cast(secure_ep); - grpc_endpoint_add_to_pollset(ep->wrapped_ep, pollset); + grpc_endpoint_add_to_pollset(ep->wrapped_ep.get(), pollset); } static void endpoint_add_to_pollset_set(grpc_endpoint* secure_ep, grpc_pollset_set* pollset_set) { secure_endpoint* ep = reinterpret_cast(secure_ep); - grpc_endpoint_add_to_pollset_set(ep->wrapped_ep, pollset_set); + grpc_endpoint_add_to_pollset_set(ep->wrapped_ep.get(), pollset_set); } static void endpoint_delete_from_pollset_set(grpc_endpoint* secure_ep, grpc_pollset_set* pollset_set) { secure_endpoint* ep = reinterpret_cast(secure_ep); - grpc_endpoint_delete_from_pollset_set(ep->wrapped_ep, pollset_set); + grpc_endpoint_delete_from_pollset_set(ep->wrapped_ep.get(), pollset_set); } static absl::string_view endpoint_get_peer(grpc_endpoint* secure_ep) { secure_endpoint* ep = reinterpret_cast(secure_ep); - return grpc_endpoint_get_peer(ep->wrapped_ep); + return grpc_endpoint_get_peer(ep->wrapped_ep.get()); } static absl::string_view endpoint_get_local_address(grpc_endpoint* secure_ep) { secure_endpoint* ep = reinterpret_cast(secure_ep); - return grpc_endpoint_get_local_address(ep->wrapped_ep); + return grpc_endpoint_get_local_address(ep->wrapped_ep.get()); } static int endpoint_get_fd(grpc_endpoint* secure_ep) { secure_endpoint* ep = reinterpret_cast(secure_ep); - return grpc_endpoint_get_fd(ep->wrapped_ep); + return grpc_endpoint_get_fd(ep->wrapped_ep.get()); } static bool endpoint_can_track_err(grpc_endpoint* secure_ep) { secure_endpoint* ep = reinterpret_cast(secure_ep); - return grpc_endpoint_can_track_err(ep->wrapped_ep); + return grpc_endpoint_can_track_err(ep->wrapped_ep.get()); } static const grpc_endpoint_vtable vtable = {endpoint_read, @@ -559,13 +562,13 @@ static const grpc_endpoint_vtable vtable = {endpoint_read, endpoint_get_fd, endpoint_can_track_err}; -grpc_endpoint* grpc_secure_endpoint_create( +grpc_core::OrphanablePtr grpc_secure_endpoint_create( struct tsi_frame_protector* protector, struct tsi_zero_copy_grpc_protector* zero_copy_protector, - grpc_endpoint* to_wrap, grpc_slice* leftover_slices, - const grpc_channel_args* channel_args, size_t leftover_nslices) { - secure_endpoint* ep = - new secure_endpoint(&vtable, protector, zero_copy_protector, to_wrap, - leftover_slices, channel_args, leftover_nslices); - return &ep->base; + grpc_core::OrphanablePtr to_wrap, + grpc_slice* leftover_slices, const grpc_channel_args* channel_args, + size_t leftover_nslices) { + return grpc_core::MakeOrphanable( + &vtable, protector, zero_copy_protector, std::move(to_wrap), + leftover_slices, channel_args, leftover_nslices); } diff --git a/src/core/handshaker/security/secure_endpoint.h b/src/core/handshaker/security/secure_endpoint.h index a9d6d2088c4..43add1a816b 100644 --- a/src/core/handshaker/security/secure_endpoint.h +++ b/src/core/handshaker/security/secure_endpoint.h @@ -26,15 +26,17 @@ #include #include "src/core/lib/debug/trace.h" +#include "src/core/lib/gprpp/orphanable.h" #include "src/core/lib/iomgr/endpoint.h" // Takes ownership of protector, zero_copy_protector, and to_wrap, and refs // leftover_slices. If zero_copy_protector is not NULL, protector will never be // used. -grpc_endpoint* grpc_secure_endpoint_create( +grpc_core::OrphanablePtr grpc_secure_endpoint_create( struct tsi_frame_protector* protector, struct tsi_zero_copy_grpc_protector* zero_copy_protector, - grpc_endpoint* to_wrap, grpc_slice* leftover_slices, - const grpc_channel_args* channel_args, size_t leftover_nslices); + grpc_core::OrphanablePtr to_wrap, + grpc_slice* leftover_slices, const grpc_channel_args* channel_args, + size_t leftover_nslices); #endif // GRPC_SRC_CORE_HANDSHAKER_SECURITY_SECURE_ENDPOINT_H diff --git a/src/core/handshaker/security/security_handshaker.cc b/src/core/handshaker/security/security_handshaker.cc index dba7f399e1e..58c9a16eaee 100644 --- a/src/core/handshaker/security/security_handshaker.cc +++ b/src/core/handshaker/security/security_handshaker.cc @@ -28,6 +28,7 @@ #include #include "absl/base/attributes.h" +#include "absl/functional/any_invocable.h" #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" @@ -79,11 +80,11 @@ class SecurityHandshaker : public Handshaker { grpc_security_connector* connector, const ChannelArgs& args); ~SecurityHandshaker() override; - void Shutdown(grpc_error_handle why) override; - void DoHandshake(grpc_tcp_server_acceptor* acceptor, - grpc_closure* on_handshake_done, - HandshakerArgs* args) override; - const char* name() const override { return "security"; } + absl::string_view name() const override { return "security"; } + void DoHandshake( + HandshakerArgs* args, + absl::AnyInvocable on_handshake_done) override; + void Shutdown(absl::Status error) override; private: grpc_error_handle DoHandshakerNextLocked(const unsigned char* bytes_received, @@ -92,12 +93,11 @@ class SecurityHandshaker : public Handshaker { grpc_error_handle OnHandshakeNextDoneLocked( tsi_result result, const unsigned char* bytes_to_send, size_t bytes_to_send_size, tsi_handshaker_result* handshaker_result); - void HandshakeFailedLocked(grpc_error_handle error); - void CleanupArgsForFailureLocked(); + void HandshakeFailedLocked(absl::Status error); + void Finish(absl::Status status); - static void OnHandshakeDataReceivedFromPeerFn(void* arg, - grpc_error_handle error); - static void OnHandshakeDataSentToPeerFn(void* arg, grpc_error_handle error); + void OnHandshakeDataReceivedFromPeerFn(absl::Status error); + void OnHandshakeDataSentToPeerFn(absl::Status error); static void OnHandshakeDataReceivedFromPeerFnScheduler( void* arg, grpc_error_handle error); static void OnHandshakeDataSentToPeerFnScheduler(void* arg, @@ -117,16 +117,14 @@ class SecurityHandshaker : public Handshaker { Mutex mu_; bool is_shutdown_ = false; - // Read buffer to destroy after a shutdown. - grpc_slice_buffer* read_buffer_to_destroy_ = nullptr; // State saved while performing the handshake. HandshakerArgs* args_ = nullptr; - grpc_closure* on_handshake_done_ = nullptr; + absl::AnyInvocable on_handshake_done_; size_t handshake_buffer_size_; unsigned char* handshake_buffer_; - grpc_slice_buffer outgoing_; + SliceBuffer outgoing_; grpc_closure on_handshake_data_sent_to_peer_; grpc_closure on_handshake_data_received_from_peer_; grpc_closure on_peer_checked_; @@ -146,7 +144,6 @@ SecurityHandshaker::SecurityHandshaker(tsi_handshaker* handshaker, static_cast(gpr_malloc(handshake_buffer_size_))), max_frame_size_( std::max(0, args.GetInt(GRPC_ARG_TSI_MAX_FRAME_SIZE).value_or(0))) { - grpc_slice_buffer_init(&outgoing_); GRPC_CLOSURE_INIT(&on_peer_checked_, &SecurityHandshaker::OnPeerCheckedFn, this, grpc_schedule_on_exec_ctx); } @@ -154,45 +151,30 @@ SecurityHandshaker::SecurityHandshaker(tsi_handshaker* handshaker, SecurityHandshaker::~SecurityHandshaker() { tsi_handshaker_destroy(handshaker_); tsi_handshaker_result_destroy(handshaker_result_); - if (read_buffer_to_destroy_ != nullptr) { - grpc_slice_buffer_destroy(read_buffer_to_destroy_); - gpr_free(read_buffer_to_destroy_); - } gpr_free(handshake_buffer_); - grpc_slice_buffer_destroy(&outgoing_); auth_context_.reset(DEBUG_LOCATION, "handshake"); connector_.reset(DEBUG_LOCATION, "handshake"); } size_t SecurityHandshaker::MoveReadBufferIntoHandshakeBuffer() { - size_t bytes_in_read_buffer = args_->read_buffer->length; + size_t bytes_in_read_buffer = args_->read_buffer.Length(); if (handshake_buffer_size_ < bytes_in_read_buffer) { handshake_buffer_ = static_cast( gpr_realloc(handshake_buffer_, bytes_in_read_buffer)); handshake_buffer_size_ = bytes_in_read_buffer; } size_t offset = 0; - while (args_->read_buffer->count > 0) { - grpc_slice* next_slice = grpc_slice_buffer_peek_first(args_->read_buffer); - memcpy(handshake_buffer_ + offset, GRPC_SLICE_START_PTR(*next_slice), - GRPC_SLICE_LENGTH(*next_slice)); - offset += GRPC_SLICE_LENGTH(*next_slice); - grpc_slice_buffer_remove_first(args_->read_buffer); + while (args_->read_buffer.Count() > 0) { + Slice slice = args_->read_buffer.TakeFirst(); + memcpy(handshake_buffer_ + offset, slice.data(), slice.size()); + offset += slice.size(); } return bytes_in_read_buffer; } -// Set args_ fields to NULL, saving the endpoint and read buffer for -// later destruction. -void SecurityHandshaker::CleanupArgsForFailureLocked() { - read_buffer_to_destroy_ = args_->read_buffer; - args_->read_buffer = nullptr; - args_->args = ChannelArgs(); -} - // If the handshake failed or we're shutting down, clean up and invoke the // callback with the error. -void SecurityHandshaker::HandshakeFailedLocked(grpc_error_handle error) { +void SecurityHandshaker::HandshakeFailedLocked(absl::Status error) { if (error.ok()) { // If we were shut down after the handshake succeeded but before an // endpoint callback was invoked, we need to generate our own error. @@ -200,17 +182,17 @@ void SecurityHandshaker::HandshakeFailedLocked(grpc_error_handle error) { } if (!is_shutdown_) { tsi_handshaker_shutdown(handshaker_); - grpc_endpoint_destroy(args_->endpoint); - args_->endpoint = nullptr; - // Not shutting down, so the write failed. Clean up before - // invoking the callback. - CleanupArgsForFailureLocked(); // Set shutdown to true so that subsequent calls to // security_handshaker_shutdown() do nothing. is_shutdown_ = true; } // Invoke callback. - ExecCtx::Run(DEBUG_LOCATION, on_handshake_done_, error); + Finish(std::move(error)); +} + +void SecurityHandshaker::Finish(absl::Status status) { + InvokeOnHandshakeDone(args_, std::move(on_handshake_done_), + std::move(status)); } namespace { @@ -306,19 +288,18 @@ void SecurityHandshaker::OnPeerCheckedInner(grpc_error_handle error) { grpc_slice slice = grpc_slice_from_copied_buffer( reinterpret_cast(unused_bytes), unused_bytes_size); args_->endpoint = grpc_secure_endpoint_create( - protector, zero_copy_protector, args_->endpoint, &slice, + protector, zero_copy_protector, std::move(args_->endpoint), &slice, args_->args.ToC().get(), 1); CSliceUnref(slice); } else { args_->endpoint = grpc_secure_endpoint_create( - protector, zero_copy_protector, args_->endpoint, nullptr, + protector, zero_copy_protector, std::move(args_->endpoint), nullptr, args_->args.ToC().get(), 0); } } else if (unused_bytes_size > 0) { // Not wrapping the endpoint, so just pass along unused bytes. - grpc_slice slice = grpc_slice_from_copied_buffer( - reinterpret_cast(unused_bytes), unused_bytes_size); - grpc_slice_buffer_add(args_->read_buffer, slice); + args_->read_buffer.Append(Slice::FromCopiedBuffer( + reinterpret_cast(unused_bytes), unused_bytes_size)); } // Done with handshaker result. tsi_handshaker_result_destroy(handshaker_result_); @@ -329,11 +310,11 @@ void SecurityHandshaker::OnPeerCheckedInner(grpc_error_handle error) { args_->args = args_->args.SetObject( MakeChannelzSecurityFromAuthContext(auth_context_.get())); } - // Invoke callback. - ExecCtx::Run(DEBUG_LOCATION, on_handshake_done_, absl::OkStatus()); // Set shutdown to true so that subsequent calls to // security_handshaker_shutdown() do nothing. is_shutdown_ = true; + // Invoke callback. + Finish(absl::OkStatus()); } void SecurityHandshaker::OnPeerCheckedFn(void* arg, grpc_error_handle error) { @@ -349,8 +330,8 @@ grpc_error_handle SecurityHandshaker::CheckPeerLocked() { return GRPC_ERROR_CREATE(absl::StrCat("Peer extraction failed (", tsi_result_to_string(result), ")")); } - connector_->check_peer(peer, args_->endpoint, args_->args, &auth_context_, - &on_peer_checked_); + connector_->check_peer(peer, args_->endpoint.get(), args_->args, + &auth_context_, &on_peer_checked_); grpc_auth_property_iterator it = grpc_auth_context_find_properties_by_name( auth_context_.get(), GRPC_TRANSPORT_SECURITY_LEVEL_PROPERTY_NAME); const grpc_auth_property* prop = grpc_auth_property_iterator_next(&it); @@ -374,7 +355,7 @@ grpc_error_handle SecurityHandshaker::OnHandshakeNextDoneLocked( if (result == TSI_INCOMPLETE_DATA) { CHECK_EQ(bytes_to_send_size, 0u); grpc_endpoint_read( - args_->endpoint, args_->read_buffer, + args_->endpoint.get(), args_->read_buffer.c_slice_buffer(), GRPC_CLOSURE_INIT( &on_handshake_data_received_from_peer_, &SecurityHandshaker::OnHandshakeDataReceivedFromPeerFnScheduler, @@ -388,6 +369,8 @@ grpc_error_handle SecurityHandshaker::OnHandshakeNextDoneLocked( if (security_connector != nullptr) { connector_type = security_connector->type().name(); } + // TODO(roth): Get a better signal from the TSI layer as to what + // status code we should use here. return GRPC_ERROR_CREATE(absl::StrCat( connector_type, " handshake failed (", tsi_result_to_string(result), ")", (tsi_handshake_error_.empty() ? "" : ": "), tsi_handshake_error_)); @@ -399,12 +382,11 @@ grpc_error_handle SecurityHandshaker::OnHandshakeNextDoneLocked( } if (bytes_to_send_size > 0) { // Send data to peer, if needed. - grpc_slice to_send = grpc_slice_from_copied_buffer( - reinterpret_cast(bytes_to_send), bytes_to_send_size); - grpc_slice_buffer_reset_and_unref(&outgoing_); - grpc_slice_buffer_add(&outgoing_, to_send); + outgoing_.Clear(); + outgoing_.Append(Slice::FromCopiedBuffer( + reinterpret_cast(bytes_to_send), bytes_to_send_size)); grpc_endpoint_write( - args_->endpoint, &outgoing_, + args_->endpoint.get(), outgoing_.c_slice_buffer(), GRPC_CLOSURE_INIT( &on_handshake_data_sent_to_peer_, &SecurityHandshaker::OnHandshakeDataSentToPeerFnScheduler, this, @@ -413,7 +395,7 @@ grpc_error_handle SecurityHandshaker::OnHandshakeNextDoneLocked( } else if (handshaker_result == nullptr) { // There is nothing to send, but need to read from peer. grpc_endpoint_read( - args_->endpoint, args_->read_buffer, + args_->endpoint.get(), args_->read_buffer.c_slice_buffer(), GRPC_CLOSURE_INIT( &on_handshake_data_received_from_peer_, &SecurityHandshaker::OnHandshakeDataReceivedFromPeerFnScheduler, @@ -435,7 +417,7 @@ void SecurityHandshaker::OnHandshakeNextDoneGrpcWrapper( grpc_error_handle error = h->OnHandshakeNextDoneLocked( result, bytes_to_send, bytes_to_send_size, handshaker_result); if (!error.ok()) { - h->HandshakeFailedLocked(error); + h->HandshakeFailedLocked(std::move(error)); } else { h.release(); // Avoid unref } @@ -463,102 +445,102 @@ grpc_error_handle SecurityHandshaker::DoHandshakerNextLocked( } // This callback might be run inline while we are still holding on to the mutex, -// so schedule OnHandshakeDataReceivedFromPeerFn on ExecCtx to avoid a deadlock. +// so run OnHandshakeDataReceivedFromPeerFn asynchronously to avoid a deadlock. +// TODO(roth): This will no longer be necessary once we migrate to the +// EventEngine endpoint API. void SecurityHandshaker::OnHandshakeDataReceivedFromPeerFnScheduler( void* arg, grpc_error_handle error) { - SecurityHandshaker* h = static_cast(arg); - ExecCtx::Run( - DEBUG_LOCATION, - GRPC_CLOSURE_INIT(&h->on_handshake_data_received_from_peer_, - &SecurityHandshaker::OnHandshakeDataReceivedFromPeerFn, - h, grpc_schedule_on_exec_ctx), - error); + SecurityHandshaker* handshaker = static_cast(arg); + handshaker->args_->event_engine->Run( + [handshaker, error = std::move(error)]() mutable { + ApplicationCallbackExecCtx callback_exec_ctx; + ExecCtx exec_ctx; + handshaker->OnHandshakeDataReceivedFromPeerFn(std::move(error)); + }); } -void SecurityHandshaker::OnHandshakeDataReceivedFromPeerFn( - void* arg, grpc_error_handle error) { - RefCountedPtr h(static_cast(arg)); - MutexLock lock(&h->mu_); - if (!error.ok() || h->is_shutdown_) { - h->HandshakeFailedLocked( +void SecurityHandshaker::OnHandshakeDataReceivedFromPeerFn(absl::Status error) { + RefCountedPtr handshaker(this); + MutexLock lock(&mu_); + if (!error.ok() || is_shutdown_) { + HandshakeFailedLocked( GRPC_ERROR_CREATE_REFERENCING("Handshake read failed", &error, 1)); return; } // Copy all slices received. - size_t bytes_received_size = h->MoveReadBufferIntoHandshakeBuffer(); + size_t bytes_received_size = MoveReadBufferIntoHandshakeBuffer(); // Call TSI handshaker. - error = h->DoHandshakerNextLocked(h->handshake_buffer_, bytes_received_size); + error = DoHandshakerNextLocked(handshake_buffer_, bytes_received_size); if (!error.ok()) { - h->HandshakeFailedLocked(error); + HandshakeFailedLocked(std::move(error)); } else { - h.release(); // Avoid unref + handshaker.release(); // Avoid unref } } // This callback might be run inline while we are still holding on to the mutex, -// so schedule OnHandshakeDataSentToPeerFn on ExecCtx to avoid a deadlock. +// so run OnHandshakeDataSentToPeerFn asynchronously to avoid a deadlock. +// TODO(roth): This will no longer be necessary once we migrate to the +// EventEngine endpoint API. void SecurityHandshaker::OnHandshakeDataSentToPeerFnScheduler( void* arg, grpc_error_handle error) { - SecurityHandshaker* h = static_cast(arg); - ExecCtx::Run( - DEBUG_LOCATION, - GRPC_CLOSURE_INIT(&h->on_handshake_data_sent_to_peer_, - &SecurityHandshaker::OnHandshakeDataSentToPeerFn, h, - grpc_schedule_on_exec_ctx), - error); + SecurityHandshaker* handshaker = static_cast(arg); + handshaker->args_->event_engine->Run( + [handshaker, error = std::move(error)]() mutable { + ApplicationCallbackExecCtx callback_exec_ctx; + ExecCtx exec_ctx; + handshaker->OnHandshakeDataSentToPeerFn(std::move(error)); + }); } -void SecurityHandshaker::OnHandshakeDataSentToPeerFn(void* arg, - grpc_error_handle error) { - RefCountedPtr h(static_cast(arg)); - MutexLock lock(&h->mu_); - if (!error.ok() || h->is_shutdown_) { - h->HandshakeFailedLocked( +void SecurityHandshaker::OnHandshakeDataSentToPeerFn(absl::Status error) { + RefCountedPtr handshaker(this); + MutexLock lock(&mu_); + if (!error.ok() || is_shutdown_) { + HandshakeFailedLocked( GRPC_ERROR_CREATE_REFERENCING("Handshake write failed", &error, 1)); return; } // We may be done. - if (h->handshaker_result_ == nullptr) { + if (handshaker_result_ == nullptr) { grpc_endpoint_read( - h->args_->endpoint, h->args_->read_buffer, + args_->endpoint.get(), args_->read_buffer.c_slice_buffer(), GRPC_CLOSURE_INIT( - &h->on_handshake_data_received_from_peer_, + &on_handshake_data_received_from_peer_, &SecurityHandshaker::OnHandshakeDataReceivedFromPeerFnScheduler, - h.get(), grpc_schedule_on_exec_ctx), + this, grpc_schedule_on_exec_ctx), /*urgent=*/true, /*min_progress_size=*/1); } else { - error = h->CheckPeerLocked(); + error = CheckPeerLocked(); if (!error.ok()) { - h->HandshakeFailedLocked(error); + HandshakeFailedLocked(error); return; } } - h.release(); // Avoid unref + handshaker.release(); // Avoid unref } // // public handshaker API // -void SecurityHandshaker::Shutdown(grpc_error_handle why) { +void SecurityHandshaker::Shutdown(grpc_error_handle error) { MutexLock lock(&mu_); if (!is_shutdown_) { is_shutdown_ = true; - connector_->cancel_check_peer(&on_peer_checked_, why); + connector_->cancel_check_peer(&on_peer_checked_, std::move(error)); tsi_handshaker_shutdown(handshaker_); - grpc_endpoint_destroy(args_->endpoint); - args_->endpoint = nullptr; - CleanupArgsForFailureLocked(); + args_->endpoint.reset(); } } -void SecurityHandshaker::DoHandshake(grpc_tcp_server_acceptor* /*acceptor*/, - grpc_closure* on_handshake_done, - HandshakerArgs* args) { +void SecurityHandshaker::DoHandshake( + HandshakerArgs* args, + absl::AnyInvocable on_handshake_done) { auto ref = Ref(); MutexLock lock(&mu_); args_ = args; - on_handshake_done_ = on_handshake_done; + on_handshake_done_ = std::move(on_handshake_done); size_t bytes_received_size = MoveReadBufferIntoHandshakeBuffer(); grpc_error_handle error = DoHandshakerNextLocked(handshake_buffer_, bytes_received_size); @@ -576,19 +558,13 @@ void SecurityHandshaker::DoHandshake(grpc_tcp_server_acceptor* /*acceptor*/, class FailHandshaker : public Handshaker { public: explicit FailHandshaker(absl::Status status) : status_(std::move(status)) {} - const char* name() const override { return "security_fail"; } - void Shutdown(grpc_error_handle /*why*/) override {} - void DoHandshake(grpc_tcp_server_acceptor* /*acceptor*/, - grpc_closure* on_handshake_done, - HandshakerArgs* args) override { - grpc_endpoint_destroy(args->endpoint); - args->endpoint = nullptr; - args->args = ChannelArgs(); - grpc_slice_buffer_destroy(args->read_buffer); - gpr_free(args->read_buffer); - args->read_buffer = nullptr; - ExecCtx::Run(DEBUG_LOCATION, on_handshake_done, status_); + absl::string_view name() const override { return "security_fail"; } + void DoHandshake( + HandshakerArgs* args, + absl::AnyInvocable on_handshake_done) override { + InvokeOnHandshakeDone(args, std::move(on_handshake_done), status_); } + void Shutdown(absl::Status /*error*/) override {} private: ~FailHandshaker() override = default; diff --git a/src/core/handshaker/tcp_connect/tcp_connect_handshaker.cc b/src/core/handshaker/tcp_connect/tcp_connect_handshaker.cc index fd557e47dc5..7822a5cebea 100644 --- a/src/core/handshaker/tcp_connect/tcp_connect_handshaker.cc +++ b/src/core/handshaker/tcp_connect/tcp_connect_handshaker.cc @@ -19,8 +19,10 @@ #include "src/core/handshaker/tcp_connect/tcp_connect_handshaker.h" #include +#include #include "absl/base/thread_annotations.h" +#include "absl/functional/any_invocable.h" #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -61,24 +63,23 @@ namespace { class TCPConnectHandshaker : public Handshaker { public: explicit TCPConnectHandshaker(grpc_pollset_set* pollset_set); - void Shutdown(grpc_error_handle why) override; - void DoHandshake(grpc_tcp_server_acceptor* /*acceptor*/, - grpc_closure* on_handshake_done, - HandshakerArgs* args) override; - const char* name() const override { return "tcp_connect"; } + absl::string_view name() const override { return "tcp_connect"; } + void DoHandshake( + HandshakerArgs* args, + absl::AnyInvocable on_handshake_done) override; + void Shutdown(absl::Status error) override; private: ~TCPConnectHandshaker() override; - void CleanupArgsForFailureLocked() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); - void FinishLocked(grpc_error_handle error) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + void FinishLocked(absl::Status error) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); static void Connected(void* arg, grpc_error_handle error); Mutex mu_; bool shutdown_ ABSL_GUARDED_BY(mu_) = false; - // Endpoint and read buffer to destroy after a shutdown. + // Endpoint to destroy after a shutdown. grpc_endpoint* endpoint_to_destroy_ ABSL_GUARDED_BY(mu_) = nullptr; - grpc_slice_buffer* read_buffer_to_destroy_ ABSL_GUARDED_BY(mu_) = nullptr; - grpc_closure* on_handshake_done_ ABSL_GUARDED_BY(mu_) = nullptr; + absl::AnyInvocable on_handshake_done_ + ABSL_GUARDED_BY(mu_); grpc_pollset_set* interested_parties_ = nullptr; grpc_polling_entity pollent_; HandshakerArgs* args_ = nullptr; @@ -99,33 +100,32 @@ TCPConnectHandshaker::TCPConnectHandshaker(grpc_pollset_set* pollset_set) GRPC_CLOSURE_INIT(&connected_, Connected, this, grpc_schedule_on_exec_ctx); } -void TCPConnectHandshaker::Shutdown(grpc_error_handle /*why*/) { +void TCPConnectHandshaker::Shutdown(absl::Status /*error*/) { // TODO(anramach): After migration to EventEngine, cancel the in-progress // TCP connection attempt. - { - MutexLock lock(&mu_); - if (!shutdown_) { - shutdown_ = true; - // If we are shutting down while connecting, respond back with - // handshake done. - // The callback from grpc_tcp_client_connect will perform - // the necessary clean up. - if (on_handshake_done_ != nullptr) { - CleanupArgsForFailureLocked(); - FinishLocked(GRPC_ERROR_CREATE("tcp handshaker shutdown")); - } + MutexLock lock(&mu_); + if (!shutdown_) { + shutdown_ = true; + // If we are shutting down while connecting, respond back with + // handshake done. + // The callback from grpc_tcp_client_connect will perform + // the necessary clean up. + if (on_handshake_done_ != nullptr) { + // TODO(roth): When we remove the legacy grpc_error APIs, propagate the + // status passed to shutdown as part of the message here. + FinishLocked(GRPC_ERROR_CREATE("tcp handshaker shutdown")); } } } -void TCPConnectHandshaker::DoHandshake(grpc_tcp_server_acceptor* /*acceptor*/, - grpc_closure* on_handshake_done, - HandshakerArgs* args) { +void TCPConnectHandshaker::DoHandshake( + HandshakerArgs* args, + absl::AnyInvocable on_handshake_done) { { MutexLock lock(&mu_); - on_handshake_done_ = on_handshake_done; + on_handshake_done_ = std::move(on_handshake_done); } - CHECK_EQ(args->endpoint, nullptr); + CHECK_EQ(args->endpoint.get(), nullptr); args_ = args; absl::StatusOr uri = URI::Parse( args->args.GetString(GRPC_ARG_TCP_HANDSHAKER_RESOLVED_ADDRESS).value()); @@ -149,7 +149,7 @@ void TCPConnectHandshaker::DoHandshake(grpc_tcp_server_acceptor* /*acceptor*/, Ref().release(); // Ref held by callback. // As we fake the TCP client connection failure when shutdown is called // we don't want to pass args->endpoint directly. - // Instead pass endpoint_ and swap this endpoint to + // Instead pass endpoint_to_destroy_ and swap this endpoint to // args endpoint on success. grpc_tcp_client_connect( &connected_, &endpoint_to_destroy_, interested_parties_, @@ -171,21 +171,19 @@ void TCPConnectHandshaker::Connected(void* arg, grpc_error_handle error) { self->endpoint_to_destroy_ = nullptr; } if (!self->shutdown_) { - self->CleanupArgsForFailureLocked(); self->shutdown_ = true; - self->FinishLocked(error); + self->FinishLocked(std::move(error)); } else { - // The on_handshake_done_ is already as part of shutdown when - // connecting So nothing to be done here other than unrefing the - // error. + // The on_handshake_done_ callback was already invoked as part of + // shutdown when connecting, so nothing to be done here. } return; } CHECK_NE(self->endpoint_to_destroy_, nullptr); - self->args_->endpoint = self->endpoint_to_destroy_; + self->args_->endpoint.reset(self->endpoint_to_destroy_); self->endpoint_to_destroy_ = nullptr; if (self->bind_endpoint_to_pollset_) { - grpc_endpoint_add_to_pollset_set(self->args_->endpoint, + grpc_endpoint_add_to_pollset_set(self->args_->endpoint.get(), self->interested_parties_); } self->FinishLocked(absl::OkStatus()); @@ -196,25 +194,14 @@ TCPConnectHandshaker::~TCPConnectHandshaker() { if (endpoint_to_destroy_ != nullptr) { grpc_endpoint_destroy(endpoint_to_destroy_); } - if (read_buffer_to_destroy_ != nullptr) { - grpc_slice_buffer_destroy(read_buffer_to_destroy_); - gpr_free(read_buffer_to_destroy_); - } grpc_pollset_set_destroy(interested_parties_); } -void TCPConnectHandshaker::CleanupArgsForFailureLocked() { - read_buffer_to_destroy_ = args_->read_buffer; - args_->read_buffer = nullptr; - args_->args = ChannelArgs(); -} - -void TCPConnectHandshaker::FinishLocked(grpc_error_handle error) { +void TCPConnectHandshaker::FinishLocked(absl::Status error) { if (interested_parties_ != nullptr) { grpc_polling_entity_del_from_pollset_set(&pollent_, interested_parties_); } - ExecCtx::Run(DEBUG_LOCATION, on_handshake_done_, error); - on_handshake_done_ = nullptr; + InvokeOnHandshakeDone(args_, std::move(on_handshake_done_), std::move(error)); } // diff --git a/src/core/lib/iomgr/endpoint.h b/src/core/lib/iomgr/endpoint.h index a0f5d8429f5..c4b70abe4f3 100644 --- a/src/core/lib/iomgr/endpoint.h +++ b/src/core/lib/iomgr/endpoint.h @@ -101,6 +101,8 @@ bool grpc_endpoint_can_track_err(grpc_endpoint* ep); struct grpc_endpoint { const grpc_endpoint_vtable* vtable; + + void Orphan() { grpc_endpoint_destroy(this); } }; #endif // GRPC_SRC_CORE_LIB_IOMGR_ENDPOINT_H diff --git a/src/core/util/http_client/httpcli.cc b/src/core/util/http_client/httpcli.cc index 5921496ef24..12f13347470 100644 --- a/src/core/util/http_client/httpcli.cc +++ b/src/core/util/http_client/httpcli.cc @@ -35,6 +35,7 @@ #include #include +#include "src/core/handshaker/handshaker.h" #include "src/core/handshaker/handshaker_registry.h" #include "src/core/handshaker/tcp_connect/tcp_connect_handshaker.h" #include "src/core/lib/address_utils/sockaddr_utils.h" @@ -192,9 +193,7 @@ HttpRequest::HttpRequest( HttpRequest::~HttpRequest() { grpc_channel_args_destroy(channel_args_); grpc_http_parser_destroy(&parser_); - if (own_endpoint_ && ep_ != nullptr) { - grpc_endpoint_destroy(ep_); - } + ep_.reset(); CSliceUnref(request_text_); grpc_iomgr_unregister_object(&iomgr_obj_); grpc_slice_buffer_destroy(&incoming_); @@ -231,10 +230,7 @@ void HttpRequest::Orphan() { handshake_mgr_->Shutdown( GRPC_ERROR_CREATE("HTTP request cancelled during handshake")); } - if (own_endpoint_ && ep_ != nullptr) { - grpc_endpoint_destroy(ep_); - ep_ = nullptr; - } + ep_.reset(); } Unref(); } @@ -288,36 +284,30 @@ void HttpRequest::StartWrite() { CSliceRef(request_text_); grpc_slice_buffer_add(&outgoing_, request_text_); Ref().release(); // ref held by pending write - grpc_endpoint_write(ep_, &outgoing_, &done_write_, nullptr, + grpc_endpoint_write(ep_.get(), &outgoing_, &done_write_, nullptr, /*max_frame_size=*/INT_MAX); } -void HttpRequest::OnHandshakeDone(void* arg, grpc_error_handle error) { - auto* args = static_cast(arg); - RefCountedPtr req(static_cast(args->user_data)); +void HttpRequest::OnHandshakeDone(absl::StatusOr result) { if (g_test_only_on_handshake_done_intercept != nullptr) { // Run this testing intercept before the lock so that it has a chance to // do things like calling Orphan on the request - g_test_only_on_handshake_done_intercept(req.get()); + g_test_only_on_handshake_done_intercept(this); } - MutexLock lock(&req->mu_); - req->own_endpoint_ = true; - if (!error.ok()) { - req->handshake_mgr_.reset(); - req->NextAddress(error); + MutexLock lock(&mu_); + if (!result.ok()) { + handshake_mgr_.reset(); + NextAddress(result.status()); return; } - // Handshake completed, so we own fields in args - grpc_slice_buffer_destroy(args->read_buffer); - gpr_free(args->read_buffer); - req->ep_ = args->endpoint; - req->handshake_mgr_.reset(); - if (req->cancelled_) { - req->NextAddress( - GRPC_ERROR_CREATE("HTTP request cancelled during handshake")); + // Handshake completed, so get the endpoint. + ep_ = std::move((*result)->endpoint); + handshake_mgr_.reset(); + if (cancelled_) { + NextAddress(GRPC_ERROR_CREATE("HTTP request cancelled during handshake")); return; } - req->StartWrite(); + StartWrite(); } void HttpRequest::DoHandshake(const grpc_resolved_address* addr) { @@ -343,13 +333,11 @@ void HttpRequest::DoHandshake(const grpc_resolved_address* addr) { handshake_mgr_ = MakeRefCounted(); CoreConfiguration::Get().handshaker_registry().AddHandshakers( HANDSHAKER_CLIENT, args, pollset_set_, handshake_mgr_.get()); - Ref().release(); // ref held by pending handshake - grpc_endpoint* ep = ep_; - ep_ = nullptr; - own_endpoint_ = false; - handshake_mgr_->DoHandshake(ep, args, deadline_, - /*acceptor=*/nullptr, OnHandshakeDone, - /*user_data=*/this); + handshake_mgr_->DoHandshake( + nullptr, args, deadline_, /*acceptor=*/nullptr, + [self = Ref()](absl::StatusOr result) { + self->OnHandshakeDone(std::move(result)); + }); } void HttpRequest::NextAddress(grpc_error_handle error) { diff --git a/src/core/util/http_client/httpcli.h b/src/core/util/http_client/httpcli.h index 2ad2810f027..7101ebf7ed8 100644 --- a/src/core/util/http_client/httpcli.h +++ b/src/core/util/http_client/httpcli.h @@ -186,7 +186,7 @@ class HttpRequest : public InternallyRefCounted { void DoRead() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) { Ref().release(); // ref held by pending read - grpc_endpoint_read(ep_, &incoming_, &on_read_, /*urgent=*/true, + grpc_endpoint_read(ep_.get(), &incoming_, &on_read_, /*urgent=*/true, /*min_progress_size=*/1); } @@ -221,7 +221,7 @@ class HttpRequest : public InternallyRefCounted { void StartWrite() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); - static void OnHandshakeDone(void* arg, grpc_error_handle error); + void OnHandshakeDone(absl::StatusOr result); void DoHandshake(const grpc_resolved_address* addr) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); @@ -240,7 +240,7 @@ class HttpRequest : public InternallyRefCounted { grpc_closure continue_on_read_after_schedule_on_exec_ctx_; grpc_closure done_write_; grpc_closure continue_done_write_after_schedule_on_exec_ctx_; - grpc_endpoint* ep_ = nullptr; + OrphanablePtr ep_; grpc_closure* on_done_; ResourceQuotaRefPtr resource_quota_; grpc_polling_entity* pollent_; @@ -248,7 +248,6 @@ class HttpRequest : public InternallyRefCounted { const absl::optional> test_only_generate_response_; Mutex mu_; RefCountedPtr handshake_mgr_ ABSL_GUARDED_BY(mu_); - bool own_endpoint_ ABSL_GUARDED_BY(mu_) = true; bool cancelled_ ABSL_GUARDED_BY(mu_) = false; grpc_http_parser parser_ ABSL_GUARDED_BY(mu_); std::vector addresses_ ABSL_GUARDED_BY(mu_); diff --git a/test/core/bad_client/bad_client.cc b/test/core/bad_client/bad_client.cc index c3e1bf920a7..012dd6abddb 100644 --- a/test/core/bad_client/bad_client.cc +++ b/test/core/bad_client/bad_client.cc @@ -227,7 +227,7 @@ void grpc_run_bad_client_test( grpc_core::CoreConfiguration::Get() .channel_args_preconditioning() .PreconditionChannelArgs(server_args.ToC().get()), - sfd.server, false); + grpc_core::OrphanablePtr(sfd.server), false); server_setup_transport(&a, transport); grpc_chttp2_transport_start_reading(transport, nullptr, nullptr, nullptr, nullptr); diff --git a/test/core/bad_connection/close_fd_test.cc b/test/core/bad_connection/close_fd_test.cc index 823cecf78de..a2afa61fd89 100644 --- a/test/core/bad_connection/close_fd_test.cc +++ b/test/core/bad_connection/close_fd_test.cc @@ -36,6 +36,7 @@ #include "src/core/channelz/channelz.h" #include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gprpp/orphanable.h" #include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/iomgr/endpoint.h" #include "src/core/lib/iomgr/error.h" @@ -121,8 +122,9 @@ static void client_setup_transport(grpc_core::Transport* transport) { static void init_client() { grpc_core::ExecCtx exec_ctx; grpc_core::Transport* transport; - transport = grpc_create_chttp2_transport(grpc_core::ChannelArgs(), - g_ctx.ep->client, true); + transport = grpc_create_chttp2_transport( + grpc_core::ChannelArgs(), + grpc_core::OrphanablePtr(g_ctx.ep->client), true); client_setup_transport(transport); CHECK(g_ctx.client); grpc_chttp2_transport_start_reading(transport, nullptr, nullptr, nullptr, @@ -136,8 +138,9 @@ static void init_server() { g_ctx.server = grpc_server_create(nullptr, nullptr); grpc_server_register_completion_queue(g_ctx.server, g_ctx.cq, nullptr); grpc_server_start(g_ctx.server); - transport = grpc_create_chttp2_transport(grpc_core::ChannelArgs(), - g_ctx.ep->server, false); + transport = grpc_create_chttp2_transport( + grpc_core::ChannelArgs(), + grpc_core::OrphanablePtr(g_ctx.ep->server), false); server_setup_transport(transport); grpc_chttp2_transport_start_reading(transport, nullptr, nullptr, nullptr, nullptr); diff --git a/test/core/end2end/fixtures/sockpair_fixture.h b/test/core/end2end/fixtures/sockpair_fixture.h index 93714b96e05..354f052a769 100644 --- a/test/core/end2end/fixtures/sockpair_fixture.h +++ b/test/core/end2end/fixtures/sockpair_fixture.h @@ -79,11 +79,12 @@ class SockpairFixture : public CoreTestFixture { auto server_channel_args = CoreConfiguration::Get() .channel_args_preconditioning() .PreconditionChannelArgs(args.ToC().get()); - auto* server_endpoint = std::exchange(ep_.server, nullptr); + OrphanablePtr server_endpoint( + std::exchange(ep_.server, nullptr)); EXPECT_NE(server_endpoint, nullptr); + grpc_endpoint_add_to_pollset(server_endpoint.get(), grpc_cq_pollset(cq)); transport = grpc_create_chttp2_transport(server_channel_args, - server_endpoint, false); - grpc_endpoint_add_to_pollset(server_endpoint, grpc_cq_pollset(cq)); + std::move(server_endpoint), false); Server* core_server = Server::FromC(server); grpc_error_handle error = core_server->SetupTransport( transport, nullptr, core_server->channel_args(), nullptr); @@ -106,9 +107,11 @@ class SockpairFixture : public CoreTestFixture { .ToC() .get()); Transport* transport; - auto* client_endpoint = std::exchange(ep_.client, nullptr); + OrphanablePtr client_endpoint( + std::exchange(ep_.client, nullptr)); EXPECT_NE(client_endpoint, nullptr); - transport = grpc_create_chttp2_transport(args, client_endpoint, true); + transport = + grpc_create_chttp2_transport(args, std::move(client_endpoint), true); auto channel = ChannelCreate("socketpair-target", args, GRPC_CLIENT_DIRECT_CHANNEL, transport); grpc_channel* client; diff --git a/test/core/end2end/fuzzers/client_fuzzer.cc b/test/core/end2end/fuzzers/client_fuzzer.cc index 23660d3a047..f924c302241 100644 --- a/test/core/end2end/fuzzers/client_fuzzer.cc +++ b/test/core/end2end/fuzzers/client_fuzzer.cc @@ -29,6 +29,7 @@ #include "src/core/lib/config/core_configuration.h" #include "src/core/lib/experiments/config.h" #include "src/core/lib/gprpp/env.h" +#include "src/core/lib/gprpp/orphanable.h" #include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/iomgr/endpoint.h" #include "src/core/lib/iomgr/exec_ctx.h" @@ -69,7 +70,10 @@ class ClientFuzzer final : public BasicFuzzer { .PreconditionChannelArgs(nullptr) .SetIfUnset(GRPC_ARG_DEFAULT_AUTHORITY, "test-authority"); Transport* transport = grpc_create_chttp2_transport( - args, mock_endpoint_controller_->TakeCEndpoint(), true); + args, + OrphanablePtr( + mock_endpoint_controller_->TakeCEndpoint()), + true); channel_ = ChannelCreate("test-target", args, GRPC_CLIENT_DIRECT_CHANNEL, transport) ->release() diff --git a/test/core/end2end/tests/max_connection_idle.cc b/test/core/end2end/tests/max_connection_idle.cc index 466e72ff5b5..ab9371485ab 100644 --- a/test/core/end2end/tests/max_connection_idle.cc +++ b/test/core/end2end/tests/max_connection_idle.cc @@ -90,7 +90,9 @@ CORE_END2END_TEST(RetryHttp2Test, MaxConnectionIdle) { .Set(GRPC_ARG_INITIAL_RECONNECT_BACKOFF_MS, Duration::Seconds(1).millis()) .Set(GRPC_ARG_MAX_RECONNECT_BACKOFF_MS, Duration::Seconds(1).millis()) - .Set(GRPC_ARG_MIN_RECONNECT_BACKOFF_MS, Duration::Seconds(5).millis()) + .Set(GRPC_ARG_MIN_RECONNECT_BACKOFF_MS, + g_is_fuzzing_core_e2e_tests ? Duration::Minutes(5).millis() + : Duration::Seconds(5).millis()) // Avoid transparent retries for this test. .Set(GRPC_ARG_ENABLE_RETRIES, false)); InitServer( diff --git a/test/core/handshake/readahead_handshaker_server_ssl.cc b/test/core/handshake/readahead_handshaker_server_ssl.cc index c5331e1abb4..b15edfc32c6 100644 --- a/test/core/handshake/readahead_handshaker_server_ssl.cc +++ b/test/core/handshake/readahead_handshaker_server_ssl.cc @@ -18,6 +18,8 @@ #include +#include "absl/base/thread_annotations.h" +#include "absl/strings/string_view.h" #include "gtest/gtest.h" #include @@ -28,6 +30,7 @@ #include "src/core/lib/channel/channel_args.h" #include "src/core/lib/config/core_configuration.h" #include "src/core/lib/gprpp/ref_counted_ptr.h" +#include "src/core/lib/gprpp/sync.h" #include "src/core/lib/iomgr/closure.h" #include "src/core/lib/iomgr/endpoint.h" #include "src/core/lib/iomgr/error.h" @@ -49,15 +52,52 @@ namespace grpc_core { class ReadAheadHandshaker : public Handshaker { public: - ~ReadAheadHandshaker() override {} - const char* name() const override { return "read_ahead"; } - void Shutdown(grpc_error_handle /*why*/) override {} - void DoHandshake(grpc_tcp_server_acceptor* /*acceptor*/, - grpc_closure* on_handshake_done, - HandshakerArgs* args) override { - grpc_endpoint_read(args->endpoint, args->read_buffer, on_handshake_done, - /*urgent=*/false, /*min_progress_size=*/1); + absl::string_view name() const override { return "read_ahead"; } + + void DoHandshake( + HandshakerArgs* args, + absl::AnyInvocable on_handshake_done) override { + MutexLock lock(&mu_); + args_ = args; + on_handshake_done_ = std::move(on_handshake_done); + Ref().release(); // Held by callback. + GRPC_CLOSURE_INIT(&on_read_done_, OnReadDone, this, nullptr); + grpc_endpoint_read(args->endpoint.get(), args->read_buffer.c_slice_buffer(), + &on_read_done_, /*urgent=*/false, + /*min_progress_size=*/1); + } + + void Shutdown(absl::Status /*error*/) override { + MutexLock lock(&mu_); + if (on_handshake_done_ != nullptr) args_->endpoint.reset(); } + + private: + static void OnReadDone(void* arg, grpc_error_handle error) { + auto* self = static_cast(arg); + // Need an async hop here, because grpc_endpoint_read() may invoke + // the callback synchronously, leading to deadlock. + // TODO(roth): This async hop will no longer be necessary once we + // switch to the EventEngine endpoint API. + self->args_->event_engine->Run( + [self = RefCountedPtr(self), + error = std::move(error)]() mutable { + absl::AnyInvocable on_handshake_done; + { + MutexLock lock(&self->mu_); + on_handshake_done = std::move(self->on_handshake_done_); + } + on_handshake_done(std::move(error)); + }); + } + + grpc_closure on_read_done_; + + Mutex mu_; + // Mutex guards args_->endpoint but not the rest of the struct. + HandshakerArgs* args_ = nullptr; + absl::AnyInvocable on_handshake_done_ + ABSL_GUARDED_BY(&mu_); }; class ReadAheadHandshakerFactory : public HandshakerFactory { diff --git a/test/core/security/secure_endpoint_test.cc b/test/core/security/secure_endpoint_test.cc index baf90302f69..b09e04c4b19 100644 --- a/test/core/security/secure_endpoint_test.cc +++ b/test/core/security/secure_endpoint_test.cc @@ -161,9 +161,11 @@ static grpc_endpoint_test_fixture secure_endpoint_create_fixture_tcp_socketpair( } if (leftover_nslices == 0) { - f.client_ep = grpc_secure_endpoint_create(fake_read_protector, - fake_read_zero_copy_protector, - tcp.client, nullptr, &args, 0); + f.client_ep = grpc_secure_endpoint_create( + fake_read_protector, fake_read_zero_copy_protector, + grpc_core::OrphanablePtr(tcp.client), + nullptr, &args, 0) + .release(); } else { unsigned i; tsi_result result; @@ -206,15 +208,19 @@ static grpc_endpoint_test_fixture secure_endpoint_create_fixture_tcp_socketpair( reinterpret_cast(encrypted_buffer), total_buffer_size - buffer_size); f.client_ep = grpc_secure_endpoint_create( - fake_read_protector, fake_read_zero_copy_protector, tcp.client, - &encrypted_leftover, &args, 1); + fake_read_protector, fake_read_zero_copy_protector, + grpc_core::OrphanablePtr(tcp.client), + &encrypted_leftover, &args, 1) + .release(); grpc_slice_unref(encrypted_leftover); gpr_free(encrypted_buffer); } - f.server_ep = grpc_secure_endpoint_create(fake_write_protector, - fake_write_zero_copy_protector, - tcp.server, nullptr, &args, 0); + f.server_ep = grpc_secure_endpoint_create( + fake_write_protector, fake_write_zero_copy_protector, + grpc_core::OrphanablePtr(tcp.server), + nullptr, &args, 0) + .release(); grpc_resource_quota_unref( static_cast(a[1].value.pointer.p)); return f; diff --git a/test/core/security/ssl_server_fuzzer.cc b/test/core/security/ssl_server_fuzzer.cc index d4310ebfa8c..b2905467bc8 100644 --- a/test/core/security/ssl_server_fuzzer.cc +++ b/test/core/security/ssl_server_fuzzer.cc @@ -15,7 +15,9 @@ // limitations under the License. // // + #include "absl/log/check.h" +#include "absl/synchronization/notification.h" #include #include @@ -35,6 +37,7 @@ #define SERVER_CERT_PATH "src/core/tsi/test_creds/server1.pem" #define SERVER_KEY_PATH "src/core/tsi/test_creds/server1.key" +using grpc_core::HandshakerArgs; using grpc_event_engine::experimental::EventEngine; using grpc_event_engine::experimental::GetDefaultEventEngine; @@ -43,20 +46,6 @@ bool squelch = true; // Turning this on will fail the leak check. bool leak_check = false; -struct handshake_state { - grpc_core::Notification done_signal; -}; - -static void on_handshake_done(void* arg, grpc_error_handle error) { - grpc_core::HandshakerArgs* args = - static_cast(arg); - struct handshake_state* state = - static_cast(args->user_data); - // The fuzzer should not pass the handshake. - CHECK(!error.ok()); - state->done_signal.Notify(); -} - extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { if (squelch) { grpc_disable_all_absl_logs(); @@ -91,21 +80,26 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { grpc_core::Timestamp deadline = grpc_core::Duration::Seconds(1) + grpc_core::Timestamp::Now(); - struct handshake_state state; auto handshake_mgr = grpc_core::MakeRefCounted(); auto channel_args = grpc_core::ChannelArgs().SetObject(std::move(engine)); sc->add_handshakers(channel_args, nullptr, handshake_mgr.get()); - handshake_mgr->DoHandshake(mock_endpoint_controller->TakeCEndpoint(), + absl::Notification handshake_completed; + handshake_mgr->DoHandshake(grpc_core::OrphanablePtr( + mock_endpoint_controller->TakeCEndpoint()), channel_args, deadline, nullptr /* acceptor */, - on_handshake_done, &state); + [&](absl::StatusOr result) { + // The fuzzer should not pass the handshake. + CHECK(!result.ok()); + handshake_completed.Notify(); + }); grpc_core::ExecCtx::Get()->Flush(); // If the given string happens to be part of the correct client hello, the // server will wait for more data. Explicitly fail the server by shutting // down the handshake manager. - if (!state.done_signal.WaitForNotificationWithTimeout(absl::Seconds(3))) { + if (!handshake_completed.WaitForNotificationWithTimeout(absl::Seconds(3))) { handshake_mgr->Shutdown( absl::DeadlineExceededError("handshake did not fail as expected")); } diff --git a/test/core/surface/channel_init_test.cc b/test/core/surface/channel_init_test.cc index 3dc6b5ec243..398ed607bf8 100644 --- a/test/core/surface/channel_init_test.cc +++ b/test/core/surface/channel_init_test.cc @@ -346,6 +346,5 @@ TEST(ChannelInitTest, CanCreateFilterWithCall) { int main(int argc, char** argv) { grpc::testing::TestEnvironment env(&argc, argv); ::testing::InitGoogleTest(&argc, argv); - grpc::testing::TestGrpcScope grpc_scope; return RUN_ALL_TESTS(); } diff --git a/test/core/transport/chttp2/graceful_shutdown_test.cc b/test/core/transport/chttp2/graceful_shutdown_test.cc index c9040edc798..f4e4da7268f 100644 --- a/test/core/transport/chttp2/graceful_shutdown_test.cc +++ b/test/core/transport/chttp2/graceful_shutdown_test.cc @@ -51,6 +51,7 @@ #include "src/core/lib/channel/channel_args.h" #include "src/core/lib/gprpp/crash.h" #include "src/core/lib/gprpp/notification.h" +#include "src/core/lib/gprpp/orphanable.h" #include "src/core/lib/gprpp/sync.h" #include "src/core/lib/iomgr/closure.h" #include "src/core/lib/iomgr/endpoint.h" @@ -94,8 +95,9 @@ class GracefulShutdownTest : public ::testing::Test { grpc_server_register_completion_queue(server_, cq_, nullptr); grpc_server_start(server_); fds_ = grpc_iomgr_create_endpoint_pair("fixture", nullptr); - auto* transport = grpc_create_chttp2_transport(core_server->channel_args(), - fds_.server, false); + auto* transport = grpc_create_chttp2_transport( + core_server->channel_args(), OrphanablePtr(fds_.server), + false); grpc_endpoint_add_to_pollset(fds_.server, grpc_cq_pollset(cq_)); CHECK(core_server->SetupTransport(transport, nullptr, core_server->channel_args(), diff --git a/test/core/transport/chttp2/ping_configuration_test.cc b/test/core/transport/chttp2/ping_configuration_test.cc index 4b91a3019c1..7524f4872ad 100644 --- a/test/core/transport/chttp2/ping_configuration_test.cc +++ b/test/core/transport/chttp2/ping_configuration_test.cc @@ -57,7 +57,9 @@ TEST_F(ConfigurationTest, ClientKeepaliveDefaults) { ExecCtx exec_ctx; grpc_chttp2_transport* t = reinterpret_cast(grpc_create_chttp2_transport( - args_, mock_endpoint_controller_->TakeCEndpoint(), + args_, + OrphanablePtr( + mock_endpoint_controller_->TakeCEndpoint()), /*is_client=*/true)); EXPECT_EQ(t->keepalive_time, Duration::Infinity()); EXPECT_EQ(t->keepalive_timeout, Duration::Infinity()); @@ -74,7 +76,9 @@ TEST_F(ConfigurationTest, ClientKeepaliveExplicitArgs) { args_ = args_.Set(GRPC_ARG_HTTP2_MAX_PINGS_WITHOUT_DATA, 3); grpc_chttp2_transport* t = reinterpret_cast(grpc_create_chttp2_transport( - args_, mock_endpoint_controller_->TakeCEndpoint(), + args_, + OrphanablePtr( + mock_endpoint_controller_->TakeCEndpoint()), /*is_client=*/true)); EXPECT_EQ(t->keepalive_time, Duration::Seconds(20)); EXPECT_EQ(t->keepalive_timeout, Duration::Seconds(10)); @@ -87,7 +91,9 @@ TEST_F(ConfigurationTest, ServerKeepaliveDefaults) { ExecCtx exec_ctx; grpc_chttp2_transport* t = reinterpret_cast(grpc_create_chttp2_transport( - args_, mock_endpoint_controller_->TakeCEndpoint(), + args_, + OrphanablePtr( + mock_endpoint_controller_->TakeCEndpoint()), /*is_client=*/false)); EXPECT_EQ(t->keepalive_time, Duration::Hours(2)); EXPECT_EQ(t->keepalive_timeout, Duration::Seconds(20)); @@ -111,7 +117,9 @@ TEST_F(ConfigurationTest, ServerKeepaliveExplicitArgs) { args_ = args_.Set(GRPC_ARG_HTTP2_MAX_PING_STRIKES, 0); grpc_chttp2_transport* t = reinterpret_cast(grpc_create_chttp2_transport( - args_, mock_endpoint_controller_->TakeCEndpoint(), + args_, + OrphanablePtr( + mock_endpoint_controller_->TakeCEndpoint()), /*is_client=*/false)); EXPECT_EQ(t->keepalive_time, Duration::Seconds(20)); EXPECT_EQ(t->keepalive_timeout, Duration::Seconds(10)); @@ -140,7 +148,9 @@ TEST_F(ConfigurationTest, ModifyClientDefaults) { // which does not override the defaults. grpc_chttp2_transport* t = reinterpret_cast(grpc_create_chttp2_transport( - args_, mock_endpoint_controller_->TakeCEndpoint(), + args_, + OrphanablePtr( + mock_endpoint_controller_->TakeCEndpoint()), /*is_client=*/true)); EXPECT_EQ(t->keepalive_time, Duration::Seconds(20)); EXPECT_EQ(t->keepalive_timeout, Duration::Seconds(10)); @@ -167,7 +177,9 @@ TEST_F(ConfigurationTest, ModifyServerDefaults) { // which does not override the defaults. grpc_chttp2_transport* t = reinterpret_cast(grpc_create_chttp2_transport( - args_, mock_endpoint_controller_->TakeCEndpoint(), + args_, + OrphanablePtr( + mock_endpoint_controller_->TakeCEndpoint()), /*is_client=*/false)); EXPECT_EQ(t->keepalive_time, Duration::Seconds(20)); EXPECT_EQ(t->keepalive_timeout, Duration::Seconds(10)); diff --git a/test/cpp/microbenchmarks/fullstack_fixtures.h b/test/cpp/microbenchmarks/fullstack_fixtures.h index 30f00e88543..664fa70a25a 100644 --- a/test/cpp/microbenchmarks/fullstack_fixtures.h +++ b/test/cpp/microbenchmarks/fullstack_fixtures.h @@ -181,7 +181,9 @@ class EndpointPairFixture : public BaseFixture { grpc_core::Server::FromC(server_->c_server()); grpc_core::ChannelArgs server_args = core_server->channel_args(); server_transport_ = grpc_create_chttp2_transport( - server_args, endpoints.server, false /* is_client */); + server_args, + grpc_core::OrphanablePtr(endpoints.server), + /*is_client=*/false); for (grpc_pollset* pollset : core_server->pollsets()) { grpc_endpoint_add_to_pollset(endpoints.server, pollset); } @@ -207,8 +209,9 @@ class EndpointPairFixture : public BaseFixture { .channel_args_preconditioning() .PreconditionChannelArgs(&tmp_args); } - client_transport_ = - grpc_create_chttp2_transport(c_args, endpoints.client, true); + client_transport_ = grpc_create_chttp2_transport( + c_args, grpc_core::OrphanablePtr(endpoints.client), + /*is_client=*/true); CHECK(client_transport_); grpc_channel* channel = grpc_core::ChannelCreate("target", c_args, GRPC_CLIENT_DIRECT_CHANNEL, diff --git a/test/cpp/performance/writes_per_rpc_test.cc b/test/cpp/performance/writes_per_rpc_test.cc index ea37648f7dc..54a1d6faadf 100644 --- a/test/cpp/performance/writes_per_rpc_test.cc +++ b/test/cpp/performance/writes_per_rpc_test.cc @@ -17,6 +17,7 @@ // #include +#include #include @@ -118,14 +119,14 @@ class InProcessCHTTP2 { { grpc_core::Server* core_server = grpc_core::Server::FromC(server_->c_server()); - grpc_endpoint* iomgr_server_endpoint = - grpc_event_engine_endpoint_create(std::move(listener_endpoint)); - grpc_core::Transport* transport = grpc_create_chttp2_transport( - core_server->channel_args(), iomgr_server_endpoint, - /*is_client=*/false); + grpc_core::OrphanablePtr iomgr_server_endpoint( + grpc_event_engine_endpoint_create(std::move(listener_endpoint))); for (grpc_pollset* pollset : core_server->pollsets()) { - grpc_endpoint_add_to_pollset(iomgr_server_endpoint, pollset); + grpc_endpoint_add_to_pollset(iomgr_server_endpoint.get(), pollset); } + grpc_core::Transport* transport = grpc_create_chttp2_transport( + core_server->channel_args(), std::move(iomgr_server_endpoint), + /*is_client=*/false); CHECK(GRPC_LOG_IF_ERROR( "SetupTransport", core_server->SetupTransport(transport, nullptr, @@ -143,9 +144,10 @@ class InProcessCHTTP2 { args = args.Set(GRPC_ARG_MAX_RECEIVE_MESSAGE_LENGTH, INT_MAX) .Set(GRPC_ARG_MAX_SEND_MESSAGE_LENGTH, INT_MAX) .Set(GRPC_ARG_HTTP2_BDP_PROBE, 0); + grpc_core::OrphanablePtr endpoint( + grpc_event_engine_endpoint_create(std::move(client_endpoint))); grpc_core::Transport* transport = grpc_create_chttp2_transport( - args, grpc_event_engine_endpoint_create(std::move(client_endpoint)), - /*is_client=*/true); + args, std::move(endpoint), /*is_client=*/true); CHECK(transport); grpc_channel* channel = grpc_core::ChannelCreate("target", args, GRPC_CLIENT_DIRECT_CHANNEL,