Merge pull request #21517 from lidizheng/aio-streaming

[Aio] Client Streaming and Bi-di Streaming
reviewable/pr21745/r1
Lidi Zheng 5 years ago committed by GitHub
commit 5a4a5a0088
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 103
      src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi
  2. 11
      src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi
  3. 21
      src/python/grpcio/grpc/_cython/_cygrpc/aio/common.pyx.pxi
  4. 4
      src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pxd.pxi
  5. 286
      src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi
  6. 4
      src/python/grpcio/grpc/experimental/aio/__init__.py
  7. 90
      src/python/grpcio/grpc/experimental/aio/_base_call.py
  8. 449
      src/python/grpcio/grpc/experimental/aio/_call.py
  9. 166
      src/python/grpcio/grpc/experimental/aio/_channel.py
  10. 2
      src/python/grpcio/grpc/experimental/aio/_typing.py
  11. 2
      src/python/grpcio_tests/tests_aio/tests.json
  12. 26
      src/python/grpcio_tests/tests_aio/unit/_test_server.py
  13. 310
      src/python/grpcio_tests/tests_aio/unit/call_test.py
  14. 100
      src/python/grpcio_tests/tests_aio/unit/channel_test.py
  15. 297
      src/python/grpcio_tests/tests_aio/unit/server_test.py

@ -23,18 +23,18 @@ _EMPTY_METADATA = None
_UNKNOWN_CANCELLATION_DETAILS = 'RPC cancelled for unknown reason.'
cdef class _AioCall:
cdef class _AioCall(GrpcCallWrapper):
def __cinit__(self,
AioChannel channel,
object deadline,
bytes method,
CallCredentials credentials):
CallCredentials call_credentials):
self.call = NULL
self._channel = channel
self._references = []
self._loop = asyncio.get_event_loop()
self._create_grpc_call(deadline, method, credentials)
self._create_grpc_call(deadline, method, call_credentials)
self._is_locally_cancelled = False
def __dealloc__(self):
@ -196,9 +196,25 @@ cdef class _AioCall:
self,
self._loop
)
return received_message
if received_message:
return received_message
else:
return EOF
async def send_serialized_message(self, bytes message):
"""Sends one single raw message in bytes."""
await _send_message(self,
message,
True,
self._loop)
async def unary_stream(self,
async def send_receive_close(self):
"""Half close the RPC on the client-side."""
cdef SendCloseFromClientOperation op = SendCloseFromClientOperation(_EMPTY_FLAGS)
cdef tuple ops = (op,)
await execute_batch(self, ops, self._loop)
async def initiate_unary_stream(self,
bytes request,
object initial_metadata_observer,
object status_observer):
@ -233,3 +249,80 @@ cdef class _AioCall:
await _receive_initial_metadata(self,
self._loop),
)
async def stream_unary(self,
tuple metadata,
object metadata_sent_observer,
object initial_metadata_observer,
object status_observer):
"""Actual implementation of the complete unary-stream call.
Needs to pay extra attention to the raise mechanism. If we want to
propagate the final status exception, then we have to raise it.
Othersize, it would end normally and raise `StopAsyncIteration()`.
"""
# Sends out initial_metadata ASAP.
await _send_initial_metadata(self,
metadata,
self._loop)
# Notify upper level that sending messages are allowed now.
metadata_sent_observer()
# Receives initial metadata.
initial_metadata_observer(
await _receive_initial_metadata(self,
self._loop),
)
cdef tuple inbound_ops
cdef ReceiveMessageOperation receive_message_op = ReceiveMessageOperation(_EMPTY_FLAGS)
cdef ReceiveStatusOnClientOperation receive_status_on_client_op = ReceiveStatusOnClientOperation(_EMPTY_FLAGS)
inbound_ops = (receive_message_op, receive_status_on_client_op)
# Executes all operations in one batch.
await execute_batch(self,
inbound_ops,
self._loop)
status = AioRpcStatus(
receive_status_on_client_op.code(),
receive_status_on_client_op.details(),
receive_status_on_client_op.trailing_metadata(),
receive_status_on_client_op.error_string(),
)
# Reports the final status of the RPC to Python layer. The observer
# pattern is used here to unify unary and streaming code path.
status_observer(status)
if status.code() == StatusCode.ok:
return receive_message_op.message()
else:
return None
async def initiate_stream_stream(self,
tuple metadata,
object metadata_sent_observer,
object initial_metadata_observer,
object status_observer):
"""Actual implementation of the complete stream-stream call.
Needs to pay extra attention to the raise mechanism. If we want to
propagate the final status exception, then we have to raise it.
Othersize, it would end normally and raise `StopAsyncIteration()`.
"""
# Peer may prematurely end this RPC at any point. We need a corutine
# that watches if the server sends the final status.
self._loop.create_task(self._handle_status_once_received(status_observer))
# Sends out initial_metadata ASAP.
await _send_initial_metadata(self,
metadata,
self._loop)
# Notify upper level that sending messages are allowed now.
metadata_sent_observer()
# Receives initial metadata.
initial_metadata_observer(
await _receive_initial_metadata(self,
self._loop),
)

