[windows DNS] Simplify c-ares Windows code (#33965)

A set of simplifications to make this code easier to reason about:

- Replace `SockToPolledFdMap` with `std::map`

- Make the c-ares `close` callback do nothing. Instead, let the ares
wrapper code destroy polled fds as it normally does, and let everything
that hasn't been registered for I/O get destroyed in the
`GrpcPolledFdFactoryWindows` dtor.

- Get rid of `GrpcPolledFdWindowsWrapper`

- Move `socket_notify_on_write` to the `RegisterForOnWriteableLocked`
method. This makes for a nice invariant that no async callback is
pending *unless* a `RegisterForOnWriteableLocked` or
`RegisterForOnReadableLocked` callback is pending.

Related: internal issue b/293321613
pull/34188/head
apolcyn 1 year ago committed by GitHub
parent 395ff71b8d
commit 5d85d7d6e3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 1
      BUILD
  2. 280
      src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_ev_driver_windows.cc

@ -3112,6 +3112,7 @@ grpc_cc_library(
external_deps = [
"absl/base:core_headers",
"absl/container:flat_hash_set",
"absl/functional:any_invocable",
"absl/status",
"absl/status:statusor",
"absl/strings",

@ -22,8 +22,13 @@
#include <string.h>
#include <map>
#include <memory>
#include <unordered_set>
#include <ares.h>
#include "absl/functional/any_invocable.h"
#include "absl/strings/str_format.h"
#include <grpc/support/alloc.h>
@ -57,6 +62,8 @@ struct iovec {
namespace grpc_core {
namespace {
// c-ares reads and takes action on the error codes of the
// "virtual socket operations" in this file, via the WSAGetLastError
// APIs. If code in this file wants to set a specific WSA error that
@ -92,7 +99,7 @@ class WSAErrorContext {
// from c-ares and are used with the grpc windows poller, and it, e.g.,
// manufactures virtual socket error codes when it e.g. needs to tell the c-ares
// library to wait for an async read.
class GrpcPolledFdWindows {
class GrpcPolledFdWindows : public GrpcPolledFd {
public:
enum WriteState {
WRITE_IDLE,
@ -102,15 +109,15 @@ class GrpcPolledFdWindows {
};
GrpcPolledFdWindows(ares_socket_t as, Mutex* mu, int address_family,
int socket_type)
int socket_type,
absl::AnyInvocable<void()> on_shutdown_locked)
: mu_(mu),
read_buf_(grpc_empty_slice()),
write_buf_(grpc_empty_slice()),
tcp_write_state_(WRITE_IDLE),
name_(absl::StrFormat("c-ares socket: %" PRIdPTR, as)),
gotten_into_driver_list_(false),
address_family_(address_family),
socket_type_(socket_type) {
socket_type_(socket_type),
on_shutdown_locked_(std::move(on_shutdown_locked)) {
// Closure Initialization
GRPC_CLOSURE_INIT(&outer_read_closure_,
&GrpcPolledFdWindows::OnIocpReadable, this,
@ -124,11 +131,18 @@ class GrpcPolledFdWindows {
winsocket_ = grpc_winsocket_create(as, name_.c_str());
}
~GrpcPolledFdWindows() {
~GrpcPolledFdWindows() override {
GRPC_CARES_TRACE_LOG("fd:|%s| ~GrpcPolledFdWindows shutdown_called_: %d ",
GetName(), shutdown_called_);
CSliceUnref(read_buf_);
CSliceUnref(write_buf_);
GPR_ASSERT(read_closure_ == nullptr);
GPR_ASSERT(write_closure_ == nullptr);
if (!shutdown_called_) {
// This can happen if the socket was never seen by grpc ares wrapper
// code, i.e. if we never started I/O polling on it.
grpc_winsocket_shutdown(winsocket_);
}
grpc_winsocket_destroy(winsocket_);
}
@ -142,7 +156,7 @@ class GrpcPolledFdWindows {
write_closure_ = nullptr;
}
void RegisterForOnReadableLocked(grpc_closure* read_closure) {
void RegisterForOnReadableLocked(grpc_closure* read_closure) override {
GPR_ASSERT(read_closure_ == nullptr);
read_closure_ = read_closure;
GPR_ASSERT(GRPC_SLICE_LENGTH(read_buf_) == 0);
@ -193,23 +207,28 @@ class GrpcPolledFdWindows {
grpc_socket_notify_on_read(winsocket_, &outer_read_closure_);
}
void RegisterForOnWriteableLocked(grpc_closure* write_closure) {
void RegisterForOnWriteableLocked(grpc_closure* write_closure) override {
if (socket_type_ == SOCK_DGRAM) {
GRPC_CARES_TRACE_LOG("fd:|%s| RegisterForOnWriteableLocked called",
GetName());
} else {
GPR_ASSERT(socket_type_ == SOCK_STREAM);
GRPC_CARES_TRACE_LOG(
"fd:|%s| RegisterForOnWriteableLocked called tcp_write_state_: %d",
GetName(), tcp_write_state_);
"fd:|%s| RegisterForOnWriteableLocked called tcp_write_state_: %d "
"connect_done_: %d",
GetName(), tcp_write_state_, connect_done_);
}
GPR_ASSERT(write_closure_ == nullptr);
write_closure_ = write_closure;
if (connect_done_) {
ContinueRegisterForOnWriteableLocked();
} else {
GPR_ASSERT(pending_continue_register_for_on_writeable_locked_ == false);
if (!connect_done_) {
GPR_ASSERT(!pending_continue_register_for_on_writeable_locked_);
pending_continue_register_for_on_writeable_locked_ = true;
// Register an async OnTcpConnect callback here rather than when the
// connect was initiated, since we are now guaranteed to hold a ref of the
// c-ares wrapper before write_closure_ is called.
grpc_socket_notify_on_write(winsocket_, &on_tcp_connect_locked_);
} else {
ContinueRegisterForOnWriteableLocked();
}
}
@ -250,17 +269,20 @@ class GrpcPolledFdWindows {
}
}
bool IsFdStillReadableLocked() { return read_buf_has_data_; }
bool IsFdStillReadableLocked() override { return read_buf_has_data_; }
void ShutdownLocked(grpc_error_handle /* error */) {
void ShutdownLocked(grpc_error_handle /* error */) override {
GPR_ASSERT(!shutdown_called_);
shutdown_called_ = true;
on_shutdown_locked_();
grpc_winsocket_shutdown(winsocket_);
}
ares_socket_t GetWrappedAresSocketLocked() {
ares_socket_t GetWrappedAresSocketLocked() override {
return grpc_winsocket_wrapped_socket(winsocket_);
}
const char* GetName() const { return name_.c_str(); }
const char* GetName() const override { return name_.c_str(); }
ares_ssize_t RecvFrom(WSAErrorContext* wsa_error_ctx, void* data,
ares_socket_t data_len, int /* flags */,
@ -437,7 +459,9 @@ class GrpcPolledFdWindows {
GPR_ASSERT(!connect_done_);
connect_done_ = true;
GPR_ASSERT(wsa_connect_error_ == 0);
if (error.ok()) {
if (!error.ok() || shutdown_called_) {
wsa_connect_error_ = WSA_OPERATION_ABORTED;
} else {
DWORD transferred_bytes = 0;
DWORD flags;
BOOL wsa_success =
@ -454,10 +478,6 @@ class GrpcPolledFdWindows {
GetName(), wsa_connect_error_, msg);
gpr_free(msg);
}
} else {
// Spoof up an error code that will cause any future c-ares operations on
// this fd to abort.
wsa_connect_error_ = WSA_OPERATION_ABORTED;
}
if (pending_continue_register_for_on_readable_locked_) {
ContinueRegisterForOnReadableLocked();
@ -564,7 +584,7 @@ class GrpcPolledFdWindows {
return -1;
}
}
grpc_socket_notify_on_write(winsocket_, &on_tcp_connect_locked_);
// RegisterForOnWriteable will register for an async notification
return out;
}
@ -645,9 +665,6 @@ class GrpcPolledFdWindows {
ScheduleAndNullWriteClosure(error);
}
bool gotten_into_driver_list() const { return gotten_into_driver_list_; }
void set_gotten_into_driver_list() { gotten_into_driver_list_ = true; }
private:
Mutex* mu_;
char recv_from_source_addr_[200];
@ -660,73 +677,48 @@ class GrpcPolledFdWindows {
grpc_closure outer_read_closure_;
grpc_closure outer_write_closure_;
grpc_winsocket* winsocket_;
// tcp_write_state_ is only used on TCP GrpcPolledFds
WriteState tcp_write_state_;
const std::string name_;
bool gotten_into_driver_list_;
bool shutdown_called_ = false;
int address_family_;
int socket_type_;
// State related to TCP sockets
grpc_closure on_tcp_connect_locked_;
bool connect_done_ = false;
int wsa_connect_error_ = 0;
WriteState tcp_write_state_ = WRITE_IDLE;
// We don't run register_for_{readable,writeable} logic until
// a socket is connected. In the interim, we queue readable/writeable
// registrations with the following state.
bool pending_continue_register_for_on_readable_locked_ = false;
bool pending_continue_register_for_on_writeable_locked_ = false;
absl::AnyInvocable<void()> on_shutdown_locked_;
};
struct SockToPolledFdEntry {
SockToPolledFdEntry(SOCKET s, GrpcPolledFdWindows* fd)
: socket(s), polled_fd(fd) {}
SOCKET socket;
GrpcPolledFdWindows* polled_fd;
SockToPolledFdEntry* next = nullptr;
};
// A SockToPolledFdMap can make ares_socket_t types (SOCKET's on windows)
// to GrpcPolledFdWindow's, and is used to find the appropriate
// GrpcPolledFdWindows to handle a virtual socket call when c-ares makes that
// socket call on the ares_socket_t type. Instances are owned by and one-to-one
// with a GrpcPolledFdWindows factory and event driver
class SockToPolledFdMap {
class GrpcPolledFdFactoryWindows : public GrpcPolledFdFactory {
public:
explicit SockToPolledFdMap(Mutex* mu) : mu_(mu) {}
explicit GrpcPolledFdFactoryWindows(Mutex* mu) : mu_(mu) {}
~SockToPolledFdMap() { GPR_ASSERT(head_ == nullptr); }
void AddNewSocket(SOCKET s, GrpcPolledFdWindows* polled_fd) {
SockToPolledFdEntry* new_node = new SockToPolledFdEntry(s, polled_fd);
new_node->next = head_;
head_ = new_node;
~GrpcPolledFdFactoryWindows() override {
// We might still have a socket -> polled fd mappings if the socket
// was never seen by the grpc ares wrapper code, i.e. if we never
// initiated I/O polling for them.
for (auto& it : sockets_) {
delete it.second;
}
}
GrpcPolledFdWindows* LookupPolledFd(SOCKET s) {
for (SockToPolledFdEntry* node = head_; node != nullptr;
node = node->next) {
if (node->socket == s) {
GPR_ASSERT(node->polled_fd != nullptr);
return node->polled_fd;
}
}
abort();
GrpcPolledFd* NewGrpcPolledFdLocked(
ares_socket_t as, grpc_pollset_set* /* driver_pollset_set */) override {
auto it = sockets_.find(as);
GPR_ASSERT(it != sockets_.end());
return it->second;
}
void RemoveEntry(SOCKET s) {
GPR_ASSERT(head_ != nullptr);
SockToPolledFdEntry** prev = &head_;
for (SockToPolledFdEntry* node = head_; node != nullptr;
node = node->next) {
if (node->socket == s) {
*prev = node->next;
delete node;
return;
}
prev = &node->next;
}
abort();
void ConfigureAresChannelLocked(ares_channel channel) override {
ares_set_socket_functions(channel, &kCustomSockFuncs, this);
}
private:
// These virtual socket functions are called from within the c-ares
// library. These methods generally dispatch those socket calls to the
// appropriate methods. The virtual "socket" and "close" methods are
@ -738,7 +730,8 @@ class SockToPolledFdMap {
GRPC_CARES_TRACE_LOG("Socket called with invalid socket type:%d", type);
return INVALID_SOCKET;
}
SockToPolledFdMap* map = static_cast<SockToPolledFdMap*>(user_data);
GrpcPolledFdFactoryWindows* self =
static_cast<GrpcPolledFdFactoryWindows*>(user_data);
SOCKET s = WSASocket(af, type, protocol, nullptr, 0,
grpc_get_default_wsa_socket_flags());
if (s == INVALID_SOCKET) {
@ -753,131 +746,68 @@ class SockToPolledFdMap {
StatusToString(error).c_str());
return INVALID_SOCKET;
}
GrpcPolledFdWindows* polled_fd =
new GrpcPolledFdWindows(s, map->mu_, af, type);
auto on_shutdown_locked = [self, s]() {
// grpc_winsocket_shutdown calls closesocket which invalidates our
// socket -> polled_fd mapping because the socket handle can be henceforth
// reused.
self->sockets_.erase(s);
};
auto polled_fd = new GrpcPolledFdWindows(s, self->mu_, af, type,
std::move(on_shutdown_locked));
GRPC_CARES_TRACE_LOG(
"fd:|%s| created with params af:%d type:%d protocol:%d",
polled_fd->GetName(), af, type, protocol);
map->AddNewSocket(s, polled_fd);
GPR_ASSERT(self->sockets_.insert({s, polled_fd}).second);
return s;
}
static int Connect(ares_socket_t as, const struct sockaddr* target,
ares_socklen_t target_len, void* user_data) {
WSAErrorContext wsa_error_ctx;
SockToPolledFdMap* map = static_cast<SockToPolledFdMap*>(user_data);
GrpcPolledFdWindows* polled_fd = map->LookupPolledFd(as);
return polled_fd->Connect(&wsa_error_ctx, target, target_len);
GrpcPolledFdFactoryWindows* self =
static_cast<GrpcPolledFdFactoryWindows*>(user_data);
auto it = self->sockets_.find(as);
GPR_ASSERT(it != self->sockets_.end());
return it->second->Connect(&wsa_error_ctx, target, target_len);
}
static ares_ssize_t SendV(ares_socket_t as, const struct iovec* iov,
int iovec_count, void* user_data) {
WSAErrorContext wsa_error_ctx;
SockToPolledFdMap* map = static_cast<SockToPolledFdMap*>(user_data);
GrpcPolledFdWindows* polled_fd = map->LookupPolledFd(as);
return polled_fd->SendV(&wsa_error_ctx, iov, iovec_count);
GrpcPolledFdFactoryWindows* self =
static_cast<GrpcPolledFdFactoryWindows*>(user_data);
auto it = self->sockets_.find(as);
GPR_ASSERT(it != self->sockets_.end());
return it->second->SendV(&wsa_error_ctx, iov, iovec_count);
}
static ares_ssize_t RecvFrom(ares_socket_t as, void* data, size_t data_len,
int flags, struct sockaddr* from,
ares_socklen_t* from_len, void* user_data) {
WSAErrorContext wsa_error_ctx;
SockToPolledFdMap* map = static_cast<SockToPolledFdMap*>(user_data);
GrpcPolledFdWindows* polled_fd = map->LookupPolledFd(as);
return polled_fd->RecvFrom(&wsa_error_ctx, data, data_len, flags, from,
from_len);
}
static int CloseSocket(SOCKET s, void* user_data) {
SockToPolledFdMap* map = static_cast<SockToPolledFdMap*>(user_data);
GrpcPolledFdWindows* polled_fd = map->LookupPolledFd(s);
map->RemoveEntry(s);
// See https://github.com/grpc/grpc/pull/20284, this trace log is
// intentionally placed to attempt to trigger a crash in case of a
// use after free on polled_fd.
GRPC_CARES_TRACE_LOG("CloseSocket called for socket: %s",
polled_fd->GetName());
// If a gRPC polled fd has not made it in to the driver's list yet, then
// the driver has not and will never see this socket.
if (!polled_fd->gotten_into_driver_list()) {
polled_fd->ShutdownLocked(GRPC_ERROR_CREATE(
"Shut down c-ares fd before without it ever having made it into the "
"driver's list"));
}
delete polled_fd;
return 0;
}
GrpcPolledFdFactoryWindows* self =
static_cast<GrpcPolledFdFactoryWindows*>(user_data);
auto it = self->sockets_.find(as);
GPR_ASSERT(it != self->sockets_.end());
return it->second->RecvFrom(&wsa_error_ctx, data, data_len, flags, from,
from_len);
}
static int CloseSocket(SOCKET /* s */, void* /* user_data */) { return 0; }
const struct ares_socket_functions kCustomSockFuncs = {
&GrpcPolledFdFactoryWindows::Socket /* socket */,
&GrpcPolledFdFactoryWindows::CloseSocket /* close */,
&GrpcPolledFdFactoryWindows::Connect /* connect */,
&GrpcPolledFdFactoryWindows::RecvFrom /* recvfrom */,
&GrpcPolledFdFactoryWindows::SendV /* sendv */,
};
private:
Mutex* mu_;
SockToPolledFdEntry* head_ = nullptr;
};
const struct ares_socket_functions custom_ares_sock_funcs = {
&SockToPolledFdMap::Socket /* socket */,
&SockToPolledFdMap::CloseSocket /* close */,
&SockToPolledFdMap::Connect /* connect */,
&SockToPolledFdMap::RecvFrom /* recvfrom */,
&SockToPolledFdMap::SendV /* sendv */,
std::map<SOCKET, GrpcPolledFdWindows*> sockets_;
};
// A thin wrapper over a GrpcPolledFdWindows object but with a shorter
// lifetime. This object releases it's GrpcPolledFdWindows upon destruction,
// so that c-ares can close it via usual socket teardown.
class GrpcPolledFdWindowsWrapper : public GrpcPolledFd {
public:
explicit GrpcPolledFdWindowsWrapper(GrpcPolledFdWindows* wrapped)
: wrapped_(wrapped) {}
~GrpcPolledFdWindowsWrapper() {}
void RegisterForOnReadableLocked(grpc_closure* read_closure) override {
wrapped_->RegisterForOnReadableLocked(read_closure);
}
void RegisterForOnWriteableLocked(grpc_closure* write_closure) override {
wrapped_->RegisterForOnWriteableLocked(write_closure);
}
bool IsFdStillReadableLocked() override {
return wrapped_->IsFdStillReadableLocked();
}
void ShutdownLocked(grpc_error_handle error) override {
wrapped_->ShutdownLocked(error);
}
ares_socket_t GetWrappedAresSocketLocked() override {
return wrapped_->GetWrappedAresSocketLocked();
}
const char* GetName() const override { return wrapped_->GetName(); }
private:
GrpcPolledFdWindows* const wrapped_;
};
class GrpcPolledFdFactoryWindows : public GrpcPolledFdFactory {
public:
explicit GrpcPolledFdFactoryWindows(Mutex* mu) : sock_to_polled_fd_map_(mu) {}
GrpcPolledFd* NewGrpcPolledFdLocked(
ares_socket_t as, grpc_pollset_set* /* driver_pollset_set */) override {
GrpcPolledFdWindows* polled_fd = sock_to_polled_fd_map_.LookupPolledFd(as);
// Set a flag so that the virtual socket "close" method knows it
// doesn't need to call ShutdownLocked, since now the driver will.
polled_fd->set_gotten_into_driver_list();
return new GrpcPolledFdWindowsWrapper(polled_fd);
}
void ConfigureAresChannelLocked(ares_channel channel) override {
ares_set_socket_functions(channel, &custom_ares_sock_funcs,
&sock_to_polled_fd_map_);
}
private:
SockToPolledFdMap sock_to_polled_fd_map_;
};
} // namespace
std::unique_ptr<GrpcPolledFdFactory> NewGrpcPolledFdFactory(Mutex* mu) {
return std::make_unique<GrpcPolledFdFactoryWindows>(mu);

Loading…
Cancel
Save