added new test (rst_during_data)

pull/8900/head
Makarand Dharmapurikar 8 years ago
parent 5b7070a15b
commit a16ea7f9b1
  1. 41
      test/http2_test/http2_test_server.py
  2. 2
      test/http2_test/test_goaway.py
  3. 1
      test/http2_test/test_rst_after_data.py
  4. 30
      test/http2_test/test_rst_during_data.py

@ -9,10 +9,20 @@ from twisted.internet import endpoints, reactor
import http2_base_server
import test_rst_after_header
import test_rst_after_data
import test_rst_during_data
import test_goaway
import test_ping
import test_max_streams
test_case_mappings = {
'rst_after_header': test_rst_after_header.TestcaseRstStreamAfterHeader,
'rst_after_data': test_rst_after_data.TestcaseRstStreamAfterData,
'rst_during_data': test_rst_during_data.TestcaseRstStreamDuringData,
'goaway': test_goaway.TestcaseGoaway,
'ping': test_ping.TestcasePing,
'max_streams': test_max_streams.TestcaseSettingsMaxStreams,
}
class H2Factory(Factory):
def __init__(self, testcase):
logging.info('In H2Factory')
@ -22,20 +32,16 @@ class H2Factory(Factory):
def buildProtocol(self, addr):
self._num_streams += 1
logging.info('New Connection: %d'%self._num_streams)
if self._testcase == 'rst_after_header':
t = test_rst_after_header.TestcaseRstStreamAfterHeader()
elif self._testcase == 'rst_after_data':
t = test_rst_after_data.TestcaseRstStreamAfterData()
elif self._testcase == 'goaway':
t = test_goaway.TestcaseGoaway(self._num_streams)
elif self._testcase == 'ping':
t = test_ping.TestcasePing()
elif self._testcase == 'max_streams':
t = test_max_streams.TestcaseSettingsMaxStreams()
else:
if not test_case_mappings.has_key(self._testcase):
logging.error('Unknown test case: %s'%self._testcase)
assert(0)
return t.get_base_server()
else:
t = test_case_mappings[self._testcase]
if self._testcase == 'goaway':
return t(self._num_streams).get_base_server()
else:
return t().get_base_server()
if __name__ == "__main__":
logging.basicConfig(format = "%(levelname) -10s %(asctime)s %(module)s:%(lineno)s | %(message)s", level=logging.INFO)
@ -43,8 +49,9 @@ if __name__ == "__main__":
parser.add_argument("test")
parser.add_argument("port")
args = parser.parse_args()
if args.test not in ['rst_after_header', 'rst_after_data', 'goaway', 'ping', 'max_streams']:
print 'unknown test: ', args.test
endpoint = endpoints.TCP4ServerEndpoint(reactor, int(args.port), backlog=128)
endpoint.listen(H2Factory(args.test))
reactor.run()
if args.test not in test_case_mappings.keys():
logging.error('unknown test: %s'%args.test)
else:
endpoint = endpoints.TCP4ServerEndpoint(reactor, int(args.port), backlog=128)
endpoint.listen(H2Factory(args.test))
reactor.run()

@ -24,7 +24,7 @@ class TestcaseGoaway(object):
def on_connection_lost(self, reason):
logging.info('Disconnect received. Count %d'%self._iteration)
# _iteration == 2 => Two different connections have been used.
if self._iteration == 200:
if self._iteration == 2:
self._base_server.on_connection_lost(reason)
def on_send_done(self, stream_id):

@ -4,6 +4,7 @@ class TestcaseRstStreamAfterData(object):
"""
In response to an incoming request, this test sends headers, followed by
data, followed by a reset stream frame. Client asserts that the RPC failed.
Client needs to deliver the complete message to the application layer.
"""
def __init__(self):
self._base_server = http2_base_server.H2ProtocolBaseServer()

@ -0,0 +1,30 @@
import http2_base_server
class TestcaseRstStreamDuringData(object):
"""
In response to an incoming request, this test sends headers, followed by
some data, followed by a reset stream frame. Client asserts that the RPC
failed and does not deliver the message to the application.
"""
def __init__(self):
self._base_server = http2_base_server.H2ProtocolBaseServer()
self._base_server._handlers['DataReceived'] = self.on_data_received
self._base_server._handlers['SendDone'] = self.on_send_done
def get_base_server(self):
return self._base_server
def on_data_received(self, event):
self._base_server.on_data_received_default(event)
sr = self._base_server.parse_received_data(event.stream_id)
if sr:
response_data = self._base_server.default_response_data(sr.response_size)
self._ready_to_send = True
response_len = len(response_data)
truncated_response_data = response_data[0:response_len/2]
self._base_server.setup_send(truncated_response_data, event.stream_id)
# send reset stream
def on_send_done(self, stream_id):
self._base_server.send_reset_stream()
self._base_server._stream_status[stream_id] = False
Loading…
Cancel
Save