Adopt reviewer's advice

pull/20905/head
Richard Belleville 5 years ago
parent e4d58fba6d
commit f7249fcd3a
  1. 24
      src/python/grpcio/grpc/_channel.py
  2. 54
      src/python/grpcio_tests/tests/unit/_interceptor_test.py

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
"""Invocation-side implementation of gRPC Python.""" """Invocation-side implementation of gRPC Python."""
import copy
import functools import functools
import logging import logging
import sys import sys
@ -266,27 +267,20 @@ def _rpc_state_string(class_name, rpc_state):
class _RpcError(grpc.RpcError, grpc.Call, grpc.Future): class _RpcError(grpc.RpcError, grpc.Call, grpc.Future):
"""An RPC error not tied to the execution of a particular RPC. """An RPC error not tied to the execution of a particular RPC.
The state passed to _RpcError must be guaranteed not to be accessed by any The RPC represented by the state object must not be in-progress or
other threads. cancelled.
The RPC represented by the state object must not be in-progress.
Attributes: Attributes:
_state: An instance of _RPCState. _state: An instance of _RPCState.
""" """
def __init__(self, state): def __init__(self, state):
if state.cancelled: with state.condition:
raise ValueError( self._state = _RPCState((), copy.deepcopy(state.initial_metadata),
"Cannot instantiate an _RpcError for a cancelled RPC.") copy.deepcopy(state.trailing_metadata),
if state.code is grpc.StatusCode.OK: state.code, copy.deepcopy(state.details))
raise ValueError( self._state.response = copy.copy(state.response)
"Cannot instantiate an _RpcError for a successfully completed RPC." self._state.debug_error_string = copy.copy(state.debug_error_string)
)
if state.code is None:
raise ValueError(
"Cannot instantiate an _RpcError for an incomplete RPC.")
self._state = state
def initial_metadata(self): def initial_metadata(self):
return self._state.initial_metadata return self._state.initial_metadata

