|
|
|
@ -211,6 +211,65 @@ cdef class _ServicerContext: |
|
|
|
|
self._rpc_state.disable_next_compression = True |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cdef class _SyncServicerContext: |
|
|
|
|
"""Sync servicer context for sync handler compatibility.""" |
|
|
|
|
|
|
|
|
|
def __cinit__(self, |
|
|
|
|
_ServicerContext context): |
|
|
|
|
self._context = context |
|
|
|
|
self._callbacks = [] |
|
|
|
|
self._loop = context._loop |
|
|
|
|
|
|
|
|
|
def read(self): |
|
|
|
|
future = asyncio.run_coroutine_threadsafe( |
|
|
|
|
self._context.read(), |
|
|
|
|
self._loop) |
|
|
|
|
return future.result() |
|
|
|
|
|
|
|
|
|
def write(self, object message): |
|
|
|
|
future = asyncio.run_coroutine_threadsafe( |
|
|
|
|
self._context.write(message), |
|
|
|
|
self._loop) |
|
|
|
|
future.result() |
|
|
|
|
|
|
|
|
|
def abort(self, |
|
|
|
|
object code, |
|
|
|
|
str details='', |
|
|
|
|
tuple trailing_metadata=_IMMUTABLE_EMPTY_METADATA): |
|
|
|
|
future = asyncio.run_coroutine_threadsafe( |
|
|
|
|
self._context.abort(code, details, trailing_metadata), |
|
|
|
|
self._loop) |
|
|
|
|
# Abort should raise an AbortError |
|
|
|
|
future.exception() |
|
|
|
|
|
|
|
|
|
def send_initial_metadata(self, tuple metadata): |
|
|
|
|
future = asyncio.run_coroutine_threadsafe( |
|
|
|
|
self._context.send_initial_metadata(metadata), |
|
|
|
|
self._loop) |
|
|
|
|
future.result() |
|
|
|
|
|
|
|
|
|
def set_trailing_metadata(self, tuple metadata): |
|
|
|
|
self._context.set_trailing_metadata(metadata) |
|
|
|
|
|
|
|
|
|
def invocation_metadata(self): |
|
|
|
|
return self._context.invocation_metadata() |
|
|
|
|
|
|
|
|
|
def set_code(self, object code): |
|
|
|
|
self._context.set_code(code) |
|
|
|
|
|
|
|
|
|
def set_details(self, str details): |
|
|
|
|
self._context.set_details(details) |
|
|
|
|
|
|
|
|
|
def set_compression(self, object compression): |
|
|
|
|
self._context.set_compression(compression) |
|
|
|
|
|
|
|
|
|
def disable_next_message_compression(self): |
|
|
|
|
self._context.disable_next_message_compression() |
|
|
|
|
|
|
|
|
|
def add_callback(self, object callback): |
|
|
|
|
self._callbacks.append(callback) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _run_interceptor(object interceptors, object query_handler, |
|
|
|
|
object handler_call_details): |
|
|
|
|
interceptor = next(interceptors, None) |
|
|
|
@ -222,6 +281,11 @@ async def _run_interceptor(object interceptors, object query_handler, |
|
|
|
|
return query_handler(handler_call_details) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _is_async_handler(object handler): |
|
|
|
|
"""Inspect if a method handler is async or sync.""" |
|
|
|
|
return inspect.isawaitable(handler) or inspect.iscoroutinefunction(handler) or inspect.isasyncgenfunction(handler) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _find_method_handler(str method, tuple metadata, list generic_handlers, |
|
|
|
|
tuple interceptors): |
|
|
|
|
def query_handlers(handler_call_details): |
|
|
|
@ -254,11 +318,27 @@ async def _finish_handler_with_unary_response(RPCState rpc_state, |
|
|
|
|
stream-unary handlers. |
|
|
|
|
""" |
|
|
|
|
# Executes application logic |
|
|
|
|
|
|
|
|
|
cdef object response_message = await unary_handler( |
|
|
|
|
request, |
|
|
|
|
servicer_context, |
|
|
|
|
) |
|
|
|
|
cdef object response_message |
|
|
|
|
cdef _SyncServicerContext sync_servicer_context |
|
|
|
|
|
|
|
|
|
if _is_async_handler(unary_handler): |
|
|
|
|
# Run async method handlers in this coroutine |
|
|
|
|
response_message = await unary_handler( |
|
|
|
|
request, |
|
|
|
|
servicer_context, |
|
|
|
|
) |
|
|
|
|
else: |
|
|
|
|
# Run sync method handlers in the thread pool |
|
|
|
|
sync_servicer_context = _SyncServicerContext(servicer_context) |
|
|
|
|
response_message = await loop.run_in_executor( |
|
|
|
|
rpc_state.server.thread_pool(), |
|
|
|
|
unary_handler, |
|
|
|
|
request, |
|
|
|
|
sync_servicer_context, |
|
|
|
|
) |
|
|
|
|
# Support sync-stack callback |
|
|
|
|
for callback in sync_servicer_context._callbacks: |
|
|
|
|
callback() |
|
|
|
|
|
|
|
|
|
# Raises exception if aborted |
|
|
|
|
rpc_state.raise_for_termination() |
|
|
|
@ -307,18 +387,31 @@ async def _finish_handler_with_stream_responses(RPCState rpc_state, |
|
|
|
|
""" |
|
|
|
|
cdef object async_response_generator |
|
|
|
|
cdef object response_message |
|
|
|
|
|
|
|
|
|
if inspect.iscoroutinefunction(stream_handler): |
|
|
|
|
# Case 1: Coroutine async handler - using reader-writer API |
|
|
|
|
# The handler uses reader / writer API, returns None. |
|
|
|
|
await stream_handler( |
|
|
|
|
request, |
|
|
|
|
servicer_context, |
|
|
|
|
) |
|
|
|
|
else: |
|
|
|
|
# The handler uses async generator API |
|
|
|
|
async_response_generator = stream_handler( |
|
|
|
|
request, |
|
|
|
|
servicer_context, |
|
|
|
|
) |
|
|
|
|
if inspect.isasyncgenfunction(stream_handler): |
|
|
|
|
# Case 2: Async handler - async generator |
|
|
|
|
# The handler uses async generator API |
|
|
|
|
async_response_generator = stream_handler( |
|
|
|
|
request, |
|
|
|
|
servicer_context, |
|
|
|
|
) |
|
|
|
|
else: |
|
|
|
|
# Case 3: Sync handler - normal generator |
|
|
|
|
# NOTE(lidiz) Streaming handler in sync stack is either a generator |
|
|
|
|
# function or a function returns a generator. |
|
|
|
|
sync_servicer_context = _SyncServicerContext(servicer_context) |
|
|
|
|
gen = stream_handler(request, sync_servicer_context) |
|
|
|
|
async_response_generator = generator_to_async_generator(gen, |
|
|
|
|
loop, |
|
|
|
|
rpc_state.server.thread_pool()) |
|
|
|
|
|
|
|
|
|
# Consumes messages from the generator |
|
|
|
|
async for response_message in async_response_generator: |
|
|
|
@ -438,6 +531,9 @@ cdef class _MessageReceiver: |
|
|
|
|
self._agen = self._async_message_receiver() |
|
|
|
|
return self._agen |
|
|
|
|
|
|
|
|
|
async def __anext__(self): |
|
|
|
|
return await self.__aiter__().__anext__() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _handle_stream_unary_rpc(object method_handler, |
|
|
|
|
RPCState rpc_state, |
|
|
|
@ -451,13 +547,20 @@ async def _handle_stream_unary_rpc(object method_handler, |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
# Prepares the request generator |
|
|
|
|
cdef object request_async_iterator = _MessageReceiver(servicer_context) |
|
|
|
|
cdef object request_iterator |
|
|
|
|
if _is_async_handler(method_handler.stream_unary): |
|
|
|
|
request_iterator = _MessageReceiver(servicer_context) |
|
|
|
|
else: |
|
|
|
|
request_iterator = async_generator_to_generator( |
|
|
|
|
_MessageReceiver(servicer_context), |
|
|
|
|
loop |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
# Finishes the application handler |
|
|
|
|
await _finish_handler_with_unary_response( |
|
|
|
|
rpc_state, |
|
|
|
|
method_handler.stream_unary, |
|
|
|
|
request_async_iterator, |
|
|
|
|
request_iterator, |
|
|
|
|
servicer_context, |
|
|
|
|
method_handler.response_serializer, |
|
|
|
|
loop |
|
|
|
@ -476,13 +579,20 @@ async def _handle_stream_stream_rpc(object method_handler, |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
# Prepares the request generator |
|
|
|
|
cdef object request_async_iterator = _MessageReceiver(servicer_context) |
|
|
|
|
cdef object request_iterator |
|
|
|
|
if _is_async_handler(method_handler.stream_stream): |
|
|
|
|
request_iterator = _MessageReceiver(servicer_context) |
|
|
|
|
else: |
|
|
|
|
request_iterator = async_generator_to_generator( |
|
|
|
|
_MessageReceiver(servicer_context), |
|
|
|
|
loop |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
# Finishes the application handler |
|
|
|
|
await _finish_handler_with_stream_responses( |
|
|
|
|
rpc_state, |
|
|
|
|
method_handler.stream_stream, |
|
|
|
|
request_async_iterator, |
|
|
|
|
request_iterator, |
|
|
|
|
servicer_context, |
|
|
|
|
loop, |
|
|
|
|
) |
|
|
|
@ -591,22 +701,22 @@ async def _handle_rpc(list generic_handlers, tuple interceptors, |
|
|
|
|
# Handles unary-unary case |
|
|
|
|
if not method_handler.request_streaming and not method_handler.response_streaming: |
|
|
|
|
await _handle_unary_unary_rpc(method_handler, |
|
|
|
|
rpc_state, |
|
|
|
|
loop) |
|
|
|
|
rpc_state, |
|
|
|
|
loop) |
|
|
|
|
return |
|
|
|
|
|
|
|
|
|
# Handles unary-stream case |
|
|
|
|
if not method_handler.request_streaming and method_handler.response_streaming: |
|
|
|
|
await _handle_unary_stream_rpc(method_handler, |
|
|
|
|
rpc_state, |
|
|
|
|
loop) |
|
|
|
|
rpc_state, |
|
|
|
|
loop) |
|
|
|
|
return |
|
|
|
|
|
|
|
|
|
# Handles stream-unary case |
|
|
|
|
if method_handler.request_streaming and not method_handler.response_streaming: |
|
|
|
|
await _handle_stream_unary_rpc(method_handler, |
|
|
|
|
rpc_state, |
|
|
|
|
loop) |
|
|
|
|
rpc_state, |
|
|
|
|
loop) |
|
|
|
|
return |
|
|
|
|
|
|
|
|
|
# Handles stream-stream case |
|
|
|
@ -648,7 +758,6 @@ cdef class AioServer: |
|
|
|
|
self._generic_handlers = [] |
|
|
|
|
self.add_generic_rpc_handlers(generic_handlers) |
|
|
|
|
self._serving_task = None |
|
|
|
|
self._ongoing_rpc_tasks = set() |
|
|
|
|
|
|
|
|
|
self._shutdown_lock = asyncio.Lock(loop=self._loop) |
|
|
|
|
self._shutdown_completed = self._loop.create_future() |
|
|
|
@ -658,17 +767,18 @@ cdef class AioServer: |
|
|
|
|
SERVER_SHUTDOWN_FAILURE_HANDLER) |
|
|
|
|
self._crash_exception = None |
|
|
|
|
|
|
|
|
|
self._interceptors = () |
|
|
|
|
if interceptors: |
|
|
|
|
self._interceptors = interceptors |
|
|
|
|
else: |
|
|
|
|
self._interceptors = () |
|
|
|
|
|
|
|
|
|
self._thread_pool = thread_pool |
|
|
|
|
|
|
|
|
|
if maximum_concurrent_rpcs: |
|
|
|
|
raise NotImplementedError() |
|
|
|
|
if thread_pool: |
|
|
|
|
raise NotImplementedError() |
|
|
|
|
|
|
|
|
|
def add_generic_rpc_handlers(self, generic_rpc_handlers): |
|
|
|
|
for h in generic_rpc_handlers: |
|
|
|
|
self._generic_handlers.append(h) |
|
|
|
|
def add_generic_rpc_handlers(self, object generic_rpc_handlers): |
|
|
|
|
self._generic_handlers.extend(generic_rpc_handlers) |
|
|
|
|
|
|
|
|
|
def add_insecure_port(self, address): |
|
|
|
|
return self._server.add_http2_port(address) |
|
|
|
@ -846,3 +956,10 @@ cdef class AioServer: |
|
|
|
|
self._status |
|
|
|
|
) |
|
|
|
|
shutdown_grpc_aio() |
|
|
|
|
|
|
|
|
|
cdef thread_pool(self): |
|
|
|
|
"""Access the thread pool instance.""" |
|
|
|
|
if self._thread_pool: |
|
|
|
|
return self._thread_pool |
|
|
|
|
else: |
|
|
|
|
raise UsageError('Please provide an Executor upon server creation.') |
|
|
|
|