diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pxd.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pxd.pxi index 46a47bd1ba7..4651a6b6f22 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pxd.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pxd.pxi @@ -67,6 +67,13 @@ cdef enum AioServerStatus: AIO_SERVER_STATUS_STOPPING +cdef class _ConcurrentRpcLimiter: + cdef int _maximum_concurrent_rpcs + cdef int _active_rpcs + cdef object _active_rpcs_condition # asyncio.Condition + cdef object _loop # asyncio.EventLoop + + cdef class AioServer: cdef Server _server cdef list _generic_handlers @@ -79,5 +86,6 @@ cdef class AioServer: cdef object _crash_exception # Exception cdef tuple _interceptors cdef object _thread_pool # concurrent.futures.ThreadPoolExecutor + cdef _ConcurrentRpcLimiter _limiter cdef thread_pool(self) diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi index 8c74d3ee22b..73d9fb4ea97 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi @@ -781,6 +781,40 @@ cdef CallbackFailureHandler SERVER_SHUTDOWN_FAILURE_HANDLER = CallbackFailureHan InternalError) +cdef class _ConcurrentRpcLimiter: + + def __cinit__(self, int maximum_concurrent_rpcs, object loop): + if maximum_concurrent_rpcs <= 0: + raise ValueError("maximum_concurrent_rpcs should be a postive integer") + self._maximum_concurrent_rpcs = maximum_concurrent_rpcs + self._active_rpcs = 0 + self._active_rpcs_condition = asyncio.Condition() + self._loop = loop + + async def check_before_request_call(self): + await self._active_rpcs_condition.acquire() + try: + predicate = lambda: self._active_rpcs < self._maximum_concurrent_rpcs + await self._active_rpcs_condition.wait_for(predicate) + self._active_rpcs += 1 + finally: + self._active_rpcs_condition.release() + + async def _decrease_active_rpcs_count_with_lock(self): + await self._active_rpcs_condition.acquire() + try: + self._active_rpcs -= 1 + self._active_rpcs_condition.notify() + finally: + self._active_rpcs_condition.release() + + def _decrease_active_rpcs_count(self, unused_future): + self._loop.create_task(self._decrease_active_rpcs_count_with_lock()) + + def decrease_once_finished(self, object rpc_task): + rpc_task.add_done_callback(self._decrease_active_rpcs_count) + + cdef class AioServer: def __init__(self, loop, thread_pool, generic_handlers, interceptors, @@ -815,9 +849,9 @@ cdef class AioServer: self._interceptors = () self._thread_pool = thread_pool - - if maximum_concurrent_rpcs: - raise NotImplementedError() + if maximum_concurrent_rpcs is not None: + self._limiter = _ConcurrentRpcLimiter(maximum_concurrent_rpcs, + loop) def add_generic_rpc_handlers(self, object generic_rpc_handlers): self._generic_handlers.extend(generic_rpc_handlers) @@ -860,6 +894,9 @@ cdef class AioServer: if self._status != AIO_SERVER_STATUS_RUNNING: break + if self._limiter is not None: + await self._limiter.check_before_request_call() + # Accepts new request from Core rpc_state = await self._request_call() @@ -874,7 +911,7 @@ cdef class AioServer: self._loop) # Fires off a task that listens on the cancellation from client. - self._loop.create_task( + rpc_task = self._loop.create_task( _schedule_rpc_coro( rpc_coro, rpc_state, @@ -882,6 +919,9 @@ cdef class AioServer: ) ) + if self._limiter is not None: + self._limiter.decrease_once_finished(rpc_task) + def _serving_task_crash_handler(self, object task): """Shutdown the server immediately if unexpectedly exited.""" if task.cancelled(): diff --git a/src/python/grpcio_tests/tests_aio/unit/server_test.py b/src/python/grpcio_tests/tests_aio/unit/server_test.py index 61d1edd5231..8ba3ce1901e 100644 --- a/src/python/grpcio_tests/tests_aio/unit/server_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/server_test.py @@ -47,6 +47,7 @@ _REQUEST = b'\x00\x00\x00' _RESPONSE = b'\x01\x01\x01' _NUM_STREAM_REQUESTS = 3 _NUM_STREAM_RESPONSES = 5 +_MAXIMUM_CONCURRENT_RPCS = 5 class _GenericHandler(grpc.GenericRpcHandler): @@ -189,7 +190,8 @@ class _GenericHandler(grpc.GenericRpcHandler): context.set_code(grpc.StatusCode.INTERNAL) def service(self, handler_details): - self._called.set_result(None) + if not self._called.done(): + self._called.set_result(None) return self._routing_table.get(handler_details.method) async def wait_for_call(self): @@ -480,6 +482,30 @@ class TestServer(AioTestBase): with self.assertRaises(RuntimeError): server.add_secure_port(bind_address, server_credentials) + async def test_maximum_concurrent_rpcs(self): + # Build the server with concurrent rpc argument + server = aio.server(maximum_concurrent_rpcs=_MAXIMUM_CONCURRENT_RPCS) + port = server.add_insecure_port('localhost:0') + bind_address = "localhost:%d" % port + server.add_generic_rpc_handlers((_GenericHandler(),)) + await server.start() + # Build the channel + channel = aio.insecure_channel(bind_address) + # Deplete the concurrent quota with 3 times of max RPCs + rpcs = [] + for _ in range(3 * _MAXIMUM_CONCURRENT_RPCS): + rpcs.append(channel.unary_unary(_BLOCK_BRIEFLY)(_REQUEST)) + task = self.loop.create_task( + asyncio.wait(rpcs, return_when=asyncio.FIRST_EXCEPTION)) + # Each batch took test_constants.SHORT_TIMEOUT /2 + start_time = time.time() + await task + elapsed_time = time.time() - start_time + self.assertGreater(elapsed_time, test_constants.SHORT_TIMEOUT * 3 / 2) + # Clean-up + await channel.close() + await server.stop(0) + if __name__ == '__main__': logging.basicConfig(level=logging.DEBUG)