|
|
|
@ -33,6 +33,8 @@ _LOCAL_CANCEL_DETAILS_EXPECTATION = 'Locally cancelled by application!' |
|
|
|
|
_RESPONSE_INTERVAL_US = test_constants.SHORT_TIMEOUT * 1000 * 1000 |
|
|
|
|
_UNREACHABLE_TARGET = '0.1:1111' |
|
|
|
|
|
|
|
|
|
_INFINITE_INTERVAL_US = 2**31-1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestUnaryUnaryCall(AioTestBase): |
|
|
|
|
|
|
|
|
@ -143,6 +145,29 @@ class TestUnaryUnaryCall(AioTestBase): |
|
|
|
|
self.assertEqual(await call.details(), |
|
|
|
|
'Locally cancelled by application!') |
|
|
|
|
|
|
|
|
|
async def test_cancel_unary_unary_in_task(self): |
|
|
|
|
async with aio.insecure_channel(self._server_target) as channel: |
|
|
|
|
stub = test_pb2_grpc.TestServiceStub(channel) |
|
|
|
|
coro_started = asyncio.Event() |
|
|
|
|
call = stub.EmptyCall(messages_pb2.SimpleRequest()) |
|
|
|
|
|
|
|
|
|
async def another_coro(): |
|
|
|
|
coro_started.set() |
|
|
|
|
await call |
|
|
|
|
|
|
|
|
|
task = self.loop.create_task(another_coro()) |
|
|
|
|
await coro_started.wait() |
|
|
|
|
|
|
|
|
|
self.assertFalse(task.done()) |
|
|
|
|
task.cancel() |
|
|
|
|
|
|
|
|
|
self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) |
|
|
|
|
|
|
|
|
|
with self.assertRaises(grpc.RpcError) as exception_context: |
|
|
|
|
await task |
|
|
|
|
self.assertEqual(grpc.StatusCode.CANCELLED, |
|
|
|
|
exception_context.exception.code()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestUnaryStreamCall(AioTestBase): |
|
|
|
|
|
|
|
|
@ -328,6 +353,73 @@ class TestUnaryStreamCall(AioTestBase): |
|
|
|
|
|
|
|
|
|
self.assertEqual(await call.code(), grpc.StatusCode.OK) |
|
|
|
|
|
|
|
|
|
async def test_cancel_unary_stream_in_task_using_read(self): |
|
|
|
|
async with aio.insecure_channel(self._server_target) as channel: |
|
|
|
|
stub = test_pb2_grpc.TestServiceStub(channel) |
|
|
|
|
coro_started = asyncio.Event() |
|
|
|
|
|
|
|
|
|
# Configs the server method to block forever |
|
|
|
|
request = messages_pb2.StreamingOutputCallRequest() |
|
|
|
|
request.response_parameters.append( |
|
|
|
|
messages_pb2.ResponseParameters( |
|
|
|
|
size=_RESPONSE_PAYLOAD_SIZE, |
|
|
|
|
interval_us=_INFINITE_INTERVAL_US, |
|
|
|
|
)) |
|
|
|
|
|
|
|
|
|
# Invokes the actual RPC |
|
|
|
|
call = stub.StreamingOutputCall(request) |
|
|
|
|
|
|
|
|
|
async def another_coro(): |
|
|
|
|
coro_started.set() |
|
|
|
|
await call.read() |
|
|
|
|
|
|
|
|
|
task = self.loop.create_task(another_coro()) |
|
|
|
|
await coro_started.wait() |
|
|
|
|
|
|
|
|
|
self.assertFalse(task.done()) |
|
|
|
|
task.cancel() |
|
|
|
|
|
|
|
|
|
self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) |
|
|
|
|
|
|
|
|
|
with self.assertRaises(grpc.RpcError) as exception_context: |
|
|
|
|
await task |
|
|
|
|
self.assertEqual(grpc.StatusCode.CANCELLED, |
|
|
|
|
exception_context.exception.code()) |
|
|
|
|
|
|
|
|
|
async def test_cancel_unary_stream_in_task_using_async_for(self): |
|
|
|
|
async with aio.insecure_channel(self._server_target) as channel: |
|
|
|
|
stub = test_pb2_grpc.TestServiceStub(channel) |
|
|
|
|
coro_started = asyncio.Event() |
|
|
|
|
|
|
|
|
|
# Configs the server method to block forever |
|
|
|
|
request = messages_pb2.StreamingOutputCallRequest() |
|
|
|
|
request.response_parameters.append( |
|
|
|
|
messages_pb2.ResponseParameters( |
|
|
|
|
size=_RESPONSE_PAYLOAD_SIZE, |
|
|
|
|
interval_us=_INFINITE_INTERVAL_US, |
|
|
|
|
)) |
|
|
|
|
|
|
|
|
|
# Invokes the actual RPC |
|
|
|
|
call = stub.StreamingOutputCall(request) |
|
|
|
|
|
|
|
|
|
async def another_coro(): |
|
|
|
|
coro_started.set() |
|
|
|
|
async for _ in call: |
|
|
|
|
pass |
|
|
|
|
|
|
|
|
|
task = self.loop.create_task(another_coro()) |
|
|
|
|
await coro_started.wait() |
|
|
|
|
|
|
|
|
|
self.assertFalse(task.done()) |
|
|
|
|
task.cancel() |
|
|
|
|
|
|
|
|
|
self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) |
|
|
|
|
|
|
|
|
|
with self.assertRaises(grpc.RpcError) as exception_context: |
|
|
|
|
await task |
|
|
|
|
self.assertEqual(grpc.StatusCode.CANCELLED, |
|
|
|
|
exception_context.exception.code()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
logging.basicConfig(level=logging.DEBUG) |
|
|
|
|