From fed3559cfc258881fb25aec593be521f2a7813e1 Mon Sep 17 00:00:00 2001 From: Brad House Date: Mon, 12 Feb 2024 11:00:13 -0500 Subject: [PATCH] Add ares_queue_wait_empty() for use with EventThreads (#710) It may be useful to wait for the queue to be empty under certain conditions (mainly test cases), expose a function to efficiently do this and rework test cases to use it. Fix By: Brad House (@bradh352) --- docs/Makefile.inc | 1 + docs/ares_queue_wait_empty.3 | 44 ++++++ docs/ares_threadsafety.3 | 2 +- include/ares.h | 14 ++ src/lib/ares__threads.c | 289 ++++++++++++++++++++++++++++++++++- src/lib/ares__threads.h | 15 ++ src/lib/ares_cancel.c | 3 + src/lib/ares_destroy.c | 2 + src/lib/ares_event_thread.c | 13 +- src/lib/ares_private.h | 13 ++ src/lib/ares_process.c | 16 +- test/ares-test-mock-et.cc | 25 ++- test/ares-test.cc | 196 ++++++++++++------------ test/ares-test.h | 30 +++- 14 files changed, 541 insertions(+), 122 deletions(-) create mode 100644 docs/ares_queue_wait_empty.3 diff --git a/docs/Makefile.inc b/docs/Makefile.inc index e3800159..5d5bb185 100644 --- a/docs/Makefile.inc +++ b/docs/Makefile.inc @@ -105,6 +105,7 @@ MANPAGES = ares_cancel.3 \ ares_parse_uri_reply.3 \ ares_process.3 \ ares_query.3 \ + ares_queue_wait_empty.3 \ ares_reinit.3 \ ares_save_options.3 \ ares_search.3 \ diff --git a/docs/ares_queue_wait_empty.3 b/docs/ares_queue_wait_empty.3 new file mode 100644 index 00000000..59038fdc --- /dev/null +++ b/docs/ares_queue_wait_empty.3 @@ -0,0 +1,44 @@ +.\" +.\" SPDX-License-Identifier: MIT +.\" +.TH ARES_QUEUE_WAIT_EMPTY 3 "12 February 2024" +.SH NAME +ares_queue_wait_empty \- Wait until all queries are complete for channel +.SH SYNOPSIS +.nf +#include + +ares_status_t ares_queue_wait_empty(ares_channel_t *channel, + int timeout_ms); +.fi +.SH DESCRIPTION +The \fBares_queue_wait_empty(3)\fP function blocks until notified that there are +no longer any queries in queue, or the specified timeout has expired. + +The \fBchannel\fP parameter must be set to an initialized channel. + +The \fBtimeout_ms\fP parameter is the number of milliseconds to wait for the +queue to be empty or -1 for Infinite. + +.SH RETURN VALUES +\fIares_queue_wait_empty(3)\fP can return any of the following values: +.TP 14 +.B ARES_ENOTIMP +if not built with threading support +.TP 14 +.B ARES_ETIMEOUT +if requested timeout expired +.TP 14 +.B ARES_SUCCESS +when queue is empty. +.TP 14 + +.SH AVAILABILITY +This function was first introduced in c-ares version 1.27.0, and requires the +c-ares library to be built with threading support. + +.SH SEE ALSO +.BR ares_init_options (3), +.BR ares_threadsafety (3) +.SH AUTHOR +Copyright (C) 2024 The c-ares project and its members. diff --git a/docs/ares_threadsafety.3 b/docs/ares_threadsafety.3 index 119c79ff..a3c29d5f 100644 --- a/docs/ares_threadsafety.3 +++ b/docs/ares_threadsafety.3 @@ -1,7 +1,7 @@ .\" .\" SPDX-License-Identifier: MIT .\" -.TH ARES_REINIT 3 "26 November 2023" +.TH ARES_THREADSAFETY 3 "26 November 2023" .SH NAME ares_threadsafety \- Query if c-ares was built with thread-safety .SH SYNOPSIS diff --git a/include/ares.h b/include/ares.h index 99c4ec50..32c0191b 100644 --- a/include/ares.h +++ b/include/ares.h @@ -773,6 +773,20 @@ CARES_EXTERN int ares_inet_pton(int af, const char *src, void *dst); */ CARES_EXTERN ares_bool_t ares_threadsafety(void); + +/*! Block until notified that there are no longer any queries in queue, or + * the specified timeout has expired. + * + * \param[in] channel Initialized ares channel + * \param[in] timeout_ms Number of milliseconds to wait for the queue to be + * empty. -1 for Infinite. + * \return ARES_ENOTIMP if not built with threading support, ARES_ETIMEOUT + * if requested timeout expires, ARES_SUCCESS when queue is empty. + */ +CARES_EXTERN ares_status_t ares_queue_wait_empty(ares_channel_t *channel, + int timeout_ms); + + #ifdef __cplusplus } #endif diff --git a/src/lib/ares__threads.c b/src/lib/ares__threads.c index 6cb6fccc..028790ae 100644 --- a/src/lib/ares__threads.c +++ b/src/lib/ares__threads.c @@ -70,13 +70,77 @@ void ares__thread_mutex_unlock(ares__thread_mutex_t *mut) LeaveCriticalSection(&mut->mutex); } +struct ares__thread_cond { + CONDITION_VARIABLE cond; +}; + +ares__thread_cond_t *ares__thread_cond_create(void) +{ + ares__thread_cond_t *cond = ares_malloc_zero(sizeof(*cond)); + if (cond == NULL) { + return NULL; + } + InitializeConditionVariable(&cond->cond); + return cond; +} + +void ares__thread_cond_destroy(ares__thread_cond_t *cond) +{ + if (cond == NULL) { + return; + } + ares_free(cond); +} + +void ares__thread_cond_signal(ares__thread_cond_t *cond) +{ + if (cond == NULL) { + return; + } + WakeConditionVariable(&cond->cond); +} + +void ares__thread_cond_broadcast(ares__thread_cond_t *cond) +{ + if (cond == NULL) { + return; + } + WakeAllConditionVariable(&cond->cond); +} + +ares_status_t ares__thread_cond_wait(ares__thread_cond_t *cond, + ares__thread_mutex_t *mut) +{ + if (cond == NULL || mut == NULL) { + return ARES_EFORMERR; + } + + SleepConditionVariableCS(&cond->cond, &mut->mutex, INFINITE); + return ARES_SUCCESS; +} + +ares_status_t ares__thread_cond_timedwait(ares__thread_cond_t *cond, + ares__thread_mutex_t *mut, + unsigned long timeout_ms) +{ + if (cond == NULL || mut == NULL) { + return ARES_EFORMERR; + } + + if (!SleepConditionVariableCS(&cond->cond, &mut->mutex, timeout_ms)) { + return ARES_ETIMEOUT; + } + + return ARES_SUCCESS; +} + struct ares__thread { HANDLE thread; DWORD id; - void *(*func)(void *arg); - void *arg; - void *rv; + void *(*func)(void *arg); + void *arg; + void *rv; }; /* Wrap for pthread compatibility */ @@ -139,6 +203,16 @@ ares_status_t ares__thread_join(ares__thread_t *thread, void **rv) # else /* !WIN32 == PTHREAD */ # include +/* for clock_gettime() */ +# ifdef HAVE_TIME_H +# include +# endif + +/* for gettimeofday() */ +# ifdef HAVE_SYS_TIME_H +# include +# endif + struct ares__thread_mutex { pthread_mutex_t mutex; }; @@ -198,6 +272,98 @@ void ares__thread_mutex_unlock(ares__thread_mutex_t *mut) pthread_mutex_unlock(&mut->mutex); } +struct ares__thread_cond { + pthread_cond_t cond; +}; + +ares__thread_cond_t *ares__thread_cond_create(void) +{ + ares__thread_cond_t *cond = ares_malloc_zero(sizeof(*cond)); + if (cond == NULL) { + return NULL; + } + pthread_cond_init(&cond->cond, NULL); + return cond; +} + +void ares__thread_cond_destroy(ares__thread_cond_t *cond) +{ + if (cond == NULL) { + return; + } + pthread_cond_destroy(&cond->cond); + ares_free(cond); +} + +void ares__thread_cond_signal(ares__thread_cond_t *cond) +{ + if (cond == NULL) { + return; + } + pthread_cond_signal(&cond->cond); +} + +void ares__thread_cond_broadcast(ares__thread_cond_t *cond) +{ + if (cond == NULL) { + return; + } + pthread_cond_broadcast(&cond->cond); +} + +ares_status_t ares__thread_cond_wait(ares__thread_cond_t *cond, + ares__thread_mutex_t *mut) +{ + if (cond == NULL || mut == NULL) { + return ARES_EFORMERR; + } + + pthread_cond_wait(&cond->cond, &mut->mutex); + return ARES_SUCCESS; +} + +static void ares__timespec_timeout(struct timespec *ts, unsigned long add_ms) +{ +# if defined(HAVE_CLOCK_GETTIME) && defined(CLOCK_REALTIME) + clock_gettime(CLOCK_REALTIME, ts); +# elif defined(HAVE_GETTIMEOFDAY) + struct timeval tv; + gettimeofday(&tv, NULL); + ts->tv_sec = tv.tv_sec; + ts->tv_nsec = tv.tv_usec * 1000; +# else +# error cannot determine current system time +# endif + + ts->tv_sec += add_ms / 1000; + ts->tv_nsec += (add_ms % 1000) * 1000000; + + /* Normalize if needed */ + if (ts->tv_nsec >= 1000000000) { + ts->tv_sec += ts->tv_nsec / 1000000000; + ts->tv_nsec %= 1000000000; + } +} + +ares_status_t ares__thread_cond_timedwait(ares__thread_cond_t *cond, + ares__thread_mutex_t *mut, + unsigned long timeout_ms) +{ + struct timespec ts; + + if (cond == NULL || mut == NULL) { + return ARES_EFORMERR; + } + + ares__timespec_timeout(&ts, timeout_ms); + + if (pthread_cond_timedwait(&cond->cond, &mut->mutex, &ts) != 0) { + return ARES_ETIMEOUT; + } + + return ARES_SUCCESS; +} + struct ares__thread { pthread_t thread; }; @@ -274,6 +440,44 @@ void ares__thread_mutex_unlock(ares__thread_mutex_t *mut) (void)mut; } +ares__thread_cond_t *ares__thread_cond_create(void) +{ + return NULL; +} + +void ares__thread_cond_destroy(ares__thread_cond_t *cond) +{ + (void)cond; +} + +void ares__thread_cond_signal(ares__thread_cond_t *cond) +{ + (void)cond; +} + +void ares__thread_cond_broadcast(ares__thread_cond_t *cond) +{ + (void)cond; +} + +ares_status_t ares__thread_cond_wait(ares__thread_cond_t *cond, + ares__thread_mutex_t *mut) +{ + (void)cond; + (void)mut; + return ARES_ENOTIMP; +} + +ares_status_t ares__thread_cond_timedwait(ares__thread_cond_t *cond, + ares__thread_mutex_t *mut, + unsigned long timeout_ms) +{ + (void)cond; + (void)mut; + (void)timeout_ms; + return ARES_ENOTIMP; +} + ares_status_t ares__thread_create(ares__thread_t **thread, ares__thread_func_t func, void *arg) { @@ -299,6 +503,8 @@ ares_bool_t ares_threadsafety(void) ares_status_t ares__channel_threading_init(ares_channel_t *channel) { + ares_status_t status = ARES_SUCCESS; + /* Threading is optional! */ if (!ares_threadsafety()) { return ARES_SUCCESS; @@ -306,15 +512,29 @@ ares_status_t ares__channel_threading_init(ares_channel_t *channel) channel->lock = ares__thread_mutex_create(); if (channel->lock == NULL) { - return ARES_ENOMEM; + status = ARES_ENOMEM; + goto done; } - return ARES_SUCCESS; + + channel->cond_empty = ares__thread_cond_create(); + if (channel->cond_empty == NULL) { + status = ARES_ENOMEM; + goto done; + } + +done: + if (status != ARES_SUCCESS) { + ares__channel_threading_destroy(channel); + } + return status; } void ares__channel_threading_destroy(ares_channel_t *channel) { ares__thread_mutex_destroy(channel->lock); channel->lock = NULL; + ares__thread_cond_destroy(channel->cond_empty); + channel->cond_empty = NULL; } void ares__channel_lock(ares_channel_t *channel) @@ -326,3 +546,62 @@ void ares__channel_unlock(ares_channel_t *channel) { ares__thread_mutex_unlock(channel->lock); } + +/* Must not be holding a channel lock already, public function only */ +ares_status_t ares_queue_wait_empty(ares_channel_t *channel, int timeout_ms) +{ + ares_status_t status = ARES_SUCCESS; + struct timeval tout; + + if (!ares_threadsafety()) { + return ARES_ENOTIMP; + } + + if (channel == NULL) { + return ARES_EFORMERR; + } + + if (timeout_ms >= 0) { + tout = ares__tvnow(); + tout.tv_sec += timeout_ms / 1000; + tout.tv_usec += (timeout_ms % 1000) * 1000; + } + + ares__thread_mutex_lock(channel->lock); + while (ares__llist_len(channel->all_queries)) { + if (timeout_ms < 0) { + ares__thread_cond_wait(channel->cond_empty, channel->lock); + } else { + struct timeval tv_remaining; + struct timeval tv_now = ares__tvnow(); + unsigned long tms; + + ares__timeval_remaining(&tv_remaining, &tv_now, &tout); + tms = (unsigned long)((tv_remaining.tv_sec * 1000) + + (tv_remaining.tv_usec / 1000)); + if (tms == 0) { + status = ARES_ETIMEOUT; + } else { + status = + ares__thread_cond_timedwait(channel->cond_empty, channel->lock, tms); + } + } + } + ares__thread_mutex_unlock(channel->lock); + return status; +} + +void ares_queue_notify_empty(ares_channel_t *channel) +{ + if (channel == NULL) { + return; + } + + /* We are guaranteed to be holding a channel lock already */ + if (ares__llist_len(channel->all_queries)) { + return; + } + + /* Notify all waiters of the conditional */ + ares__thread_cond_broadcast(channel->cond_empty); +} diff --git a/src/lib/ares__threads.h b/src/lib/ares__threads.h index 03f67f08..39764296 100644 --- a/src/lib/ares__threads.h +++ b/src/lib/ares__threads.h @@ -34,6 +34,21 @@ void ares__thread_mutex_destroy(ares__thread_mutex_t *mut); void ares__thread_mutex_lock(ares__thread_mutex_t *mut); void ares__thread_mutex_unlock(ares__thread_mutex_t *mut); + +struct ares__thread_cond; +typedef struct ares__thread_cond ares__thread_cond_t; + +ares__thread_cond_t *ares__thread_cond_create(void); +void ares__thread_cond_destroy(ares__thread_cond_t *cond); +void ares__thread_cond_signal(ares__thread_cond_t *cond); +void ares__thread_cond_broadcast(ares__thread_cond_t *cond); +ares_status_t ares__thread_cond_wait(ares__thread_cond_t *cond, + ares__thread_mutex_t *mut); +ares_status_t ares__thread_cond_timedwait(ares__thread_cond_t *cond, + ares__thread_mutex_t *mut, + unsigned long timeout_ms); + + struct ares__thread; typedef struct ares__thread ares__thread_t; diff --git a/src/lib/ares_cancel.c b/src/lib/ares_cancel.c index 9841f9bf..0ee6124d 100644 --- a/src/lib/ares_cancel.c +++ b/src/lib/ares_cancel.c @@ -85,6 +85,9 @@ void ares_cancel(ares_channel_t *channel) ares__llist_destroy(list_copy); } + + ares_queue_notify_empty(channel); + done: ares__channel_unlock(channel); } diff --git a/src/lib/ares_destroy.c b/src/lib/ares_destroy.c index f2f0d9a7..14508457 100644 --- a/src/lib/ares_destroy.c +++ b/src/lib/ares_destroy.c @@ -57,6 +57,8 @@ void ares_destroy(ares_channel_t *channel) node = next; } + ares_queue_notify_empty(channel); + #ifndef NDEBUG /* Freeing the query should remove it from all the lists in which it sits, * so all query lists should be empty now. diff --git a/src/lib/ares_event_thread.c b/src/lib/ares_event_thread.c index 694753f9..6dd7b502 100644 --- a/src/lib/ares_event_thread.c +++ b/src/lib/ares_event_thread.c @@ -52,7 +52,13 @@ static void ares_event_destroy_cb(void *arg) /* See if a pending update already exists. We don't want to enqueue multiple * updates for the same event handle. Right now this is O(n) based on number * of updates already enqueued. In the future, it might make sense to make - * this O(1) with a hashtable. */ + * this O(1) with a hashtable. + * NOTE: in some cases a delete then re-add of the same fd, but really pointing + * to a different destination can happen due to a quick close of a + * connection then creation of a new one. So we need to look at the + * flags and ignore any delete events when finding a match since we + * need to process the delete always, it can't be combined with other + * updates. */ static ares_event_t *ares_event_update_find(ares_event_thread_t *e, ares_socket_t fd, const void *data) { @@ -62,12 +68,12 @@ static ares_event_t *ares_event_update_find(ares_event_thread_t *e, node = ares__llist_node_next(node)) { ares_event_t *ev = ares__llist_node_val(node); - if (fd != ARES_SOCKET_BAD && fd == ev->fd) { + if (fd != ARES_SOCKET_BAD && fd == ev->fd && ev->flags != 0) { return ev; } if (fd == ARES_SOCKET_BAD && ev->fd == ARES_SOCKET_BAD && - data == ev->data) { + data == ev->data && ev->flags != 0) { return ev; } } @@ -188,7 +194,6 @@ static void ares_event_thread_sockstate_cb(void *data, ares_socket_t socket_fd, /* Update channel fd */ ares__thread_mutex_lock(e->mutex); - ares_event_update(NULL, e, flags, ares_event_thread_process_fd, socket_fd, NULL, NULL, NULL); diff --git a/src/lib/ares_private.h b/src/lib/ares_private.h index fcddf05e..90f714f1 100644 --- a/src/lib/ares_private.h +++ b/src/lib/ares_private.h @@ -262,6 +262,9 @@ struct ares_channeldata { /* Thread safety lock */ ares__thread_mutex_t *lock; + /* Conditional to wake waiters when queue is empty */ + ares__thread_cond_t *cond_empty; + /* Server addresses and communications state. Sorted by least consecutive * failures, followed by the configuration order if failures are equal. */ ares__slist_t *servers; @@ -530,6 +533,16 @@ ares_status_t ares__dns_name_write(ares__buf_t *buf, ares__llist_t **list, ares_bool_t validate_hostname, const char *name); +/*! Check if the queue is empty, if so, wake any waiters. This is only + * effective if built with threading support. + * + * Must be holding a channel lock when calling this function. + * + * \param[in] channel Initialized ares channel object + */ +void ares_queue_notify_empty(ares_channel_t *channel); + + #define ARES_SWAP_BYTE(a, b) \ do { \ unsigned char swapByte = *(a); \ diff --git a/src/lib/ares_process.c b/src/lib/ares_process.c index d24add05..62e7f762 100644 --- a/src/lib/ares_process.c +++ b/src/lib/ares_process.c @@ -67,7 +67,7 @@ static ares_bool_t same_questions(const ares_dns_record_t *qrec, const ares_dns_record_t *arec); static ares_bool_t same_address(const struct sockaddr *sa, const struct ares_addr *aa); -static void end_query(const ares_channel_t *channel, struct query *query, +static void end_query(ares_channel_t *channel, struct query *query, ares_status_t status, const unsigned char *abuf, size_t alen); @@ -715,6 +715,7 @@ static ares_status_t process_answer(ares_channel_t *channel, default: break; } + server_increment_failures(server); ares__requeue_query(query, now); @@ -759,7 +760,7 @@ static void handle_conn_error(struct server_connection *conn, ares_status_t ares__requeue_query(struct query *query, struct timeval *now) { - const ares_channel_t *channel = query->channel; + ares_channel_t *channel = query->channel; size_t max_tries = ares__slist_len(channel->servers) * channel->tries; query->try_count++; @@ -1122,18 +1123,23 @@ static void ares_detach_query(struct query *query) query->node_all_queries = NULL; } -static void end_query(const ares_channel_t *channel, struct query *query, +static void end_query(ares_channel_t *channel, struct query *query, ares_status_t status, const unsigned char *abuf, size_t alen) { - (void)channel; - /* Invoke the callback. */ query->callback(query->arg, (int)status, (int)query->timeouts, /* due to prior design flaws, abuf isn't meant to be modified, * but bad prototypes, ugh. Lets cast off constfor compat. */ (unsigned char *)((void *)((size_t)abuf)), (int)alen); ares__free_query(query); + + /* Check and notify if no other queries are enqueued on the channel. This + * must come after the callback and freeing the query for 2 reasons. + * 1) The callback itself may enqueue a new query + * 2) Technically the current query isn't detached until it is free()'d. + */ + ares_queue_notify_empty(channel); } void ares__free_query(struct query *query) diff --git a/test/ares-test-mock-et.cc b/test/ares-test-mock-et.cc index a28637a7..eaab2c9e 100644 --- a/test/ares-test-mock-et.cc +++ b/test/ares-test-mock-et.cc @@ -213,9 +213,10 @@ class MockUDPEventThreadMaxQueriesTest MockUDPEventThreadMaxQueriesTest() : MockEventThreadOptsTest(1, std::get<0>(GetParam()), std::get<1>(GetParam()), false, FillOptions(&opts_), - ARES_OPT_UDP_MAX_QUERIES) {} + ARES_OPT_UDP_MAX_QUERIES|ARES_OPT_FLAGS) {} static struct ares_options* FillOptions(struct ares_options * opts) { memset(opts, 0, sizeof(struct ares_options)); + opts->flags = ARES_FLAG_STAYOPEN|ARES_FLAG_EDNS; opts->udp_max_queries = MAXUDPQUERIES_LIMIT; return opts; } @@ -306,7 +307,25 @@ TEST_P(CacheQueriesEventThreadTest, GetHostByNameCache) { } #define TCPPARALLELLOOKUPS 32 -TEST_P(MockTCPEventThreadTest, GetHostByNameParallelLookups) { + +class MockTCPEventThreadStayOpenTest + : public MockEventThreadOptsTest, + public ::testing::WithParamInterface> { + public: + MockTCPEventThreadStayOpenTest() + : MockEventThreadOptsTest(1, std::get<0>(GetParam()), std::get<1>(GetParam()), true /* tcp */, + FillOptions(&opts_), + ARES_OPT_FLAGS) {} + static struct ares_options* FillOptions(struct ares_options * opts) { + memset(opts, 0, sizeof(struct ares_options)); + opts->flags = ARES_FLAG_STAYOPEN|ARES_FLAG_EDNS; + return opts; + } + private: + struct ares_options opts_; +}; + +TEST_P(MockTCPEventThreadStayOpenTest, GetHostByNameParallelLookups) { DNSPacket rsp; rsp.set_response().set_aa() .add_question(new DNSQuestion("www.google.com", T_A)) @@ -1367,6 +1386,8 @@ INSTANTIATE_TEST_SUITE_P(AddressFamilies, CacheQueriesEventThreadTest, ::testing INSTANTIATE_TEST_SUITE_P(AddressFamilies, MockTCPEventThreadTest, ::testing::ValuesIn(ares::test::evsys_families), ares::test::PrintEvsysFamily); +INSTANTIATE_TEST_SUITE_P(AddressFamilies, MockTCPEventThreadStayOpenTest, ::testing::ValuesIn(ares::test::evsys_families), ares::test::PrintEvsysFamily); + INSTANTIATE_TEST_SUITE_P(AddressFamilies, MockExtraOptsEventThreadTest, ::testing::ValuesIn(ares::test::evsys_families_modes), ares::test::PrintEvsysFamilyMode); INSTANTIATE_TEST_SUITE_P(AddressFamilies, MockNoCheckRespEventThreadTest, ::testing::ValuesIn(ares::test::evsys_families_modes), ares::test::PrintEvsysFamilyMode); diff --git a/test/ares-test.cc b/test/ares-test.cc index 29621c61..53a299e0 100644 --- a/test/ares-test.cc +++ b/test/ares-test.cc @@ -54,6 +54,7 @@ extern "C" { #endif #include #include +#include #include #include @@ -337,96 +338,6 @@ void ProcessWork(ares_channel_t *channel, } } -void ProcessWorkEventThread(ares_channel_t *channel, - std::function()> get_extrafds, - std::function process_extra, - unsigned int cancel_ms) { - int nfds=0, count; - fd_set readers; - size_t retry_cnt = 1; - -#ifndef CARES_SYMBOL_HIDING - struct timeval tv_begin = ares__tvnow(); - struct timeval tv_cancel = tv_begin; - - if (cancel_ms) { - if (verbose) std::cerr << "ares_cancel will be called after " << cancel_ms << "ms" << std::endl; - tv_cancel.tv_sec += (cancel_ms / 1000); - tv_cancel.tv_usec += ((cancel_ms % 1000) * 1000); - } -#else - if (cancel_ms) { - std::cerr << "library built with symbol hiding, can't test with cancel support" << std::endl; - return; - } -#endif - - while (true) { -#ifndef CARES_SYMBOL_HIDING - struct timeval tv_now = ares__tvnow(); - struct timeval tv_remaining; -#endif - struct timeval tv; - - /* c-ares is using its own event thread, so we only need to monitor the - * extrafds passed in */ - FD_ZERO(&readers); - std::set extrafds = get_extrafds(); - for (ares_socket_t extrafd : extrafds) { - FD_SET(extrafd, &readers); - if (extrafd >= (ares_socket_t)nfds) { - nfds = (int)extrafd + 1; - } - } - - /* If ares_timeout returns NULL, it means there are no requests in queue, - * so we can break out, but lets loop one additional time just incase we - * have some weird multithreading issue where a result hasn't yet been - * delivered. This is really just an odd case, its not "normal" to try - * to determine if an event has been delivered by solely monitoring the - * channel, really we should know how many callbacks we expect and how - * many we get, but that's not easy to do in a test framework. */ - if (ares_timeout(channel, NULL, &tv) == NULL) { - if (retry_cnt == 0) - return; - retry_cnt--; - } else { - retry_cnt = 1; - } - -#ifndef CARES_SYMBOL_HIDING - if (cancel_ms) { - unsigned int remaining_ms; - ares__timeval_remaining(&tv_remaining, - &tv_now, - &tv_cancel); - remaining_ms = (unsigned int)((tv_remaining.tv_sec * 1000) + (tv_remaining.tv_usec / 1000)); - if (remaining_ms == 0) { - if (verbose) std::cerr << "Issuing ares_cancel()" << std::endl; - ares_cancel(channel); - cancel_ms = 0; /* Disable issuing cancel again */ - } - } -#endif - - /* We just always wait 50ms then recheck. Not doing any complex signalling. */ - tv.tv_sec = 0; - tv.tv_usec = 50000; - - count = select(nfds, &readers, nullptr, nullptr, &tv); - if (count < 0) { - fprintf(stderr, "select() failed, errno %d\n", errno); - return; - } - - // Let the provided callback process any activity on the extra FD. - for (ares_socket_t extrafd : extrafds) { - if (FD_ISSET(extrafd, &readers)) { - process_extra(extrafd); - } - } - } -} // static void LibraryTest::SetAllocFail(int nth) { @@ -600,6 +511,14 @@ MockServer::~MockServer() { free(tcp_data_); } +static unsigned short getaddrport(struct sockaddr_storage *addr) +{ + if (addr->ss_family == AF_INET) + return ntohs(((struct sockaddr_in *)(void *)addr)->sin_port); + + return ntohs(((struct sockaddr_in6 *)(void *)addr)->sin6_port); +} + void MockServer::ProcessPacket(ares_socket_t fd, struct sockaddr_storage *addr, ares_socklen_t addrlen, byte *data, int len) { @@ -658,7 +577,8 @@ void MockServer::ProcessPacket(ares_socket_t fd, struct sockaddr_storage *addr, if (verbose) { std::vector req(data, data + len); std::cerr << "received " << (fd == udpfd_ ? "UDP" : "TCP") << " request " << PacketToString(req) - << " on port " << (fd == udpfd_ ? udpport_ : tcpport_) << std::endl; + << " on port " << (fd == udpfd_ ? udpport_ : tcpport_) + << ":" << getaddrport(addr) << std::endl; std::cerr << "ProcessRequest(" << qid << ", '" << namestr << "', " << RRTypeToString(rrtype) << ")" << std::endl; } @@ -730,6 +650,7 @@ std::set MockServer::fds() const { return result; } + void MockServer::ProcessRequest(ares_socket_t fd, struct sockaddr_storage* addr, ares_socklen_t addrlen, int qid, const std::string& name, int rrtype) { // Before processing, let gMock know the request is happening. @@ -751,8 +672,11 @@ void MockServer::ProcessRequest(ares_socket_t fd, struct sockaddr_storage* addr, reply[0] = (byte)((qid >> 8) & 0xff); reply[1] = (byte)(qid & 0xff); } - if (verbose) std::cerr << "sending reply " << PacketToString(reply) - << " on port " << ((fd == udpfd_) ? udpport_ : tcpport_) << std::endl; + if (verbose) { + std::cerr << "sending reply " << PacketToString(reply) + << " on port " << ((fd == udpfd_) ? udpport_ : tcpport_) + << ":" << getaddrport(addr) << std::endl; + } // Prefix with 2-byte length if TCP. if (fd != udpfd_) { @@ -895,12 +819,86 @@ void MockChannelOptsTest::Process(unsigned int cancel_ms) { cancel_ms); } -void MockEventThreadOptsTest::Process(unsigned int cancel_ms) { - using namespace std::placeholders; - ProcessWorkEventThread(channel_, - std::bind(&MockEventThreadOptsTest::fds, this), - std::bind(&MockEventThreadOptsTest::ProcessFD, this, _1), - cancel_ms); +void MockEventThreadOptsTest::ProcessThread() { + std::set fds; + +#ifndef CARES_SYMBOL_HIDING + bool has_cancel_ms = false; + struct timeval tv_begin; + struct timeval tv_cancel; +#endif + + mutex.lock(); + + while (isup) { + int nfds = 0; + fd_set readers; +#ifndef CARES_SYMBOL_HIDING + struct timeval tv_now = ares__tvnow(); + struct timeval tv_remaining; + if (cancel_ms_ && !has_cancel_ms) { + tv_begin = ares__tvnow(); + tv_cancel = tv_begin; + if (verbose) std::cerr << "ares_cancel will be called after " << cancel_ms_ << "ms" << std::endl; + tv_cancel.tv_sec += (cancel_ms_ / 1000); + tv_cancel.tv_usec += ((cancel_ms_ % 1000) * 1000); + has_cancel_ms = true; + } +#else + if (cancel_ms_) { + std::cerr << "library built with symbol hiding, can't test with cancel support" << std::endl; + return; + } +#endif + struct timeval tv; + + /* c-ares is using its own event thread, so we only need to monitor the + * extrafds passed in */ + FD_ZERO(&readers); + fds = MockEventThreadOptsTest::fds(); + for (ares_socket_t fd : fds) { + FD_SET(fd, &readers); + if (fd >= (ares_socket_t)nfds) { + nfds = (int)fd + 1; + } + } + +#ifndef CARES_SYMBOL_HIDING + if (has_cancel_ms) { + unsigned int remaining_ms; + ares__timeval_remaining(&tv_remaining, + &tv_now, + &tv_cancel); + remaining_ms = (unsigned int)((tv_remaining.tv_sec * 1000) + (tv_remaining.tv_usec / 1000)); + if (remaining_ms == 0) { + if (verbose) std::cerr << "Issuing ares_cancel()" << std::endl; + ares_cancel(channel_); + cancel_ms_ = 0; /* Disable issuing cancel again */ + has_cancel_ms = false; + } + } +#endif + + /* We just always wait 20ms then recheck. Not doing any complex signaling. */ + tv.tv_sec = 0; + tv.tv_usec = 20000; + + mutex.unlock(); + if (select(nfds, &readers, nullptr, nullptr, &tv) < 0) { + fprintf(stderr, "select() failed, errno %d\n", errno); + return; + } + + // Let the provided callback process any activity on the extra FD. + for (ares_socket_t fd : fds) { + if (FD_ISSET(fd, &readers)) { + ProcessFD(fd); + } + } + mutex.lock(); + } + mutex.unlock(); + } std::ostream& operator<<(std::ostream& os, const HostResult& result) { diff --git a/test/ares-test.h b/test/ares-test.h index ebeec50b..cf2f4d2e 100644 --- a/test/ares-test.h +++ b/test/ares-test.h @@ -50,6 +50,7 @@ #include #include #include +#include #include #include @@ -90,10 +91,6 @@ void ProcessWork(ares_channel_t *cha std::function()> get_extrafds, std::function process_extra, unsigned int cancel_ms = 0); -void ProcessWorkEventThread(ares_channel_t *channel, - std::function()> get_extrafds, - std::function process_extra, - unsigned int cancel_ms); std::set NoExtraFDs(); const char *af_tostr(int af); @@ -355,9 +352,17 @@ public: struct ares_options *givenopts, int optmask) : MockChannelOptsTest(count, family, force_tcp, FillOptionsET(&evopts_, givenopts, evsys), optmask | ARES_OPT_EVENT_THREAD) { + cancel_ms_ = 0; + isup = true; + thread = std::thread(&MockEventThreadOptsTest::ProcessThread, this); + } + ~MockEventThreadOptsTest() + { + mutex.lock(); + isup = false; + mutex.unlock(); + thread.join(); } - - void Process(unsigned int cancel_ms = 0); static struct ares_options *FillOptionsET(struct ares_options *opts, struct ares_options *givenopts, ares_evsys_t evsys) { if (givenopts) { @@ -369,8 +374,20 @@ public: return opts; } + void Process(unsigned int cancel_ms = 0) { + mutex.lock(); + cancel_ms_ = cancel_ms; + mutex.unlock(); + ares_queue_wait_empty(channel_, -1); + } + private: + void ProcessThread(); struct ares_options evopts_; + unsigned int cancel_ms_; + bool isup; + std::mutex mutex; + std::thread thread; }; class MockEventThreadTest @@ -381,6 +398,7 @@ public: : MockEventThreadOptsTest(1, std::get<0>(GetParam()), std::get<1>(GetParam()), std::get<2>(GetParam()), nullptr, 0) { } + }; class MockUDPEventThreadTest : public MockEventThreadOptsTest,