[EventEngine] Fix use-after-free in the IOCP read callback (#33447)

In rare cases across gRPC's end2end tests (only in Windows RBE with
debug builds, interestingly enough), the endpoint may be destroyed
before the final IOCP read callback has run. This edge case is only
triggered when all the following are true:

* the previous read operation received data (not an error, not
0-length), so more data is expected in the stream.
* the socket has not yet shut down
* the application destroyed its endpoint before (or during) the IOCP
callback execution.
* the Read operation has not yet called the client's on_read callback.

This is a valid scenario, and it is expected that the engine
implementation should call the application callbacks with error statuses
when this occurs. This PR fixes two associated bugs.
pull/33457/head
AJ Heller 1 year ago committed by GitHub
parent e5035063e8
commit 6c6faa9cab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 72
      src/core/lib/event_engine/windows/windows_endpoint.cc
  2. 13
      src/core/lib/event_engine/windows/windows_endpoint.h

@ -55,9 +55,8 @@ WindowsEndpoint::WindowsEndpoint(
std::shared_ptr<EventEngine> engine)
: peer_address_(peer_address),
allocator_(std::move(allocator)),
thread_pool_(thread_pool),
io_state_(std::make_shared<AsyncIOState>(this, std::move(socket),
std::move(engine))) {
io_state_(std::make_shared<AsyncIOState>(
this, std::move(socket), std::move(engine), thread_pool)) {
char addr[EventEngine::ResolvedAddress::MAX_SIZE_BYTES];
int addr_len = sizeof(addr);
if (getsockname(io_state_->socket->raw_socket(),
@ -77,9 +76,9 @@ WindowsEndpoint::~WindowsEndpoint() {
GRPC_EVENT_ENGINE_ENDPOINT_TRACE("~WindowsEndpoint::%p", this);
}
absl::Status WindowsEndpoint::DoTcpRead(SliceBuffer* buffer) {
GRPC_EVENT_ENGINE_ENDPOINT_TRACE("WindowsEndpoint::%p reading", this);
if (io_state_->socket->IsShutdown()) {
absl::Status WindowsEndpoint::AsyncIOState::DoTcpRead(SliceBuffer* buffer) {
GRPC_EVENT_ENGINE_ENDPOINT_TRACE("WindowsEndpoint::%p reading", endpoint);
if (socket->IsShutdown()) {
return absl::UnavailableError("Socket is shutting down.");
}
// Prepare the WSABUF struct
@ -94,25 +93,25 @@ absl::Status WindowsEndpoint::DoTcpRead(SliceBuffer* buffer) {
DWORD flags = 0;
// First try a synchronous, non-blocking read.
int status =
WSARecv(io_state_->socket->raw_socket(), wsa_buffers,
(DWORD)buffer->Count(), &bytes_read, &flags, nullptr, nullptr);
WSARecv(socket->raw_socket(), wsa_buffers, (DWORD)buffer->Count(),
&bytes_read, &flags, nullptr, nullptr);
int wsa_error = status == 0 ? 0 : WSAGetLastError();
if (wsa_error != WSAEWOULDBLOCK) {
// Data or some error was returned immediately.
io_state_->socket->read_info()->SetResult(
socket->read_info()->SetResult(
{/*wsa_error=*/wsa_error, /*bytes_read=*/bytes_read});
thread_pool_->Run(&io_state_->handle_read_event);
thread_pool->Run(&handle_read_event);
return absl::OkStatus();
}
// If the endpoint has already received some data, and the next call would
// block, return the data in case that is all the data the reader expects.
if (io_state_->handle_read_event.MaybeFinishIfDataHasAlreadyBeenRead()) {
if (handle_read_event.MaybeFinishIfDataHasAlreadyBeenRead()) {
return absl::OkStatus();
}
// Otherwise, let's retry, by queuing a read.
status = WSARecv(io_state_->socket->raw_socket(), wsa_buffers,
(DWORD)buffer->Count(), &bytes_read, &flags,
io_state_->socket->read_info()->overlapped(), nullptr);
status =
WSARecv(socket->raw_socket(), wsa_buffers, (DWORD)buffer->Count(),
&bytes_read, &flags, socket->read_info()->overlapped(), nullptr);
wsa_error = status == 0 ? 0 : WSAGetLastError();
if (wsa_error != 0 && wsa_error != WSA_IO_PENDING) {
// Async read returned immediately with an error
@ -120,14 +119,14 @@ absl::Status WindowsEndpoint::DoTcpRead(SliceBuffer* buffer) {
wsa_error,
absl::StrFormat("WindowsEndpont::%p Read failed", this).c_str());
}
io_state_->socket->NotifyOnRead(&io_state_->handle_read_event);
socket->NotifyOnRead(&handle_read_event);
return absl::OkStatus();
}
bool WindowsEndpoint::Read(absl::AnyInvocable<void(absl::Status)> on_read,
SliceBuffer* buffer, const ReadArgs* /* args */) {
if (io_state_->socket->IsShutdown()) {
thread_pool_->Run([on_read = std::move(on_read)]() mutable {
io_state_->thread_pool->Run([on_read = std::move(on_read)]() mutable {
on_read(absl::UnavailableError("Socket is shutting down."));
});
return false;
@ -141,10 +140,10 @@ bool WindowsEndpoint::Read(absl::AnyInvocable<void(absl::Status)> on_read,
buffer->AppendIndexed(Slice(allocator_.MakeSlice(min_read_size)));
}
io_state_->handle_read_event.Prime(io_state_, buffer, std::move(on_read));
auto status = DoTcpRead(buffer);
auto status = io_state_->DoTcpRead(buffer);
if (!status.ok()) {
// The read could not be completed.
thread_pool_->Run(
io_state_->thread_pool->Run(
[cb = io_state_->handle_read_event.ResetAndReturnCallback(),
status]() mutable { cb(status); });
}
@ -155,9 +154,10 @@ bool WindowsEndpoint::Write(absl::AnyInvocable<void(absl::Status)> on_writable,
SliceBuffer* data, const WriteArgs* /* args */) {
GRPC_EVENT_ENGINE_ENDPOINT_TRACE("WindowsEndpoint::%p writing", this);
if (io_state_->socket->IsShutdown()) {
thread_pool_->Run([on_writable = std::move(on_writable)]() mutable {
on_writable(absl::UnavailableError("Socket is shutting down."));
});
io_state_->thread_pool->Run(
[on_writable = std::move(on_writable)]() mutable {
on_writable(absl::UnavailableError("Socket is shutting down."));
});
return false;
}
if (grpc_event_engine_endpoint_data_trace.enabled()) {
@ -183,7 +183,7 @@ bool WindowsEndpoint::Write(absl::AnyInvocable<void(absl::Status)> on_writable,
if (status == 0) {
if (bytes_sent == data->Length()) {
// Write completed, exiting early
thread_pool_->Run(
io_state_->thread_pool->Run(
[cb = std::move(on_writable)]() mutable { cb(absl::OkStatus()); });
return false;
}
@ -204,9 +204,10 @@ bool WindowsEndpoint::Write(absl::AnyInvocable<void(absl::Status)> on_writable,
// then we can avoid doing an async write operation at all.
int wsa_error = WSAGetLastError();
if (wsa_error != WSAEWOULDBLOCK) {
thread_pool_->Run([cb = std::move(on_writable), wsa_error]() mutable {
cb(GRPC_WSA_ERROR(wsa_error, "WSASend"));
});
io_state_->thread_pool->Run(
[cb = std::move(on_writable), wsa_error]() mutable {
cb(GRPC_WSA_ERROR(wsa_error, "WSASend"));
});
return false;
}
}
@ -218,9 +219,10 @@ bool WindowsEndpoint::Write(absl::AnyInvocable<void(absl::Status)> on_writable,
if (status != 0) {
int wsa_error = WSAGetLastError();
if (wsa_error != WSA_IO_PENDING) {
thread_pool_->Run([cb = std::move(on_writable), wsa_error]() mutable {
cb(GRPC_WSA_ERROR(wsa_error, "WSASend"));
});
io_state_->thread_pool->Run(
[cb = std::move(on_writable), wsa_error]() mutable {
cb(GRPC_WSA_ERROR(wsa_error, "WSASend"));
});
return false;
}
}
@ -249,18 +251,18 @@ void AbortOnEvent(absl::Status) {
absl::AnyInvocable<void(absl::Status)>
WindowsEndpoint::HandleReadClosure::ResetAndReturnCallback() {
auto cb = std::move(cb_);
io_state_.reset();
cb_ = &AbortOnEvent;
buffer_ = nullptr;
io_state_.reset();
return cb;
}
absl::AnyInvocable<void(absl::Status)>
WindowsEndpoint::HandleWriteClosure::ResetAndReturnCallback() {
auto cb = std::move(cb_);
io_state_.reset();
cb_ = &AbortOnEvent;
buffer_ = nullptr;
io_state_.reset();
return cb;
}
@ -316,7 +318,7 @@ void WindowsEndpoint::HandleReadClosure::Run() {
}
// Doing another read. Let's keep the AsyncIOState alive a bit longer.
io_state_ = std::move(io_state);
status = io_state_->endpoint->DoTcpRead(buffer_);
status = io_state_->DoTcpRead(buffer_);
if (!status.ok()) {
return ResetAndReturnCallback()(status);
}
@ -326,7 +328,7 @@ bool WindowsEndpoint::HandleReadClosure::MaybeFinishIfDataHasAlreadyBeenRead() {
if (last_read_buffer_.Length() > 0) {
buffer_->Swap(last_read_buffer_);
// Captures io_state_ to ensure it remains alive until the callback is run.
io_state_->endpoint->thread_pool_->Run(
io_state_->thread_pool->Run(
[cb = ResetAndReturnCallback()]() mutable { cb(absl::OkStatus()); });
return true;
}
@ -361,10 +363,12 @@ void WindowsEndpoint::HandleWriteClosure::Run() {
WindowsEndpoint::AsyncIOState::AsyncIOState(WindowsEndpoint* endpoint,
std::unique_ptr<WinSocket> socket,
std::shared_ptr<EventEngine> engine)
std::shared_ptr<EventEngine> engine,
ThreadPool* thread_pool)
: endpoint(endpoint),
socket(std::move(socket)),
engine(std::move(engine)) {}
engine(std::move(engine)),
thread_pool(thread_pool) {}
WindowsEndpoint::AsyncIOState::~AsyncIOState() {
socket->Shutdown(DEBUG_LOCATION, "~AsyncIOState");

@ -92,25 +92,26 @@ class WindowsEndpoint : public EventEngine::Endpoint {
// events are complete.
struct AsyncIOState {
AsyncIOState(WindowsEndpoint* endpoint, std::unique_ptr<WinSocket> socket,
std::shared_ptr<EventEngine> engine);
std::shared_ptr<EventEngine> engine, ThreadPool* thread_pool);
~AsyncIOState();
// Perform the low-level calls and execute the HandleReadClosure
// asynchronously.
absl::Status DoTcpRead(SliceBuffer* buffer);
WindowsEndpoint* const endpoint;
std::unique_ptr<WinSocket> socket;
HandleReadClosure handle_read_event;
HandleWriteClosure handle_write_event;
std::shared_ptr<EventEngine> engine;
ThreadPool* thread_pool;
};
// Perform the low-level calls and execute the HandleReadClosure
// asynchronously.
absl::Status DoTcpRead(SliceBuffer* buffer);
EventEngine::ResolvedAddress peer_address_;
std::string peer_address_string_;
EventEngine::ResolvedAddress local_address_;
std::string local_address_string_;
MemoryAllocator allocator_;
ThreadPool* thread_pool_;
std::shared_ptr<AsyncIOState> io_state_;
};

Loading…
Cancel
Save