|
|
|
@ -14,9 +14,7 @@ |
|
|
|
|
"""Testing the done callbacks mechanism.""" |
|
|
|
|
|
|
|
|
|
import asyncio |
|
|
|
|
import gc |
|
|
|
|
import logging |
|
|
|
|
import time |
|
|
|
|
import unittest |
|
|
|
|
|
|
|
|
|
import grpc |
|
|
|
@ -24,7 +22,6 @@ from grpc.experimental import aio |
|
|
|
|
|
|
|
|
|
from src.proto.grpc.testing import messages_pb2 |
|
|
|
|
from src.proto.grpc.testing import test_pb2_grpc |
|
|
|
|
from tests.unit.framework.common import test_constants |
|
|
|
|
from tests_aio.unit._common import inject_callbacks |
|
|
|
|
from tests_aio.unit._test_base import AioTestBase |
|
|
|
|
from tests_aio.unit._test_server import start_test_server |
|
|
|
@ -32,9 +29,13 @@ from tests_aio.unit._test_server import start_test_server |
|
|
|
|
_NUM_STREAM_RESPONSES = 5 |
|
|
|
|
_REQUEST_PAYLOAD_SIZE = 7 |
|
|
|
|
_RESPONSE_PAYLOAD_SIZE = 42 |
|
|
|
|
_REQUEST = b'\x01\x02\x03' |
|
|
|
|
_RESPONSE = b'\x04\x05\x06' |
|
|
|
|
_TEST_METHOD = '/test/Test' |
|
|
|
|
_FAKE_METHOD = '/test/Fake' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestDoneCallback(AioTestBase): |
|
|
|
|
class TestClientSideDoneCallback(AioTestBase): |
|
|
|
|
|
|
|
|
|
async def setUp(self): |
|
|
|
|
address, self._server = await start_test_server() |
|
|
|
@ -121,6 +122,155 @@ class TestDoneCallback(AioTestBase): |
|
|
|
|
await validation |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestServerSideDoneCallback(AioTestBase): |
|
|
|
|
|
|
|
|
|
async def setUp(self): |
|
|
|
|
self._server = aio.server() |
|
|
|
|
port = self._server.add_insecure_port('[::]:0') |
|
|
|
|
self._channel = aio.insecure_channel('localhost:%d' % port) |
|
|
|
|
|
|
|
|
|
async def tearDown(self): |
|
|
|
|
await self._channel.close() |
|
|
|
|
await self._server.stop(None) |
|
|
|
|
|
|
|
|
|
async def _register_method_handler(self, method_handler): |
|
|
|
|
"""Registers method handler and starts the server""" |
|
|
|
|
generic_handler = grpc.method_handlers_generic_handler( |
|
|
|
|
'test', |
|
|
|
|
dict(Test=method_handler), |
|
|
|
|
) |
|
|
|
|
self._server.add_generic_rpc_handlers((generic_handler,)) |
|
|
|
|
await self._server.start() |
|
|
|
|
|
|
|
|
|
async def test_unary_unary(self): |
|
|
|
|
validation_future = self.loop.create_future() |
|
|
|
|
|
|
|
|
|
async def test_handler(request: bytes, context: aio.ServicerContext): |
|
|
|
|
self.assertEqual(_REQUEST, request) |
|
|
|
|
validation_future.set_result(inject_callbacks(context)) |
|
|
|
|
return _RESPONSE |
|
|
|
|
|
|
|
|
|
await self._register_method_handler( |
|
|
|
|
grpc.unary_unary_rpc_method_handler(test_handler)) |
|
|
|
|
response = await self._channel.unary_unary(_TEST_METHOD)(_REQUEST) |
|
|
|
|
self.assertEqual(_RESPONSE, response) |
|
|
|
|
|
|
|
|
|
validation = await validation_future |
|
|
|
|
await validation |
|
|
|
|
|
|
|
|
|
async def test_unary_stream(self): |
|
|
|
|
validation_future = self.loop.create_future() |
|
|
|
|
|
|
|
|
|
async def test_handler(request: bytes, context: aio.ServicerContext): |
|
|
|
|
self.assertEqual(_REQUEST, request) |
|
|
|
|
validation_future.set_result(inject_callbacks(context)) |
|
|
|
|
for _ in range(_NUM_STREAM_RESPONSES): |
|
|
|
|
yield _RESPONSE |
|
|
|
|
|
|
|
|
|
await self._register_method_handler( |
|
|
|
|
grpc.unary_stream_rpc_method_handler(test_handler)) |
|
|
|
|
call = self._channel.unary_stream(_TEST_METHOD)(_REQUEST) |
|
|
|
|
async for response in call: |
|
|
|
|
self.assertEqual(_RESPONSE, response) |
|
|
|
|
|
|
|
|
|
validation = await validation_future |
|
|
|
|
await validation |
|
|
|
|
|
|
|
|
|
async def test_stream_unary(self): |
|
|
|
|
validation_future = self.loop.create_future() |
|
|
|
|
|
|
|
|
|
async def test_handler(request_iterator, context: aio.ServicerContext): |
|
|
|
|
validation_future.set_result(inject_callbacks(context)) |
|
|
|
|
|
|
|
|
|
async for request in request_iterator: |
|
|
|
|
self.assertEqual(_REQUEST, request) |
|
|
|
|
return _RESPONSE |
|
|
|
|
|
|
|
|
|
await self._register_method_handler( |
|
|
|
|
grpc.stream_unary_rpc_method_handler(test_handler)) |
|
|
|
|
call = self._channel.stream_unary(_TEST_METHOD)() |
|
|
|
|
for _ in range(_NUM_STREAM_RESPONSES): |
|
|
|
|
await call.write(_REQUEST) |
|
|
|
|
await call.done_writing() |
|
|
|
|
self.assertEqual(_RESPONSE, await call) |
|
|
|
|
|
|
|
|
|
validation = await validation_future |
|
|
|
|
await validation |
|
|
|
|
|
|
|
|
|
async def test_stream_stream(self): |
|
|
|
|
validation_future = self.loop.create_future() |
|
|
|
|
|
|
|
|
|
async def test_handler(request_iterator, context: aio.ServicerContext): |
|
|
|
|
validation_future.set_result(inject_callbacks(context)) |
|
|
|
|
|
|
|
|
|
async for request in request_iterator: |
|
|
|
|
self.assertEqual(_REQUEST, request) |
|
|
|
|
return _RESPONSE |
|
|
|
|
|
|
|
|
|
await self._register_method_handler( |
|
|
|
|
grpc.stream_stream_rpc_method_handler(test_handler)) |
|
|
|
|
call = self._channel.stream_stream(_TEST_METHOD)() |
|
|
|
|
for _ in range(_NUM_STREAM_RESPONSES): |
|
|
|
|
await call.write(_REQUEST) |
|
|
|
|
await call.done_writing() |
|
|
|
|
async for response in call: |
|
|
|
|
self.assertEqual(_RESPONSE, response) |
|
|
|
|
|
|
|
|
|
validation = await validation_future |
|
|
|
|
await validation |
|
|
|
|
|
|
|
|
|
async def test_error_in_handler(self): |
|
|
|
|
"""Errors in the handler still triggers callbacks.""" |
|
|
|
|
validation_future = self.loop.create_future() |
|
|
|
|
|
|
|
|
|
async def test_handler(request: bytes, context: aio.ServicerContext): |
|
|
|
|
self.assertEqual(_REQUEST, request) |
|
|
|
|
validation_future.set_result(inject_callbacks(context)) |
|
|
|
|
raise RuntimeError('A test RuntimeError') |
|
|
|
|
|
|
|
|
|
await self._register_method_handler( |
|
|
|
|
grpc.unary_unary_rpc_method_handler(test_handler)) |
|
|
|
|
with self.assertRaises(aio.AioRpcError) as exception_context: |
|
|
|
|
await self._channel.unary_unary(_TEST_METHOD)(_REQUEST) |
|
|
|
|
rpc_error = exception_context.exception |
|
|
|
|
self.assertEqual(grpc.StatusCode.UNKNOWN, rpc_error.code()) |
|
|
|
|
|
|
|
|
|
validation = await validation_future |
|
|
|
|
await validation |
|
|
|
|
|
|
|
|
|
async def test_error_in_callback(self): |
|
|
|
|
"""Errors in the callback won't be propagated to client.""" |
|
|
|
|
validation_future = self.loop.create_future() |
|
|
|
|
|
|
|
|
|
async def test_handler(request: bytes, context: aio.ServicerContext): |
|
|
|
|
self.assertEqual(_REQUEST, request) |
|
|
|
|
|
|
|
|
|
def exception_raiser(unused_context): |
|
|
|
|
raise RuntimeError('A test RuntimeError') |
|
|
|
|
|
|
|
|
|
context.add_done_callback(exception_raiser) |
|
|
|
|
validation_future.set_result(inject_callbacks(context)) |
|
|
|
|
return _RESPONSE |
|
|
|
|
|
|
|
|
|
await self._register_method_handler( |
|
|
|
|
grpc.unary_unary_rpc_method_handler(test_handler)) |
|
|
|
|
|
|
|
|
|
response = await self._channel.unary_unary(_TEST_METHOD)(_REQUEST) |
|
|
|
|
self.assertEqual(_RESPONSE, response) |
|
|
|
|
|
|
|
|
|
# Following callbacks won't be invoked, if one of the callback crashed. |
|
|
|
|
validation = await validation_future |
|
|
|
|
with self.assertRaises(asyncio.TimeoutError): |
|
|
|
|
await validation |
|
|
|
|
|
|
|
|
|
# Invoke RPC one more time to ensure the toxic callback won't break the |
|
|
|
|
# server. |
|
|
|
|
with self.assertRaises(aio.AioRpcError) as exception_context: |
|
|
|
|
await self._channel.unary_unary(_FAKE_METHOD)(_REQUEST) |
|
|
|
|
rpc_error = exception_context.exception |
|
|
|
|
self.assertEqual(grpc.StatusCode.UNIMPLEMENTED, rpc_error.code()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
logging.basicConfig(level=logging.DEBUG) |
|
|
|
|
unittest.main(verbosity=2) |
|
|
|
|