diff --git a/BUILD b/BUILD index ce012ea4d0f..97c1ee534cd 100644 --- a/BUILD +++ b/BUILD @@ -3112,6 +3112,7 @@ grpc_cc_library( external_deps = [ "absl/base:core_headers", "absl/container:flat_hash_set", + "absl/functional:any_invocable", "absl/status", "absl/status:statusor", "absl/strings", diff --git a/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_ev_driver_windows.cc b/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_ev_driver_windows.cc index a44132d9df2..ad1d73fb45a 100644 --- a/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_ev_driver_windows.cc +++ b/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_ev_driver_windows.cc @@ -22,8 +22,13 @@ #include +#include +#include +#include + #include +#include "absl/functional/any_invocable.h" #include "absl/strings/str_format.h" #include @@ -57,6 +62,8 @@ struct iovec { namespace grpc_core { +namespace { + // c-ares reads and takes action on the error codes of the // "virtual socket operations" in this file, via the WSAGetLastError // APIs. If code in this file wants to set a specific WSA error that @@ -92,7 +99,7 @@ class WSAErrorContext { // from c-ares and are used with the grpc windows poller, and it, e.g., // manufactures virtual socket error codes when it e.g. needs to tell the c-ares // library to wait for an async read. -class GrpcPolledFdWindows { +class GrpcPolledFdWindows : public GrpcPolledFd { public: enum WriteState { WRITE_IDLE, @@ -102,15 +109,15 @@ class GrpcPolledFdWindows { }; GrpcPolledFdWindows(ares_socket_t as, Mutex* mu, int address_family, - int socket_type) + int socket_type, + absl::AnyInvocable on_shutdown_locked) : mu_(mu), read_buf_(grpc_empty_slice()), write_buf_(grpc_empty_slice()), - tcp_write_state_(WRITE_IDLE), name_(absl::StrFormat("c-ares socket: %" PRIdPTR, as)), - gotten_into_driver_list_(false), address_family_(address_family), - socket_type_(socket_type) { + socket_type_(socket_type), + on_shutdown_locked_(std::move(on_shutdown_locked)) { // Closure Initialization GRPC_CLOSURE_INIT(&outer_read_closure_, &GrpcPolledFdWindows::OnIocpReadable, this, @@ -124,11 +131,18 @@ class GrpcPolledFdWindows { winsocket_ = grpc_winsocket_create(as, name_.c_str()); } - ~GrpcPolledFdWindows() { + ~GrpcPolledFdWindows() override { + GRPC_CARES_TRACE_LOG("fd:|%s| ~GrpcPolledFdWindows shutdown_called_: %d ", + GetName(), shutdown_called_); CSliceUnref(read_buf_); CSliceUnref(write_buf_); GPR_ASSERT(read_closure_ == nullptr); GPR_ASSERT(write_closure_ == nullptr); + if (!shutdown_called_) { + // This can happen if the socket was never seen by grpc ares wrapper + // code, i.e. if we never started I/O polling on it. + grpc_winsocket_shutdown(winsocket_); + } grpc_winsocket_destroy(winsocket_); } @@ -142,7 +156,7 @@ class GrpcPolledFdWindows { write_closure_ = nullptr; } - void RegisterForOnReadableLocked(grpc_closure* read_closure) { + void RegisterForOnReadableLocked(grpc_closure* read_closure) override { GPR_ASSERT(read_closure_ == nullptr); read_closure_ = read_closure; GPR_ASSERT(GRPC_SLICE_LENGTH(read_buf_) == 0); @@ -193,23 +207,28 @@ class GrpcPolledFdWindows { grpc_socket_notify_on_read(winsocket_, &outer_read_closure_); } - void RegisterForOnWriteableLocked(grpc_closure* write_closure) { + void RegisterForOnWriteableLocked(grpc_closure* write_closure) override { if (socket_type_ == SOCK_DGRAM) { GRPC_CARES_TRACE_LOG("fd:|%s| RegisterForOnWriteableLocked called", GetName()); } else { GPR_ASSERT(socket_type_ == SOCK_STREAM); GRPC_CARES_TRACE_LOG( - "fd:|%s| RegisterForOnWriteableLocked called tcp_write_state_: %d", - GetName(), tcp_write_state_); + "fd:|%s| RegisterForOnWriteableLocked called tcp_write_state_: %d " + "connect_done_: %d", + GetName(), tcp_write_state_, connect_done_); } GPR_ASSERT(write_closure_ == nullptr); write_closure_ = write_closure; - if (connect_done_) { - ContinueRegisterForOnWriteableLocked(); - } else { - GPR_ASSERT(pending_continue_register_for_on_writeable_locked_ == false); + if (!connect_done_) { + GPR_ASSERT(!pending_continue_register_for_on_writeable_locked_); pending_continue_register_for_on_writeable_locked_ = true; + // Register an async OnTcpConnect callback here rather than when the + // connect was initiated, since we are now guaranteed to hold a ref of the + // c-ares wrapper before write_closure_ is called. + grpc_socket_notify_on_write(winsocket_, &on_tcp_connect_locked_); + } else { + ContinueRegisterForOnWriteableLocked(); } } @@ -250,17 +269,20 @@ class GrpcPolledFdWindows { } } - bool IsFdStillReadableLocked() { return read_buf_has_data_; } + bool IsFdStillReadableLocked() override { return read_buf_has_data_; } - void ShutdownLocked(grpc_error_handle /* error */) { + void ShutdownLocked(grpc_error_handle /* error */) override { + GPR_ASSERT(!shutdown_called_); + shutdown_called_ = true; + on_shutdown_locked_(); grpc_winsocket_shutdown(winsocket_); } - ares_socket_t GetWrappedAresSocketLocked() { + ares_socket_t GetWrappedAresSocketLocked() override { return grpc_winsocket_wrapped_socket(winsocket_); } - const char* GetName() const { return name_.c_str(); } + const char* GetName() const override { return name_.c_str(); } ares_ssize_t RecvFrom(WSAErrorContext* wsa_error_ctx, void* data, ares_socket_t data_len, int /* flags */, @@ -437,7 +459,9 @@ class GrpcPolledFdWindows { GPR_ASSERT(!connect_done_); connect_done_ = true; GPR_ASSERT(wsa_connect_error_ == 0); - if (error.ok()) { + if (!error.ok() || shutdown_called_) { + wsa_connect_error_ = WSA_OPERATION_ABORTED; + } else { DWORD transferred_bytes = 0; DWORD flags; BOOL wsa_success = @@ -454,10 +478,6 @@ class GrpcPolledFdWindows { GetName(), wsa_connect_error_, msg); gpr_free(msg); } - } else { - // Spoof up an error code that will cause any future c-ares operations on - // this fd to abort. - wsa_connect_error_ = WSA_OPERATION_ABORTED; } if (pending_continue_register_for_on_readable_locked_) { ContinueRegisterForOnReadableLocked(); @@ -564,7 +584,7 @@ class GrpcPolledFdWindows { return -1; } } - grpc_socket_notify_on_write(winsocket_, &on_tcp_connect_locked_); + // RegisterForOnWriteable will register for an async notification return out; } @@ -645,9 +665,6 @@ class GrpcPolledFdWindows { ScheduleAndNullWriteClosure(error); } - bool gotten_into_driver_list() const { return gotten_into_driver_list_; } - void set_gotten_into_driver_list() { gotten_into_driver_list_ = true; } - private: Mutex* mu_; char recv_from_source_addr_[200]; @@ -660,73 +677,48 @@ class GrpcPolledFdWindows { grpc_closure outer_read_closure_; grpc_closure outer_write_closure_; grpc_winsocket* winsocket_; - // tcp_write_state_ is only used on TCP GrpcPolledFds - WriteState tcp_write_state_; const std::string name_; - bool gotten_into_driver_list_; + bool shutdown_called_ = false; int address_family_; int socket_type_; + // State related to TCP sockets grpc_closure on_tcp_connect_locked_; bool connect_done_ = false; int wsa_connect_error_ = 0; + WriteState tcp_write_state_ = WRITE_IDLE; // We don't run register_for_{readable,writeable} logic until // a socket is connected. In the interim, we queue readable/writeable // registrations with the following state. bool pending_continue_register_for_on_readable_locked_ = false; bool pending_continue_register_for_on_writeable_locked_ = false; + absl::AnyInvocable on_shutdown_locked_; }; -struct SockToPolledFdEntry { - SockToPolledFdEntry(SOCKET s, GrpcPolledFdWindows* fd) - : socket(s), polled_fd(fd) {} - SOCKET socket; - GrpcPolledFdWindows* polled_fd; - SockToPolledFdEntry* next = nullptr; -}; - -// A SockToPolledFdMap can make ares_socket_t types (SOCKET's on windows) -// to GrpcPolledFdWindow's, and is used to find the appropriate -// GrpcPolledFdWindows to handle a virtual socket call when c-ares makes that -// socket call on the ares_socket_t type. Instances are owned by and one-to-one -// with a GrpcPolledFdWindows factory and event driver -class SockToPolledFdMap { +class GrpcPolledFdFactoryWindows : public GrpcPolledFdFactory { public: - explicit SockToPolledFdMap(Mutex* mu) : mu_(mu) {} + explicit GrpcPolledFdFactoryWindows(Mutex* mu) : mu_(mu) {} - ~SockToPolledFdMap() { GPR_ASSERT(head_ == nullptr); } - - void AddNewSocket(SOCKET s, GrpcPolledFdWindows* polled_fd) { - SockToPolledFdEntry* new_node = new SockToPolledFdEntry(s, polled_fd); - new_node->next = head_; - head_ = new_node; + ~GrpcPolledFdFactoryWindows() override { + // We might still have a socket -> polled fd mappings if the socket + // was never seen by the grpc ares wrapper code, i.e. if we never + // initiated I/O polling for them. + for (auto& it : sockets_) { + delete it.second; + } } - GrpcPolledFdWindows* LookupPolledFd(SOCKET s) { - for (SockToPolledFdEntry* node = head_; node != nullptr; - node = node->next) { - if (node->socket == s) { - GPR_ASSERT(node->polled_fd != nullptr); - return node->polled_fd; - } - } - abort(); + GrpcPolledFd* NewGrpcPolledFdLocked( + ares_socket_t as, grpc_pollset_set* /* driver_pollset_set */) override { + auto it = sockets_.find(as); + GPR_ASSERT(it != sockets_.end()); + return it->second; } - void RemoveEntry(SOCKET s) { - GPR_ASSERT(head_ != nullptr); - SockToPolledFdEntry** prev = &head_; - for (SockToPolledFdEntry* node = head_; node != nullptr; - node = node->next) { - if (node->socket == s) { - *prev = node->next; - delete node; - return; - } - prev = &node->next; - } - abort(); + void ConfigureAresChannelLocked(ares_channel channel) override { + ares_set_socket_functions(channel, &kCustomSockFuncs, this); } + private: // These virtual socket functions are called from within the c-ares // library. These methods generally dispatch those socket calls to the // appropriate methods. The virtual "socket" and "close" methods are @@ -738,7 +730,8 @@ class SockToPolledFdMap { GRPC_CARES_TRACE_LOG("Socket called with invalid socket type:%d", type); return INVALID_SOCKET; } - SockToPolledFdMap* map = static_cast(user_data); + GrpcPolledFdFactoryWindows* self = + static_cast(user_data); SOCKET s = WSASocket(af, type, protocol, nullptr, 0, grpc_get_default_wsa_socket_flags()); if (s == INVALID_SOCKET) { @@ -753,131 +746,68 @@ class SockToPolledFdMap { StatusToString(error).c_str()); return INVALID_SOCKET; } - GrpcPolledFdWindows* polled_fd = - new GrpcPolledFdWindows(s, map->mu_, af, type); + auto on_shutdown_locked = [self, s]() { + // grpc_winsocket_shutdown calls closesocket which invalidates our + // socket -> polled_fd mapping because the socket handle can be henceforth + // reused. + self->sockets_.erase(s); + }; + auto polled_fd = new GrpcPolledFdWindows(s, self->mu_, af, type, + std::move(on_shutdown_locked)); GRPC_CARES_TRACE_LOG( "fd:|%s| created with params af:%d type:%d protocol:%d", polled_fd->GetName(), af, type, protocol); - map->AddNewSocket(s, polled_fd); + GPR_ASSERT(self->sockets_.insert({s, polled_fd}).second); return s; } static int Connect(ares_socket_t as, const struct sockaddr* target, ares_socklen_t target_len, void* user_data) { WSAErrorContext wsa_error_ctx; - SockToPolledFdMap* map = static_cast(user_data); - GrpcPolledFdWindows* polled_fd = map->LookupPolledFd(as); - return polled_fd->Connect(&wsa_error_ctx, target, target_len); + GrpcPolledFdFactoryWindows* self = + static_cast(user_data); + auto it = self->sockets_.find(as); + GPR_ASSERT(it != self->sockets_.end()); + return it->second->Connect(&wsa_error_ctx, target, target_len); } static ares_ssize_t SendV(ares_socket_t as, const struct iovec* iov, int iovec_count, void* user_data) { WSAErrorContext wsa_error_ctx; - SockToPolledFdMap* map = static_cast(user_data); - GrpcPolledFdWindows* polled_fd = map->LookupPolledFd(as); - return polled_fd->SendV(&wsa_error_ctx, iov, iovec_count); + GrpcPolledFdFactoryWindows* self = + static_cast(user_data); + auto it = self->sockets_.find(as); + GPR_ASSERT(it != self->sockets_.end()); + return it->second->SendV(&wsa_error_ctx, iov, iovec_count); } static ares_ssize_t RecvFrom(ares_socket_t as, void* data, size_t data_len, int flags, struct sockaddr* from, ares_socklen_t* from_len, void* user_data) { WSAErrorContext wsa_error_ctx; - SockToPolledFdMap* map = static_cast(user_data); - GrpcPolledFdWindows* polled_fd = map->LookupPolledFd(as); - return polled_fd->RecvFrom(&wsa_error_ctx, data, data_len, flags, from, - from_len); - } - - static int CloseSocket(SOCKET s, void* user_data) { - SockToPolledFdMap* map = static_cast(user_data); - GrpcPolledFdWindows* polled_fd = map->LookupPolledFd(s); - map->RemoveEntry(s); - // See https://github.com/grpc/grpc/pull/20284, this trace log is - // intentionally placed to attempt to trigger a crash in case of a - // use after free on polled_fd. - GRPC_CARES_TRACE_LOG("CloseSocket called for socket: %s", - polled_fd->GetName()); - // If a gRPC polled fd has not made it in to the driver's list yet, then - // the driver has not and will never see this socket. - if (!polled_fd->gotten_into_driver_list()) { - polled_fd->ShutdownLocked(GRPC_ERROR_CREATE( - "Shut down c-ares fd before without it ever having made it into the " - "driver's list")); - } - delete polled_fd; - return 0; - } + GrpcPolledFdFactoryWindows* self = + static_cast(user_data); + auto it = self->sockets_.find(as); + GPR_ASSERT(it != self->sockets_.end()); + return it->second->RecvFrom(&wsa_error_ctx, data, data_len, flags, from, + from_len); + } + + static int CloseSocket(SOCKET /* s */, void* /* user_data */) { return 0; } + + const struct ares_socket_functions kCustomSockFuncs = { + &GrpcPolledFdFactoryWindows::Socket /* socket */, + &GrpcPolledFdFactoryWindows::CloseSocket /* close */, + &GrpcPolledFdFactoryWindows::Connect /* connect */, + &GrpcPolledFdFactoryWindows::RecvFrom /* recvfrom */, + &GrpcPolledFdFactoryWindows::SendV /* sendv */, + }; - private: Mutex* mu_; - SockToPolledFdEntry* head_ = nullptr; -}; - -const struct ares_socket_functions custom_ares_sock_funcs = { - &SockToPolledFdMap::Socket /* socket */, - &SockToPolledFdMap::CloseSocket /* close */, - &SockToPolledFdMap::Connect /* connect */, - &SockToPolledFdMap::RecvFrom /* recvfrom */, - &SockToPolledFdMap::SendV /* sendv */, + std::map sockets_; }; -// A thin wrapper over a GrpcPolledFdWindows object but with a shorter -// lifetime. This object releases it's GrpcPolledFdWindows upon destruction, -// so that c-ares can close it via usual socket teardown. -class GrpcPolledFdWindowsWrapper : public GrpcPolledFd { - public: - explicit GrpcPolledFdWindowsWrapper(GrpcPolledFdWindows* wrapped) - : wrapped_(wrapped) {} - - ~GrpcPolledFdWindowsWrapper() {} - - void RegisterForOnReadableLocked(grpc_closure* read_closure) override { - wrapped_->RegisterForOnReadableLocked(read_closure); - } - - void RegisterForOnWriteableLocked(grpc_closure* write_closure) override { - wrapped_->RegisterForOnWriteableLocked(write_closure); - } - - bool IsFdStillReadableLocked() override { - return wrapped_->IsFdStillReadableLocked(); - } - - void ShutdownLocked(grpc_error_handle error) override { - wrapped_->ShutdownLocked(error); - } - - ares_socket_t GetWrappedAresSocketLocked() override { - return wrapped_->GetWrappedAresSocketLocked(); - } - - const char* GetName() const override { return wrapped_->GetName(); } - - private: - GrpcPolledFdWindows* const wrapped_; -}; - -class GrpcPolledFdFactoryWindows : public GrpcPolledFdFactory { - public: - explicit GrpcPolledFdFactoryWindows(Mutex* mu) : sock_to_polled_fd_map_(mu) {} - - GrpcPolledFd* NewGrpcPolledFdLocked( - ares_socket_t as, grpc_pollset_set* /* driver_pollset_set */) override { - GrpcPolledFdWindows* polled_fd = sock_to_polled_fd_map_.LookupPolledFd(as); - // Set a flag so that the virtual socket "close" method knows it - // doesn't need to call ShutdownLocked, since now the driver will. - polled_fd->set_gotten_into_driver_list(); - return new GrpcPolledFdWindowsWrapper(polled_fd); - } - - void ConfigureAresChannelLocked(ares_channel channel) override { - ares_set_socket_functions(channel, &custom_ares_sock_funcs, - &sock_to_polled_fd_map_); - } - - private: - SockToPolledFdMap sock_to_polled_fd_map_; -}; +} // namespace std::unique_ptr NewGrpcPolledFdFactory(Mutex* mu) { return std::make_unique(mu);