From f7249fcd3accbc001e45d7dca36fcaad2e70b96a Mon Sep 17 00:00:00 2001 From: Richard Belleville Date: Mon, 4 Nov 2019 13:44:39 -0800 Subject: [PATCH] Adopt reviewer's advice --- src/python/grpcio/grpc/_channel.py | 24 ++++----- .../tests/unit/_interceptor_test.py | 54 +++++++++++++++---- 2 files changed, 53 insertions(+), 25 deletions(-) diff --git a/src/python/grpcio/grpc/_channel.py b/src/python/grpcio/grpc/_channel.py index fb25d91e401..4044228bbb8 100644 --- a/src/python/grpcio/grpc/_channel.py +++ b/src/python/grpcio/grpc/_channel.py @@ -13,6 +13,7 @@ # limitations under the License. """Invocation-side implementation of gRPC Python.""" +import copy import functools import logging import sys @@ -266,27 +267,20 @@ def _rpc_state_string(class_name, rpc_state): class _RpcError(grpc.RpcError, grpc.Call, grpc.Future): """An RPC error not tied to the execution of a particular RPC. - The state passed to _RpcError must be guaranteed not to be accessed by any - other threads. - - The RPC represented by the state object must not be in-progress. + The RPC represented by the state object must not be in-progress or + cancelled. Attributes: _state: An instance of _RPCState. """ def __init__(self, state): - if state.cancelled: - raise ValueError( - "Cannot instantiate an _RpcError for a cancelled RPC.") - if state.code is grpc.StatusCode.OK: - raise ValueError( - "Cannot instantiate an _RpcError for a successfully completed RPC." - ) - if state.code is None: - raise ValueError( - "Cannot instantiate an _RpcError for an incomplete RPC.") - self._state = state + with state.condition: + self._state = _RPCState((), copy.deepcopy(state.initial_metadata), + copy.deepcopy(state.trailing_metadata), + state.code, copy.deepcopy(state.details)) + self._state.response = copy.copy(state.response) + self._state.debug_error_string = copy.copy(state.debug_error_string) def initial_metadata(self): return self._state.initial_metadata diff --git a/src/python/grpcio_tests/tests/unit/_interceptor_test.py b/src/python/grpcio_tests/tests/unit/_interceptor_test.py index 5a6434e987e..6da3d53816a 100644 --- a/src/python/grpcio_tests/tests/unit/_interceptor_test.py +++ b/src/python/grpcio_tests/tests/unit/_interceptor_test.py @@ -40,6 +40,10 @@ _STREAM_UNARY = '/test/StreamUnary' _STREAM_STREAM = '/test/StreamStream' +class _ApplicationErrorStandin(Exception): + pass + + class _Callback(object): def __init__(self): @@ -73,12 +77,12 @@ class _Handler(object): 'testvalue', ),)) if request == _EXCEPTION_REQUEST: - raise RuntimeError() + raise _ApplicationErrorStandin() return request def handle_unary_stream(self, request, servicer_context): if request == _EXCEPTION_REQUEST: - raise RuntimeError() + raise _ApplicationErrorStandin() for _ in range(test_constants.STREAM_LENGTH): self._control.control() yield request @@ -104,7 +108,7 @@ class _Handler(object): 'testvalue', ),)) if _EXCEPTION_REQUEST in response_elements: - raise RuntimeError() + raise _ApplicationErrorStandin() return b''.join(response_elements) def handle_stream_stream(self, request_iterator, servicer_context): @@ -116,7 +120,7 @@ class _Handler(object): ),)) for request in request_iterator: if request == _EXCEPTION_REQUEST: - raise RuntimeError() + raise _ApplicationErrorStandin() self._control.control() yield request self._control.control() @@ -245,10 +249,12 @@ class _LoggingInterceptor( result = continuation(client_call_details, request) assert isinstance( result, - grpc.Call), '{} is not an instance of grpc.Call'.format(result) + grpc.Call), '{} ({}) is not an instance of grpc.Call'.format( + result, type(result)) assert isinstance( result, - grpc.Future), '{} is not an instance of grpc.Future'.format(result) + grpc.Future), '{} ({}) is not an instance of grpc.Future'.format( + result, type(result)) return result def intercept_unary_stream(self, continuation, client_call_details, @@ -476,11 +482,18 @@ class InterceptorTest(unittest.TestCase): 'c2', self._record)) multi_callable = _unary_unary_multi_callable(channel) - with self.assertRaises(grpc.RpcError): + with self.assertRaises(grpc.RpcError) as exception_context: multi_callable( request, metadata=(('test', 'InterceptedUnaryRequestBlockingUnaryResponse'),)) + exception = exception_context.exception + self.assertFalse(exception.cancelled()) + self.assertFalse(exception.running()) + self.assertTrue(exception.done()) + with self.assertRaises(grpc.RpcError): + exception.result() + self.assertIsInstance(exception.exception(), grpc.RpcError) def testInterceptedUnaryRequestBlockingUnaryResponseWithCall(self): request = b'\x07\x08' @@ -561,8 +574,15 @@ class InterceptorTest(unittest.TestCase): response_iterator = multi_callable( request, metadata=(('test', 'InterceptedUnaryRequestStreamResponse'),)) - with self.assertRaises(grpc.RpcError): + with self.assertRaises(grpc.RpcError) as exception_context: tuple(response_iterator) + exception = exception_context.exception + self.assertFalse(exception.cancelled()) + self.assertFalse(exception.running()) + self.assertTrue(exception.done()) + with self.assertRaises(grpc.RpcError): + exception.result() + self.assertIsInstance(exception.exception(), grpc.RpcError) def testInterceptedStreamRequestBlockingUnaryResponse(self): requests = tuple( @@ -650,8 +670,15 @@ class InterceptorTest(unittest.TestCase): response_future = multi_callable.future( request_iterator, metadata=(('test', 'InterceptedStreamRequestFutureUnaryResponse'),)) - with self.assertRaises(grpc.RpcError): + with self.assertRaises(grpc.RpcError) as exception_context: response_future.result() + exception = exception_context.exception + self.assertFalse(exception.cancelled()) + self.assertFalse(exception.running()) + self.assertTrue(exception.done()) + with self.assertRaises(grpc.RpcError): + exception.result() + self.assertIsInstance(exception.exception(), grpc.RpcError) def testInterceptedStreamRequestStreamResponse(self): requests = tuple( @@ -692,8 +719,15 @@ class InterceptorTest(unittest.TestCase): response_iterator = multi_callable( request_iterator, metadata=(('test', 'InterceptedStreamRequestStreamResponse'),)) - with self.assertRaises(grpc.RpcError): + with self.assertRaises(grpc.RpcError) as exception_context: tuple(response_iterator) + exception = exception_context.exception + self.assertFalse(exception.cancelled()) + self.assertFalse(exception.running()) + self.assertTrue(exception.done()) + with self.assertRaises(grpc.RpcError): + exception.result() + self.assertIsInstance(exception.exception(), grpc.RpcError) if __name__ == '__main__':