From acc605371647bd349dbcf79c0dc490f136c182ff Mon Sep 17 00:00:00 2001 From: Richard Belleville Date: Fri, 1 Nov 2019 15:28:48 -0700 Subject: [PATCH] Fix interceptors for unary-unary case --- src/python/grpcio/grpc/_channel.py | 63 +++++++++++++++---- .../tests/unit/_interceptor_test.py | 28 ++++++++- 2 files changed, 79 insertions(+), 12 deletions(-) diff --git a/src/python/grpcio/grpc/_channel.py b/src/python/grpcio/grpc/_channel.py index 232688c1ca1..5fb843acbde 100644 --- a/src/python/grpcio/grpc/_channel.py +++ b/src/python/grpcio/grpc/_channel.py @@ -263,35 +263,41 @@ def _rpc_state_string(class_name, rpc_state): rpc_state.debug_error_string) -class _RpcError(grpc.RpcError, grpc.Call): +class _RpcError(grpc.RpcError, grpc.Call, grpc.Future): """An RPC error not tied to the execution of a particular RPC. + The state passed to _RpcError must be guaranteed not to be accessed by any + other threads. + + The RPC represented by the state object must not be in-progress. + Attributes: _state: An instance of _RPCState. """ def __init__(self, state): + if state.cancelled: + raise ValueError("Cannot instantiate an _RpcError for a cancelled RPC.") + if state.code is grpc.StatusCode.OK: + raise ValueError("Cannot instantiate an _RpcError for a successfully completed RPC.") + if state.code is None: + raise ValueError("Cannot instantiate an _RpcError for an incomplete RPC.") self._state = state def initial_metadata(self): - with self._state.condition: - return self._state.initial_metadata + return self._state.initial_metadata def trailing_metadata(self): - with self._state.condition: - return self._state.trailing_metadata + return self._state.trailing_metadata def code(self): - with self._state.condition: - return self._state.code + return self._state.code def details(self): - with self._state.condition: - return _common.decode(self._state.details) + return _common.decode(self._state.details) def debug_error_string(self): - with self._state.condition: - return _common.decode(self._state.debug_error_string) + return _common.decode(self._state.debug_error_string) def _repr(self): return _rpc_state_string(self.__class__.__name__, self._state) @@ -302,6 +308,41 @@ class _RpcError(grpc.RpcError, grpc.Call): def __str__(self): return self._repr() + def cancel(self): + """See grpc.Future.cancel.""" + return False + + def cancelled(self): + """See grpc.Future.cancelled.""" + return False + + def running(self): + """See grpc.Future.running.""" + return False + + def done(self): + """See grpc.Future.done.""" + return True + + def result(self, timeout=None): + """See grpc.Future.result.""" + raise self + + def exception(self, timeout=None): + """See grpc.Future.exception.""" + return self + + def traceback(self, timeout=None): + """See grpc.Future.traceback.""" + try: + raise self + except grpc.RpcError: + return sys.exc_info()[2] + + def add_done_callback(self, timeout=None): + """See grpc.Future.add_done_callback.""" + fn(self) + class _Rendezvous(grpc.RpcError, grpc.RpcContext): """An RPC iterator. diff --git a/src/python/grpcio_tests/tests/unit/_interceptor_test.py b/src/python/grpcio_tests/tests/unit/_interceptor_test.py index a647e5e720c..1b72123153e 100644 --- a/src/python/grpcio_tests/tests/unit/_interceptor_test.py +++ b/src/python/grpcio_tests/tests/unit/_interceptor_test.py @@ -32,6 +32,8 @@ _DESERIALIZE_REQUEST = lambda bytestring: bytestring[len(bytestring) // 2:] _SERIALIZE_RESPONSE = lambda bytestring: bytestring * 3 _DESERIALIZE_RESPONSE = lambda bytestring: bytestring[:len(bytestring) // 3] +_EXCEPTION_REQUEST = b'\x09\x0a' + _UNARY_UNARY = '/test/UnaryUnary' _UNARY_STREAM = '/test/UnaryStream' _STREAM_UNARY = '/test/StreamUnary' @@ -70,8 +72,11 @@ class _Handler(object): 'testkey', 'testvalue', ),)) + if request == _EXCEPTION_REQUEST: + raise RuntimeError() return request + # TODO(gnossen): Instrument this for a test of exception handling. def handle_unary_stream(self, request, servicer_context): for _ in range(test_constants.STREAM_LENGTH): self._control.control() @@ -232,7 +237,10 @@ class _LoggingInterceptor( def intercept_unary_unary(self, continuation, client_call_details, request): self.record.append(self.tag + ':intercept_unary_unary') - return continuation(client_call_details, request) + result = continuation(client_call_details, request) + assert isinstance(result, grpc.Call), '{} is not an instance of grpc.Call'.format(result) + assert isinstance(result, grpc.Future), '{} is not an instance of grpc.Future'.format(result) + return result def intercept_unary_stream(self, continuation, client_call_details, request): @@ -440,6 +448,24 @@ class InterceptorTest(unittest.TestCase): 's1:intercept_service', 's2:intercept_service' ]) + def testInterceptedUnaryRequestBlockingUnaryResponseWithException(self): + request = _EXCEPTION_REQUEST + + self._record[:] = [] + + channel = grpc.intercept_channel(self._channel, + _LoggingInterceptor( + 'c1', self._record), + _LoggingInterceptor( + 'c2', self._record)) + + multi_callable = _unary_unary_multi_callable(channel) + with self.assertRaises(grpc.RpcError) as exception_context: + multi_callable( + request, + metadata=(('test', + 'InterceptedUnaryRequestBlockingUnaryResponse'),)) + def testInterceptedUnaryRequestBlockingUnaryResponseWithCall(self): request = b'\x07\x08'