diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi index 64f89bb5575..468a8f42ce7 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi @@ -23,18 +23,18 @@ _EMPTY_METADATA = None _UNKNOWN_CANCELLATION_DETAILS = 'RPC cancelled for unknown reason.' -cdef class _AioCall: +cdef class _AioCall(GrpcCallWrapper): def __cinit__(self, AioChannel channel, object deadline, bytes method, - CallCredentials credentials): + CallCredentials call_credentials): self.call = NULL self._channel = channel self._references = [] self._loop = asyncio.get_event_loop() - self._create_grpc_call(deadline, method, credentials) + self._create_grpc_call(deadline, method, call_credentials) self._is_locally_cancelled = False def __dealloc__(self): @@ -196,9 +196,25 @@ cdef class _AioCall: self, self._loop ) - return received_message + if received_message: + return received_message + else: + return EOF + + async def send_serialized_message(self, bytes message): + """Sends one single raw message in bytes.""" + await _send_message(self, + message, + True, + self._loop) - async def unary_stream(self, + async def send_receive_close(self): + """Half close the RPC on the client-side.""" + cdef SendCloseFromClientOperation op = SendCloseFromClientOperation(_EMPTY_FLAGS) + cdef tuple ops = (op,) + await execute_batch(self, ops, self._loop) + + async def initiate_unary_stream(self, bytes request, object initial_metadata_observer, object status_observer): @@ -233,3 +249,80 @@ cdef class _AioCall: await _receive_initial_metadata(self, self._loop), ) + + async def stream_unary(self, + tuple metadata, + object metadata_sent_observer, + object initial_metadata_observer, + object status_observer): + """Actual implementation of the complete unary-stream call. + + Needs to pay extra attention to the raise mechanism. If we want to + propagate the final status exception, then we have to raise it. + Othersize, it would end normally and raise `StopAsyncIteration()`. + """ + # Sends out initial_metadata ASAP. + await _send_initial_metadata(self, + metadata, + self._loop) + # Notify upper level that sending messages are allowed now. + metadata_sent_observer() + + # Receives initial metadata. + initial_metadata_observer( + await _receive_initial_metadata(self, + self._loop), + ) + + cdef tuple inbound_ops + cdef ReceiveMessageOperation receive_message_op = ReceiveMessageOperation(_EMPTY_FLAGS) + cdef ReceiveStatusOnClientOperation receive_status_on_client_op = ReceiveStatusOnClientOperation(_EMPTY_FLAGS) + inbound_ops = (receive_message_op, receive_status_on_client_op) + + # Executes all operations in one batch. + await execute_batch(self, + inbound_ops, + self._loop) + + status = AioRpcStatus( + receive_status_on_client_op.code(), + receive_status_on_client_op.details(), + receive_status_on_client_op.trailing_metadata(), + receive_status_on_client_op.error_string(), + ) + # Reports the final status of the RPC to Python layer. The observer + # pattern is used here to unify unary and streaming code path. + status_observer(status) + + if status.code() == StatusCode.ok: + return receive_message_op.message() + else: + return None + + async def initiate_stream_stream(self, + tuple metadata, + object metadata_sent_observer, + object initial_metadata_observer, + object status_observer): + """Actual implementation of the complete stream-stream call. + + Needs to pay extra attention to the raise mechanism. If we want to + propagate the final status exception, then we have to raise it. + Othersize, it would end normally and raise `StopAsyncIteration()`. + """ + # Peer may prematurely end this RPC at any point. We need a corutine + # that watches if the server sends the final status. + self._loop.create_task(self._handle_status_once_received(status_observer)) + + # Sends out initial_metadata ASAP. + await _send_initial_metadata(self, + metadata, + self._loop) + # Notify upper level that sending messages are allowed now. + metadata_sent_observer() + + # Receives initial metadata. + initial_metadata_observer( + await _receive_initial_metadata(self, + self._loop), + ) diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi index 4022c892d20..c938f55adff 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi @@ -96,7 +96,7 @@ cdef class AioChannel: def call(self, bytes method, object deadline, - CallCredentials credentials): + object python_call_credentials): """Assembles a Cython Call object. Returns: @@ -105,5 +105,12 @@ cdef class AioChannel: if self._status == AIO_CHANNEL_STATUS_DESTROYED: # TODO(lidiz) switch to UsageError raise RuntimeError('Channel is closed.') - cdef _AioCall call = _AioCall(self, deadline, method, credentials) + + cdef CallCredentials cython_call_credentials + if python_call_credentials is not None: + cython_call_credentials = python_call_credentials._credentials + else: + cython_call_credentials = None + + cdef _AioCall call = _AioCall(self, deadline, method, cython_call_credentials) return call diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/common.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/common.pyx.pxi index fbb65983f19..2fb5f04fbbc 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/common.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/common.pyx.pxi @@ -33,3 +33,24 @@ cdef bytes serialize(object serializer, object message): return serializer(message) else: return message + + +class _EOF: + + def __bool__(self): + return False + + def __len__(self): + return 0 + + def _repr(self) -> str: + return '' + + def __repr__(self) -> str: + return self._repr() + + def __str__(self) -> str: + return self._repr() + + +EOF = _EOF() diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pxd.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pxd.pxi index b8ae832bfc8..15f6bba0a80 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pxd.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pxd.pxi @@ -21,6 +21,10 @@ cdef class RPCState(GrpcCallWrapper): cdef grpc_call_details details cdef grpc_metadata_array request_metadata cdef AioServer server + # NOTE(lidiz) Under certain corner case, receiving the client close + # operation won't immediately fail ongoing RECV_MESSAGE operations. Here I + # added a flag to workaround this unexpected behavior. + cdef bint client_closed cdef object abort_exception cdef bint metadata_sent cdef bint status_sent diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi index 5e0108464a5..b8c635c4568 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi @@ -20,7 +20,8 @@ import traceback # TODO(https://github.com/grpc/grpc/issues/20850) refactor this. _LOGGER = logging.getLogger(__name__) cdef int _EMPTY_FLAG = 0 - +# TODO(lidiz) Use a designated value other than None. +cdef str _SERVER_STOPPED_DETAILS = 'Server already stopped.' cdef class _HandlerCallDetails: def __cinit__(self, str method, tuple invocation_metadata): @@ -35,6 +36,7 @@ cdef class RPCState: self.server = server grpc_metadata_array_init(&self.request_metadata) grpc_call_details_init(&self.details) + self.client_closed = False self.abort_exception = None self.metadata_sent = False self.status_sent = False @@ -83,13 +85,23 @@ cdef class _ServicerContext: self._loop = loop async def read(self): + cdef bytes raw_message + if self._rpc_state.server._status == AIO_SERVER_STATUS_STOPPED: + raise RuntimeError(_SERVER_STOPPED_DETAILS) if self._rpc_state.status_sent: raise RuntimeError('RPC already finished.') - cdef bytes raw_message = await _receive_message(self._rpc_state, self._loop) - return deserialize(self._request_deserializer, - raw_message) + if self._rpc_state.client_closed: + return EOF + raw_message = await _receive_message(self._rpc_state, self._loop) + if raw_message is None: + return EOF + else: + return deserialize(self._request_deserializer, + raw_message) async def write(self, object message): + if self._rpc_state.server._status == AIO_SERVER_STATUS_STOPPED: + raise RuntimeError(_SERVER_STOPPED_DETAILS) if self._rpc_state.status_sent: raise RuntimeError('RPC already finished.') await _send_message(self._rpc_state, @@ -102,6 +114,8 @@ cdef class _ServicerContext: async def send_initial_metadata(self, tuple metadata): if self._rpc_state.status_sent: raise RuntimeError('RPC already finished.') + elif self._rpc_state.server._status == AIO_SERVER_STATUS_STOPPED: + raise RuntimeError(_SERVER_STOPPED_DETAILS) elif self._rpc_state.metadata_sent: raise RuntimeError('Send initial metadata failed: already sent') else: @@ -145,27 +159,23 @@ cdef _find_method_handler(str method, list generic_handlers): return None -async def _handle_unary_unary_rpc(object method_handler, - RPCState rpc_state, - object loop): - # Receives request message - cdef bytes request_raw = await _receive_message(rpc_state, loop) - - # Deserializes the request message - cdef object request_message = deserialize( - method_handler.request_deserializer, - request_raw, - ) - +async def _finish_handler_with_unary_response(RPCState rpc_state, + object unary_handler, + object request, + _ServicerContext servicer_context, + object response_serializer, + object loop): + """Finishes server method handler with a single response. + + This function executes the application handler, and handles response + sending, as well as errors. It is shared between unary-unary and + stream-unary handlers. + """ # Executes application logic - cdef object response_message = await method_handler.unary_unary( - request_message, - _ServicerContext( - rpc_state, - None, - None, - loop, - ), + + cdef object response_message = await unary_handler( + request, + servicer_context, ) # Raises exception if aborted @@ -173,50 +183,50 @@ async def _handle_unary_unary_rpc(object method_handler, # Serializes the response message cdef bytes response_raw = serialize( - method_handler.response_serializer, + response_serializer, response_message, ) - # Sends response message - cdef tuple send_ops = ( - SendStatusFromServerOperation( - tuple(), + # Assembles the batch operations + cdef Operation send_status_op = SendStatusFromServerOperation( + tuple(), StatusCode.ok, b'', _EMPTY_FLAGS, - ), - SendInitialMetadataOperation(None, _EMPTY_FLAGS), - SendMessageOperation(response_raw, _EMPTY_FLAGS), ) + cdef tuple finish_ops + if not rpc_state.metadata_sent: + finish_ops = ( + send_status_op, + SendInitialMetadataOperation(None, _EMPTY_FLAGS), + SendMessageOperation(response_raw, _EMPTY_FLAGS), + ) + else: + finish_ops = ( + send_status_op, + SendMessageOperation(response_raw, _EMPTY_FLAGS), + ) rpc_state.status_sent = True - await execute_batch(rpc_state, send_ops, loop) - - -async def _handle_unary_stream_rpc(object method_handler, - RPCState rpc_state, - object loop): - # Receives request message - cdef bytes request_raw = await _receive_message(rpc_state, loop) - - # Deserializes the request message - cdef object request_message = deserialize( - method_handler.request_deserializer, - request_raw, - ) + await execute_batch(rpc_state, finish_ops, loop) - cdef _ServicerContext servicer_context = _ServicerContext( - rpc_state, - method_handler.request_deserializer, - method_handler.response_serializer, - loop, - ) +async def _finish_handler_with_stream_responses(RPCState rpc_state, + object stream_handler, + object request, + _ServicerContext servicer_context, + object loop): + """Finishes server method handler with multiple responses. + + This function executes the application handler, and handles response + sending, as well as errors. It is shared between unary-stream and + stream-stream handlers. + """ cdef object async_response_generator cdef object response_message - if inspect.iscoroutinefunction(method_handler.unary_stream): + if inspect.iscoroutinefunction(stream_handler): # The handler uses reader / writer API, returns None. - await method_handler.unary_stream( - request_message, + await stream_handler( + request, servicer_context, ) @@ -224,8 +234,8 @@ async def _handle_unary_stream_rpc(object method_handler, _raise_if_aborted(rpc_state) else: # The handler uses async generator API - async_response_generator = method_handler.unary_stream( - request_message, + async_response_generator = stream_handler( + request, servicer_context, ) @@ -250,9 +260,132 @@ async def _handle_unary_stream_rpc(object method_handler, _EMPTY_FLAGS, ) - cdef tuple ops = (op,) + cdef tuple finish_ops = (op,) + if not rpc_state.metadata_sent: + finish_ops = (op, SendInitialMetadataOperation(None, _EMPTY_FLAGS)) rpc_state.status_sent = True - await execute_batch(rpc_state, ops, loop) + await execute_batch(rpc_state, finish_ops, loop) + + +async def _handle_unary_unary_rpc(object method_handler, + RPCState rpc_state, + object loop): + # Receives request message + cdef bytes request_raw = await _receive_message(rpc_state, loop) + + # Deserializes the request message + cdef object request_message = deserialize( + method_handler.request_deserializer, + request_raw, + ) + + # Creates a dedecated ServicerContext + cdef _ServicerContext servicer_context = _ServicerContext( + rpc_state, + None, + None, + loop, + ) + + # Finishes the application handler + await _finish_handler_with_unary_response( + rpc_state, + method_handler.unary_unary, + request_message, + servicer_context, + method_handler.response_serializer, + loop + ) + + +async def _handle_unary_stream_rpc(object method_handler, + RPCState rpc_state, + object loop): + # Receives request message + cdef bytes request_raw = await _receive_message(rpc_state, loop) + + # Deserializes the request message + cdef object request_message = deserialize( + method_handler.request_deserializer, + request_raw, + ) + + # Creates a dedecated ServicerContext + cdef _ServicerContext servicer_context = _ServicerContext( + rpc_state, + method_handler.request_deserializer, + method_handler.response_serializer, + loop, + ) + + # Finishes the application handler + await _finish_handler_with_stream_responses( + rpc_state, + method_handler.unary_stream, + request_message, + servicer_context, + loop, + ) + + +async def _message_receiver(_ServicerContext servicer_context): + """Bridge between the async generator API and the reader-writer API.""" + cdef object message + while True: + message = await servicer_context.read() + if message is not EOF: + yield message + else: + break + + +async def _handle_stream_unary_rpc(object method_handler, + RPCState rpc_state, + object loop): + # Creates a dedecated ServicerContext + cdef _ServicerContext servicer_context = _ServicerContext( + rpc_state, + method_handler.request_deserializer, + None, + loop, + ) + + # Prepares the request generator + cdef object request_async_iterator = _message_receiver(servicer_context) + + # Finishes the application handler + await _finish_handler_with_unary_response( + rpc_state, + method_handler.stream_unary, + request_async_iterator, + servicer_context, + method_handler.response_serializer, + loop + ) + + +async def _handle_stream_stream_rpc(object method_handler, + RPCState rpc_state, + object loop): + # Creates a dedecated ServicerContext + cdef _ServicerContext servicer_context = _ServicerContext( + rpc_state, + method_handler.request_deserializer, + method_handler.response_serializer, + loop, + ) + + # Prepares the request generator + cdef object request_async_iterator = _message_receiver(servicer_context) + + # Finishes the application handler + await _finish_handler_with_stream_responses( + rpc_state, + method_handler.stream_stream, + request_async_iterator, + servicer_context, + loop, + ) async def _handle_exceptions(RPCState rpc_state, object rpc_coro, object loop): @@ -293,6 +426,7 @@ async def _handle_cancellation_from_core(object rpc_task, # Awaits cancellation from peer. await execute_batch(rpc_state, ops, loop) + rpc_state.client_closed = True if op.cancelled() and not rpc_task.done(): # Injects `CancelledError` to halt the RPC coroutine rpc_task.cancel() @@ -311,8 +445,9 @@ async def _schedule_rpc_coro(object rpc_coro, async def _handle_rpc(list generic_handlers, RPCState rpc_state, object loop): + cdef object method_handler # Finds the method handler (application logic) - cdef object method_handler = _find_method_handler( + method_handler = _find_method_handler( rpc_state.method().decode(), generic_handlers, ) @@ -328,20 +463,33 @@ async def _handle_rpc(list generic_handlers, RPCState rpc_state, object loop): ) return - # TODO(lidiz) extend to all 4 types of RPC + # Handles unary-unary case + if not method_handler.request_streaming and not method_handler.response_streaming: + await _handle_unary_unary_rpc(method_handler, + rpc_state, + loop) + return + + # Handles unary-stream case if not method_handler.request_streaming and method_handler.response_streaming: - try: - await _handle_unary_stream_rpc(method_handler, + await _handle_unary_stream_rpc(method_handler, rpc_state, loop) - except Exception as e: - raise - elif not method_handler.request_streaming and not method_handler.response_streaming: - await _handle_unary_unary_rpc(method_handler, - rpc_state, - loop) - else: - raise NotImplementedError() + return + + # Handles stream-unary case + if method_handler.request_streaming and not method_handler.response_streaming: + await _handle_stream_unary_rpc(method_handler, + rpc_state, + loop) + return + + # Handles stream-stream case + if method_handler.request_streaming and method_handler.response_streaming: + await _handle_stream_stream_rpc(method_handler, + rpc_state, + loop) + return class _RequestCallError(Exception): pass diff --git a/src/python/grpcio/grpc/experimental/aio/__init__.py b/src/python/grpcio/grpc/experimental/aio/__init__.py index 2f162d52922..8dc52b8b842 100644 --- a/src/python/grpcio/grpc/experimental/aio/__init__.py +++ b/src/python/grpcio/grpc/experimental/aio/__init__.py @@ -22,7 +22,7 @@ from typing import Any, Optional, Sequence, Text, Tuple import six import grpc -from grpc._cython.cygrpc import init_grpc_aio, AbortError +from grpc._cython.cygrpc import EOF, AbortError, init_grpc_aio from ._base_call import Call, RpcContext, UnaryStreamCall, UnaryUnaryCall from ._call import AioRpcError @@ -86,5 +86,5 @@ __all__ = ('AioRpcError', 'RpcContext', 'Call', 'UnaryUnaryCall', 'UnaryStreamCall', 'init_grpc_aio', 'Channel', 'UnaryUnaryMultiCallable', 'ClientCallDetails', 'UnaryUnaryClientInterceptor', 'InterceptedUnaryUnaryCall', - 'insecure_channel', 'secure_channel', 'server', 'Server', + 'insecure_channel', 'server', 'Server', 'EOF', 'secure_channel', 'AbortError') diff --git a/src/python/grpcio/grpc/experimental/aio/_base_call.py b/src/python/grpcio/grpc/experimental/aio/_base_call.py index de63729a148..bdd6902d893 100644 --- a/src/python/grpcio/grpc/experimental/aio/_base_call.py +++ b/src/python/grpcio/grpc/experimental/aio/_base_call.py @@ -19,11 +19,12 @@ RPC, e.g. cancellation. """ from abc import ABCMeta, abstractmethod -from typing import Any, AsyncIterable, Awaitable, Callable, Generic, Text, Optional +from typing import (Any, AsyncIterable, Awaitable, Callable, Generic, Optional, + Text, Union) import grpc -from ._typing import MetadataType, RequestType, ResponseType +from ._typing import EOFType, MetadataType, RequestType, ResponseType __all__ = 'RpcContext', 'Call', 'UnaryUnaryCall', 'UnaryStreamCall' @@ -146,14 +147,85 @@ class UnaryStreamCall(Generic[RequestType, ResponseType], """ @abstractmethod - async def read(self) -> ResponseType: - """Reads one message from the RPC. + async def read(self) -> Union[EOFType, ResponseType]: + """Reads one message from the stream. - For each streaming RPC, concurrent reads in multiple coroutines are not - allowed. If you want to perform read in multiple coroutines, you needs - synchronization. So, you can start another read after current read is - finished. + Read operations must be serialized when called from multiple + coroutines. Returns: - A response message of the RPC. + A response message, or an `grpc.aio.EOF` to indicate the end of the + stream. + """ + + +class StreamUnaryCall(Generic[RequestType, ResponseType], + Call, + metaclass=ABCMeta): + + @abstractmethod + async def write(self, request: RequestType) -> None: + """Writes one message to the stream. + + Raises: + An RpcError exception if the write failed. + """ + + @abstractmethod + async def done_writing(self) -> None: + """Notifies server that the client is done sending messages. + + After done_writing is called, any additional invocation to the write + function will fail. This function is idempotent. + """ + + @abstractmethod + def __await__(self) -> Awaitable[ResponseType]: + """Await the response message to be ready. + + Returns: + The response message of the stream. + """ + + +class StreamStreamCall(Generic[RequestType, ResponseType], + Call, + metaclass=ABCMeta): + + @abstractmethod + def __aiter__(self) -> AsyncIterable[ResponseType]: + """Returns the async iterable representation that yields messages. + + Under the hood, it is calling the "read" method. + + Returns: + An async iterable object that yields messages. + """ + + @abstractmethod + async def read(self) -> Union[EOFType, ResponseType]: + """Reads one message from the stream. + + Read operations must be serialized when called from multiple + coroutines. + + Returns: + A response message, or an `grpc.aio.EOF` to indicate the end of the + stream. + """ + + @abstractmethod + async def write(self, request: RequestType) -> None: + """Writes one message to the stream. + + Raises: + An RpcError exception if the write failed. + """ + + @abstractmethod + async def done_writing(self) -> None: + """Notifies server that the client is done sending messages. + + After done_writing is called, any additional invocation to the write + function will fail. This function is idempotent. """ diff --git a/src/python/grpcio/grpc/experimental/aio/_call.py b/src/python/grpcio/grpc/experimental/aio/_call.py index 3080d233847..25ea89ccbcc 100644 --- a/src/python/grpcio/grpc/experimental/aio/_call.py +++ b/src/python/grpcio/grpc/experimental/aio/_call.py @@ -29,6 +29,7 @@ __all__ = 'AioRpcError', 'Call', 'UnaryUnaryCall', 'UnaryStreamCall' _LOCAL_CANCELLATION_DETAILS = 'Locally cancelled by application!' _GC_CANCELLATION_DETAILS = 'Cancelled upon garbage collection!' _RPC_ALREADY_FINISHED_DETAILS = 'RPC already finished.' +_RPC_HALF_CLOSED_DETAILS = 'RPC is half closed after calling "done_writing".' _OK_CALL_REPRESENTATION = ('<{} of RPC that terminated with:\n' '\tstatus = {}\n' @@ -146,31 +147,48 @@ def _create_rpc_error(initial_metadata: Optional[MetadataType], class Call(_base_call.Call): + """Base implementation of client RPC Call object. + + Implements logic around final status, metadata and cancellation. + """ _loop: asyncio.AbstractEventLoop _code: grpc.StatusCode _status: Awaitable[cygrpc.AioRpcStatus] _initial_metadata: Awaitable[MetadataType] _locally_cancelled: bool + _cython_call: cygrpc._AioCall - def __init__(self) -> None: + def __init__(self, cython_call: cygrpc._AioCall) -> None: self._loop = asyncio.get_event_loop() self._code = None self._status = self._loop.create_future() self._initial_metadata = self._loop.create_future() self._locally_cancelled = False + self._cython_call = cython_call - def cancel(self) -> bool: - """Placeholder cancellation method. - - The implementation of this method needs to pass the cancellation reason - into self._cancellation, using `set_result` instead of - `set_exception`. - """ - raise NotImplementedError() + def __del__(self) -> None: + if not self._status.done(): + self._cancel( + cygrpc.AioRpcStatus(cygrpc.StatusCode.cancelled, + _GC_CANCELLATION_DETAILS, None, None)) def cancelled(self) -> bool: return self._code == grpc.StatusCode.CANCELLED + def _cancel(self, status: cygrpc.AioRpcStatus) -> bool: + """Forwards the application cancellation reasoning.""" + if not self._status.done(): + self._set_status(status) + self._cython_call.cancel(status) + return True + else: + return False + + def cancel(self) -> bool: + return self._cancel( + cygrpc.AioRpcStatus(cygrpc.StatusCode.cancelled, + _LOCAL_CANCELLATION_DETAILS, None, None)) + def done(self) -> bool: return self._status.done() @@ -247,6 +265,7 @@ class Call(_base_call.Call): return self._repr() +# TODO(https://github.com/grpc/grpc/issues/21623) remove this suppression # pylint: disable=abstract-method class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall): """Object for managing unary-unary RPC calls. @@ -254,37 +273,29 @@ class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall): Returned when an instance of `UnaryUnaryMultiCallable` object is called. """ _request: RequestType - _channel: cygrpc.AioChannel _request_serializer: SerializingFunction _response_deserializer: DeserializingFunction _call: asyncio.Task - _cython_call: cygrpc._AioCall - def __init__( # pylint: disable=R0913 - self, request: RequestType, deadline: Optional[float], - credentials: Optional[grpc.CallCredentials], - channel: cygrpc.AioChannel, method: bytes, - request_serializer: SerializingFunction, - response_deserializer: DeserializingFunction) -> None: - super().__init__() + # pylint: disable=too-many-arguments + def __init__(self, request: RequestType, deadline: Optional[float], + credentials: Optional[grpc.CallCredentials], + channel: cygrpc.AioChannel, method: bytes, + request_serializer: SerializingFunction, + response_deserializer: DeserializingFunction) -> None: + channel.call(method, deadline, credentials) + super().__init__(channel.call(method, deadline, credentials)) self._request = request - self._channel = channel self._request_serializer = request_serializer self._response_deserializer = response_deserializer - - if credentials is not None: - grpc_credentials = credentials._credentials - else: - grpc_credentials = None - self._cython_call = self._channel.call(method, deadline, - grpc_credentials) self._call = self._loop.create_task(self._invoke()) - def __del__(self) -> None: - if not self._call.done(): - self._cancel( - cygrpc.AioRpcStatus(cygrpc.StatusCode.cancelled, - _GC_CANCELLATION_DETAILS, None, None)) + def cancel(self) -> bool: + if super().cancel(): + self._call.cancel() + return True + else: + return False async def _invoke(self) -> ResponseType: serialized_request = _common.serialize(self._request, @@ -300,7 +311,7 @@ class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall): self._set_status, ) except asyncio.CancelledError: - if self._code != grpc.StatusCode.CANCELLED: + if not self.cancelled(): self.cancel() # Raises here if RPC failed or cancelled @@ -309,21 +320,6 @@ class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall): return _common.deserialize(serialized_response, self._response_deserializer) - def _cancel(self, status: cygrpc.AioRpcStatus) -> bool: - """Forwards the application cancellation reasoning.""" - if not self._status.done(): - self._set_status(status) - self._cython_call.cancel(status) - self._call.cancel() - return True - else: - return False - - def cancel(self) -> bool: - return self._cancel( - cygrpc.AioRpcStatus(cygrpc.StatusCode.cancelled, - _LOCAL_CANCELLATION_DETAILS, None, None)) - def __await__(self) -> ResponseType: """Wait till the ongoing RPC request finishes.""" try: @@ -339,6 +335,7 @@ class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall): return response +# TODO(https://github.com/grpc/grpc/issues/21623) remove this suppression # pylint: disable=abstract-method class UnaryStreamCall(Call, _base_call.UnaryStreamCall): """Object for managing unary-stream RPC calls. @@ -346,107 +343,346 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall): Returned when an instance of `UnaryStreamMultiCallable` object is called. """ _request: RequestType - _channel: cygrpc.AioChannel _request_serializer: SerializingFunction _response_deserializer: DeserializingFunction - _cython_call: cygrpc._AioCall _send_unary_request_task: asyncio.Task _message_aiter: AsyncIterable[ResponseType] - def __init__( # pylint: disable=R0913 - self, request: RequestType, deadline: Optional[float], - credentials: Optional[grpc.CallCredentials], - channel: cygrpc.AioChannel, method: bytes, - request_serializer: SerializingFunction, - response_deserializer: DeserializingFunction) -> None: - super().__init__() + # pylint: disable=too-many-arguments + def __init__(self, request: RequestType, deadline: Optional[float], + credentials: Optional[grpc.CallCredentials], + channel: cygrpc.AioChannel, method: bytes, + request_serializer: SerializingFunction, + response_deserializer: DeserializingFunction) -> None: + super().__init__(channel.call(method, deadline, credentials)) self._request = request - self._channel = channel self._request_serializer = request_serializer self._response_deserializer = response_deserializer self._send_unary_request_task = self._loop.create_task( self._send_unary_request()) - self._message_aiter = self._fetch_stream_responses() + self._message_aiter = None - if credentials is not None: - grpc_credentials = credentials._credentials + def cancel(self) -> bool: + if super().cancel(): + self._send_unary_request_task.cancel() + return True else: - grpc_credentials = None - - self._cython_call = self._channel.call(method, deadline, - grpc_credentials) - - def __del__(self) -> None: - if not self._status.done(): - self._cancel( - cygrpc.AioRpcStatus(cygrpc.StatusCode.cancelled, - _GC_CANCELLATION_DETAILS, None, None)) + return False async def _send_unary_request(self) -> ResponseType: serialized_request = _common.serialize(self._request, self._request_serializer) try: - await self._cython_call.unary_stream(serialized_request, - self._set_initial_metadata, - self._set_status) + await self._cython_call.initiate_unary_stream( + serialized_request, self._set_initial_metadata, + self._set_status) except asyncio.CancelledError: - if self._code != grpc.StatusCode.CANCELLED: + if not self.cancelled(): self.cancel() raise async def _fetch_stream_responses(self) -> ResponseType: - await self._send_unary_request_task message = await self._read() - while message: + while message is not cygrpc.EOF: yield message message = await self._read() - def _cancel(self, status: cygrpc.AioRpcStatus) -> bool: - """Forwards the application cancellation reasoning. + def __aiter__(self) -> AsyncIterable[ResponseType]: + if self._message_aiter is None: + self._message_aiter = self._fetch_stream_responses() + return self._message_aiter - Async generator will receive an exception. The cancellation will go - deep down into Core, and then propagates backup as the - `cygrpc.AioRpcStatus` exception. + async def _read(self) -> ResponseType: + # Wait for the request being sent + await self._send_unary_request_task - So, under race condition, e.g. the server sent out final state headers - and the client calling "cancel" at the same time, this method respects - the winner in Core. - """ - if not self._status.done(): - self._set_status(status) - self._cython_call.cancel(status) + # Reads response message from Core + try: + raw_response = await self._cython_call.receive_serialized_message() + except asyncio.CancelledError: + if not self.cancelled(): + self.cancel() + await self._raise_for_status() + + if raw_response is cygrpc.EOF: + return cygrpc.EOF + else: + return _common.deserialize(raw_response, + self._response_deserializer) + + async def read(self) -> ResponseType: + if self._status.done(): + await self._raise_for_status() + return cygrpc.EOF + + response_message = await self._read() + + if response_message is cygrpc.EOF: + # If the read operation failed, Core should explain why. + await self._raise_for_status() + return response_message + + +# TODO(https://github.com/grpc/grpc/issues/21623) remove this suppression +# pylint: disable=abstract-method +class StreamUnaryCall(Call, _base_call.StreamUnaryCall): + """Object for managing stream-unary RPC calls. + + Returned when an instance of `StreamUnaryMultiCallable` object is called. + """ + _metadata: MetadataType + _request_serializer: SerializingFunction + _response_deserializer: DeserializingFunction + + _metadata_sent: asyncio.Event + _done_writing: bool + _call_finisher: asyncio.Task + _async_request_poller: asyncio.Task - if not self._send_unary_request_task.done(): - # Injects CancelledError to the Task. The exception will - # propagate to _fetch_stream_responses as well, if the sending - # is not done. - self._send_unary_request_task.cancel() + # pylint: disable=too-many-arguments + def __init__(self, + request_async_iterator: Optional[AsyncIterable[RequestType]], + deadline: Optional[float], + credentials: Optional[grpc.CallCredentials], + channel: cygrpc.AioChannel, method: bytes, + request_serializer: SerializingFunction, + response_deserializer: DeserializingFunction) -> None: + super().__init__(channel.call(method, deadline, credentials)) + self._metadata = _EMPTY_METADATA + self._request_serializer = request_serializer + self._response_deserializer = response_deserializer + + self._metadata_sent = asyncio.Event(loop=self._loop) + self._done_writing = False + + self._call_finisher = self._loop.create_task(self._conduct_rpc()) + + # If user passes in an async iterator, create a consumer Task. + if request_async_iterator is not None: + self._async_request_poller = self._loop.create_task( + self._consume_request_iterator(request_async_iterator)) + else: + self._async_request_poller = None + + def cancel(self) -> bool: + if super().cancel(): + self._call_finisher.cancel() + if self._async_request_poller is not None: + self._async_request_poller.cancel() return True else: return False + def _metadata_sent_observer(self): + self._metadata_sent.set() + + async def _conduct_rpc(self) -> ResponseType: + try: + serialized_response = await self._cython_call.stream_unary( + self._metadata, + self._metadata_sent_observer, + self._set_initial_metadata, + self._set_status, + ) + except asyncio.CancelledError: + if not self.cancelled(): + self.cancel() + + # Raises RpcError if the RPC failed or cancelled + await self._raise_for_status() + + return _common.deserialize(serialized_response, + self._response_deserializer) + + async def _consume_request_iterator( + self, request_async_iterator: AsyncIterable[RequestType]) -> None: + async for request in request_async_iterator: + await self.write(request) + await self.done_writing() + + def __await__(self) -> ResponseType: + """Wait till the ongoing RPC request finishes.""" + try: + response = yield from self._call_finisher + except asyncio.CancelledError: + if not self.cancelled(): + self.cancel() + raise + return response + + async def write(self, request: RequestType) -> None: + if self._status.done(): + raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS) + if self._done_writing: + raise asyncio.InvalidStateError(_RPC_HALF_CLOSED_DETAILS) + if not self._metadata_sent.is_set(): + await self._metadata_sent.wait() + + serialized_request = _common.serialize(request, + self._request_serializer) + + try: + await self._cython_call.send_serialized_message(serialized_request) + except asyncio.CancelledError: + if not self.cancelled(): + self.cancel() + await self._raise_for_status() + + async def done_writing(self) -> None: + """Implementation of done_writing is idempotent.""" + if self._status.done(): + # If the RPC is finished, do nothing. + return + if not self._done_writing: + # If the done writing is not sent before, try to send it. + self._done_writing = True + try: + await self._cython_call.send_receive_close() + except asyncio.CancelledError: + if not self.cancelled(): + self.cancel() + await self._raise_for_status() + + +# TODO(https://github.com/grpc/grpc/issues/21623) remove this suppression +# pylint: disable=abstract-method +class StreamStreamCall(Call, _base_call.StreamStreamCall): + """Object for managing stream-stream RPC calls. + + Returned when an instance of `StreamStreamMultiCallable` object is called. + """ + _metadata: MetadataType + _request_serializer: SerializingFunction + _response_deserializer: DeserializingFunction + + _metadata_sent: asyncio.Event + _done_writing: bool + _initializer: asyncio.Task + _async_request_poller: asyncio.Task + _message_aiter: AsyncIterable[ResponseType] + + # pylint: disable=too-many-arguments + def __init__(self, + request_async_iterator: Optional[AsyncIterable[RequestType]], + deadline: Optional[float], + credentials: Optional[grpc.CallCredentials], + channel: cygrpc.AioChannel, method: bytes, + request_serializer: SerializingFunction, + response_deserializer: DeserializingFunction) -> None: + super().__init__(channel.call(method, deadline, credentials)) + self._metadata = _EMPTY_METADATA + self._request_serializer = request_serializer + self._response_deserializer = response_deserializer + + self._metadata_sent = asyncio.Event(loop=self._loop) + self._done_writing = False + + self._initializer = self._loop.create_task(self._prepare_rpc()) + + # If user passes in an async iterator, create a consumer coroutine. + if request_async_iterator is not None: + self._async_request_poller = self._loop.create_task( + self._consume_request_iterator(request_async_iterator)) + else: + self._async_request_poller = None + self._message_aiter = None + def cancel(self) -> bool: - return self._cancel( - cygrpc.AioRpcStatus(cygrpc.StatusCode.cancelled, - _LOCAL_CANCELLATION_DETAILS, None, None)) + if super().cancel(): + self._initializer.cancel() + if self._async_request_poller is not None: + self._async_request_poller.cancel() + return True + else: + return False + + def _metadata_sent_observer(self): + self._metadata_sent.set() + + async def _prepare_rpc(self): + """This method prepares the RPC for receiving/sending messages. + + All other operations around the stream should only happen after the + completion of this method. + """ + try: + await self._cython_call.initiate_stream_stream( + self._metadata, + self._metadata_sent_observer, + self._set_initial_metadata, + self._set_status, + ) + except asyncio.CancelledError: + if not self.cancelled(): + self.cancel() + # No need to raise RpcError here, because no one will `await` this task. + + async def _consume_request_iterator( + self, request_async_iterator: Optional[AsyncIterable[RequestType]] + ) -> None: + async for request in request_async_iterator: + await self.write(request) + await self.done_writing() + + async def write(self, request: RequestType) -> None: + if self._status.done(): + raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS) + if self._done_writing: + raise asyncio.InvalidStateError(_RPC_HALF_CLOSED_DETAILS) + if not self._metadata_sent.is_set(): + await self._metadata_sent.wait() + + serialized_request = _common.serialize(request, + self._request_serializer) + + try: + await self._cython_call.send_serialized_message(serialized_request) + except asyncio.CancelledError: + if not self.cancelled(): + self.cancel() + await self._raise_for_status() + + async def done_writing(self) -> None: + """Implementation of done_writing is idempotent.""" + if self._status.done(): + # If the RPC is finished, do nothing. + return + if not self._done_writing: + # If the done writing is not sent before, try to send it. + self._done_writing = True + try: + await self._cython_call.send_receive_close() + except asyncio.CancelledError: + if not self.cancelled(): + self.cancel() + await self._raise_for_status() + + async def _fetch_stream_responses(self) -> ResponseType: + """The async generator that yields responses from peer.""" + message = await self._read() + while message is not cygrpc.EOF: + yield message + message = await self._read() def __aiter__(self) -> AsyncIterable[ResponseType]: + if self._message_aiter is None: + self._message_aiter = self._fetch_stream_responses() return self._message_aiter async def _read(self) -> ResponseType: - # Wait for the request being sent - await self._send_unary_request_task + # Wait for the setup + await self._initializer # Reads response message from Core try: raw_response = await self._cython_call.receive_serialized_message() except asyncio.CancelledError: - if self._code != grpc.StatusCode.CANCELLED: + if not self.cancelled(): self.cancel() - raise + await self._raise_for_status() - if raw_response is None: - return None + if raw_response is cygrpc.EOF: + return cygrpc.EOF else: return _common.deserialize(raw_response, self._response_deserializer) @@ -454,14 +690,11 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall): async def read(self) -> ResponseType: if self._status.done(): await self._raise_for_status() - raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS) + return cygrpc.EOF response_message = await self._read() - if response_message is None: + if response_message is cygrpc.EOF: # If the read operation failed, Core should explain why. await self._raise_for_status() - # If no exception raised, there is something wrong internally. - assert False, 'Read operation failed with StatusCode.OK' - else: - return response_message + return response_message diff --git a/src/python/grpcio/grpc/experimental/aio/_channel.py b/src/python/grpcio/grpc/experimental/aio/_channel.py index 2562f0f6d81..6d4fe9145b0 100644 --- a/src/python/grpcio/grpc/experimental/aio/_channel.py +++ b/src/python/grpcio/grpc/experimental/aio/_channel.py @@ -13,14 +13,15 @@ # limitations under the License. """Invocation-side implementation of gRPC Asyncio Python.""" import asyncio -from typing import Any, Optional, Sequence, Text +from typing import Any, AsyncIterable, Optional, Sequence, Text import grpc from grpc import _common from grpc._cython import cygrpc from . import _base_call -from ._call import UnaryStreamCall, UnaryUnaryCall +from ._call import (StreamStreamCall, StreamUnaryCall, UnaryStreamCall, + UnaryUnaryCall) from ._interceptor import (InterceptedUnaryUnaryCall, UnaryUnaryClientInterceptor) from ._typing import (ChannelArgumentType, DeserializingFunction, MetadataType, @@ -28,8 +29,16 @@ from ._typing import (ChannelArgumentType, DeserializingFunction, MetadataType, from ._utils import _timeout_to_deadline -class UnaryUnaryMultiCallable: - """Factory an asynchronous unary-unary RPC stub call from client-side.""" +class _BaseMultiCallable: + """Base class of all multi callable objects. + + Handles the initialization logic and stores common attributes. + """ + _loop: asyncio.AbstractEventLoop + _channel: cygrpc.AioChannel + _method: bytes + _request_serializer: SerializingFunction + _response_deserializer: DeserializingFunction _channel: cygrpc.AioChannel _method: bytes @@ -50,6 +59,10 @@ class UnaryUnaryMultiCallable: self._response_deserializer = response_deserializer self._interceptors = interceptors + +class UnaryUnaryMultiCallable(_BaseMultiCallable): + """Factory an asynchronous unary-unary RPC stub call from client-side.""" + def __call__(self, request: Any, *, @@ -114,17 +127,8 @@ class UnaryUnaryMultiCallable: ) -class UnaryStreamMultiCallable: - """Afford invoking a unary-stream RPC from client-side in an asynchronous way.""" - - def __init__(self, channel: cygrpc.AioChannel, method: bytes, - request_serializer: SerializingFunction, - response_deserializer: DeserializingFunction) -> None: - self._channel = channel - self._method = method - self._request_serializer = request_serializer - self._response_deserializer = response_deserializer - self._loop = asyncio.get_event_loop() +class UnaryStreamMultiCallable(_BaseMultiCallable): + """Affords invoking a unary-stream RPC from client-side in an asynchronous way.""" def __call__(self, request: Any, @@ -176,6 +180,122 @@ class UnaryStreamMultiCallable: ) +class StreamUnaryMultiCallable(_BaseMultiCallable): + """Affords invoking a stream-unary RPC from client-side in an asynchronous way.""" + + def __call__(self, + request_async_iterator: Optional[AsyncIterable[Any]] = None, + timeout: Optional[float] = None, + metadata: Optional[MetadataType] = None, + credentials: Optional[grpc.CallCredentials] = None, + wait_for_ready: Optional[bool] = None, + compression: Optional[grpc.Compression] = None + ) -> _base_call.StreamUnaryCall: + """Asynchronously invokes the underlying RPC. + + Args: + request: The request value for the RPC. + timeout: An optional duration of time in seconds to allow + for the RPC. + metadata: Optional :term:`metadata` to be transmitted to the + service-side of the RPC. + credentials: An optional CallCredentials for the RPC. Only valid for + secure Channel. + wait_for_ready: This is an EXPERIMENTAL argument. An optional + flag to enable wait for ready mechanism + compression: An element of grpc.compression, e.g. + grpc.compression.Gzip. This is an EXPERIMENTAL option. + + Returns: + A Call object instance which is an awaitable object. + + Raises: + RpcError: Indicating that the RPC terminated with non-OK status. The + raised RpcError will also be a Call for the RPC affording the RPC's + metadata, status code, and details. + """ + + if metadata: + raise NotImplementedError("TODO: metadata not implemented yet") + + if wait_for_ready: + raise NotImplementedError( + "TODO: wait_for_ready not implemented yet") + + if compression: + raise NotImplementedError("TODO: compression not implemented yet") + + deadline = _timeout_to_deadline(timeout) + + return StreamUnaryCall( + request_async_iterator, + deadline, + credentials, + self._channel, + self._method, + self._request_serializer, + self._response_deserializer, + ) + + +class StreamStreamMultiCallable(_BaseMultiCallable): + """Affords invoking a stream-stream RPC from client-side in an asynchronous way.""" + + def __call__(self, + request_async_iterator: Optional[AsyncIterable[Any]] = None, + timeout: Optional[float] = None, + metadata: Optional[MetadataType] = None, + credentials: Optional[grpc.CallCredentials] = None, + wait_for_ready: Optional[bool] = None, + compression: Optional[grpc.Compression] = None + ) -> _base_call.StreamStreamCall: + """Asynchronously invokes the underlying RPC. + + Args: + request: The request value for the RPC. + timeout: An optional duration of time in seconds to allow + for the RPC. + metadata: Optional :term:`metadata` to be transmitted to the + service-side of the RPC. + credentials: An optional CallCredentials for the RPC. Only valid for + secure Channel. + wait_for_ready: This is an EXPERIMENTAL argument. An optional + flag to enable wait for ready mechanism + compression: An element of grpc.compression, e.g. + grpc.compression.Gzip. This is an EXPERIMENTAL option. + + Returns: + A Call object instance which is an awaitable object. + + Raises: + RpcError: Indicating that the RPC terminated with non-OK status. The + raised RpcError will also be a Call for the RPC affording the RPC's + metadata, status code, and details. + """ + + if metadata: + raise NotImplementedError("TODO: metadata not implemented yet") + + if wait_for_ready: + raise NotImplementedError( + "TODO: wait_for_ready not implemented yet") + + if compression: + raise NotImplementedError("TODO: compression not implemented yet") + + deadline = _timeout_to_deadline(timeout) + + return StreamStreamCall( + request_async_iterator, + deadline, + credentials, + self._channel, + self._method, + self._request_serializer, + self._response_deserializer, + ) + + class Channel: """Asynchronous Channel implementation. @@ -301,21 +421,27 @@ class Channel: ) -> UnaryStreamMultiCallable: return UnaryStreamMultiCallable(self._channel, _common.encode(method), request_serializer, - response_deserializer) + response_deserializer, None) def stream_unary( self, method: Text, request_serializer: Optional[SerializingFunction] = None, - response_deserializer: Optional[DeserializingFunction] = None): - """Placeholder method for stream-unary calls.""" + response_deserializer: Optional[DeserializingFunction] = None + ) -> StreamUnaryMultiCallable: + return StreamUnaryMultiCallable(self._channel, _common.encode(method), + request_serializer, + response_deserializer, None) def stream_stream( self, method: Text, request_serializer: Optional[SerializingFunction] = None, - response_deserializer: Optional[DeserializingFunction] = None): - """Placeholder method for stream-stream calls.""" + response_deserializer: Optional[DeserializingFunction] = None + ) -> StreamStreamMultiCallable: + return StreamStreamMultiCallable(self._channel, _common.encode(method), + request_serializer, + response_deserializer, None) async def _close(self): # TODO: Send cancellation status diff --git a/src/python/grpcio/grpc/experimental/aio/_typing.py b/src/python/grpcio/grpc/experimental/aio/_typing.py index 7ff893893a9..6428fb72f98 100644 --- a/src/python/grpcio/grpc/experimental/aio/_typing.py +++ b/src/python/grpcio/grpc/experimental/aio/_typing.py @@ -14,6 +14,7 @@ """Common types for gRPC Async API""" from typing import Any, AnyStr, Callable, Sequence, Text, Tuple, TypeVar +from grpc._cython.cygrpc import EOF RequestType = TypeVar('RequestType') ResponseType = TypeVar('ResponseType') @@ -21,3 +22,4 @@ SerializingFunction = Callable[[Any], bytes] DeserializingFunction = Callable[[bytes], Any] MetadataType = Sequence[Tuple[Text, AnyStr]] ChannelArgumentType = Sequence[Tuple[Text, Any]] +EOFType = type(EOF) diff --git a/src/python/grpcio_tests/tests_aio/tests.json b/src/python/grpcio_tests/tests_aio/tests.json index 024d5c877f2..605a088f204 100644 --- a/src/python/grpcio_tests/tests_aio/tests.json +++ b/src/python/grpcio_tests/tests_aio/tests.json @@ -2,6 +2,8 @@ "_sanity._sanity_test.AioSanityTest", "unit.abort_test.TestAbort", "unit.aio_rpc_error_test.TestAioRpcError", + "unit.call_test.TestStreamStreamCall", + "unit.call_test.TestStreamUnaryCall", "unit.call_test.TestUnaryStreamCall", "unit.call_test.TestUnaryUnaryCall", "unit.channel_argument_test.TestChannelArgument", diff --git a/src/python/grpcio_tests/tests_aio/unit/_test_server.py b/src/python/grpcio_tests/tests_aio/unit/_test_server.py index d99c46f05c2..ccb9f45fe4d 100644 --- a/src/python/grpcio_tests/tests_aio/unit/_test_server.py +++ b/src/python/grpcio_tests/tests_aio/unit/_test_server.py @@ -26,11 +26,12 @@ from tests_aio.unit._constants import UNARY_CALL_WITH_SLEEP_VALUE class _TestServiceServicer(test_pb2_grpc.TestServiceServicer): - async def UnaryCall(self, request, context): + async def UnaryCall(self, unused_request, unused_context): return messages_pb2.SimpleResponse() async def StreamingOutputCall( - self, request: messages_pb2.StreamingOutputCallRequest, context): + self, request: messages_pb2.StreamingOutputCallRequest, + unused_context): for response_parameters in request.response_parameters: if response_parameters.interval_us != 0: await asyncio.sleep( @@ -44,11 +45,30 @@ class _TestServiceServicer(test_pb2_grpc.TestServiceServicer): # Next methods are extra ones that are registred programatically # when the sever is instantiated. They are not being provided by # the proto file. - async def UnaryCallWithSleep(self, request, context): await asyncio.sleep(UNARY_CALL_WITH_SLEEP_VALUE) return messages_pb2.SimpleResponse() + async def StreamingInputCall(self, request_async_iterator, unused_context): + aggregate_size = 0 + async for request in request_async_iterator: + if request.payload is not None and request.payload.body: + aggregate_size += len(request.payload.body) + return messages_pb2.StreamingInputCallResponse( + aggregated_payload_size=aggregate_size) + + async def FullDuplexCall(self, request_async_iterator, unused_context): + async for request in request_async_iterator: + for response_parameters in request.response_parameters: + if response_parameters.interval_us != 0: + await asyncio.sleep( + datetime.timedelta(microseconds=response_parameters. + interval_us).total_seconds()) + yield messages_pb2.StreamingOutputCallResponse( + payload=messages_pb2.Payload(type=request.payload.type, + body=b'\x00' * + response_parameters.size)) + async def start_test_server(secure=False): server = aio.server(options=(('grpc.so_reuseport', 0),)) diff --git a/src/python/grpcio_tests/tests_aio/unit/call_test.py b/src/python/grpcio_tests/tests_aio/unit/call_test.py index 59db225031c..209643e52d1 100644 --- a/src/python/grpcio_tests/tests_aio/unit/call_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/call_test.py @@ -30,10 +30,10 @@ from src.proto.grpc.testing import messages_pb2 _NUM_STREAM_RESPONSES = 5 _RESPONSE_PAYLOAD_SIZE = 42 +_REQUEST_PAYLOAD_SIZE = 7 _LOCAL_CANCEL_DETAILS_EXPECTATION = 'Locally cancelled by application!' _RESPONSE_INTERVAL_US = test_constants.SHORT_TIMEOUT * 1000 * 1000 _UNREACHABLE_TARGET = '0.1:1111' - _INFINITE_INTERVAL_US = 2**31 - 1 @@ -286,7 +286,7 @@ class TestUnaryStreamCall(AioTestBase): [grpc.StatusCode.OK, grpc.StatusCode.CANCELLED]) async def test_too_many_reads_unary_stream(self): - """Test cancellation after received all messages.""" + """Test calling read after received all messages fails.""" async with aio.insecure_channel(self._server_target) as channel: stub = test_pb2_grpc.TestServiceStub(channel) @@ -306,13 +306,14 @@ class TestUnaryStreamCall(AioTestBase): messages_pb2.StreamingOutputCallResponse) self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) + self.assertIs(await call.read(), aio.EOF) # After the RPC is finished, further reads will lead to exception. self.assertEqual(await call.code(), grpc.StatusCode.OK) - with self.assertRaises(asyncio.InvalidStateError): - await call.read() + self.assertIs(await call.read(), aio.EOF) async def test_unary_stream_async_generator(self): + """Sunny day test case for unary_stream.""" async with aio.insecure_channel(self._server_target) as channel: stub = test_pb2_grpc.TestServiceStub(channel) @@ -426,6 +427,307 @@ class TestUnaryStreamCall(AioTestBase): self.loop.run_until_complete(coro()) +class TestStreamUnaryCall(AioTestBase): + + async def setUp(self): + self._server_target, self._server = await start_test_server() + self._channel = aio.insecure_channel(self._server_target) + self._stub = test_pb2_grpc.TestServiceStub(self._channel) + + async def tearDown(self): + await self._channel.close() + await self._server.stop(None) + + async def test_cancel_stream_unary(self): + call = self._stub.StreamingInputCall() + + # Prepares the request + payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) + request = messages_pb2.StreamingInputCallRequest(payload=payload) + + # Sends out requests + for _ in range(_NUM_STREAM_RESPONSES): + await call.write(request) + + # Cancels the RPC + self.assertFalse(call.done()) + self.assertFalse(call.cancelled()) + self.assertTrue(call.cancel()) + self.assertTrue(call.cancelled()) + + await call.done_writing() + + with self.assertRaises(asyncio.CancelledError): + await call + + async def test_early_cancel_stream_unary(self): + call = self._stub.StreamingInputCall() + + # Cancels the RPC + self.assertFalse(call.done()) + self.assertFalse(call.cancelled()) + self.assertTrue(call.cancel()) + self.assertTrue(call.cancelled()) + + with self.assertRaises(asyncio.InvalidStateError): + await call.write(messages_pb2.StreamingInputCallRequest()) + + # Should be no-op + await call.done_writing() + + with self.assertRaises(asyncio.CancelledError): + await call + + async def test_write_after_done_writing(self): + call = self._stub.StreamingInputCall() + + # Prepares the request + payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) + request = messages_pb2.StreamingInputCallRequest(payload=payload) + + # Sends out requests + for _ in range(_NUM_STREAM_RESPONSES): + await call.write(request) + + # Should be no-op + await call.done_writing() + + with self.assertRaises(asyncio.InvalidStateError): + await call.write(messages_pb2.StreamingInputCallRequest()) + + response = await call + self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse) + self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE, + response.aggregated_payload_size) + + self.assertEqual(await call.code(), grpc.StatusCode.OK) + + async def test_error_in_async_generator(self): + # Server will pause between responses + request = messages_pb2.StreamingOutputCallRequest() + request.response_parameters.append( + messages_pb2.ResponseParameters( + size=_RESPONSE_PAYLOAD_SIZE, + interval_us=_RESPONSE_INTERVAL_US, + )) + + # We expect the request iterator to receive the exception + request_iterator_received_the_exception = asyncio.Event() + + async def request_iterator(): + with self.assertRaises(asyncio.CancelledError): + for _ in range(_NUM_STREAM_RESPONSES): + yield request + await asyncio.sleep(test_constants.SHORT_TIMEOUT) + request_iterator_received_the_exception.set() + + call = self._stub.StreamingInputCall(request_iterator()) + + # Cancel the RPC after at least one response + async def cancel_later(): + await asyncio.sleep(test_constants.SHORT_TIMEOUT * 2) + call.cancel() + + cancel_later_task = self.loop.create_task(cancel_later()) + + # No exceptions here + with self.assertRaises(asyncio.CancelledError): + await call + + await request_iterator_received_the_exception.wait() + + # No failures in the cancel later task! + await cancel_later_task + + +# Prepares the request that stream in a ping-pong manner. +_STREAM_OUTPUT_REQUEST_ONE_RESPONSE = messages_pb2.StreamingOutputCallRequest() +_STREAM_OUTPUT_REQUEST_ONE_RESPONSE.response_parameters.append( + messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE)) + + +class TestStreamStreamCall(AioTestBase): + + async def setUp(self): + self._server_target, self._server = await start_test_server() + self._channel = aio.insecure_channel(self._server_target) + self._stub = test_pb2_grpc.TestServiceStub(self._channel) + + async def tearDown(self): + await self._channel.close() + await self._server.stop(None) + + async def test_cancel(self): + # Invokes the actual RPC + call = self._stub.FullDuplexCall() + + for _ in range(_NUM_STREAM_RESPONSES): + await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE) + response = await call.read() + self.assertIsInstance(response, + messages_pb2.StreamingOutputCallResponse) + self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) + + # Cancels the RPC + self.assertFalse(call.done()) + self.assertFalse(call.cancelled()) + self.assertTrue(call.cancel()) + self.assertTrue(call.cancelled()) + self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) + + async def test_cancel_with_pending_read(self): + call = self._stub.FullDuplexCall() + + await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE) + + # Cancels the RPC + self.assertFalse(call.done()) + self.assertFalse(call.cancelled()) + self.assertTrue(call.cancel()) + self.assertTrue(call.cancelled()) + self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) + + async def test_cancel_with_ongoing_read(self): + call = self._stub.FullDuplexCall() + coro_started = asyncio.Event() + + async def read_coro(): + coro_started.set() + await call.read() + + read_task = self.loop.create_task(read_coro()) + await coro_started.wait() + self.assertFalse(read_task.done()) + + # Cancels the RPC + self.assertFalse(call.done()) + self.assertFalse(call.cancelled()) + self.assertTrue(call.cancel()) + self.assertTrue(call.cancelled()) + self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) + + async def test_early_cancel(self): + call = self._stub.FullDuplexCall() + + # Cancels the RPC + self.assertFalse(call.done()) + self.assertFalse(call.cancelled()) + self.assertTrue(call.cancel()) + self.assertTrue(call.cancelled()) + self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) + + async def test_cancel_after_done_writing(self): + call = self._stub.FullDuplexCall() + await call.done_writing() + + # Cancels the RPC + self.assertFalse(call.done()) + self.assertFalse(call.cancelled()) + self.assertTrue(call.cancel()) + self.assertTrue(call.cancelled()) + self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) + + async def test_late_cancel(self): + call = self._stub.FullDuplexCall() + await call.done_writing() + self.assertEqual(grpc.StatusCode.OK, await call.code()) + + # Cancels the RPC + self.assertTrue(call.done()) + self.assertFalse(call.cancelled()) + self.assertFalse(call.cancel()) + self.assertFalse(call.cancelled()) + + # Status is still OK + self.assertEqual(grpc.StatusCode.OK, await call.code()) + + async def test_async_generator(self): + + async def request_generator(): + yield _STREAM_OUTPUT_REQUEST_ONE_RESPONSE + yield _STREAM_OUTPUT_REQUEST_ONE_RESPONSE + + call = self._stub.FullDuplexCall(request_generator()) + async for response in call: + self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) + + self.assertEqual(await call.code(), grpc.StatusCode.OK) + + async def test_too_many_reads(self): + + async def request_generator(): + for _ in range(_NUM_STREAM_RESPONSES): + yield _STREAM_OUTPUT_REQUEST_ONE_RESPONSE + + call = self._stub.FullDuplexCall(request_generator()) + for _ in range(_NUM_STREAM_RESPONSES): + response = await call.read() + self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) + self.assertIs(await call.read(), aio.EOF) + + self.assertEqual(await call.code(), grpc.StatusCode.OK) + # After the RPC finished, the read should also produce EOF + self.assertIs(await call.read(), aio.EOF) + + async def test_read_write_after_done_writing(self): + call = self._stub.FullDuplexCall() + + # Writes two requests, and pending two requests + await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE) + await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE) + await call.done_writing() + + # Further write should fail + with self.assertRaises(asyncio.InvalidStateError): + await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE) + + # But read should be unaffected + response = await call.read() + self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) + response = await call.read() + self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) + + self.assertEqual(await call.code(), grpc.StatusCode.OK) + + async def test_error_in_async_generator(self): + # Server will pause between responses + request = messages_pb2.StreamingOutputCallRequest() + request.response_parameters.append( + messages_pb2.ResponseParameters( + size=_RESPONSE_PAYLOAD_SIZE, + interval_us=_RESPONSE_INTERVAL_US, + )) + + # We expect the request iterator to receive the exception + request_iterator_received_the_exception = asyncio.Event() + + async def request_iterator(): + with self.assertRaises(asyncio.CancelledError): + for _ in range(_NUM_STREAM_RESPONSES): + yield request + await asyncio.sleep(test_constants.SHORT_TIMEOUT) + request_iterator_received_the_exception.set() + + call = self._stub.FullDuplexCall(request_iterator()) + + # Cancel the RPC after at least one response + async def cancel_later(): + await asyncio.sleep(test_constants.SHORT_TIMEOUT * 2) + call.cancel() + + cancel_later_task = self.loop.create_task(cancel_later()) + + # No exceptions here + async for response in call: + self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) + + await request_iterator_received_the_exception.wait() + + self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) + # No failures in the cancel later task! + await cancel_later_task + + if __name__ == '__main__': logging.basicConfig() unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests_aio/unit/channel_test.py b/src/python/grpcio_tests/tests_aio/unit/channel_test.py index 1ab372a0e8c..6267862d890 100644 --- a/src/python/grpcio_tests/tests_aio/unit/channel_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/channel_test.py @@ -32,6 +32,7 @@ _UNARY_CALL_METHOD = '/grpc.testing.TestService/UnaryCall' _UNARY_CALL_METHOD_WITH_SLEEP = '/grpc.testing.TestService/UnaryCallWithSleep' _STREAMING_OUTPUT_CALL_METHOD = '/grpc.testing.TestService/StreamingOutputCall' _NUM_STREAM_RESPONSES = 5 +_REQUEST_PAYLOAD_SIZE = 7 _RESPONSE_PAYLOAD_SIZE = 42 @@ -121,7 +122,104 @@ class TestChannel(AioTestBase): self.assertEqual(await call.code(), grpc.StatusCode.OK) await channel.close() + async def test_stream_unary_using_write(self): + channel = aio.insecure_channel(self._server_target) + stub = test_pb2_grpc.TestServiceStub(channel) + + # Invokes the actual RPC + call = stub.StreamingInputCall() + + # Prepares the request + payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) + request = messages_pb2.StreamingInputCallRequest(payload=payload) + + # Sends out requests + for _ in range(_NUM_STREAM_RESPONSES): + await call.write(request) + await call.done_writing() + + # Validates the responses + response = await call + self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse) + self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE, + response.aggregated_payload_size) + + self.assertEqual(await call.code(), grpc.StatusCode.OK) + await channel.close() + + async def test_stream_unary_using_async_gen(self): + channel = aio.insecure_channel(self._server_target) + stub = test_pb2_grpc.TestServiceStub(channel) + + # Prepares the request + payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) + request = messages_pb2.StreamingInputCallRequest(payload=payload) + + async def gen(): + for _ in range(_NUM_STREAM_RESPONSES): + yield request + + # Invokes the actual RPC + call = stub.StreamingInputCall(gen()) + + # Validates the responses + response = await call + self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse) + self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE, + response.aggregated_payload_size) + + self.assertEqual(await call.code(), grpc.StatusCode.OK) + await channel.close() + + async def test_stream_stream_using_read_write(self): + channel = aio.insecure_channel(self._server_target) + stub = test_pb2_grpc.TestServiceStub(channel) + + # Invokes the actual RPC + call = stub.FullDuplexCall() + + # Prepares the request + request = messages_pb2.StreamingOutputCallRequest() + request.response_parameters.append( + messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE)) + + for _ in range(_NUM_STREAM_RESPONSES): + await call.write(request) + response = await call.read() + self.assertIsInstance(response, + messages_pb2.StreamingOutputCallResponse) + self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) + + await call.done_writing() + + self.assertEqual(grpc.StatusCode.OK, await call.code()) + await channel.close() + + async def test_stream_stream_using_async_gen(self): + channel = aio.insecure_channel(self._server_target) + stub = test_pb2_grpc.TestServiceStub(channel) + + # Prepares the request + request = messages_pb2.StreamingOutputCallRequest() + request.response_parameters.append( + messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE)) + + async def gen(): + for _ in range(_NUM_STREAM_RESPONSES): + yield request + + # Invokes the actual RPC + call = stub.FullDuplexCall(gen()) + + async for response in call: + self.assertIsInstance(response, + messages_pb2.StreamingOutputCallResponse) + self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) + + self.assertEqual(grpc.StatusCode.OK, await call.code()) + await channel.close() + if __name__ == '__main__': - logging.basicConfig(level=logging.WARN) + logging.basicConfig(level=logging.DEBUG) unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests_aio/unit/server_test.py b/src/python/grpcio_tests/tests_aio/unit/server_test.py index 265744b1f54..fff944f27bd 100644 --- a/src/python/grpcio_tests/tests_aio/unit/server_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/server_test.py @@ -13,15 +13,16 @@ # limitations under the License. import asyncio +import gc import logging -import unittest import time -import gc +import unittest import grpc from grpc.experimental import aio -from tests_aio.unit._test_base import AioTestBase + from tests.unit.framework.common import test_constants +from tests_aio.unit._test_base import AioTestBase _SIMPLE_UNARY_UNARY = '/test/SimpleUnaryUnary' _BLOCK_FOREVER = '/test/BlockForever' @@ -29,9 +30,16 @@ _BLOCK_BRIEFLY = '/test/BlockBriefly' _UNARY_STREAM_ASYNC_GEN = '/test/UnaryStreamAsyncGen' _UNARY_STREAM_READER_WRITER = '/test/UnaryStreamReaderWriter' _UNARY_STREAM_EVILLY_MIXED = '/test/UnaryStreamEvillyMixed' +_STREAM_UNARY_ASYNC_GEN = '/test/StreamUnaryAsyncGen' +_STREAM_UNARY_READER_WRITER = '/test/StreamUnaryReaderWriter' +_STREAM_UNARY_EVILLY_MIXED = '/test/StreamUnaryEvillyMixed' +_STREAM_STREAM_ASYNC_GEN = '/test/StreamStreamAsyncGen' +_STREAM_STREAM_READER_WRITER = '/test/StreamStreamReaderWriter' +_STREAM_STREAM_EVILLY_MIXED = '/test/StreamStreamEvillyMixed' _REQUEST = b'\x00\x00\x00' _RESPONSE = b'\x01\x01\x01' +_NUM_STREAM_REQUESTS = 3 _NUM_STREAM_RESPONSES = 5 @@ -39,6 +47,41 @@ class _GenericHandler(grpc.GenericRpcHandler): def __init__(self): self._called = asyncio.get_event_loop().create_future() + self._routing_table = { + _SIMPLE_UNARY_UNARY: + grpc.unary_unary_rpc_method_handler(self._unary_unary), + _BLOCK_FOREVER: + grpc.unary_unary_rpc_method_handler(self._block_forever), + _BLOCK_BRIEFLY: + grpc.unary_unary_rpc_method_handler(self._block_briefly), + _UNARY_STREAM_ASYNC_GEN: + grpc.unary_stream_rpc_method_handler( + self._unary_stream_async_gen), + _UNARY_STREAM_READER_WRITER: + grpc.unary_stream_rpc_method_handler( + self._unary_stream_reader_writer), + _UNARY_STREAM_EVILLY_MIXED: + grpc.unary_stream_rpc_method_handler( + self._unary_stream_evilly_mixed), + _STREAM_UNARY_ASYNC_GEN: + grpc.stream_unary_rpc_method_handler( + self._stream_unary_async_gen), + _STREAM_UNARY_READER_WRITER: + grpc.stream_unary_rpc_method_handler( + self._stream_unary_reader_writer), + _STREAM_UNARY_EVILLY_MIXED: + grpc.stream_unary_rpc_method_handler( + self._stream_unary_evilly_mixed), + _STREAM_STREAM_ASYNC_GEN: + grpc.stream_stream_rpc_method_handler( + self._stream_stream_async_gen), + _STREAM_STREAM_READER_WRITER: + grpc.stream_stream_rpc_method_handler( + self._stream_stream_reader_writer), + _STREAM_STREAM_EVILLY_MIXED: + grpc.stream_stream_rpc_method_handler( + self._stream_stream_evilly_mixed), + } @staticmethod async def _unary_unary(unused_request, unused_context): @@ -64,23 +107,59 @@ class _GenericHandler(grpc.GenericRpcHandler): for _ in range(_NUM_STREAM_RESPONSES - 1): await context.write(_RESPONSE) + async def _stream_unary_async_gen(self, request_iterator, unused_context): + request_count = 0 + async for request in request_iterator: + assert _REQUEST == request + request_count += 1 + assert _NUM_STREAM_REQUESTS == request_count + return _RESPONSE + + async def _stream_unary_reader_writer(self, unused_request, context): + for _ in range(_NUM_STREAM_REQUESTS): + assert _REQUEST == await context.read() + return _RESPONSE + + async def _stream_unary_evilly_mixed(self, request_iterator, context): + assert _REQUEST == await context.read() + request_count = 0 + async for request in request_iterator: + assert _REQUEST == request + request_count += 1 + assert _NUM_STREAM_REQUESTS - 1 == request_count + return _RESPONSE + + async def _stream_stream_async_gen(self, request_iterator, unused_context): + request_count = 0 + async for request in request_iterator: + assert _REQUEST == request + request_count += 1 + assert _NUM_STREAM_REQUESTS == request_count + + for _ in range(_NUM_STREAM_RESPONSES): + yield _RESPONSE + + async def _stream_stream_reader_writer(self, unused_request, context): + for _ in range(_NUM_STREAM_REQUESTS): + assert _REQUEST == await context.read() + for _ in range(_NUM_STREAM_RESPONSES): + await context.write(_RESPONSE) + + async def _stream_stream_evilly_mixed(self, request_iterator, context): + assert _REQUEST == await context.read() + request_count = 0 + async for request in request_iterator: + assert _REQUEST == request + request_count += 1 + assert _NUM_STREAM_REQUESTS - 1 == request_count + + yield _RESPONSE + for _ in range(_NUM_STREAM_RESPONSES - 1): + await context.write(_RESPONSE) + def service(self, handler_details): self._called.set_result(None) - if handler_details.method == _SIMPLE_UNARY_UNARY: - return grpc.unary_unary_rpc_method_handler(self._unary_unary) - if handler_details.method == _BLOCK_FOREVER: - return grpc.unary_unary_rpc_method_handler(self._block_forever) - if handler_details.method == _BLOCK_BRIEFLY: - return grpc.unary_unary_rpc_method_handler(self._block_briefly) - if handler_details.method == _UNARY_STREAM_ASYNC_GEN: - return grpc.unary_stream_rpc_method_handler( - self._unary_stream_async_gen) - if handler_details.method == _UNARY_STREAM_READER_WRITER: - return grpc.unary_stream_rpc_method_handler( - self._unary_stream_reader_writer) - if handler_details.method == _UNARY_STREAM_EVILLY_MIXED: - return grpc.unary_stream_rpc_method_handler( - self._unary_stream_evilly_mixed) + return self._routing_table[handler_details.method] async def wait_for_call(self): await self._called @@ -98,89 +177,152 @@ async def _start_test_server(): class TestServer(AioTestBase): async def setUp(self): - self._server_target, self._server, self._generic_handler = await _start_test_server( - ) + addr, self._server, self._generic_handler = await _start_test_server() + self._channel = aio.insecure_channel(addr) async def tearDown(self): + await self._channel.close() await self._server.stop(None) async def test_unary_unary(self): - async with aio.insecure_channel(self._server_target) as channel: - unary_unary_call = channel.unary_unary(_SIMPLE_UNARY_UNARY) - response = await unary_unary_call(_REQUEST) - self.assertEqual(response, _RESPONSE) + unary_unary_call = self._channel.unary_unary(_SIMPLE_UNARY_UNARY) + response = await unary_unary_call(_REQUEST) + self.assertEqual(response, _RESPONSE) async def test_unary_stream_async_generator(self): - async with aio.insecure_channel(self._server_target) as channel: - unary_stream_call = channel.unary_stream(_UNARY_STREAM_ASYNC_GEN) - call = unary_stream_call(_REQUEST) + unary_stream_call = self._channel.unary_stream(_UNARY_STREAM_ASYNC_GEN) + call = unary_stream_call(_REQUEST) - # Expecting the request message to reach server before retriving - # any responses. - await asyncio.wait_for(self._generic_handler.wait_for_call(), - test_constants.SHORT_TIMEOUT) + response_cnt = 0 + async for response in call: + response_cnt += 1 + self.assertEqual(_RESPONSE, response) - response_cnt = 0 - async for response in call: - response_cnt += 1 - self.assertEqual(_RESPONSE, response) - - self.assertEqual(_NUM_STREAM_RESPONSES, response_cnt) - self.assertEqual(await call.code(), grpc.StatusCode.OK) + self.assertEqual(_NUM_STREAM_RESPONSES, response_cnt) + self.assertEqual(await call.code(), grpc.StatusCode.OK) async def test_unary_stream_reader_writer(self): - async with aio.insecure_channel(self._server_target) as channel: - unary_stream_call = channel.unary_stream( - _UNARY_STREAM_READER_WRITER) - call = unary_stream_call(_REQUEST) - - # Expecting the request message to reach server before retriving - # any responses. - await asyncio.wait_for(self._generic_handler.wait_for_call(), - test_constants.SHORT_TIMEOUT) + unary_stream_call = self._channel.unary_stream( + _UNARY_STREAM_READER_WRITER) + call = unary_stream_call(_REQUEST) - for _ in range(_NUM_STREAM_RESPONSES): - response = await call.read() - self.assertEqual(_RESPONSE, response) + for _ in range(_NUM_STREAM_RESPONSES): + response = await call.read() + self.assertEqual(_RESPONSE, response) - self.assertEqual(await call.code(), grpc.StatusCode.OK) + self.assertEqual(await call.code(), grpc.StatusCode.OK) async def test_unary_stream_evilly_mixed(self): - async with aio.insecure_channel(self._server_target) as channel: - unary_stream_call = channel.unary_stream(_UNARY_STREAM_EVILLY_MIXED) - call = unary_stream_call(_REQUEST) + unary_stream_call = self._channel.unary_stream( + _UNARY_STREAM_EVILLY_MIXED) + call = unary_stream_call(_REQUEST) + + # Uses reader API + self.assertEqual(_RESPONSE, await call.read()) + + # Uses async generator API + response_cnt = 0 + async for response in call: + response_cnt += 1 + self.assertEqual(_RESPONSE, response) + + self.assertEqual(_NUM_STREAM_RESPONSES - 1, response_cnt) + self.assertEqual(await call.code(), grpc.StatusCode.OK) + + async def test_stream_unary_async_generator(self): + stream_unary_call = self._channel.stream_unary(_STREAM_UNARY_ASYNC_GEN) + call = stream_unary_call() + + for _ in range(_NUM_STREAM_REQUESTS): + await call.write(_REQUEST) + await call.done_writing() + + response = await call + self.assertEqual(_RESPONSE, response) + self.assertEqual(await call.code(), grpc.StatusCode.OK) + + async def test_stream_unary_reader_writer(self): + stream_unary_call = self._channel.stream_unary( + _STREAM_UNARY_READER_WRITER) + call = stream_unary_call() + + for _ in range(_NUM_STREAM_REQUESTS): + await call.write(_REQUEST) + await call.done_writing() + + response = await call + self.assertEqual(_RESPONSE, response) + self.assertEqual(await call.code(), grpc.StatusCode.OK) + + async def test_stream_unary_evilly_mixed(self): + stream_unary_call = self._channel.stream_unary( + _STREAM_UNARY_EVILLY_MIXED) + call = stream_unary_call() - # Expecting the request message to reach server before retriving - # any responses. - await asyncio.wait_for(self._generic_handler.wait_for_call(), - test_constants.SHORT_TIMEOUT) + for _ in range(_NUM_STREAM_REQUESTS): + await call.write(_REQUEST) + await call.done_writing() - # Uses reader API - self.assertEqual(_RESPONSE, await call.read()) + response = await call + self.assertEqual(_RESPONSE, response) + self.assertEqual(await call.code(), grpc.StatusCode.OK) - # Uses async generator API - response_cnt = 0 - async for response in call: - response_cnt += 1 - self.assertEqual(_RESPONSE, response) + async def test_stream_stream_async_generator(self): + stream_stream_call = self._channel.stream_stream( + _STREAM_STREAM_ASYNC_GEN) + call = stream_stream_call() - self.assertEqual(_NUM_STREAM_RESPONSES - 1, response_cnt) + for _ in range(_NUM_STREAM_REQUESTS): + await call.write(_REQUEST) + await call.done_writing() - self.assertEqual(await call.code(), grpc.StatusCode.OK) + for _ in range(_NUM_STREAM_RESPONSES): + response = await call.read() + self.assertEqual(_RESPONSE, response) + + self.assertEqual(await call.code(), grpc.StatusCode.OK) + + async def test_stream_stream_reader_writer(self): + stream_stream_call = self._channel.stream_stream( + _STREAM_STREAM_READER_WRITER) + call = stream_stream_call() + + for _ in range(_NUM_STREAM_REQUESTS): + await call.write(_REQUEST) + await call.done_writing() + + for _ in range(_NUM_STREAM_RESPONSES): + response = await call.read() + self.assertEqual(_RESPONSE, response) + + self.assertEqual(await call.code(), grpc.StatusCode.OK) + + async def test_stream_stream_evilly_mixed(self): + stream_stream_call = self._channel.stream_stream( + _STREAM_STREAM_EVILLY_MIXED) + call = stream_stream_call() + + for _ in range(_NUM_STREAM_REQUESTS): + await call.write(_REQUEST) + await call.done_writing() + + for _ in range(_NUM_STREAM_RESPONSES): + response = await call.read() + self.assertEqual(_RESPONSE, response) + + self.assertEqual(await call.code(), grpc.StatusCode.OK) async def test_shutdown(self): await self._server.stop(None) # Ensures no SIGSEGV triggered, and ends within timeout. async def test_shutdown_after_call(self): - async with aio.insecure_channel(self._server_target) as channel: - await channel.unary_unary(_SIMPLE_UNARY_UNARY)(_REQUEST) + await self._channel.unary_unary(_SIMPLE_UNARY_UNARY)(_REQUEST) await self._server.stop(None) async def test_graceful_shutdown_success(self): - channel = aio.insecure_channel(self._server_target) - call = channel.unary_unary(_BLOCK_BRIEFLY)(_REQUEST) + call = self._channel.unary_unary(_BLOCK_BRIEFLY)(_REQUEST) await self._generic_handler.wait_for_call() shutdown_start_time = time.time() @@ -190,13 +332,11 @@ class TestServer(AioTestBase): test_constants.SHORT_TIMEOUT / 3) # Validates the states. - await channel.close() self.assertEqual(_RESPONSE, await call) self.assertTrue(call.done()) async def test_graceful_shutdown_failed(self): - channel = aio.insecure_channel(self._server_target) - call = channel.unary_unary(_BLOCK_FOREVER)(_REQUEST) + call = self._channel.unary_unary(_BLOCK_FOREVER)(_REQUEST) await self._generic_handler.wait_for_call() await self._server.stop(test_constants.SHORT_TIMEOUT) @@ -206,11 +346,9 @@ class TestServer(AioTestBase): self.assertEqual(grpc.StatusCode.UNAVAILABLE, exception_context.exception.code()) self.assertIn('GOAWAY', exception_context.exception.details()) - await channel.close() async def test_concurrent_graceful_shutdown(self): - channel = aio.insecure_channel(self._server_target) - call = channel.unary_unary(_BLOCK_BRIEFLY)(_REQUEST) + call = self._channel.unary_unary(_BLOCK_BRIEFLY)(_REQUEST) await self._generic_handler.wait_for_call() # Expects the shortest grace period to be effective. @@ -224,13 +362,11 @@ class TestServer(AioTestBase): self.assertGreater(grace_period_length, test_constants.SHORT_TIMEOUT / 3) - await channel.close() self.assertEqual(_RESPONSE, await call) self.assertTrue(call.done()) async def test_concurrent_graceful_shutdown_immediate(self): - channel = aio.insecure_channel(self._server_target) - call = channel.unary_unary(_BLOCK_FOREVER)(_REQUEST) + call = self._channel.unary_unary(_BLOCK_FOREVER)(_REQUEST) await self._generic_handler.wait_for_call() # Expects no grace period, due to the "server.stop(None)". @@ -246,7 +382,6 @@ class TestServer(AioTestBase): self.assertEqual(grpc.StatusCode.UNAVAILABLE, exception_context.exception.code()) self.assertIn('GOAWAY', exception_context.exception.details()) - await channel.close() @unittest.skip('https://github.com/grpc/grpc/issues/20818') async def test_shutdown_before_call(self):