[EventEngine] Fix race between connection and its deadline timer on Windows (#36709)

Closes #36709

COPYBARA_INTEGRATE_REVIEW=https://github.com/grpc/grpc/pull/36709 from drfloob:fix-win-ee-use-after-free dd4ae2683e
PiperOrigin-RevId: 641000031
pull/36838/head
AJ Heller 6 months ago committed by Copybara-Service
parent 2e35d4aab5
commit 5d586d3ae3
  1. 19
      include/grpc/event_engine/event_engine.h
  2. 1
      src/core/BUILD
  3. 33
      src/core/lib/event_engine/event_engine.cc
  4. 2
      src/core/lib/event_engine/thread_local.h
  5. 8
      src/core/lib/event_engine/trace.h
  6. 2
      src/core/lib/event_engine/windows/iocp.h
  7. 36
      src/core/lib/event_engine/windows/win_socket.cc
  8. 9
      src/core/lib/event_engine/windows/win_socket.h
  9. 11
      src/core/lib/event_engine/windows/windows_endpoint.cc
  10. 280
      src/core/lib/event_engine/windows/windows_engine.cc
  11. 161
      src/core/lib/event_engine/windows/windows_engine.h
  12. 2
      src/core/lib/gprpp/dump_args.cc
  13. 11
      test/core/event_engine/event_engine_test_utils.cc
  14. 8
      test/core/event_engine/event_engine_test_utils.h

@ -132,8 +132,6 @@ class EventEngine : public std::enable_shared_from_this<EventEngine>,
struct TaskHandle {
intptr_t keys[2];
static const GRPC_DLL TaskHandle kInvalid;
friend bool operator==(const TaskHandle& lhs, const TaskHandle& rhs);
friend bool operator!=(const TaskHandle& lhs, const TaskHandle& rhs);
};
/// A handle to a cancellable connection attempt.
///
@ -141,10 +139,6 @@ class EventEngine : public std::enable_shared_from_this<EventEngine>,
struct ConnectionHandle {
intptr_t keys[2];
static const GRPC_DLL ConnectionHandle kInvalid;
friend bool operator==(const ConnectionHandle& lhs,
const ConnectionHandle& rhs);
friend bool operator!=(const ConnectionHandle& lhs,
const ConnectionHandle& rhs);
};
/// Thin wrapper around a platform-specific sockaddr type. A sockaddr struct
/// exists on all platforms that gRPC supports.
@ -496,6 +490,19 @@ void EventEngineFactoryReset();
/// Create an EventEngine using the default factory.
std::unique_ptr<EventEngine> CreateEventEngine();
bool operator==(const EventEngine::TaskHandle& lhs,
const EventEngine::TaskHandle& rhs);
bool operator!=(const EventEngine::TaskHandle& lhs,
const EventEngine::TaskHandle& rhs);
std::ostream& operator<<(std::ostream& out,
const EventEngine::TaskHandle& handle);
bool operator==(const EventEngine::ConnectionHandle& lhs,
const EventEngine::ConnectionHandle& rhs);
bool operator!=(const EventEngine::ConnectionHandle& lhs,
const EventEngine::ConnectionHandle& rhs);
std::ostream& operator<<(std::ostream& out,
const EventEngine::ConnectionHandle& handle);
} // namespace experimental
} // namespace grpc_event_engine

@ -2493,6 +2493,7 @@ grpc_cc_library(
"ares_resolver",
"channel_args_endpoint_config",
"common_event_engine_closures",
"dump_args",
"error",
"event_engine_common",
"event_engine_tcp_socket_utils",

@ -11,6 +11,8 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "absl/strings/str_cat.h"
#include <grpc/event_engine/event_engine.h>
#include <grpc/support/port_platform.h>
@ -21,24 +23,47 @@ const EventEngine::TaskHandle EventEngine::TaskHandle::kInvalid = {-1, -1};
const EventEngine::ConnectionHandle EventEngine::ConnectionHandle::kInvalid = {
-1, -1};
namespace {
template <typename T>
bool eq(const T& lhs, const T& rhs) {
return lhs.keys[0] == rhs.keys[0] && lhs.keys[1] == rhs.keys[1];
}
template <typename T>
std::ostream& printout(std::ostream& out, const T& handle) {
out << absl::StrCat("{", absl::Hex(handle.keys[0], absl::kZeroPad16), ",",
absl::Hex(handle.keys[1], absl::kZeroPad16), "}");
return out;
}
} // namespace
bool operator==(const EventEngine::TaskHandle& lhs,
const EventEngine::TaskHandle& rhs) {
return lhs.keys[0] == rhs.keys[0] && lhs.keys[1] == rhs.keys[1];
return eq(lhs, rhs);
}
bool operator!=(const EventEngine::TaskHandle& lhs,
const EventEngine::TaskHandle& rhs) {
return !(lhs == rhs);
return !eq(lhs, rhs);
}
std::ostream& operator<<(std::ostream& out,
const EventEngine::TaskHandle& handle) {
return printout(out, handle);
}
bool operator==(const EventEngine::ConnectionHandle& lhs,
const EventEngine::ConnectionHandle& rhs) {
return lhs.keys[0] == rhs.keys[0] && lhs.keys[1] == rhs.keys[1];
return eq(lhs, rhs);
}
bool operator!=(const EventEngine::ConnectionHandle& lhs,
const EventEngine::ConnectionHandle& rhs) {
return !(lhs == rhs);
return !eq(lhs, rhs);
}
std::ostream& operator<<(std::ostream& out,
const EventEngine::ConnectionHandle& handle) {
return printout(out, handle);
}
} // namespace experimental

