diff --git a/src/python/grpcio/grpc/experimental/aio/_call.py b/src/python/grpcio/grpc/experimental/aio/_call.py index f04c6cdd761..0ff5c0a83f8 100644 --- a/src/python/grpcio/grpc/experimental/aio/_call.py +++ b/src/python/grpcio/grpc/experimental/aio/_call.py @@ -233,7 +233,7 @@ class Call(_base_call.Call): if self._code is grpc.StatusCode.OK: return _OK_CALL_REPRESENTATION.format( self.__class__.__name__, self._code, - self._status.result().self._status.result().details()) + self._status.result().details()) else: return _NON_OK_CALL_REPRESENTATION.format( self.__class__.__name__, self._code, diff --git a/src/python/grpcio/grpc/experimental/aio/_interceptor.py b/src/python/grpcio/grpc/experimental/aio/_interceptor.py index 00ea17924a4..4c643443dda 100644 --- a/src/python/grpcio/grpc/experimental/aio/_interceptor.py +++ b/src/python/grpcio/grpc/experimental/aio/_interceptor.py @@ -22,7 +22,7 @@ import grpc from grpc._cython import cygrpc from . import _base_call -from ._call import UnaryUnaryCall +from ._call import UnaryUnaryCall, AioRpcError from ._utils import _timeout_to_deadline from ._typing import (RequestType, SerializingFunction, DeserializingFunction, MetadataType, ResponseType) @@ -135,19 +135,9 @@ class InterceptedUnaryUnaryCall(_base_call.UnaryUnaryCall): if interceptor: continuation = functools.partial(_run_interceptor, interceptors) - try: - call_or_response = await interceptor.intercept_unary_unary( - continuation, client_call_details, request) - except grpc.RpcError as err: - # gRPC error is masked inside an artificial call, - # caller will see this error if and only - # if it runs an `await call` operation - return UnaryUnaryCallRpcError(err) - except asyncio.CancelledError: - # Cancellation is masked inside an artificial call, - # caller will see this error if and only - # if it runs an `await call` operation - return UnaryUnaryCancelledError() + + call_or_response = await interceptor.intercept_unary_unary( + continuation, client_call_details, request) if isinstance(call_or_response, _base_call.UnaryUnaryCall): return call_or_response @@ -176,14 +166,25 @@ class InterceptedUnaryUnaryCall(_base_call.UnaryUnaryCall): if not self._interceptors_task.done(): return False - call = self._interceptors_task.result() - return call.cancelled() + try: + call = self._interceptors_task.result() + except AioRpcError: + return False + except asyncio.CancelledError: + return True + else: + return call.cancelled() def done(self) -> bool: if not self._interceptors_task.done(): return False - return True + try: + call = self._interceptors_task.result() + except (AioRpcError, asyncio.CancelledError): + return True + else: + return call.done() def add_done_callback(self, unused_callback) -> None: raise NotImplementedError() @@ -192,19 +193,54 @@ class InterceptedUnaryUnaryCall(_base_call.UnaryUnaryCall): raise NotImplementedError() async def initial_metadata(self) -> Optional[MetadataType]: - return await (await self._interceptors_task).initial_metadata() + try: + call = await self._interceptors_task + except AioRpcError as err: + return err.initial_metadata() + except asyncio.CancelledError: + return None + else: + return await call.initial_metadata() async def trailing_metadata(self) -> Optional[MetadataType]: - return await (await self._interceptors_task).trailing_metadata() + try: + call = await self._interceptors_task + except AioRpcError as err: + return err.trailing_metadata() + except asyncio.CancelledError: + return None + else: + return await call.trailing_metadata() async def code(self) -> grpc.StatusCode: - return await (await self._interceptors_task).code() + try: + call = await self._interceptors_task + except AioRpcError as err: + return err.code() + except asyncio.CancelledError: + return grpc.StatusCode.CANCELLED + else: + return await call.code() async def details(self) -> str: - return await (await self._interceptors_task).details() + try: + call = await self._interceptors_task + except AioRpcError as err: + return err.details() + except asyncio.CancelledError: + return _LOCAL_CANCELLATION_DETAILS + else: + return await call.details() async def debug_error_string(self) -> Optional[str]: - return await (await self._interceptors_task).debug_error_string() + try: + call = await self._interceptors_task + except AioRpcError as err: + return err.debug_error_string() + except asyncio.CancelledError: + return '' + else: + return await call.debug_error_string() def __await__(self): call = yield from self._interceptors_task.__await__() @@ -212,47 +248,6 @@ class InterceptedUnaryUnaryCall(_base_call.UnaryUnaryCall): return response -class UnaryUnaryCallRpcError(_base_call.UnaryUnaryCall): - """Final UnaryUnaryCall class finished with an RpcError.""" - _error: grpc.RpcError - - def __init__(self, error: grpc.RpcError) -> None: - self._error = error - - def cancel(self) -> bool: - return False - - def cancelled(self) -> bool: - return False - - def done(self) -> bool: - return True - - def add_done_callback(self, unused_callback) -> None: - raise NotImplementedError() - - def time_remaining(self) -> Optional[float]: - raise NotImplementedError() - - async def initial_metadata(self) -> Optional[MetadataType]: - return None - - async def trailing_metadata(self) -> Optional[MetadataType]: - return self._error.initial_metadata() - - async def code(self) -> grpc.StatusCode: - return self._error.code() - - async def details(self) -> str: - return self._error.details() - - async def debug_error_string(self) -> Optional[str]: - return self._error.debug_error_string() - - def __await__(self): - raise self._error - - class UnaryUnaryCallResponse(_base_call.UnaryUnaryCall): """Final UnaryUnaryCall class finished with a response.""" _response: ResponseType @@ -296,40 +291,3 @@ class UnaryUnaryCallResponse(_base_call.UnaryUnaryCall): # for telling the interpreter that __await__ is a generator. yield None return self._response - - -class UnaryUnaryCancelledError(_base_call.UnaryUnaryCall): - """Final UnaryUnaryCall class finished with an asyncio.CancelledError.""" - - def cancel(self) -> bool: - return False - - def cancelled(self) -> bool: - return True - - def done(self) -> bool: - return True - - def add_done_callback(self, unused_callback) -> None: - raise NotImplementedError() - - def time_remaining(self) -> Optional[float]: - raise NotImplementedError() - - async def initial_metadata(self) -> Optional[MetadataType]: - return None - - async def trailing_metadata(self) -> Optional[MetadataType]: - return None - - async def code(self) -> grpc.StatusCode: - return grpc.StatusCode.CANCELLED - - async def details(self) -> str: - return _LOCAL_CANCELLATION_DETAILS - - async def debug_error_string(self) -> Optional[str]: - return None - - def __await__(self): - raise asyncio.CancelledError() diff --git a/src/python/grpcio_tests/tests_aio/unit/interceptor_test.py b/src/python/grpcio_tests/tests_aio/unit/interceptor_test.py index f97fbe171d3..f39360d2f3e 100644 --- a/src/python/grpcio_tests/tests_aio/unit/interceptor_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/interceptor_test.py @@ -177,6 +177,7 @@ class TestUnaryUnaryClientInterceptor(AioTestBase): self.calls.append(call) + new_client_call_details = aio.ClientCallDetails( method=client_call_details.method, timeout=None, @@ -212,61 +213,6 @@ class TestUnaryUnaryClientInterceptor(AioTestBase): self.assertEqual(await interceptor.calls[1].code(), grpc.StatusCode.OK) - async def test_rpcerror_raised_when_call_is_awaited(self): - - class Interceptor(aio.UnaryUnaryClientInterceptor): - """RpcErrors are only seen when the call is awaited""" - - def __init__(self): - self.deadline_seen = False - - async def intercept_unary_unary(self, continuation, - client_call_details, request): - call = await continuation(client_call_details, request) - - try: - await call - except aio.AioRpcError as err: - if err.code() == grpc.StatusCode.DEADLINE_EXCEEDED: - self.deadline_seen = True - raise - - # This point should never be reached - raise Exception() - - interceptor_a, interceptor_b = (Interceptor(), Interceptor()) - server_target, server = await start_test_server() - - async with aio.insecure_channel( - server_target, interceptors=[interceptor_a, - interceptor_b]) as channel: - - multicallable = channel.unary_unary( - '/grpc.testing.TestService/UnaryCall', - request_serializer=messages_pb2.SimpleRequest.SerializeToString, - response_deserializer=messages_pb2.SimpleResponse.FromString) - - call = multicallable(messages_pb2.SimpleRequest(), timeout=0.1) - - with self.assertRaises(grpc.RpcError) as exception_context: - await call - - # Check that the two interceptors catch the deadline exception - # only when the call was awaited - self.assertTrue(interceptor_a.deadline_seen) - self.assertTrue(interceptor_b.deadline_seen) - - # Check all of the UnaryUnaryCallRpcError attributes - self.assertTrue(call.done()) - self.assertFalse(call.cancel()) - self.assertFalse(call.cancelled()) - self.assertEqual(await call.code(), - grpc.StatusCode.DEADLINE_EXCEEDED) - self.assertEqual(await call.details(), 'Deadline Exceeded') - self.assertEqual(await call.initial_metadata(), None) - self.assertEqual(await call.trailing_metadata(), ()) - self.assertEqual(await call.debug_error_string(), None) - async def test_rpcresponse(self): class Interceptor(aio.UnaryUnaryClientInterceptor): @@ -348,6 +294,106 @@ class TestInterceptedUnaryUnaryCall(AioTestBase): self.assertEqual(await call.initial_metadata(), ()) self.assertEqual(await call.trailing_metadata(), ()) + async def test_call_ok_awaited(self): + + class Interceptor(aio.UnaryUnaryClientInterceptor): + + async def intercept_unary_unary(self, continuation, + client_call_details, request): + call = await continuation(client_call_details, request) + await call + return call + + server_target, _ = await start_test_server() # pylint: disable=unused-variable + + async with aio.insecure_channel(server_target, + interceptors=[Interceptor() + ]) as channel: + + multicallable = channel.unary_unary( + '/grpc.testing.TestService/UnaryCall', + request_serializer=messages_pb2.SimpleRequest.SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString) + call = multicallable(messages_pb2.SimpleRequest()) + response = await call + + self.assertTrue(call.done()) + self.assertFalse(call.cancelled()) + self.assertEqual(type(response), messages_pb2.SimpleResponse) + self.assertEqual(await call.code(), grpc.StatusCode.OK) + self.assertEqual(await call.details(), '') + self.assertEqual(await call.initial_metadata(), ()) + self.assertEqual(await call.trailing_metadata(), ()) + + async def test_call_rpcerror(self): + + class Interceptor(aio.UnaryUnaryClientInterceptor): + + async def intercept_unary_unary(self, continuation, + client_call_details, request): + call = await continuation(client_call_details, request) + return call + + server_target, server = await start_test_server() # pylint: disable=unused-variable + + async with aio.insecure_channel(server_target, + interceptors=[Interceptor() + ]) as channel: + + multicallable = channel.unary_unary( + '/grpc.testing.TestService/UnaryCall', + request_serializer=messages_pb2.SimpleRequest.SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString) + + await server.stop(None) + + call = multicallable(messages_pb2.SimpleRequest(), timeout=0.1) + + with self.assertRaises(aio.AioRpcError) as exception_context: + await call + + self.assertTrue(call.done()) + self.assertFalse(call.cancelled()) + self.assertEqual(await call.code(), grpc.StatusCode.DEADLINE_EXCEEDED) + self.assertEqual(await call.details(), 'Deadline Exceeded') + self.assertEqual(await call.initial_metadata(), ()) + self.assertEqual(await call.trailing_metadata(), ()) + + async def test_call_rpcerror_awaited(self): + + class Interceptor(aio.UnaryUnaryClientInterceptor): + + async def intercept_unary_unary(self, continuation, + client_call_details, request): + call = await continuation(client_call_details, request) + await call + return call + + server_target, server = await start_test_server() # pylint: disable=unused-variable + + async with aio.insecure_channel(server_target, + interceptors=[Interceptor() + ]) as channel: + + multicallable = channel.unary_unary( + '/grpc.testing.TestService/UnaryCall', + request_serializer=messages_pb2.SimpleRequest.SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString) + + await server.stop(None) + + call = multicallable(messages_pb2.SimpleRequest(), timeout=0.1) + + with self.assertRaises(aio.AioRpcError) as exception_context: + await call + + self.assertTrue(call.done()) + self.assertFalse(call.cancelled()) + self.assertEqual(await call.code(), grpc.StatusCode.DEADLINE_EXCEEDED) + self.assertEqual(await call.details(), 'Deadline Exceeded') + self.assertEqual(await call.initial_metadata(), ()) + self.assertEqual(await call.trailing_metadata(), ()) + async def test_cancel_before_rpc(self): interceptor_reached = asyncio.Event()