diff --git a/src/python/grpcio_tests/tests/unit/_interceptor_test.py b/src/python/grpcio_tests/tests/unit/_interceptor_test.py index 1b72123153e..243c95468bc 100644 --- a/src/python/grpcio_tests/tests/unit/_interceptor_test.py +++ b/src/python/grpcio_tests/tests/unit/_interceptor_test.py @@ -76,8 +76,9 @@ class _Handler(object): raise RuntimeError() return request - # TODO(gnossen): Instrument this for a test of exception handling. def handle_unary_stream(self, request, servicer_context): + if request == _EXCEPTION_REQUEST: + raise RuntimeError() for _ in range(test_constants.STREAM_LENGTH): self._control.control() yield request @@ -102,6 +103,8 @@ class _Handler(object): 'testkey', 'testvalue', ),)) + if _EXCEPTION_REQUEST in response_elements: + raise RuntimeError() return b''.join(response_elements) def handle_stream_stream(self, request_iterator, servicer_context): @@ -112,6 +115,8 @@ class _Handler(object): 'testvalue', ),)) for request in request_iterator: + if request == _EXCEPTION_REQUEST: + raise RuntimeError() self._control.control() yield request self._control.control() @@ -250,7 +255,10 @@ class _LoggingInterceptor( def intercept_stream_unary(self, continuation, client_call_details, request_iterator): self.record.append(self.tag + ':intercept_stream_unary') - return continuation(client_call_details, request_iterator) + result = continuation(client_call_details, request_iterator) + assert isinstance(result, grpc.Call), '{} is not an instance of grpc.Call'.format(result) + assert isinstance(result, grpc.Future), '{} is not an instance of grpc.Future'.format(result) + return result def intercept_stream_stream(self, continuation, client_call_details, request_iterator): @@ -448,7 +456,7 @@ class InterceptorTest(unittest.TestCase): 's1:intercept_service', 's2:intercept_service' ]) - def testInterceptedUnaryRequestBlockingUnaryResponseWithException(self): + def testInterceptedUnaryRequestBlockingUnaryResponseWithError(self): request = _EXCEPTION_REQUEST self._record[:] = [] @@ -460,7 +468,7 @@ class InterceptorTest(unittest.TestCase): 'c2', self._record)) multi_callable = _unary_unary_multi_callable(channel) - with self.assertRaises(grpc.RpcError) as exception_context: + with self.assertRaises(grpc.RpcError): multi_callable( request, metadata=(('test', @@ -531,6 +539,23 @@ class InterceptorTest(unittest.TestCase): 's1:intercept_service', 's2:intercept_service' ]) + def testInterceptedUnaryRequestStreamResponseWithError(self): + request = _EXCEPTION_REQUEST + + self._record[:] = [] + channel = grpc.intercept_channel(self._channel, + _LoggingInterceptor( + 'c1', self._record), + _LoggingInterceptor( + 'c2', self._record)) + + multi_callable = _unary_stream_multi_callable(channel) + response_iterator = multi_callable( + request, + metadata=(('test', 'InterceptedUnaryRequestStreamResponse'),)) + with self.assertRaises(grpc.RpcError): + tuple(response_iterator) + def testInterceptedStreamRequestBlockingUnaryResponse(self): requests = tuple( b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)) @@ -601,6 +626,25 @@ class InterceptorTest(unittest.TestCase): 's1:intercept_service', 's2:intercept_service' ]) + def testInterceptedStreamRequestFutureUnaryResponseWithError(self): + requests = tuple( + _EXCEPTION_REQUEST for _ in range(test_constants.STREAM_LENGTH)) + request_iterator = iter(requests) + + self._record[:] = [] + channel = grpc.intercept_channel(self._channel, + _LoggingInterceptor( + 'c1', self._record), + _LoggingInterceptor( + 'c2', self._record)) + + multi_callable = _stream_unary_multi_callable(channel) + response_future = multi_callable.future( + request_iterator, + metadata=(('test', 'InterceptedStreamRequestFutureUnaryResponse'),)) + with self.assertRaises(grpc.RpcError): + response_future.result() + def testInterceptedStreamRequestStreamResponse(self): requests = tuple( b'\x77\x58' for _ in range(test_constants.STREAM_LENGTH)) @@ -624,6 +668,25 @@ class InterceptorTest(unittest.TestCase): 's1:intercept_service', 's2:intercept_service' ]) + def testInterceptedStreamRequestStreamResponseWithError(self): + requests = tuple( + _EXCEPTION_REQUEST for _ in range(test_constants.STREAM_LENGTH)) + request_iterator = iter(requests) + + self._record[:] = [] + channel = grpc.intercept_channel(self._channel, + _LoggingInterceptor( + 'c1', self._record), + _LoggingInterceptor( + 'c2', self._record)) + + multi_callable = _stream_stream_multi_callable(channel) + response_iterator = multi_callable( + request_iterator, + metadata=(('test', 'InterceptedStreamRequestStreamResponse'),)) + with self.assertRaises(grpc.RpcError): + tuple(response_iterator) + if __name__ == '__main__': logging.basicConfig()