Add tests for other arities

pull/20905/head
Richard Belleville 5 years ago
parent acc6053716
commit 6f0b772afa
  1. 71
      src/python/grpcio_tests/tests/unit/_interceptor_test.py

@ -76,8 +76,9 @@ class _Handler(object):
raise RuntimeError()
return request
# TODO(gnossen): Instrument this for a test of exception handling.
def handle_unary_stream(self, request, servicer_context):
if request == _EXCEPTION_REQUEST:
raise RuntimeError()
for _ in range(test_constants.STREAM_LENGTH):
self._control.control()
yield request
@ -102,6 +103,8 @@ class _Handler(object):
'testkey',
'testvalue',
),))
if _EXCEPTION_REQUEST in response_elements:
raise RuntimeError()
return b''.join(response_elements)
def handle_stream_stream(self, request_iterator, servicer_context):
@ -112,6 +115,8 @@ class _Handler(object):
'testvalue',
),))
for request in request_iterator:
if request == _EXCEPTION_REQUEST:
raise RuntimeError()
self._control.control()
yield request
self._control.control()
@ -250,7 +255,10 @@ class _LoggingInterceptor(
def intercept_stream_unary(self, continuation, client_call_details,
request_iterator):
self.record.append(self.tag + ':intercept_stream_unary')
return continuation(client_call_details, request_iterator)
result = continuation(client_call_details, request_iterator)
assert isinstance(result, grpc.Call), '{} is not an instance of grpc.Call'.format(result)
assert isinstance(result, grpc.Future), '{} is not an instance of grpc.Future'.format(result)
return result
def intercept_stream_stream(self, continuation, client_call_details,
request_iterator):
@ -448,7 +456,7 @@ class InterceptorTest(unittest.TestCase):
's1:intercept_service', 's2:intercept_service'
])
def testInterceptedUnaryRequestBlockingUnaryResponseWithException(self):
def testInterceptedUnaryRequestBlockingUnaryResponseWithError(self):
request = _EXCEPTION_REQUEST
self._record[:] = []
@ -460,7 +468,7 @@ class InterceptorTest(unittest.TestCase):
'c2', self._record))
multi_callable = _unary_unary_multi_callable(channel)
with self.assertRaises(grpc.RpcError) as exception_context:
with self.assertRaises(grpc.RpcError):
multi_callable(
request,
metadata=(('test',
@ -531,6 +539,23 @@ class InterceptorTest(unittest.TestCase):
's1:intercept_service', 's2:intercept_service'
])
def testInterceptedUnaryRequestStreamResponseWithError(self):
request = _EXCEPTION_REQUEST
self._record[:] = []
channel = grpc.intercept_channel(self._channel,
_LoggingInterceptor(
'c1', self._record),
_LoggingInterceptor(
'c2', self._record))
multi_callable = _unary_stream_multi_callable(channel)
response_iterator = multi_callable(
request,
metadata=(('test', 'InterceptedUnaryRequestStreamResponse'),))
with self.assertRaises(grpc.RpcError):
tuple(response_iterator)
def testInterceptedStreamRequestBlockingUnaryResponse(self):
requests = tuple(
b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH))
@ -601,6 +626,25 @@ class InterceptorTest(unittest.TestCase):
's1:intercept_service', 's2:intercept_service'
])
def testInterceptedStreamRequestFutureUnaryResponseWithError(self):
requests = tuple(
_EXCEPTION_REQUEST for _ in range(test_constants.STREAM_LENGTH))
request_iterator = iter(requests)
self._record[:] = []
channel = grpc.intercept_channel(self._channel,
_LoggingInterceptor(
'c1', self._record),
_LoggingInterceptor(
'c2', self._record))
multi_callable = _stream_unary_multi_callable(channel)
response_future = multi_callable.future(
request_iterator,
metadata=(('test', 'InterceptedStreamRequestFutureUnaryResponse'),))
with self.assertRaises(grpc.RpcError):
response_future.result()
def testInterceptedStreamRequestStreamResponse(self):
requests = tuple(
b'\x77\x58' for _ in range(test_constants.STREAM_LENGTH))
@ -624,6 +668,25 @@ class InterceptorTest(unittest.TestCase):
's1:intercept_service', 's2:intercept_service'
])
def testInterceptedStreamRequestStreamResponseWithError(self):
requests = tuple(
_EXCEPTION_REQUEST for _ in range(test_constants.STREAM_LENGTH))
request_iterator = iter(requests)
self._record[:] = []
channel = grpc.intercept_channel(self._channel,
_LoggingInterceptor(
'c1', self._record),
_LoggingInterceptor(
'c2', self._record))
multi_callable = _stream_stream_multi_callable(channel)
response_iterator = multi_callable(
request_iterator,
metadata=(('test', 'InterceptedStreamRequestStreamResponse'),))
with self.assertRaises(grpc.RpcError):
tuple(response_iterator)
if __name__ == '__main__':
logging.basicConfig()

Loading…
Cancel
Save