Make sync handlers runnable in AsyncIO server

pull/22812/head
Lidi Zheng 5 years ago
parent 6b0b2602f6
commit df065d41fa
  1. 54
      src/python/grpcio/grpc/_cython/_cygrpc/aio/common.pyx.pxi
  2. 10
      src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pxd.pxi
  3. 171
      src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi
  4. 6
      src/python/grpcio_tests/tests_aio/unit/_test_server.py
  5. 179
      src/python/grpcio_tests/tests_aio/unit/compatibility_test.py

@ -112,3 +112,57 @@ def schedule_coro_threadsafe(object coro, object loop):
)
else:
raise
def async_generator_to_generator(object agen, object loop):
"""Converts an async generator into generator."""
try:
while True:
future = asyncio.run_coroutine_threadsafe(
agen.__anext__(),
loop
)
response = future.result()
if response is EOF:
break
else:
yield response
except StopAsyncIteration:
# If StopAsyncIteration is raised, end this generator.
pass
async def generator_to_async_generator(object gen, object loop, object thread_pool):
"""Converts a generator into async generator.
The generator might block, so we need to delegate the iteration to thread
pool. Also, we can't simply delegate __next__ to the thread pool, otherwise
we will see following error:
TypeError: StopIteration interacts badly with generators and cannot be
raised into a Future
"""
queue = asyncio.Queue(loop=loop)
def yield_to_queue():
try:
for item in gen:
# For an infinite sized queue, the put_nowait should always success
loop.call_soon_threadsafe(queue.put_nowait, item)
finally:
loop.call_soon_threadsafe(queue.put_nowait, EOF)
future = loop.run_in_executor(
thread_pool,
yield_to_queue,
)
while True:
response = await queue.get()
if response is EOF:
break
else:
yield response
# Port the exception if there is any
future.result()

@ -48,6 +48,12 @@ cdef class _ServicerContext:
cdef object _response_serializer # Callable[[Any], bytes]
cdef class _SyncServicerContext:
cdef _ServicerContext _context
cdef list _callbacks
cdef object _loop # asyncio.AbstractEventLoop
cdef class _MessageReceiver:
cdef _ServicerContext _servicer_context
cdef object _agen
@ -71,5 +77,7 @@ cdef class AioServer:
cdef object _shutdown_completed # asyncio.Future
cdef CallbackWrapper _shutdown_callback_wrapper
cdef object _crash_exception # Exception
cdef set _ongoing_rpc_tasks
cdef tuple _interceptors
cdef object _thread_pool # concurrent.futures.ThreadPoolExecutor
cdef thread_pool(self)

