diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi index ef7160a3059..ef4b7f8c51d 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi @@ -14,6 +14,7 @@ # TODO(https://github.com/grpc/grpc/issues/20850) refactor this. _LOGGER = logging.getLogger(__name__) +cdef int _EMPTY_FLAG = 0 cdef class _HandlerCallDetails: @@ -171,6 +172,9 @@ async def _handle_unary_unary_rpc(object method_handler, await callback_start_batch(rpc_state, send_ops, loop) + + + async def _handle_rpc(list generic_handlers, RPCState rpc_state, object loop): # Finds the method handler (application logic) cdef object method_handler = _find_method_handler( @@ -180,6 +184,7 @@ async def _handle_rpc(list generic_handlers, RPCState rpc_state, object loop): if method_handler is None: # TODO(lidiz) return unimplemented error to client side raise NotImplementedError() + # TODO(lidiz) extend to all 4 types of RPC if method_handler.request_streaming or method_handler.response_streaming: raise NotImplementedError() @@ -223,6 +228,16 @@ async def _server_call_request_call(Server server, return rpc_state +async def _handle_cancellation_from_core(object rpc_task, + RPCState rpc_state, + object loop): + cdef ReceiveCloseOnServerOperation op = ReceiveCloseOnServerOperation(_EMPTY_FLAG) + cdef tuple ops = (op,) + await callback_start_batch(rpc_state, ops, loop) + if op.cancelled() and not rpc_task.done(): + rpc_task.cancel() + + cdef _CallbackFailureHandler CQ_SHUTDOWN_FAILURE_HANDLER = _CallbackFailureHandler( 'grpc_completion_queue_shutdown', 'Unknown', @@ -277,7 +292,7 @@ cdef class AioServer: self.add_generic_rpc_handlers(generic_handlers) self._serving_task = None - self._shutdown_lock = asyncio.Lock() + self._shutdown_lock = asyncio.Lock(loop=self._loop) self._shutdown_completed = self._loop.create_future() self._shutdown_callback_wrapper = CallbackWrapper( self._shutdown_completed, @@ -320,10 +335,20 @@ cdef class AioServer: self._cq, self._loop) - self._loop.create_task(_handle_rpc( - self._generic_handlers, - rpc_state, - self._loop)) + rpc_task = self._loop.create_task( + _handle_rpc( + self._generic_handlers, + rpc_state, + self._loop + ) + ) + self._loop.create_task( + _handle_cancellation_from_core( + rpc_task, + rpc_state, + self._loop + ) + ) def _serving_task_crash_handler(self, object task): """Shutdown the server immediately if unexpectedly exited.""" @@ -389,7 +414,14 @@ cdef class AioServer: await self._shutdown_completed else: try: - await asyncio.wait_for(asyncio.shield(self._shutdown_completed), grace) + await asyncio.wait_for( + asyncio.shield( + self._shutdown_completed, + loop=self._loop + ), + grace, + loop=self._loop, + ) except asyncio.TimeoutError: # Cancels all ongoing calls by the end of grace period. grpc_server_cancel_all_calls(self._server.c_server) @@ -410,7 +442,14 @@ cdef class AioServer: await self._shutdown_completed else: try: - await asyncio.wait_for(asyncio.shield(self._shutdown_completed), timeout) + await asyncio.wait_for( + asyncio.shield( + self._shutdown_completed, + loop=self._loop, + ), + timeout, + loop=self._loop, + ) except asyncio.TimeoutError: if self._crash_exception is not None: raise self._crash_exception