|
|
|
#!/usr/bin/env python2.7
|
|
|
|
# Copyright 2015 gRPC authors.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
"""Starts a local DNS server for use in tests"""
|
|
|
|
|
|
|
|
import argparse
|
|
|
|
import sys
|
|
|
|
import yaml
|
|
|
|
import signal
|
|
|
|
import os
|
|
|
|
import threading
|
|
|
|
import time
|
|
|
|
|
|
|
|
import twisted
|
|
|
|
import twisted.internet
|
|
|
|
import twisted.internet.reactor
|
|
|
|
import twisted.internet.threads
|
|
|
|
import twisted.internet.defer
|
|
|
|
import twisted.internet.protocol
|
|
|
|
import twisted.names
|
|
|
|
import twisted.names.client
|
|
|
|
import twisted.names.dns
|
|
|
|
import twisted.names.server
|
|
|
|
from twisted.names import client, server, common, authority, dns
|
|
|
|
import argparse
|
|
|
|
import platform
|
|
|
|
|
|
|
|
_SERVER_HEALTH_CHECK_RECORD_NAME = 'health-check-local-dns-server-is-alive.resolver-tests.grpctestingexp' # missing end '.' for twisted syntax
|
|
|
|
_SERVER_HEALTH_CHECK_RECORD_DATA = '123.123.123.123'
|
|
|
|
|
|
|
|
|
|
|
|
class NoFileAuthority(authority.FileAuthority):
|
|
|
|
|
|
|
|
def __init__(self, soa, records):
|
|
|
|
# skip FileAuthority
|
|
|
|
common.ResolverBase.__init__(self)
|
|
|
|
self.soa = soa
|
|
|
|
self.records = records
|
|
|
|
|
|
|
|
|
|
|
|
def start_local_dns_server(args):
|
|
|
|
all_records = {}
|
|
|
|
|
|
|
|
def _push_record(name, r):
|
|
|
|
name = bytes(name.encode('ascii'))
|
|
|
|
print('pushing record: |%s|' % name)
|
|
|
|
if all_records.get(name) is not None:
|
|
|
|
all_records[name].append(r)
|
|
|
|
return
|
|
|
|
all_records[name] = [r]
|
|
|
|
|
|
|
|
def _maybe_split_up_txt_data(name, txt_data, r_ttl):
|
|
|
|
txt_data = bytes(txt_data.encode('ascii'))
|
|
|
|
start = 0
|
|
|
|
txt_data_list = []
|
|
|
|
while len(txt_data[start:]) > 0:
|
|
|
|
next_read = len(txt_data[start:])
|
|
|
|
if next_read > 255:
|
|
|
|
next_read = 255
|
|
|
|
txt_data_list.append(txt_data[start:start + next_read])
|
|
|
|
start += next_read
|
|
|
|
_push_record(name, dns.Record_TXT(*txt_data_list, ttl=r_ttl))
|
|
|
|
|
|
|
|
with open(args.records_config_path) as config:
|
|
|
|
test_records_config = yaml.load(config)
|
|
|
|
common_zone_name = test_records_config['resolver_tests_common_zone_name']
|
|
|
|
for group in test_records_config['resolver_component_tests']:
|
|
|
|
for name in group['records'].keys():
|
|
|
|
for record in group['records'][name]:
|
|
|
|
r_type = record['type']
|
|
|
|
r_data = record['data']
|
|
|
|
r_ttl = int(record['TTL'])
|
|
|
|
record_full_name = '%s.%s' % (name, common_zone_name)
|
|
|
|
assert record_full_name[-1] == '.'
|
|
|
|
record_full_name = record_full_name[:-1]
|
|
|
|
if r_type == 'A':
|
|
|
|
_push_record(record_full_name,
|
|
|
|
dns.Record_A(r_data, ttl=r_ttl))
|
|
|
|
if r_type == 'AAAA':
|
|
|
|
_push_record(record_full_name,
|
|
|
|
dns.Record_AAAA(r_data, ttl=r_ttl))
|
|
|
|
if r_type == 'SRV':
|
|
|
|
p, w, port, target = r_data.split(' ')
|
|
|
|
p = int(p)
|
|
|
|
w = int(w)
|
|
|
|
port = int(port)
|
|
|
|
target_full_name = bytes(
|
|
|
|
('%s.%s' % (target, common_zone_name)).encode('ascii'))
|
|
|
|
_push_record(
|
|
|
|
record_full_name,
|
|
|
|
dns.Record_SRV(p, w, port, target_full_name, ttl=r_ttl))
|
|
|
|
if r_type == 'TXT':
|
|
|
|
_maybe_split_up_txt_data(record_full_name, r_data, r_ttl)
|
|
|
|
# Add an optional IPv4 record is specified
|
|
|
|
if args.add_a_record:
|
|
|
|
extra_host, extra_host_ipv4 = args.add_a_record.split(':')
|
|
|
|
_push_record(extra_host, dns.Record_A(extra_host_ipv4, ttl=0))
|
|
|
|
# Server health check record
|
|
|
|
_push_record(_SERVER_HEALTH_CHECK_RECORD_NAME,
|
|
|
|
dns.Record_A(_SERVER_HEALTH_CHECK_RECORD_DATA, ttl=0))
|
|
|
|
soa_record = dns.Record_SOA(mname=bytes(common_zone_name.encode('ascii')))
|
|
|
|
test_domain_com = NoFileAuthority(
|
|
|
|
soa=(bytes(common_zone_name.encode('ascii')), soa_record),
|
|
|
|
records=all_records,
|
|
|
|
)
|
|
|
|
server = twisted.names.server.DNSServerFactory(
|
|
|
|
authorities=[test_domain_com], verbose=2)
|
|
|
|
server.noisy = 2
|
|
|
|
twisted.internet.reactor.listenTCP(args.port, server)
|
|
|
|
dns_proto = twisted.names.dns.DNSDatagramProtocol(server)
|
|
|
|
dns_proto.noisy = 2
|
|
|
|
twisted.internet.reactor.listenUDP(args.port, dns_proto)
|
|
|
|
print('starting local dns server on 127.0.0.1:%s' % args.port)
|
|
|
|
print('starting twisted.internet.reactor')
|
|
|
|
twisted.internet.reactor.suggestThreadPoolSize(1)
|
|
|
|
twisted.internet.reactor.run()
|
|
|
|
|
|
|
|
|
|
|
|
def _quit_on_signal(signum, _frame):
|
|
|
|
print('Received SIGNAL %d. Quitting with exit code 0' % signum)
|
|
|
|
twisted.internet.reactor.stop()
|
|
|
|
sys.stdout.flush()
|
|
|
|
sys.exit(0)
|
|
|
|
|
|
|
|
|
|
|
|
def flush_stdout_loop():
|
|
|
|
num_timeouts_so_far = 0
|
|
|
|
sleep_time = 1
|
|
|
|
# Prevent zombies. Tests that use this server are short-lived.
|
|
|
|
max_timeouts = 60 * 10
|
|
|
|
while num_timeouts_so_far < max_timeouts:
|
|
|
|
sys.stdout.flush()
|
|
|
|
time.sleep(sleep_time)
|
|
|
|
num_timeouts_so_far += 1
|
|
|
|
print('Process timeout reached, or cancelled. Exitting 0.')
|
|
|
|
os.kill(os.getpid(), signal.SIGTERM)
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
argp = argparse.ArgumentParser(
|
|
|
|
description='Local DNS Server for resolver tests')
|
|
|
|
argp.add_argument('-p',
|
|
|
|
'--port',
|
|
|
|
default=None,
|
|
|
|
type=int,
|
|
|
|
help='Port for DNS server to listen on for TCP and UDP.')
|
|
|
|
argp.add_argument(
|
|
|
|
'-r',
|
|
|
|
'--records_config_path',
|
|
|
|
default=None,
|
|
|
|
type=str,
|
|
|
|
help=('Directory of resolver_test_record_groups.yaml file. '
|
|
|
|
'Defaults to path needed when the test is invoked as part '
|
|
|
|
'of run_tests.py.'))
|
|
|
|
argp.add_argument(
|
|
|
|
'--add_a_record',
|
|
|
|
default=None,
|
|
|
|
type=str,
|
|
|
|
help=('Add an A record via the command line. Useful for when we '
|
|
|
|
'need to serve a one-off A record that is under a '
|
|
|
|
'different domain then the rest the records configured in '
|
|
|
|
'--records_config_path (which all need to be under the '
|
|
|
|
'same domain). Format: <name>:<ipv4 address>'))
|
|
|
|
args = argp.parse_args()
|
|
|
|
signal.signal(signal.SIGTERM, _quit_on_signal)
|
|
|
|
signal.signal(signal.SIGINT, _quit_on_signal)
|
|
|
|
output_flush_thread = threading.Thread(target=flush_stdout_loop)
|
|
|
|
output_flush_thread.setDaemon(True)
|
|
|
|
output_flush_thread.start()
|
|
|
|
start_local_dns_server(args)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
main()
|