diff --git a/docs/ares_init_options.3 b/docs/ares_init_options.3 index 77d163e4..694beb5e 100644 --- a/docs/ares_init_options.3 +++ b/docs/ares_init_options.3 @@ -130,6 +130,14 @@ v1.22, this is on by default if flags are otherwise not set. .B ARES_FLAG_NO_DFLT_SVR Do not attempt to add a default local named server if there are no other servers available. Instead, fail initialization with \fIARES_ENOSERVER\fP. +.TP 23 +.B ARES_FLAG_DNS0x20 +Enable support for DNS 0x20 as per https://datatracker.ietf.org/doc/html/draft-vixie-dnsext-dns0x20-00 +which adds additional entropy to the request by randomizing the case of the +query name. Integrators need to ensure they treat DNS name responses as +case-insensitive. In rare circumstances this may cause the inability to lookup +certain domains if the upstream server or the authoritative server for the +domain is non-compliant. .RE .TP 18 .B ARES_OPT_TIMEOUT diff --git a/include/ares.h b/include/ares.h index ce549587..afff1759 100644 --- a/include/ares.h +++ b/include/ares.h @@ -238,6 +238,7 @@ typedef enum { #define ARES_FLAG_NOCHECKRESP (1 << 7) #define ARES_FLAG_EDNS (1 << 8) #define ARES_FLAG_NO_DFLT_SVR (1 << 9) +#define ARES_FLAG_DNS0x20 (1 << 10) /* Option mask values */ #define ARES_OPT_FLAGS (1 << 0) diff --git a/include/ares_dns_record.h b/include/ares_dns_record.h index 1084a52e..c8ef15b1 100644 --- a/include/ares_dns_record.h +++ b/include/ares_dns_record.h @@ -600,6 +600,15 @@ CARES_EXTERN void ares_dns_record_destroy(ares_dns_record_t *dnsrec); CARES_EXTERN unsigned short ares_dns_record_get_id(const ares_dns_record_t *dnsrec); +/*! Overwrite the DNS query id + * + * \param[in] dnsrec Initialized record object + * \param[in] id DNS query id + * \return ARES_TRUE on success, ARES_FALSE on usage error + */ +CARES_EXTERN ares_bool_t + ares_dns_record_set_id(ares_dns_record_t *dnsrec, unsigned short id); + /*! Get the DNS Record Flags * * \param[in] dnsrec Initialized record object diff --git a/src/lib/ares_dns_name.c b/src/lib/ares_dns_name.c index 749ed2c7..076d2664 100644 --- a/src/lib/ares_dns_name.c +++ b/src/lib/ares_dns_name.c @@ -111,7 +111,10 @@ static const ares_nameoffset_t *ares__nameoffset_find(ares__llist_t *list, prefix_len = name_len - val->name_len; - if (strcasecmp(val->name, name + prefix_len) != 0) { + /* Due to DNS 0x20, lets not inadvertently mangle things, use case-sensitive + * matching instead of case-insensitive. This may result in slightly + * larger DNS queries overall. */ + if (strcmp(val->name, name + prefix_len) != 0) { continue; } diff --git a/src/lib/ares_dns_record.c b/src/lib/ares_dns_record.c index dfaac572..1593c0d3 100644 --- a/src/lib/ares_dns_record.c +++ b/src/lib/ares_dns_record.c @@ -66,6 +66,15 @@ unsigned short ares_dns_record_get_id(const ares_dns_record_t *dnsrec) return dnsrec->id; } +ares_bool_t ares_dns_record_set_id(ares_dns_record_t *dnsrec, unsigned short id) +{ + if (dnsrec == NULL) { + return ARES_FALSE; + } + dnsrec->id = id; + return ARES_TRUE; +} + unsigned short ares_dns_record_get_flags(const ares_dns_record_t *dnsrec) { if (dnsrec == NULL) { diff --git a/src/lib/ares_private.h b/src/lib/ares_private.h index aba1133d..3da8e78a 100644 --- a/src/lib/ares_private.h +++ b/src/lib/ares_private.h @@ -254,9 +254,8 @@ struct query { /* connection handle query is associated with */ struct server_connection *conn; - /* Arguments passed to ares_send() */ - unsigned char *qbuf; - size_t qlen; + /* Query */ + ares_dns_record_t *query; ares_callback_dnsrec callback; void *arg; diff --git a/src/lib/ares_process.c b/src/lib/ares_process.c index 7251a11c..562d6b5e 100644 --- a/src/lib/ares_process.c +++ b/src/lib/ares_process.c @@ -60,7 +60,7 @@ static ares_status_t process_answer(ares_channel_t *channel, static void handle_conn_error(struct server_connection *conn, ares_bool_t critical_failure); -static ares_bool_t same_questions(const ares_dns_record_t *qrec, +static ares_bool_t same_questions(const struct query *query, const ares_dns_record_t *arec); static ares_bool_t same_address(const struct sockaddr *sa, const struct ares_addr *aa); @@ -619,22 +619,19 @@ static void process_timeouts(ares_channel_t *channel, const ares_timeval_t *now) } } -static ares_status_t rewrite_without_edns(ares_dns_record_t *qdnsrec, - struct query *query) +static ares_status_t rewrite_without_edns(struct query *query) { - ares_status_t status; + ares_status_t status = ARES_SUCCESS; size_t i; ares_bool_t found_opt_rr = ARES_FALSE; - unsigned char *msg = NULL; - size_t msglen = 0; /* Find and remove the OPT RR record */ - for (i = 0; i < ares_dns_record_rr_cnt(qdnsrec, ARES_SECTION_ADDITIONAL); + for (i = 0; i < ares_dns_record_rr_cnt(query->query, ARES_SECTION_ADDITIONAL); i++) { const ares_dns_rr_t *rr; - rr = ares_dns_record_rr_get(qdnsrec, ARES_SECTION_ADDITIONAL, i); + rr = ares_dns_record_rr_get(query->query, ARES_SECTION_ADDITIONAL, i); if (ares_dns_rr_get_type(rr) == ARES_REC_TYPE_OPT) { - ares_dns_record_rr_del(qdnsrec, ARES_SECTION_ADDITIONAL, i); + ares_dns_record_rr_del(query->query, ARES_SECTION_ADDITIONAL, i); found_opt_rr = ARES_TRUE; break; } @@ -645,16 +642,6 @@ static ares_status_t rewrite_without_edns(ares_dns_record_t *qdnsrec, goto done; } - /* Rewrite the DNS message */ - status = ares_dns_write(qdnsrec, &msg, &msglen); - if (status != ARES_SUCCESS) { - goto done; /* LCOV_EXCL_LINE: OutOfMemory */ - } - - ares_free(query->qbuf); - query->qbuf = msg; - query->qlen = msglen; - done: return status; } @@ -672,7 +659,6 @@ static ares_status_t process_answer(ares_channel_t *channel, * invalidating the connection all-together */ struct server_state *server = conn->server; ares_dns_record_t *rdnsrec = NULL; - ares_dns_record_t *qdnsrec = NULL; ares_status_t status; ares_bool_t is_cached = ARES_FALSE; @@ -695,16 +681,9 @@ static ares_status_t process_answer(ares_channel_t *channel, goto cleanup; } - /* Parse the question we sent as we use it to compare */ - status = ares_dns_parse(query->qbuf, query->qlen, 0, &qdnsrec); - if (status != ARES_SUCCESS) { - end_query(channel, server, query, status, NULL); - goto cleanup; - } - /* Both the query id and the questions must be the same. We will drop any * replies that aren't for the same query as this is considered invalid. */ - if (!same_questions(qdnsrec, rdnsrec)) { + if (!same_questions(query, rdnsrec)) { /* Possible qid conflict due to delayed response, that's ok */ status = ARES_SUCCESS; goto cleanup; @@ -721,8 +700,8 @@ static ares_status_t process_answer(ares_channel_t *channel, * protocol extension is not understood by the responder. We must retry the * query without EDNS enabled. */ if (ares_dns_record_get_rcode(rdnsrec) == ARES_RCODE_FORMERR && - ares_dns_has_opt_rr(qdnsrec) && !ares_dns_has_opt_rr(rdnsrec)) { - status = rewrite_without_edns(qdnsrec, query); + ares_dns_has_opt_rr(query->query) && !ares_dns_has_opt_rr(rdnsrec)) { + status = rewrite_without_edns(query); if (status != ARES_SUCCESS) { end_query(channel, server, query, status, NULL); goto cleanup; @@ -793,7 +772,6 @@ cleanup: ares_dns_record_destroy(rdnsrec); } - ares_dns_record_destroy(qdnsrec); return status; } @@ -934,12 +912,49 @@ static ares_status_t ares__append_tcpbuf(struct server_state *server, const struct query *query) { ares_status_t status; + unsigned char *qbuf = NULL; + size_t qbuf_len = 0; + + status = ares_dns_write(query->query, &qbuf, &qbuf_len); + if (status != ARES_SUCCESS) { + goto done; + } + + status = ares__buf_append_be16(server->tcp_send, (unsigned short)qbuf_len); + if (status != ARES_SUCCESS) { + goto done; /* LCOV_EXCL_LINE: OutOfMemory */ + } + + status = ares__buf_append(server->tcp_send, qbuf, qbuf_len); + +done: + ares_free(qbuf); + return status; +} + + +static ares_status_t ares__write_udpbuf(ares_channel_t *channel, + ares_socket_t fd, + const struct query *query) +{ + ares_status_t status; + unsigned char *qbuf = NULL; + size_t qbuf_len = 0; - status = ares__buf_append_be16(server->tcp_send, (unsigned short)query->qlen); + status = ares_dns_write(query->query, &qbuf, &qbuf_len); if (status != ARES_SUCCESS) { - return status; /* LCOV_EXCL_LINE: OutOfMemory */ + goto done; } - return ares__buf_append(server->tcp_send, query->qbuf, query->qlen); + + if (ares__socket_write(channel, fd, qbuf, qbuf_len) == -1) { + status = ARES_ESERVFAIL; + } else { + status = ARES_SUCCESS; + } + +done: + ares_free(qbuf); + return status; } static size_t ares__calc_query_timeout(const struct query *query, @@ -1108,7 +1123,15 @@ ares_status_t ares__send_query(struct query *query, const ares_timeval_t *now) } conn = ares__llist_node_val(node); - if (ares__socket_write(channel, conn->fd, query->qbuf, query->qlen) == -1) { + + status = ares__write_udpbuf(channel, conn->fd, query); + if (status != ARES_SUCCESS) { + if (status == ARES_ENOMEM) { + /* Not retryable */ + end_query(channel, server, query, status, NULL); + return status; + } + /* FIXME: Handle EAGAIN here since it likely can happen. */ server_increment_failures(server, query->using_tcp); status = ares__requeue_query(query, now); @@ -1168,11 +1191,13 @@ ares_status_t ares__send_query(struct query *query, const ares_timeval_t *now) return ARES_SUCCESS; } -static ares_bool_t same_questions(const ares_dns_record_t *qrec, +static ares_bool_t same_questions(const struct query *query, const ares_dns_record_t *arec) { - size_t i; - ares_bool_t rv = ARES_FALSE; + size_t i; + ares_bool_t rv = ARES_FALSE; + const ares_dns_record_t *qrec = query->query; + const ares_channel_t *channel = query->channel; if (ares_dns_record_query_cnt(qrec) != ares_dns_record_query_cnt(arec)) { @@ -1198,9 +1223,26 @@ static ares_bool_t same_questions(const ares_dns_record_t *qrec, aname == NULL) { goto done; } - if (strcasecmp(qname, aname) != 0 || qtype != atype || qclass != aclass) { + + if (qtype != atype || qclass != aclass) { goto done; } + + if (channel->flags & ARES_FLAG_DNS0x20 && !query->using_tcp) { + /* NOTE: for DNS 0x20, part of the protection is to use a case-sensitive + * comparison of the DNS query name. This expects the upstream DNS + * server to preserve the case of the name in the response packet. + * https://datatracker.ietf.org/doc/html/draft-vixie-dnsext-dns0x20-00 + */ + if (strcmp(qname, aname) != 0) { + goto done; + } + } else { + /* without DNS0x20 use case-insensitive matching */ + if (strcasecmp(qname, aname) != 0) { + goto done; + } + } } rv = ARES_TRUE; @@ -1275,7 +1317,7 @@ void ares__free_query(struct query *query) query->callback = NULL; query->arg = NULL; /* Deallocate the memory associated with the query */ - ares_free(query->qbuf); + ares_dns_record_destroy(query->query); ares_free(query); } diff --git a/src/lib/ares_qcache.c b/src/lib/ares_qcache.c index 80eedaf1..bf4e95da 100644 --- a/src/lib/ares_qcache.c +++ b/src/lib/ares_qcache.c @@ -299,37 +299,18 @@ static unsigned int ares__qcache_soa_minimum(ares_dns_record_t *dnsrec) return 0; } -static char *ares__qcache_calc_key_frombuf(const unsigned char *qbuf, - size_t qlen) -{ - ares_status_t status; - ares_dns_record_t *dnsrec = NULL; - char *key = NULL; - - status = ares_dns_parse(qbuf, qlen, 0, &dnsrec); - if (status != ARES_SUCCESS) { - goto done; - } - - key = ares__qcache_calc_key(dnsrec); - -done: - ares_dns_record_destroy(dnsrec); - return key; -} - /* On success, takes ownership of dnsrec */ static ares_status_t ares__qcache_insert(ares__qcache_t *qcache, - ares_dns_record_t *dnsrec, - const unsigned char *qbuf, size_t qlen, + ares_dns_record_t *qresp, + ares_dns_record_t *qreq, const ares_timeval_t *now) { ares__qcache_entry_t *entry; unsigned int ttl; - ares_dns_rcode_t rcode = ares_dns_record_get_rcode(dnsrec); - ares_dns_flags_t flags = ares_dns_record_get_flags(dnsrec); + ares_dns_rcode_t rcode = ares_dns_record_get_rcode(qresp); + ares_dns_flags_t flags = ares_dns_record_get_flags(qresp); - if (qcache == NULL || dnsrec == NULL) { + if (qcache == NULL || qresp == NULL) { return ARES_EFORMERR; } @@ -345,9 +326,9 @@ static ares_status_t ares__qcache_insert(ares__qcache_t *qcache, /* Look at SOA for NXDOMAIN for minimum */ if (rcode == ARES_RCODE_NXDOMAIN) { - ttl = ares__qcache_soa_minimum(dnsrec); + ttl = ares__qcache_soa_minimum(qresp); } else { - ttl = ares__qcache_calc_minttl(dnsrec); + ttl = ares__qcache_calc_minttl(qresp); } if (ttl > qcache->max_ttl) { @@ -364,7 +345,7 @@ static ares_status_t ares__qcache_insert(ares__qcache_t *qcache, goto fail; /* LCOV_EXCL_LINE: OutOfMemory */ } - entry->dnsrec = dnsrec; + entry->dnsrec = qresp; entry->expire_ts = now->sec + (time_t)ttl; entry->insert_ts = now->sec; @@ -372,7 +353,7 @@ static ares_status_t ares__qcache_insert(ares__qcache_t *qcache, * request had, so we have to re-parse the request in order to generate the * key for caching, but we'll only do this once we know for sure we really * want to cache it */ - entry->key = ares__qcache_calc_key_frombuf(qbuf, qlen); + entry->key = ares__qcache_calc_key(qreq); if (entry->key == NULL) { goto fail; /* LCOV_EXCL_LINE: OutOfMemory */ } @@ -444,6 +425,6 @@ ares_status_t ares_qcache_insert(ares_channel_t *channel, const struct query *query, ares_dns_record_t *dnsrec) { - return ares__qcache_insert(channel->qcache, dnsrec, query->qbuf, query->qlen, + return ares__qcache_insert(channel->qcache, dnsrec, query->query, now); } diff --git a/src/lib/ares_send.c b/src/lib/ares_send.c index 7d36fbdf..94e7e12b 100644 --- a/src/lib/ares_send.c +++ b/src/lib/ares_send.c @@ -43,13 +43,72 @@ static unsigned short generate_unique_qid(ares_channel_t *channel) return id; } + +/* https://datatracker.ietf.org/doc/html/draft-vixie-dnsext-dns0x20-00 */ +static ares_status_t ares_apply_dns0x20(ares_channel_t *channel, + ares_dns_record_t *dnsrec) +{ + ares_status_t status = ARES_SUCCESS; + const char *name = NULL; + char dns0x20name[256]; + unsigned char randdata[256 / 8]; + size_t len; + size_t remaining_bits; + size_t total_bits; + size_t i; + + status = ares_dns_record_query_get(dnsrec, 0, &name, NULL, NULL); + if (status != ARES_SUCCESS) { + goto done; + } + + len = ares_strlen(name); + if (len == 0 || len >= sizeof(dns0x20name)) { + status = ARES_EBADNAME; + goto done; + } + + memset(dns0x20name, 0, sizeof(dns0x20name)); + + /* Fetch the minimum amount of random data we'd need for the string, which + * is 1 bit per byte */ + total_bits = ((len + 7) / 8) * 8; + remaining_bits = total_bits; + ares__rand_bytes(channel->rand_state, randdata, total_bits / 8); + + /* Randomly apply 0x20 to name */ + for (i=0; ichannel = channel; - - status = ares_dns_write(dnsrec, &query->qbuf, &query->qlen); - if (status != ARES_SUCCESS) { - ares_free(query); - callback(arg, status, 0, NULL); - return status; - } - + query->channel = channel; query->qid = id; query->timeout.sec = 0; query->timeout.usec = 0; + query->using_tcp = (channel->flags & ARES_FLAG_USEVC)?ARES_TRUE:ARES_FALSE; - /* Ignore first 2 bytes, assign our own query id */ - query->qbuf[0] = (unsigned char)((id >> 8) & 0xFF); - query->qbuf[1] = (unsigned char)(id & 0xFF); + /* Duplicate Query */ + query->query = ares_dns_record_duplicate(dnsrec); + if (query->query == NULL) { + ares_free(query); + callback(arg, ARES_ENOMEM, 0, NULL); + return ARES_ENOMEM; + } + + ares_dns_record_set_id(query->query, id); + + if (channel->flags & ARES_FLAG_DNS0x20 && !query->using_tcp) { + status = ares_apply_dns0x20(channel, query->query); + if (status != ARES_SUCCESS) { + /* LCOV_EXCL_START: OutOfMemory */ + callback(arg, status, 0, NULL); + ares__free_query(query); + return status; + /* LCOV_EXCL_STOP */ + } + } /* Fill in query arguments. */ query->callback = callback; @@ -103,9 +172,6 @@ ares_status_t ares_send_nolock(ares_channel_t *channel, /* Initialize query status. */ query->try_count = 0; - packetsz = (channel->flags & ARES_FLAG_EDNS) ? channel->ednspsz : PACKETSZ; - query->using_tcp = - (channel->flags & ARES_FLAG_USEVC) || query->qlen > packetsz; query->error_status = ARES_SUCCESS; query->timeouts = 0; diff --git a/test/ares-test-mock-et.cc b/test/ares-test-mock-et.cc index 32e1144a..430a471b 100644 --- a/test/ares-test-mock-et.cc +++ b/test/ares-test-mock-et.cc @@ -42,46 +42,6 @@ using testing::DoAll; namespace ares { namespace test { -TEST_P(MockEventThreadTest, Basic) { - std::vector reply = { - 0x00, 0x00, // qid - 0x84, // response + query + AA + not-TC + not-RD - 0x00, // not-RA + not-Z + not-AD + not-CD + rc=NoError - 0x00, 0x01, // 1 question - 0x00, 0x01, // 1 answer RRs - 0x00, 0x00, // 0 authority RRs - 0x00, 0x00, // 0 additional RRs - // Question - 0x03, 'w', 'w', 'w', - 0x06, 'g', 'o', 'o', 'g', 'l', 'e', - 0x03, 'c', 'o', 'm', - 0x00, - 0x00, 0x01, // type A - 0x00, 0x01, // class IN - // Answer - 0x03, 'w', 'w', 'w', - 0x06, 'g', 'o', 'o', 'g', 'l', 'e', - 0x03, 'c', 'o', 'm', - 0x00, - 0x00, 0x01, // type A - 0x00, 0x01, // class IN - 0x00, 0x00, 0x01, 0x00, // TTL - 0x00, 0x04, // rdata length - 0x01, 0x02, 0x03, 0x04 - }; - - ON_CALL(server_, OnRequest("www.google.com", T_A)) - .WillByDefault(SetReplyData(&server_, reply)); - - HostResult result; - ares_gethostbyname(channel_, "www.google.com.", AF_INET, HostCallback, &result); - Process(); - EXPECT_TRUE(result.done_); - std::stringstream ss; - ss << result.host_; - EXPECT_EQ("{'www.google.com' aliases=[] addrs=[1.2.3.4]}", ss.str()); -} - // UDP only so mock server doesn't get confused by concatenated requests TEST_P(MockUDPEventThreadTest, GetHostByNameParallelLookups) { DNSPacket rsp1; diff --git a/test/ares-test-mock.cc b/test/ares-test-mock.cc index 79874579..0f3b7421 100644 --- a/test/ares-test-mock.cc +++ b/test/ares-test-mock.cc @@ -40,7 +40,25 @@ using testing::DoAll; namespace ares { namespace test { -TEST_P(MockChannelTest, Basic) { +class NoDNS0x20MockTest + : public MockChannelOptsTest, + public ::testing::WithParamInterface { + public: + NoDNS0x20MockTest() + : MockChannelOptsTest(1, GetParam(), false, + FillOptions(&opts_), + ARES_OPT_FLAGS) {} + static struct ares_options* FillOptions(struct ares_options * opts) { + memset(opts, 0, sizeof(struct ares_options)); + opts->flags = ARES_FLAG_EDNS; + return opts; + } + private: + struct ares_options opts_; +}; + + +TEST_P(NoDNS0x20MockTest, Basic) { std::vector reply = { 0x00, 0x00, // qid 0x84, // response + query + AA + not-TC + not-RD @@ -80,6 +98,48 @@ TEST_P(MockChannelTest, Basic) { EXPECT_EQ("{'www.google.com' aliases=[] addrs=[1.2.3.4]}", ss.str()); } +TEST_P(MockUDPChannelTest, DNS0x20BadReply) { + std::vector reply = { + 0x00, 0x00, // qid + 0x84, // response + query + AA + not-TC + not-RD + 0x00, // not-RA + not-Z + not-AD + not-CD + rc=NoError + 0x00, 0x01, // 1 question + 0x00, 0x01, // 1 answer RRs + 0x00, 0x00, // 0 authority RRs + 0x00, 0x00, // 0 additional RRs + // Question + 0x03, 'w', 'w', 'w', + 0x1D, 's', 'o', 'm', 'e', 'l', 'o', 'n', 'g', 'd', 'o', 'm', 'a', 'i', 'n', 'n', 'a', 'm', 'e', 'b', 'e', 'c', 'a', 'u', 's', 'e', 'p', 'r', 'n', 'g', + 0x03, 'c', 'o', 'm', + 0x00, + 0x00, 0x01, // type A + 0x00, 0x01, // class IN + // Answer + 0x03, 'w', 'w', 'w', + 0x1D, 's', 'o', 'm', 'e', 'l', 'o', 'n', 'g', 'd', 'o', 'm', 'a', 'i', 'n', 'n', 'a', 'm', 'e', 'b', 'e', 'c', 'a', 'u', 's', 'e', 'p', 'r', 'n', 'g', + 0x03, 'c', 'o', 'm', + 0x00, + 0x00, 0x01, // type A + 0x00, 0x01, // class IN + 0x00, 0x00, 0x01, 0x00, // TTL + 0x00, 0x04, // rdata length + 0x01, 0x02, 0x03, 0x04 + }; + + ON_CALL(server_, OnRequest("www.somelongdomainnamebecauseprng.com", T_A)) + .WillByDefault(SetReplyData(&server_, reply)); + + /* Reply will be thrown out due to mismatched case for DNS 0x20 in response, + * its technically possible this test case may not fail if somehow the + * PRNG returns all lowercase domain name so we need to make this domain + * fairly long to make sure those odds are very very very low */ + HostResult result; + ares_gethostbyname(channel_, "www.somelongdomainnamebecauseprng.com.", AF_INET, HostCallback, &result); + Process(); + EXPECT_TRUE(result.done_); + EXPECT_EQ(ARES_ETIMEOUT, result.status_); +} + // UDP only so mock server doesn't get confused by concatenated requests TEST_P(MockUDPChannelTest, GetHostByNameParallelLookups) { DNSPacket rsp1; @@ -1797,6 +1857,8 @@ std::string PrintFamily(const testing::TestParamInfo &info) return name; } +INSTANTIATE_TEST_SUITE_P(AddressFamilies, NoDNS0x20MockTest, ::testing::ValuesIn(ares::test::families), PrintFamily); + INSTANTIATE_TEST_SUITE_P(AddressFamilies, MockChannelTest, ::testing::ValuesIn(ares::test::families_modes), PrintFamilyMode); INSTANTIATE_TEST_SUITE_P(AddressFamilies, MockUDPChannelTest, ::testing::ValuesIn(ares::test::families), PrintFamily); diff --git a/test/ares-test-parse-a.cc b/test/ares-test-parse-a.cc index 7e045dcc..a91ee56a 100644 --- a/test/ares-test-parse-a.cc +++ b/test/ares-test-parse-a.cc @@ -297,7 +297,7 @@ TEST_F(LibraryTest, ParseAReplyErrors) { ASSERT_NE(nullptr, host); std::stringstream ss; ss << HostEnt(host); - EXPECT_EQ("{'Axample.com' aliases=[] addrs=[2.3.4.5]}", ss.str()); + EXPECT_EQ("{'axample.com' aliases=[] addrs=[2.3.4.5]}", ss.str()); ares_free_hostent(host); host = nullptr; diff --git a/test/ares-test-parse-aaaa.cc b/test/ares-test-parse-aaaa.cc index 17255b0d..8443bea4 100644 --- a/test/ares-test-parse-aaaa.cc +++ b/test/ares-test-parse-aaaa.cc @@ -148,7 +148,7 @@ TEST_F(LibraryTest, ParseAaaaReplyErrors) { ASSERT_NE(nullptr, host); std::stringstream ss; ss << HostEnt(host); - EXPECT_EQ("{'Axample.com' aliases=[] addrs=[0101:0101:0202:0202:0303:0303:0404:0404]}", ss.str()); + EXPECT_EQ("{'axample.com' aliases=[] addrs=[0101:0101:0202:0202:0303:0303:0404:0404]}", ss.str()); ares_free_hostent(host); host = nullptr; diff --git a/test/ares-test.cc b/test/ares-test.cc index 639609be..f96fb5b3 100644 --- a/test/ares-test.cc +++ b/test/ares-test.cc @@ -58,6 +58,7 @@ extern "C" { #include #include +#include #ifdef WIN32 #define BYTE_CAST (char *) @@ -431,6 +432,7 @@ void DefaultChannelModeTest::Process(unsigned int cancel_ms) { MockServer::MockServer(int family, unsigned short port) : udpport_(port), tcpport_(port), qid_(-1) { + reply_ = nullptr; // Create a TCP socket to receive data on. tcp_data_ = NULL; tcp_data_len_ = 0; @@ -570,21 +572,22 @@ void MockServer::ProcessPacket(ares_socket_t fd, struct sockaddr_storage *addr, } if (enclen > qlen) { std::cerr << "(error, encoded name len " << enclen << "bigger than remaining data " << qlen << " bytes)" << std::endl; + ares_free_string(name); return; } qlen -= (int)enclen; question += enclen; - std::string namestr(name); - ares_free_string(name); if (qlen < 4) { std::cerr << "Unexpected question size (" << qlen << " bytes after name)" << std::endl; + ares_free_string(name); return; } if (DNS_QUESTION_CLASS(question) != C_IN) { std::cerr << "Unexpected question class (" << DNS_QUESTION_CLASS(question) << ")" << std::endl; + ares_free_string(name); return; } int rrtype = DNS_QUESTION_TYPE(question); @@ -595,11 +598,11 @@ void MockServer::ProcessPacket(ares_socket_t fd, struct sockaddr_storage *addr, std::cerr << "received " << (fd == udpfd_ ? "UDP" : "TCP") << " request " << reqstr << " on port " << (fd == udpfd_ ? udpport_ : tcpport_) << ":" << getaddrport(addr) << std::endl; - std::cerr << "ProcessRequest(" << qid << ", '" << namestr + std::cerr << "ProcessRequest(" << qid << ", '" << name << "', " << RRTypeToString(rrtype) << ")" << std::endl; } - ProcessRequest(fd, addr, addrlen, reqstr, qid, namestr, rrtype); - + ProcessRequest(fd, addr, addrlen, reqstr, qid, name, rrtype); + ares_free_string(name); } void MockServer::ProcessFD(ares_socket_t fd) { @@ -667,24 +670,32 @@ std::set MockServer::fds() const { return result; } - void MockServer::ProcessRequest(ares_socket_t fd, struct sockaddr_storage* addr, ares_socklen_t addrlen, const std::string &reqstr, - int qid, const std::string& name, int rrtype) { + int qid, const char *name, int rrtype) { + + /* DNS 0x20 will mix case, do case-insensitive matching of name in request */ + char lower_name[256]; + arestest_strtolower(lower_name, name, sizeof(lower_name)); + // Before processing, let gMock know the request is happening. - OnRequest(name, rrtype); + OnRequest(lower_name, rrtype); // If we are expecting a specific request then check it matches here. if (expected_request_.length() > 0) { ASSERT_EQ(expected_request_, reqstr); } - if (reply_.size() == 0) { + if (reply_ != nullptr) { + exact_reply_ = reply_->data(name); + } + + if (exact_reply_.size() == 0) { return; } // Make a local copy of the current pending reply. - std::vector reply = reply_; + std::vector reply = exact_reply_; if (qid_ >= 0) { // Use the explicitly specified query ID. @@ -788,6 +799,12 @@ MockChannelOptsTest::MockChannelOptsTest(int count, optmask |= ARES_OPT_QUERY_CACHE; } + /* Enable DNS0x20 by default. Need to also turn on default flag of EDNS */ + if (!(optmask & ARES_OPT_FLAGS)) { + optmask |= ARES_OPT_FLAGS; + opts.flags = ARES_FLAG_DNS0x20|ARES_FLAG_EDNS; + } + EXPECT_EQ(ARES_SUCCESS, ares_init_options(&channel_, &opts, optmask)); EXPECT_NE(nullptr, channel_); @@ -961,8 +978,13 @@ HostEnt::HostEnt(const struct hostent *hostent) : addrtype_(-1) { if (!hostent) return; - if (hostent->h_name) - name_ = hostent->h_name; + if (hostent->h_name) { + // DNS 0x20 may mix case, output as all lower for checks as the mixed case + // is really more of an internal thing + char lowername[256]; + arestest_strtolower(lowername, hostent->h_name, sizeof(lowername)); + name_ = lowername; + } if (hostent->h_aliases) { char** palias = hostent->h_aliases; diff --git a/test/ares-test.h b/test/ares-test.h index 2d204b76..70f1134c 100644 --- a/test/ares-test.h +++ b/test/ares-test.h @@ -218,6 +218,7 @@ protected: ares_channel_t *channel_; }; + // Mock DNS server to allow responses to be scripted by tests. class MockServer { public: @@ -232,12 +233,14 @@ public: // with the value from the request. void SetReplyData(const std::vector &reply) { - reply_ = reply; + exact_reply_ = reply; + reply_ = nullptr; } void SetReply(const DNSPacket *reply) { - SetReplyData(reply->data()); + reply_ = reply; + exact_reply_.clear(); } // Set the reply to be sent next as well as the request (in string form) that @@ -246,7 +249,7 @@ public: void SetReplyExpRequest(const DNSPacket *reply, const std::string &request) { expected_request_ = request; - SetReply(reply); + reply_ = reply; } void SetReplyQID(int qid) @@ -256,6 +259,8 @@ public: void Disconnect() { + reply_ = nullptr; + exact_reply_.clear(); for (ares_socket_t fd : connfds_) { sclose(fd); } @@ -285,7 +290,7 @@ public: private: void ProcessRequest(ares_socket_t fd, struct sockaddr_storage *addr, ares_socklen_t addrlen, const std::string &reqstr, - int qid, const std::string &name, int rrtype); + int qid, const char *name, int rrtype); void ProcessPacket(ares_socket_t fd, struct sockaddr_storage *addr, ares_socklen_t addrlen, byte *data, int len); unsigned short udpport_; @@ -293,7 +298,8 @@ private: ares_socket_t udpfd_; ares_socket_t tcpfd_; std::set connfds_; - std::vector reply_; + std::vector exact_reply_; + const DNSPacket *reply_; std::string expected_request_; int qid_; unsigned char *tcp_data_; diff --git a/test/dns-proto.cc b/test/dns-proto.cc index 5801c075..f9c7d4d5 100644 --- a/test/dns-proto.cc +++ b/test/dns-proto.cc @@ -34,6 +34,33 @@ #include #include +#include + +#if defined(_WIN32) && !defined(strcasecmp) +# define strcasecmp(a,b) stricmp(a,b) +#endif + +void arestest_strtolower(char *dest, const char *src, size_t dest_size) +{ + size_t len; + + if (dest == NULL) + return; + + memset(dest, 0, dest_size); + + if (src == NULL) + return; + + len = strlen(src); + if (len >= dest_size) + return; + + for (size_t i = 0; i& packet, } *len -= (int)enclen; *data += enclen; - ss << "'" << name << "' "; + + // DNS 0x20 may mix case, output as all lower for checks as the mixed case + // is really more of an internal thing + char lowername[256]; + arestest_strtolower(lowername, name, sizeof(lowername)); ares_free_string(name); + + ss << "'" << lowername << "' "; if (*len < NS_QFIXEDSZ) { ss << "(too short, len left " << *len << ")"; return ss.str(); @@ -498,7 +531,7 @@ void PushInt16(std::vector* data, int value) { data->push_back((byte)value & 0x00ff); } -std::vector EncodeString(const std::string& name) { +std::vector EncodeString(const std::string &name) { std::vector data; std::stringstream ss(name); std::string label; @@ -515,9 +548,14 @@ std::vector EncodeString(const std::string& name) { return data; } -std::vector DNSQuestion::data() const { +std::vector DNSQuestion::data(const char *request_name) const { std::vector data; - std::vector encname = EncodeString(name_); + std::vector encname; + if (request_name != nullptr && strcasecmp(request_name, name_.c_str()) == 0) { + encname = EncodeString(request_name); + } else { + encname = EncodeString(name_); + } data.insert(data.end(), encname.begin(), encname.end()); PushInt16(&data, rrtype_); PushInt16(&data, qclass_); @@ -641,7 +679,7 @@ std::vector DNSNaptrRR::data() const { return data; } -std::vector DNSPacket::data() const { +std::vector DNSPacket::data(const char *request_name) const { std::vector data; PushInt16(&data, qid_); byte b = 0x00; @@ -669,7 +707,7 @@ std::vector DNSPacket::data() const { PushInt16(&data, count); for (const std::unique_ptr& question : questions_) { - std::vector qdata = question->data(); + std::vector qdata = question->data(request_name); data.insert(data.end(), qdata.begin(), qdata.end()); } for (const std::unique_ptr& rr : answers_) { diff --git a/test/dns-proto.h b/test/dns-proto.h index 618673ee..88e0a23c 100644 --- a/test/dns-proto.h +++ b/test/dns-proto.h @@ -36,6 +36,8 @@ #include #include +extern "C" void arestest_strtolower(char *dest, const char *src, size_t dest_size); + namespace ares { typedef unsigned char byte; @@ -81,7 +83,13 @@ struct DNSQuestion { { } - virtual std::vector data() const; + virtual std::vector data(const char *request_name) const; + + virtual std::vector data() const + { + return data(nullptr); + } + std::string name_; int rrtype_; int qclass_; @@ -375,7 +383,12 @@ struct DNSPacket { } // Return the encoded packet. - std::vector data() const; + std::vector data(const char *request_name) const; + std::vector data() const + { + return data(nullptr); + } + int qid_; bool response_;