@ -211,6 +211,65 @@ cdef class _ServicerContext:
self._rpc_state.disable_next_compression = True
cdef class _SyncServicerContext:
"""Sync servicer context for sync handler compatibility."""
def __cinit__(self,
_ServicerContext context):
self._context = context
self._callbacks = []
self._loop = context._loop
def read(self):
future = asyncio.run_coroutine_threadsafe(
self._context.read(),
self._loop)
return future.result()
def write(self, object message):
future = asyncio.run_coroutine_threadsafe(
self._context.write(message),
self._loop)
future.result()
def abort(self,
object code,
str details='',
tuple trailing_metadata=_IMMUTABLE_EMPTY_METADATA):
future = asyncio.run_coroutine_threadsafe(
self._context.abort(code, details, trailing_metadata),
self._loop)
# Abort should raise an AbortError
future.exception()
def send_initial_metadata(self, tuple metadata):
future = asyncio.run_coroutine_threadsafe(
self._context.send_initial_metadata(metadata),
self._loop)
future.result()
def set_trailing_metadata(self, tuple metadata):
self._context.set_trailing_metadata(metadata)
def invocation_metadata(self):
return self._context.invocation_metadata()
def set_code(self, object code):
self._context.set_code(code)
def set_details(self, str details):
self._context.set_details(details)
def set_compression(self, object compression):
self._context.set_compression(compression)
def disable_next_message_compression(self):
self._context.disable_next_message_compression()
def add_callback(self, object callback):
self._callbacks.append(callback)
async def _run_interceptor(object interceptors, object query_handler,
object handler_call_details):
interceptor = next(interceptors, None)
@ -222,6 +281,11 @@ async def _run_interceptor(object interceptors, object query_handler,
return query_handler(handler_call_details)
def _is_async_handler(object handler):
"""Inspect if a method handler is async or sync."""
return inspect.isawaitable(handler) or inspect.iscoroutinefunction(handler) or inspect.isasyncgenfunction(handler)
async def _find_method_handler(str method, tuple metadata, list generic_handlers,
tuple interceptors):
def query_handlers(handler_call_details):
@ -254,11 +318,27 @@ async def _finish_handler_with_unary_response(RPCState rpc_state,
stream-unary handlers.
"""
# Executes application logic
cdef object response_message = await unary_handler(
request,
servicer_context,
)
cdef object response_message
cdef _SyncServicerContext sync_servicer_context
if _is_async_handler(unary_handler):
# Run async method handlers in this coroutine
response_message = await unary_handler(
request,
servicer_context,
)
else:
# Run sync method handlers in the thread pool
sync_servicer_context = _SyncServicerContext(servicer_context)
response_message = await loop.run_in_executor(
rpc_state.server.thread_pool(),
unary_handler,
request,
sync_servicer_context,
)
# Support sync-stack callback
for callback in sync_servicer_context._callbacks:
callback()
# Raises exception if aborted
rpc_state.raise_for_termination()
@ -307,18 +387,31 @@ async def _finish_handler_with_stream_responses(RPCState rpc_state,
"""
cdef object async_response_generator
cdef object response_message
if inspect.iscoroutinefunction(stream_handler):
# Case 1: Coroutine async handler - using reader-writer API
# The handler uses reader / writer API, returns None.
await stream_handler(
request,
servicer_context,
)
else:
# The handler uses async generator API
async_response_generator = stream_handler(
request,
servicer_context,
)
if inspect.isasyncgenfunction(stream_handler):
# Case 2: Async handler - async generator
# The handler uses async generator API
async_response_generator = stream_handler(
request,
servicer_context,
)
else:
# Case 3: Sync handler - normal generator
# NOTE(lidiz) Streaming handler in sync stack is either a generator
# function or a function returns a generator.
sync_servicer_context = _SyncServicerContext(servicer_context)
gen = stream_handler(request, sync_servicer_context)
async_response_generator = generator_to_async_generator(gen,
loop,
rpc_state.server.thread_pool())
# Consumes messages from the generator
async for response_message in async_response_generator:
@ -438,6 +531,9 @@ cdef class _MessageReceiver:
self._agen = self._async_message_receiver()
return self._agen
async def __anext__(self):
return await self.__aiter__().__anext__()
async def _handle_stream_unary_rpc(object method_handler,
RPCState rpc_state,
@ -451,13 +547,20 @@ async def _handle_stream_unary_rpc(object method_handler,
)
# Prepares the request generator
cdef object request_async_iterator = _MessageReceiver(servicer_context)
cdef object request_iterator
if _is_async_handler(method_handler.stream_unary):
request_iterator = _MessageReceiver(servicer_context)
else:
request_iterator = async_generator_to_generator(
_MessageReceiver(servicer_context),
loop
)
# Finishes the application handler
await _finish_handler_with_unary_response(
rpc_state,
method_handler.stream_unary,
request_async_iterator,
request_iterator,
servicer_context,
method_handler.response_serializer,
loop
@ -476,13 +579,20 @@ async def _handle_stream_stream_rpc(object method_handler,
)
# Prepares the request generator
cdef object request_async_iterator = _MessageReceiver(servicer_context)
cdef object request_iterator
if _is_async_handler(method_handler.stream_stream):
request_iterator = _MessageReceiver(servicer_context)
else:
request_iterator = async_generator_to_generator(
_MessageReceiver(servicer_context),
loop
)
# Finishes the application handler
await _finish_handler_with_stream_responses(
rpc_state,
method_handler.stream_stream,
request_async_iterator,
request_iterator,
servicer_context,
loop,
)
@ -591,22 +701,22 @@ async def _handle_rpc(list generic_handlers, tuple interceptors,
# Handles unary-unary case
if not method_handler.request_streaming and not method_handler.response_streaming:
await _handle_unary_unary_rpc(method_handler,
rpc_state,
loop)
rpc_state,
loop)
return
# Handles unary-stream case
if not method_handler.request_streaming and method_handler.response_streaming:
await _handle_unary_stream_rpc(method_handler,
rpc_state,
loop)
rpc_state,
loop)
return
# Handles stream-unary case
if method_handler.request_streaming and not method_handler.response_streaming:
await _handle_stream_unary_rpc(method_handler,
rpc_state,
loop)
rpc_state,
loop)
return
# Handles stream-stream case
@ -648,7 +758,6 @@ cdef class AioServer:
self._generic_handlers = []
self.add_generic_rpc_handlers(generic_handlers)
self._serving_task = None
self._ongoing_rpc_tasks = set()
self._shutdown_lock = asyncio.Lock(loop=self._loop)
self._shutdown_completed = self._loop.create_future()
@ -658,17 +767,18 @@ cdef class AioServer:
SERVER_SHUTDOWN_FAILURE_HANDLER)
self._crash_exception = None
self._interceptors = ()
if interceptors:
self._interceptors = interceptors
else:
self._interceptors = ()
self._thread_pool = thread_pool
if maximum_concurrent_rpcs:
raise NotImplementedError()
if thread_pool:
raise NotImplementedError()
def add_generic_rpc_handlers(self, generic_rpc_handlers):
for h in generic_rpc_handlers:
self._generic_handlers.append(h)
def add_generic_rpc_handlers(self, object generic_rpc_handlers):
self._generic_handlers.extend(generic_rpc_handlers)
def add_insecure_port(self, address):
return self._server.add_http2_port(address)
@ -846,3 +956,10 @@ cdef class AioServer:
self._status
)
shutdown_grpc_aio()
cdef thread_pool(self):
"""Access the thread pool instance."""
if self._thread_pool:
return self._thread_pool
else:
raise UsageError('Please provide an Executor upon server creation.')

