Implement using virtual socket IO functions when set

Uses virtual socket IO functions when set on a channel.
Note that no socket options are set, nor is any binding
done by the library in this case, since the client defining
these is probably more suited to deal with this.
pull/76/merge
elcallio 8 years ago committed by Daniel Stenberg
parent c32103e112
commit 7e1e31c6cd
  1. 4
      ares__close_sockets.c
  2. 2
      ares_private.h
  3. 148
      ares_process.c

@ -48,14 +48,14 @@ void ares__close_sockets(ares_channel channel, struct server_state *server)
if (server->tcp_socket != ARES_SOCKET_BAD)
{
SOCK_STATE_CALLBACK(channel, server->tcp_socket, 0, 0);
sclose(server->tcp_socket);
ares__socket_close(channel, server->tcp_socket);
server->tcp_socket = ARES_SOCKET_BAD;
server->tcp_connection_generation = ++channel->tcp_connection_generation;
}
if (server->udp_socket != ARES_SOCKET_BAD)
{
SOCK_STATE_CALLBACK(channel, server->udp_socket, 0, 0);
sclose(server->udp_socket);
ares__socket_close(channel, server->udp_socket);
server->udp_socket = ARES_SOCKET_BAD;
}
}

@ -345,6 +345,8 @@ void ares__destroy_servers_state(ares_channel channel);
long ares__tvdiff(struct timeval t1, struct timeval t2);
#endif
void ares__socket_close(ares_channel, ares_socket_t);
#define ARES_SWAP_BYTE(a,b) \
{ unsigned char swapByte = *(a); *(a) = *(b); *(b) = swapByte; }

