|
|
|
@ -1,4 +1,4 @@ |
|
|
|
|
# Copyright 2019 The gRPC Authors. |
|
|
|
|
# Copyright 2020 The gRPC Authors. |
|
|
|
|
# |
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
|
|
|
# you may not use this file except in compliance with the License. |
|
|
|
@ -20,67 +20,34 @@ import grpc |
|
|
|
|
|
|
|
|
|
from grpc.experimental import aio |
|
|
|
|
from tests_aio.unit._constants import UNREACHABLE_TARGET |
|
|
|
|
from tests_aio.unit._common import inject_callbacks |
|
|
|
|
from tests_aio.unit._test_server import start_test_server |
|
|
|
|
from tests_aio.unit._test_base import AioTestBase |
|
|
|
|
from tests.unit.framework.common import test_constants |
|
|
|
|
from src.proto.grpc.testing import messages_pb2, test_pb2_grpc |
|
|
|
|
|
|
|
|
|
_SHORT_TIMEOUT_S = datetime.timedelta(seconds=1).total_seconds() |
|
|
|
|
_SHORT_TIMEOUT_S = 1.0 |
|
|
|
|
|
|
|
|
|
_LOCAL_CANCEL_DETAILS_EXPECTATION = 'Locally cancelled by application!' |
|
|
|
|
_NUM_STREAM_RESPONSES = 5 |
|
|
|
|
_REQUEST_PAYLOAD_SIZE = 7 |
|
|
|
|
_RESPONSE_PAYLOAD_SIZE = 7 |
|
|
|
|
_RESPONSE_INTERVAL_US = int(_SHORT_TIMEOUT_S * 1000 * 1000) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _ResponseIterator: |
|
|
|
|
class _CountingResponseIterator: |
|
|
|
|
|
|
|
|
|
def __init__(self, response_iterator): |
|
|
|
|
self._response_cnt = 0 |
|
|
|
|
self.response_cnt = 0 |
|
|
|
|
self._response_iterator = response_iterator |
|
|
|
|
|
|
|
|
|
async def _forward_responses(self): |
|
|
|
|
async for response in self._response_iterator: |
|
|
|
|
self._response_cnt += 1 |
|
|
|
|
self.response_cnt += 1 |
|
|
|
|
yield response |
|
|
|
|
|
|
|
|
|
def __aiter__(self): |
|
|
|
|
return self._forward_responses() |
|
|
|
|
|
|
|
|
|
@property |
|
|
|
|
def response_cnt(self): |
|
|
|
|
return self._response_cnt |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _inject_callbacks(call): |
|
|
|
|
first_callback_ran = asyncio.Event() |
|
|
|
|
|
|
|
|
|
def first_callback(call): |
|
|
|
|
# Validate that all resopnses have been received |
|
|
|
|
# and the call is an end state. |
|
|
|
|
assert call.done() |
|
|
|
|
first_callback_ran.set() |
|
|
|
|
|
|
|
|
|
second_callback_ran = asyncio.Event() |
|
|
|
|
|
|
|
|
|
def second_callback(call): |
|
|
|
|
# Validate that all resopnses have been received |
|
|
|
|
# and the call is an end state. |
|
|
|
|
assert call.done() |
|
|
|
|
second_callback_ran.set() |
|
|
|
|
|
|
|
|
|
call.add_done_callback(first_callback) |
|
|
|
|
call.add_done_callback(second_callback) |
|
|
|
|
|
|
|
|
|
async def validation(): |
|
|
|
|
await asyncio.wait_for( |
|
|
|
|
asyncio.gather(first_callback_ran.wait(), |
|
|
|
|
second_callback_ran.wait()), |
|
|
|
|
test_constants.SHORT_TIMEOUT) |
|
|
|
|
|
|
|
|
|
return validation() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _UnaryStreamInterceptorEmpty(aio.UnaryStreamClientInterceptor): |
|
|
|
|
|
|
|
|
@ -89,7 +56,7 @@ class _UnaryStreamInterceptorEmpty(aio.UnaryStreamClientInterceptor): |
|
|
|
|
return await continuation(client_call_details, request) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _UnaryStreamInterceptorWith_ResponseIterator( |
|
|
|
|
class _UnaryStreamInterceptorWithResponseIterator( |
|
|
|
|
aio.UnaryStreamClientInterceptor): |
|
|
|
|
|
|
|
|
|
def __init__(self): |
|
|
|
@ -98,7 +65,7 @@ class _UnaryStreamInterceptorWith_ResponseIterator( |
|
|
|
|
async def intercept_unary_stream(self, continuation, client_call_details, |
|
|
|
|
request): |
|
|
|
|
call = await continuation(client_call_details, request) |
|
|
|
|
self.response_iterator = _ResponseIterator(call) |
|
|
|
|
self.response_iterator = _CountingResponseIterator(call) |
|
|
|
|
return self.response_iterator |
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -112,16 +79,15 @@ class TestUnaryStreamClientInterceptor(AioTestBase): |
|
|
|
|
|
|
|
|
|
async def test_intercepts(self): |
|
|
|
|
for interceptor_class in (_UnaryStreamInterceptorEmpty, |
|
|
|
|
_UnaryStreamInterceptorWith_ResponseIterator): |
|
|
|
|
_UnaryStreamInterceptorWithResponseIterator): |
|
|
|
|
|
|
|
|
|
with self.subTest(name=interceptor_class): |
|
|
|
|
interceptor = interceptor_class() |
|
|
|
|
|
|
|
|
|
request = messages_pb2.StreamingOutputCallRequest() |
|
|
|
|
for _ in range(_NUM_STREAM_RESPONSES): |
|
|
|
|
request.response_parameters.append( |
|
|
|
|
messages_pb2.ResponseParameters( |
|
|
|
|
size=_RESPONSE_PAYLOAD_SIZE)) |
|
|
|
|
request.response_parameters.extend([ |
|
|
|
|
messages_pb2.ResponseParameters( |
|
|
|
|
size=_RESPONSE_PAYLOAD_SIZE)] * _NUM_STREAM_RESPONSES) |
|
|
|
|
|
|
|
|
|
channel = aio.insecure_channel(self._server_target, |
|
|
|
|
interceptors=[interceptor]) |
|
|
|
@ -138,7 +104,7 @@ class TestUnaryStreamClientInterceptor(AioTestBase): |
|
|
|
|
self.assertEqual(_RESPONSE_PAYLOAD_SIZE, |
|
|
|
|
len(response.payload.body)) |
|
|
|
|
|
|
|
|
|
self.assertTrue(response_cnt, _NUM_STREAM_RESPONSES) |
|
|
|
|
self.assertEqual(response_cnt, _NUM_STREAM_RESPONSES) |
|
|
|
|
self.assertEqual(await call.code(), grpc.StatusCode.OK) |
|
|
|
|
self.assertEqual(await call.initial_metadata(), ()) |
|
|
|
|
self.assertEqual(await call.trailing_metadata(), ()) |
|
|
|
@ -148,31 +114,30 @@ class TestUnaryStreamClientInterceptor(AioTestBase): |
|
|
|
|
self.assertEqual(call.cancelled(), False) |
|
|
|
|
self.assertEqual(call.done(), True) |
|
|
|
|
|
|
|
|
|
if interceptor_class == _UnaryStreamInterceptorWith_ResponseIterator: |
|
|
|
|
self.assertTrue(interceptor.response_iterator.response_cnt, |
|
|
|
|
if interceptor_class == _UnaryStreamInterceptorWithResponseIterator: |
|
|
|
|
self.assertEqual(interceptor.response_iterator.response_cnt, |
|
|
|
|
_NUM_STREAM_RESPONSES) |
|
|
|
|
|
|
|
|
|
await channel.close() |
|
|
|
|
|
|
|
|
|
async def test_add_done_callback(self): |
|
|
|
|
async def test_add_done_callback_interceptor_task_not_finished(self): |
|
|
|
|
for interceptor_class in (_UnaryStreamInterceptorEmpty, |
|
|
|
|
_UnaryStreamInterceptorWith_ResponseIterator): |
|
|
|
|
_UnaryStreamInterceptorWithResponseIterator): |
|
|
|
|
|
|
|
|
|
with self.subTest(name=interceptor_class): |
|
|
|
|
interceptor = interceptor_class() |
|
|
|
|
|
|
|
|
|
request = messages_pb2.StreamingOutputCallRequest() |
|
|
|
|
for _ in range(_NUM_STREAM_RESPONSES): |
|
|
|
|
request.response_parameters.append( |
|
|
|
|
messages_pb2.ResponseParameters( |
|
|
|
|
size=_RESPONSE_PAYLOAD_SIZE)) |
|
|
|
|
request.response_parameters.extend([ |
|
|
|
|
messages_pb2.ResponseParameters( |
|
|
|
|
size=_RESPONSE_PAYLOAD_SIZE)] * _NUM_STREAM_RESPONSES) |
|
|
|
|
|
|
|
|
|
channel = aio.insecure_channel(self._server_target, |
|
|
|
|
interceptors=[interceptor]) |
|
|
|
|
stub = test_pb2_grpc.TestServiceStub(channel) |
|
|
|
|
call = stub.StreamingOutputCall(request) |
|
|
|
|
|
|
|
|
|
validation = _inject_callbacks(call) |
|
|
|
|
validation = inject_callbacks(call) |
|
|
|
|
|
|
|
|
|
async for response in call: |
|
|
|
|
pass |
|
|
|
@ -181,18 +146,17 @@ class TestUnaryStreamClientInterceptor(AioTestBase): |
|
|
|
|
|
|
|
|
|
await channel.close() |
|
|
|
|
|
|
|
|
|
async def test_add_done_callback_after_connection(self): |
|
|
|
|
async def test_add_done_callback_interceptor_task_finished(self): |
|
|
|
|
for interceptor_class in (_UnaryStreamInterceptorEmpty, |
|
|
|
|
_UnaryStreamInterceptorWith_ResponseIterator): |
|
|
|
|
_UnaryStreamInterceptorWithResponseIterator): |
|
|
|
|
|
|
|
|
|
with self.subTest(name=interceptor_class): |
|
|
|
|
interceptor = interceptor_class() |
|
|
|
|
|
|
|
|
|
request = messages_pb2.StreamingOutputCallRequest() |
|
|
|
|
for _ in range(_NUM_STREAM_RESPONSES): |
|
|
|
|
request.response_parameters.append( |
|
|
|
|
messages_pb2.ResponseParameters( |
|
|
|
|
size=_RESPONSE_PAYLOAD_SIZE)) |
|
|
|
|
request.response_parameters.extend([ |
|
|
|
|
messages_pb2.ResponseParameters( |
|
|
|
|
size=_RESPONSE_PAYLOAD_SIZE)] * _NUM_STREAM_RESPONSES) |
|
|
|
|
|
|
|
|
|
channel = aio.insecure_channel(self._server_target, |
|
|
|
|
interceptors=[interceptor]) |
|
|
|
@ -204,7 +168,7 @@ class TestUnaryStreamClientInterceptor(AioTestBase): |
|
|
|
|
# pending state list. |
|
|
|
|
await call.wait_for_connection() |
|
|
|
|
|
|
|
|
|
validation = _inject_callbacks(call) |
|
|
|
|
validation = inject_callbacks(call) |
|
|
|
|
|
|
|
|
|
async for response in call: |
|
|
|
|
pass |
|
|
|
@ -214,16 +178,16 @@ class TestUnaryStreamClientInterceptor(AioTestBase): |
|
|
|
|
await channel.close() |
|
|
|
|
|
|
|
|
|
async def test_response_iterator_using_read(self): |
|
|
|
|
interceptor = _UnaryStreamInterceptorWith_ResponseIterator() |
|
|
|
|
interceptor = _UnaryStreamInterceptorWithResponseIterator() |
|
|
|
|
|
|
|
|
|
channel = aio.insecure_channel(self._server_target, |
|
|
|
|
interceptors=[interceptor]) |
|
|
|
|
stub = test_pb2_grpc.TestServiceStub(channel) |
|
|
|
|
|
|
|
|
|
request = messages_pb2.StreamingOutputCallRequest() |
|
|
|
|
for _ in range(_NUM_STREAM_RESPONSES): |
|
|
|
|
request.response_parameters.append( |
|
|
|
|
messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE)) |
|
|
|
|
request.response_parameters.extend([ |
|
|
|
|
messages_pb2.ResponseParameters( |
|
|
|
|
size=_RESPONSE_PAYLOAD_SIZE)] * _NUM_STREAM_RESPONSES) |
|
|
|
|
|
|
|
|
|
call = stub.StreamingOutputCall(request) |
|
|
|
|
|
|
|
|
@ -235,16 +199,16 @@ class TestUnaryStreamClientInterceptor(AioTestBase): |
|
|
|
|
messages_pb2.StreamingOutputCallResponse) |
|
|
|
|
self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) |
|
|
|
|
|
|
|
|
|
self.assertTrue(response_cnt, _NUM_STREAM_RESPONSES) |
|
|
|
|
self.assertTrue(interceptor.response_iterator.response_cnt, |
|
|
|
|
self.assertEqual(response_cnt, _NUM_STREAM_RESPONSES) |
|
|
|
|
self.assertEqual(interceptor.response_iterator.response_cnt, |
|
|
|
|
_NUM_STREAM_RESPONSES) |
|
|
|
|
self.assertEqual(await call.code(), grpc.StatusCode.OK) |
|
|
|
|
|
|
|
|
|
await channel.close() |
|
|
|
|
|
|
|
|
|
async def test_mulitple_interceptors_response_iterator(self): |
|
|
|
|
async def test_multiple_interceptors_response_iterator(self): |
|
|
|
|
for interceptor_class in (_UnaryStreamInterceptorEmpty, |
|
|
|
|
_UnaryStreamInterceptorWith_ResponseIterator): |
|
|
|
|
_UnaryStreamInterceptorWithResponseIterator): |
|
|
|
|
|
|
|
|
|
with self.subTest(name=interceptor_class): |
|
|
|
|
|
|
|
|
@ -255,10 +219,9 @@ class TestUnaryStreamClientInterceptor(AioTestBase): |
|
|
|
|
stub = test_pb2_grpc.TestServiceStub(channel) |
|
|
|
|
|
|
|
|
|
request = messages_pb2.StreamingOutputCallRequest() |
|
|
|
|
for _ in range(_NUM_STREAM_RESPONSES): |
|
|
|
|
request.response_parameters.append( |
|
|
|
|
messages_pb2.ResponseParameters( |
|
|
|
|
size=_RESPONSE_PAYLOAD_SIZE)) |
|
|
|
|
request.response_parameters.extend([ |
|
|
|
|
messages_pb2.ResponseParameters( |
|
|
|
|
size=_RESPONSE_PAYLOAD_SIZE)] * _NUM_STREAM_RESPONSES) |
|
|
|
|
|
|
|
|
|
call = stub.StreamingOutputCall(request) |
|
|
|
|
|
|
|
|
@ -270,14 +233,14 @@ class TestUnaryStreamClientInterceptor(AioTestBase): |
|
|
|
|
self.assertEqual(_RESPONSE_PAYLOAD_SIZE, |
|
|
|
|
len(response.payload.body)) |
|
|
|
|
|
|
|
|
|
self.assertTrue(response_cnt, _NUM_STREAM_RESPONSES) |
|
|
|
|
self.assertEqual(response_cnt, _NUM_STREAM_RESPONSES) |
|
|
|
|
self.assertEqual(await call.code(), grpc.StatusCode.OK) |
|
|
|
|
|
|
|
|
|
await channel.close() |
|
|
|
|
|
|
|
|
|
async def test_intercepts_response_iterator_rpc_error(self): |
|
|
|
|
for interceptor_class in (_UnaryStreamInterceptorEmpty, |
|
|
|
|
_UnaryStreamInterceptorWith_ResponseIterator): |
|
|
|
|
_UnaryStreamInterceptorWithResponseIterator): |
|
|
|
|
|
|
|
|
|
with self.subTest(name=interceptor_class): |
|
|
|
|
|
|
|
|
@ -329,8 +292,6 @@ class TestUnaryStreamClientInterceptor(AioTestBase): |
|
|
|
|
self.assertTrue(call.cancelled()) |
|
|
|
|
self.assertTrue(call.done()) |
|
|
|
|
self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED) |
|
|
|
|
self.assertEqual(await call.details(), |
|
|
|
|
_LOCAL_CANCEL_DETAILS_EXPECTATION) |
|
|
|
|
self.assertEqual(await call.initial_metadata(), None) |
|
|
|
|
self.assertEqual(await call.trailing_metadata(), None) |
|
|
|
|
await channel.close() |
|
|
|
@ -367,23 +328,19 @@ class TestUnaryStreamClientInterceptor(AioTestBase): |
|
|
|
|
self.assertTrue(call.cancelled()) |
|
|
|
|
self.assertTrue(call.done()) |
|
|
|
|
self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED) |
|
|
|
|
self.assertEqual(await call.details(), |
|
|
|
|
_LOCAL_CANCEL_DETAILS_EXPECTATION) |
|
|
|
|
self.assertEqual(await call.initial_metadata(), None) |
|
|
|
|
self.assertEqual(await call.trailing_metadata(), None) |
|
|
|
|
await channel.close() |
|
|
|
|
|
|
|
|
|
async def test_cancel_consuming_response_iterator(self): |
|
|
|
|
request = messages_pb2.StreamingOutputCallRequest() |
|
|
|
|
for _ in range(_NUM_STREAM_RESPONSES): |
|
|
|
|
request.response_parameters.append( |
|
|
|
|
messages_pb2.ResponseParameters( |
|
|
|
|
size=_RESPONSE_PAYLOAD_SIZE, |
|
|
|
|
interval_us=_RESPONSE_INTERVAL_US)) |
|
|
|
|
request.response_parameters.extend([ |
|
|
|
|
messages_pb2.ResponseParameters( |
|
|
|
|
size=_RESPONSE_PAYLOAD_SIZE)] * _NUM_STREAM_RESPONSES) |
|
|
|
|
|
|
|
|
|
channel = aio.insecure_channel( |
|
|
|
|
self._server_target, |
|
|
|
|
interceptors=[_UnaryStreamInterceptorWith_ResponseIterator()]) |
|
|
|
|
interceptors=[_UnaryStreamInterceptorWithResponseIterator()]) |
|
|
|
|
stub = test_pb2_grpc.TestServiceStub(channel) |
|
|
|
|
call = stub.StreamingOutputCall(request) |
|
|
|
|
|
|
|
|
@ -394,10 +351,57 @@ class TestUnaryStreamClientInterceptor(AioTestBase): |
|
|
|
|
self.assertTrue(call.cancelled()) |
|
|
|
|
self.assertTrue(call.done()) |
|
|
|
|
self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED) |
|
|
|
|
self.assertEqual(await call.details(), |
|
|
|
|
_LOCAL_CANCEL_DETAILS_EXPECTATION) |
|
|
|
|
await channel.close() |
|
|
|
|
|
|
|
|
|
async def test_cancel_by_the_interceptor(self): |
|
|
|
|
|
|
|
|
|
class Interceptor(aio.UnaryStreamClientInterceptor): |
|
|
|
|
|
|
|
|
|
async def intercept_unary_stream(self, continuation, |
|
|
|
|
client_call_details, request): |
|
|
|
|
call = await continuation(client_call_details, request) |
|
|
|
|
call.cancel() |
|
|
|
|
return call |
|
|
|
|
|
|
|
|
|
channel = aio.insecure_channel(UNREACHABLE_TARGET, |
|
|
|
|
interceptors=[Interceptor()]) |
|
|
|
|
request = messages_pb2.StreamingOutputCallRequest() |
|
|
|
|
stub = test_pb2_grpc.TestServiceStub(channel) |
|
|
|
|
call = stub.StreamingOutputCall(request) |
|
|
|
|
|
|
|
|
|
with self.assertRaises(asyncio.CancelledError): |
|
|
|
|
async for response in call: |
|
|
|
|
pass |
|
|
|
|
|
|
|
|
|
self.assertTrue(call.cancelled()) |
|
|
|
|
self.assertTrue(call.done()) |
|
|
|
|
self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED) |
|
|
|
|
await channel.close() |
|
|
|
|
|
|
|
|
|
async def test_exception_raised_by_interceptor(self): |
|
|
|
|
|
|
|
|
|
class InterceptorException(Exception): |
|
|
|
|
pass |
|
|
|
|
|
|
|
|
|
class Interceptor(aio.UnaryStreamClientInterceptor): |
|
|
|
|
|
|
|
|
|
async def intercept_unary_stream(self, continuation, |
|
|
|
|
client_call_details, request): |
|
|
|
|
raise InterceptorException |
|
|
|
|
|
|
|
|
|
channel = aio.insecure_channel(UNREACHABLE_TARGET, |
|
|
|
|
interceptors=[Interceptor()]) |
|
|
|
|
request = messages_pb2.StreamingOutputCallRequest() |
|
|
|
|
stub = test_pb2_grpc.TestServiceStub(channel) |
|
|
|
|
call = stub.StreamingOutputCall(request) |
|
|
|
|
|
|
|
|
|
with self.assertRaises(InterceptorException): |
|
|
|
|
async for response in call: |
|
|
|
|
pass |
|
|
|
|
|
|
|
|
|
await channel.close() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
logging.basicConfig(level=logging.DEBUG) |
|
|
|
|