diff --git a/test/ares-test-init.cc b/test/ares-test-init.cc index 0d4bd278..7cec5834 100644 --- a/test/ares-test-init.cc +++ b/test/ares-test-init.cc @@ -222,6 +222,31 @@ TEST_F(LibraryTest, FailChannelInit) { ares_library_cleanup(); } +TEST_F(LibraryTest, EnvInit) { + ares_channel channel = nullptr; + EnvValue v1("LOCALDOMAIN", "this.is.local"); + EnvValue v2("RES_OPTIONS", "options debug ndots:3 retry:3 rotate retrans:2"); + EXPECT_EQ(ARES_SUCCESS, ares_init(&channel)); + ares_destroy(channel); +} + +TEST_F(LibraryTest, EnvInitAllocFail) { + ares_channel channel; + EnvValue v1("LOCALDOMAIN", "this.is.local"); + EnvValue v2("RES_OPTIONS", "options debug ndots:3 retry:3 rotate retrans:2"); + for (int ii = 1; ii <= 10; ii++) { + ClearFails(); + SetAllocFail(ii); + channel = nullptr; + int rc = ares_init(&channel); + if (rc == ARES_SUCCESS) { + ares_destroy(channel); + } else { + EXPECT_EQ(ARES_ENOMEM, rc); + } + } +} + TEST_F(DefaultChannelTest, SetAddresses) { ares_set_local_ip4(channel_, 0x01020304); byte addr6[16] = {0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, diff --git a/test/ares-test-mock.cc b/test/ares-test-mock.cc index 2a235acb..52293ba5 100644 --- a/test/ares-test-mock.cc +++ b/test/ares-test-mock.cc @@ -89,6 +89,27 @@ TEST_P(MockUDPChannelTest, ParallelLookups) { EXPECT_EQ("{'www.google.com' aliases=[] addrs=[2.3.4.5]}", ss3.str()); } +// UDP to TCP specific test +TEST_P(MockUDPChannelTest, TruncationRetry) { + DNSPacket rsptruncated; + rsptruncated.set_response().set_aa().set_tc() + .add_question(new DNSQuestion("www.google.com", ns_t_a)); + DNSPacket rspok; + rspok.set_response() + .add_question(new DNSQuestion("www.google.com", ns_t_a)) + .add_answer(new DNSARR("www.google.com", 100, {1, 2, 3, 4})); + EXPECT_CALL(server_, OnRequest("www.google.com", ns_t_a)) + .WillOnce(SetReply(&server_, &rsptruncated)) + .WillOnce(SetReply(&server_, &rspok)); + HostResult result; + ares_gethostbyname(channel_, "www.google.com.", AF_INET, HostCallback, &result); + Process(); + EXPECT_TRUE(result.done_); + std::stringstream ss; + ss << result.host_; + EXPECT_EQ("{'www.google.com' aliases=[] addrs=[1.2.3.4]}", ss.str()); +} + static int sock_cb_count = 0; static int SocketCallback(ares_socket_t fd, int type, void *data) { if (verbose) std::cerr << "SocketCallback(" << fd << ") invoked" << std::endl; @@ -119,6 +140,18 @@ TEST_P(MockChannelTest, SockCallback) { } // TCP only to prevent retries +TEST_P(MockTCPChannelTest, MalformedResponse) { + std::vector one = {0x01}; + EXPECT_CALL(server_, OnRequest("www.google.com", ns_t_a)) + .WillOnce(SetReplyData(&server_, one)); + + HostResult result; + ares_gethostbyname(channel_, "www.google.com.", AF_INET, HostCallback, &result); + Process(); + EXPECT_TRUE(result.done_); + EXPECT_EQ(ARES_ETIMEOUT, result.status_); +} + TEST_P(MockTCPChannelTest, FormErrResponse) { DNSPacket rsp; rsp.set_response().set_aa() @@ -192,20 +225,27 @@ TEST_P(MockTCPChannelTest, YXDomainResponse) { EXPECT_EQ(ARES_ENODATA, result.status_); } -class MockNoCheckRespChannelTest +class MockFlagsChannelOptsTest : public MockChannelOptsTest, - public ::testing::WithParamInterface { + public ::testing::WithParamInterface< std::pair > { public: - MockNoCheckRespChannelTest() : MockChannelOptsTest(GetParam(), true, FillOptions(&opts_), ARES_OPT_FLAGS) {} - static struct ares_options* FillOptions(struct ares_options * opts) { + MockFlagsChannelOptsTest(int flags) + : MockChannelOptsTest(GetParam().first, GetParam().second, + FillOptions(&opts_, flags), ARES_OPT_FLAGS) {} + static struct ares_options* FillOptions(struct ares_options * opts, int flags) { memset(opts, 0, sizeof(struct ares_options)); - opts->flags = ARES_FLAG_NOCHECKRESP; + opts->flags = flags; return opts; } private: struct ares_options opts_; }; +class MockNoCheckRespChannelTest : public MockFlagsChannelOptsTest { + public: + MockNoCheckRespChannelTest() : MockFlagsChannelOptsTest(ARES_FLAG_NOCHECKRESP) {} +}; + TEST_P(MockNoCheckRespChannelTest, ServFailResponse) { DNSPacket rsp; rsp.set_response().set_aa() @@ -248,6 +288,31 @@ TEST_P(MockNoCheckRespChannelTest, RefusedResponse) { EXPECT_EQ(ARES_EREFUSED, result.status_); } +class MockEDNSChannelTest : public MockFlagsChannelOptsTest { + public: + MockEDNSChannelTest() : MockFlagsChannelOptsTest(ARES_FLAG_EDNS) {} +}; + +TEST_P(MockEDNSChannelTest, RetryWithoutEDNS) { + DNSPacket rspfail; + rspfail.set_response().set_aa().set_rcode(ns_r_servfail) + .add_question(new DNSQuestion("www.google.com", ns_t_a)); + DNSPacket rspok; + rspok.set_response() + .add_question(new DNSQuestion("www.google.com", ns_t_a)) + .add_answer(new DNSARR("www.google.com", 100, {1, 2, 3, 4})); + EXPECT_CALL(server_, OnRequest("www.google.com", ns_t_a)) + .WillOnce(SetReply(&server_, &rspfail)) + .WillOnce(SetReply(&server_, &rspok)); + HostResult result; + ares_gethostbyname(channel_, "www.google.com.", AF_INET, HostCallback, &result); + Process(); + EXPECT_TRUE(result.done_); + std::stringstream ss; + ss << result.host_; + EXPECT_EQ("{'www.google.com' aliases=[] addrs=[1.2.3.4]}", ss.str()); +} + TEST_P(MockChannelTest, SearchDomains) { DNSPacket nofirst; nofirst.set_response().set_aa().set_rcode(ns_r_nxdomain) @@ -619,7 +684,7 @@ TEST_P(MockUDPChannelTest, SearchDomainsAllocFail) { ares_gethostbyname(channel_, "www", AF_INET, HostCallback, result); Process(); EXPECT_TRUE(result->done_); - if (result->status_ != ARES_ENOMEM) { + if (result->status_ == ARES_SUCCESS) { std::stringstream ss; ss << result->host_; EXPECT_EQ("{'www.third.gov' aliases=[] addrs=[2.3.4.5]}", ss.str()) << " failed alloc #" << ii; @@ -801,7 +866,16 @@ INSTANTIATE_TEST_CASE_P(AddressFamilies, MockTCPChannelTest, ::testing::Values(AF_INET, AF_INET6)); INSTANTIATE_TEST_CASE_P(AddressFamilies, MockNoCheckRespChannelTest, - ::testing::Values(AF_INET, AF_INET6)); + ::testing::Values(std::make_pair(AF_INET, false), + std::make_pair(AF_INET, true), + std::make_pair(AF_INET6, false), + std::make_pair(AF_INET6, true))); + +INSTANTIATE_TEST_CASE_P(AddressFamilies, MockEDNSChannelTest, + ::testing::Values(std::make_pair(AF_INET, false), + std::make_pair(AF_INET, true), + std::make_pair(AF_INET6, false), + std::make_pair(AF_INET6, true))); } // namespace test } // namespace ares diff --git a/test/ares-test.cc b/test/ares-test.cc index b0bbe5f7..c55e34d0 100644 --- a/test/ares-test.cc +++ b/test/ares-test.cc @@ -68,6 +68,7 @@ void ProcessWork(ares_channel channel, // static void LibraryTest::SetAllocFail(int nth) { assert(nth > 0); + assert(nth <= (int)(8 * sizeof(fails_))); fails_ |= (1 << (nth - 1)); } @@ -131,29 +132,29 @@ void DefaultChannelModeTest::Process() { ProcessWork(channel_, NoExtraFDs, nullptr); } -MockServer::MockServer(int family, bool tcp, int port) : tcp_(tcp), port_(port), qid_(-1) { - if (tcp_) { - // Create a TCP socket to receive data on. - sockfd_ = socket(family, SOCK_STREAM, 0); - EXPECT_NE(-1, sockfd_); - int optval = 1; - setsockopt(sockfd_, SOL_SOCKET, SO_REUSEADDR, - (const void *)&optval , sizeof(int)); - } else { - // Create a UDP socket to receive data on. - sockfd_ = socket(family, SOCK_DGRAM, 0); - EXPECT_NE(-1, sockfd_); - } +MockServer::MockServer(int family, int port) : port_(port), qid_(-1) { + // Create a TCP socket to receive data on. + tcpfd_ = socket(family, SOCK_STREAM, 0); + EXPECT_NE(-1, tcpfd_); + int optval = 1; + setsockopt(tcpfd_, SOL_SOCKET, SO_REUSEADDR, + (const void *)&optval , sizeof(int)); + + // Create a UDP socket to receive data on. + udpfd_ = socket(family, SOCK_DGRAM, 0); + EXPECT_NE(-1, udpfd_); - // Bind it to the given port. + // Bind the sockets to the given port. if (family == AF_INET) { struct sockaddr_in addr; memset(&addr, 0, sizeof(addr)); addr.sin_family = AF_INET; addr.sin_addr.s_addr = htonl(INADDR_ANY); addr.sin_port = htons(port_); - int rc = bind(sockfd_, (struct sockaddr*)&addr, sizeof(addr)); - EXPECT_EQ(0, rc) << "Failed to bind AF_INET to port " << port_; + int tcprc = bind(tcpfd_, (struct sockaddr*)&addr, sizeof(addr)); + EXPECT_EQ(0, tcprc) << "Failed to bind AF_INET to TCP port " << port_; + int udprc = bind(udpfd_, (struct sockaddr*)&addr, sizeof(addr)); + EXPECT_EQ(0, udprc) << "Failed to bind AF_INET to UDP port " << port_; } else { EXPECT_EQ(AF_INET6, family); struct sockaddr_in6 addr; @@ -161,31 +162,31 @@ MockServer::MockServer(int family, bool tcp, int port) : tcp_(tcp), port_(port), addr.sin6_family = AF_INET6; addr.sin6_addr = in6addr_any; addr.sin6_port = htons(port_); - int rc = bind(sockfd_, (struct sockaddr*)&addr, sizeof(addr)); - EXPECT_EQ(0, rc) << "Failed to bind AF_INET6 to port " << port_; + int tcprc = bind(tcpfd_, (struct sockaddr*)&addr, sizeof(addr)); + EXPECT_EQ(0, tcprc) << "Failed to bind AF_INET6 to UDP port " << port_; + int udprc = bind(udpfd_, (struct sockaddr*)&addr, sizeof(addr)); + EXPECT_EQ(0, udprc) << "Failed to bind AF_INET6 to UDP port " << port_; } // For TCP, also need to listen for connections. - if (tcp_) { - EXPECT_EQ(0, listen(sockfd_, 5)) << "Failed to listen for TCP connections"; - } + EXPECT_EQ(0, listen(tcpfd_, 5)) << "Failed to listen for TCP connections"; } MockServer::~MockServer() { for (int fd : connfds_) { close(fd); } - close(sockfd_); - sockfd_ = -1; + close(tcpfd_); + close(udpfd_); } void MockServer::Process(int fd) { - if (fd != sockfd_ && connfds_.find(fd) == connfds_.end()) { + if (fd != tcpfd_ && fd != udpfd_ && connfds_.find(fd) == connfds_.end()) { std::cerr << "Asked to process unknown fd " << fd << std::endl; return; } - if (tcp_ && fd == sockfd_) { - int connfd = accept(sockfd_, NULL, NULL); + if (fd == tcpfd_) { + int connfd = accept(tcpfd_, NULL, NULL); if (connfd < 0) { std::cerr << "Error accepting connection on fd " << fd << std::endl; } else { @@ -201,7 +202,7 @@ void MockServer::Process(int fd) { int len = recvfrom(fd, buffer, sizeof(buffer), 0, (struct sockaddr *)&addr, &addrlen); byte* data = buffer; - if (tcp_) { + if (fd != udpfd_) { if (len == 0) { connfds_.erase(std::find(connfds_.begin(), connfds_.end(), fd)); close(fd); @@ -213,10 +214,10 @@ void MockServer::Process(int fd) { } int tcplen = (data[0] << 8) + data[1]; data += 2; - if (tcplen + 2 != len) { + len -= 2; + if (tcplen != len) { std::cerr << "Warning: TCP length " << tcplen - << " doesn't match data length (" << len - << " - 2)" << std::endl; + << " doesn't match remaining data length " << len << std::endl; } } @@ -268,14 +269,19 @@ void MockServer::Process(int fd) { } int rrtype = DNS_QUESTION_TYPE(question); - if (verbose) std::cerr << "ProcessRequest(" << qid << ", '" << namestr - << "', " << RRTypeToString(rrtype) << ")" << std::endl; + if (verbose) { + std::vector req(data, data + len); + std::cerr << "received " << (fd == udpfd_ ? "UDP" : "TCP") << " request " << PacketToString(req) << std::endl; + std::cerr << "ProcessRequest(" << qid << ", '" << namestr + << "', " << RRTypeToString(rrtype) << ")" << std::endl; + } ProcessRequest(fd, &addr, addrlen, qid, namestr, rrtype); } std::set MockServer::fds() const { std::set result = connfds_; - result.insert(sockfd_); + result.insert(tcpfd_); + result.insert(udpfd_); return result; } @@ -284,24 +290,26 @@ void MockServer::ProcessRequest(int fd, struct sockaddr_storage* addr, int addrl // Before processing, let gMock know the request is happening. OnRequest(name, rrtype); - // Make a local copy of the current pending reply. - if (reply_.size() < 2) { - if (verbose) std::cerr << "Skipping reply as not-present/too-short" << std::endl; + if (reply_.size() == 0) { return; } + + // Make a local copy of the current pending reply. std::vector reply = reply_; if (qid_ >= 0) { // Use the explicitly specified query ID. qid = qid_; } - // Overwrite the query ID. - reply[0] = (byte)((qid >> 8) & 0xff); - reply[1] = (byte)(qid & 0xff); + if (reply.size() >= 2) { + // Overwrite the query ID if space to do so. + reply[0] = (byte)((qid >> 8) & 0xff); + reply[1] = (byte)(qid & 0xff); + } if (verbose) std::cerr << "sending reply " << PacketToString(reply) << std::endl; // Prefix with 2-byte length if TCP. - if (tcp_) { + if (fd != udpfd_) { int len = reply.size(); std::vector vlen = {(byte)((len & 0xFF00) >> 8), (byte)(len & 0xFF)}; reply.insert(reply.begin(), vlen.begin(), vlen.end()); @@ -321,7 +329,7 @@ MockChannelOptsTest::MockChannelOptsTest(int family, bool force_tcp, struct ares_options* givenopts, int optmask) - : server_(family, force_tcp, mock_port), channel_(nullptr) { + : server_(family, mock_port), channel_(nullptr) { // Set up channel options. struct ares_options opts; if (givenopts) { diff --git a/test/ares-test.h b/test/ares-test.h index 4618de2d..21c900fe 100644 --- a/test/ares-test.h +++ b/test/ares-test.h @@ -114,7 +114,7 @@ class DefaultChannelModeTest // Mock DNS server to allow responses to be scripted by tests. class MockServer { public: - MockServer(int family, bool tcp, int port); + MockServer(int family, int port); ~MockServer(); // Mock method indicating the processing of a particular @@ -140,9 +140,9 @@ class MockServer { void ProcessRequest(int fd, struct sockaddr_storage* addr, int addrlen, int qid, const std::string& name, int rrtype); - bool tcp_; int port_; - int sockfd_; + int udpfd_; + int tcpfd_; std::set connfds_; std::vector reply_; int qid_; diff --git a/test/dns-proto.cc b/test/dns-proto.cc index 40366eed..8814d28f 100644 --- a/test/dns-proto.cc +++ b/test/dns-proto.cc @@ -232,7 +232,11 @@ std::string QuestionToString(const std::vector& packet, char *name = nullptr; long enclen; - ares_expand_name(*data, packet.data(), packet.size(), &name, &enclen); + int rc = ares_expand_name(*data, packet.data(), packet.size(), &name, &enclen); + if (rc != ARES_SUCCESS) { + ss << "(error from ares_expand_name)"; + return ss.str(); + } if (enclen > *len) { ss << "(error, encoded name len " << enclen << "bigger than remaining data " << *len << " bytes)"; return ss.str(); @@ -264,7 +268,11 @@ std::string RRToString(const std::vector& packet, char *name = nullptr; long enclen; - ares_expand_name(*data, packet.data(), packet.size(), &name, &enclen); + int rc = ares_expand_name(*data, packet.data(), packet.size(), &name, &enclen); + if (rc != ARES_SUCCESS) { + ss << "(error from ares_expand_name)"; + return ss.str(); + } if (enclen > *len) { ss << "(error, encoded name len " << enclen << "bigger than remaining data " << *len << " bytes)"; return ss.str(); @@ -275,7 +283,7 @@ std::string RRToString(const std::vector& packet, free(name); name = nullptr; - if (*len < NS_QFIXEDSZ) { + if (*len < NS_RRFIXEDSZ) { ss << "(too short, len left " << *len << ")"; return ss.str(); } @@ -293,92 +301,141 @@ std::string RRToString(const std::vector& packet, *data += NS_RRFIXEDSZ; *len -= NS_RRFIXEDSZ; - switch (rrtype) { - case ns_t_a: - case ns_t_aaaa: - ss << " " << AddressToString(*data, rdatalen); - break; - case ns_t_txt: { - const byte* p = *data; - while (p < (*data + rdatalen)) { - int len = *p++; - std::string txt(p, p + len); - ss << " " << len << ":'" << txt << "'"; - p += len; + if (*len < rdatalen) { + ss << "(RR too long at " << rdatalen << ", len left " << *len << ")"; + } else { + switch (rrtype) { + case ns_t_a: + case ns_t_aaaa: + ss << " " << AddressToString(*data, rdatalen); + break; + case ns_t_txt: { + const byte* p = *data; + while (p < (*data + rdatalen)) { + int len = *p++; + if ((p + len) <= (*data + rdatalen)) { + std::string txt(p, p + len); + ss << " " << len << ":'" << txt << "'"; + } else { + ss << "(string too long)"; + } + p += len; + } + break; + } + case ns_t_cname: + case ns_t_ns: + case ns_t_ptr: { + int rc = ares_expand_name(*data, packet.data(), packet.size(), &name, &enclen); + if (rc != ARES_SUCCESS) { + ss << "(error from ares_expand_name)"; + break; + } + ss << " '" << name << "'"; + free(name); + break; + } + case ns_t_mx: + if (rdatalen > 2) { + int rc = ares_expand_name(*data + 2, packet.data(), packet.size(), &name, &enclen); + if (rc != ARES_SUCCESS) { + ss << "(error from ares_expand_name)"; + break; + } + ss << " " << DNS__16BIT(*data) << " '" << name << "'"; + free(name); + } else { + ss << "(RR too short)"; + } + break; + case ns_t_srv: { + if (rdatalen > 6) { + const byte* p = *data; + unsigned long prio = DNS__16BIT(p); + unsigned long weight = DNS__16BIT(p + 2); + unsigned long port = DNS__16BIT(p + 4); + p += 6; + int rc = ares_expand_name(p, packet.data(), packet.size(), &name, &enclen); + if (rc != ARES_SUCCESS) { + ss << "(error from ares_expand_name)"; + break; + } + ss << prio << " " << weight << " " << port << " '" << name << "'"; + free(name); + } else { + ss << "(RR too short)"; + } + break; + } + case ns_t_soa: { + const byte* p = *data; + int rc = ares_expand_name(p, packet.data(), packet.size(), &name, &enclen); + if (rc != ARES_SUCCESS) { + ss << "(error from ares_expand_name)"; + break; + } + ss << " '" << name << "'"; + free(name); + p += enclen; + rc = ares_expand_name(p, packet.data(), packet.size(), &name, &enclen); + if (rc != ARES_SUCCESS) { + ss << "(error from ares_expand_name)"; + break; + } + ss << " '" << name << "'"; + free(name); + p += enclen; + if ((p + 20) <= (*data + rdatalen)) { + unsigned long serial = DNS__32BIT(p); + unsigned long refresh = DNS__32BIT(p + 4); + unsigned long retry = DNS__32BIT(p + 8); + unsigned long expire = DNS__32BIT(p + 12); + unsigned long minimum = DNS__32BIT(p + 16); + ss << " " << serial << " " << refresh << " " << retry << " " << expire << " " << minimum; + } else { + ss << "(RR too short)"; + } + break; + } + case ns_t_naptr: { + if (rdatalen > 7) { + const byte* p = *data; + unsigned long order = DNS__16BIT(p); + unsigned long pref = DNS__16BIT(p + 2); + p += 4; + ss << order << " " << pref; + + int len = *p++; + std::string flags(p, p + len); + ss << " " << flags; + p += len; + + len = *p++; + std::string service(p, p + len); + ss << " '" << service << "'"; + p += len; + + len = *p++; + std::string regexp(p, p + len); + ss << " '" << regexp << "'"; + p += len; + + int rc = ares_expand_name(p, packet.data(), packet.size(), &name, &enclen); + if (rc != ARES_SUCCESS) { + ss << "(error from ares_expand_name)"; + break; + } + ss << " '" << name << "'"; + free(name); + } else { + ss << "(RR too short)"; + } + break; + } + default: + ss << " " << HexDump(*data, rdatalen); + break; } - break; - } - case ns_t_cname: - case ns_t_ns: - case ns_t_ptr: - ares_expand_name(*data, packet.data(), packet.size(), &name, &enclen); - ss << " '" << name << "'"; - free(name); - break; - case ns_t_mx: - ares_expand_name(*data + 2, packet.data(), packet.size(), &name, &enclen); - ss << " " << DNS__16BIT(*data) << " '" << name << "'"; - free(name); - break; - case ns_t_srv: { - const byte* p = *data; - unsigned long prio = DNS__16BIT(p); - unsigned long weight = DNS__16BIT(p + 2); - unsigned long port = DNS__16BIT(p + 4); - p += 6; - ares_expand_name(p, packet.data(), packet.size(), &name, &enclen); - ss << prio << " " << weight << " " << port << " '" << name << "'"; - free(name); - break; - } - case ns_t_soa: { - const byte* p = *data; - ares_expand_name(p, packet.data(), packet.size(), &name, &enclen); - ss << " '" << name << "'"; - free(name); - p += enclen; - ares_expand_name(p, packet.data(), packet.size(), &name, &enclen); - ss << " '" << name << "'"; - free(name); - p += enclen; - unsigned long serial = DNS__32BIT(p); - unsigned long refresh = DNS__32BIT(p + 4); - unsigned long retry = DNS__32BIT(p + 8); - unsigned long expire = DNS__32BIT(p + 12); - unsigned long minimum = DNS__32BIT(p + 16); - ss << " " << serial << " " << refresh << " " << retry << " " << expire << " " << minimum; - break; - } - case ns_t_naptr: { - const byte* p = *data; - unsigned long order = DNS__16BIT(p); - unsigned long pref = DNS__16BIT(p + 2); - p += 4; - ss << order << " " << pref; - - int len = *p++; - std::string flags(p, p + len); - ss << " " << flags; - p += len; - - len = *p++; - std::string service(p, p + len); - ss << " '" << service << "'"; - p += len; - - len = *p++; - std::string regexp(p, p + len); - ss << " '" << regexp << "'"; - p += len; - - ares_expand_name(p, packet.data(), packet.size(), &name, &enclen); - ss << " '" << name << "'"; - free(name); - break; - } - default: - ss << " " << HexDump(*data, rdatalen); - break; } *data += rdatalen; *len -= rdatalen;