@ -29,4 +29,4 @@ class ThreadLocal {
} // namespace experimental
} // namespace grpc_event_engine
#endif // GRPC_SRC_CORE_LIB_EVENT_ENGINE_THREAD_LOCAL_H
#endif // GRPC_SRC_CORE_LIB_EVENT_ENGINE_THREAD_LOCAL_H

@ -27,22 +27,22 @@ extern grpc_core::TraceFlag grpc_event_engine_endpoint_trace;
#define GRPC_EVENT_ENGINE_TRACE(format, ...) \
if (GRPC_TRACE_FLAG_ENABLED(grpc_event_engine_trace)) { \
gpr_log(GPR_DEBUG, "(event_engine) " format, __VA_ARGS__); \
gpr_log(GPR_ERROR, "(event_engine) " format, __VA_ARGS__); \
}
#define GRPC_EVENT_ENGINE_ENDPOINT_TRACE(format, ...) \
if (GRPC_TRACE_FLAG_ENABLED(grpc_event_engine_endpoint_trace)) { \
gpr_log(GPR_DEBUG, "(event_engine endpoint) " format, __VA_ARGS__); \
gpr_log(GPR_ERROR, "(event_engine endpoint) " format, __VA_ARGS__); \
}
#define GRPC_EVENT_ENGINE_POLLER_TRACE(format, ...) \
if (GRPC_TRACE_FLAG_ENABLED(grpc_event_engine_poller_trace)) { \
gpr_log(GPR_DEBUG, "(event_engine poller) " format, __VA_ARGS__); \
gpr_log(GPR_ERROR, "(event_engine poller) " format, __VA_ARGS__); \
}
#define GRPC_EVENT_ENGINE_DNS_TRACE(format, ...) \
if (GRPC_TRACE_FLAG_ENABLED(grpc_event_engine_dns_trace)) { \
gpr_log(GPR_DEBUG, "(event_engine dns) " format, __VA_ARGS__); \
gpr_log(GPR_ERROR, "(event_engine dns) " format, __VA_ARGS__); \
}
#endif // GRPC_SRC_CORE_LIB_EVENT_ENGINE_TRACE_H

