diff --git a/ares__close_sockets.c b/ares__close_sockets.c index 6c66483b..f07904e8 100644 --- a/ares__close_sockets.c +++ b/ares__close_sockets.c @@ -48,14 +48,14 @@ void ares__close_sockets(ares_channel channel, struct server_state *server) if (server->tcp_socket != ARES_SOCKET_BAD) { SOCK_STATE_CALLBACK(channel, server->tcp_socket, 0, 0); - sclose(server->tcp_socket); + ares__socket_close(channel, server->tcp_socket); server->tcp_socket = ARES_SOCKET_BAD; server->tcp_connection_generation = ++channel->tcp_connection_generation; } if (server->udp_socket != ARES_SOCKET_BAD) { SOCK_STATE_CALLBACK(channel, server->udp_socket, 0, 0); - sclose(server->udp_socket); + ares__socket_close(channel, server->udp_socket); server->udp_socket = ARES_SOCKET_BAD; } } diff --git a/ares_private.h b/ares_private.h index a7988032..7359d5eb 100644 --- a/ares_private.h +++ b/ares_private.h @@ -345,6 +345,8 @@ void ares__destroy_servers_state(ares_channel channel); long ares__tvdiff(struct timeval t1, struct timeval t2); #endif +void ares__socket_close(ares_channel, ares_socket_t); + #define ARES_SWAP_BYTE(a,b) \ { unsigned char swapByte = *(a); *(a) = *(b); *(b) = swapByte; } diff --git a/ares_process.c b/ares_process.c index 1d1e7b8b..e9648965 100644 --- a/ares_process.c +++ b/ares_process.c @@ -1,6 +1,6 @@ /* Copyright 1998 by the Massachusetts Institute of Technology. - * Copyright (C) 2004-2016 by Daniel Stenberg + * Copyright (C) 2004-2017 by Daniel Stenberg * * Permission to use, copy, modify, and distribute this * software and its documentation for any purpose and without @@ -175,6 +175,26 @@ static int try_again(int errnum) return 0; } +static ssize_t socket_writev(ares_channel channel, ares_socket_t s, const struct iovec * vec, int len) +{ + if (channel->sock_funcs) + return channel->sock_funcs->asendv(s, vec, len, channel->sock_func_cb_data); + + return writev(s, vec, len); +} + +static ssize_t socket_write(ares_channel channel, ares_socket_t s, const void * data, size_t len) +{ + if (channel->sock_funcs) + { + struct iovec vec; + vec.iov_base = (void*)data; + vec.iov_len = len; + return channel->sock_funcs->asendv(s, &vec, 1, channel->sock_func_cb_data); + } + return swrite(s, data, len); +} + /* If any TCP sockets select true for writing, write out queued data * we have for them. */ @@ -238,7 +258,7 @@ static void write_tcp_data(ares_channel channel, vec[n].iov_len = sendreq->len; n++; } - wcount = (ssize_t)writev(server->tcp_socket, vec, (int)n); + wcount = socket_writev(channel, server->tcp_socket, vec, (int)n); ares_free(vec); if (wcount < 0) { @@ -255,7 +275,7 @@ static void write_tcp_data(ares_channel channel, /* Can't allocate iovecs; just send the first request. */ sendreq = server->qhead; - scount = swrite(server->tcp_socket, sendreq->data, sendreq->len); + scount = socket_write(channel, server->tcp_socket, sendreq->data, sendreq->len); if (scount < 0) { if (!try_again(SOCKERRNO)) @@ -299,6 +319,38 @@ static void advance_tcp_send_queue(ares_channel channel, int whichserver, } } +static ssize_t socket_recvfrom(ares_channel channel, + ares_socket_t s, + void * data, + size_t data_len, + int flags, + struct sockaddr *from, + socklen_t *from_len) +{ + if (channel->sock_funcs) + return channel->sock_funcs->arecvfrom(s, data, data_len, + flags, from, from_len, + channel->sock_func_cb_data); + +#ifdef HAVE_RECVFROM + return recvfrom(s, data, data_len, flags, from, from_len); +#else + return sread(s, data, data_len); +#endif +} + +static ssize_t socket_recv(ares_channel channel, + ares_socket_t s, + void * data, + size_t data_len) +{ + if (channel->sock_funcs) + return channel->sock_funcs->arecvfrom(s, data, data_len, 0, 0, 0, + channel->sock_func_cb_data); + + return sread(s, data, data_len); +} + /* If any TCP socket selects true for reading, read some data, * allocate a buffer if we finish reading the length word, and process * a packet if we finish reading one. @@ -343,9 +395,9 @@ static void read_tcp_data(ares_channel channel, fd_set *read_fds, /* We haven't yet read a length word, so read that (or * what's left to read of it). */ - count = sread(server->tcp_socket, - server->tcp_lenbuf + server->tcp_lenbuf_pos, - 2 - server->tcp_lenbuf_pos); + count = socket_recv(channel, server->tcp_socket, + server->tcp_lenbuf + server->tcp_lenbuf_pos, + 2 - server->tcp_lenbuf_pos); if (count <= 0) { if (!(count == -1 && try_again(SOCKERRNO))) @@ -373,9 +425,9 @@ static void read_tcp_data(ares_channel channel, fd_set *read_fds, else { /* Read data into the allocated buffer. */ - count = sread(server->tcp_socket, - server->tcp_buffer + server->tcp_buffer_pos, - server->tcp_length - server->tcp_buffer_pos); + count = socket_recv(channel, server->tcp_socket, + server->tcp_buffer + server->tcp_buffer_pos, + server->tcp_length - server->tcp_buffer_pos); if (count <= 0) { if (!(count == -1 && try_again(SOCKERRNO))) @@ -453,16 +505,12 @@ static void read_udp_packets(ares_channel channel, fd_set *read_fds, count = 0; else { -#ifdef HAVE_RECVFROM if (server->addr.family == AF_INET) fromlen = sizeof(from.sa4); else fromlen = sizeof(from.sa6); - count = (ssize_t)recvfrom(server->udp_socket, (void *)buf, - sizeof(buf), 0, &from.sa, &fromlen); -#else - count = sread(server->udp_socket, buf, sizeof(buf)); -#endif + count = socket_recvfrom(channel, server->udp_socket, (void *)buf, + sizeof(buf), 0, &from.sa, &fromlen); } if (count == -1 && try_again(SOCKERRNO)) @@ -812,7 +860,7 @@ void ares__send_query(ares_channel channel, struct query *query, return; } } - if (swrite(server->udp_socket, query->qbuf, query->qlen) == -1) + if (socket_write(channel, server->udp_socket, query->qbuf, query->qlen) == -1) { /* FIXME: Handle EAGAIN here since it likely can happen. */ skip_server(channel, query, query->server); @@ -904,6 +952,10 @@ static int configure_socket(ares_socket_t s, int family, ares_channel channel) struct sockaddr_in6 sa6; } local; + /* do not set options for user-managed sockets */ + if (channel->sock_funcs) + return 0; + (void)setsocknonblock(s, TRUE); #if defined(FD_CLOEXEC) && !defined(MSDOS) @@ -959,6 +1011,30 @@ static int configure_socket(ares_socket_t s, int family, ares_channel channel) return 0; } +static int open_socket(ares_channel channel, int af, int type, int protocol) +{ + if (channel->sock_funcs != 0) + return channel->sock_funcs->asocket(af, + type, + protocol, + channel->sock_func_cb_data); + + return socket(af, type, protocol); +} + +static int connect_socket(ares_channel channel, ares_socket_t sockfd, + const struct sockaddr * addr, + socklen_t addrlen) +{ + if (channel->sock_funcs != 0) + return channel->sock_funcs->aconnect(sockfd, + addr, + addrlen, + channel->sock_func_cb_data); + + return connect(sockfd, addr, addrlen); +} + static int open_tcp_socket(ares_channel channel, struct server_state *server) { ares_socket_t s; @@ -1003,14 +1079,14 @@ static int open_tcp_socket(ares_channel channel, struct server_state *server) } /* Acquire a socket. */ - s = socket(server->addr.family, SOCK_STREAM, 0); + s = open_socket(channel, server->addr.family, SOCK_STREAM, 0); if (s == ARES_SOCKET_BAD) return -1; /* Configure it. */ if (configure_socket(s, server->addr.family, channel) < 0) { - sclose(s); + ares__socket_close(channel, s); return -1; } @@ -1022,10 +1098,12 @@ static int open_tcp_socket(ares_channel channel, struct server_state *server) * so batching isn't very interesting. */ opt = 1; - if (setsockopt(s, IPPROTO_TCP, TCP_NODELAY, - (void *)&opt, sizeof(opt)) == -1) + if (channel->sock_funcs == 0 + && + setsockopt(s, IPPROTO_TCP, TCP_NODELAY, + (void *)&opt, sizeof(opt)) == -1) { - sclose(s); + ares__socket_close(channel, s); return -1; } #endif @@ -1036,19 +1114,19 @@ static int open_tcp_socket(ares_channel channel, struct server_state *server) channel->sock_config_cb_data); if (err < 0) { - sclose(s); + ares__socket_close(channel, s); return err; } } /* Connect to the server. */ - if (connect(s, sa, salen) == -1) + if (connect_socket(channel, s, sa, salen) == -1) { int err = SOCKERRNO; if (err != EINPROGRESS && err != EWOULDBLOCK) { - sclose(s); + ares__socket_close(channel, s); return -1; } } @@ -1059,7 +1137,7 @@ static int open_tcp_socket(ares_channel channel, struct server_state *server) channel->sock_create_cb_data); if (err < 0) { - sclose(s); + ares__socket_close(channel, s); return err; } } @@ -1114,14 +1192,14 @@ static int open_udp_socket(ares_channel channel, struct server_state *server) } /* Acquire a socket. */ - s = socket(server->addr.family, SOCK_DGRAM, 0); + s = open_socket(channel, server->addr.family, SOCK_DGRAM, 0); if (s == ARES_SOCKET_BAD) return -1; /* Set the socket non-blocking. */ if (configure_socket(s, server->addr.family, channel) < 0) { - sclose(s); + ares__socket_close(channel, s); return -1; } @@ -1131,19 +1209,19 @@ static int open_udp_socket(ares_channel channel, struct server_state *server) channel->sock_config_cb_data); if (err < 0) { - sclose(s); + ares__socket_close(channel, s); return err; } } /* Connect to the server. */ - if (connect(s, sa, salen) == -1) + if (connect_socket(channel, s, sa, salen) == -1) { int err = SOCKERRNO; if (err != EINPROGRESS && err != EWOULDBLOCK) { - sclose(s); + ares__socket_close(channel, s); return -1; } } @@ -1154,7 +1232,7 @@ static int open_udp_socket(ares_channel channel, struct server_state *server) channel->sock_create_cb_data); if (err < 0) { - sclose(s); + ares__socket_close(channel, s); return err; } } @@ -1357,3 +1435,11 @@ void ares__free_query(struct query *query) ares_free(query->server_info); ares_free(query); } + +void ares__socket_close(ares_channel channel, ares_socket_t s) +{ + if (channel->sock_funcs) + channel->sock_funcs->aclose(s, channel->sock_func_cb_data); + else + sclose(s); +}