Add limit concurrent RPC feature to asyncio server

* Reduce the allocation of new function
pull/24818/head
Lidi Zheng 4 years ago
parent 3f46d68975
commit 3da3cc2168
  1. 8
      src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pxd.pxi
  2. 48
      src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi
  3. 28
      src/python/grpcio_tests/tests_aio/unit/server_test.py

@ -67,6 +67,13 @@ cdef enum AioServerStatus:
AIO_SERVER_STATUS_STOPPING
cdef class _ConcurrentRpcLimiter:
cdef int _maximum_concurrent_rpcs
cdef int _active_rpcs
cdef object _active_rpcs_condition # asyncio.Condition
cdef object _loop # asyncio.EventLoop
cdef class AioServer:
cdef Server _server
cdef list _generic_handlers
@ -79,5 +86,6 @@ cdef class AioServer:
cdef object _crash_exception # Exception
cdef tuple _interceptors
cdef object _thread_pool # concurrent.futures.ThreadPoolExecutor
cdef _ConcurrentRpcLimiter _limiter
cdef thread_pool(self)

@ -781,6 +781,40 @@ cdef CallbackFailureHandler SERVER_SHUTDOWN_FAILURE_HANDLER = CallbackFailureHan
InternalError)
cdef class _ConcurrentRpcLimiter:
def __cinit__(self, int maximum_concurrent_rpcs, object loop):
if maximum_concurrent_rpcs <= 0:
raise ValueError("maximum_concurrent_rpcs should be a postive integer")
self._maximum_concurrent_rpcs = maximum_concurrent_rpcs
self._active_rpcs = 0
self._active_rpcs_condition = asyncio.Condition()
self._loop = loop
async def check_before_request_call(self):
await self._active_rpcs_condition.acquire()
try:
predicate = lambda: self._active_rpcs < self._maximum_concurrent_rpcs
await self._active_rpcs_condition.wait_for(predicate)
self._active_rpcs += 1
finally:
self._active_rpcs_condition.release()
async def _decrease_active_rpcs_count_with_lock(self):
await self._active_rpcs_condition.acquire()
try:
self._active_rpcs -= 1
self._active_rpcs_condition.notify()
finally:
self._active_rpcs_condition.release()
def _decrease_active_rpcs_count(self, unused_future):
self._loop.create_task(self._decrease_active_rpcs_count_with_lock())
def decrease_once_finished(self, object rpc_task):
rpc_task.add_done_callback(self._decrease_active_rpcs_count)
cdef class AioServer:
def __init__(self, loop, thread_pool, generic_handlers, interceptors,
@ -815,9 +849,9 @@ cdef class AioServer:
self._interceptors = ()
self._thread_pool = thread_pool
if maximum_concurrent_rpcs:
raise NotImplementedError()
if maximum_concurrent_rpcs is not None:
self._limiter = _ConcurrentRpcLimiter(maximum_concurrent_rpcs,
loop)
def add_generic_rpc_handlers(self, object generic_rpc_handlers):
self._generic_handlers.extend(generic_rpc_handlers)
@ -860,6 +894,9 @@ cdef class AioServer:
if self._status != AIO_SERVER_STATUS_RUNNING:
break
if self._limiter is not None:
await self._limiter.check_before_request_call()
# Accepts new request from Core
rpc_state = await self._request_call()
@ -874,7 +911,7 @@ cdef class AioServer:
self._loop)
# Fires off a task that listens on the cancellation from client.
self._loop.create_task(
rpc_task = self._loop.create_task(
_schedule_rpc_coro(
rpc_coro,
rpc_state,
@ -882,6 +919,9 @@ cdef class AioServer:
)
)
if self._limiter is not None:
self._limiter.decrease_once_finished(rpc_task)
def _serving_task_crash_handler(self, object task):
"""Shutdown the server immediately if unexpectedly exited."""
if task.cancelled():

@ -47,6 +47,7 @@ _REQUEST = b'\x00\x00\x00'
_RESPONSE = b'\x01\x01\x01'
_NUM_STREAM_REQUESTS = 3
_NUM_STREAM_RESPONSES = 5
_MAXIMUM_CONCURRENT_RPCS = 5
class _GenericHandler(grpc.GenericRpcHandler):
@ -189,7 +190,8 @@ class _GenericHandler(grpc.GenericRpcHandler):
context.set_code(grpc.StatusCode.INTERNAL)
def service(self, handler_details):
self._called.set_result(None)
if not self._called.done():
self._called.set_result(None)
return self._routing_table.get(handler_details.method)
async def wait_for_call(self):
@ -480,6 +482,30 @@ class TestServer(AioTestBase):
with self.assertRaises(RuntimeError):
server.add_secure_port(bind_address, server_credentials)
async def test_maximum_concurrent_rpcs(self):
# Build the server with concurrent rpc argument
server = aio.server(maximum_concurrent_rpcs=_MAXIMUM_CONCURRENT_RPCS)
port = server.add_insecure_port('localhost:0')
bind_address = "localhost:%d" % port
server.add_generic_rpc_handlers((_GenericHandler(),))
await server.start()
# Build the channel
channel = aio.insecure_channel(bind_address)
# Deplete the concurrent quota with 3 times of max RPCs
rpcs = []
for _ in range(3 * _MAXIMUM_CONCURRENT_RPCS):
rpcs.append(channel.unary_unary(_BLOCK_BRIEFLY)(_REQUEST))
task = self.loop.create_task(
asyncio.wait(rpcs, return_when=asyncio.FIRST_EXCEPTION))
# Each batch took test_constants.SHORT_TIMEOUT /2
start_time = time.time()
await task
elapsed_time = time.time() - start_time
self.assertGreater(elapsed_time, test_constants.SHORT_TIMEOUT * 3 / 2)
# Clean-up
await channel.close()
await server.stop(0)
if __name__ == '__main__':
logging.basicConfig(level=logging.DEBUG)

Loading…
Cancel
Save