@ -40,6 +40,10 @@ _STREAM_UNARY = '/test/StreamUnary'
_STREAM_STREAM = '/test/StreamStream' _STREAM_STREAM = '/test/StreamStream'
class _ApplicationErrorStandin(Exception):
pass
class _Callback(object): class _Callback(object):
def __init__(self): def __init__(self):
@ -73,12 +77,12 @@ class _Handler(object):
'testvalue', 'testvalue',
),)) ),))
if request == _EXCEPTION_REQUEST: if request == _EXCEPTION_REQUEST:
raise RuntimeError() raise _ApplicationErrorStandin()
return request return request
def handle_unary_stream(self, request, servicer_context): def handle_unary_stream(self, request, servicer_context):
if request == _EXCEPTION_REQUEST: if request == _EXCEPTION_REQUEST:
raise RuntimeError() raise _ApplicationErrorStandin()
for _ in range(test_constants.STREAM_LENGTH): for _ in range(test_constants.STREAM_LENGTH):
self._control.control() self._control.control()
yield request yield request
@ -104,7 +108,7 @@ class _Handler(object):
'testvalue', 'testvalue',
),)) ),))
if _EXCEPTION_REQUEST in response_elements: if _EXCEPTION_REQUEST in response_elements:
raise RuntimeError() raise _ApplicationErrorStandin()
return b''.join(response_elements) return b''.join(response_elements)
def handle_stream_stream(self, request_iterator, servicer_context): def handle_stream_stream(self, request_iterator, servicer_context):
@ -116,7 +120,7 @@ class _Handler(object):
),)) ),))
for request in request_iterator: for request in request_iterator:
if request == _EXCEPTION_REQUEST: if request == _EXCEPTION_REQUEST:
raise RuntimeError() raise _ApplicationErrorStandin()
self._control.control() self._control.control()
yield request yield request
self._control.control() self._control.control()
@ -245,10 +249,12 @@ class _LoggingInterceptor(
result = continuation(client_call_details, request) result = continuation(client_call_details, request)
assert isinstance( assert isinstance(
result, result,
grpc.Call), '{} is not an instance of grpc.Call'.format(result) grpc.Call), '{} ({}) is not an instance of grpc.Call'.format(
result, type(result))
assert isinstance( assert isinstance(
result, result,
grpc.Future), '{} is not an instance of grpc.Future'.format(result) grpc.Future), '{} ({}) is not an instance of grpc.Future'.format(
result, type(result))
return result return result
def intercept_unary_stream(self, continuation, client_call_details, def intercept_unary_stream(self, continuation, client_call_details,
@ -476,11 +482,18 @@ class InterceptorTest(unittest.TestCase):
'c2', self._record)) 'c2', self._record))
multi_callable = _unary_unary_multi_callable(channel) multi_callable = _unary_unary_multi_callable(channel)
with self.assertRaises(grpc.RpcError): with self.assertRaises(grpc.RpcError) as exception_context:
multi_callable( multi_callable(
request, request,
metadata=(('test', metadata=(('test',
'InterceptedUnaryRequestBlockingUnaryResponse'),)) 'InterceptedUnaryRequestBlockingUnaryResponse'),))
exception = exception_context.exception
self.assertFalse(exception.cancelled())
self.assertFalse(exception.running())
self.assertTrue(exception.done())
with self.assertRaises(grpc.RpcError):
exception.result()
self.assertIsInstance(exception.exception(), grpc.RpcError)
def testInterceptedUnaryRequestBlockingUnaryResponseWithCall(self): def testInterceptedUnaryRequestBlockingUnaryResponseWithCall(self):
request = b'\x07\x08' request = b'\x07\x08'
@ -561,8 +574,15 @@ class InterceptorTest(unittest.TestCase):
response_iterator = multi_callable( response_iterator = multi_callable(
request, request,
metadata=(('test', 'InterceptedUnaryRequestStreamResponse'),)) metadata=(('test', 'InterceptedUnaryRequestStreamResponse'),))
with self.assertRaises(grpc.RpcError): with self.assertRaises(grpc.RpcError) as exception_context:
tuple(response_iterator) tuple(response_iterator)
exception = exception_context.exception
self.assertFalse(exception.cancelled())
self.assertFalse(exception.running())
self.assertTrue(exception.done())
with self.assertRaises(grpc.RpcError):
exception.result()
self.assertIsInstance(exception.exception(), grpc.RpcError)
def testInterceptedStreamRequestBlockingUnaryResponse(self): def testInterceptedStreamRequestBlockingUnaryResponse(self):
requests = tuple( requests = tuple(
@ -650,8 +670,15 @@ class InterceptorTest(unittest.TestCase):
response_future = multi_callable.future( response_future = multi_callable.future(
request_iterator, request_iterator,
metadata=(('test', 'InterceptedStreamRequestFutureUnaryResponse'),)) metadata=(('test', 'InterceptedStreamRequestFutureUnaryResponse'),))
with self.assertRaises(grpc.RpcError): with self.assertRaises(grpc.RpcError) as exception_context:
response_future.result() response_future.result()
exception = exception_context.exception
self.assertFalse(exception.cancelled())
self.assertFalse(exception.running())
self.assertTrue(exception.done())
with self.assertRaises(grpc.RpcError):
exception.result()
self.assertIsInstance(exception.exception(), grpc.RpcError)
def testInterceptedStreamRequestStreamResponse(self): def testInterceptedStreamRequestStreamResponse(self):
requests = tuple( requests = tuple(
@ -692,8 +719,15 @@ class InterceptorTest(unittest.TestCase):
response_iterator = multi_callable( response_iterator = multi_callable(
request_iterator, request_iterator,
metadata=(('test', 'InterceptedStreamRequestStreamResponse'),)) metadata=(('test', 'InterceptedStreamRequestStreamResponse'),))
with self.assertRaises(grpc.RpcError): with self.assertRaises(grpc.RpcError) as exception_context:
tuple(response_iterator) tuple(response_iterator)
exception = exception_context.exception
self.assertFalse(exception.cancelled())
self.assertFalse(exception.running())
self.assertTrue(exception.done())
with self.assertRaises(grpc.RpcError):
exception.result()
self.assertIsInstance(exception.exception(), grpc.RpcError)
if __name__ == '__main__': if __name__ == '__main__':

Loading…
Cancel
Save