fixed feedback from review

pull/8900/head
Makarand Dharmapurikar 8 years ago
parent a16ea7f9b1
commit 8c57917f56
  1. 65
      test/http2_test/http2_base_server.py
  2. 36
      test/http2_test/http2_test_server.py
  3. 10
      test/http2_test/test_goaway.py
  4. 9
      test/http2_test/test_max_streams.py
  5. 5
      test/http2_test/test_ping.py
  6. 1
      test/http2_test/test_rst_during_data.py

@ -1,19 +1,19 @@
import struct
import messages_pb2
import logging
import messages_pb2
import struct
from twisted.internet.protocol import Protocol
from twisted.internet import reactor
from h2.connection import H2Connection
from h2.events import RequestReceived, DataReceived, WindowUpdated, RemoteSettingsChanged, PingAcknowledged
from h2.exceptions import ProtocolError
import h2
import h2.connection
import twisted
import twisted.internet
import twisted.internet.protocol
READ_CHUNK_SIZE = 16384
GRPC_HEADER_SIZE = 5
_READ_CHUNK_SIZE = 16384
_GRPC_HEADER_SIZE = 5
class H2ProtocolBaseServer(Protocol):
class H2ProtocolBaseServer(twisted.internet.protocol.Protocol):
def __init__(self):
self._conn = H2Connection(client_side=False)
self._conn = h2.connection.H2Connection(client_side=False)
self._recv_buffer = {}
self._handlers = {}
self._handlers['ConnectionMade'] = self.on_connection_made_default
@ -43,34 +43,35 @@ class H2ProtocolBaseServer(Protocol):
self.transport.write(self._conn.data_to_send())
def on_connection_lost(self, reason):
logging.info('Disconnected %s'%reason)
reactor.callFromThread(reactor.stop)
logging.info('Disconnected %s' % reason)
twisted.internet.reactor.callFromThread(twisted.internet.reactor.stop)
def dataReceived(self, data):
try:
events = self._conn.receive_data(data)
except ProtocolError:
except h2.exceptions.ProtocolError:
# this try/except block catches exceptions due to race between sending
# GOAWAY and processing a response in flight.
return
if self._conn.data_to_send:
self.transport.write(self._conn.data_to_send())
for event in events:
if isinstance(event, RequestReceived) and self._handlers.has_key('RequestReceived'):
logging.info('RequestReceived Event for stream: %d'%event.stream_id)
if isinstance(event, h2.events.RequestReceived) and self._handlers.has_key('RequestReceived'):
logging.info('RequestReceived Event for stream: %d' % event.stream_id)
self._handlers['RequestReceived'](event)
elif isinstance(event, DataReceived) and self._handlers.has_key('DataReceived'):
logging.info('DataReceived Event for stream: %d'%event.stream_id)
elif isinstance(event, h2.events.DataReceived) and self._handlers.has_key('DataReceived'):
logging.info('DataReceived Event for stream: %d' % event.stream_id)
self._handlers['DataReceived'](event)
elif isinstance(event, WindowUpdated) and self._handlers.has_key('WindowUpdated'):
logging.info('WindowUpdated Event for stream: %d'%event.stream_id)
elif isinstance(event, h2.events.WindowUpdated) and self._handlers.has_key('WindowUpdated'):
logging.info('WindowUpdated Event for stream: %d' % event.stream_id)
self._handlers['WindowUpdated'](event)
elif isinstance(event, PingAcknowledged) and self._handlers.has_key('PingAcknowledged'):
elif isinstance(event, h2.events.PingAcknowledged) and self._handlers.has_key('PingAcknowledged'):
logging.info('PingAcknowledged Event')
self._handlers['PingAcknowledged'](event)
self.transport.write(self._conn.data_to_send())
def on_ping_acknowledged_default(self, event):
logging.info('ping acknowledged')
self._outstanding_pings -= 1
def on_data_received_default(self, event):
@ -101,7 +102,7 @@ class H2ProtocolBaseServer(Protocol):
self.transport.write(self._conn.data_to_send())
def setup_send(self, data_to_send, stream_id):
logging.info('Setting up data to send for stream_id: %d'%stream_id)
logging.info('Setting up data to send for stream_id: %d' % stream_id)
self._send_remaining[stream_id] = len(data_to_send)
self._send_offset = 0
self._data_to_send = data_to_send
@ -116,16 +117,16 @@ class H2ProtocolBaseServer(Protocol):
lfcw = self._conn.local_flow_control_window(stream_id)
if lfcw == 0:
break
chunk_size = min(lfcw, READ_CHUNK_SIZE)
chunk_size = min(lfcw, _READ_CHUNK_SIZE)
bytes_to_send = min(chunk_size, self._send_remaining[stream_id])
logging.info('flow_control_window = %d. sending [%d:%d] stream_id %d'%
logging.info('flow_control_window = %d. sending [%d:%d] stream_id %d' %
(lfcw, self._send_offset, self._send_offset + bytes_to_send,
stream_id))
data = self._data_to_send[self._send_offset : self._send_offset + bytes_to_send]
try:
self._conn.send_data(stream_id, data, False)
except ProtocolError:
logging.info('Stream %d is closed'%stream_id)
except h2.exceptions.ProtocolError:
logging.info('Stream %d is closed' % stream_id)
break
self._send_remaining[stream_id] -= bytes_to_send
self._send_offset += bytes_to_send
@ -133,6 +134,7 @@ class H2ProtocolBaseServer(Protocol):
self._handlers['SendDone'](stream_id)
def default_ping(self):
logging.info('sending ping')
self._outstanding_pings += 1
self._conn.ping(b'\x00'*8)
self.transport.write(self._conn.data_to_send())
@ -141,9 +143,11 @@ class H2ProtocolBaseServer(Protocol):
if self._stream_status[stream_id]:
self._stream_status[stream_id] = False
self.default_send_trailer(stream_id)
else:
logging.error('Stream %d is already closed' % stream_id)
def default_send_trailer(self, stream_id):
logging.info('Sending trailer for stream id %d'%stream_id)
logging.info('Sending trailer for stream id %d' % stream_id)
self._conn.send_headers(stream_id,
headers=[ ('grpc-status', '0') ],
end_stream=True
@ -159,15 +163,14 @@ class H2ProtocolBaseServer(Protocol):
return response_data
def parse_received_data(self, stream_id):
recv_buffer = self._recv_buffer[stream_id]
""" returns a grpc framed string of bytes containing response proto of the size
asked in request """
recv_buffer = self._recv_buffer[stream_id]
grpc_msg_size = struct.unpack('i',recv_buffer[1:5][::-1])[0]
if len(recv_buffer) != GRPC_HEADER_SIZE + grpc_msg_size:
#logging.error('not enough data to decode req proto. size = %d, needed %s'%(len(recv_buffer), 5+grpc_msg_size))
if len(recv_buffer) != _GRPC_HEADER_SIZE + grpc_msg_size:
return None
req_proto_str = recv_buffer[5:5+grpc_msg_size]
sr = messages_pb2.SimpleRequest()
sr.ParseFromString(req_proto_str)
logging.info('Parsed request for stream %d: response_size=%s'%(stream_id, sr.response_size))
logging.info('Parsed request for stream %d: response_size=%s' % (stream_id, sr.response_size))
return sr

@ -3,18 +3,20 @@
"""
import argparse
import logging
import twisted
import twisted.internet
import twisted.internet.endpoints
import twisted.internet.reactor
from twisted.internet.protocol import Factory
from twisted.internet import endpoints, reactor
import http2_base_server
import test_rst_after_header
import test_rst_after_data
import test_rst_during_data
import test_goaway
import test_ping
import test_max_streams
import test_ping
import test_rst_after_data
import test_rst_after_header
import test_rst_during_data
test_case_mappings = {
_TEST_CASE_MAPPING = {
'rst_after_header': test_rst_after_header.TestcaseRstStreamAfterHeader,
'rst_after_data': test_rst_after_data.TestcaseRstStreamAfterData,
'rst_during_data': test_rst_during_data.TestcaseRstStreamDuringData,
@ -23,20 +25,20 @@ test_case_mappings = {
'max_streams': test_max_streams.TestcaseSettingsMaxStreams,
}
class H2Factory(Factory):
class H2Factory(twisted.internet.protocol.Factory):
def __init__(self, testcase):
logging.info('In H2Factory')
logging.info('Creating H2Factory for new connection.')
self._num_streams = 0
self._testcase = testcase
def buildProtocol(self, addr):
self._num_streams += 1
logging.info('New Connection: %d'%self._num_streams)
if not test_case_mappings.has_key(self._testcase):
logging.error('Unknown test case: %s'%self._testcase)
logging.info('New Connection: %d' % self._num_streams)
if not _TEST_CASE_MAPPING.has_key(self._testcase):
logging.error('Unknown test case: %s' % self._testcase)
assert(0)
else:
t = test_case_mappings[self._testcase]
t = _TEST_CASE_MAPPING[self._testcase]
if self._testcase == 'goaway':
return t(self._num_streams).get_base_server()
@ -49,9 +51,9 @@ if __name__ == "__main__":
parser.add_argument("test")
parser.add_argument("port")
args = parser.parse_args()
if args.test not in test_case_mappings.keys():
logging.error('unknown test: %s'%args.test)
if args.test not in _TEST_CASE_MAPPING.keys():
logging.error('unknown test: %s' % args.test)
else:
endpoint = endpoints.TCP4ServerEndpoint(reactor, int(args.port), backlog=128)
endpoint = twisted.internet.endpoints.TCP4ServerEndpoint(twisted.internet.reactor, int(args.port), backlog=128)
endpoint.listen(H2Factory(args.test))
reactor.run()
twisted.internet.reactor.run()

@ -1,5 +1,6 @@
import logging
import time
import http2_base_server
class TestcaseGoaway(object):
@ -7,7 +8,7 @@ class TestcaseGoaway(object):
This test does the following:
Process incoming request normally, i.e. send headers, data and trailers.
Then send a GOAWAY frame with the stream id of the processed request.
It assert that the next request is made on a different TCP connection.
It checks that the next request is made on a different TCP connection.
"""
def __init__(self, iteration):
self._base_server = http2_base_server.H2ProtocolBaseServer()
@ -22,15 +23,14 @@ class TestcaseGoaway(object):
return self._base_server
def on_connection_lost(self, reason):
logging.info('Disconnect received. Count %d'%self._iteration)
logging.info('Disconnect received. Count %d' % self._iteration)
# _iteration == 2 => Two different connections have been used.
if self._iteration == 2:
self._base_server.on_connection_lost(reason)
def on_send_done(self, stream_id):
self._base_server.on_send_done_default(stream_id)
time.sleep(1)
logging.info('Sending GOAWAY for stream %d:'%stream_id)
logging.info('Sending GOAWAY for stream %d:' % stream_id)
self._base_server._conn.close_connection(error_code=0, additional_data=None, last_stream_id=stream_id)
self._base_server._stream_status[stream_id] = False
@ -42,7 +42,7 @@ class TestcaseGoaway(object):
self._base_server.on_data_received_default(event)
sr = self._base_server.parse_received_data(event.stream_id)
if sr:
logging.info('Creating response size = %s'%sr.response_size)
logging.info('Creating response size = %s' % sr.response_size)
response_data = self._base_server.default_response_data(sr.response_size)
self._ready_to_send = True
self._base_server.setup_send(response_data, event.stream_id)

@ -1,6 +1,7 @@
import hyperframe.frame
import logging
import http2_base_server
from hyperframe.frame import SettingsFrame
class TestcaseSettingsMaxStreams(object):
"""
@ -18,7 +19,8 @@ class TestcaseSettingsMaxStreams(object):
def on_connection_made(self):
logging.info('Connection Made')
self._base_server._conn.initiate_connection()
self._base_server._conn.update_settings({SettingsFrame.MAX_CONCURRENT_STREAMS: 1})
self._base_server._conn.update_settings(
{hyperframe.frame.SettingsFrame.MAX_CONCURRENT_STREAMS: 1})
self._base_server.transport.setTcpNoDelay(True)
self._base_server.transport.write(self._base_server._conn.data_to_send())
@ -26,6 +28,7 @@ class TestcaseSettingsMaxStreams(object):
self._base_server.on_data_received_default(event)
sr = self._base_server.parse_received_data(event.stream_id)
if sr:
logging.info('Creating response of size = %s'%sr.response_size)
logging.info('Creating response of size = %s' % sr.response_size)
response_data = self._base_server.default_response_data(sr.response_size)
self._base_server.setup_send(response_data, event.stream_id)
# TODO (makdharma): Add assertion to check number of live streams

@ -1,4 +1,5 @@
import logging
import http2_base_server
class TestcasePing(object):
@ -25,13 +26,13 @@ class TestcasePing(object):
self._base_server.on_data_received_default(event)
sr = self._base_server.parse_received_data(event.stream_id)
if sr:
logging.info('Creating response size = %s'%sr.response_size)
logging.info('Creating response size = %s' % sr.response_size)
response_data = self._base_server.default_response_data(sr.response_size)
self._base_server.default_ping()
self._base_server.setup_send(response_data, event.stream_id)
self._base_server.default_ping()
def on_connection_lost(self, reason):
logging.info('Disconnect received. Ping Count %d'%self._base_server._outstanding_pings)
logging.info('Disconnect received. Ping Count %d' % self._base_server._outstanding_pings)
assert(self._base_server._outstanding_pings == 0)
self._base_server.on_connection_lost(reason)

@ -23,7 +23,6 @@ class TestcaseRstStreamDuringData(object):
response_len = len(response_data)
truncated_response_data = response_data[0:response_len/2]
self._base_server.setup_send(truncated_response_data, event.stream_id)
# send reset stream
def on_send_done(self, stream_id):
self._base_server.send_reset_stream()

Loading…
Cancel
Save