@ -47,7 +47,7 @@ async def _maybe_echo_status(request: messages_pb2.SimpleRequest,
request.response_status.message)
class _TestServiceServicer(test_pb2_grpc.TestServiceServicer):
class TestServiceServicer(test_pb2_grpc.TestServiceServicer):
async def UnaryCall(self, request, context):
await _maybe_echo_metadata(context)
@ -102,7 +102,7 @@ class _TestServiceServicer(test_pb2_grpc.TestServiceServicer):
response_parameters.size))
def _create_extra_generic_handler(servicer: _TestServiceServicer):
def _create_extra_generic_handler(servicer: TestServiceServicer):
# Add programatically extra methods not provided by the proto file
# that are used during the tests
rpc_method_handlers = {
@ -123,7 +123,7 @@ async def start_test_server(port=0,
interceptors=None):
server = aio.server(options=(('grpc.so_reuseport', 0),),
interceptors=interceptors)
servicer = _TestServiceServicer()
servicer = TestServiceServicer()
test_pb2_grpc.add_TestServiceServicer_to_server(servicer, server)
server.add_generic_rpc_handlers((_create_extra_generic_handler(servicer),))

@ -20,32 +20,63 @@ import random
import threading
import unittest
from concurrent.futures import ThreadPoolExecutor
from typing import Callable, Sequence, Tuple
from typing import Callable, Iterable, Sequence, Tuple
import grpc
from grpc.experimental import aio
from src.proto.grpc.testing import messages_pb2, test_pb2_grpc
from tests.unit.framework.common import test_constants
from tests_aio.unit import _common
from tests_aio.unit._test_base import AioTestBase
from tests_aio.unit._test_server import start_test_server
from tests_aio.unit._test_server import TestServiceServicer, start_test_server
_NUM_STREAM_RESPONSES = 5
_REQUEST_PAYLOAD_SIZE = 7
_RESPONSE_PAYLOAD_SIZE = 42
_REQUEST = b'\x03\x07'
_ADHOC_METHOD = '/test/AdHoc'
def _unique_options() -> Sequence[Tuple[str, float]]:
return (('iv', random.random()),)
class _AdhocGenericHandler(grpc.GenericRpcHandler):
_handler: grpc.RpcMethodHandler
def __init__(self):
self._handler = None
def set_adhoc_handler(self, handler: grpc.RpcMethodHandler):
self._handler = handler
def service(self, handler_call_details: grpc.HandlerCallDetails):
if handler_call_details.method == _ADHOC_METHOD:
return self._handler
else:
return None
@unittest.skipIf(
os.environ.get('GRPC_ASYNCIO_ENGINE', '').lower() != 'poller',
os.environ.get('GRPC_ASYNCIO_ENGINE', '').lower() == 'custom_io_manager',
'Compatible mode needs POLLER completion queue.')
class TestCompatibility(AioTestBase):
async def setUp(self):
address, self._async_server = await start_test_server()
self._async_server = aio.server(
options=(('grpc.so_reuseport', 0),),
migration_thread_pool=ThreadPoolExecutor())
test_pb2_grpc.add_TestServiceServicer_to_server(TestServiceServicer(),
self._async_server)
self._adhoc_handlers = _AdhocGenericHandler()
self._async_server.add_generic_rpc_handlers((self._adhoc_handlers,))
port = self._async_server.add_insecure_port('[::]:0')
address = 'localhost:%d' % port
await self._async_server.start()
# Create async stub
self._async_channel = aio.insecure_channel(address,
options=_unique_options())
@ -202,6 +233,146 @@ class TestCompatibility(AioTestBase):
await self._run_in_another_thread(sync_work)
await server.stop(None)
async def test_sync_unary_unary_success(self):
@grpc.unary_unary_rpc_method_handler
def echo_unary_unary(request: bytes, unused_context):
return request
self._adhoc_handlers.set_adhoc_handler(echo_unary_unary)
response = await self._async_channel.unary_unary(_ADHOC_METHOD)(_REQUEST
)
self.assertEqual(_REQUEST, response)
async def test_sync_unary_unary_metadata(self):
metadata = (('unique', 'key-42'),)
@grpc.unary_unary_rpc_method_handler
def metadata_unary_unary(request: bytes, context: grpc.ServicerContext):
context.send_initial_metadata(metadata)
return request
self._adhoc_handlers.set_adhoc_handler(metadata_unary_unary)
call = self._async_channel.unary_unary(_ADHOC_METHOD)(_REQUEST)
self.assertTrue(
_common.seen_metadata(metadata, await call.initial_metadata()))
async def test_sync_unary_unary_abort(self):
@grpc.unary_unary_rpc_method_handler
def abort_unary_unary(request: bytes, context: grpc.ServicerContext):
context.abort(grpc.StatusCode.INTERNAL, 'Test')
self._adhoc_handlers.set_adhoc_handler(abort_unary_unary)
with self.assertRaises(aio.AioRpcError) as exception_context:
await self._async_channel.unary_unary(_ADHOC_METHOD)(_REQUEST)
self.assertEqual(grpc.StatusCode.INTERNAL,
exception_context.exception.code())
async def test_sync_unary_unary_set_code(self):
@grpc.unary_unary_rpc_method_handler
def set_code_unary_unary(request: bytes, context: grpc.ServicerContext):
context.set_code(grpc.StatusCode.INTERNAL)
self._adhoc_handlers.set_adhoc_handler(set_code_unary_unary)
with self.assertRaises(aio.AioRpcError) as exception_context:
await self._async_channel.unary_unary(_ADHOC_METHOD)(_REQUEST)
self.assertEqual(grpc.StatusCode.INTERNAL,
exception_context.exception.code())
async def test_sync_unary_stream_success(self):
@grpc.unary_stream_rpc_method_handler
def echo_unary_stream(request: bytes, unused_context):
for _ in range(_NUM_STREAM_RESPONSES):
yield request
self._adhoc_handlers.set_adhoc_handler(echo_unary_stream)
call = self._async_channel.unary_stream(_ADHOC_METHOD)(_REQUEST)
async for response in call:
self.assertEqual(_REQUEST, response)
async def test_sync_unary_stream_error(self):
@grpc.unary_stream_rpc_method_handler
def error_unary_stream(request: bytes, unused_context):
for _ in range(_NUM_STREAM_RESPONSES):
yield request
raise RuntimeError('Test')
self._adhoc_handlers.set_adhoc_handler(error_unary_stream)
call = self._async_channel.unary_stream(_ADHOC_METHOD)(_REQUEST)
with self.assertRaises(aio.AioRpcError) as exception_context:
async for response in call:
self.assertEqual(_REQUEST, response)
self.assertEqual(grpc.StatusCode.UNKNOWN,
exception_context.exception.code())
async def test_sync_stream_unary_success(self):
@grpc.stream_unary_rpc_method_handler
def echo_stream_unary(request_iterator: Iterable[bytes],
unused_context):
self.assertEqual(len(list(request_iterator)), _NUM_STREAM_RESPONSES)
return _REQUEST
self._adhoc_handlers.set_adhoc_handler(echo_stream_unary)
request_iterator = iter([_REQUEST] * _NUM_STREAM_RESPONSES)
response = await self._async_channel.stream_unary(_ADHOC_METHOD)(
request_iterator)
self.assertEqual(_REQUEST, response)
async def test_sync_stream_unary_error(self):
@grpc.stream_unary_rpc_method_handler
def echo_stream_unary(request_iterator: Iterable[bytes],
unused_context):
self.assertEqual(len(list(request_iterator)), _NUM_STREAM_RESPONSES)
raise RuntimeError('Test')
self._adhoc_handlers.set_adhoc_handler(echo_stream_unary)
request_iterator = iter([_REQUEST] * _NUM_STREAM_RESPONSES)
with self.assertRaises(aio.AioRpcError) as exception_context:
response = await self._async_channel.stream_unary(_ADHOC_METHOD)(
request_iterator)
self.assertEqual(grpc.StatusCode.UNKNOWN,
exception_context.exception.code())
async def test_sync_stream_stream_success(self):
@grpc.stream_stream_rpc_method_handler
def echo_stream_stream(request_iterator: Iterable[bytes],
unused_context):
for request in request_iterator:
yield request
self._adhoc_handlers.set_adhoc_handler(echo_stream_stream)
request_iterator = iter([_REQUEST] * _NUM_STREAM_RESPONSES)
call = self._async_channel.stream_stream(_ADHOC_METHOD)(
request_iterator)
async for response in call:
self.assertEqual(_REQUEST, response)
async def test_sync_stream_stream_error(self):
@grpc.stream_stream_rpc_method_handler
def echo_stream_stream(request_iterator: Iterable[bytes],
unused_context):
for request in request_iterator:
yield request
raise RuntimeError('test')
self._adhoc_handlers.set_adhoc_handler(echo_stream_stream)
request_iterator = iter([_REQUEST] * _NUM_STREAM_RESPONSES)
call = self._async_channel.stream_stream(_ADHOC_METHOD)(
request_iterator)
with self.assertRaises(aio.AioRpcError) as exception_context:
async for response in call:
self.assertEqual(_REQUEST, response)
self.assertEqual(grpc.StatusCode.UNKNOWN,
exception_context.exception.code())
if __name__ == '__main__':
logging.basicConfig(level=logging.DEBUG)

Loading…
Cancel
Save