Fix interceptors for unary-unary case

pull/20905/head
Richard Belleville 5 years ago
parent 6dde2f43f7
commit acc6053716
  1. 63
      src/python/grpcio/grpc/_channel.py
  2. 28
      src/python/grpcio_tests/tests/unit/_interceptor_test.py

@ -263,35 +263,41 @@ def _rpc_state_string(class_name, rpc_state):
rpc_state.debug_error_string)
class _RpcError(grpc.RpcError, grpc.Call):
class _RpcError(grpc.RpcError, grpc.Call, grpc.Future):
"""An RPC error not tied to the execution of a particular RPC.
The state passed to _RpcError must be guaranteed not to be accessed by any
other threads.
The RPC represented by the state object must not be in-progress.
Attributes:
_state: An instance of _RPCState.
"""
def __init__(self, state):
if state.cancelled:
raise ValueError("Cannot instantiate an _RpcError for a cancelled RPC.")
if state.code is grpc.StatusCode.OK:
raise ValueError("Cannot instantiate an _RpcError for a successfully completed RPC.")
if state.code is None:
raise ValueError("Cannot instantiate an _RpcError for an incomplete RPC.")
self._state = state
def initial_metadata(self):
with self._state.condition:
return self._state.initial_metadata
return self._state.initial_metadata
def trailing_metadata(self):
with self._state.condition:
return self._state.trailing_metadata
return self._state.trailing_metadata
def code(self):
with self._state.condition:
return self._state.code
return self._state.code
def details(self):
with self._state.condition:
return _common.decode(self._state.details)
return _common.decode(self._state.details)
def debug_error_string(self):
with self._state.condition:
return _common.decode(self._state.debug_error_string)
return _common.decode(self._state.debug_error_string)
def _repr(self):
return _rpc_state_string(self.__class__.__name__, self._state)
@ -302,6 +308,41 @@ class _RpcError(grpc.RpcError, grpc.Call):
def __str__(self):
return self._repr()
def cancel(self):
"""See grpc.Future.cancel."""
return False
def cancelled(self):
"""See grpc.Future.cancelled."""
return False
def running(self):
"""See grpc.Future.running."""
return False
def done(self):
"""See grpc.Future.done."""
return True
def result(self, timeout=None):
"""See grpc.Future.result."""
raise self
def exception(self, timeout=None):
"""See grpc.Future.exception."""
return self
def traceback(self, timeout=None):
"""See grpc.Future.traceback."""
try:
raise self
except grpc.RpcError:
return sys.exc_info()[2]
def add_done_callback(self, timeout=None):
"""See grpc.Future.add_done_callback."""
fn(self)
class _Rendezvous(grpc.RpcError, grpc.RpcContext):
"""An RPC iterator.

@ -32,6 +32,8 @@ _DESERIALIZE_REQUEST = lambda bytestring: bytestring[len(bytestring) // 2:]
_SERIALIZE_RESPONSE = lambda bytestring: bytestring * 3
_DESERIALIZE_RESPONSE = lambda bytestring: bytestring[:len(bytestring) // 3]
_EXCEPTION_REQUEST = b'\x09\x0a'
_UNARY_UNARY = '/test/UnaryUnary'
_UNARY_STREAM = '/test/UnaryStream'
_STREAM_UNARY = '/test/StreamUnary'
@ -70,8 +72,11 @@ class _Handler(object):
'testkey',
'testvalue',
),))
if request == _EXCEPTION_REQUEST:
raise RuntimeError()
return request
# TODO(gnossen): Instrument this for a test of exception handling.
def handle_unary_stream(self, request, servicer_context):
for _ in range(test_constants.STREAM_LENGTH):
self._control.control()
@ -232,7 +237,10 @@ class _LoggingInterceptor(
def intercept_unary_unary(self, continuation, client_call_details, request):
self.record.append(self.tag + ':intercept_unary_unary')
return continuation(client_call_details, request)
result = continuation(client_call_details, request)
assert isinstance(result, grpc.Call), '{} is not an instance of grpc.Call'.format(result)
assert isinstance(result, grpc.Future), '{} is not an instance of grpc.Future'.format(result)
return result
def intercept_unary_stream(self, continuation, client_call_details,
request):
@ -440,6 +448,24 @@ class InterceptorTest(unittest.TestCase):
's1:intercept_service', 's2:intercept_service'
])
def testInterceptedUnaryRequestBlockingUnaryResponseWithException(self):
request = _EXCEPTION_REQUEST
self._record[:] = []
channel = grpc.intercept_channel(self._channel,
_LoggingInterceptor(
'c1', self._record),
_LoggingInterceptor(
'c2', self._record))
multi_callable = _unary_unary_multi_callable(channel)
with self.assertRaises(grpc.RpcError) as exception_context:
multi_callable(
request,
metadata=(('test',
'InterceptedUnaryRequestBlockingUnaryResponse'),))
def testInterceptedUnaryRequestBlockingUnaryResponseWithCall(self):
request = b'\x07\x08'

Loading…
Cancel
Save