@ -1,6 +1,6 @@
/* Copyright 1998 by the Massachusetts Institute of Technology.
* Copyright (C) 2004-2016 by Daniel Stenberg
* Copyright (C) 2004-2017 by Daniel Stenberg
*
* Permission to use, copy, modify, and distribute this
* software and its documentation for any purpose and without
@ -175,6 +175,26 @@ static int try_again(int errnum)
return 0;
}
static ssize_t socket_writev(ares_channel channel, ares_socket_t s, const struct iovec * vec, int len)
{
if (channel->sock_funcs)
return channel->sock_funcs->asendv(s, vec, len, channel->sock_func_cb_data);
return writev(s, vec, len);
}
static ssize_t socket_write(ares_channel channel, ares_socket_t s, const void * data, size_t len)
{
if (channel->sock_funcs)
{
struct iovec vec;
vec.iov_base = (void*)data;
vec.iov_len = len;
return channel->sock_funcs->asendv(s, &vec, 1, channel->sock_func_cb_data);
}
return swrite(s, data, len);
}
/* If any TCP sockets select true for writing, write out queued data
* we have for them.
*/
@ -238,7 +258,7 @@ static void write_tcp_data(ares_channel channel,
vec[n].iov_len = sendreq->len;
n++;
}
wcount = (ssize_t)writev(server->tcp_socket, vec, (int)n);
wcount = socket_writev(channel, server->tcp_socket, vec, (int)n);
ares_free(vec);
if (wcount < 0)
{
@ -255,7 +275,7 @@ static void write_tcp_data(ares_channel channel,
/* Can't allocate iovecs; just send the first request. */
sendreq = server->qhead;
scount = swrite(server->tcp_socket, sendreq->data, sendreq->len);
scount = socket_write(channel, server->tcp_socket, sendreq->data, sendreq->len);
if (scount < 0)
{
if (!try_again(SOCKERRNO))
@ -299,6 +319,38 @@ static void advance_tcp_send_queue(ares_channel channel, int whichserver,
}
}
static ssize_t socket_recvfrom(ares_channel channel,
ares_socket_t s,
void * data,
size_t data_len,
int flags,
struct sockaddr *from,
socklen_t *from_len)
{
if (channel->sock_funcs)
return channel->sock_funcs->arecvfrom(s, data, data_len,
flags, from, from_len,
channel->sock_func_cb_data);
#ifdef HAVE_RECVFROM
return recvfrom(s, data, data_len, flags, from, from_len);
#else
return sread(s, data, data_len);
#endif
}
static ssize_t socket_recv(ares_channel channel,
ares_socket_t s,
void * data,
size_t data_len)
{
if (channel->sock_funcs)
return channel->sock_funcs->arecvfrom(s, data, data_len, 0, 0, 0,
channel->sock_func_cb_data);
return sread(s, data, data_len);
}
/* If any TCP socket selects true for reading, read some data,
* allocate a buffer if we finish reading the length word, and process
* a packet if we finish reading one.
@ -343,9 +395,9 @@ static void read_tcp_data(ares_channel channel, fd_set *read_fds,
/* We haven't yet read a length word, so read that (or
* what's left to read of it).
*/
count = sread(server->tcp_socket,
server->tcp_lenbuf + server->tcp_lenbuf_pos,
2 - server->tcp_lenbuf_pos);
count = socket_recv(channel, server->tcp_socket,
server->tcp_lenbuf + server->tcp_lenbuf_pos,
2 - server->tcp_lenbuf_pos);
if (count <= 0)
{
if (!(count == -1 && try_again(SOCKERRNO)))
@ -373,9 +425,9 @@ static void read_tcp_data(ares_channel channel, fd_set *read_fds,
else
{
/* Read data into the allocated buffer. */
count = sread(server->tcp_socket,
server->tcp_buffer + server->tcp_buffer_pos,
server->tcp_length - server->tcp_buffer_pos);
count = socket_recv(channel, server->tcp_socket,
server->tcp_buffer + server->tcp_buffer_pos,
server->tcp_length - server->tcp_buffer_pos);
if (count <= 0)
{
if (!(count == -1 && try_again(SOCKERRNO)))
@ -453,16 +505,12 @@ static void read_udp_packets(ares_channel channel, fd_set *read_fds,
count = 0;
else {
#ifdef HAVE_RECVFROM
if (server->addr.family == AF_INET)
fromlen = sizeof(from.sa4);
else
fromlen = sizeof(from.sa6);
count = (ssize_t)recvfrom(server->udp_socket, (void *)buf,
sizeof(buf), 0, &from.sa, &fromlen);
#else
count = sread(server->udp_socket, buf, sizeof(buf));
#endif
count = socket_recvfrom(channel, server->udp_socket, (void *)buf,
sizeof(buf), 0, &from.sa, &fromlen);
}
if (count == -1 && try_again(SOCKERRNO))
@ -812,7 +860,7 @@ void ares__send_query(ares_channel channel, struct query *query,
return;
}
}
if (swrite(server->udp_socket, query->qbuf, query->qlen) == -1)
if (socket_write(channel, server->udp_socket, query->qbuf, query->qlen) == -1)
{
/* FIXME: Handle EAGAIN here since it likely can happen. */
skip_server(channel, query, query->server);
@ -904,6 +952,10 @@ static int configure_socket(ares_socket_t s, int family, ares_channel channel)
struct sockaddr_in6 sa6;
} local;
/* do not set options for user-managed sockets */
if (channel->sock_funcs)
return 0;
(void)setsocknonblock(s, TRUE);
#if defined(FD_CLOEXEC) && !defined(MSDOS)
@ -959,6 +1011,30 @@ static int configure_socket(ares_socket_t s, int family, ares_channel channel)
return 0;
}
static int open_socket(ares_channel channel, int af, int type, int protocol)
{
if (channel->sock_funcs != 0)
return channel->sock_funcs->asocket(af,
type,
protocol,
channel->sock_func_cb_data);
return socket(af, type, protocol);
}
static int connect_socket(ares_channel channel, ares_socket_t sockfd,
const struct sockaddr * addr,
socklen_t addrlen)
{
if (channel->sock_funcs != 0)
return channel->sock_funcs->aconnect(sockfd,
addr,
addrlen,
channel->sock_func_cb_data);
return connect(sockfd, addr, addrlen);
}
static int open_tcp_socket(ares_channel channel, struct server_state *server)
{
ares_socket_t s;
@ -1003,14 +1079,14 @@ static int open_tcp_socket(ares_channel channel, struct server_state *server)
}
/* Acquire a socket. */
s = socket(server->addr.family, SOCK_STREAM, 0);
s = open_socket(channel, server->addr.family, SOCK_STREAM, 0);
if (s == ARES_SOCKET_BAD)
return -1;
/* Configure it. */
if (configure_socket(s, server->addr.family, channel) < 0)
{
sclose(s);
ares__socket_close(channel, s);
return -1;
}
@ -1022,10 +1098,12 @@ static int open_tcp_socket(ares_channel channel, struct server_state *server)
* so batching isn't very interesting.
*/
opt = 1;
if (setsockopt(s, IPPROTO_TCP, TCP_NODELAY,
(void *)&opt, sizeof(opt)) == -1)
if (channel->sock_funcs == 0
&&
setsockopt(s, IPPROTO_TCP, TCP_NODELAY,
(void *)&opt, sizeof(opt)) == -1)
{
sclose(s);
ares__socket_close(channel, s);
return -1;
}
#endif
@ -1036,19 +1114,19 @@ static int open_tcp_socket(ares_channel channel, struct server_state *server)
channel->sock_config_cb_data);
if (err < 0)
{
sclose(s);
ares__socket_close(channel, s);
return err;
}
}
/* Connect to the server. */
if (connect(s, sa, salen) == -1)
if (connect_socket(channel, s, sa, salen) == -1)
{
int err = SOCKERRNO;
if (err != EINPROGRESS && err != EWOULDBLOCK)
{
sclose(s);
ares__socket_close(channel, s);
return -1;
}
}
@ -1059,7 +1137,7 @@ static int open_tcp_socket(ares_channel channel, struct server_state *server)
channel->sock_create_cb_data);
if (err < 0)
{
sclose(s);
ares__socket_close(channel, s);
return err;
}
}
@ -1114,14 +1192,14 @@ static int open_udp_socket(ares_channel channel, struct server_state *server)
}
/* Acquire a socket. */
s = socket(server->addr.family, SOCK_DGRAM, 0);
s = open_socket(channel, server->addr.family, SOCK_DGRAM, 0);
if (s == ARES_SOCKET_BAD)
return -1;
/* Set the socket non-blocking. */
if (configure_socket(s, server->addr.family, channel) < 0)
{
sclose(s);
ares__socket_close(channel, s);
return -1;
}
@ -1131,19 +1209,19 @@ static int open_udp_socket(ares_channel channel, struct server_state *server)
channel->sock_config_cb_data);
if (err < 0)
{
sclose(s);
ares__socket_close(channel, s);
return err;
}
}
/* Connect to the server. */
if (connect(s, sa, salen) == -1)
if (connect_socket(channel, s, sa, salen) == -1)
{
int err = SOCKERRNO;
if (err != EINPROGRESS && err != EWOULDBLOCK)
{
sclose(s);
ares__socket_close(channel, s);
return -1;
}
}
@ -1154,7 +1232,7 @@ static int open_udp_socket(ares_channel channel, struct server_state *server)
channel->sock_create_cb_data);
if (err < 0)
{
sclose(s);
ares__socket_close(channel, s);
return err;
}
}
@ -1357,3 +1435,11 @@ void ares__free_query(struct query *query)
ares_free(query->server_info);
ares_free(query);
}
void ares__socket_close(ares_channel channel, ares_socket_t s)
{
if (channel->sock_funcs)
channel->sock_funcs->aclose(s, channel->sock_func_cb_data);
else
sclose(s);
}

Loading…
Cancel
Save