|
|
|
@ -177,6 +177,7 @@ class TestUnaryUnaryClientInterceptor(AioTestBase): |
|
|
|
|
|
|
|
|
|
self.calls.append(call) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
new_client_call_details = aio.ClientCallDetails( |
|
|
|
|
method=client_call_details.method, |
|
|
|
|
timeout=None, |
|
|
|
@ -212,61 +213,6 @@ class TestUnaryUnaryClientInterceptor(AioTestBase): |
|
|
|
|
self.assertEqual(await interceptor.calls[1].code(), |
|
|
|
|
grpc.StatusCode.OK) |
|
|
|
|
|
|
|
|
|
async def test_rpcerror_raised_when_call_is_awaited(self): |
|
|
|
|
|
|
|
|
|
class Interceptor(aio.UnaryUnaryClientInterceptor): |
|
|
|
|
"""RpcErrors are only seen when the call is awaited""" |
|
|
|
|
|
|
|
|
|
def __init__(self): |
|
|
|
|
self.deadline_seen = False |
|
|
|
|
|
|
|
|
|
async def intercept_unary_unary(self, continuation, |
|
|
|
|
client_call_details, request): |
|
|
|
|
call = await continuation(client_call_details, request) |
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
await call |
|
|
|
|
except aio.AioRpcError as err: |
|
|
|
|
if err.code() == grpc.StatusCode.DEADLINE_EXCEEDED: |
|
|
|
|
self.deadline_seen = True |
|
|
|
|
raise |
|
|
|
|
|
|
|
|
|
# This point should never be reached |
|
|
|
|
raise Exception() |
|
|
|
|
|
|
|
|
|
interceptor_a, interceptor_b = (Interceptor(), Interceptor()) |
|
|
|
|
server_target, server = await start_test_server() |
|
|
|
|
|
|
|
|
|
async with aio.insecure_channel( |
|
|
|
|
server_target, interceptors=[interceptor_a, |
|
|
|
|
interceptor_b]) as channel: |
|
|
|
|
|
|
|
|
|
multicallable = channel.unary_unary( |
|
|
|
|
'/grpc.testing.TestService/UnaryCall', |
|
|
|
|
request_serializer=messages_pb2.SimpleRequest.SerializeToString, |
|
|
|
|
response_deserializer=messages_pb2.SimpleResponse.FromString) |
|
|
|
|
|
|
|
|
|
call = multicallable(messages_pb2.SimpleRequest(), timeout=0.1) |
|
|
|
|
|
|
|
|
|
with self.assertRaises(grpc.RpcError) as exception_context: |
|
|
|
|
await call |
|
|
|
|
|
|
|
|
|
# Check that the two interceptors catch the deadline exception |
|
|
|
|
# only when the call was awaited |
|
|
|
|
self.assertTrue(interceptor_a.deadline_seen) |
|
|
|
|
self.assertTrue(interceptor_b.deadline_seen) |
|
|
|
|
|
|
|
|
|
# Check all of the UnaryUnaryCallRpcError attributes |
|
|
|
|
self.assertTrue(call.done()) |
|
|
|
|
self.assertFalse(call.cancel()) |
|
|
|
|
self.assertFalse(call.cancelled()) |
|
|
|
|
self.assertEqual(await call.code(), |
|
|
|
|
grpc.StatusCode.DEADLINE_EXCEEDED) |
|
|
|
|
self.assertEqual(await call.details(), 'Deadline Exceeded') |
|
|
|
|
self.assertEqual(await call.initial_metadata(), None) |
|
|
|
|
self.assertEqual(await call.trailing_metadata(), ()) |
|
|
|
|
self.assertEqual(await call.debug_error_string(), None) |
|
|
|
|
|
|
|
|
|
async def test_rpcresponse(self): |
|
|
|
|
|
|
|
|
|
class Interceptor(aio.UnaryUnaryClientInterceptor): |
|
|
|
@ -348,6 +294,106 @@ class TestInterceptedUnaryUnaryCall(AioTestBase): |
|
|
|
|
self.assertEqual(await call.initial_metadata(), ()) |
|
|
|
|
self.assertEqual(await call.trailing_metadata(), ()) |
|
|
|
|
|
|
|
|
|
async def test_call_ok_awaited(self): |
|
|
|
|
|
|
|
|
|
class Interceptor(aio.UnaryUnaryClientInterceptor): |
|
|
|
|
|
|
|
|
|
async def intercept_unary_unary(self, continuation, |
|
|
|
|
client_call_details, request): |
|
|
|
|
call = await continuation(client_call_details, request) |
|
|
|
|
await call |
|
|
|
|
return call |
|
|
|
|
|
|
|
|
|
server_target, _ = await start_test_server() # pylint: disable=unused-variable |
|
|
|
|
|
|
|
|
|
async with aio.insecure_channel(server_target, |
|
|
|
|
interceptors=[Interceptor() |
|
|
|
|
]) as channel: |
|
|
|
|
|
|
|
|
|
multicallable = channel.unary_unary( |
|
|
|
|
'/grpc.testing.TestService/UnaryCall', |
|
|
|
|
request_serializer=messages_pb2.SimpleRequest.SerializeToString, |
|
|
|
|
response_deserializer=messages_pb2.SimpleResponse.FromString) |
|
|
|
|
call = multicallable(messages_pb2.SimpleRequest()) |
|
|
|
|
response = await call |
|
|
|
|
|
|
|
|
|
self.assertTrue(call.done()) |
|
|
|
|
self.assertFalse(call.cancelled()) |
|
|
|
|
self.assertEqual(type(response), messages_pb2.SimpleResponse) |
|
|
|
|
self.assertEqual(await call.code(), grpc.StatusCode.OK) |
|
|
|
|
self.assertEqual(await call.details(), '') |
|
|
|
|
self.assertEqual(await call.initial_metadata(), ()) |
|
|
|
|
self.assertEqual(await call.trailing_metadata(), ()) |
|
|
|
|
|
|
|
|
|
async def test_call_rpcerror(self): |
|
|
|
|
|
|
|
|
|
class Interceptor(aio.UnaryUnaryClientInterceptor): |
|
|
|
|
|
|
|
|
|
async def intercept_unary_unary(self, continuation, |
|
|
|
|
client_call_details, request): |
|
|
|
|
call = await continuation(client_call_details, request) |
|
|
|
|
return call |
|
|
|
|
|
|
|
|
|
server_target, server = await start_test_server() # pylint: disable=unused-variable |
|
|
|
|
|
|
|
|
|
async with aio.insecure_channel(server_target, |
|
|
|
|
interceptors=[Interceptor() |
|
|
|
|
]) as channel: |
|
|
|
|
|
|
|
|
|
multicallable = channel.unary_unary( |
|
|
|
|
'/grpc.testing.TestService/UnaryCall', |
|
|
|
|
request_serializer=messages_pb2.SimpleRequest.SerializeToString, |
|
|
|
|
response_deserializer=messages_pb2.SimpleResponse.FromString) |
|
|
|
|
|
|
|
|
|
await server.stop(None) |
|
|
|
|
|
|
|
|
|
call = multicallable(messages_pb2.SimpleRequest(), timeout=0.1) |
|
|
|
|
|
|
|
|
|
with self.assertRaises(aio.AioRpcError) as exception_context: |
|
|
|
|
await call |
|
|
|
|
|
|
|
|
|
self.assertTrue(call.done()) |
|
|
|
|
self.assertFalse(call.cancelled()) |
|
|
|
|
self.assertEqual(await call.code(), grpc.StatusCode.DEADLINE_EXCEEDED) |
|
|
|
|
self.assertEqual(await call.details(), 'Deadline Exceeded') |
|
|
|
|
self.assertEqual(await call.initial_metadata(), ()) |
|
|
|
|
self.assertEqual(await call.trailing_metadata(), ()) |
|
|
|
|
|
|
|
|
|
async def test_call_rpcerror_awaited(self): |
|
|
|
|
|
|
|
|
|
class Interceptor(aio.UnaryUnaryClientInterceptor): |
|
|
|
|
|
|
|
|
|
async def intercept_unary_unary(self, continuation, |
|
|
|
|
client_call_details, request): |
|
|
|
|
call = await continuation(client_call_details, request) |
|
|
|
|
await call |
|
|
|
|
return call |
|
|
|
|
|
|
|
|
|
server_target, server = await start_test_server() # pylint: disable=unused-variable |
|
|
|
|
|
|
|
|
|
async with aio.insecure_channel(server_target, |
|
|
|
|
interceptors=[Interceptor() |
|
|
|
|
]) as channel: |
|
|
|
|
|
|
|
|
|
multicallable = channel.unary_unary( |
|
|
|
|
'/grpc.testing.TestService/UnaryCall', |
|
|
|
|
request_serializer=messages_pb2.SimpleRequest.SerializeToString, |
|
|
|
|
response_deserializer=messages_pb2.SimpleResponse.FromString) |
|
|
|
|
|
|
|
|
|
await server.stop(None) |
|
|
|
|
|
|
|
|
|
call = multicallable(messages_pb2.SimpleRequest(), timeout=0.1) |
|
|
|
|
|
|
|
|
|
with self.assertRaises(aio.AioRpcError) as exception_context: |
|
|
|
|
await call |
|
|
|
|
|
|
|
|
|
self.assertTrue(call.done()) |
|
|
|
|
self.assertFalse(call.cancelled()) |
|
|
|
|
self.assertEqual(await call.code(), grpc.StatusCode.DEADLINE_EXCEEDED) |
|
|
|
|
self.assertEqual(await call.details(), 'Deadline Exceeded') |
|
|
|
|
self.assertEqual(await call.initial_metadata(), ()) |
|
|
|
|
self.assertEqual(await call.trailing_metadata(), ()) |
|
|
|
|
|
|
|
|
|
async def test_cancel_before_rpc(self): |
|
|
|
|
|
|
|
|
|
interceptor_reached = asyncio.Event() |
|
|
|
|