@ -96,7 +96,7 @@ cdef class AioChannel:
def call(self,
bytes method,
object deadline,
CallCredentials credentials):
object python_call_credentials):
"""Assembles a Cython Call object.
Returns:
@ -105,5 +105,12 @@ cdef class AioChannel:
if self._status == AIO_CHANNEL_STATUS_DESTROYED:
# TODO(lidiz) switch to UsageError
raise RuntimeError('Channel is closed.')
cdef _AioCall call = _AioCall(self, deadline, method, credentials)
cdef CallCredentials cython_call_credentials
if python_call_credentials is not None:
cython_call_credentials = python_call_credentials._credentials
else:
cython_call_credentials = None
cdef _AioCall call = _AioCall(self, deadline, method, cython_call_credentials)
return call

@ -33,3 +33,24 @@ cdef bytes serialize(object serializer, object message):
return serializer(message)
else:
return message
class _EOF:
def __bool__(self):
return False
def __len__(self):
return 0
def _repr(self) -> str:
return '<grpc.aio.EOF>'
def __repr__(self) -> str:
return self._repr()
def __str__(self) -> str:
return self._repr()
EOF = _EOF()

@ -21,6 +21,10 @@ cdef class RPCState(GrpcCallWrapper):
cdef grpc_call_details details
cdef grpc_metadata_array request_metadata
cdef AioServer server
# NOTE(lidiz) Under certain corner case, receiving the client close
# operation won't immediately fail ongoing RECV_MESSAGE operations. Here I
# added a flag to workaround this unexpected behavior.
cdef bint client_closed
cdef object abort_exception
cdef bint metadata_sent
cdef bint status_sent

@ -20,7 +20,8 @@ import traceback
# TODO(https://github.com/grpc/grpc/issues/20850) refactor this.
_LOGGER = logging.getLogger(__name__)
cdef int _EMPTY_FLAG = 0
# TODO(lidiz) Use a designated value other than None.
cdef str _SERVER_STOPPED_DETAILS = 'Server already stopped.'
cdef class _HandlerCallDetails:
def __cinit__(self, str method, tuple invocation_metadata):
@ -35,6 +36,7 @@ cdef class RPCState:
self.server = server
grpc_metadata_array_init(&self.request_metadata)
grpc_call_details_init(&self.details)
self.client_closed = False
self.abort_exception = None
self.metadata_sent = False
self.status_sent = False
@ -83,13 +85,23 @@ cdef class _ServicerContext:
self._loop = loop
async def read(self):
cdef bytes raw_message
if self._rpc_state.server._status == AIO_SERVER_STATUS_STOPPED:
raise RuntimeError(_SERVER_STOPPED_DETAILS)
if self._rpc_state.status_sent:
raise RuntimeError('RPC already finished.')
cdef bytes raw_message = await _receive_message(self._rpc_state, self._loop)
return deserialize(self._request_deserializer,
raw_message)
if self._rpc_state.client_closed:
return EOF
raw_message = await _receive_message(self._rpc_state, self._loop)
if raw_message is None:
return EOF
else:
return deserialize(self._request_deserializer,
raw_message)
async def write(self, object message):
if self._rpc_state.server._status == AIO_SERVER_STATUS_STOPPED:
raise RuntimeError(_SERVER_STOPPED_DETAILS)
if self._rpc_state.status_sent:
raise RuntimeError('RPC already finished.')
await _send_message(self._rpc_state,
@ -102,6 +114,8 @@ cdef class _ServicerContext:
async def send_initial_metadata(self, tuple metadata):
if self._rpc_state.status_sent:
raise RuntimeError('RPC already finished.')
elif self._rpc_state.server._status == AIO_SERVER_STATUS_STOPPED:
raise RuntimeError(_SERVER_STOPPED_DETAILS)
elif self._rpc_state.metadata_sent:
raise RuntimeError('Send initial metadata failed: already sent')
else:
@ -145,27 +159,23 @@ cdef _find_method_handler(str method, list generic_handlers):
return None
async def _handle_unary_unary_rpc(object method_handler,
RPCState rpc_state,
object loop):
# Receives request message
cdef bytes request_raw = await _receive_message(rpc_state, loop)
# Deserializes the request message
cdef object request_message = deserialize(
method_handler.request_deserializer,
request_raw,
)
async def _finish_handler_with_unary_response(RPCState rpc_state,
object unary_handler,
object request,
_ServicerContext servicer_context,
object response_serializer,
object loop):
"""Finishes server method handler with a single response.
This function executes the application handler, and handles response
sending, as well as errors. It is shared between unary-unary and
stream-unary handlers.
"""
# Executes application logic
cdef object response_message = await method_handler.unary_unary(
request_message,
_ServicerContext(
rpc_state,
None,
None,
loop,
),
cdef object response_message = await unary_handler(
request,
servicer_context,
)
# Raises exception if aborted
@ -173,50 +183,50 @@ async def _handle_unary_unary_rpc(object method_handler,
# Serializes the response message
cdef bytes response_raw = serialize(
method_handler.response_serializer,
response_serializer,
response_message,
)
# Sends response message
cdef tuple send_ops = (
SendStatusFromServerOperation(
tuple(),
# Assembles the batch operations
cdef Operation send_status_op = SendStatusFromServerOperation(
tuple(),
StatusCode.ok,
b'',
_EMPTY_FLAGS,
),
SendInitialMetadataOperation(None, _EMPTY_FLAGS),
SendMessageOperation(response_raw, _EMPTY_FLAGS),
)
cdef tuple finish_ops
if not rpc_state.metadata_sent:
finish_ops = (
send_status_op,
SendInitialMetadataOperation(None, _EMPTY_FLAGS),
SendMessageOperation(response_raw, _EMPTY_FLAGS),
)
else:
finish_ops = (
send_status_op,
SendMessageOperation(response_raw, _EMPTY_FLAGS),
)
rpc_state.status_sent = True
await execute_batch(rpc_state, send_ops, loop)
await execute_batch(rpc_state, finish_ops, loop)
async def _handle_unary_stream_rpc(object method_handler,
RPCState rpc_state,
object loop):
# Receives request message
cdef bytes request_raw = await _receive_message(rpc_state, loop)
# Deserializes the request message
cdef object request_message = deserialize(
method_handler.request_deserializer,
request_raw,
)
cdef _ServicerContext servicer_context = _ServicerContext(
rpc_state,
method_handler.request_deserializer,
method_handler.response_serializer,
loop,
)
async def _finish_handler_with_stream_responses(RPCState rpc_state,
object stream_handler,
object request,
_ServicerContext servicer_context,
object loop):
"""Finishes server method handler with multiple responses.
This function executes the application handler, and handles response
sending, as well as errors. It is shared between unary-stream and
stream-stream handlers.
"""
cdef object async_response_generator
cdef object response_message
if inspect.iscoroutinefunction(method_handler.unary_stream):
if inspect.iscoroutinefunction(stream_handler):
# The handler uses reader / writer API, returns None.
await method_handler.unary_stream(
request_message,
await stream_handler(
request,
servicer_context,
)
@ -224,8 +234,8 @@ async def _handle_unary_stream_rpc(object method_handler,
_raise_if_aborted(rpc_state)
else:
# The handler uses async generator API
async_response_generator = method_handler.unary_stream(
request_message,
async_response_generator = stream_handler(
request,
servicer_context,
)
@ -250,9 +260,132 @@ async def _handle_unary_stream_rpc(object method_handler,
_EMPTY_FLAGS,
)
cdef tuple ops = (op,)
cdef tuple finish_ops = (op,)
if not rpc_state.metadata_sent:
finish_ops = (op, SendInitialMetadataOperation(None, _EMPTY_FLAGS))
rpc_state.status_sent = True
await execute_batch(rpc_state, ops, loop)
await execute_batch(rpc_state, finish_ops, loop)
async def _handle_unary_unary_rpc(object method_handler,
RPCState rpc_state,
object loop):
# Receives request message
cdef bytes request_raw = await _receive_message(rpc_state, loop)
# Deserializes the request message
cdef object request_message = deserialize(
method_handler.request_deserializer,
request_raw,
)
# Creates a dedecated ServicerContext
cdef _ServicerContext servicer_context = _ServicerContext(
rpc_state,
None,
None,
loop,
)
# Finishes the application handler
await _finish_handler_with_unary_response(
rpc_state,
method_handler.unary_unary,
request_message,
servicer_context,
method_handler.response_serializer,
loop
)
async def _handle_unary_stream_rpc(object method_handler,
RPCState rpc_state,
object loop):
# Receives request message
cdef bytes request_raw = await _receive_message(rpc_state, loop)
# Deserializes the request message
cdef object request_message = deserialize(
method_handler.request_deserializer,
request_raw,
)
# Creates a dedecated ServicerContext
cdef _ServicerContext servicer_context = _ServicerContext(
rpc_state,
method_handler.request_deserializer,
method_handler.response_serializer,
loop,
)
# Finishes the application handler
await _finish_handler_with_stream_responses(
rpc_state,
method_handler.unary_stream,
request_message,
servicer_context,
loop,
)
async def _message_receiver(_ServicerContext servicer_context):
"""Bridge between the async generator API and the reader-writer API."""
cdef object message
while True:
message = await servicer_context.read()
if message is not EOF:
yield message
else:
break
async def _handle_stream_unary_rpc(object method_handler,
RPCState rpc_state,
object loop):
# Creates a dedecated ServicerContext
cdef _ServicerContext servicer_context = _ServicerContext(
rpc_state,
method_handler.request_deserializer,
None,
loop,
)
# Prepares the request generator
cdef object request_async_iterator = _message_receiver(servicer_context)
# Finishes the application handler
await _finish_handler_with_unary_response(
rpc_state,
method_handler.stream_unary,
request_async_iterator,
servicer_context,
method_handler.response_serializer,
loop
)
async def _handle_stream_stream_rpc(object method_handler,
RPCState rpc_state,
object loop):
# Creates a dedecated ServicerContext
cdef _ServicerContext servicer_context = _ServicerContext(
rpc_state,
method_handler.request_deserializer,
method_handler.response_serializer,
loop,
)
# Prepares the request generator
cdef object request_async_iterator = _message_receiver(servicer_context)
# Finishes the application handler
await _finish_handler_with_stream_responses(
rpc_state,
method_handler.stream_stream,
request_async_iterator,
servicer_context,
loop,
)
async def _handle_exceptions(RPCState rpc_state, object rpc_coro, object loop):
@ -293,6 +426,7 @@ async def _handle_cancellation_from_core(object rpc_task,
# Awaits cancellation from peer.
await execute_batch(rpc_state, ops, loop)
rpc_state.client_closed = True
if op.cancelled() and not rpc_task.done():
# Injects `CancelledError` to halt the RPC coroutine
rpc_task.cancel()
@ -311,8 +445,9 @@ async def _schedule_rpc_coro(object rpc_coro,
async def _handle_rpc(list generic_handlers, RPCState rpc_state, object loop):
cdef object method_handler
# Finds the method handler (application logic)
cdef object method_handler = _find_method_handler(
method_handler = _find_method_handler(
rpc_state.method().decode(),
generic_handlers,
)
@ -328,20 +463,33 @@ async def _handle_rpc(list generic_handlers, RPCState rpc_state, object loop):
)
return
# TODO(lidiz) extend to all 4 types of RPC
# Handles unary-unary case
if not method_handler.request_streaming and not method_handler.response_streaming:
await _handle_unary_unary_rpc(method_handler,
rpc_state,
loop)
return
# Handles unary-stream case
if not method_handler.request_streaming and method_handler.response_streaming:
try:
await _handle_unary_stream_rpc(method_handler,
await _handle_unary_stream_rpc(method_handler,
rpc_state,
loop)
except Exception as e:
raise
elif not method_handler.request_streaming and not method_handler.response_streaming:
await _handle_unary_unary_rpc(method_handler,
rpc_state,
loop)
else:
raise NotImplementedError()
return
# Handles stream-unary case
if method_handler.request_streaming and not method_handler.response_streaming:
await _handle_stream_unary_rpc(method_handler,
rpc_state,
loop)
return
# Handles stream-stream case
if method_handler.request_streaming and method_handler.response_streaming:
await _handle_stream_stream_rpc(method_handler,
rpc_state,
loop)
return
class _RequestCallError(Exception): pass

@ -22,7 +22,7 @@ from typing import Any, Optional, Sequence, Text, Tuple
import six
import grpc
from grpc._cython.cygrpc import init_grpc_aio, AbortError
from grpc._cython.cygrpc import EOF, AbortError, init_grpc_aio
from ._base_call import Call, RpcContext, UnaryStreamCall, UnaryUnaryCall
from ._call import AioRpcError
@ -86,5 +86,5 @@ __all__ = ('AioRpcError', 'RpcContext', 'Call', 'UnaryUnaryCall',
'UnaryStreamCall', 'init_grpc_aio', 'Channel',
'UnaryUnaryMultiCallable', 'ClientCallDetails',
'UnaryUnaryClientInterceptor', 'InterceptedUnaryUnaryCall',
'insecure_channel', 'secure_channel', 'server', 'Server',
'insecure_channel', 'server', 'Server', 'EOF', 'secure_channel',
'AbortError')

@ -19,11 +19,12 @@ RPC, e.g. cancellation.
"""
from abc import ABCMeta, abstractmethod
from typing import Any, AsyncIterable, Awaitable, Callable, Generic, Text, Optional
from typing import (Any, AsyncIterable, Awaitable, Callable, Generic, Optional,
Text, Union)
import grpc
from ._typing import MetadataType, RequestType, ResponseType
from ._typing import EOFType, MetadataType, RequestType, ResponseType
__all__ = 'RpcContext', 'Call', 'UnaryUnaryCall', 'UnaryStreamCall'
@ -146,14 +147,85 @@ class UnaryStreamCall(Generic[RequestType, ResponseType],
"""
@abstractmethod
async def read(self) -> ResponseType:
"""Reads one message from the RPC.
async def read(self) -> Union[EOFType, ResponseType]:
"""Reads one message from the stream.
For each streaming RPC, concurrent reads in multiple coroutines are not
allowed. If you want to perform read in multiple coroutines, you needs
synchronization. So, you can start another read after current read is
finished.
Read operations must be serialized when called from multiple
coroutines.
Returns:
A response message of the RPC.
A response message, or an `grpc.aio.EOF` to indicate the end of the
stream.
"""
class StreamUnaryCall(Generic[RequestType, ResponseType],
Call,
metaclass=ABCMeta):
@abstractmethod
async def write(self, request: RequestType) -> None:
"""Writes one message to the stream.
Raises:
An RpcError exception if the write failed.
"""
@abstractmethod
async def done_writing(self) -> None:
"""Notifies server that the client is done sending messages.
After done_writing is called, any additional invocation to the write
function will fail. This function is idempotent.
"""
@abstractmethod
def __await__(self) -> Awaitable[ResponseType]:
"""Await the response message to be ready.
Returns:
The response message of the stream.
"""
class StreamStreamCall(Generic[RequestType, ResponseType],
Call,
metaclass=ABCMeta):
@abstractmethod
def __aiter__(self) -> AsyncIterable[ResponseType]:
"""Returns the async iterable representation that yields messages.
Under the hood, it is calling the "read" method.
Returns:
An async iterable object that yields messages.
"""
@abstractmethod
async def read(self) -> Union[EOFType, ResponseType]:
"""Reads one message from the stream.
Read operations must be serialized when called from multiple
coroutines.
Returns:
A response message, or an `grpc.aio.EOF` to indicate the end of the
stream.
"""
@abstractmethod
async def write(self, request: RequestType) -> None:
"""Writes one message to the stream.
Raises:
An RpcError exception if the write failed.
"""
@abstractmethod
async def done_writing(self) -> None:
"""Notifies server that the client is done sending messages.
After done_writing is called, any additional invocation to the write
function will fail. This function is idempotent.
"""

@ -29,6 +29,7 @@ __all__ = 'AioRpcError', 'Call', 'UnaryUnaryCall', 'UnaryStreamCall'
_LOCAL_CANCELLATION_DETAILS = 'Locally cancelled by application!'
_GC_CANCELLATION_DETAILS = 'Cancelled upon garbage collection!'
_RPC_ALREADY_FINISHED_DETAILS = 'RPC already finished.'
_RPC_HALF_CLOSED_DETAILS = 'RPC is half closed after calling "done_writing".'
_OK_CALL_REPRESENTATION = ('<{} of RPC that terminated with:\n'
'\tstatus = {}\n'
@ -146,31 +147,48 @@ def _create_rpc_error(initial_metadata: Optional[MetadataType],
class Call(_base_call.Call):
"""Base implementation of client RPC Call object.
Implements logic around final status, metadata and cancellation.
"""
_loop: asyncio.AbstractEventLoop
_code: grpc.StatusCode
_status: Awaitable[cygrpc.AioRpcStatus]
_initial_metadata: Awaitable[MetadataType]
_locally_cancelled: bool
_cython_call: cygrpc._AioCall
def __init__(self) -> None:
def __init__(self, cython_call: cygrpc._AioCall) -> None:
self._loop = asyncio.get_event_loop()
self._code = None
self._status = self._loop.create_future()
self._initial_metadata = self._loop.create_future()
self._locally_cancelled = False
self._cython_call = cython_call
def cancel(self) -> bool:
"""Placeholder cancellation method.
The implementation of this method needs to pass the cancellation reason
into self._cancellation, using `set_result` instead of
`set_exception`.
"""
raise NotImplementedError()
def __del__(self) -> None:
if not self._status.done():
self._cancel(
cygrpc.AioRpcStatus(cygrpc.StatusCode.cancelled,
_GC_CANCELLATION_DETAILS, None, None))
def cancelled(self) -> bool:
return self._code == grpc.StatusCode.CANCELLED
def _cancel(self, status: cygrpc.AioRpcStatus) -> bool:
"""Forwards the application cancellation reasoning."""
if not self._status.done():
self._set_status(status)
self._cython_call.cancel(status)
return True
else:
return False
def cancel(self) -> bool:
return self._cancel(
cygrpc.AioRpcStatus(cygrpc.StatusCode.cancelled,
_LOCAL_CANCELLATION_DETAILS, None, None))
def done(self) -> bool:
return self._status.done()
@ -247,6 +265,7 @@ class Call(_base_call.Call):
return self._repr()
# TODO(https://github.com/grpc/grpc/issues/21623) remove this suppression
# pylint: disable=abstract-method
class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall):
"""Object for managing unary-unary RPC calls.
@ -254,37 +273,29 @@ class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall):
Returned when an instance of `UnaryUnaryMultiCallable` object is called.
"""
_request: RequestType
_channel: cygrpc.AioChannel
_request_serializer: SerializingFunction
_response_deserializer: DeserializingFunction
_call: asyncio.Task
_cython_call: cygrpc._AioCall
def __init__( # pylint: disable=R0913
self, request: RequestType, deadline: Optional[float],
credentials: Optional[grpc.CallCredentials],
channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction) -> None:
super().__init__()
# pylint: disable=too-many-arguments
def __init__(self, request: RequestType, deadline: Optional[float],
credentials: Optional[grpc.CallCredentials],
channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction) -> None:
channel.call(method, deadline, credentials)
super().__init__(channel.call(method, deadline, credentials))
self._request = request
self._channel = channel
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
if credentials is not None:
grpc_credentials = credentials._credentials
else:
grpc_credentials = None
self._cython_call = self._channel.call(method, deadline,
grpc_credentials)
self._call = self._loop.create_task(self._invoke())
def __del__(self) -> None:
if not self._call.done():
self._cancel(
cygrpc.AioRpcStatus(cygrpc.StatusCode.cancelled,
_GC_CANCELLATION_DETAILS, None, None))
def cancel(self) -> bool:
if super().cancel():
self._call.cancel()
return True
else:
return False
async def _invoke(self) -> ResponseType:
serialized_request = _common.serialize(self._request,
@ -300,7 +311,7 @@ class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall):
self._set_status,
)
except asyncio.CancelledError:
if self._code != grpc.StatusCode.CANCELLED:
if not self.cancelled():
self.cancel()
# Raises here if RPC failed or cancelled
@ -309,21 +320,6 @@ class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall):
return _common.deserialize(serialized_response,
self._response_deserializer)
def _cancel(self, status: cygrpc.AioRpcStatus) -> bool:
"""Forwards the application cancellation reasoning."""
if not self._status.done():
self._set_status(status)
self._cython_call.cancel(status)
self._call.cancel()
return True
else:
return False
def cancel(self) -> bool:
return self._cancel(
cygrpc.AioRpcStatus(cygrpc.StatusCode.cancelled,
_LOCAL_CANCELLATION_DETAILS, None, None))
def __await__(self) -> ResponseType:
"""Wait till the ongoing RPC request finishes."""
try:
@ -339,6 +335,7 @@ class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall):
return response
# TODO(https://github.com/grpc/grpc/issues/21623) remove this suppression
# pylint: disable=abstract-method
class UnaryStreamCall(Call, _base_call.UnaryStreamCall):
"""Object for managing unary-stream RPC calls.
@ -346,107 +343,346 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall):
Returned when an instance of `UnaryStreamMultiCallable` object is called.
"""
_request: RequestType
_channel: cygrpc.AioChannel
_request_serializer: SerializingFunction
_response_deserializer: DeserializingFunction
_cython_call: cygrpc._AioCall
_send_unary_request_task: asyncio.Task
_message_aiter: AsyncIterable[ResponseType]
def __init__( # pylint: disable=R0913
self, request: RequestType, deadline: Optional[float],
credentials: Optional[grpc.CallCredentials],
channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction) -> None:
super().__init__()
# pylint: disable=too-many-arguments
def __init__(self, request: RequestType, deadline: Optional[float],
credentials: Optional[grpc.CallCredentials],
channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction) -> None:
super().__init__(channel.call(method, deadline, credentials))
self._request = request
self._channel = channel
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
self._send_unary_request_task = self._loop.create_task(
self._send_unary_request())
self._message_aiter = self._fetch_stream_responses()
self._message_aiter = None
if credentials is not None:
grpc_credentials = credentials._credentials
def cancel(self) -> bool:
if super().cancel():
self._send_unary_request_task.cancel()
return True
else:
grpc_credentials = None
self._cython_call = self._channel.call(method, deadline,
grpc_credentials)
def __del__(self) -> None:
if not self._status.done():
self._cancel(
cygrpc.AioRpcStatus(cygrpc.StatusCode.cancelled,
_GC_CANCELLATION_DETAILS, None, None))
return False
async def _send_unary_request(self) -> ResponseType:
serialized_request = _common.serialize(self._request,
self._request_serializer)
try:
await self._cython_call.unary_stream(serialized_request,
self._set_initial_metadata,
self._set_status)
await self._cython_call.initiate_unary_stream(
serialized_request, self._set_initial_metadata,
self._set_status)
except asyncio.CancelledError:
if self._code != grpc.StatusCode.CANCELLED:
if not self.cancelled():
self.cancel()
raise
async def _fetch_stream_responses(self) -> ResponseType:
await self._send_unary_request_task
message = await self._read()
while message:
while message is not cygrpc.EOF:
yield message
message = await self._read()
def _cancel(self, status: cygrpc.AioRpcStatus) -> bool:
"""Forwards the application cancellation reasoning.
def __aiter__(self) -> AsyncIterable[ResponseType]:
if self._message_aiter is None:
self._message_aiter = self._fetch_stream_responses()
return self._message_aiter
Async generator will receive an exception. The cancellation will go
deep down into Core, and then propagates backup as the
`cygrpc.AioRpcStatus` exception.
async def _read(self) -> ResponseType:
# Wait for the request being sent
await self._send_unary_request_task
So, under race condition, e.g. the server sent out final state headers
and the client calling "cancel" at the same time, this method respects
the winner in Core.
"""
if not self._status.done():
self._set_status(status)
self._cython_call.cancel(status)
# Reads response message from Core
try:
raw_response = await self._cython_call.receive_serialized_message()
except asyncio.CancelledError:
if not self.cancelled():
self.cancel()
await self._raise_for_status()
if raw_response is cygrpc.EOF:
return cygrpc.EOF
else:
return _common.deserialize(raw_response,
self._response_deserializer)
async def read(self) -> ResponseType:
if self._status.done():
await self._raise_for_status()
return cygrpc.EOF
response_message = await self._read()
if response_message is cygrpc.EOF:
# If the read operation failed, Core should explain why.
await self._raise_for_status()
return response_message
# TODO(https://github.com/grpc/grpc/issues/21623) remove this suppression
# pylint: disable=abstract-method
class StreamUnaryCall(Call, _base_call.StreamUnaryCall):
"""Object for managing stream-unary RPC calls.
Returned when an instance of `StreamUnaryMultiCallable` object is called.
"""
_metadata: MetadataType
_request_serializer: SerializingFunction
_response_deserializer: DeserializingFunction
_metadata_sent: asyncio.Event
_done_writing: bool
_call_finisher: asyncio.Task
_async_request_poller: asyncio.Task
if not self._send_unary_request_task.done():
# Injects CancelledError to the Task. The exception will
# propagate to _fetch_stream_responses as well, if the sending
# is not done.
self._send_unary_request_task.cancel()
# pylint: disable=too-many-arguments
def __init__(self,
request_async_iterator: Optional[AsyncIterable[RequestType]],
deadline: Optional[float],
credentials: Optional[grpc.CallCredentials],
channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction) -> None:
super().__init__(channel.call(method, deadline, credentials))
self._metadata = _EMPTY_METADATA
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
self._metadata_sent = asyncio.Event(loop=self._loop)
self._done_writing = False
self._call_finisher = self._loop.create_task(self._conduct_rpc())
# If user passes in an async iterator, create a consumer Task.
if request_async_iterator is not None:
self._async_request_poller = self._loop.create_task(
self._consume_request_iterator(request_async_iterator))
else:
self._async_request_poller = None
def cancel(self) -> bool:
if super().cancel():
self._call_finisher.cancel()
if self._async_request_poller is not None:
self._async_request_poller.cancel()
return True
else:
return False
def _metadata_sent_observer(self):
self._metadata_sent.set()
async def _conduct_rpc(self) -> ResponseType:
try:
serialized_response = await self._cython_call.stream_unary(
self._metadata,
self._metadata_sent_observer,
self._set_initial_metadata,
self._set_status,
)
except asyncio.CancelledError:
if not self.cancelled():
self.cancel()
# Raises RpcError if the RPC failed or cancelled
await self._raise_for_status()
return _common.deserialize(serialized_response,
self._response_deserializer)
async def _consume_request_iterator(
self, request_async_iterator: AsyncIterable[RequestType]) -> None:
async for request in request_async_iterator:
await self.write(request)
await self.done_writing()
def __await__(self) -> ResponseType:
"""Wait till the ongoing RPC request finishes."""
try:
response = yield from self._call_finisher
except asyncio.CancelledError:
if not self.cancelled():
self.cancel()
raise
return response
async def write(self, request: RequestType) -> None:
if self._status.done():
raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
if self._done_writing:
raise asyncio.InvalidStateError(_RPC_HALF_CLOSED_DETAILS)
if not self._metadata_sent.is_set():
await self._metadata_sent.wait()
serialized_request = _common.serialize(request,
self._request_serializer)
try:
await self._cython_call.send_serialized_message(serialized_request)
except asyncio.CancelledError:
if not self.cancelled():
self.cancel()
await self._raise_for_status()
async def done_writing(self) -> None:
"""Implementation of done_writing is idempotent."""
if self._status.done():
# If the RPC is finished, do nothing.
return
if not self._done_writing:
# If the done writing is not sent before, try to send it.
self._done_writing = True
try:
await self._cython_call.send_receive_close()
except asyncio.CancelledError:
if not self.cancelled():
self.cancel()
await self._raise_for_status()
# TODO(https://github.com/grpc/grpc/issues/21623) remove this suppression
# pylint: disable=abstract-method
class StreamStreamCall(Call, _base_call.StreamStreamCall):
"""Object for managing stream-stream RPC calls.
Returned when an instance of `StreamStreamMultiCallable` object is called.
"""
_metadata: MetadataType
_request_serializer: SerializingFunction
_response_deserializer: DeserializingFunction
_metadata_sent: asyncio.Event
_done_writing: bool
_initializer: asyncio.Task
_async_request_poller: asyncio.Task
_message_aiter: AsyncIterable[ResponseType]
# pylint: disable=too-many-arguments
def __init__(self,
request_async_iterator: Optional[AsyncIterable[RequestType]],
deadline: Optional[float],
credentials: Optional[grpc.CallCredentials],
channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction) -> None:
super().__init__(channel.call(method, deadline, credentials))
self._metadata = _EMPTY_METADATA
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
self._metadata_sent = asyncio.Event(loop=self._loop)
self._done_writing = False
self._initializer = self._loop.create_task(self._prepare_rpc())
# If user passes in an async iterator, create a consumer coroutine.
if request_async_iterator is not None:
self._async_request_poller = self._loop.create_task(
self._consume_request_iterator(request_async_iterator))
else:
self._async_request_poller = None
self._message_aiter = None
def cancel(self) -> bool:
return self._cancel(
cygrpc.AioRpcStatus(cygrpc.StatusCode.cancelled,
_LOCAL_CANCELLATION_DETAILS, None, None))
if super().cancel():
self._initializer.cancel()
if self._async_request_poller is not None:
self._async_request_poller.cancel()
return True
else:
return False
def _metadata_sent_observer(self):
self._metadata_sent.set()
async def _prepare_rpc(self):
"""This method prepares the RPC for receiving/sending messages.
All other operations around the stream should only happen after the
completion of this method.
"""
try:
await self._cython_call.initiate_stream_stream(
self._metadata,
self._metadata_sent_observer,
self._set_initial_metadata,
self._set_status,
)
except asyncio.CancelledError:
if not self.cancelled():
self.cancel()
# No need to raise RpcError here, because no one will `await` this task.
async def _consume_request_iterator(
self, request_async_iterator: Optional[AsyncIterable[RequestType]]
) -> None:
async for request in request_async_iterator:
await self.write(request)
await self.done_writing()
async def write(self, request: RequestType) -> None:
if self._status.done():
raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
if self._done_writing:
raise asyncio.InvalidStateError(_RPC_HALF_CLOSED_DETAILS)
if not self._metadata_sent.is_set():
await self._metadata_sent.wait()
serialized_request = _common.serialize(request,
self._request_serializer)
try:
await self._cython_call.send_serialized_message(serialized_request)
except asyncio.CancelledError:
if not self.cancelled():
self.cancel()
await self._raise_for_status()
async def done_writing(self) -> None:
"""Implementation of done_writing is idempotent."""
if self._status.done():
# If the RPC is finished, do nothing.
return
if not self._done_writing:
# If the done writing is not sent before, try to send it.
self._done_writing = True
try:
await self._cython_call.send_receive_close()
except asyncio.CancelledError:
if not self.cancelled():
self.cancel()
await self._raise_for_status()
async def _fetch_stream_responses(self) -> ResponseType:
"""The async generator that yields responses from peer."""
message = await self._read()
while message is not cygrpc.EOF:
yield message
message = await self._read()
def __aiter__(self) -> AsyncIterable[ResponseType]:
if self._message_aiter is None:
self._message_aiter = self._fetch_stream_responses()
return self._message_aiter
async def _read(self) -> ResponseType:
# Wait for the request being sent
await self._send_unary_request_task
# Wait for the setup
await self._initializer
# Reads response message from Core
try:
raw_response = await self._cython_call.receive_serialized_message()
except asyncio.CancelledError:
if self._code != grpc.StatusCode.CANCELLED:
if not self.cancelled():
self.cancel()
raise
await self._raise_for_status()
if raw_response is None:
return None
if raw_response is cygrpc.EOF:
return cygrpc.EOF
else:
return _common.deserialize(raw_response,
self._response_deserializer)
@ -454,14 +690,11 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall):
async def read(self) -> ResponseType:
if self._status.done():
await self._raise_for_status()
raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
return cygrpc.EOF
response_message = await self._read()
if response_message is None:
if response_message is cygrpc.EOF:
# If the read operation failed, Core should explain why.
await self._raise_for_status()
# If no exception raised, there is something wrong internally.
assert False, 'Read operation failed with StatusCode.OK'
else:
return response_message
return response_message

