test: more tests, especially fallback processing

- Make mock server listen on UDP + TCP in parallel.
 - Test UDP->TCP fallback on truncation
 - Test EDNS->no-EDNS fallback
 - Test some environment init options
 - Test nonsense reply

test: short response
pull/34/head
David Drysdale 9 years ago
parent a05f240dea
commit 25adcc413e
  1. 25
      test/ares-test-init.cc
  2. 88
      test/ares-test-mock.cc
  3. 90
      test/ares-test.cc
  4. 6
      test/ares-test.h
  5. 233
      test/dns-proto.cc

@ -222,6 +222,31 @@ TEST_F(LibraryTest, FailChannelInit) {
ares_library_cleanup();
}
TEST_F(LibraryTest, EnvInit) {
ares_channel channel = nullptr;
EnvValue v1("LOCALDOMAIN", "this.is.local");
EnvValue v2("RES_OPTIONS", "options debug ndots:3 retry:3 rotate retrans:2");
EXPECT_EQ(ARES_SUCCESS, ares_init(&channel));
ares_destroy(channel);
}
TEST_F(LibraryTest, EnvInitAllocFail) {
ares_channel channel;
EnvValue v1("LOCALDOMAIN", "this.is.local");
EnvValue v2("RES_OPTIONS", "options debug ndots:3 retry:3 rotate retrans:2");
for (int ii = 1; ii <= 10; ii++) {
ClearFails();
SetAllocFail(ii);
channel = nullptr;
int rc = ares_init(&channel);
if (rc == ARES_SUCCESS) {
ares_destroy(channel);
} else {
EXPECT_EQ(ARES_ENOMEM, rc);
}
}
}
TEST_F(DefaultChannelTest, SetAddresses) {
ares_set_local_ip4(channel_, 0x01020304);
byte addr6[16] = {0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08,

@ -89,6 +89,27 @@ TEST_P(MockUDPChannelTest, ParallelLookups) {
EXPECT_EQ("{'www.google.com' aliases=[] addrs=[2.3.4.5]}", ss3.str());
}
// UDP to TCP specific test
TEST_P(MockUDPChannelTest, TruncationRetry) {
DNSPacket rsptruncated;
rsptruncated.set_response().set_aa().set_tc()
.add_question(new DNSQuestion("www.google.com", ns_t_a));
DNSPacket rspok;
rspok.set_response()
.add_question(new DNSQuestion("www.google.com", ns_t_a))
.add_answer(new DNSARR("www.google.com", 100, {1, 2, 3, 4}));
EXPECT_CALL(server_, OnRequest("www.google.com", ns_t_a))
.WillOnce(SetReply(&server_, &rsptruncated))
.WillOnce(SetReply(&server_, &rspok));
HostResult result;
ares_gethostbyname(channel_, "www.google.com.", AF_INET, HostCallback, &result);
Process();
EXPECT_TRUE(result.done_);
std::stringstream ss;
ss << result.host_;
EXPECT_EQ("{'www.google.com' aliases=[] addrs=[1.2.3.4]}", ss.str());
}
static int sock_cb_count = 0;
static int SocketCallback(ares_socket_t fd, int type, void *data) {
if (verbose) std::cerr << "SocketCallback(" << fd << ") invoked" << std::endl;
@ -119,6 +140,18 @@ TEST_P(MockChannelTest, SockCallback) {
}
// TCP only to prevent retries
TEST_P(MockTCPChannelTest, MalformedResponse) {
std::vector<byte> one = {0x01};
EXPECT_CALL(server_, OnRequest("www.google.com", ns_t_a))
.WillOnce(SetReplyData(&server_, one));
HostResult result;
ares_gethostbyname(channel_, "www.google.com.", AF_INET, HostCallback, &result);
Process();
EXPECT_TRUE(result.done_);
EXPECT_EQ(ARES_ETIMEOUT, result.status_);
}
TEST_P(MockTCPChannelTest, FormErrResponse) {
DNSPacket rsp;
rsp.set_response().set_aa()
@ -192,20 +225,27 @@ TEST_P(MockTCPChannelTest, YXDomainResponse) {
EXPECT_EQ(ARES_ENODATA, result.status_);
}
class MockNoCheckRespChannelTest
class MockFlagsChannelOptsTest
: public MockChannelOptsTest,
public ::testing::WithParamInterface<int> {
public ::testing::WithParamInterface< std::pair<int, bool> > {
public:
MockNoCheckRespChannelTest() : MockChannelOptsTest(GetParam(), true, FillOptions(&opts_), ARES_OPT_FLAGS) {}
static struct ares_options* FillOptions(struct ares_options * opts) {
MockFlagsChannelOptsTest(int flags)
: MockChannelOptsTest(GetParam().first, GetParam().second,
FillOptions(&opts_, flags), ARES_OPT_FLAGS) {}
static struct ares_options* FillOptions(struct ares_options * opts, int flags) {
memset(opts, 0, sizeof(struct ares_options));
opts->flags = ARES_FLAG_NOCHECKRESP;
opts->flags = flags;
return opts;
}
private:
struct ares_options opts_;
};
class MockNoCheckRespChannelTest : public MockFlagsChannelOptsTest {
public:
MockNoCheckRespChannelTest() : MockFlagsChannelOptsTest(ARES_FLAG_NOCHECKRESP) {}
};
TEST_P(MockNoCheckRespChannelTest, ServFailResponse) {
DNSPacket rsp;
rsp.set_response().set_aa()
@ -248,6 +288,31 @@ TEST_P(MockNoCheckRespChannelTest, RefusedResponse) {
EXPECT_EQ(ARES_EREFUSED, result.status_);
}
class MockEDNSChannelTest : public MockFlagsChannelOptsTest {
public:
MockEDNSChannelTest() : MockFlagsChannelOptsTest(ARES_FLAG_EDNS) {}
};
TEST_P(MockEDNSChannelTest, RetryWithoutEDNS) {
DNSPacket rspfail;
rspfail.set_response().set_aa().set_rcode(ns_r_servfail)
.add_question(new DNSQuestion("www.google.com", ns_t_a));
DNSPacket rspok;
rspok.set_response()
.add_question(new DNSQuestion("www.google.com", ns_t_a))
.add_answer(new DNSARR("www.google.com", 100, {1, 2, 3, 4}));
EXPECT_CALL(server_, OnRequest("www.google.com", ns_t_a))
.WillOnce(SetReply(&server_, &rspfail))
.WillOnce(SetReply(&server_, &rspok));
HostResult result;
ares_gethostbyname(channel_, "www.google.com.", AF_INET, HostCallback, &result);
Process();
EXPECT_TRUE(result.done_);
std::stringstream ss;
ss << result.host_;
EXPECT_EQ("{'www.google.com' aliases=[] addrs=[1.2.3.4]}", ss.str());
}
TEST_P(MockChannelTest, SearchDomains) {
DNSPacket nofirst;
nofirst.set_response().set_aa().set_rcode(ns_r_nxdomain)
@ -619,7 +684,7 @@ TEST_P(MockUDPChannelTest, SearchDomainsAllocFail) {
ares_gethostbyname(channel_, "www", AF_INET, HostCallback, result);
Process();
EXPECT_TRUE(result->done_);
if (result->status_ != ARES_ENOMEM) {
if (result->status_ == ARES_SUCCESS) {
std::stringstream ss;
ss << result->host_;
EXPECT_EQ("{'www.third.gov' aliases=[] addrs=[2.3.4.5]}", ss.str()) << " failed alloc #" << ii;
@ -801,7 +866,16 @@ INSTANTIATE_TEST_CASE_P(AddressFamilies, MockTCPChannelTest,
::testing::Values(AF_INET, AF_INET6));
INSTANTIATE_TEST_CASE_P(AddressFamilies, MockNoCheckRespChannelTest,
::testing::Values(AF_INET, AF_INET6));
::testing::Values(std::make_pair<int, bool>(AF_INET, false),
std::make_pair<int, bool>(AF_INET, true),
std::make_pair<int, bool>(AF_INET6, false),
std::make_pair<int, bool>(AF_INET6, true)));
INSTANTIATE_TEST_CASE_P(AddressFamilies, MockEDNSChannelTest,
::testing::Values(std::make_pair<int, bool>(AF_INET, false),
std::make_pair<int, bool>(AF_INET, true),
std::make_pair<int, bool>(AF_INET6, false),
std::make_pair<int, bool>(AF_INET6, true)));
} // namespace test
} // namespace ares

@ -68,6 +68,7 @@ void ProcessWork(ares_channel channel,
// static
void LibraryTest::SetAllocFail(int nth) {
assert(nth > 0);
assert(nth <= (int)(8 * sizeof(fails_)));
fails_ |= (1 << (nth - 1));
}
@ -131,29 +132,29 @@ void DefaultChannelModeTest::Process() {
ProcessWork(channel_, NoExtraFDs, nullptr);
}
MockServer::MockServer(int family, bool tcp, int port) : tcp_(tcp), port_(port), qid_(-1) {
if (tcp_) {
// Create a TCP socket to receive data on.
sockfd_ = socket(family, SOCK_STREAM, 0);
EXPECT_NE(-1, sockfd_);
int optval = 1;
setsockopt(sockfd_, SOL_SOCKET, SO_REUSEADDR,
(const void *)&optval , sizeof(int));
} else {
// Create a UDP socket to receive data on.
sockfd_ = socket(family, SOCK_DGRAM, 0);
EXPECT_NE(-1, sockfd_);
}
MockServer::MockServer(int family, int port) : port_(port), qid_(-1) {
// Create a TCP socket to receive data on.
tcpfd_ = socket(family, SOCK_STREAM, 0);
EXPECT_NE(-1, tcpfd_);
int optval = 1;
setsockopt(tcpfd_, SOL_SOCKET, SO_REUSEADDR,
(const void *)&optval , sizeof(int));
// Create a UDP socket to receive data on.
udpfd_ = socket(family, SOCK_DGRAM, 0);
EXPECT_NE(-1, udpfd_);
// Bind it to the given port.
// Bind the sockets to the given port.
if (family == AF_INET) {
struct sockaddr_in addr;
memset(&addr, 0, sizeof(addr));
addr.sin_family = AF_INET;
addr.sin_addr.s_addr = htonl(INADDR_ANY);
addr.sin_port = htons(port_);
int rc = bind(sockfd_, (struct sockaddr*)&addr, sizeof(addr));
EXPECT_EQ(0, rc) << "Failed to bind AF_INET to port " << port_;
int tcprc = bind(tcpfd_, (struct sockaddr*)&addr, sizeof(addr));
EXPECT_EQ(0, tcprc) << "Failed to bind AF_INET to TCP port " << port_;
int udprc = bind(udpfd_, (struct sockaddr*)&addr, sizeof(addr));
EXPECT_EQ(0, udprc) << "Failed to bind AF_INET to UDP port " << port_;
} else {
EXPECT_EQ(AF_INET6, family);
struct sockaddr_in6 addr;
@ -161,31 +162,31 @@ MockServer::MockServer(int family, bool tcp, int port) : tcp_(tcp), port_(port),
addr.sin6_family = AF_INET6;
addr.sin6_addr = in6addr_any;
addr.sin6_port = htons(port_);
int rc = bind(sockfd_, (struct sockaddr*)&addr, sizeof(addr));
EXPECT_EQ(0, rc) << "Failed to bind AF_INET6 to port " << port_;
int tcprc = bind(tcpfd_, (struct sockaddr*)&addr, sizeof(addr));
EXPECT_EQ(0, tcprc) << "Failed to bind AF_INET6 to UDP port " << port_;
int udprc = bind(udpfd_, (struct sockaddr*)&addr, sizeof(addr));
EXPECT_EQ(0, udprc) << "Failed to bind AF_INET6 to UDP port " << port_;
}
// For TCP, also need to listen for connections.
if (tcp_) {
EXPECT_EQ(0, listen(sockfd_, 5)) << "Failed to listen for TCP connections";
}
EXPECT_EQ(0, listen(tcpfd_, 5)) << "Failed to listen for TCP connections";
}
MockServer::~MockServer() {
for (int fd : connfds_) {
close(fd);
}
close(sockfd_);
sockfd_ = -1;
close(tcpfd_);
close(udpfd_);
}
void MockServer::Process(int fd) {
if (fd != sockfd_ && connfds_.find(fd) == connfds_.end()) {
if (fd != tcpfd_ && fd != udpfd_ && connfds_.find(fd) == connfds_.end()) {
std::cerr << "Asked to process unknown fd " << fd << std::endl;
return;
}
if (tcp_ && fd == sockfd_) {
int connfd = accept(sockfd_, NULL, NULL);
if (fd == tcpfd_) {
int connfd = accept(tcpfd_, NULL, NULL);
if (connfd < 0) {
std::cerr << "Error accepting connection on fd " << fd << std::endl;
} else {
@ -201,7 +202,7 @@ void MockServer::Process(int fd) {
int len = recvfrom(fd, buffer, sizeof(buffer), 0,
(struct sockaddr *)&addr, &addrlen);
byte* data = buffer;
if (tcp_) {
if (fd != udpfd_) {
if (len == 0) {
connfds_.erase(std::find(connfds_.begin(), connfds_.end(), fd));
close(fd);
@ -213,10 +214,10 @@ void MockServer::Process(int fd) {
}
int tcplen = (data[0] << 8) + data[1];
data += 2;
if (tcplen + 2 != len) {
len -= 2;
if (tcplen != len) {
std::cerr << "Warning: TCP length " << tcplen
<< " doesn't match data length (" << len
<< " - 2)" << std::endl;
<< " doesn't match remaining data length " << len << std::endl;
}
}
@ -268,14 +269,19 @@ void MockServer::Process(int fd) {
}
int rrtype = DNS_QUESTION_TYPE(question);
if (verbose) std::cerr << "ProcessRequest(" << qid << ", '" << namestr
<< "', " << RRTypeToString(rrtype) << ")" << std::endl;
if (verbose) {
std::vector<byte> req(data, data + len);
std::cerr << "received " << (fd == udpfd_ ? "UDP" : "TCP") << " request " << PacketToString(req) << std::endl;
std::cerr << "ProcessRequest(" << qid << ", '" << namestr
<< "', " << RRTypeToString(rrtype) << ")" << std::endl;
}
ProcessRequest(fd, &addr, addrlen, qid, namestr, rrtype);
}
std::set<int> MockServer::fds() const {
std::set<int> result = connfds_;
result.insert(sockfd_);
result.insert(tcpfd_);
result.insert(udpfd_);
return result;
}
@ -284,24 +290,26 @@ void MockServer::ProcessRequest(int fd, struct sockaddr_storage* addr, int addrl
// Before processing, let gMock know the request is happening.
OnRequest(name, rrtype);
// Make a local copy of the current pending reply.
if (reply_.size() < 2) {
if (verbose) std::cerr << "Skipping reply as not-present/too-short" << std::endl;
if (reply_.size() == 0) {
return;
}
// Make a local copy of the current pending reply.
std::vector<byte> reply = reply_;
if (qid_ >= 0) {
// Use the explicitly specified query ID.
qid = qid_;
}
// Overwrite the query ID.
reply[0] = (byte)((qid >> 8) & 0xff);
reply[1] = (byte)(qid & 0xff);
if (reply.size() >= 2) {
// Overwrite the query ID if space to do so.
reply[0] = (byte)((qid >> 8) & 0xff);
reply[1] = (byte)(qid & 0xff);
}
if (verbose) std::cerr << "sending reply " << PacketToString(reply) << std::endl;
// Prefix with 2-byte length if TCP.
if (tcp_) {
if (fd != udpfd_) {
int len = reply.size();
std::vector<byte> vlen = {(byte)((len & 0xFF00) >> 8), (byte)(len & 0xFF)};
reply.insert(reply.begin(), vlen.begin(), vlen.end());
@ -321,7 +329,7 @@ MockChannelOptsTest::MockChannelOptsTest(int family,
bool force_tcp,
struct ares_options* givenopts,
int optmask)
: server_(family, force_tcp, mock_port), channel_(nullptr) {
: server_(family, mock_port), channel_(nullptr) {
// Set up channel options.
struct ares_options opts;
if (givenopts) {

@ -114,7 +114,7 @@ class DefaultChannelModeTest
// Mock DNS server to allow responses to be scripted by tests.
class MockServer {
public:
MockServer(int family, bool tcp, int port);
MockServer(int family, int port);
~MockServer();
// Mock method indicating the processing of a particular <name, RRtype>
@ -140,9 +140,9 @@ class MockServer {
void ProcessRequest(int fd, struct sockaddr_storage* addr, int addrlen,
int qid, const std::string& name, int rrtype);
bool tcp_;
int port_;
int sockfd_;
int udpfd_;
int tcpfd_;
std::set<int> connfds_;
std::vector<byte> reply_;
int qid_;

@ -232,7 +232,11 @@ std::string QuestionToString(const std::vector<byte>& packet,
char *name = nullptr;
long enclen;
ares_expand_name(*data, packet.data(), packet.size(), &name, &enclen);
int rc = ares_expand_name(*data, packet.data(), packet.size(), &name, &enclen);
if (rc != ARES_SUCCESS) {
ss << "(error from ares_expand_name)";
return ss.str();
}
if (enclen > *len) {
ss << "(error, encoded name len " << enclen << "bigger than remaining data " << *len << " bytes)";
return ss.str();
@ -264,7 +268,11 @@ std::string RRToString(const std::vector<byte>& packet,
char *name = nullptr;
long enclen;
ares_expand_name(*data, packet.data(), packet.size(), &name, &enclen);
int rc = ares_expand_name(*data, packet.data(), packet.size(), &name, &enclen);
if (rc != ARES_SUCCESS) {
ss << "(error from ares_expand_name)";
return ss.str();
}
if (enclen > *len) {
ss << "(error, encoded name len " << enclen << "bigger than remaining data " << *len << " bytes)";
return ss.str();
@ -275,7 +283,7 @@ std::string RRToString(const std::vector<byte>& packet,
free(name);
name = nullptr;
if (*len < NS_QFIXEDSZ) {
if (*len < NS_RRFIXEDSZ) {
ss << "(too short, len left " << *len << ")";
return ss.str();
}
@ -293,92 +301,141 @@ std::string RRToString(const std::vector<byte>& packet,
*data += NS_RRFIXEDSZ;
*len -= NS_RRFIXEDSZ;
switch (rrtype) {
case ns_t_a:
case ns_t_aaaa:
ss << " " << AddressToString(*data, rdatalen);
break;
case ns_t_txt: {
const byte* p = *data;
while (p < (*data + rdatalen)) {
int len = *p++;
std::string txt(p, p + len);
ss << " " << len << ":'" << txt << "'";
p += len;
if (*len < rdatalen) {
ss << "(RR too long at " << rdatalen << ", len left " << *len << ")";
} else {
switch (rrtype) {
case ns_t_a:
case ns_t_aaaa:
ss << " " << AddressToString(*data, rdatalen);
break;
case ns_t_txt: {
const byte* p = *data;
while (p < (*data + rdatalen)) {
int len = *p++;
if ((p + len) <= (*data + rdatalen)) {
std::string txt(p, p + len);
ss << " " << len << ":'" << txt << "'";
} else {
ss << "(string too long)";
}
p += len;
}
break;
}
case ns_t_cname:
case ns_t_ns:
case ns_t_ptr: {
int rc = ares_expand_name(*data, packet.data(), packet.size(), &name, &enclen);
if (rc != ARES_SUCCESS) {
ss << "(error from ares_expand_name)";
break;
}
ss << " '" << name << "'";
free(name);
break;
}
case ns_t_mx:
if (rdatalen > 2) {
int rc = ares_expand_name(*data + 2, packet.data(), packet.size(), &name, &enclen);
if (rc != ARES_SUCCESS) {
ss << "(error from ares_expand_name)";
break;
}
ss << " " << DNS__16BIT(*data) << " '" << name << "'";
free(name);
} else {
ss << "(RR too short)";
}
break;
case ns_t_srv: {
if (rdatalen > 6) {
const byte* p = *data;
unsigned long prio = DNS__16BIT(p);
unsigned long weight = DNS__16BIT(p + 2);
unsigned long port = DNS__16BIT(p + 4);
p += 6;
int rc = ares_expand_name(p, packet.data(), packet.size(), &name, &enclen);
if (rc != ARES_SUCCESS) {
ss << "(error from ares_expand_name)";
break;
}
ss << prio << " " << weight << " " << port << " '" << name << "'";
free(name);
} else {
ss << "(RR too short)";
}
break;
}
case ns_t_soa: {
const byte* p = *data;
int rc = ares_expand_name(p, packet.data(), packet.size(), &name, &enclen);
if (rc != ARES_SUCCESS) {
ss << "(error from ares_expand_name)";
break;
}
ss << " '" << name << "'";
free(name);
p += enclen;
rc = ares_expand_name(p, packet.data(), packet.size(), &name, &enclen);
if (rc != ARES_SUCCESS) {
ss << "(error from ares_expand_name)";
break;
}
ss << " '" << name << "'";
free(name);
p += enclen;
if ((p + 20) <= (*data + rdatalen)) {
unsigned long serial = DNS__32BIT(p);
unsigned long refresh = DNS__32BIT(p + 4);
unsigned long retry = DNS__32BIT(p + 8);
unsigned long expire = DNS__32BIT(p + 12);
unsigned long minimum = DNS__32BIT(p + 16);
ss << " " << serial << " " << refresh << " " << retry << " " << expire << " " << minimum;
} else {
ss << "(RR too short)";
}
break;
}
case ns_t_naptr: {
if (rdatalen > 7) {
const byte* p = *data;
unsigned long order = DNS__16BIT(p);
unsigned long pref = DNS__16BIT(p + 2);
p += 4;
ss << order << " " << pref;
int len = *p++;
std::string flags(p, p + len);
ss << " " << flags;
p += len;
len = *p++;
std::string service(p, p + len);
ss << " '" << service << "'";
p += len;
len = *p++;
std::string regexp(p, p + len);
ss << " '" << regexp << "'";
p += len;
int rc = ares_expand_name(p, packet.data(), packet.size(), &name, &enclen);
if (rc != ARES_SUCCESS) {
ss << "(error from ares_expand_name)";
break;
}
ss << " '" << name << "'";
free(name);
} else {
ss << "(RR too short)";
}
break;
}
default:
ss << " " << HexDump(*data, rdatalen);
break;
}
break;
}
case ns_t_cname:
case ns_t_ns:
case ns_t_ptr:
ares_expand_name(*data, packet.data(), packet.size(), &name, &enclen);
ss << " '" << name << "'";
free(name);
break;
case ns_t_mx:
ares_expand_name(*data + 2, packet.data(), packet.size(), &name, &enclen);
ss << " " << DNS__16BIT(*data) << " '" << name << "'";
free(name);
break;
case ns_t_srv: {
const byte* p = *data;
unsigned long prio = DNS__16BIT(p);
unsigned long weight = DNS__16BIT(p + 2);
unsigned long port = DNS__16BIT(p + 4);
p += 6;
ares_expand_name(p, packet.data(), packet.size(), &name, &enclen);
ss << prio << " " << weight << " " << port << " '" << name << "'";
free(name);
break;
}
case ns_t_soa: {
const byte* p = *data;
ares_expand_name(p, packet.data(), packet.size(), &name, &enclen);
ss << " '" << name << "'";
free(name);
p += enclen;
ares_expand_name(p, packet.data(), packet.size(), &name, &enclen);
ss << " '" << name << "'";
free(name);
p += enclen;
unsigned long serial = DNS__32BIT(p);
unsigned long refresh = DNS__32BIT(p + 4);
unsigned long retry = DNS__32BIT(p + 8);
unsigned long expire = DNS__32BIT(p + 12);
unsigned long minimum = DNS__32BIT(p + 16);
ss << " " << serial << " " << refresh << " " << retry << " " << expire << " " << minimum;
break;
}
case ns_t_naptr: {
const byte* p = *data;
unsigned long order = DNS__16BIT(p);
unsigned long pref = DNS__16BIT(p + 2);
p += 4;
ss << order << " " << pref;
int len = *p++;
std::string flags(p, p + len);
ss << " " << flags;
p += len;
len = *p++;
std::string service(p, p + len);
ss << " '" << service << "'";
p += len;
len = *p++;
std::string regexp(p, p + len);
ss << " '" << regexp << "'";
p += len;
ares_expand_name(p, packet.data(), packet.size(), &name, &enclen);
ss << " '" << name << "'";
free(name);
break;
}
default:
ss << " " << HexDump(*data, rdatalen);
break;
}
*data += rdatalen;
*len -= rdatalen;

Loading…
Cancel
Save