diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/callbackcontext.pxd.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/callbackcontext.pxd.pxi index 8e52c856dd2..beada919f4d 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/callbackcontext.pxd.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/callbackcontext.pxd.pxi @@ -15,6 +15,15 @@ cimport cpython cdef struct CallbackContext: + # C struct to store callback context in the form of pointers. + # + # Attributes: + # functor: A grpc_experimental_completion_queue_functor represents the + # callback function in the only way C-Core understands. + # waiter: An asyncio.Future object that fulfills when the callback is + # invoked by C-Core. + # failure_handler: A CallbackFailureHandler object that called when C-Core + # returns 'success == 0' state. grpc_experimental_completion_queue_functor functor cpython.PyObject *waiter - + cpython.PyObject *failure_handler diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/iomgr/socket.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/iomgr/socket.pyx.pxi index 2d56a568348..eed496df7c6 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/iomgr/socket.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/iomgr/socket.pyx.pxi @@ -152,6 +152,13 @@ cdef class _AsyncioSocket: cdef void close(self): if self.is_connected(): self._writer.close() + if self._server: + self._server.close() + # NOTE(lidiz) If the asyncio.Server is created from a Python socket, + # the server.close() won't release the fd until the close() is called + # for the Python socket. + if self._py_socket: + self._py_socket.close() def _new_connection_callback(self, object reader, object writer): client_socket = _AsyncioSocket.create( diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pxd.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pxd.pxi index 1906463d088..ca4a6a837ea 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pxd.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pxd.pxi @@ -25,16 +25,33 @@ cdef class RPCState: cdef bytes method(self) +cdef class CallbackWrapper: + cdef CallbackContext context + cdef object _reference_of_future + cdef object _reference_of_failure_handler + + @staticmethod + cdef void functor_run( + grpc_experimental_completion_queue_functor* functor, + int succeed) + + cdef grpc_experimental_completion_queue_functor *c_functor(self) + + cdef enum AioServerStatus: AIO_SERVER_STATUS_UNKNOWN AIO_SERVER_STATUS_READY AIO_SERVER_STATUS_RUNNING AIO_SERVER_STATUS_STOPPED + AIO_SERVER_STATUS_STOPPING cdef class _CallbackCompletionQueue: cdef grpc_completion_queue *_cq cdef grpc_completion_queue* c_ptr(self) + cdef object _shutdown_completed # asyncio.Future + cdef CallbackWrapper _wrapper + cdef object _loop # asyncio.EventLoop cdef class AioServer: @@ -42,3 +59,9 @@ cdef class AioServer: cdef _CallbackCompletionQueue _cq cdef list _generic_handlers cdef AioServerStatus _status + cdef object _loop # asyncio.EventLoop + cdef object _serving_task # asyncio.Task + cdef object _shutdown_lock # asyncio.Lock + cdef object _shutdown_completed # asyncio.Future + cdef CallbackWrapper _shutdown_callback_wrapper + cdef object _crash_exception # Exception diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi index dd6ff8b29d8..61335ca9e60 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi @@ -12,6 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +# TODO(https://github.com/grpc/grpc/issues/20850) refactor this. +_LOGGER = logging.getLogger(__name__) +cdef int _EMPTY_FLAG = 0 + + cdef class _HandlerCallDetails: def __cinit__(self, str method, tuple invocation_metadata): self.method = method @@ -21,16 +26,38 @@ cdef class _HandlerCallDetails: class _ServicerContextPlaceHolder(object): pass +cdef class _CallbackFailureHandler: + cdef str _core_function_name + cdef object _error_details + cdef object _exception_type + + def __cinit__(self, + str core_function_name, + object error_details, + object exception_type): + """Handles failure by raising exception.""" + self._core_function_name = core_function_name + self._error_details = error_details + self._exception_type = exception_type + + cdef handle(self, object future): + future.set_exception(self._exception_type( + 'Failed "%s": %s' % (self._core_function_name, self._error_details) + )) + + # TODO(https://github.com/grpc/grpc/issues/20669) # Apply this to the client-side cdef class CallbackWrapper: - cdef CallbackContext context - cdef object _reference - def __cinit__(self, object future): + def __cinit__(self, object future, _CallbackFailureHandler failure_handler): self.context.functor.functor_run = self.functor_run - self.context.waiter = (future) - self._reference = future + self.context.waiter = future + self.context.failure_handler = failure_handler + # NOTE(lidiz) Not using a list here, because this class is critical in + # data path. We should make it as efficient as possible. + self._reference_of_future = future + self._reference_of_failure_handler = failure_handler @staticmethod cdef void functor_run( @@ -38,7 +65,8 @@ cdef class CallbackWrapper: int success): cdef CallbackContext *context = functor if success == 0: - (context.waiter).set_exception(RuntimeError()) + (<_CallbackFailureHandler>context.failure_handler).handle( + context.waiter) else: (context.waiter).set_result(None) @@ -85,7 +113,9 @@ async def callback_start_batch(RPCState rpc_state, batch_operation_tag.prepare() cdef object future = loop.create_future() - cdef CallbackWrapper wrapper = CallbackWrapper(future) + cdef CallbackWrapper wrapper = CallbackWrapper( + future, + _CallbackFailureHandler('callback_start_batch', operations, RuntimeError)) # NOTE(lidiz) Without Py_INCREF, the wrapper object will be destructed # when calling "await". This is an over-optimization by Cython. cpython.Py_INCREF(wrapper) @@ -142,6 +172,9 @@ async def _handle_unary_unary_rpc(object method_handler, await callback_start_batch(rpc_state, send_ops, loop) + + + async def _handle_rpc(list generic_handlers, RPCState rpc_state, object loop): # Finds the method handler (application logic) cdef object method_handler = _find_method_handler( @@ -151,6 +184,7 @@ async def _handle_rpc(list generic_handlers, RPCState rpc_state, object loop): if method_handler is None: # TODO(lidiz) return unimplemented error to client side raise NotImplementedError() + # TODO(lidiz) extend to all 4 types of RPC if method_handler.request_streaming or method_handler.response_streaming: raise NotImplementedError() @@ -162,13 +196,21 @@ async def _handle_rpc(list generic_handlers, RPCState rpc_state, object loop): ) +class _RequestCallError(Exception): pass + +cdef _CallbackFailureHandler REQUEST_CALL_FAILURE_HANDLER = _CallbackFailureHandler( + 'grpc_server_request_call', 'server shutdown', _RequestCallError) + + async def _server_call_request_call(Server server, _CallbackCompletionQueue cq, object loop): cdef grpc_call_error error cdef RPCState rpc_state = RPCState() cdef object future = loop.create_future() - cdef CallbackWrapper wrapper = CallbackWrapper(future) + cdef CallbackWrapper wrapper = CallbackWrapper( + future, + REQUEST_CALL_FAILURE_HANDLER) # NOTE(lidiz) Without Py_INCREF, the wrapper object will be destructed # when calling "await". This is an over-optimization by Cython. cpython.Py_INCREF(wrapper) @@ -186,54 +228,76 @@ async def _server_call_request_call(Server server, return rpc_state -async def _server_main_loop(Server server, - _CallbackCompletionQueue cq, - list generic_handlers): - cdef object loop = asyncio.get_event_loop() - cdef RPCState rpc_state - - while True: - rpc_state = await _server_call_request_call( - server, - cq, - loop) +async def _handle_cancellation_from_core(object rpc_task, + RPCState rpc_state, + object loop): + cdef ReceiveCloseOnServerOperation op = ReceiveCloseOnServerOperation(_EMPTY_FLAG) + cdef tuple ops = (op,) + await callback_start_batch(rpc_state, ops, loop) + if op.cancelled() and not rpc_task.done(): + rpc_task.cancel() - loop.create_task(_handle_rpc(generic_handlers, rpc_state, loop)) - -async def _server_start(Server server, - _CallbackCompletionQueue cq, - list generic_handlers): - server.start() - await _server_main_loop(server, cq, generic_handlers) +cdef _CallbackFailureHandler CQ_SHUTDOWN_FAILURE_HANDLER = _CallbackFailureHandler( + 'grpc_completion_queue_shutdown', + 'Unknown', + RuntimeError) cdef class _CallbackCompletionQueue: - def __cinit__(self): + def __cinit__(self, object loop): + self._loop = loop + self._shutdown_completed = loop.create_future() + self._wrapper = CallbackWrapper( + self._shutdown_completed, + CQ_SHUTDOWN_FAILURE_HANDLER) self._cq = grpc_completion_queue_create_for_callback( - NULL, + self._wrapper.c_functor(), NULL ) cdef grpc_completion_queue* c_ptr(self): return self._cq + + async def shutdown(self): + grpc_completion_queue_shutdown(self._cq) + await self._shutdown_completed + grpc_completion_queue_destroy(self._cq) + + +cdef _CallbackFailureHandler SERVER_SHUTDOWN_FAILURE_HANDLER = _CallbackFailureHandler( + 'grpc_server_shutdown_and_notify', + 'Unknown', + RuntimeError) cdef class AioServer: - def __init__(self, thread_pool, generic_handlers, interceptors, options, - maximum_concurrent_rpcs, compression): + def __init__(self, loop, thread_pool, generic_handlers, interceptors, + options, maximum_concurrent_rpcs, compression): + # NOTE(lidiz) Core objects won't be deallocated automatically. + # If AioServer.shutdown is not called, those objects will leak. self._server = Server(options) - self._cq = _CallbackCompletionQueue() - self._status = AIO_SERVER_STATUS_READY - self._generic_handlers = [] + self._cq = _CallbackCompletionQueue(loop) grpc_server_register_completion_queue( self._server.c_server, self._cq.c_ptr(), NULL ) + + self._loop = loop + self._status = AIO_SERVER_STATUS_READY + self._generic_handlers = [] self.add_generic_rpc_handlers(generic_handlers) + self._serving_task = None + + self._shutdown_lock = asyncio.Lock(loop=self._loop) + self._shutdown_completed = self._loop.create_future() + self._shutdown_callback_wrapper = CallbackWrapper( + self._shutdown_completed, + SERVER_SHUTDOWN_FAILURE_HANDLER) + self._crash_exception = None if interceptors: raise NotImplementedError() @@ -255,6 +319,46 @@ cdef class AioServer: return self._server.add_http2_port(address, server_credentials._credentials) + async def _server_main_loop(self, + object server_started): + self._server.start() + cdef RPCState rpc_state + server_started.set_result(True) + + while True: + # When shutdown begins, no more new connections. + if self._status != AIO_SERVER_STATUS_RUNNING: + break + + rpc_state = await _server_call_request_call( + self._server, + self._cq, + self._loop) + + rpc_task = self._loop.create_task( + _handle_rpc( + self._generic_handlers, + rpc_state, + self._loop + ) + ) + self._loop.create_task( + _handle_cancellation_from_core( + rpc_task, + rpc_state, + self._loop + ) + ) + + def _serving_task_crash_handler(self, object task): + """Shutdown the server immediately if unexpectedly exited.""" + if task.exception() is None: + return + if self._status != AIO_SERVER_STATUS_STOPPING: + self._crash_exception = task.exception() + _LOGGER.exception(self._crash_exception) + self._loop.create_task(self.shutdown(None)) + async def start(self): if self._status == AIO_SERVER_STATUS_RUNNING: return @@ -262,14 +366,103 @@ cdef class AioServer: raise RuntimeError('Server not in ready state') self._status = AIO_SERVER_STATUS_RUNNING - loop = asyncio.get_event_loop() - loop.create_task(_server_start( - self._server, - self._cq, - self._generic_handlers, - )) + cdef object server_started = self._loop.create_future() + self._serving_task = self._loop.create_task(self._server_main_loop(server_started)) + self._serving_task.add_done_callback(self._serving_task_crash_handler) + # Needs to explicitly wait for the server to start up. + # Otherwise, the actual start time of the server is un-controllable. + await server_started + + async def _start_shutting_down(self): + """Prepares the server to shutting down. + + This coroutine function is NOT coroutine-safe. + """ + # The shutdown callback won't be called until there is no live RPC. + grpc_server_shutdown_and_notify( + self._server.c_server, + self._cq._cq, + self._shutdown_callback_wrapper.c_functor()) + + # Ensures the serving task (coroutine) exits. + try: + await self._serving_task + except _RequestCallError: + pass + + async def shutdown(self, grace): + """Gracefully shutdown the C-Core server. + + Application should only call shutdown once. + + Args: + grace: An optional float indicating the length of grace period in + seconds. + """ + if self._status == AIO_SERVER_STATUS_READY or self._status == AIO_SERVER_STATUS_STOPPED: + return + + async with self._shutdown_lock: + if self._status == AIO_SERVER_STATUS_RUNNING: + self._server.is_shutting_down = True + self._status = AIO_SERVER_STATUS_STOPPING + await self._start_shutting_down() + + if grace is None: + # Directly cancels all calls + grpc_server_cancel_all_calls(self._server.c_server) + await self._shutdown_completed + else: + try: + await asyncio.wait_for( + asyncio.shield( + self._shutdown_completed, + loop=self._loop + ), + grace, + loop=self._loop, + ) + except asyncio.TimeoutError: + # Cancels all ongoing calls by the end of grace period. + grpc_server_cancel_all_calls(self._server.c_server) + await self._shutdown_completed + + async with self._shutdown_lock: + if self._status == AIO_SERVER_STATUS_STOPPING: + grpc_server_destroy(self._server.c_server) + self._server.c_server = NULL + self._server.is_shutdown = True + self._status = AIO_SERVER_STATUS_STOPPED + + # Shuts down the completion queue + await self._cq.shutdown() + + async def wait_for_termination(self, float timeout): + if timeout is None: + await self._shutdown_completed + else: + try: + await asyncio.wait_for( + asyncio.shield( + self._shutdown_completed, + loop=self._loop, + ), + timeout, + loop=self._loop, + ) + except asyncio.TimeoutError: + if self._crash_exception is not None: + raise self._crash_exception + return False + if self._crash_exception is not None: + raise self._crash_exception + return True + + def __dealloc__(self): + """Deallocation of Core objects are ensured by Python grpc.aio.Server. - # TODO(https://github.com/grpc/grpc/issues/20668) - # Implement Destruction Methods for AsyncIO Server - def stop(self, unused_grace): - pass + If the Cython representation is deallocated without underlying objects + freed, raise an RuntimeError. + """ + if self._status != AIO_SERVER_STATUS_STOPPED: + raise RuntimeError('__dealloc__ called on running server: %d', self._status) diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi index 67b2e9d4e88..4ce554d078f 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi @@ -61,16 +61,25 @@ cdef class Server: self.c_server, queue.c_completion_queue, NULL) self.registered_completion_queues.append(queue) - def start(self): + def start(self, backup_queue=True): + """Start the Cython gRPC Server. + + Args: + backup_queue: a bool indicates whether to spawn a backup completion + queue. In the case that no CQ is bound to the server, and the shutdown + of server becomes un-observable. + """ if self.is_started: raise ValueError("the server has already started") - self.backup_shutdown_queue = CompletionQueue(shutdown_cq=True) - self.register_completion_queue(self.backup_shutdown_queue) + if backup_queue: + self.backup_shutdown_queue = CompletionQueue(shutdown_cq=True) + self.register_completion_queue(self.backup_shutdown_queue) self.is_started = True with nogil: grpc_server_start(self.c_server) - # Ensure the core has gotten a chance to do the start-up work - self.backup_shutdown_queue.poll(deadline=time.time()) + if backup_queue: + # Ensure the core has gotten a chance to do the start-up work + self.backup_shutdown_queue.poll(deadline=time.time()) def add_http2_port(self, bytes address, ServerCredentials server_credentials=None): @@ -134,11 +143,14 @@ cdef class Server: elif self.is_shutdown: pass elif not self.is_shutting_down: - # the user didn't call shutdown - use our backup queue - self._c_shutdown(self.backup_shutdown_queue, None) - # and now we wait - while not self.is_shutdown: - self.backup_shutdown_queue.poll() + if self.backup_shutdown_queue is None: + raise RuntimeError('Server shutdown failed: no completion queue.') + else: + # the user didn't call shutdown - use our backup queue + self._c_shutdown(self.backup_shutdown_queue, None) + # and now we wait + while not self.is_shutdown: + self.backup_shutdown_queue.poll() else: # We're in the process of shutting down, but have not shutdown; can't do # much but repeatedly release the GIL and wait diff --git a/src/python/grpcio/grpc/experimental/aio/__init__.py b/src/python/grpcio/grpc/experimental/aio/__init__.py index 696db001133..3f6b96eaa54 100644 --- a/src/python/grpcio/grpc/experimental/aio/__init__.py +++ b/src/python/grpcio/grpc/experimental/aio/__init__.py @@ -17,6 +17,8 @@ import abc import six import grpc +from grpc import _common +from grpc._cython import cygrpc from grpc._cython.cygrpc import init_grpc_aio from ._call import AioRpcError diff --git a/src/python/grpcio/grpc/experimental/aio/_server.py b/src/python/grpcio/grpc/experimental/aio/_server.py index 6bc3d210aed..8b53fdd0d03 100644 --- a/src/python/grpcio/grpc/experimental/aio/_server.py +++ b/src/python/grpcio/grpc/experimental/aio/_server.py @@ -25,8 +25,9 @@ class Server: def __init__(self, thread_pool, generic_handlers, interceptors, options, maximum_concurrent_rpcs, compression): - self._server = cygrpc.AioServer(thread_pool, generic_handlers, - interceptors, options, + self._loop = asyncio.get_event_loop() + self._server = cygrpc.AioServer(self._loop, thread_pool, + generic_handlers, interceptors, options, maximum_concurrent_rpcs, compression) def add_generic_rpc_handlers( @@ -83,35 +84,29 @@ class Server: """ await self._server.start() - def stop(self, grace: Optional[float]) -> asyncio.Event: + async def stop(self, grace: Optional[float]) -> None: """Stops this Server. - "This method immediately stops the server from servicing new RPCs in + This method immediately stops the server from servicing new RPCs in all cases. - If a grace period is specified, this method returns immediately - and all RPCs active at the end of the grace period are aborted. - If a grace period is not specified (by passing None for `grace`), - all existing RPCs are aborted immediately and this method - blocks until the last RPC handler terminates. + If a grace period is specified, this method returns immediately and all + RPCs active at the end of the grace period are aborted. If a grace + period is not specified (by passing None for grace), all existing RPCs + are aborted immediately and this method blocks until the last RPC + handler terminates. - This method is idempotent and may be called at any time. - Passing a smaller grace value in a subsequent call will have - the effect of stopping the Server sooner (passing None will - have the effect of stopping the server immediately). Passing - a larger grace value in a subsequent call *will not* have the - effect of stopping the server later (i.e. the most restrictive - grace value is used). + This method is idempotent and may be called at any time. Passing a + smaller grace value in a subsequent call will have the effect of + stopping the Server sooner (passing None will have the effect of + stopping the server immediately). Passing a larger grace value in a + subsequent call will not have the effect of stopping the server later + (i.e. the most restrictive grace value is used). Args: grace: A duration of time in seconds or None. - - Returns: - A threading.Event that will be set when this Server has completely - stopped, i.e. when running RPCs either complete or are aborted and - all handlers have terminated. """ - raise NotImplementedError() + await self._server.shutdown(grace) async def wait_for_termination(self, timeout: Optional[float] = None) -> bool: @@ -135,11 +130,15 @@ class Server: Returns: A bool indicates if the operation times out. """ - if timeout: - raise NotImplementedError() - # TODO(lidiz) replace this wait forever logic - future = asyncio.get_event_loop().create_future() - await future + return await self._server.wait_for_termination(timeout) + + def __del__(self): + """Schedules a graceful shutdown in current event loop. + + The Cython AioServer doesn't hold a ref-count to this class. It should + be safe to slightly extend the underlying Cython object's life span. + """ + self._loop.create_task(self._server.shutdown(None)) def server(migration_thread_pool=None, diff --git a/src/python/grpcio_tests/tests_aio/unit/channel_test.py b/src/python/grpcio_tests/tests_aio/unit/channel_test.py index 96817c61a6f..076300786fb 100644 --- a/src/python/grpcio_tests/tests_aio/unit/channel_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/channel_test.py @@ -22,6 +22,9 @@ from tests.unit.framework.common import test_constants from tests_aio.unit._test_server import start_test_server from tests_aio.unit._test_base import AioTestBase +_UNARY_CALL_METHOD = '/grpc.testing.TestService/UnaryCall' +_EMPTY_CALL_METHOD = '/grpc.testing.TestService/EmptyCall' + class TestChannel(AioTestBase): @@ -32,7 +35,7 @@ class TestChannel(AioTestBase): async with aio.insecure_channel(server_target) as channel: hi = channel.unary_unary( - '/grpc.testing.TestService/UnaryCall', + _UNARY_CALL_METHOD, request_serializer=messages_pb2.SimpleRequest. SerializeToString, response_deserializer=messages_pb2.SimpleResponse.FromString @@ -48,7 +51,7 @@ class TestChannel(AioTestBase): channel = aio.insecure_channel(server_target) hi = channel.unary_unary( - '/grpc.testing.TestService/UnaryCall', + _UNARY_CALL_METHOD, request_serializer=messages_pb2.SimpleRequest.SerializeToString, response_deserializer=messages_pb2.SimpleResponse.FromString) response = await hi(messages_pb2.SimpleRequest()) @@ -66,7 +69,7 @@ class TestChannel(AioTestBase): async with aio.insecure_channel(server_target) as channel: empty_call_with_sleep = channel.unary_unary( - "/grpc.testing.TestService/EmptyCall", + _EMPTY_CALL_METHOD, request_serializer=messages_pb2.SimpleRequest. SerializeToString, response_deserializer=messages_pb2.SimpleResponse. @@ -94,6 +97,23 @@ class TestChannel(AioTestBase): self.loop.run_until_complete(coro()) + @unittest.skip('https://github.com/grpc/grpc/issues/20818') + def test_call_to_the_void(self): + + async def coro(): + channel = aio.insecure_channel('0.1.1.1:1111') + hi = channel.unary_unary( + _UNARY_CALL_METHOD, + request_serializer=messages_pb2.SimpleRequest.SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString) + response = await hi(messages_pb2.SimpleRequest()) + + self.assertIs(type(response), messages_pb2.SimpleResponse) + + await channel.close() + + self.loop.run_until_complete(coro()) + if __name__ == '__main__': logging.basicConfig() diff --git a/src/python/grpcio_tests/tests_aio/unit/server_test.py b/src/python/grpcio_tests/tests_aio/unit/server_test.py index 937cce9eebb..1e86de65404 100644 --- a/src/python/grpcio_tests/tests_aio/unit/server_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/server_test.py @@ -12,27 +12,61 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import logging import unittest +import time +import gc import grpc from grpc.experimental import aio from tests_aio.unit._test_base import AioTestBase +from tests.unit.framework.common import test_constants -_TEST_METHOD_PATH = '' +_SIMPLE_UNARY_UNARY = '/test/SimpleUnaryUnary' +_BLOCK_FOREVER = '/test/BlockForever' +_BLOCK_BRIEFLY = '/test/BlockBriefly' _REQUEST = b'\x00\x00\x00' _RESPONSE = b'\x01\x01\x01' -async def unary_unary(unused_request, unused_context): - return _RESPONSE +class _GenericHandler(grpc.GenericRpcHandler): + def __init__(self): + self._called = asyncio.get_event_loop().create_future() -class GenericHandler(grpc.GenericRpcHandler): + @staticmethod + async def _unary_unary(unused_request, unused_context): + return _RESPONSE - def service(self, unused_handler_details): - return grpc.unary_unary_rpc_method_handler(unary_unary) + async def _block_forever(self, unused_request, unused_context): + await asyncio.get_event_loop().create_future() + + async def _BLOCK_BRIEFLY(self, unused_request, unused_context): + await asyncio.sleep(test_constants.SHORT_TIMEOUT / 2) + return _RESPONSE + + def service(self, handler_details): + self._called.set_result(None) + if handler_details.method == _SIMPLE_UNARY_UNARY: + return grpc.unary_unary_rpc_method_handler(self._unary_unary) + if handler_details.method == _BLOCK_FOREVER: + return grpc.unary_unary_rpc_method_handler(self._block_forever) + if handler_details.method == _BLOCK_BRIEFLY: + return grpc.unary_unary_rpc_method_handler(self._BLOCK_BRIEFLY) + + async def wait_for_call(self): + await self._called + + +async def _start_test_server(): + server = aio.server() + port = server.add_insecure_port('[::]:0') + generic_handler = _GenericHandler() + server.add_generic_rpc_handlers((generic_handler,)) + await server.start() + return 'localhost:%d' % port, server, generic_handler class TestServer(AioTestBase): @@ -40,18 +74,146 @@ class TestServer(AioTestBase): def test_unary_unary(self): async def test_unary_unary_body(): - server = aio.server() - port = server.add_insecure_port('[::]:0') - server.add_generic_rpc_handlers((GenericHandler(),)) - await server.start() + result = await _start_test_server() + server_target = result[0] - async with aio.insecure_channel('localhost:%d' % port) as channel: - unary_call = channel.unary_unary(_TEST_METHOD_PATH) + async with aio.insecure_channel(server_target) as channel: + unary_call = channel.unary_unary(_SIMPLE_UNARY_UNARY) response = await unary_call(_REQUEST) self.assertEqual(response, _RESPONSE) self.loop.run_until_complete(test_unary_unary_body()) + def test_shutdown(self): + + async def test_shutdown_body(): + _, server, _ = await _start_test_server() + await server.stop(None) + + self.loop.run_until_complete(test_shutdown_body()) + # Ensures no SIGSEGV triggered, and ends within timeout. + + def test_shutdown_after_call(self): + + async def test_shutdown_body(): + server_target, server, _ = await _start_test_server() + + async with aio.insecure_channel(server_target) as channel: + await channel.unary_unary(_SIMPLE_UNARY_UNARY)(_REQUEST) + + await server.stop(None) + + self.loop.run_until_complete(test_shutdown_body()) + + def test_graceful_shutdown_success(self): + + async def test_graceful_shutdown_success_body(): + server_target, server, generic_handler = await _start_test_server() + + channel = aio.insecure_channel(server_target) + call = channel.unary_unary(_BLOCK_BRIEFLY)(_REQUEST) + await generic_handler.wait_for_call() + + shutdown_start_time = time.time() + await server.stop(test_constants.SHORT_TIMEOUT) + grace_period_length = time.time() - shutdown_start_time + self.assertGreater(grace_period_length, + test_constants.SHORT_TIMEOUT / 3) + + # Validates the states. + await channel.close() + self.assertEqual(_RESPONSE, await call) + self.assertTrue(call.done()) + + self.loop.run_until_complete(test_graceful_shutdown_success_body()) + + def test_graceful_shutdown_failed(self): + + async def test_graceful_shutdown_failed_body(): + server_target, server, generic_handler = await _start_test_server() + + channel = aio.insecure_channel(server_target) + call = channel.unary_unary(_BLOCK_FOREVER)(_REQUEST) + await generic_handler.wait_for_call() + + await server.stop(test_constants.SHORT_TIMEOUT) + + with self.assertRaises(aio.AioRpcError) as exception_context: + await call + self.assertEqual(grpc.StatusCode.UNAVAILABLE, + exception_context.exception.code()) + self.assertIn('GOAWAY', exception_context.exception.details()) + await channel.close() + + self.loop.run_until_complete(test_graceful_shutdown_failed_body()) + + def test_concurrent_graceful_shutdown(self): + + async def test_concurrent_graceful_shutdown_body(): + server_target, server, generic_handler = await _start_test_server() + + channel = aio.insecure_channel(server_target) + call = channel.unary_unary(_BLOCK_BRIEFLY)(_REQUEST) + await generic_handler.wait_for_call() + + # Expects the shortest grace period to be effective. + shutdown_start_time = time.time() + await asyncio.gather( + server.stop(test_constants.LONG_TIMEOUT), + server.stop(test_constants.SHORT_TIMEOUT), + server.stop(test_constants.LONG_TIMEOUT), + ) + grace_period_length = time.time() - shutdown_start_time + self.assertGreater(grace_period_length, + test_constants.SHORT_TIMEOUT / 3) + + await channel.close() + self.assertEqual(_RESPONSE, await call) + self.assertTrue(call.done()) + + self.loop.run_until_complete(test_concurrent_graceful_shutdown_body()) + + def test_concurrent_graceful_shutdown_immediate(self): + + async def test_concurrent_graceful_shutdown_immediate_body(): + server_target, server, generic_handler = await _start_test_server() + + channel = aio.insecure_channel(server_target) + call = channel.unary_unary(_BLOCK_FOREVER)(_REQUEST) + await generic_handler.wait_for_call() + + # Expects no grace period, due to the "server.stop(None)". + await asyncio.gather( + server.stop(test_constants.LONG_TIMEOUT), + server.stop(None), + server.stop(test_constants.SHORT_TIMEOUT), + server.stop(test_constants.LONG_TIMEOUT), + ) + + with self.assertRaises(aio.AioRpcError) as exception_context: + await call + self.assertEqual(grpc.StatusCode.UNAVAILABLE, + exception_context.exception.code()) + self.assertIn('GOAWAY', exception_context.exception.details()) + await channel.close() + + self.loop.run_until_complete( + test_concurrent_graceful_shutdown_immediate_body()) + + @unittest.skip('https://github.com/grpc/grpc/issues/20818') + def test_shutdown_before_call(self): + + async def test_shutdown_body(): + server_target, server, _ = _start_test_server() + await server.stop(None) + + # Ensures the server is cleaned up at this point. + # Some proper exception should be raised. + async with aio.insecure_channel('localhost:%d' % port) as channel: + await channel.unary_unary(_SIMPLE_UNARY_UNARY)(_REQUEST) + + self.loop.run_until_complete(test_shutdown_body()) + if __name__ == '__main__': logging.basicConfig()