@ -13,14 +13,15 @@
# limitations under the License.
"""Invocation-side implementation of gRPC Asyncio Python."""
import asyncio
from typing import Any, Optional, Sequence, Text
from typing import Any, AsyncIterable, Optional, Sequence, Text
import grpc
from grpc import _common
from grpc._cython import cygrpc
from . import _base_call
from ._call import UnaryStreamCall, UnaryUnaryCall
from ._call import (StreamStreamCall, StreamUnaryCall, UnaryStreamCall,
UnaryUnaryCall)
from ._interceptor import (InterceptedUnaryUnaryCall,
UnaryUnaryClientInterceptor)
from ._typing import (ChannelArgumentType, DeserializingFunction, MetadataType,
@ -28,8 +29,16 @@ from ._typing import (ChannelArgumentType, DeserializingFunction, MetadataType,
from ._utils import _timeout_to_deadline
class UnaryUnaryMultiCallable:
"""Factory an asynchronous unary-unary RPC stub call from client-side."""
class _BaseMultiCallable:
"""Base class of all multi callable objects.
Handles the initialization logic and stores common attributes.
"""
_loop: asyncio.AbstractEventLoop
_channel: cygrpc.AioChannel
_method: bytes
_request_serializer: SerializingFunction
_response_deserializer: DeserializingFunction
_channel: cygrpc.AioChannel
_method: bytes
@ -50,6 +59,10 @@ class UnaryUnaryMultiCallable:
self._response_deserializer = response_deserializer
self._interceptors = interceptors
class UnaryUnaryMultiCallable(_BaseMultiCallable):
"""Factory an asynchronous unary-unary RPC stub call from client-side."""
def __call__(self,
request: Any,
*,
@ -114,17 +127,8 @@ class UnaryUnaryMultiCallable:
)
class UnaryStreamMultiCallable:
"""Afford invoking a unary-stream RPC from client-side in an asynchronous way."""
def __init__(self, channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction) -> None:
self._channel = channel
self._method = method
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
self._loop = asyncio.get_event_loop()
class UnaryStreamMultiCallable(_BaseMultiCallable):
"""Affords invoking a unary-stream RPC from client-side in an asynchronous way."""
def __call__(self,
request: Any,
@ -176,6 +180,122 @@ class UnaryStreamMultiCallable:
)
class StreamUnaryMultiCallable(_BaseMultiCallable):
"""Affords invoking a stream-unary RPC from client-side in an asynchronous way."""
def __call__(self,
request_async_iterator: Optional[AsyncIterable[Any]] = None,
timeout: Optional[float] = None,
metadata: Optional[MetadataType] = None,
credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None
) -> _base_call.StreamUnaryCall:
"""Asynchronously invokes the underlying RPC.
Args:
request: The request value for the RPC.
timeout: An optional duration of time in seconds to allow
for the RPC.
metadata: Optional :term:`metadata` to be transmitted to the
service-side of the RPC.
credentials: An optional CallCredentials for the RPC. Only valid for
secure Channel.
wait_for_ready: This is an EXPERIMENTAL argument. An optional
flag to enable wait for ready mechanism
compression: An element of grpc.compression, e.g.
grpc.compression.Gzip. This is an EXPERIMENTAL option.
Returns:
A Call object instance which is an awaitable object.
Raises:
RpcError: Indicating that the RPC terminated with non-OK status. The
raised RpcError will also be a Call for the RPC affording the RPC's
metadata, status code, and details.
"""
if metadata:
raise NotImplementedError("TODO: metadata not implemented yet")
if wait_for_ready:
raise NotImplementedError(
"TODO: wait_for_ready not implemented yet")
if compression:
raise NotImplementedError("TODO: compression not implemented yet")
deadline = _timeout_to_deadline(timeout)
return StreamUnaryCall(
request_async_iterator,
deadline,
credentials,
self._channel,
self._method,
self._request_serializer,
self._response_deserializer,
)
class StreamStreamMultiCallable(_BaseMultiCallable):
"""Affords invoking a stream-stream RPC from client-side in an asynchronous way."""
def __call__(self,
request_async_iterator: Optional[AsyncIterable[Any]] = None,
timeout: Optional[float] = None,
metadata: Optional[MetadataType] = None,
credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None
) -> _base_call.StreamStreamCall:
"""Asynchronously invokes the underlying RPC.
Args:
request: The request value for the RPC.
timeout: An optional duration of time in seconds to allow
for the RPC.
metadata: Optional :term:`metadata` to be transmitted to the
service-side of the RPC.
credentials: An optional CallCredentials for the RPC. Only valid for
secure Channel.
wait_for_ready: This is an EXPERIMENTAL argument. An optional
flag to enable wait for ready mechanism
compression: An element of grpc.compression, e.g.
grpc.compression.Gzip. This is an EXPERIMENTAL option.
Returns:
A Call object instance which is an awaitable object.
Raises:
RpcError: Indicating that the RPC terminated with non-OK status. The
raised RpcError will also be a Call for the RPC affording the RPC's
metadata, status code, and details.
"""
if metadata:
raise NotImplementedError("TODO: metadata not implemented yet")
if wait_for_ready:
raise NotImplementedError(
"TODO: wait_for_ready not implemented yet")
if compression:
raise NotImplementedError("TODO: compression not implemented yet")
deadline = _timeout_to_deadline(timeout)
return StreamStreamCall(
request_async_iterator,
deadline,
credentials,
self._channel,
self._method,
self._request_serializer,
self._response_deserializer,
)
class Channel:
"""Asynchronous Channel implementation.
@ -301,21 +421,27 @@ class Channel:
) -> UnaryStreamMultiCallable:
return UnaryStreamMultiCallable(self._channel, _common.encode(method),
request_serializer,
response_deserializer)
response_deserializer, None)
def stream_unary(
self,
method: Text,
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None):
"""Placeholder method for stream-unary calls."""
response_deserializer: Optional[DeserializingFunction] = None
) -> StreamUnaryMultiCallable:
return StreamUnaryMultiCallable(self._channel, _common.encode(method),
request_serializer,
response_deserializer, None)
def stream_stream(
self,
method: Text,
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None):
"""Placeholder method for stream-stream calls."""
response_deserializer: Optional[DeserializingFunction] = None
) -> StreamStreamMultiCallable:
return StreamStreamMultiCallable(self._channel, _common.encode(method),
request_serializer,
response_deserializer, None)
async def _close(self):
# TODO: Send cancellation status

@ -14,6 +14,7 @@
"""Common types for gRPC Async API"""
from typing import Any, AnyStr, Callable, Sequence, Text, Tuple, TypeVar
from grpc._cython.cygrpc import EOF
RequestType = TypeVar('RequestType')
ResponseType = TypeVar('ResponseType')
@ -21,3 +22,4 @@ SerializingFunction = Callable[[Any], bytes]
DeserializingFunction = Callable[[bytes], Any]
MetadataType = Sequence[Tuple[Text, AnyStr]]
ChannelArgumentType = Sequence[Tuple[Text, Any]]
EOFType = type(EOF)

@ -2,6 +2,8 @@
"_sanity._sanity_test.AioSanityTest",
"unit.abort_test.TestAbort",
"unit.aio_rpc_error_test.TestAioRpcError",
"unit.call_test.TestStreamStreamCall",
"unit.call_test.TestStreamUnaryCall",
"unit.call_test.TestUnaryStreamCall",
"unit.call_test.TestUnaryUnaryCall",
"unit.channel_argument_test.TestChannelArgument",

@ -26,11 +26,12 @@ from tests_aio.unit._constants import UNARY_CALL_WITH_SLEEP_VALUE
class _TestServiceServicer(test_pb2_grpc.TestServiceServicer):
async def UnaryCall(self, request, context):
async def UnaryCall(self, unused_request, unused_context):
return messages_pb2.SimpleResponse()
async def StreamingOutputCall(
self, request: messages_pb2.StreamingOutputCallRequest, context):
self, request: messages_pb2.StreamingOutputCallRequest,
unused_context):
for response_parameters in request.response_parameters:
if response_parameters.interval_us != 0:
await asyncio.sleep(
@ -44,11 +45,30 @@ class _TestServiceServicer(test_pb2_grpc.TestServiceServicer):
# Next methods are extra ones that are registred programatically
# when the sever is instantiated. They are not being provided by
# the proto file.
async def UnaryCallWithSleep(self, request, context):
await asyncio.sleep(UNARY_CALL_WITH_SLEEP_VALUE)
return messages_pb2.SimpleResponse()
async def StreamingInputCall(self, request_async_iterator, unused_context):
aggregate_size = 0
async for request in request_async_iterator:
if request.payload is not None and request.payload.body:
aggregate_size += len(request.payload.body)
return messages_pb2.StreamingInputCallResponse(
aggregated_payload_size=aggregate_size)
async def FullDuplexCall(self, request_async_iterator, unused_context):
async for request in request_async_iterator:
for response_parameters in request.response_parameters:
if response_parameters.interval_us != 0:
await asyncio.sleep(
datetime.timedelta(microseconds=response_parameters.
interval_us).total_seconds())
yield messages_pb2.StreamingOutputCallResponse(
payload=messages_pb2.Payload(type=request.payload.type,
body=b'\x00' *
response_parameters.size))
async def start_test_server(secure=False):
server = aio.server(options=(('grpc.so_reuseport', 0),))

@ -30,10 +30,10 @@ from src.proto.grpc.testing import messages_pb2
_NUM_STREAM_RESPONSES = 5
_RESPONSE_PAYLOAD_SIZE = 42
_REQUEST_PAYLOAD_SIZE = 7
_LOCAL_CANCEL_DETAILS_EXPECTATION = 'Locally cancelled by application!'
_RESPONSE_INTERVAL_US = test_constants.SHORT_TIMEOUT * 1000 * 1000
_UNREACHABLE_TARGET = '0.1:1111'
_INFINITE_INTERVAL_US = 2**31 - 1
@ -286,7 +286,7 @@ class TestUnaryStreamCall(AioTestBase):
[grpc.StatusCode.OK, grpc.StatusCode.CANCELLED])
async def test_too_many_reads_unary_stream(self):
"""Test cancellation after received all messages."""
"""Test calling read after received all messages fails."""
async with aio.insecure_channel(self._server_target) as channel:
stub = test_pb2_grpc.TestServiceStub(channel)
@ -306,13 +306,14 @@ class TestUnaryStreamCall(AioTestBase):
messages_pb2.StreamingOutputCallResponse)
self.assertEqual(_RESPONSE_PAYLOAD_SIZE,
len(response.payload.body))
self.assertIs(await call.read(), aio.EOF)
# After the RPC is finished, further reads will lead to exception.
self.assertEqual(await call.code(), grpc.StatusCode.OK)
with self.assertRaises(asyncio.InvalidStateError):
await call.read()
self.assertIs(await call.read(), aio.EOF)
async def test_unary_stream_async_generator(self):
"""Sunny day test case for unary_stream."""
async with aio.insecure_channel(self._server_target) as channel:
stub = test_pb2_grpc.TestServiceStub(channel)
@ -426,6 +427,307 @@ class TestUnaryStreamCall(AioTestBase):
self.loop.run_until_complete(coro())
class TestStreamUnaryCall(AioTestBase):
async def setUp(self):
self._server_target, self._server = await start_test_server()
self._channel = aio.insecure_channel(self._server_target)
self._stub = test_pb2_grpc.TestServiceStub(self._channel)
async def tearDown(self):
await self._channel.close()
await self._server.stop(None)
async def test_cancel_stream_unary(self):
call = self._stub.StreamingInputCall()
# Prepares the request
payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE)
request = messages_pb2.StreamingInputCallRequest(payload=payload)
# Sends out requests
for _ in range(_NUM_STREAM_RESPONSES):
await call.write(request)
# Cancels the RPC
self.assertFalse(call.done())
self.assertFalse(call.cancelled())
self.assertTrue(call.cancel())
self.assertTrue(call.cancelled())
await call.done_writing()
with self.assertRaises(asyncio.CancelledError):
await call
async def test_early_cancel_stream_unary(self):
call = self._stub.StreamingInputCall()
# Cancels the RPC
self.assertFalse(call.done())
self.assertFalse(call.cancelled())
self.assertTrue(call.cancel())
self.assertTrue(call.cancelled())
with self.assertRaises(asyncio.InvalidStateError):
await call.write(messages_pb2.StreamingInputCallRequest())
# Should be no-op
await call.done_writing()
with self.assertRaises(asyncio.CancelledError):
await call
async def test_write_after_done_writing(self):
call = self._stub.StreamingInputCall()
# Prepares the request
payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE)
request = messages_pb2.StreamingInputCallRequest(payload=payload)
# Sends out requests
for _ in range(_NUM_STREAM_RESPONSES):
await call.write(request)
# Should be no-op
await call.done_writing()
with self.assertRaises(asyncio.InvalidStateError):
await call.write(messages_pb2.StreamingInputCallRequest())
response = await call
self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse)
self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE,
response.aggregated_payload_size)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
async def test_error_in_async_generator(self):
# Server will pause between responses
request = messages_pb2.StreamingOutputCallRequest()
request.response_parameters.append(
messages_pb2.ResponseParameters(
size=_RESPONSE_PAYLOAD_SIZE,
interval_us=_RESPONSE_INTERVAL_US,
))
# We expect the request iterator to receive the exception
request_iterator_received_the_exception = asyncio.Event()
async def request_iterator():
with self.assertRaises(asyncio.CancelledError):
for _ in range(_NUM_STREAM_RESPONSES):
yield request
await asyncio.sleep(test_constants.SHORT_TIMEOUT)
request_iterator_received_the_exception.set()
call = self._stub.StreamingInputCall(request_iterator())
# Cancel the RPC after at least one response
async def cancel_later():
await asyncio.sleep(test_constants.SHORT_TIMEOUT * 2)
call.cancel()
cancel_later_task = self.loop.create_task(cancel_later())
# No exceptions here
with self.assertRaises(asyncio.CancelledError):
await call
await request_iterator_received_the_exception.wait()
# No failures in the cancel later task!
await cancel_later_task
# Prepares the request that stream in a ping-pong manner.
_STREAM_OUTPUT_REQUEST_ONE_RESPONSE = messages_pb2.StreamingOutputCallRequest()
_STREAM_OUTPUT_REQUEST_ONE_RESPONSE.response_parameters.append(
messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE))
class TestStreamStreamCall(AioTestBase):
async def setUp(self):
self._server_target, self._server = await start_test_server()
self._channel = aio.insecure_channel(self._server_target)
self._stub = test_pb2_grpc.TestServiceStub(self._channel)
async def tearDown(self):
await self._channel.close()
await self._server.stop(None)
async def test_cancel(self):
# Invokes the actual RPC
call = self._stub.FullDuplexCall()
for _ in range(_NUM_STREAM_RESPONSES):
await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE)
response = await call.read()
self.assertIsInstance(response,
messages_pb2.StreamingOutputCallResponse)
self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
# Cancels the RPC
self.assertFalse(call.done())
self.assertFalse(call.cancelled())
self.assertTrue(call.cancel())
self.assertTrue(call.cancelled())
self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
async def test_cancel_with_pending_read(self):
call = self._stub.FullDuplexCall()
await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE)
# Cancels the RPC
self.assertFalse(call.done())
self.assertFalse(call.cancelled())
self.assertTrue(call.cancel())
self.assertTrue(call.cancelled())
self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
async def test_cancel_with_ongoing_read(self):
call = self._stub.FullDuplexCall()
coro_started = asyncio.Event()
async def read_coro():
coro_started.set()
await call.read()
read_task = self.loop.create_task(read_coro())
await coro_started.wait()
self.assertFalse(read_task.done())
# Cancels the RPC
self.assertFalse(call.done())
self.assertFalse(call.cancelled())
self.assertTrue(call.cancel())
self.assertTrue(call.cancelled())
self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
async def test_early_cancel(self):
call = self._stub.FullDuplexCall()
# Cancels the RPC
self.assertFalse(call.done())
self.assertFalse(call.cancelled())
self.assertTrue(call.cancel())
self.assertTrue(call.cancelled())
self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
async def test_cancel_after_done_writing(self):
call = self._stub.FullDuplexCall()
await call.done_writing()
# Cancels the RPC
self.assertFalse(call.done())
self.assertFalse(call.cancelled())
self.assertTrue(call.cancel())
self.assertTrue(call.cancelled())
self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
async def test_late_cancel(self):
call = self._stub.FullDuplexCall()
await call.done_writing()
self.assertEqual(grpc.StatusCode.OK, await call.code())
# Cancels the RPC
self.assertTrue(call.done())
self.assertFalse(call.cancelled())
self.assertFalse(call.cancel())
self.assertFalse(call.cancelled())
# Status is still OK
self.assertEqual(grpc.StatusCode.OK, await call.code())
async def test_async_generator(self):
async def request_generator():
yield _STREAM_OUTPUT_REQUEST_ONE_RESPONSE
yield _STREAM_OUTPUT_REQUEST_ONE_RESPONSE
call = self._stub.FullDuplexCall(request_generator())
async for response in call:
self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
self.assertEqual(await call.code(), grpc.StatusCode.OK)
async def test_too_many_reads(self):
async def request_generator():
for _ in range(_NUM_STREAM_RESPONSES):
yield _STREAM_OUTPUT_REQUEST_ONE_RESPONSE
call = self._stub.FullDuplexCall(request_generator())
for _ in range(_NUM_STREAM_RESPONSES):
response = await call.read()
self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
self.assertIs(await call.read(), aio.EOF)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
# After the RPC finished, the read should also produce EOF
self.assertIs(await call.read(), aio.EOF)
async def test_read_write_after_done_writing(self):
call = self._stub.FullDuplexCall()
# Writes two requests, and pending two requests
await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE)
await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE)
await call.done_writing()
# Further write should fail
with self.assertRaises(asyncio.InvalidStateError):
await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE)
# But read should be unaffected
response = await call.read()
self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
response = await call.read()
self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
self.assertEqual(await call.code(), grpc.StatusCode.OK)
async def test_error_in_async_generator(self):
# Server will pause between responses
request = messages_pb2.StreamingOutputCallRequest()
request.response_parameters.append(
messages_pb2.ResponseParameters(
size=_RESPONSE_PAYLOAD_SIZE,
interval_us=_RESPONSE_INTERVAL_US,
))
# We expect the request iterator to receive the exception
request_iterator_received_the_exception = asyncio.Event()
async def request_iterator():
with self.assertRaises(asyncio.CancelledError):
for _ in range(_NUM_STREAM_RESPONSES):
yield request
await asyncio.sleep(test_constants.SHORT_TIMEOUT)
request_iterator_received_the_exception.set()
call = self._stub.FullDuplexCall(request_iterator())
# Cancel the RPC after at least one response
async def cancel_later():
await asyncio.sleep(test_constants.SHORT_TIMEOUT * 2)
call.cancel()
cancel_later_task = self.loop.create_task(cancel_later())
# No exceptions here
async for response in call:
self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
await request_iterator_received_the_exception.wait()
self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
# No failures in the cancel later task!
await cancel_later_task
if __name__ == '__main__':
logging.basicConfig()
unittest.main(verbosity=2)

@ -32,6 +32,7 @@ _UNARY_CALL_METHOD = '/grpc.testing.TestService/UnaryCall'
_UNARY_CALL_METHOD_WITH_SLEEP = '/grpc.testing.TestService/UnaryCallWithSleep'
_STREAMING_OUTPUT_CALL_METHOD = '/grpc.testing.TestService/StreamingOutputCall'
_NUM_STREAM_RESPONSES = 5
_REQUEST_PAYLOAD_SIZE = 7
_RESPONSE_PAYLOAD_SIZE = 42
@ -121,7 +122,104 @@ class TestChannel(AioTestBase):
self.assertEqual(await call.code(), grpc.StatusCode.OK)
await channel.close()
async def test_stream_unary_using_write(self):
channel = aio.insecure_channel(self._server_target)
stub = test_pb2_grpc.TestServiceStub(channel)
# Invokes the actual RPC
call = stub.StreamingInputCall()
# Prepares the request
payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE)
request = messages_pb2.StreamingInputCallRequest(payload=payload)
# Sends out requests
for _ in range(_NUM_STREAM_RESPONSES):
await call.write(request)
await call.done_writing()
# Validates the responses
response = await call
self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse)
self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE,
response.aggregated_payload_size)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
await channel.close()
async def test_stream_unary_using_async_gen(self):
channel = aio.insecure_channel(self._server_target)
stub = test_pb2_grpc.TestServiceStub(channel)
# Prepares the request
payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE)
request = messages_pb2.StreamingInputCallRequest(payload=payload)
async def gen():
for _ in range(_NUM_STREAM_RESPONSES):
yield request
# Invokes the actual RPC
call = stub.StreamingInputCall(gen())
# Validates the responses
response = await call
self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse)
self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE,
response.aggregated_payload_size)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
await channel.close()
async def test_stream_stream_using_read_write(self):
channel = aio.insecure_channel(self._server_target)
stub = test_pb2_grpc.TestServiceStub(channel)
# Invokes the actual RPC
call = stub.FullDuplexCall()
# Prepares the request
request = messages_pb2.StreamingOutputCallRequest()
request.response_parameters.append(
messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE))
for _ in range(_NUM_STREAM_RESPONSES):
await call.write(request)
response = await call.read()
self.assertIsInstance(response,
messages_pb2.StreamingOutputCallResponse)
self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
await call.done_writing()
self.assertEqual(grpc.StatusCode.OK, await call.code())
await channel.close()
async def test_stream_stream_using_async_gen(self):
channel = aio.insecure_channel(self._server_target)
stub = test_pb2_grpc.TestServiceStub(channel)
# Prepares the request
request = messages_pb2.StreamingOutputCallRequest()
request.response_parameters.append(
messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE))
async def gen():
for _ in range(_NUM_STREAM_RESPONSES):
yield request
# Invokes the actual RPC
call = stub.FullDuplexCall(gen())
async for response in call:
self.assertIsInstance(response,
messages_pb2.StreamingOutputCallResponse)
self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
self.assertEqual(grpc.StatusCode.OK, await call.code())
await channel.close()
if __name__ == '__main__':
logging.basicConfig(level=logging.WARN)
logging.basicConfig(level=logging.DEBUG)
unittest.main(verbosity=2)

@ -13,15 +13,16 @@
# limitations under the License.
import asyncio
import gc
import logging
import unittest
import time
import gc
import unittest
import grpc
from grpc.experimental import aio
from tests_aio.unit._test_base import AioTestBase
from tests.unit.framework.common import test_constants
from tests_aio.unit._test_base import AioTestBase
_SIMPLE_UNARY_UNARY = '/test/SimpleUnaryUnary'
_BLOCK_FOREVER = '/test/BlockForever'
@ -29,9 +30,16 @@ _BLOCK_BRIEFLY = '/test/BlockBriefly'
_UNARY_STREAM_ASYNC_GEN = '/test/UnaryStreamAsyncGen'
_UNARY_STREAM_READER_WRITER = '/test/UnaryStreamReaderWriter'
_UNARY_STREAM_EVILLY_MIXED = '/test/UnaryStreamEvillyMixed'
_STREAM_UNARY_ASYNC_GEN = '/test/StreamUnaryAsyncGen'
_STREAM_UNARY_READER_WRITER = '/test/StreamUnaryReaderWriter'
_STREAM_UNARY_EVILLY_MIXED = '/test/StreamUnaryEvillyMixed'
_STREAM_STREAM_ASYNC_GEN = '/test/StreamStreamAsyncGen'
_STREAM_STREAM_READER_WRITER = '/test/StreamStreamReaderWriter'
_STREAM_STREAM_EVILLY_MIXED = '/test/StreamStreamEvillyMixed'
_REQUEST = b'\x00\x00\x00'
_RESPONSE = b'\x01\x01\x01'
_NUM_STREAM_REQUESTS = 3
_NUM_STREAM_RESPONSES = 5
@ -39,6 +47,41 @@ class _GenericHandler(grpc.GenericRpcHandler):
def __init__(self):
self._called = asyncio.get_event_loop().create_future()
self._routing_table = {
_SIMPLE_UNARY_UNARY:
grpc.unary_unary_rpc_method_handler(self._unary_unary),
_BLOCK_FOREVER:
grpc.unary_unary_rpc_method_handler(self._block_forever),
_BLOCK_BRIEFLY:
grpc.unary_unary_rpc_method_handler(self._block_briefly),
_UNARY_STREAM_ASYNC_GEN:
grpc.unary_stream_rpc_method_handler(
self._unary_stream_async_gen),
_UNARY_STREAM_READER_WRITER:
grpc.unary_stream_rpc_method_handler(
self._unary_stream_reader_writer),
_UNARY_STREAM_EVILLY_MIXED:
grpc.unary_stream_rpc_method_handler(
self._unary_stream_evilly_mixed),
_STREAM_UNARY_ASYNC_GEN:
grpc.stream_unary_rpc_method_handler(
self._stream_unary_async_gen),
_STREAM_UNARY_READER_WRITER:
grpc.stream_unary_rpc_method_handler(
self._stream_unary_reader_writer),
_STREAM_UNARY_EVILLY_MIXED:
grpc.stream_unary_rpc_method_handler(
self._stream_unary_evilly_mixed),
_STREAM_STREAM_ASYNC_GEN:
grpc.stream_stream_rpc_method_handler(
self._stream_stream_async_gen),
_STREAM_STREAM_READER_WRITER:
grpc.stream_stream_rpc_method_handler(
self._stream_stream_reader_writer),
_STREAM_STREAM_EVILLY_MIXED:
grpc.stream_stream_rpc_method_handler(
self._stream_stream_evilly_mixed),
}
@staticmethod
async def _unary_unary(unused_request, unused_context):
@ -64,23 +107,59 @@ class _GenericHandler(grpc.GenericRpcHandler):
for _ in range(_NUM_STREAM_RESPONSES - 1):
await context.write(_RESPONSE)
async def _stream_unary_async_gen(self, request_iterator, unused_context):
request_count = 0
async for request in request_iterator:
assert _REQUEST == request
request_count += 1
assert _NUM_STREAM_REQUESTS == request_count
return _RESPONSE
async def _stream_unary_reader_writer(self, unused_request, context):
for _ in range(_NUM_STREAM_REQUESTS):
assert _REQUEST == await context.read()
return _RESPONSE
async def _stream_unary_evilly_mixed(self, request_iterator, context):
assert _REQUEST == await context.read()
request_count = 0
async for request in request_iterator:
assert _REQUEST == request
request_count += 1
assert _NUM_STREAM_REQUESTS - 1 == request_count
return _RESPONSE
async def _stream_stream_async_gen(self, request_iterator, unused_context):
request_count = 0
async for request in request_iterator:
assert _REQUEST == request
request_count += 1
assert _NUM_STREAM_REQUESTS == request_count
for _ in range(_NUM_STREAM_RESPONSES):
yield _RESPONSE
async def _stream_stream_reader_writer(self, unused_request, context):
for _ in range(_NUM_STREAM_REQUESTS):
assert _REQUEST == await context.read()
for _ in range(_NUM_STREAM_RESPONSES):
await context.write(_RESPONSE)
async def _stream_stream_evilly_mixed(self, request_iterator, context):
assert _REQUEST == await context.read()
request_count = 0
async for request in request_iterator:
assert _REQUEST == request
request_count += 1
assert _NUM_STREAM_REQUESTS - 1 == request_count
yield _RESPONSE
for _ in range(_NUM_STREAM_RESPONSES - 1):
await context.write(_RESPONSE)
def service(self, handler_details):
self._called.set_result(None)
if handler_details.method == _SIMPLE_UNARY_UNARY:
return grpc.unary_unary_rpc_method_handler(self._unary_unary)
if handler_details.method == _BLOCK_FOREVER:
return grpc.unary_unary_rpc_method_handler(self._block_forever)
if handler_details.method == _BLOCK_BRIEFLY:
return grpc.unary_unary_rpc_method_handler(self._block_briefly)
if handler_details.method == _UNARY_STREAM_ASYNC_GEN:
return grpc.unary_stream_rpc_method_handler(
self._unary_stream_async_gen)
if handler_details.method == _UNARY_STREAM_READER_WRITER:
return grpc.unary_stream_rpc_method_handler(
self._unary_stream_reader_writer)
if handler_details.method == _UNARY_STREAM_EVILLY_MIXED:
return grpc.unary_stream_rpc_method_handler(
self._unary_stream_evilly_mixed)
return self._routing_table[handler_details.method]
async def wait_for_call(self):
await self._called
@ -98,89 +177,152 @@ async def _start_test_server():
class TestServer(AioTestBase):
async def setUp(self):
self._server_target, self._server, self._generic_handler = await _start_test_server(
)
addr, self._server, self._generic_handler = await _start_test_server()
self._channel = aio.insecure_channel(addr)
async def tearDown(self):
await self._channel.close()
await self._server.stop(None)
async def test_unary_unary(self):
async with aio.insecure_channel(self._server_target) as channel:
unary_unary_call = channel.unary_unary(_SIMPLE_UNARY_UNARY)
response = await unary_unary_call(_REQUEST)
self.assertEqual(response, _RESPONSE)
unary_unary_call = self._channel.unary_unary(_SIMPLE_UNARY_UNARY)
response = await unary_unary_call(_REQUEST)
self.assertEqual(response, _RESPONSE)
async def test_unary_stream_async_generator(self):
async with aio.insecure_channel(self._server_target) as channel:
unary_stream_call = channel.unary_stream(_UNARY_STREAM_ASYNC_GEN)
call = unary_stream_call(_REQUEST)
unary_stream_call = self._channel.unary_stream(_UNARY_STREAM_ASYNC_GEN)
call = unary_stream_call(_REQUEST)
# Expecting the request message to reach server before retriving
# any responses.
await asyncio.wait_for(self._generic_handler.wait_for_call(),
test_constants.SHORT_TIMEOUT)
response_cnt = 0
async for response in call:
response_cnt += 1
self.assertEqual(_RESPONSE, response)
response_cnt = 0
async for response in call:
response_cnt += 1
self.assertEqual(_RESPONSE, response)
self.assertEqual(_NUM_STREAM_RESPONSES, response_cnt)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
self.assertEqual(_NUM_STREAM_RESPONSES, response_cnt)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
async def test_unary_stream_reader_writer(self):
async with aio.insecure_channel(self._server_target) as channel:
unary_stream_call = channel.unary_stream(
_UNARY_STREAM_READER_WRITER)
call = unary_stream_call(_REQUEST)
# Expecting the request message to reach server before retriving
# any responses.
await asyncio.wait_for(self._generic_handler.wait_for_call(),
test_constants.SHORT_TIMEOUT)
unary_stream_call = self._channel.unary_stream(
_UNARY_STREAM_READER_WRITER)
call = unary_stream_call(_REQUEST)
for _ in range(_NUM_STREAM_RESPONSES):
response = await call.read()
self.assertEqual(_RESPONSE, response)
for _ in range(_NUM_STREAM_RESPONSES):
response = await call.read()
self.assertEqual(_RESPONSE, response)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
async def test_unary_stream_evilly_mixed(self):
async with aio.insecure_channel(self._server_target) as channel:
unary_stream_call = channel.unary_stream(_UNARY_STREAM_EVILLY_MIXED)
call = unary_stream_call(_REQUEST)
unary_stream_call = self._channel.unary_stream(
_UNARY_STREAM_EVILLY_MIXED)
call = unary_stream_call(_REQUEST)
# Uses reader API
self.assertEqual(_RESPONSE, await call.read())
# Uses async generator API
response_cnt = 0
async for response in call:
response_cnt += 1
self.assertEqual(_RESPONSE, response)
self.assertEqual(_NUM_STREAM_RESPONSES - 1, response_cnt)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
async def test_stream_unary_async_generator(self):
stream_unary_call = self._channel.stream_unary(_STREAM_UNARY_ASYNC_GEN)
call = stream_unary_call()
for _ in range(_NUM_STREAM_REQUESTS):
await call.write(_REQUEST)
await call.done_writing()
response = await call
self.assertEqual(_RESPONSE, response)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
async def test_stream_unary_reader_writer(self):
stream_unary_call = self._channel.stream_unary(
_STREAM_UNARY_READER_WRITER)
call = stream_unary_call()
for _ in range(_NUM_STREAM_REQUESTS):
await call.write(_REQUEST)
await call.done_writing()
response = await call
self.assertEqual(_RESPONSE, response)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
async def test_stream_unary_evilly_mixed(self):
stream_unary_call = self._channel.stream_unary(
_STREAM_UNARY_EVILLY_MIXED)
call = stream_unary_call()
# Expecting the request message to reach server before retriving
# any responses.
await asyncio.wait_for(self._generic_handler.wait_for_call(),
test_constants.SHORT_TIMEOUT)
for _ in range(_NUM_STREAM_REQUESTS):
await call.write(_REQUEST)
await call.done_writing()
# Uses reader API
self.assertEqual(_RESPONSE, await call.read())
response = await call
self.assertEqual(_RESPONSE, response)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
# Uses async generator API
response_cnt = 0
async for response in call:
response_cnt += 1
self.assertEqual(_RESPONSE, response)
async def test_stream_stream_async_generator(self):
stream_stream_call = self._channel.stream_stream(
_STREAM_STREAM_ASYNC_GEN)
call = stream_stream_call()
self.assertEqual(_NUM_STREAM_RESPONSES - 1, response_cnt)
for _ in range(_NUM_STREAM_REQUESTS):
await call.write(_REQUEST)
await call.done_writing()
self.assertEqual(await call.code(), grpc.StatusCode.OK)
for _ in range(_NUM_STREAM_RESPONSES):
response = await call.read()
self.assertEqual(_RESPONSE, response)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
async def test_stream_stream_reader_writer(self):
stream_stream_call = self._channel.stream_stream(
_STREAM_STREAM_READER_WRITER)
call = stream_stream_call()
for _ in range(_NUM_STREAM_REQUESTS):
await call.write(_REQUEST)
await call.done_writing()
for _ in range(_NUM_STREAM_RESPONSES):
response = await call.read()
self.assertEqual(_RESPONSE, response)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
async def test_stream_stream_evilly_mixed(self):
stream_stream_call = self._channel.stream_stream(
_STREAM_STREAM_EVILLY_MIXED)
call = stream_stream_call()
for _ in range(_NUM_STREAM_REQUESTS):
await call.write(_REQUEST)
await call.done_writing()
for _ in range(_NUM_STREAM_RESPONSES):
response = await call.read()
self.assertEqual(_RESPONSE, response)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
async def test_shutdown(self):
await self._server.stop(None)
# Ensures no SIGSEGV triggered, and ends within timeout.
async def test_shutdown_after_call(self):
async with aio.insecure_channel(self._server_target) as channel:
await channel.unary_unary(_SIMPLE_UNARY_UNARY)(_REQUEST)
await self._channel.unary_unary(_SIMPLE_UNARY_UNARY)(_REQUEST)
await self._server.stop(None)
async def test_graceful_shutdown_success(self):
channel = aio.insecure_channel(self._server_target)
call = channel.unary_unary(_BLOCK_BRIEFLY)(_REQUEST)
call = self._channel.unary_unary(_BLOCK_BRIEFLY)(_REQUEST)
await self._generic_handler.wait_for_call()
shutdown_start_time = time.time()
@ -190,13 +332,11 @@ class TestServer(AioTestBase):
test_constants.SHORT_TIMEOUT / 3)
# Validates the states.
await channel.close()
self.assertEqual(_RESPONSE, await call)
self.assertTrue(call.done())
async def test_graceful_shutdown_failed(self):
channel = aio.insecure_channel(self._server_target)
call = channel.unary_unary(_BLOCK_FOREVER)(_REQUEST)
call = self._channel.unary_unary(_BLOCK_FOREVER)(_REQUEST)
await self._generic_handler.wait_for_call()
await self._server.stop(test_constants.SHORT_TIMEOUT)
@ -206,11 +346,9 @@ class TestServer(AioTestBase):
self.assertEqual(grpc.StatusCode.UNAVAILABLE,
exception_context.exception.code())
self.assertIn('GOAWAY', exception_context.exception.details())
await channel.close()
async def test_concurrent_graceful_shutdown(self):
channel = aio.insecure_channel(self._server_target)
call = channel.unary_unary(_BLOCK_BRIEFLY)(_REQUEST)
call = self._channel.unary_unary(_BLOCK_BRIEFLY)(_REQUEST)
await self._generic_handler.wait_for_call()
# Expects the shortest grace period to be effective.
@ -224,13 +362,11 @@ class TestServer(AioTestBase):
self.assertGreater(grace_period_length,
test_constants.SHORT_TIMEOUT / 3)
await channel.close()
self.assertEqual(_RESPONSE, await call)
self.assertTrue(call.done())
async def test_concurrent_graceful_shutdown_immediate(self):
channel = aio.insecure_channel(self._server_target)
call = channel.unary_unary(_BLOCK_FOREVER)(_REQUEST)
call = self._channel.unary_unary(_BLOCK_FOREVER)(_REQUEST)
await self._generic_handler.wait_for_call()
# Expects no grace period, due to the "server.stop(None)".
@ -246,7 +382,6 @@ class TestServer(AioTestBase):
self.assertEqual(grpc.StatusCode.UNAVAILABLE,
exception_context.exception.code())
self.assertIn('GOAWAY', exception_context.exception.details())
await channel.close()
@unittest.skip('https://github.com/grpc/grpc/issues/20818')
async def test_shutdown_before_call(self):

Loading…
Cancel
Save