From 5d586d3ae339dee8b2ed5022c7181ad1ac8e7e59 Mon Sep 17 00:00:00 2001 From: AJ Heller Date: Thu, 6 Jun 2024 13:21:45 -0700 Subject: [PATCH] [EventEngine] Fix race between connection and its deadline timer on Windows (#36709) Closes #36709 COPYBARA_INTEGRATE_REVIEW=https://github.com/grpc/grpc/pull/36709 from drfloob:fix-win-ee-use-after-free dd4ae2683e14caf4fd339147e8139613e15dfafd PiperOrigin-RevId: 641000031 --- include/grpc/event_engine/event_engine.h | 19 +- src/core/BUILD | 1 + src/core/lib/event_engine/event_engine.cc | 33 ++- src/core/lib/event_engine/thread_local.h | 2 +- src/core/lib/event_engine/trace.h | 8 +- src/core/lib/event_engine/windows/iocp.h | 2 +- .../lib/event_engine/windows/win_socket.cc | 36 ++- .../lib/event_engine/windows/win_socket.h | 9 +- .../event_engine/windows/windows_endpoint.cc | 11 +- .../event_engine/windows/windows_engine.cc | 280 ++++++++++++------ .../lib/event_engine/windows/windows_engine.h | 161 ++++++++-- src/core/lib/gprpp/dump_args.cc | 2 +- .../event_engine/event_engine_test_utils.cc | 11 + .../event_engine/event_engine_test_utils.h | 8 + 14 files changed, 431 insertions(+), 152 deletions(-) diff --git a/include/grpc/event_engine/event_engine.h b/include/grpc/event_engine/event_engine.h index 7e2cd0346aa..add6593aa7b 100644 --- a/include/grpc/event_engine/event_engine.h +++ b/include/grpc/event_engine/event_engine.h @@ -132,8 +132,6 @@ class EventEngine : public std::enable_shared_from_this, struct TaskHandle { intptr_t keys[2]; static const GRPC_DLL TaskHandle kInvalid; - friend bool operator==(const TaskHandle& lhs, const TaskHandle& rhs); - friend bool operator!=(const TaskHandle& lhs, const TaskHandle& rhs); }; /// A handle to a cancellable connection attempt. /// @@ -141,10 +139,6 @@ class EventEngine : public std::enable_shared_from_this, struct ConnectionHandle { intptr_t keys[2]; static const GRPC_DLL ConnectionHandle kInvalid; - friend bool operator==(const ConnectionHandle& lhs, - const ConnectionHandle& rhs); - friend bool operator!=(const ConnectionHandle& lhs, - const ConnectionHandle& rhs); }; /// Thin wrapper around a platform-specific sockaddr type. A sockaddr struct /// exists on all platforms that gRPC supports. @@ -496,6 +490,19 @@ void EventEngineFactoryReset(); /// Create an EventEngine using the default factory. std::unique_ptr CreateEventEngine(); +bool operator==(const EventEngine::TaskHandle& lhs, + const EventEngine::TaskHandle& rhs); +bool operator!=(const EventEngine::TaskHandle& lhs, + const EventEngine::TaskHandle& rhs); +std::ostream& operator<<(std::ostream& out, + const EventEngine::TaskHandle& handle); +bool operator==(const EventEngine::ConnectionHandle& lhs, + const EventEngine::ConnectionHandle& rhs); +bool operator!=(const EventEngine::ConnectionHandle& lhs, + const EventEngine::ConnectionHandle& rhs); +std::ostream& operator<<(std::ostream& out, + const EventEngine::ConnectionHandle& handle); + } // namespace experimental } // namespace grpc_event_engine diff --git a/src/core/BUILD b/src/core/BUILD index 79359edb982..5b06f6e331d 100644 --- a/src/core/BUILD +++ b/src/core/BUILD @@ -2493,6 +2493,7 @@ grpc_cc_library( "ares_resolver", "channel_args_endpoint_config", "common_event_engine_closures", + "dump_args", "error", "event_engine_common", "event_engine_tcp_socket_utils", diff --git a/src/core/lib/event_engine/event_engine.cc b/src/core/lib/event_engine/event_engine.cc index 6815e2eb50b..74f24fdc9ce 100644 --- a/src/core/lib/event_engine/event_engine.cc +++ b/src/core/lib/event_engine/event_engine.cc @@ -11,6 +11,8 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. +#include "absl/strings/str_cat.h" + #include #include @@ -21,24 +23,47 @@ const EventEngine::TaskHandle EventEngine::TaskHandle::kInvalid = {-1, -1}; const EventEngine::ConnectionHandle EventEngine::ConnectionHandle::kInvalid = { -1, -1}; +namespace { +template +bool eq(const T& lhs, const T& rhs) { + return lhs.keys[0] == rhs.keys[0] && lhs.keys[1] == rhs.keys[1]; +} +template +std::ostream& printout(std::ostream& out, const T& handle) { + out << absl::StrCat("{", absl::Hex(handle.keys[0], absl::kZeroPad16), ",", + absl::Hex(handle.keys[1], absl::kZeroPad16), "}"); + return out; +} +} // namespace + bool operator==(const EventEngine::TaskHandle& lhs, const EventEngine::TaskHandle& rhs) { - return lhs.keys[0] == rhs.keys[0] && lhs.keys[1] == rhs.keys[1]; + return eq(lhs, rhs); } bool operator!=(const EventEngine::TaskHandle& lhs, const EventEngine::TaskHandle& rhs) { - return !(lhs == rhs); + return !eq(lhs, rhs); +} + +std::ostream& operator<<(std::ostream& out, + const EventEngine::TaskHandle& handle) { + return printout(out, handle); } bool operator==(const EventEngine::ConnectionHandle& lhs, const EventEngine::ConnectionHandle& rhs) { - return lhs.keys[0] == rhs.keys[0] && lhs.keys[1] == rhs.keys[1]; + return eq(lhs, rhs); } bool operator!=(const EventEngine::ConnectionHandle& lhs, const EventEngine::ConnectionHandle& rhs) { - return !(lhs == rhs); + return !eq(lhs, rhs); +} + +std::ostream& operator<<(std::ostream& out, + const EventEngine::ConnectionHandle& handle) { + return printout(out, handle); } } // namespace experimental diff --git a/src/core/lib/event_engine/thread_local.h b/src/core/lib/event_engine/thread_local.h index 986df908aca..62eeed3b08d 100644 --- a/src/core/lib/event_engine/thread_local.h +++ b/src/core/lib/event_engine/thread_local.h @@ -29,4 +29,4 @@ class ThreadLocal { } // namespace experimental } // namespace grpc_event_engine -#endif // GRPC_SRC_CORE_LIB_EVENT_ENGINE_THREAD_LOCAL_H \ No newline at end of file +#endif // GRPC_SRC_CORE_LIB_EVENT_ENGINE_THREAD_LOCAL_H diff --git a/src/core/lib/event_engine/trace.h b/src/core/lib/event_engine/trace.h index 3d0a9e1d5df..a81302281d5 100644 --- a/src/core/lib/event_engine/trace.h +++ b/src/core/lib/event_engine/trace.h @@ -27,22 +27,22 @@ extern grpc_core::TraceFlag grpc_event_engine_endpoint_trace; #define GRPC_EVENT_ENGINE_TRACE(format, ...) \ if (GRPC_TRACE_FLAG_ENABLED(grpc_event_engine_trace)) { \ - gpr_log(GPR_DEBUG, "(event_engine) " format, __VA_ARGS__); \ + gpr_log(GPR_ERROR, "(event_engine) " format, __VA_ARGS__); \ } #define GRPC_EVENT_ENGINE_ENDPOINT_TRACE(format, ...) \ if (GRPC_TRACE_FLAG_ENABLED(grpc_event_engine_endpoint_trace)) { \ - gpr_log(GPR_DEBUG, "(event_engine endpoint) " format, __VA_ARGS__); \ + gpr_log(GPR_ERROR, "(event_engine endpoint) " format, __VA_ARGS__); \ } #define GRPC_EVENT_ENGINE_POLLER_TRACE(format, ...) \ if (GRPC_TRACE_FLAG_ENABLED(grpc_event_engine_poller_trace)) { \ - gpr_log(GPR_DEBUG, "(event_engine poller) " format, __VA_ARGS__); \ + gpr_log(GPR_ERROR, "(event_engine poller) " format, __VA_ARGS__); \ } #define GRPC_EVENT_ENGINE_DNS_TRACE(format, ...) \ if (GRPC_TRACE_FLAG_ENABLED(grpc_event_engine_dns_trace)) { \ - gpr_log(GPR_DEBUG, "(event_engine dns) " format, __VA_ARGS__); \ + gpr_log(GPR_ERROR, "(event_engine dns) " format, __VA_ARGS__); \ } #endif // GRPC_SRC_CORE_LIB_EVENT_ENGINE_TRACE_H diff --git a/src/core/lib/event_engine/windows/iocp.h b/src/core/lib/event_engine/windows/iocp.h index c36b9582122..9c972ed0f92 100644 --- a/src/core/lib/event_engine/windows/iocp.h +++ b/src/core/lib/event_engine/windows/iocp.h @@ -32,7 +32,7 @@ namespace experimental { class IOCP final : public Poller { public: explicit IOCP(ThreadPool* thread_pool) noexcept; - ~IOCP(); + ~IOCP() override; // Not copyable IOCP(const IOCP&) = delete; IOCP& operator=(const IOCP&) = delete; diff --git a/src/core/lib/event_engine/windows/win_socket.cc b/src/core/lib/event_engine/windows/win_socket.cc index 883efc6b0b8..c5dc520cd7c 100644 --- a/src/core/lib/event_engine/windows/win_socket.cc +++ b/src/core/lib/event_engine/windows/win_socket.cc @@ -69,13 +69,19 @@ void WinSocket::Shutdown() { int status = WSAIoctl(socket_, SIO_GET_EXTENSION_FUNCTION_POINTER, &guid, sizeof(guid), &DisconnectEx, sizeof(DisconnectEx), &ioctl_num_bytes, NULL, NULL); - - if (status == 0) { - DisconnectEx(socket_, NULL, 0, 0); - } else { + if (status != 0) { char* utf8_message = gpr_format_message(WSAGetLastError()); LOG(INFO) << "Unable to retrieve DisconnectEx pointer : " << utf8_message; gpr_free(utf8_message); + } else if (DisconnectEx(socket_, NULL, 0, 0) == FALSE) { + auto last_error = WSAGetLastError(); + // DisconnectEx may be called when the socket is not connected. Ignore that + // error, and log all others. + if (last_error != WSAENOTCONN) { + char* utf8_message = gpr_format_message(last_error); + LOG(INFO) << "DisconnectEx failed: " << utf8_message; + gpr_free(utf8_message); + } } closesocket(socket_); GRPC_EVENT_ENGINE_ENDPOINT_TRACE("WinSocket::%p socket closed", this); @@ -91,7 +97,7 @@ void WinSocket::Shutdown(const grpc_core::DebugLocation& location, void WinSocket::NotifyOnReady(OpState& info, EventEngine::Closure* closure) { if (IsShutdown()) { - info.SetError(WSAESHUTDOWN); + info.SetResult(WSAESHUTDOWN, 0, "NotifyOnReady"); thread_pool_->Run(closure); return; }; @@ -130,12 +136,13 @@ void WinSocket::OpState::SetReady() { win_socket_->thread_pool_->Run(closure); } -void WinSocket::OpState::SetError(int wsa_error) { - result_ = OverlappedResult{/*wsa_error=*/wsa_error, /*bytes_transferred=*/0}; -} - -void WinSocket::OpState::SetResult(OverlappedResult result) { - result_ = result; +void WinSocket::OpState::SetResult(int wsa_error, DWORD bytes, + absl::string_view context) { + bytes = wsa_error == 0 ? bytes : 0; + result_ = OverlappedResult{ + /*wsa_error=*/wsa_error, /*bytes_transferred=*/bytes, + /*error_status=*/wsa_error == 0 ? absl::OkStatus() + : GRPC_WSA_ERROR(wsa_error, context)}; } void WinSocket::OpState::SetErrorStatus(absl::Status error_status) { @@ -149,16 +156,15 @@ void WinSocket::OpState::GetOverlappedResult() { void WinSocket::OpState::GetOverlappedResult(SOCKET sock) { if (win_socket_->IsShutdown()) { - result_ = OverlappedResult{/*wsa_error=*/WSA_OPERATION_ABORTED, - /*bytes_transferred=*/0}; + SetResult(WSA_OPERATION_ABORTED, 0, "GetOverlappedResult"); return; } DWORD flags = 0; DWORD bytes; BOOL success = WSAGetOverlappedResult(sock, &overlapped_, &bytes, FALSE, &flags); - result_ = OverlappedResult{/*wsa_error=*/success ? 0 : WSAGetLastError(), - /*bytes_transferred=*/bytes}; + auto wsa_error = success ? 0 : WSAGetLastError(); + SetResult(wsa_error, bytes, "WSAGetOverlappedResult"); } bool WinSocket::IsShutdown() { return is_shutdown_.load(); } diff --git a/src/core/lib/event_engine/windows/win_socket.h b/src/core/lib/event_engine/windows/win_socket.h index c6d52f41def..be453d26e6d 100644 --- a/src/core/lib/event_engine/windows/win_socket.h +++ b/src/core/lib/event_engine/windows/win_socket.h @@ -47,12 +47,11 @@ class WinSocket { // the WinSocket's ThreadPool. Otherwise, a "pending iocp" flag will // be set. void SetReady(); - // Set WSA error results for a completed op. - void SetError(int wsa_error); - // Set an OverlappedResult. Useful when WSARecv returns immediately. - void SetResult(OverlappedResult result); + // Set WSA result for a completed op. + // If the error is non-zero, bytes will be overridden to 0. + void SetResult(int wsa_error, DWORD bytes, absl::string_view context); // Set error results for a completed op. - // This is a manual override, meant to override any WSA status code. + // This is a manual override, meant to ignore any WSA status code. void SetErrorStatus(absl::Status error_status); // Retrieve the results of an overlapped operation (via Winsock API) and // store them locally. diff --git a/src/core/lib/event_engine/windows/windows_endpoint.cc b/src/core/lib/event_engine/windows/windows_endpoint.cc index c3edbebaa47..7ade5c95b20 100644 --- a/src/core/lib/event_engine/windows/windows_endpoint.cc +++ b/src/core/lib/event_engine/windows/windows_endpoint.cc @@ -102,8 +102,7 @@ void WindowsEndpoint::AsyncIOState::DoTcpRead(SliceBuffer* buffer) { int wsa_error = status == 0 ? 0 : WSAGetLastError(); if (wsa_error != WSAEWOULDBLOCK) { // Data or some error was returned immediately. - socket->read_info()->SetResult( - {/*wsa_error=*/wsa_error, /*bytes_read=*/bytes_read}); + socket->read_info()->SetResult(wsa_error, bytes_read, "WSARecv"); thread_pool->Run(&handle_read_event); return; } @@ -120,9 +119,8 @@ void WindowsEndpoint::AsyncIOState::DoTcpRead(SliceBuffer* buffer) { if (wsa_error != 0 && wsa_error != WSA_IO_PENDING) { // The async read attempt returned an error immediately. socket->UnregisterReadCallback(); - socket->read_info()->SetErrorStatus(GRPC_WSA_ERROR( - wsa_error, - absl::StrFormat("WindowsEndpont::%p Read failed", this).c_str())); + socket->read_info()->SetResult( + wsa_error, 0, absl::StrFormat("WindowsEndpont::%p Read failed", this)); thread_pool->Run(&handle_read_event); } } @@ -220,8 +218,7 @@ bool WindowsEndpoint::Write(absl::AnyInvocable on_writable, int wsa_error = WSAGetLastError(); if (wsa_error != WSA_IO_PENDING) { io_state_->socket->UnregisterWriteCallback(); - io_state_->socket->write_info()->SetErrorStatus( - GRPC_WSA_ERROR(wsa_error, "WSASend")); + io_state_->socket->write_info()->SetResult(wsa_error, 0, "WSASend"); io_state_->thread_pool->Run(&io_state_->handle_write_event); } } diff --git a/src/core/lib/event_engine/windows/windows_engine.cc b/src/core/lib/event_engine/windows/windows_engine.cc index 69a0f94ff53..b674606c9b7 100644 --- a/src/core/lib/event_engine/windows/windows_engine.cc +++ b/src/core/lib/event_engine/windows/windows_engine.cc @@ -16,8 +16,10 @@ #ifdef GPR_WINDOWS #include +#include #include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" @@ -43,6 +45,7 @@ #include "src/core/lib/event_engine/windows/windows_engine.h" #include "src/core/lib/event_engine/windows/windows_listener.h" #include "src/core/lib/gprpp/crash.h" +#include "src/core/lib/gprpp/dump_args.h" #include "src/core/lib/gprpp/sync.h" #include "src/core/lib/gprpp/time.h" #include "src/core/lib/iomgr/error.h" @@ -50,13 +53,104 @@ namespace grpc_event_engine { namespace experimental { -namespace { -EventEngine::OnConnectCallback CreateCrashingOnConnectCallback() { - return [](absl::StatusOr>) { - grpc_core::Crash("Internal Error: OnConnect callback called when unset"); - }; +std::ostream& operator<<( + std::ostream& out, + const WindowsEventEngine::ConnectionState& connection_state) { + out << "ConnectionState::" << &connection_state + << ": connection_state.address=" + << ResolvedAddressToURI(connection_state.address_) << "," + << GRPC_DUMP_ARGS(connection_state.has_run_, + connection_state.connection_handle_, + connection_state.timer_handle_); + return out; } -} // namespace + +// ---- ConnectionState ---- + +WindowsEventEngine::ConnectionState::ConnectionState( + std::shared_ptr engine, + std::unique_ptr socket, EventEngine::ResolvedAddress address, + MemoryAllocator allocator, + EventEngine::OnConnectCallback on_connect_user_callback) + : socket_(std::move(socket)), + address_(address), + allocator_(std::move(allocator)), + on_connect_user_callback_(std::move(on_connect_user_callback)), + engine_(std::move(engine)) { + CHECK(socket_ != nullptr); + connection_handle_ = ConnectionHandle{reinterpret_cast(this), + engine_->aba_token_.fetch_add(1)}; +} + +void WindowsEventEngine::ConnectionState::Start(Duration timeout) { + on_connected_cb_ = + std::make_unique(engine_.get(), shared_from_this()); + socket_->NotifyOnWrite(on_connected_cb_.get()); + deadline_timer_cb_ = std::make_unique( + engine_.get(), shared_from_this()); + timer_handle_ = engine_->RunAfter(timeout, deadline_timer_cb_.get()); +} + +EventEngine::OnConnectCallback +WindowsEventEngine::ConnectionState::TakeCallback() { + return std::exchange(on_connect_user_callback_, nullptr); +} + +std::unique_ptr +WindowsEventEngine::ConnectionState::FinishConnectingAndMakeEndpoint( + ThreadPool* thread_pool) { + ChannelArgsEndpointConfig cfg; + return std::make_unique(address_, std::move(socket_), + std::move(allocator_), cfg, + thread_pool, engine_); +} + +void WindowsEventEngine::ConnectionState::AbortOnConnect() { + on_connected_cb_.reset(); +} + +void WindowsEventEngine::ConnectionState::AbortDeadlineTimer() { + deadline_timer_cb_.reset(); +} + +void WindowsEventEngine::ConnectionState::OnConnectedCallback::Run() { + DCHECK_NE(connection_state_, nullptr) + << "ConnectionState::OnConnectedCallback::" << this + << " has already run. It should only ever run once."; + bool has_run; + { + grpc_core::MutexLock lock(&connection_state_->mu_); + has_run = std::exchange(connection_state_->has_run_, true); + } + // This could race with the deadline timer. If so, the engine's + // OnConnectCompleted callback should not run, and the refs should be + // released. + if (has_run) { + connection_state_.reset(); + return; + } + engine_->OnConnectCompleted(std::move(connection_state_)); +} + +void WindowsEventEngine::ConnectionState::DeadlineTimerCallback::Run() { + DCHECK_NE(connection_state_, nullptr) + << "ConnectionState::DeadlineTimerCallback::" << this + << " has already run. It should only ever run once."; + bool has_run; + { + grpc_core::MutexLock lock(&connection_state_->mu_); + has_run = std::exchange(connection_state_->has_run_, true); + } + // This could race with the on connected callback. If so, the engine's + // OnDeadlineTimerFired callback should not run, and the refs should be + // released. + if (has_run) { + connection_state_.reset(); + return; + } + engine_->OnDeadlineTimerFired(std::move(connection_state_)); +} + // ---- IOCPWorkClosure ---- WindowsEventEngine::IOCPWorkClosure::IOCPWorkClosure(ThreadPool* thread_pool, @@ -66,6 +160,7 @@ WindowsEventEngine::IOCPWorkClosure::IOCPWorkClosure(ThreadPool* thread_pool, } void WindowsEventEngine::IOCPWorkClosure::Run() { + if (done_signal_.HasBeenNotified()) return; auto result = iocp_->Work(std::chrono::seconds(60), [this] { workers_.fetch_add(1); thread_pool_->Run(this); @@ -256,35 +351,54 @@ void WindowsEventEngine::OnConnectCompleted( EventEngine::OnConnectCallback cb; { // Connection attempt complete! - grpc_core::MutexLock lock(&state->mu); - cb = std::move(state->on_connected_user_callback); - state->on_connected_user_callback = CreateCrashingOnConnectCallback(); - state->on_connected = nullptr; + grpc_core::MutexLock lock(&state->mu()); + // return early if we cannot cancel the connection timeout timer. + int erased_handles = 0; { grpc_core::MutexLock handle_lock(&connection_mu_); - known_connection_handles_.erase(state->connection_handle); + erased_handles = + known_connection_handles_.erase(state->connection_handle()); } - const auto& overlapped_result = state->socket->write_info()->result(); - // return early if we cannot cancel the connection timeout timer. - if (!Cancel(state->timer_handle)) return; + if (erased_handles != 1 || !Cancel(state->timer_handle())) { + GRPC_EVENT_ENGINE_TRACE( + "%s", "Not accepting connection since the deadline timer has fired"); + return; + } + // Release refs held by the deadline timer. + state->AbortDeadlineTimer(); + const auto& overlapped_result = state->socket()->write_info()->result(); if (!overlapped_result.error_status.ok()) { - state->socket->Shutdown(DEBUG_LOCATION, "ConnectEx failure"); + state->socket()->Shutdown(DEBUG_LOCATION, "ConnectEx failure"); endpoint = overlapped_result.error_status; } else if (overlapped_result.wsa_error != 0) { - state->socket->Shutdown(DEBUG_LOCATION, "ConnectEx failure"); + state->socket()->Shutdown(DEBUG_LOCATION, "ConnectEx failure"); endpoint = GRPC_WSA_ERROR(overlapped_result.wsa_error, "ConnectEx"); } else { - ChannelArgsEndpointConfig cfg; - endpoint = std::make_unique( - state->address, std::move(state->socket), std::move(state->allocator), - cfg, thread_pool_.get(), shared_from_this()); + endpoint = state->FinishConnectingAndMakeEndpoint(thread_pool_.get()); } + cb = state->TakeCallback(); } // This code should be running in a thread pool thread already, so the // callback can be run directly. + state.reset(); cb(std::move(endpoint)); } +void WindowsEventEngine::OnDeadlineTimerFired( + std::shared_ptr connection_state) { + bool cancelled = false; + EventEngine::OnConnectCallback cb; + { + grpc_core::MutexLock lock(&connection_state->mu()); + cancelled = CancelConnectFromDeadlineTimer(connection_state.get()); + if (cancelled) cb = connection_state->TakeCallback(); + } + if (cancelled) { + connection_state.reset(); + cb(absl::DeadlineExceededError("Connection timed out")); + } +} + EventEngine::ConnectionHandle WindowsEventEngine::Connect( OnConnectCallback on_connect, const ResolvedAddress& addr, const EndpointConfig& /* args */, MemoryAllocator memory_allocator, @@ -367,65 +481,61 @@ EventEngine::ConnectionHandle WindowsEventEngine::Connect( return EventEngine::ConnectionHandle::kInvalid; } // Prepare the socket to receive a connection - auto connection_state = std::make_shared(); - grpc_core::MutexLock lock(&connection_state->mu); - connection_state->socket = iocp_.Watch(sock); - CHECK(connection_state->socket != nullptr); - auto* info = connection_state->socket->write_info(); - connection_state->address = address; - connection_state->allocator = std::move(memory_allocator); - connection_state->on_connected_user_callback = std::move(on_connect); - connection_state->on_connected = - SelfDeletingClosure::Create([this, connection_state]() mutable { - OnConnectCompleted(std::move(connection_state)); - }); - connection_state->timer_handle = - RunAfter(timeout, [this, connection_state]() { - grpc_core::ReleasableMutexLock lock(&connection_state->mu); - if (CancelConnectFromDeadlineTimer(connection_state.get())) { - auto cb = std::move(connection_state->on_connected_user_callback); - connection_state->on_connected_user_callback = - CreateCrashingOnConnectCallback(); - lock.Release(); - cb(absl::DeadlineExceededError("Connection timed out")); - } - // else: The connection attempt could not be canceled. We can assume - // the connection callback will be called. - }); - // Connect - connection_state->socket->NotifyOnWrite(connection_state->on_connected); + auto connection_state = std::make_shared( + std::static_pointer_cast(shared_from_this()), + /*socket=*/iocp_.Watch(sock), address, + /*memory_allocator=*/std::move(memory_allocator), + /*on_connect_user_callback=*/std::move(on_connect)); + grpc_core::MutexLock lock(&connection_state->mu()); + auto* info = connection_state->socket()->write_info(); + { + grpc_core::MutexLock connection_handle_lock(&connection_mu_); + known_connection_handles_.insert(connection_state->connection_handle()); + } + connection_state->Start(timeout); bool success = - ConnectEx(connection_state->socket->raw_socket(), address.address(), + ConnectEx(connection_state->socket()->raw_socket(), address.address(), address.size(), nullptr, 0, nullptr, info->overlapped()); // It wouldn't be unusual to get a success immediately. But we'll still get an // IOCP notification, so let's ignore it. - if (!success) { - int last_error = WSAGetLastError(); - if (last_error != ERROR_IO_PENDING) { - if (!Cancel(connection_state->timer_handle)) { - return EventEngine::ConnectionHandle::kInvalid; - } - connection_state->socket->Shutdown(DEBUG_LOCATION, "ConnectEx"); - Run([connection_state = std::move(connection_state), - status = GRPC_WSA_ERROR(WSAGetLastError(), "ConnectEx")]() mutable { - EventEngine::OnConnectCallback cb; - { - grpc_core::MutexLock lock(&connection_state->mu); - cb = std::move(connection_state->on_connected_user_callback); - connection_state->on_connected_user_callback = - CreateCrashingOnConnectCallback(); - } - cb(status); - }); - return EventEngine::ConnectionHandle::kInvalid; - } + if (success) return connection_state->connection_handle(); + // Otherwise, we need to handle an error or a pending IO Event. + int last_error = WSAGetLastError(); + if (last_error == ERROR_IO_PENDING) { + // Overlapped I/O operation is in progress. + return connection_state->connection_handle(); + } + // Time to abort the connection. + // The on-connect callback won't run, so we must clean up its state. + connection_state->AbortOnConnect(); + int erased_handles = 0; + { + grpc_core::MutexLock connection_handle_lock(&connection_mu_); + erased_handles = + known_connection_handles_.erase(connection_state->connection_handle()); + } + CHECK_EQ(erased_handles, 1) << "Did not find connection handle " + << connection_state->connection_handle() + << " after a synchronous connection failure. " + "This should not be possible."; + connection_state->socket()->Shutdown(DEBUG_LOCATION, "ConnectEx"); + if (!Cancel(connection_state->timer_handle())) { + // The deadline timer will run, or is running. + return EventEngine::ConnectionHandle::kInvalid; } - connection_state->connection_handle = - ConnectionHandle{reinterpret_cast(connection_state.get()), - aba_token_.fetch_add(1)}; - grpc_core::MutexLock connection_handle_lock(&connection_mu_); - known_connection_handles_.insert(connection_state->connection_handle); - return connection_state->connection_handle; + // The deadline timer won't run, so we must clean up its state. + connection_state->AbortDeadlineTimer(); + Run([connection_state = std::move(connection_state), + status = GRPC_WSA_ERROR(WSAGetLastError(), "ConnectEx")]() mutable { + EventEngine::OnConnectCallback cb; + { + grpc_core::MutexLock lock(&connection_state->mu()); + cb = connection_state->TakeCallback(); + } + connection_state.reset(); + cb(std::move(status)); + }); + return EventEngine::ConnectionHandle::kInvalid; } bool WindowsEventEngine::CancelConnect(EventEngine::ConnectionHandle handle) { @@ -437,17 +547,20 @@ bool WindowsEventEngine::CancelConnect(EventEngine::ConnectionHandle handle) { // Erase the connection handle, which may be unknown { grpc_core::MutexLock lock(&connection_mu_); - if (!known_connection_handles_.contains(handle)) { + if (known_connection_handles_.erase(handle) != 1) { GRPC_EVENT_ENGINE_TRACE( "Unknown connection handle: %s", HandleToString(handle).c_str()); return false; } - known_connection_handles_.erase(handle); } auto* connection_state = reinterpret_cast(handle.keys[0]); - grpc_core::MutexLock state_lock(&connection_state->mu); - if (!Cancel(connection_state->timer_handle)) return false; + grpc_core::MutexLock state_lock(&connection_state->mu()); + // The connection cannot be cancelled if the deadline timer is already firing. + if (!Cancel(connection_state->timer_handle())) return false; + // The deadline timer was cancelled, so we must clean up its state. + connection_state->AbortDeadlineTimer(); + // The on-connect callback will run when the socket shutdown event occurs. return CancelConnectInternalStateLocked(connection_state); } @@ -456,20 +569,21 @@ bool WindowsEventEngine::CancelConnectFromDeadlineTimer( // Erase the connection handle, which is guaranteed to exist. { grpc_core::MutexLock lock(&connection_mu_); - CHECK(known_connection_handles_.erase( - connection_state->connection_handle) == 1); + if (known_connection_handles_.erase( + connection_state->connection_handle()) != 1) { + return false; + } } return CancelConnectInternalStateLocked(connection_state); } bool WindowsEventEngine::CancelConnectInternalStateLocked( ConnectionState* connection_state) { - connection_state->socket->Shutdown(DEBUG_LOCATION, "CancelConnect"); + connection_state->socket()->Shutdown(DEBUG_LOCATION, "CancelConnect"); // Release the connection_state shared_ptr owned by the connected callback. - delete connection_state->on_connected; GRPC_EVENT_ENGINE_TRACE("Successfully cancelled connection %s", HandleToString( - connection_state->connection_handle) + connection_state->connection_handle()) .c_str()); return true; } diff --git a/src/core/lib/event_engine/windows/windows_engine.h b/src/core/lib/event_engine/windows/windows_engine.h index 150e7facf3f..2047ab6cd84 100644 --- a/src/core/lib/event_engine/windows/windows_engine.h +++ b/src/core/lib/event_engine/windows/windows_engine.h @@ -44,8 +44,6 @@ namespace grpc_event_engine { namespace experimental { -// TODO(ctiller): KeepsGrpcInitialized is an interim measure to ensure that -// EventEngine is shut down before we shut down iomgr. class WindowsEventEngine : public EventEngine, public grpc_core::KeepsGrpcInitialized { public: @@ -105,24 +103,133 @@ class WindowsEventEngine : public EventEngine, IOCP* poller() { return &iocp_; } private: - // State of an active connection. - // Managed by a shared_ptr, owned exclusively by the timeout callback and the - // OnConnectCompleted callback herein. - struct ConnectionState { - // everything is guarded by mu; - grpc_core::Mutex mu + // The state of an active connection. + // + // This object is managed by a shared_ptr, which is owned by: + // 1) the deadline timer callback, and + // 2) the OnConnectCompleted callback. + class ConnectionState : public std::enable_shared_from_this { + public: + ConnectionState(std::shared_ptr engine, + std::unique_ptr socket, + EventEngine::ResolvedAddress address, + MemoryAllocator allocator, + EventEngine::OnConnectCallback on_connect_user_callback); + + // Starts the deadline timer, and sets up the socket to notify on writes. + // + // This cannot be done in the constructor since shared_from_this is required + // for the callbacks to hold a ref to this object. + void Start(Duration timeout) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Returns the user's callback and resets it to nullptr to ensure it only + // runs once. + OnConnectCallback TakeCallback() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Create an Endpoint, transfering held object ownership to the endpoint. + // + // This can only be called once, and the connection state is no longer valid + // after an endpoint has been created. Callers must guarantee that the + // deadline timer callback will not be run. + std::unique_ptr FinishConnectingAndMakeEndpoint( + ThreadPool* thread_pool) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Release all refs to the on-connect callback. + void AbortOnConnect() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + // Release all refs to the deadline timer callback. + void AbortDeadlineTimer() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // TODO(hork): this is unsafe. Whatever needs the socket should likely + // delegate responsibility to this object. + WinSocket* socket() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + return socket_.get(); + } + + const EventEngine::ConnectionHandle& connection_handle() { + return connection_handle_; + } + const EventEngine::TaskHandle& timer_handle() { return timer_handle_; } + + grpc_core::Mutex& mu() ABSL_LOCK_RETURNED(mu_) { return mu_; } + + private: + // Required for the custom operator<< overload to see the private + // ConnectionState internal state. + friend std::ostream& operator<<(std::ostream& out, + const ConnectionState& connection_state) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(connection_state.mu_); + + // Stateful closure for the endpoint's on-connect callback. + // + // Once created, this closure must be Run or deleted to release the held + // refs. + class OnConnectedCallback : public EventEngine::Closure { + public: + OnConnectedCallback(WindowsEventEngine* engine, + std::shared_ptr connection_state) + : engine_(engine), connection_state_(std::move(connection_state)) {} + + // Runs the WindowsEventEngine's OnConnectCompleted if the deadline timer + // hasn't fired first. + void Run() override; + + private: + WindowsEventEngine* engine_; + std::shared_ptr connection_state_; + }; + + // Stateful closure for the deadline timer. + // + // Once created, this closure must be Run or deleted to release the held + // refs. + class DeadlineTimerCallback : public EventEngine::Closure { + public: + DeadlineTimerCallback(WindowsEventEngine* engine, + std::shared_ptr connection_state) + : engine_(engine), connection_state_(std::move(connection_state)) {} + + // Runs the WindowsEventEngine's OnDeadlineTimerFired if the deadline + // timer hasn't fired first. + void Run() override; + + private: + WindowsEventEngine* engine_; + std::shared_ptr connection_state_; + }; + + // everything is guarded by mu_; + grpc_core::Mutex mu_ ABSL_ACQUIRED_BEFORE(WindowsEventEngine::connection_mu_); - EventEngine::ConnectionHandle connection_handle ABSL_GUARDED_BY(mu); - EventEngine::TaskHandle timer_handle ABSL_GUARDED_BY(mu) = + // Endpoint connection state. + std::unique_ptr socket_ ABSL_GUARDED_BY(mu_); + EventEngine::ResolvedAddress address_ ABSL_GUARDED_BY(mu_); + MemoryAllocator allocator_ ABSL_GUARDED_BY(mu_); + EventEngine::OnConnectCallback on_connect_user_callback_ + ABSL_GUARDED_BY(mu_); + // This guarantees the EventEngine survives long enough to execute these + // deadline timer or on-connect callbacks. + std::shared_ptr engine_ ABSL_GUARDED_BY(mu_); + // Owned closures. These hold refs to this object. + std::unique_ptr on_connected_cb_ ABSL_GUARDED_BY(mu_); + std::unique_ptr deadline_timer_cb_ + ABSL_GUARDED_BY(mu_); + // Their respective method handles. + EventEngine::ConnectionHandle connection_handle_ ABSL_GUARDED_BY(mu_) = + EventEngine::ConnectionHandle::kInvalid; + EventEngine::TaskHandle timer_handle_ ABSL_GUARDED_BY(mu_) = EventEngine::TaskHandle::kInvalid; - EventEngine::OnConnectCallback on_connected_user_callback - ABSL_GUARDED_BY(mu); - EventEngine::Closure* on_connected ABSL_GUARDED_BY(mu); - std::unique_ptr socket ABSL_GUARDED_BY(mu); - EventEngine::ResolvedAddress address ABSL_GUARDED_BY(mu); - MemoryAllocator allocator ABSL_GUARDED_BY(mu); + // Flag to ensure that only one of the even closures will complete its + // responsibilities. + bool has_run_ ABSL_GUARDED_BY(mu_) = false; }; + // Required for the custom operator<< overload to see the private + // ConnectionState type. + friend std::ostream& operator<<(std::ostream& out, + const ConnectionState& connection_state); + + struct TimerClosure; + // A poll worker which schedules itself unless kicked class IOCPWorkClosure : public EventEngine::Closure { public: @@ -137,25 +244,29 @@ class WindowsEventEngine : public EventEngine, IOCP* iocp_; }; + // Called via IOCP notifications when a connection is ready to be processed. + // Either this or the deadline timer will run, never both. void OnConnectCompleted(std::shared_ptr state); - // CancelConnect called from within the deadline timer. - // In this case, the connection_state->mu is already locked, and timer - // cancellation is not possible. + // Called after a timeout when no connection has been established. + // Either this or the on-connect callback will run, never both. + void OnDeadlineTimerFired(std::shared_ptr state); + + // CancelConnect, called from within the deadline timer. + // Timer cancellation is not possible. bool CancelConnectFromDeadlineTimer(ConnectionState* connection_state) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(connection_state->mu); + ABSL_EXCLUSIVE_LOCKS_REQUIRED(connection_state->mu()); - // Completes the connection cancellation logic after checking handle validity - // and optionally cancelling deadline timers. + // Completes the connection cancellation logic after checking handle + // validity and optionally cancelling deadline timers. bool CancelConnectInternalStateLocked(ConnectionState* connection_state) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(connection_state->mu); + ABSL_EXCLUSIVE_LOCKS_REQUIRED(connection_state->mu()); - struct TimerClosure; EventEngine::TaskHandle RunAfterInternal(Duration when, absl::AnyInvocable cb); grpc_core::Mutex task_mu_; TaskHandleSet known_handles_ ABSL_GUARDED_BY(task_mu_); - grpc_core::Mutex connection_mu_ ABSL_ACQUIRED_AFTER(ConnectionState::mu); + grpc_core::Mutex connection_mu_; grpc_core::CondVar connection_cv_; ConnectionHandleSet known_connection_handles_ ABSL_GUARDED_BY(connection_mu_); std::atomic aba_token_{0}; diff --git a/src/core/lib/gprpp/dump_args.cc b/src/core/lib/gprpp/dump_args.cc index e5bc183246b..d4400bbf296 100644 --- a/src/core/lib/gprpp/dump_args.cc +++ b/src/core/lib/gprpp/dump_args.cc @@ -51,4 +51,4 @@ std::ostream& operator<<(std::ostream& out, const DumpArgs& args) { } } // namespace dump_args_detail -} // namespace grpc_core \ No newline at end of file +} // namespace grpc_core diff --git a/test/core/event_engine/event_engine_test_utils.cc b/test/core/event_engine/event_engine_test_utils.cc index d026f50d0b9..7fb8f96b713 100644 --- a/test/core/event_engine/event_engine_test_utils.cc +++ b/test/core/event_engine/event_engine_test_utils.cc @@ -17,6 +17,7 @@ #include #include +#include #include #include #include @@ -38,6 +39,7 @@ #include "src/core/lib/event_engine/channel_args_endpoint_config.h" #include "src/core/lib/event_engine/tcp_socket_utils.h" +#include "src/core/lib/gprpp/crash.h" #include "src/core/lib/gprpp/notification.h" #include "src/core/lib/gprpp/time.h" #include "src/core/lib/resource_quota/memory_quota.h" @@ -78,10 +80,19 @@ std::string GetNextSendMessage() { } void WaitForSingleOwner(std::shared_ptr engine) { + WaitForSingleOwnerWithTimeout(std::move(engine), std::chrono::hours{24}); +} + +void WaitForSingleOwnerWithTimeout(std::shared_ptr engine, + EventEngine::Duration timeout) { int n = 0; + auto start = std::chrono::system_clock::now(); while (engine.use_count() > 1) { ++n; if (n % 100 == 0) AsanAssertNoLeaks(); + if (std::chrono::system_clock::now() - start > timeout) { + grpc_core::Crash("Timed out waiting for a single EventEngine owner"); + } GRPC_LOG_EVERY_N_SEC(2, GPR_INFO, "engine.use_count() = %ld", engine.use_count()); absl::SleepFor(absl::Milliseconds(100)); diff --git a/test/core/event_engine/event_engine_test_utils.h b/test/core/event_engine/event_engine_test_utils.h index debef7a0449..1b71dc549f7 100644 --- a/test/core/event_engine/event_engine_test_utils.h +++ b/test/core/event_engine/event_engine_test_utils.h @@ -52,6 +52,14 @@ std::string GetNextSendMessage(); // Usage: WaitForSingleOwner(std::move(engine)) void WaitForSingleOwner(std::shared_ptr engine); +// Waits until the use_count of the EventEngine shared_ptr has reached 1 +// and returns. +// Callers must give up their ref, or this method will block forever. +// This version will CRASH after the given timeout +// Usage: WaitForSingleOwner(std::move(engine), 30s) +void WaitForSingleOwnerWithTimeout(std::shared_ptr engine, + EventEngine::Duration timeout); + // A helper method to exchange data between two endpoints. It is assumed // that both endpoints are connected. The data (specified as a string) is // written by the sender_endpoint and read by the receiver_endpoint. It