From 28d198008a6da19d28f0c34f7b3a4441f9c9d9ff Mon Sep 17 00:00:00 2001 From: Makarand Dharmapurikar Date: Thu, 1 Dec 2016 10:56:52 -0800 Subject: [PATCH] minor cleanup.. --- test/http2_test/http2_base_server.py | 19 ++-- test/http2_test/http2_test_server.py | 149 +++------------------------ 2 files changed, 26 insertions(+), 142 deletions(-) diff --git a/test/http2_test/http2_base_server.py b/test/http2_test/http2_base_server.py index 07bd37cae9a..91caa74fcc2 100644 --- a/test/http2_test/http2_base_server.py +++ b/test/http2_test/http2_base_server.py @@ -1,16 +1,11 @@ import struct import messages_pb2 -import functools -import argparse import logging -import time -from twisted.internet.defer import Deferred, inlineCallbacks -from twisted.internet.protocol import Protocol, Factory -from twisted.internet import endpoints, reactor, error, defer +from twisted.internet.protocol import Protocol +from twisted.internet import reactor from h2.connection import H2Connection from h2.events import RequestReceived, DataReceived, WindowUpdated, RemoteSettingsChanged, PingAcknowledged -from threading import Lock READ_CHUNK_SIZE = 16384 GRPC_HEADER_SIZE = 5 @@ -20,6 +15,7 @@ class H2ProtocolBaseServer(Protocol): self._conn = H2Connection(client_side=False) self._recv_buffer = '' self._handlers = {} + self._handlers['ConnectionMade'] = self.on_connection_made_default self._handlers['DataReceived'] = self.on_data_received_default self._handlers['WindowUpdated'] = self.on_window_update_default self._handlers['RequestReceived'] = self.on_request_received_default @@ -33,14 +29,17 @@ class H2ProtocolBaseServer(Protocol): self._handlers = handlers def connectionMade(self): + self._handlers['ConnectionMade']() + + def connectionLost(self, reason): + self._handlers['ConnectionLost'](reason) + + def on_connection_made_default(self): logging.info('Connection Made') self._conn.initiate_connection() self.transport.setTcpNoDelay(True) self.transport.write(self._conn.data_to_send()) - def connectionLost(self, reason): - self._handlers['ConnectionLost'](reason) - def on_connection_lost(self, reason): logging.info('Disconnected %s'%reason) reactor.callFromThread(reactor.stop) diff --git a/test/http2_test/http2_test_server.py b/test/http2_test/http2_test_server.py index be5f1593ebc..7ec781d2aa5 100644 --- a/test/http2_test/http2_test_server.py +++ b/test/http2_test/http2_test_server.py @@ -1,135 +1,17 @@ """ HTTP2 Test Server. Highly experimental work in progress. """ -import struct -import messages_pb2 import argparse import logging -import time -from twisted.internet.defer import Deferred, inlineCallbacks -from twisted.internet.protocol import Protocol, Factory -from twisted.internet import endpoints, reactor, error, defer -from h2.connection import H2Connection -from h2.events import RequestReceived, DataReceived, WindowUpdated, RemoteSettingsChanged -from threading import Lock +from twisted.internet.protocol import Factory +from twisted.internet import endpoints, reactor import http2_base_server - -READ_CHUNK_SIZE = 16384 -GRPC_HEADER_SIZE = 5 - -class TestcaseRstStreamAfterHeader(object): - def __init__(self): - self._base_server = http2_base_server.H2ProtocolBaseServer() - self._base_server._handlers['RequestReceived'] = self.on_request_received - - def get_base_server(self): - return self._base_server - - def on_request_received(self, event): - # send initial headers - self._base_server.on_request_received_default(event) - # send reset stream - self._base_server.send_reset_stream() - -class TestcaseRstStreamAfterData(object): - def __init__(self): - self._base_server = http2_base_server.H2ProtocolBaseServer() - self._base_server._handlers['DataReceived'] = self.on_data_received - - def get_base_server(self): - return self._base_server - - def on_data_received(self, event): - self._base_server.on_data_received_default(event) - sr = self._base_server.parse_received_data(self._base_server._recv_buffer) - assert(sr is not None) - assert(sr.response_size <= 2048) # so it can fit into one flow control window - response_data = self._base_server.default_response_data(sr.response_size) - self._ready_to_send = True - self._base_server.setup_send(response_data) - # send reset stream - self._base_server.send_reset_stream() - -class TestcaseGoaway(object): - """ - Process incoming request normally. After sending trailer response, - send GOAWAY with stream id = 1. - assert that the next request is made on a different connection. - """ - def __init__(self, iteration): - self._base_server = http2_base_server.H2ProtocolBaseServer() - self._base_server._handlers['RequestReceived'] = self.on_request_received - self._base_server._handlers['DataReceived'] = self.on_data_received - self._base_server._handlers['WindowUpdated'] = self.on_window_update_default - self._base_server._handlers['SendDone'] = self.on_send_done - self._base_server._handlers['ConnectionLost'] = self.on_connection_lost - self._ready_to_send = False - self._iteration = iteration - - def get_base_server(self): - return self._base_server - - def on_connection_lost(self, reason): - logging.info('Disconnect received. Count %d'%self._iteration) - # _iteration == 2 => Two different connections have been used. - if self._iteration == 2: - self._base_server.on_connection_lost(reason) - - def on_send_done(self): - self._base_server.on_send_done_default() - if self._base_server._stream_id == 1: - logging.info('Sending GOAWAY for stream 1') - self._base_server._conn.close_connection(error_code=0, additional_data=None, last_stream_id=1) - - def on_request_received(self, event): - self._ready_to_send = False - self._base_server.on_request_received_default(event) - - def on_data_received(self, event): - self._base_server.on_data_received_default(event) - sr = self._base_server.parse_received_data(self._base_server._recv_buffer) - if sr: - time.sleep(1) - logging.info('Creating response size = %s'%sr.response_size) - response_data = self._base_server.default_response_data(sr.response_size) - self._ready_to_send = True - self._base_server.setup_send(response_data) - - def on_window_update_default(self, event): - if self._ready_to_send: - self._base_server.default_send() - -class TestcasePing(object): - """ - """ - def __init__(self, iteration): - self._base_server = http2_base_server.H2ProtocolBaseServer() - self._base_server._handlers['RequestReceived'] = self.on_request_received - self._base_server._handlers['DataReceived'] = self.on_data_received - self._base_server._handlers['ConnectionLost'] = self.on_connection_lost - - def get_base_server(self): - return self._base_server - - def on_request_received(self, event): - self._base_server.default_ping() - self._base_server.on_request_received_default(event) - self._base_server.default_ping() - - def on_data_received(self, event): - self._base_server.on_data_received_default(event) - sr = self._base_server.parse_received_data(self._base_server._recv_buffer) - logging.info('Creating response size = %s'%sr.response_size) - response_data = self._base_server.default_response_data(sr.response_size) - self._base_server.default_ping() - self._base_server.setup_send(response_data) - self._base_server.default_ping() - - def on_connection_lost(self, reason): - logging.info('Disconnect received. Ping Count %d'%self._base_server._outstanding_pings) - assert(self._base_server._outstanding_pings == 0) - self._base_server.on_connection_lost(reason) +import test_rst_after_header +import test_rst_after_data +import test_goaway +import test_ping +import test_max_streams class H2Factory(Factory): def __init__(self, testcase): @@ -139,15 +21,18 @@ class H2Factory(Factory): def buildProtocol(self, addr): self._num_streams += 1 - if self._testcase == 'rst_stream_after_header': - t = TestcaseRstStreamAfterHeader(self._num_streams) - elif self._testcase == 'rst_stream_after_data': - t = TestcaseRstStreamAfterData(self._num_streams) + if self._testcase == 'rst_after_header': + t = test_rst_after_header.TestcaseRstStreamAfterHeader() + elif self._testcase == 'rst_after_data': + t = test_rst_after_data.TestcaseRstStreamAfterData() elif self._testcase == 'goaway': - t = TestcaseGoaway(self._num_streams) + t = test_goaway.TestcaseGoaway(self._num_streams) elif self._testcase == 'ping': - t = TestcasePing(self._num_streams) + t = test_ping.TestcasePing() + elif self._testcase == 'max_streams': + t = TestcaseSettingsMaxStreams(self._num_streams) else: + logging.error('Unknown test case: %s'%self._testcase) assert(0) return t.get_base_server() @@ -157,7 +42,7 @@ if __name__ == "__main__": parser.add_argument("test") parser.add_argument("port") args = parser.parse_args() - if args.test not in ['rst_stream_after_header', 'rst_stream_after_data', 'goaway', 'ping']: + if args.test not in ['rst_after_header', 'rst_after_data', 'goaway', 'ping', 'max_streams']: print 'unknown test: ', args.test endpoint = endpoints.TCP4ServerEndpoint(reactor, int(args.port), backlog=128) endpoint.listen(H2Factory(args.test))