@ -32,7 +32,7 @@ namespace experimental {
class IOCP final : public Poller {
public:
explicit IOCP(ThreadPool* thread_pool) noexcept;
~IOCP();
~IOCP() override;
// Not copyable
IOCP(const IOCP&) = delete;
IOCP& operator=(const IOCP&) = delete;

@ -69,13 +69,19 @@ void WinSocket::Shutdown() {
int status = WSAIoctl(socket_, SIO_GET_EXTENSION_FUNCTION_POINTER, &guid,
sizeof(guid), &DisconnectEx, sizeof(DisconnectEx),
&ioctl_num_bytes, NULL, NULL);
if (status == 0) {
DisconnectEx(socket_, NULL, 0, 0);
} else {
if (status != 0) {
char* utf8_message = gpr_format_message(WSAGetLastError());
LOG(INFO) << "Unable to retrieve DisconnectEx pointer : " << utf8_message;
gpr_free(utf8_message);
} else if (DisconnectEx(socket_, NULL, 0, 0) == FALSE) {
auto last_error = WSAGetLastError();
// DisconnectEx may be called when the socket is not connected. Ignore that
// error, and log all others.
if (last_error != WSAENOTCONN) {
char* utf8_message = gpr_format_message(last_error);
LOG(INFO) << "DisconnectEx failed: " << utf8_message;
gpr_free(utf8_message);
}
}
closesocket(socket_);
GRPC_EVENT_ENGINE_ENDPOINT_TRACE("WinSocket::%p socket closed", this);
@ -91,7 +97,7 @@ void WinSocket::Shutdown(const grpc_core::DebugLocation& location,
void WinSocket::NotifyOnReady(OpState& info, EventEngine::Closure* closure) {
if (IsShutdown()) {
info.SetError(WSAESHUTDOWN);
info.SetResult(WSAESHUTDOWN, 0, "NotifyOnReady");
thread_pool_->Run(closure);
return;
};
@ -130,12 +136,13 @@ void WinSocket::OpState::SetReady() {
win_socket_->thread_pool_->Run(closure);
}
void WinSocket::OpState::SetError(int wsa_error) {
result_ = OverlappedResult{/*wsa_error=*/wsa_error, /*bytes_transferred=*/0};
}
void WinSocket::OpState::SetResult(OverlappedResult result) {
result_ = result;
void WinSocket::OpState::SetResult(int wsa_error, DWORD bytes,
absl::string_view context) {
bytes = wsa_error == 0 ? bytes : 0;
result_ = OverlappedResult{
/*wsa_error=*/wsa_error, /*bytes_transferred=*/bytes,
/*error_status=*/wsa_error == 0 ? absl::OkStatus()
: GRPC_WSA_ERROR(wsa_error, context)};
}
void WinSocket::OpState::SetErrorStatus(absl::Status error_status) {
@ -149,16 +156,15 @@ void WinSocket::OpState::GetOverlappedResult() {
void WinSocket::OpState::GetOverlappedResult(SOCKET sock) {
if (win_socket_->IsShutdown()) {
result_ = OverlappedResult{/*wsa_error=*/WSA_OPERATION_ABORTED,
/*bytes_transferred=*/0};
SetResult(WSA_OPERATION_ABORTED, 0, "GetOverlappedResult");
return;
}
DWORD flags = 0;
DWORD bytes;
BOOL success =
WSAGetOverlappedResult(sock, &overlapped_, &bytes, FALSE, &flags);
result_ = OverlappedResult{/*wsa_error=*/success ? 0 : WSAGetLastError(),
/*bytes_transferred=*/bytes};
auto wsa_error = success ? 0 : WSAGetLastError();
SetResult(wsa_error, bytes, "WSAGetOverlappedResult");
}
bool WinSocket::IsShutdown() { return is_shutdown_.load(); }

@ -47,12 +47,11 @@ class WinSocket {
// the WinSocket's ThreadPool. Otherwise, a "pending iocp" flag will
// be set.
void SetReady();
// Set WSA error results for a completed op.
void SetError(int wsa_error);
// Set an OverlappedResult. Useful when WSARecv returns immediately.
void SetResult(OverlappedResult result);
// Set WSA result for a completed op.
// If the error is non-zero, bytes will be overridden to 0.
void SetResult(int wsa_error, DWORD bytes, absl::string_view context);
// Set error results for a completed op.
// This is a manual override, meant to override any WSA status code.
// This is a manual override, meant to ignore any WSA status code.
void SetErrorStatus(absl::Status error_status);
// Retrieve the results of an overlapped operation (via Winsock API) and
// store them locally.

@ -102,8 +102,7 @@ void WindowsEndpoint::AsyncIOState::DoTcpRead(SliceBuffer* buffer) {
int wsa_error = status == 0 ? 0 : WSAGetLastError();
if (wsa_error != WSAEWOULDBLOCK) {
// Data or some error was returned immediately.
socket->read_info()->SetResult(
{/*wsa_error=*/wsa_error, /*bytes_read=*/bytes_read});
socket->read_info()->SetResult(wsa_error, bytes_read, "WSARecv");
thread_pool->Run(&handle_read_event);
return;
}
@ -120,9 +119,8 @@ void WindowsEndpoint::AsyncIOState::DoTcpRead(SliceBuffer* buffer) {
if (wsa_error != 0 && wsa_error != WSA_IO_PENDING) {
// The async read attempt returned an error immediately.
socket->UnregisterReadCallback();
socket->read_info()->SetErrorStatus(GRPC_WSA_ERROR(
wsa_error,
absl::StrFormat("WindowsEndpont::%p Read failed", this).c_str()));
socket->read_info()->SetResult(
wsa_error, 0, absl::StrFormat("WindowsEndpont::%p Read failed", this));
thread_pool->Run(&handle_read_event);
}
}
@ -220,8 +218,7 @@ bool WindowsEndpoint::Write(absl::AnyInvocable<void(absl::Status)> on_writable,
int wsa_error = WSAGetLastError();
if (wsa_error != WSA_IO_PENDING) {
io_state_->socket->UnregisterWriteCallback();
io_state_->socket->write_info()->SetErrorStatus(
GRPC_WSA_ERROR(wsa_error, "WSASend"));
io_state_->socket->write_info()->SetResult(wsa_error, 0, "WSASend");
io_state_->thread_pool->Run(&io_state_->handle_write_event);
}
}

@ -16,8 +16,10 @@
#ifdef GPR_WINDOWS
#include <memory>
#include <ostream>
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
@ -43,6 +45,7 @@
#include "src/core/lib/event_engine/windows/windows_engine.h"
#include "src/core/lib/event_engine/windows/windows_listener.h"
#include "src/core/lib/gprpp/crash.h"
#include "src/core/lib/gprpp/dump_args.h"
#include "src/core/lib/gprpp/sync.h"
#include "src/core/lib/gprpp/time.h"
#include "src/core/lib/iomgr/error.h"
@ -50,13 +53,104 @@
namespace grpc_event_engine {
namespace experimental {
namespace {
EventEngine::OnConnectCallback CreateCrashingOnConnectCallback() {
return [](absl::StatusOr<std::unique_ptr<EventEngine::Endpoint>>) {
grpc_core::Crash("Internal Error: OnConnect callback called when unset");
};
std::ostream& operator<<(
std::ostream& out,
const WindowsEventEngine::ConnectionState& connection_state) {
out << "ConnectionState::" << &connection_state
<< ": connection_state.address="
<< ResolvedAddressToURI(connection_state.address_) << ","
<< GRPC_DUMP_ARGS(connection_state.has_run_,
connection_state.connection_handle_,
connection_state.timer_handle_);
return out;
}
} // namespace
// ---- ConnectionState ----
WindowsEventEngine::ConnectionState::ConnectionState(
std::shared_ptr<WindowsEventEngine> engine,
std::unique_ptr<WinSocket> socket, EventEngine::ResolvedAddress address,
MemoryAllocator allocator,
EventEngine::OnConnectCallback on_connect_user_callback)
: socket_(std::move(socket)),
address_(address),
allocator_(std::move(allocator)),
on_connect_user_callback_(std::move(on_connect_user_callback)),
engine_(std::move(engine)) {
CHECK(socket_ != nullptr);
connection_handle_ = ConnectionHandle{reinterpret_cast<intptr_t>(this),
engine_->aba_token_.fetch_add(1)};
}
void WindowsEventEngine::ConnectionState::Start(Duration timeout) {
on_connected_cb_ =
std::make_unique<OnConnectedCallback>(engine_.get(), shared_from_this());
socket_->NotifyOnWrite(on_connected_cb_.get());
deadline_timer_cb_ = std::make_unique<DeadlineTimerCallback>(
engine_.get(), shared_from_this());
timer_handle_ = engine_->RunAfter(timeout, deadline_timer_cb_.get());
}
EventEngine::OnConnectCallback
WindowsEventEngine::ConnectionState::TakeCallback() {
return std::exchange(on_connect_user_callback_, nullptr);
}
std::unique_ptr<WindowsEndpoint>
WindowsEventEngine::ConnectionState::FinishConnectingAndMakeEndpoint(
ThreadPool* thread_pool) {
ChannelArgsEndpointConfig cfg;
return std::make_unique<WindowsEndpoint>(address_, std::move(socket_),
std::move(allocator_), cfg,
thread_pool, engine_);
}
void WindowsEventEngine::ConnectionState::AbortOnConnect() {
on_connected_cb_.reset();
}
void WindowsEventEngine::ConnectionState::AbortDeadlineTimer() {
deadline_timer_cb_.reset();
}
void WindowsEventEngine::ConnectionState::OnConnectedCallback::Run() {
DCHECK_NE(connection_state_, nullptr)
<< "ConnectionState::OnConnectedCallback::" << this
<< " has already run. It should only ever run once.";
bool has_run;
{
grpc_core::MutexLock lock(&connection_state_->mu_);
has_run = std::exchange(connection_state_->has_run_, true);
}
// This could race with the deadline timer. If so, the engine's
// OnConnectCompleted callback should not run, and the refs should be
// released.
if (has_run) {
connection_state_.reset();
return;
}
engine_->OnConnectCompleted(std::move(connection_state_));
}
void WindowsEventEngine::ConnectionState::DeadlineTimerCallback::Run() {
DCHECK_NE(connection_state_, nullptr)
<< "ConnectionState::DeadlineTimerCallback::" << this
<< " has already run. It should only ever run once.";
bool has_run;
{
grpc_core::MutexLock lock(&connection_state_->mu_);
has_run = std::exchange(connection_state_->has_run_, true);
}
// This could race with the on connected callback. If so, the engine's
// OnDeadlineTimerFired callback should not run, and the refs should be
// released.
if (has_run) {
connection_state_.reset();
return;
}
engine_->OnDeadlineTimerFired(std::move(connection_state_));
}
// ---- IOCPWorkClosure ----
WindowsEventEngine::IOCPWorkClosure::IOCPWorkClosure(ThreadPool* thread_pool,
@ -66,6 +160,7 @@ WindowsEventEngine::IOCPWorkClosure::IOCPWorkClosure(ThreadPool* thread_pool,
}
void WindowsEventEngine::IOCPWorkClosure::Run() {
if (done_signal_.HasBeenNotified()) return;
auto result = iocp_->Work(std::chrono::seconds(60), [this] {
workers_.fetch_add(1);
thread_pool_->Run(this);
@ -256,35 +351,54 @@ void WindowsEventEngine::OnConnectCompleted(
EventEngine::OnConnectCallback cb;
{
// Connection attempt complete!
grpc_core::MutexLock lock(&state->mu);
cb = std::move(state->on_connected_user_callback);
state->on_connected_user_callback = CreateCrashingOnConnectCallback();
state->on_connected = nullptr;
grpc_core::MutexLock lock(&state->mu());
// return early if we cannot cancel the connection timeout timer.
int erased_handles = 0;
{
grpc_core::MutexLock handle_lock(&connection_mu_);
known_connection_handles_.erase(state->connection_handle);
erased_handles =
known_connection_handles_.erase(state->connection_handle());
}
const auto& overlapped_result = state->socket->write_info()->result();
// return early if we cannot cancel the connection timeout timer.
if (!Cancel(state->timer_handle)) return;
if (erased_handles != 1 || !Cancel(state->timer_handle())) {
GRPC_EVENT_ENGINE_TRACE(
"%s", "Not accepting connection since the deadline timer has fired");
return;
}
// Release refs held by the deadline timer.
state->AbortDeadlineTimer();
const auto& overlapped_result = state->socket()->write_info()->result();
if (!overlapped_result.error_status.ok()) {
state->socket->Shutdown(DEBUG_LOCATION, "ConnectEx failure");
state->socket()->Shutdown(DEBUG_LOCATION, "ConnectEx failure");
endpoint = overlapped_result.error_status;
} else if (overlapped_result.wsa_error != 0) {
state->socket->Shutdown(DEBUG_LOCATION, "ConnectEx failure");
state->socket()->Shutdown(DEBUG_LOCATION, "ConnectEx failure");
endpoint = GRPC_WSA_ERROR(overlapped_result.wsa_error, "ConnectEx");
} else {
ChannelArgsEndpointConfig cfg;
endpoint = std::make_unique<WindowsEndpoint>(
state->address, std::move(state->socket), std::move(state->allocator),
cfg, thread_pool_.get(), shared_from_this());
endpoint = state->FinishConnectingAndMakeEndpoint(thread_pool_.get());
}
cb = state->TakeCallback();
}
// This code should be running in a thread pool thread already, so the
// callback can be run directly.
state.reset();
cb(std::move(endpoint));
}
void WindowsEventEngine::OnDeadlineTimerFired(
std::shared_ptr<ConnectionState> connection_state) {
bool cancelled = false;
EventEngine::OnConnectCallback cb;
{
grpc_core::MutexLock lock(&connection_state->mu());
cancelled = CancelConnectFromDeadlineTimer(connection_state.get());
if (cancelled) cb = connection_state->TakeCallback();
}
if (cancelled) {
connection_state.reset();
cb(absl::DeadlineExceededError("Connection timed out"));
}
}
EventEngine::ConnectionHandle WindowsEventEngine::Connect(
OnConnectCallback on_connect, const ResolvedAddress& addr,
const EndpointConfig& /* args */, MemoryAllocator memory_allocator,
@ -367,65 +481,61 @@ EventEngine::ConnectionHandle WindowsEventEngine::Connect(
return EventEngine::ConnectionHandle::kInvalid;
}
// Prepare the socket to receive a connection
auto connection_state = std::make_shared<ConnectionState>();
grpc_core::MutexLock lock(&connection_state->mu);
connection_state->socket = iocp_.Watch(sock);
CHECK(connection_state->socket != nullptr);
auto* info = connection_state->socket->write_info();
connection_state->address = address;
connection_state->allocator = std::move(memory_allocator);
connection_state->on_connected_user_callback = std::move(on_connect);
connection_state->on_connected =
SelfDeletingClosure::Create([this, connection_state]() mutable {
OnConnectCompleted(std::move(connection_state));
});
connection_state->timer_handle =
RunAfter(timeout, [this, connection_state]() {
grpc_core::ReleasableMutexLock lock(&connection_state->mu);
if (CancelConnectFromDeadlineTimer(connection_state.get())) {
auto cb = std::move(connection_state->on_connected_user_callback);
connection_state->on_connected_user_callback =
CreateCrashingOnConnectCallback();
lock.Release();
cb(absl::DeadlineExceededError("Connection timed out"));
}
// else: The connection attempt could not be canceled. We can assume
// the connection callback will be called.
});
// Connect
connection_state->socket->NotifyOnWrite(connection_state->on_connected);
auto connection_state = std::make_shared<ConnectionState>(
std::static_pointer_cast<WindowsEventEngine>(shared_from_this()),
/*socket=*/iocp_.Watch(sock), address,
/*memory_allocator=*/std::move(memory_allocator),
/*on_connect_user_callback=*/std::move(on_connect));
grpc_core::MutexLock lock(&connection_state->mu());
auto* info = connection_state->socket()->write_info();
{
grpc_core::MutexLock connection_handle_lock(&connection_mu_);
known_connection_handles_.insert(connection_state->connection_handle());
}
connection_state->Start(timeout);
bool success =
ConnectEx(connection_state->socket->raw_socket(), address.address(),
ConnectEx(connection_state->socket()->raw_socket(), address.address(),
address.size(), nullptr, 0, nullptr, info->overlapped());
// It wouldn't be unusual to get a success immediately. But we'll still get an
// IOCP notification, so let's ignore it.
if (!success) {
int last_error = WSAGetLastError();
if (last_error != ERROR_IO_PENDING) {
if (!Cancel(connection_state->timer_handle)) {
return EventEngine::ConnectionHandle::kInvalid;
}
connection_state->socket->Shutdown(DEBUG_LOCATION, "ConnectEx");
Run([connection_state = std::move(connection_state),
status = GRPC_WSA_ERROR(WSAGetLastError(), "ConnectEx")]() mutable {
EventEngine::OnConnectCallback cb;
{
grpc_core::MutexLock lock(&connection_state->mu);
cb = std::move(connection_state->on_connected_user_callback);
connection_state->on_connected_user_callback =
CreateCrashingOnConnectCallback();
}
cb(status);
});
return EventEngine::ConnectionHandle::kInvalid;
}
if (success) return connection_state->connection_handle();
// Otherwise, we need to handle an error or a pending IO Event.
int last_error = WSAGetLastError();
if (last_error == ERROR_IO_PENDING) {
// Overlapped I/O operation is in progress.
return connection_state->connection_handle();
}
// Time to abort the connection.
// The on-connect callback won't run, so we must clean up its state.
connection_state->AbortOnConnect();
int erased_handles = 0;
{
grpc_core::MutexLock connection_handle_lock(&connection_mu_);
erased_handles =
known_connection_handles_.erase(connection_state->connection_handle());
}
CHECK_EQ(erased_handles, 1) << "Did not find connection handle "
<< connection_state->connection_handle()
<< " after a synchronous connection failure. "
"This should not be possible.";
connection_state->socket()->Shutdown(DEBUG_LOCATION, "ConnectEx");
if (!Cancel(connection_state->timer_handle())) {
// The deadline timer will run, or is running.
return EventEngine::ConnectionHandle::kInvalid;
}
connection_state->connection_handle =
ConnectionHandle{reinterpret_cast<intptr_t>(connection_state.get()),
aba_token_.fetch_add(1)};
grpc_core::MutexLock connection_handle_lock(&connection_mu_);
known_connection_handles_.insert(connection_state->connection_handle);
return connection_state->connection_handle;
// The deadline timer won't run, so we must clean up its state.
connection_state->AbortDeadlineTimer();
Run([connection_state = std::move(connection_state),
status = GRPC_WSA_ERROR(WSAGetLastError(), "ConnectEx")]() mutable {
EventEngine::OnConnectCallback cb;
{
grpc_core::MutexLock lock(&connection_state->mu());
cb = connection_state->TakeCallback();
}
connection_state.reset();
cb(std::move(status));
});
return EventEngine::ConnectionHandle::kInvalid;
}
bool WindowsEventEngine::CancelConnect(EventEngine::ConnectionHandle handle) {
@ -437,17 +547,20 @@ bool WindowsEventEngine::CancelConnect(EventEngine::ConnectionHandle handle) {
// Erase the connection handle, which may be unknown
{
grpc_core::MutexLock lock(&connection_mu_);
if (!known_connection_handles_.contains(handle)) {
if (known_connection_handles_.erase(handle) != 1) {
GRPC_EVENT_ENGINE_TRACE(
"Unknown connection handle: %s",
HandleToString<EventEngine::ConnectionHandle>(handle).c_str());
return false;
}
known_connection_handles_.erase(handle);
}
auto* connection_state = reinterpret_cast<ConnectionState*>(handle.keys[0]);
grpc_core::MutexLock state_lock(&connection_state->mu);
if (!Cancel(connection_state->timer_handle)) return false;
grpc_core::MutexLock state_lock(&connection_state->mu());
// The connection cannot be cancelled if the deadline timer is already firing.
if (!Cancel(connection_state->timer_handle())) return false;
// The deadline timer was cancelled, so we must clean up its state.
connection_state->AbortDeadlineTimer();
// The on-connect callback will run when the socket shutdown event occurs.
return CancelConnectInternalStateLocked(connection_state);
}
@ -456,20 +569,21 @@ bool WindowsEventEngine::CancelConnectFromDeadlineTimer(
// Erase the connection handle, which is guaranteed to exist.
{
grpc_core::MutexLock lock(&connection_mu_);
CHECK(known_connection_handles_.erase(
connection_state->connection_handle) == 1);
if (known_connection_handles_.erase(
connection_state->connection_handle()) != 1) {
return false;
}
}
return CancelConnectInternalStateLocked(connection_state);
}
bool WindowsEventEngine::CancelConnectInternalStateLocked(
ConnectionState* connection_state) {
connection_state->socket->Shutdown(DEBUG_LOCATION, "CancelConnect");
connection_state->socket()->Shutdown(DEBUG_LOCATION, "CancelConnect");
// Release the connection_state shared_ptr owned by the connected callback.
delete connection_state->on_connected;
GRPC_EVENT_ENGINE_TRACE("Successfully cancelled connection %s",
HandleToString<EventEngine::ConnectionHandle>(
connection_state->connection_handle)
connection_state->connection_handle())
.c_str());
return true;
}

@ -44,8 +44,6 @@
namespace grpc_event_engine {
namespace experimental {
// TODO(ctiller): KeepsGrpcInitialized is an interim measure to ensure that
// EventEngine is shut down before we shut down iomgr.
class WindowsEventEngine : public EventEngine,
public grpc_core::KeepsGrpcInitialized {
public:
@ -105,24 +103,133 @@ class WindowsEventEngine : public EventEngine,
IOCP* poller() { return &iocp_; }
private:
// State of an active connection.
// Managed by a shared_ptr, owned exclusively by the timeout callback and the
// OnConnectCompleted callback herein.
struct ConnectionState {
// everything is guarded by mu;
grpc_core::Mutex mu
// The state of an active connection.
//
// This object is managed by a shared_ptr, which is owned by:
// 1) the deadline timer callback, and
// 2) the OnConnectCompleted callback.
class ConnectionState : public std::enable_shared_from_this<ConnectionState> {
public:
ConnectionState(std::shared_ptr<WindowsEventEngine> engine,
std::unique_ptr<WinSocket> socket,
EventEngine::ResolvedAddress address,
MemoryAllocator allocator,
EventEngine::OnConnectCallback on_connect_user_callback);
// Starts the deadline timer, and sets up the socket to notify on writes.
//
// This cannot be done in the constructor since shared_from_this is required
// for the callbacks to hold a ref to this object.
void Start(Duration timeout) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
// Returns the user's callback and resets it to nullptr to ensure it only
// runs once.
OnConnectCallback TakeCallback() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
// Create an Endpoint, transfering held object ownership to the endpoint.
//
// This can only be called once, and the connection state is no longer valid
// after an endpoint has been created. Callers must guarantee that the
// deadline timer callback will not be run.
std::unique_ptr<WindowsEndpoint> FinishConnectingAndMakeEndpoint(
ThreadPool* thread_pool) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
// Release all refs to the on-connect callback.
void AbortOnConnect() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
// Release all refs to the deadline timer callback.
void AbortDeadlineTimer() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
// TODO(hork): this is unsafe. Whatever needs the socket should likely
// delegate responsibility to this object.
WinSocket* socket() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
return socket_.get();
}
const EventEngine::ConnectionHandle& connection_handle() {
return connection_handle_;
}
const EventEngine::TaskHandle& timer_handle() { return timer_handle_; }
grpc_core::Mutex& mu() ABSL_LOCK_RETURNED(mu_) { return mu_; }
private:
// Required for the custom operator<< overload to see the private
// ConnectionState internal state.
friend std::ostream& operator<<(std::ostream& out,
const ConnectionState& connection_state)
ABSL_EXCLUSIVE_LOCKS_REQUIRED(connection_state.mu_);
// Stateful closure for the endpoint's on-connect callback.
//
// Once created, this closure must be Run or deleted to release the held
// refs.
class OnConnectedCallback : public EventEngine::Closure {
public:
OnConnectedCallback(WindowsEventEngine* engine,
std::shared_ptr<ConnectionState> connection_state)
: engine_(engine), connection_state_(std::move(connection_state)) {}
// Runs the WindowsEventEngine's OnConnectCompleted if the deadline timer
// hasn't fired first.
void Run() override;
private:
WindowsEventEngine* engine_;
std::shared_ptr<ConnectionState> connection_state_;
};
// Stateful closure for the deadline timer.
//
// Once created, this closure must be Run or deleted to release the held
// refs.
class DeadlineTimerCallback : public EventEngine::Closure {
public:
DeadlineTimerCallback(WindowsEventEngine* engine,
std::shared_ptr<ConnectionState> connection_state)
: engine_(engine), connection_state_(std::move(connection_state)) {}
// Runs the WindowsEventEngine's OnDeadlineTimerFired if the deadline
// timer hasn't fired first.
void Run() override;
private:
WindowsEventEngine* engine_;
std::shared_ptr<ConnectionState> connection_state_;
};
// everything is guarded by mu_;
grpc_core::Mutex mu_
ABSL_ACQUIRED_BEFORE(WindowsEventEngine::connection_mu_);
EventEngine::ConnectionHandle connection_handle ABSL_GUARDED_BY(mu);
EventEngine::TaskHandle timer_handle ABSL_GUARDED_BY(mu) =
// Endpoint connection state.
std::unique_ptr<WinSocket> socket_ ABSL_GUARDED_BY(mu_);
EventEngine::ResolvedAddress address_ ABSL_GUARDED_BY(mu_);
MemoryAllocator allocator_ ABSL_GUARDED_BY(mu_);
EventEngine::OnConnectCallback on_connect_user_callback_
ABSL_GUARDED_BY(mu_);
// This guarantees the EventEngine survives long enough to execute these
// deadline timer or on-connect callbacks.
std::shared_ptr<WindowsEventEngine> engine_ ABSL_GUARDED_BY(mu_);
// Owned closures. These hold refs to this object.
std::unique_ptr<OnConnectedCallback> on_connected_cb_ ABSL_GUARDED_BY(mu_);
std::unique_ptr<DeadlineTimerCallback> deadline_timer_cb_
ABSL_GUARDED_BY(mu_);
// Their respective method handles.
EventEngine::ConnectionHandle connection_handle_ ABSL_GUARDED_BY(mu_) =
EventEngine::ConnectionHandle::kInvalid;
EventEngine::TaskHandle timer_handle_ ABSL_GUARDED_BY(mu_) =
EventEngine::TaskHandle::kInvalid;
EventEngine::OnConnectCallback on_connected_user_callback
ABSL_GUARDED_BY(mu);
EventEngine::Closure* on_connected ABSL_GUARDED_BY(mu);
std::unique_ptr<WinSocket> socket ABSL_GUARDED_BY(mu);
EventEngine::ResolvedAddress address ABSL_GUARDED_BY(mu);
MemoryAllocator allocator ABSL_GUARDED_BY(mu);
// Flag to ensure that only one of the even closures will complete its
// responsibilities.
bool has_run_ ABSL_GUARDED_BY(mu_) = false;
};
// Required for the custom operator<< overload to see the private
// ConnectionState type.
friend std::ostream& operator<<(std::ostream& out,
const ConnectionState& connection_state);
struct TimerClosure;
// A poll worker which schedules itself unless kicked
class IOCPWorkClosure : public EventEngine::Closure {
public:
@ -137,25 +244,29 @@ class WindowsEventEngine : public EventEngine,
IOCP* iocp_;
};
// Called via IOCP notifications when a connection is ready to be processed.
// Either this or the deadline timer will run, never both.
void OnConnectCompleted(std::shared_ptr<ConnectionState> state);
// CancelConnect called from within the deadline timer.
// In this case, the connection_state->mu is already locked, and timer
// cancellation is not possible.
// Called after a timeout when no connection has been established.
// Either this or the on-connect callback will run, never both.
void OnDeadlineTimerFired(std::shared_ptr<ConnectionState> state);
// CancelConnect, called from within the deadline timer.
// Timer cancellation is not possible.
bool CancelConnectFromDeadlineTimer(ConnectionState* connection_state)
ABSL_EXCLUSIVE_LOCKS_REQUIRED(connection_state->mu);
ABSL_EXCLUSIVE_LOCKS_REQUIRED(connection_state->mu());
// Completes the connection cancellation logic after checking handle validity
// and optionally cancelling deadline timers.
// Completes the connection cancellation logic after checking handle
// validity and optionally cancelling deadline timers.
bool CancelConnectInternalStateLocked(ConnectionState* connection_state)
ABSL_EXCLUSIVE_LOCKS_REQUIRED(connection_state->mu);
ABSL_EXCLUSIVE_LOCKS_REQUIRED(connection_state->mu());
struct TimerClosure;
EventEngine::TaskHandle RunAfterInternal(Duration when,
absl::AnyInvocable<void()> cb);
grpc_core::Mutex task_mu_;
TaskHandleSet known_handles_ ABSL_GUARDED_BY(task_mu_);
grpc_core::Mutex connection_mu_ ABSL_ACQUIRED_AFTER(ConnectionState::mu);
grpc_core::Mutex connection_mu_;
grpc_core::CondVar connection_cv_;
ConnectionHandleSet known_connection_handles_ ABSL_GUARDED_BY(connection_mu_);
std::atomic<intptr_t> aba_token_{0};

@ -51,4 +51,4 @@ std::ostream& operator<<(std::ostream& out, const DumpArgs& args) {
}
} // namespace dump_args_detail
} // namespace grpc_core
} // namespace grpc_core

@ -17,6 +17,7 @@
#include <stdlib.h>
#include <algorithm>
#include <chrono>
#include <memory>
#include <random>
#include <string>
@ -38,6 +39,7 @@
#include "src/core/lib/event_engine/channel_args_endpoint_config.h"
#include "src/core/lib/event_engine/tcp_socket_utils.h"
#include "src/core/lib/gprpp/crash.h"
#include "src/core/lib/gprpp/notification.h"
#include "src/core/lib/gprpp/time.h"
#include "src/core/lib/resource_quota/memory_quota.h"
@ -78,10 +80,19 @@ std::string GetNextSendMessage() {
}
void WaitForSingleOwner(std::shared_ptr<EventEngine> engine) {
WaitForSingleOwnerWithTimeout(std::move(engine), std::chrono::hours{24});
}
void WaitForSingleOwnerWithTimeout(std::shared_ptr<EventEngine> engine,
EventEngine::Duration timeout) {
int n = 0;
auto start = std::chrono::system_clock::now();
while (engine.use_count() > 1) {
++n;
if (n % 100 == 0) AsanAssertNoLeaks();
if (std::chrono::system_clock::now() - start > timeout) {
grpc_core::Crash("Timed out waiting for a single EventEngine owner");
}
GRPC_LOG_EVERY_N_SEC(2, GPR_INFO, "engine.use_count() = %ld",
engine.use_count());
absl::SleepFor(absl::Milliseconds(100));

@ -52,6 +52,14 @@ std::string GetNextSendMessage();
// Usage: WaitForSingleOwner(std::move(engine))
void WaitForSingleOwner(std::shared_ptr<EventEngine> engine);
// Waits until the use_count of the EventEngine shared_ptr has reached 1
// and returns.
// Callers must give up their ref, or this method will block forever.
// This version will CRASH after the given timeout
// Usage: WaitForSingleOwner(std::move(engine), 30s)
void WaitForSingleOwnerWithTimeout(std::shared_ptr<EventEngine> engine,
EventEngine::Duration timeout);
// A helper method to exchange data between two endpoints. It is assumed
// that both endpoints are connected. The data (specified as a string) is
// written by the sender_endpoint and read by the receiver_endpoint. It

Loading…
Cancel
Save