Merge pull request #21517 from lidizheng/aio-streaming

[Aio] Client Streaming and Bi-di Streaming
reviewable/pr21745/r1
Lidi Zheng 5 years ago committed by GitHub
commit 5a4a5a0088
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 103
      src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi
  2. 11
      src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi
  3. 21
      src/python/grpcio/grpc/_cython/_cygrpc/aio/common.pyx.pxi
  4. 4
      src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pxd.pxi
  5. 290
      src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi
  6. 4
      src/python/grpcio/grpc/experimental/aio/__init__.py
  7. 90
      src/python/grpcio/grpc/experimental/aio/_base_call.py
  8. 449
      src/python/grpcio/grpc/experimental/aio/_call.py
  9. 166
      src/python/grpcio/grpc/experimental/aio/_channel.py
  10. 2
      src/python/grpcio/grpc/experimental/aio/_typing.py
  11. 2
      src/python/grpcio_tests/tests_aio/tests.json
  12. 26
      src/python/grpcio_tests/tests_aio/unit/_test_server.py
  13. 310
      src/python/grpcio_tests/tests_aio/unit/call_test.py
  14. 100
      src/python/grpcio_tests/tests_aio/unit/channel_test.py
  15. 297
      src/python/grpcio_tests/tests_aio/unit/server_test.py

@ -23,18 +23,18 @@ _EMPTY_METADATA = None
_UNKNOWN_CANCELLATION_DETAILS = 'RPC cancelled for unknown reason.' _UNKNOWN_CANCELLATION_DETAILS = 'RPC cancelled for unknown reason.'
cdef class _AioCall: cdef class _AioCall(GrpcCallWrapper):
def __cinit__(self, def __cinit__(self,
AioChannel channel, AioChannel channel,
object deadline, object deadline,
bytes method, bytes method,
CallCredentials credentials): CallCredentials call_credentials):
self.call = NULL self.call = NULL
self._channel = channel self._channel = channel
self._references = [] self._references = []
self._loop = asyncio.get_event_loop() self._loop = asyncio.get_event_loop()
self._create_grpc_call(deadline, method, credentials) self._create_grpc_call(deadline, method, call_credentials)
self._is_locally_cancelled = False self._is_locally_cancelled = False
def __dealloc__(self): def __dealloc__(self):
@ -196,9 +196,25 @@ cdef class _AioCall:
self, self,
self._loop self._loop
) )
return received_message if received_message:
return received_message
else:
return EOF
async def send_serialized_message(self, bytes message):
"""Sends one single raw message in bytes."""
await _send_message(self,
message,
True,
self._loop)
async def unary_stream(self, async def send_receive_close(self):
"""Half close the RPC on the client-side."""
cdef SendCloseFromClientOperation op = SendCloseFromClientOperation(_EMPTY_FLAGS)
cdef tuple ops = (op,)
await execute_batch(self, ops, self._loop)
async def initiate_unary_stream(self,
bytes request, bytes request,
object initial_metadata_observer, object initial_metadata_observer,
object status_observer): object status_observer):
@ -233,3 +249,80 @@ cdef class _AioCall:
await _receive_initial_metadata(self, await _receive_initial_metadata(self,
self._loop), self._loop),
) )
async def stream_unary(self,
tuple metadata,
object metadata_sent_observer,
object initial_metadata_observer,
object status_observer):
"""Actual implementation of the complete unary-stream call.
Needs to pay extra attention to the raise mechanism. If we want to
propagate the final status exception, then we have to raise it.
Othersize, it would end normally and raise `StopAsyncIteration()`.
"""
# Sends out initial_metadata ASAP.
await _send_initial_metadata(self,
metadata,
self._loop)
# Notify upper level that sending messages are allowed now.
metadata_sent_observer()
# Receives initial metadata.
initial_metadata_observer(
await _receive_initial_metadata(self,
self._loop),
)
cdef tuple inbound_ops
cdef ReceiveMessageOperation receive_message_op = ReceiveMessageOperation(_EMPTY_FLAGS)
cdef ReceiveStatusOnClientOperation receive_status_on_client_op = ReceiveStatusOnClientOperation(_EMPTY_FLAGS)
inbound_ops = (receive_message_op, receive_status_on_client_op)
# Executes all operations in one batch.
await execute_batch(self,
inbound_ops,
self._loop)
status = AioRpcStatus(
receive_status_on_client_op.code(),
receive_status_on_client_op.details(),
receive_status_on_client_op.trailing_metadata(),
receive_status_on_client_op.error_string(),
)
# Reports the final status of the RPC to Python layer. The observer
# pattern is used here to unify unary and streaming code path.
status_observer(status)
if status.code() == StatusCode.ok:
return receive_message_op.message()
else:
return None
async def initiate_stream_stream(self,
tuple metadata,
object metadata_sent_observer,
object initial_metadata_observer,
object status_observer):
"""Actual implementation of the complete stream-stream call.
Needs to pay extra attention to the raise mechanism. If we want to
propagate the final status exception, then we have to raise it.
Othersize, it would end normally and raise `StopAsyncIteration()`.
"""
# Peer may prematurely end this RPC at any point. We need a corutine
# that watches if the server sends the final status.
self._loop.create_task(self._handle_status_once_received(status_observer))
# Sends out initial_metadata ASAP.
await _send_initial_metadata(self,
metadata,
self._loop)
# Notify upper level that sending messages are allowed now.
metadata_sent_observer()
# Receives initial metadata.
initial_metadata_observer(
await _receive_initial_metadata(self,
self._loop),
)

@ -96,7 +96,7 @@ cdef class AioChannel:
def call(self, def call(self,
bytes method, bytes method,
object deadline, object deadline,
CallCredentials credentials): object python_call_credentials):
"""Assembles a Cython Call object. """Assembles a Cython Call object.
Returns: Returns:
@ -105,5 +105,12 @@ cdef class AioChannel:
if self._status == AIO_CHANNEL_STATUS_DESTROYED: if self._status == AIO_CHANNEL_STATUS_DESTROYED:
# TODO(lidiz) switch to UsageError # TODO(lidiz) switch to UsageError
raise RuntimeError('Channel is closed.') raise RuntimeError('Channel is closed.')
cdef _AioCall call = _AioCall(self, deadline, method, credentials)
cdef CallCredentials cython_call_credentials
if python_call_credentials is not None:
cython_call_credentials = python_call_credentials._credentials
else:
cython_call_credentials = None
cdef _AioCall call = _AioCall(self, deadline, method, cython_call_credentials)
return call return call

@ -33,3 +33,24 @@ cdef bytes serialize(object serializer, object message):
return serializer(message) return serializer(message)
else: else:
return message return message
class _EOF:
def __bool__(self):
return False
def __len__(self):
return 0
def _repr(self) -> str:
return '<grpc.aio.EOF>'
def __repr__(self) -> str:
return self._repr()
def __str__(self) -> str:
return self._repr()
EOF = _EOF()

@ -21,6 +21,10 @@ cdef class RPCState(GrpcCallWrapper):
cdef grpc_call_details details cdef grpc_call_details details
cdef grpc_metadata_array request_metadata cdef grpc_metadata_array request_metadata
cdef AioServer server cdef AioServer server
# NOTE(lidiz) Under certain corner case, receiving the client close
# operation won't immediately fail ongoing RECV_MESSAGE operations. Here I
# added a flag to workaround this unexpected behavior.
cdef bint client_closed
cdef object abort_exception cdef object abort_exception
cdef bint metadata_sent cdef bint metadata_sent
cdef bint status_sent cdef bint status_sent

@ -20,7 +20,8 @@ import traceback
# TODO(https://github.com/grpc/grpc/issues/20850) refactor this. # TODO(https://github.com/grpc/grpc/issues/20850) refactor this.
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
cdef int _EMPTY_FLAG = 0 cdef int _EMPTY_FLAG = 0
# TODO(lidiz) Use a designated value other than None.
cdef str _SERVER_STOPPED_DETAILS = 'Server already stopped.'
cdef class _HandlerCallDetails: cdef class _HandlerCallDetails:
def __cinit__(self, str method, tuple invocation_metadata): def __cinit__(self, str method, tuple invocation_metadata):
@ -35,6 +36,7 @@ cdef class RPCState:
self.server = server self.server = server
grpc_metadata_array_init(&self.request_metadata) grpc_metadata_array_init(&self.request_metadata)
grpc_call_details_init(&self.details) grpc_call_details_init(&self.details)
self.client_closed = False
self.abort_exception = None self.abort_exception = None
self.metadata_sent = False self.metadata_sent = False
self.status_sent = False self.status_sent = False
@ -83,13 +85,23 @@ cdef class _ServicerContext:
self._loop = loop self._loop = loop
async def read(self): async def read(self):
cdef bytes raw_message
if self._rpc_state.server._status == AIO_SERVER_STATUS_STOPPED:
raise RuntimeError(_SERVER_STOPPED_DETAILS)
if self._rpc_state.status_sent: if self._rpc_state.status_sent:
raise RuntimeError('RPC already finished.') raise RuntimeError('RPC already finished.')
cdef bytes raw_message = await _receive_message(self._rpc_state, self._loop) if self._rpc_state.client_closed:
return deserialize(self._request_deserializer, return EOF
raw_message) raw_message = await _receive_message(self._rpc_state, self._loop)
if raw_message is None:
return EOF
else:
return deserialize(self._request_deserializer,
raw_message)
async def write(self, object message): async def write(self, object message):
if self._rpc_state.server._status == AIO_SERVER_STATUS_STOPPED:
raise RuntimeError(_SERVER_STOPPED_DETAILS)
if self._rpc_state.status_sent: if self._rpc_state.status_sent:
raise RuntimeError('RPC already finished.') raise RuntimeError('RPC already finished.')
await _send_message(self._rpc_state, await _send_message(self._rpc_state,
@ -102,6 +114,8 @@ cdef class _ServicerContext:
async def send_initial_metadata(self, tuple metadata): async def send_initial_metadata(self, tuple metadata):
if self._rpc_state.status_sent: if self._rpc_state.status_sent:
raise RuntimeError('RPC already finished.') raise RuntimeError('RPC already finished.')
elif self._rpc_state.server._status == AIO_SERVER_STATUS_STOPPED:
raise RuntimeError(_SERVER_STOPPED_DETAILS)
elif self._rpc_state.metadata_sent: elif self._rpc_state.metadata_sent:
raise RuntimeError('Send initial metadata failed: already sent') raise RuntimeError('Send initial metadata failed: already sent')
else: else:
@ -145,27 +159,23 @@ cdef _find_method_handler(str method, list generic_handlers):
return None return None
async def _handle_unary_unary_rpc(object method_handler, async def _finish_handler_with_unary_response(RPCState rpc_state,
RPCState rpc_state, object unary_handler,
object loop): object request,
# Receives request message _ServicerContext servicer_context,
cdef bytes request_raw = await _receive_message(rpc_state, loop) object response_serializer,
object loop):
# Deserializes the request message """Finishes server method handler with a single response.
cdef object request_message = deserialize(
method_handler.request_deserializer, This function executes the application handler, and handles response
request_raw, sending, as well as errors. It is shared between unary-unary and
) stream-unary handlers.
"""
# Executes application logic # Executes application logic
cdef object response_message = await method_handler.unary_unary(
request_message, cdef object response_message = await unary_handler(
_ServicerContext( request,
rpc_state, servicer_context,
None,
None,
loop,
),
) )
# Raises exception if aborted # Raises exception if aborted
@ -173,50 +183,50 @@ async def _handle_unary_unary_rpc(object method_handler,
# Serializes the response message # Serializes the response message
cdef bytes response_raw = serialize( cdef bytes response_raw = serialize(
method_handler.response_serializer, response_serializer,
response_message, response_message,
) )
# Sends response message # Assembles the batch operations
cdef tuple send_ops = ( cdef Operation send_status_op = SendStatusFromServerOperation(
SendStatusFromServerOperation( tuple(),
tuple(),
StatusCode.ok, StatusCode.ok,
b'', b'',
_EMPTY_FLAGS, _EMPTY_FLAGS,
),
SendInitialMetadataOperation(None, _EMPTY_FLAGS),
SendMessageOperation(response_raw, _EMPTY_FLAGS),
) )
cdef tuple finish_ops
if not rpc_state.metadata_sent:
finish_ops = (
send_status_op,
SendInitialMetadataOperation(None, _EMPTY_FLAGS),
SendMessageOperation(response_raw, _EMPTY_FLAGS),
)
else:
finish_ops = (
send_status_op,
SendMessageOperation(response_raw, _EMPTY_FLAGS),
)
rpc_state.status_sent = True rpc_state.status_sent = True
await execute_batch(rpc_state, send_ops, loop) await execute_batch(rpc_state, finish_ops, loop)
async def _handle_unary_stream_rpc(object method_handler,
RPCState rpc_state,
object loop):
# Receives request message
cdef bytes request_raw = await _receive_message(rpc_state, loop)
# Deserializes the request message
cdef object request_message = deserialize(
method_handler.request_deserializer,
request_raw,
)
cdef _ServicerContext servicer_context = _ServicerContext(
rpc_state,
method_handler.request_deserializer,
method_handler.response_serializer,
loop,
)
async def _finish_handler_with_stream_responses(RPCState rpc_state,
object stream_handler,
object request,
_ServicerContext servicer_context,
object loop):
"""Finishes server method handler with multiple responses.
This function executes the application handler, and handles response
sending, as well as errors. It is shared between unary-stream and
stream-stream handlers.
"""
cdef object async_response_generator cdef object async_response_generator
cdef object response_message cdef object response_message
if inspect.iscoroutinefunction(method_handler.unary_stream): if inspect.iscoroutinefunction(stream_handler):
# The handler uses reader / writer API, returns None. # The handler uses reader / writer API, returns None.
await method_handler.unary_stream( await stream_handler(
request_message, request,
servicer_context, servicer_context,
) )
@ -224,8 +234,8 @@ async def _handle_unary_stream_rpc(object method_handler,
_raise_if_aborted(rpc_state) _raise_if_aborted(rpc_state)
else: else:
# The handler uses async generator API # The handler uses async generator API
async_response_generator = method_handler.unary_stream( async_response_generator = stream_handler(
request_message, request,
servicer_context, servicer_context,
) )
@ -250,9 +260,132 @@ async def _handle_unary_stream_rpc(object method_handler,
_EMPTY_FLAGS, _EMPTY_FLAGS,
) )
cdef tuple ops = (op,) cdef tuple finish_ops = (op,)
if not rpc_state.metadata_sent:
finish_ops = (op, SendInitialMetadataOperation(None, _EMPTY_FLAGS))
rpc_state.status_sent = True rpc_state.status_sent = True
await execute_batch(rpc_state, ops, loop) await execute_batch(rpc_state, finish_ops, loop)
async def _handle_unary_unary_rpc(object method_handler,
RPCState rpc_state,
object loop):
# Receives request message
cdef bytes request_raw = await _receive_message(rpc_state, loop)
# Deserializes the request message
cdef object request_message = deserialize(
method_handler.request_deserializer,
request_raw,
)
# Creates a dedecated ServicerContext
cdef _ServicerContext servicer_context = _ServicerContext(
rpc_state,
None,
None,
loop,
)
# Finishes the application handler
await _finish_handler_with_unary_response(
rpc_state,
method_handler.unary_unary,
request_message,
servicer_context,
method_handler.response_serializer,
loop
)
async def _handle_unary_stream_rpc(object method_handler,
RPCState rpc_state,
object loop):
# Receives request message
cdef bytes request_raw = await _receive_message(rpc_state, loop)
# Deserializes the request message
cdef object request_message = deserialize(
method_handler.request_deserializer,
request_raw,
)
# Creates a dedecated ServicerContext
cdef _ServicerContext servicer_context = _ServicerContext(
rpc_state,
method_handler.request_deserializer,
method_handler.response_serializer,
loop,
)
# Finishes the application handler
await _finish_handler_with_stream_responses(
rpc_state,
method_handler.unary_stream,
request_message,
servicer_context,
loop,
)
async def _message_receiver(_ServicerContext servicer_context):
"""Bridge between the async generator API and the reader-writer API."""
cdef object message
while True:
message = await servicer_context.read()
if message is not EOF:
yield message
else:
break
async def _handle_stream_unary_rpc(object method_handler,
RPCState rpc_state,
object loop):
# Creates a dedecated ServicerContext
cdef _ServicerContext servicer_context = _ServicerContext(
rpc_state,
method_handler.request_deserializer,
None,
loop,
)
# Prepares the request generator
cdef object request_async_iterator = _message_receiver(servicer_context)
# Finishes the application handler
await _finish_handler_with_unary_response(
rpc_state,
method_handler.stream_unary,
request_async_iterator,
servicer_context,
method_handler.response_serializer,
loop
)
async def _handle_stream_stream_rpc(object method_handler,
RPCState rpc_state,
object loop):
# Creates a dedecated ServicerContext
cdef _ServicerContext servicer_context = _ServicerContext(
rpc_state,
method_handler.request_deserializer,
method_handler.response_serializer,
loop,
)
# Prepares the request generator
cdef object request_async_iterator = _message_receiver(servicer_context)
# Finishes the application handler
await _finish_handler_with_stream_responses(
rpc_state,
method_handler.stream_stream,
request_async_iterator,
servicer_context,
loop,
)
async def _handle_exceptions(RPCState rpc_state, object rpc_coro, object loop): async def _handle_exceptions(RPCState rpc_state, object rpc_coro, object loop):
@ -293,6 +426,7 @@ async def _handle_cancellation_from_core(object rpc_task,
# Awaits cancellation from peer. # Awaits cancellation from peer.
await execute_batch(rpc_state, ops, loop) await execute_batch(rpc_state, ops, loop)
rpc_state.client_closed = True
if op.cancelled() and not rpc_task.done(): if op.cancelled() and not rpc_task.done():
# Injects `CancelledError` to halt the RPC coroutine # Injects `CancelledError` to halt the RPC coroutine
rpc_task.cancel() rpc_task.cancel()
@ -311,8 +445,9 @@ async def _schedule_rpc_coro(object rpc_coro,
async def _handle_rpc(list generic_handlers, RPCState rpc_state, object loop): async def _handle_rpc(list generic_handlers, RPCState rpc_state, object loop):
cdef object method_handler
# Finds the method handler (application logic) # Finds the method handler (application logic)
cdef object method_handler = _find_method_handler( method_handler = _find_method_handler(
rpc_state.method().decode(), rpc_state.method().decode(),
generic_handlers, generic_handlers,
) )
@ -328,20 +463,33 @@ async def _handle_rpc(list generic_handlers, RPCState rpc_state, object loop):
) )
return return
# TODO(lidiz) extend to all 4 types of RPC # Handles unary-unary case
if not method_handler.request_streaming and not method_handler.response_streaming:
await _handle_unary_unary_rpc(method_handler,
rpc_state,
loop)
return
# Handles unary-stream case
if not method_handler.request_streaming and method_handler.response_streaming: if not method_handler.request_streaming and method_handler.response_streaming:
try: await _handle_unary_stream_rpc(method_handler,
await _handle_unary_stream_rpc(method_handler,
rpc_state, rpc_state,
loop) loop)
except Exception as e: return
raise
elif not method_handler.request_streaming and not method_handler.response_streaming: # Handles stream-unary case
await _handle_unary_unary_rpc(method_handler, if method_handler.request_streaming and not method_handler.response_streaming:
rpc_state, await _handle_stream_unary_rpc(method_handler,
loop) rpc_state,
else: loop)
raise NotImplementedError() return
# Handles stream-stream case
if method_handler.request_streaming and method_handler.response_streaming:
await _handle_stream_stream_rpc(method_handler,
rpc_state,
loop)
return
class _RequestCallError(Exception): pass class _RequestCallError(Exception): pass

@ -22,7 +22,7 @@ from typing import Any, Optional, Sequence, Text, Tuple
import six import six
import grpc import grpc
from grpc._cython.cygrpc import init_grpc_aio, AbortError from grpc._cython.cygrpc import EOF, AbortError, init_grpc_aio
from ._base_call import Call, RpcContext, UnaryStreamCall, UnaryUnaryCall from ._base_call import Call, RpcContext, UnaryStreamCall, UnaryUnaryCall
from ._call import AioRpcError from ._call import AioRpcError
@ -86,5 +86,5 @@ __all__ = ('AioRpcError', 'RpcContext', 'Call', 'UnaryUnaryCall',
'UnaryStreamCall', 'init_grpc_aio', 'Channel', 'UnaryStreamCall', 'init_grpc_aio', 'Channel',
'UnaryUnaryMultiCallable', 'ClientCallDetails', 'UnaryUnaryMultiCallable', 'ClientCallDetails',
'UnaryUnaryClientInterceptor', 'InterceptedUnaryUnaryCall', 'UnaryUnaryClientInterceptor', 'InterceptedUnaryUnaryCall',
'insecure_channel', 'secure_channel', 'server', 'Server', 'insecure_channel', 'server', 'Server', 'EOF', 'secure_channel',
'AbortError') 'AbortError')

@ -19,11 +19,12 @@ RPC, e.g. cancellation.
""" """
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from typing import Any, AsyncIterable, Awaitable, Callable, Generic, Text, Optional from typing import (Any, AsyncIterable, Awaitable, Callable, Generic, Optional,
Text, Union)
import grpc import grpc
from ._typing import MetadataType, RequestType, ResponseType from ._typing import EOFType, MetadataType, RequestType, ResponseType
__all__ = 'RpcContext', 'Call', 'UnaryUnaryCall', 'UnaryStreamCall' __all__ = 'RpcContext', 'Call', 'UnaryUnaryCall', 'UnaryStreamCall'
@ -146,14 +147,85 @@ class UnaryStreamCall(Generic[RequestType, ResponseType],
""" """
@abstractmethod @abstractmethod
async def read(self) -> ResponseType: async def read(self) -> Union[EOFType, ResponseType]:
"""Reads one message from the RPC. """Reads one message from the stream.
For each streaming RPC, concurrent reads in multiple coroutines are not Read operations must be serialized when called from multiple
allowed. If you want to perform read in multiple coroutines, you needs coroutines.
synchronization. So, you can start another read after current read is
finished.
Returns: Returns:
A response message of the RPC. A response message, or an `grpc.aio.EOF` to indicate the end of the
stream.
"""
class StreamUnaryCall(Generic[RequestType, ResponseType],
Call,
metaclass=ABCMeta):
@abstractmethod
async def write(self, request: RequestType) -> None:
"""Writes one message to the stream.
Raises:
An RpcError exception if the write failed.
"""
@abstractmethod
async def done_writing(self) -> None:
"""Notifies server that the client is done sending messages.
After done_writing is called, any additional invocation to the write
function will fail. This function is idempotent.
"""
@abstractmethod
def __await__(self) -> Awaitable[ResponseType]:
"""Await the response message to be ready.
Returns:
The response message of the stream.
"""
class StreamStreamCall(Generic[RequestType, ResponseType],
Call,
metaclass=ABCMeta):
@abstractmethod
def __aiter__(self) -> AsyncIterable[ResponseType]:
"""Returns the async iterable representation that yields messages.
Under the hood, it is calling the "read" method.
Returns:
An async iterable object that yields messages.
"""
@abstractmethod
async def read(self) -> Union[EOFType, ResponseType]:
"""Reads one message from the stream.
Read operations must be serialized when called from multiple
coroutines.
Returns:
A response message, or an `grpc.aio.EOF` to indicate the end of the
stream.
"""
@abstractmethod
async def write(self, request: RequestType) -> None:
"""Writes one message to the stream.
Raises:
An RpcError exception if the write failed.
"""
@abstractmethod
async def done_writing(self) -> None:
"""Notifies server that the client is done sending messages.
After done_writing is called, any additional invocation to the write
function will fail. This function is idempotent.
""" """

@ -29,6 +29,7 @@ __all__ = 'AioRpcError', 'Call', 'UnaryUnaryCall', 'UnaryStreamCall'
_LOCAL_CANCELLATION_DETAILS = 'Locally cancelled by application!' _LOCAL_CANCELLATION_DETAILS = 'Locally cancelled by application!'
_GC_CANCELLATION_DETAILS = 'Cancelled upon garbage collection!' _GC_CANCELLATION_DETAILS = 'Cancelled upon garbage collection!'
_RPC_ALREADY_FINISHED_DETAILS = 'RPC already finished.' _RPC_ALREADY_FINISHED_DETAILS = 'RPC already finished.'
_RPC_HALF_CLOSED_DETAILS = 'RPC is half closed after calling "done_writing".'
_OK_CALL_REPRESENTATION = ('<{} of RPC that terminated with:\n' _OK_CALL_REPRESENTATION = ('<{} of RPC that terminated with:\n'
'\tstatus = {}\n' '\tstatus = {}\n'
@ -146,31 +147,48 @@ def _create_rpc_error(initial_metadata: Optional[MetadataType],
class Call(_base_call.Call): class Call(_base_call.Call):
"""Base implementation of client RPC Call object.
Implements logic around final status, metadata and cancellation.
"""
_loop: asyncio.AbstractEventLoop _loop: asyncio.AbstractEventLoop
_code: grpc.StatusCode _code: grpc.StatusCode
_status: Awaitable[cygrpc.AioRpcStatus] _status: Awaitable[cygrpc.AioRpcStatus]
_initial_metadata: Awaitable[MetadataType] _initial_metadata: Awaitable[MetadataType]
_locally_cancelled: bool _locally_cancelled: bool
_cython_call: cygrpc._AioCall
def __init__(self) -> None: def __init__(self, cython_call: cygrpc._AioCall) -> None:
self._loop = asyncio.get_event_loop() self._loop = asyncio.get_event_loop()
self._code = None self._code = None
self._status = self._loop.create_future() self._status = self._loop.create_future()
self._initial_metadata = self._loop.create_future() self._initial_metadata = self._loop.create_future()
self._locally_cancelled = False self._locally_cancelled = False
self._cython_call = cython_call
def cancel(self) -> bool: def __del__(self) -> None:
"""Placeholder cancellation method. if not self._status.done():
self._cancel(
The implementation of this method needs to pass the cancellation reason cygrpc.AioRpcStatus(cygrpc.StatusCode.cancelled,
into self._cancellation, using `set_result` instead of _GC_CANCELLATION_DETAILS, None, None))
`set_exception`.
"""
raise NotImplementedError()
def cancelled(self) -> bool: def cancelled(self) -> bool:
return self._code == grpc.StatusCode.CANCELLED return self._code == grpc.StatusCode.CANCELLED
def _cancel(self, status: cygrpc.AioRpcStatus) -> bool:
"""Forwards the application cancellation reasoning."""
if not self._status.done():
self._set_status(status)
self._cython_call.cancel(status)
return True
else:
return False
def cancel(self) -> bool:
return self._cancel(
cygrpc.AioRpcStatus(cygrpc.StatusCode.cancelled,
_LOCAL_CANCELLATION_DETAILS, None, None))
def done(self) -> bool: def done(self) -> bool:
return self._status.done() return self._status.done()
@ -247,6 +265,7 @@ class Call(_base_call.Call):
return self._repr() return self._repr()
# TODO(https://github.com/grpc/grpc/issues/21623) remove this suppression
# pylint: disable=abstract-method # pylint: disable=abstract-method
class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall): class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall):
"""Object for managing unary-unary RPC calls. """Object for managing unary-unary RPC calls.
@ -254,37 +273,29 @@ class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall):
Returned when an instance of `UnaryUnaryMultiCallable` object is called. Returned when an instance of `UnaryUnaryMultiCallable` object is called.
""" """
_request: RequestType _request: RequestType
_channel: cygrpc.AioChannel
_request_serializer: SerializingFunction _request_serializer: SerializingFunction
_response_deserializer: DeserializingFunction _response_deserializer: DeserializingFunction
_call: asyncio.Task _call: asyncio.Task
_cython_call: cygrpc._AioCall
def __init__( # pylint: disable=R0913 # pylint: disable=too-many-arguments
self, request: RequestType, deadline: Optional[float], def __init__(self, request: RequestType, deadline: Optional[float],
credentials: Optional[grpc.CallCredentials], credentials: Optional[grpc.CallCredentials],
channel: cygrpc.AioChannel, method: bytes, channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction, request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction) -> None: response_deserializer: DeserializingFunction) -> None:
super().__init__() channel.call(method, deadline, credentials)
super().__init__(channel.call(method, deadline, credentials))
self._request = request self._request = request
self._channel = channel
self._request_serializer = request_serializer self._request_serializer = request_serializer
self._response_deserializer = response_deserializer self._response_deserializer = response_deserializer
if credentials is not None:
grpc_credentials = credentials._credentials
else:
grpc_credentials = None
self._cython_call = self._channel.call(method, deadline,
grpc_credentials)
self._call = self._loop.create_task(self._invoke()) self._call = self._loop.create_task(self._invoke())
def __del__(self) -> None: def cancel(self) -> bool:
if not self._call.done(): if super().cancel():
self._cancel( self._call.cancel()
cygrpc.AioRpcStatus(cygrpc.StatusCode.cancelled, return True
_GC_CANCELLATION_DETAILS, None, None)) else:
return False
async def _invoke(self) -> ResponseType: async def _invoke(self) -> ResponseType:
serialized_request = _common.serialize(self._request, serialized_request = _common.serialize(self._request,
@ -300,7 +311,7 @@ class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall):
self._set_status, self._set_status,
) )
except asyncio.CancelledError: except asyncio.CancelledError:
if self._code != grpc.StatusCode.CANCELLED: if not self.cancelled():
self.cancel() self.cancel()
# Raises here if RPC failed or cancelled # Raises here if RPC failed or cancelled
@ -309,21 +320,6 @@ class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall):
return _common.deserialize(serialized_response, return _common.deserialize(serialized_response,
self._response_deserializer) self._response_deserializer)
def _cancel(self, status: cygrpc.AioRpcStatus) -> bool:
"""Forwards the application cancellation reasoning."""
if not self._status.done():
self._set_status(status)
self._cython_call.cancel(status)
self._call.cancel()
return True
else:
return False
def cancel(self) -> bool:
return self._cancel(
cygrpc.AioRpcStatus(cygrpc.StatusCode.cancelled,
_LOCAL_CANCELLATION_DETAILS, None, None))
def __await__(self) -> ResponseType: def __await__(self) -> ResponseType:
"""Wait till the ongoing RPC request finishes.""" """Wait till the ongoing RPC request finishes."""
try: try:
@ -339,6 +335,7 @@ class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall):
return response return response
# TODO(https://github.com/grpc/grpc/issues/21623) remove this suppression
# pylint: disable=abstract-method # pylint: disable=abstract-method
class UnaryStreamCall(Call, _base_call.UnaryStreamCall): class UnaryStreamCall(Call, _base_call.UnaryStreamCall):
"""Object for managing unary-stream RPC calls. """Object for managing unary-stream RPC calls.
@ -346,107 +343,346 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall):
Returned when an instance of `UnaryStreamMultiCallable` object is called. Returned when an instance of `UnaryStreamMultiCallable` object is called.
""" """
_request: RequestType _request: RequestType
_channel: cygrpc.AioChannel
_request_serializer: SerializingFunction _request_serializer: SerializingFunction
_response_deserializer: DeserializingFunction _response_deserializer: DeserializingFunction
_cython_call: cygrpc._AioCall
_send_unary_request_task: asyncio.Task _send_unary_request_task: asyncio.Task
_message_aiter: AsyncIterable[ResponseType] _message_aiter: AsyncIterable[ResponseType]
def __init__( # pylint: disable=R0913 # pylint: disable=too-many-arguments
self, request: RequestType, deadline: Optional[float], def __init__(self, request: RequestType, deadline: Optional[float],
credentials: Optional[grpc.CallCredentials], credentials: Optional[grpc.CallCredentials],
channel: cygrpc.AioChannel, method: bytes, channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction, request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction) -> None: response_deserializer: DeserializingFunction) -> None:
super().__init__() super().__init__(channel.call(method, deadline, credentials))
self._request = request self._request = request
self._channel = channel
self._request_serializer = request_serializer self._request_serializer = request_serializer
self._response_deserializer = response_deserializer self._response_deserializer = response_deserializer
self._send_unary_request_task = self._loop.create_task( self._send_unary_request_task = self._loop.create_task(
self._send_unary_request()) self._send_unary_request())
self._message_aiter = self._fetch_stream_responses() self._message_aiter = None
if credentials is not None: def cancel(self) -> bool:
grpc_credentials = credentials._credentials if super().cancel():
self._send_unary_request_task.cancel()
return True
else: else:
grpc_credentials = None return False
self._cython_call = self._channel.call(method, deadline,
grpc_credentials)
def __del__(self) -> None:
if not self._status.done():
self._cancel(
cygrpc.AioRpcStatus(cygrpc.StatusCode.cancelled,
_GC_CANCELLATION_DETAILS, None, None))
async def _send_unary_request(self) -> ResponseType: async def _send_unary_request(self) -> ResponseType:
serialized_request = _common.serialize(self._request, serialized_request = _common.serialize(self._request,
self._request_serializer) self._request_serializer)
try: try:
await self._cython_call.unary_stream(serialized_request, await self._cython_call.initiate_unary_stream(
self._set_initial_metadata, serialized_request, self._set_initial_metadata,
self._set_status) self._set_status)
except asyncio.CancelledError: except asyncio.CancelledError:
if self._code != grpc.StatusCode.CANCELLED: if not self.cancelled():
self.cancel() self.cancel()
raise raise
async def _fetch_stream_responses(self) -> ResponseType: async def _fetch_stream_responses(self) -> ResponseType:
await self._send_unary_request_task
message = await self._read() message = await self._read()
while message: while message is not cygrpc.EOF:
yield message yield message
message = await self._read() message = await self._read()
def _cancel(self, status: cygrpc.AioRpcStatus) -> bool: def __aiter__(self) -> AsyncIterable[ResponseType]:
"""Forwards the application cancellation reasoning. if self._message_aiter is None:
self._message_aiter = self._fetch_stream_responses()
return self._message_aiter
Async generator will receive an exception. The cancellation will go async def _read(self) -> ResponseType:
deep down into Core, and then propagates backup as the # Wait for the request being sent
`cygrpc.AioRpcStatus` exception. await self._send_unary_request_task
So, under race condition, e.g. the server sent out final state headers # Reads response message from Core
and the client calling "cancel" at the same time, this method respects try:
the winner in Core. raw_response = await self._cython_call.receive_serialized_message()
""" except asyncio.CancelledError:
if not self._status.done(): if not self.cancelled():
self._set_status(status) self.cancel()
self._cython_call.cancel(status) await self._raise_for_status()
if raw_response is cygrpc.EOF:
return cygrpc.EOF
else:
return _common.deserialize(raw_response,
self._response_deserializer)
async def read(self) -> ResponseType:
if self._status.done():
await self._raise_for_status()
return cygrpc.EOF
response_message = await self._read()
if response_message is cygrpc.EOF:
# If the read operation failed, Core should explain why.
await self._raise_for_status()
return response_message
# TODO(https://github.com/grpc/grpc/issues/21623) remove this suppression
# pylint: disable=abstract-method
class StreamUnaryCall(Call, _base_call.StreamUnaryCall):
"""Object for managing stream-unary RPC calls.
Returned when an instance of `StreamUnaryMultiCallable` object is called.
"""
_metadata: MetadataType
_request_serializer: SerializingFunction
_response_deserializer: DeserializingFunction
_metadata_sent: asyncio.Event
_done_writing: bool
_call_finisher: asyncio.Task
_async_request_poller: asyncio.Task
if not self._send_unary_request_task.done(): # pylint: disable=too-many-arguments
# Injects CancelledError to the Task. The exception will def __init__(self,
# propagate to _fetch_stream_responses as well, if the sending request_async_iterator: Optional[AsyncIterable[RequestType]],
# is not done. deadline: Optional[float],
self._send_unary_request_task.cancel() credentials: Optional[grpc.CallCredentials],
channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction) -> None:
super().__init__(channel.call(method, deadline, credentials))
self._metadata = _EMPTY_METADATA
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
self._metadata_sent = asyncio.Event(loop=self._loop)
self._done_writing = False
self._call_finisher = self._loop.create_task(self._conduct_rpc())
# If user passes in an async iterator, create a consumer Task.
if request_async_iterator is not None:
self._async_request_poller = self._loop.create_task(
self._consume_request_iterator(request_async_iterator))
else:
self._async_request_poller = None
def cancel(self) -> bool:
if super().cancel():
self._call_finisher.cancel()
if self._async_request_poller is not None:
self._async_request_poller.cancel()
return True return True
else: else:
return False return False
def _metadata_sent_observer(self):
self._metadata_sent.set()
async def _conduct_rpc(self) -> ResponseType:
try:
serialized_response = await self._cython_call.stream_unary(
self._metadata,
self._metadata_sent_observer,
self._set_initial_metadata,
self._set_status,
)
except asyncio.CancelledError:
if not self.cancelled():
self.cancel()
# Raises RpcError if the RPC failed or cancelled
await self._raise_for_status()
return _common.deserialize(serialized_response,
self._response_deserializer)
async def _consume_request_iterator(
self, request_async_iterator: AsyncIterable[RequestType]) -> None:
async for request in request_async_iterator:
await self.write(request)
await self.done_writing()
def __await__(self) -> ResponseType:
"""Wait till the ongoing RPC request finishes."""
try:
response = yield from self._call_finisher
except asyncio.CancelledError:
if not self.cancelled():
self.cancel()
raise
return response
async def write(self, request: RequestType) -> None:
if self._status.done():
raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
if self._done_writing:
raise asyncio.InvalidStateError(_RPC_HALF_CLOSED_DETAILS)
if not self._metadata_sent.is_set():
await self._metadata_sent.wait()
serialized_request = _common.serialize(request,
self._request_serializer)
try:
await self._cython_call.send_serialized_message(serialized_request)
except asyncio.CancelledError:
if not self.cancelled():
self.cancel()
await self._raise_for_status()
async def done_writing(self) -> None:
"""Implementation of done_writing is idempotent."""
if self._status.done():
# If the RPC is finished, do nothing.
return
if not self._done_writing:
# If the done writing is not sent before, try to send it.
self._done_writing = True
try:
await self._cython_call.send_receive_close()
except asyncio.CancelledError:
if not self.cancelled():
self.cancel()
await self._raise_for_status()
# TODO(https://github.com/grpc/grpc/issues/21623) remove this suppression
# pylint: disable=abstract-method
class StreamStreamCall(Call, _base_call.StreamStreamCall):
"""Object for managing stream-stream RPC calls.
Returned when an instance of `StreamStreamMultiCallable` object is called.
"""
_metadata: MetadataType
_request_serializer: SerializingFunction
_response_deserializer: DeserializingFunction
_metadata_sent: asyncio.Event
_done_writing: bool
_initializer: asyncio.Task
_async_request_poller: asyncio.Task
_message_aiter: AsyncIterable[ResponseType]
# pylint: disable=too-many-arguments
def __init__(self,
request_async_iterator: Optional[AsyncIterable[RequestType]],
deadline: Optional[float],
credentials: Optional[grpc.CallCredentials],
channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction) -> None:
super().__init__(channel.call(method, deadline, credentials))
self._metadata = _EMPTY_METADATA
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
self._metadata_sent = asyncio.Event(loop=self._loop)
self._done_writing = False
self._initializer = self._loop.create_task(self._prepare_rpc())
# If user passes in an async iterator, create a consumer coroutine.
if request_async_iterator is not None:
self._async_request_poller = self._loop.create_task(
self._consume_request_iterator(request_async_iterator))
else:
self._async_request_poller = None
self._message_aiter = None
def cancel(self) -> bool: def cancel(self) -> bool:
return self._cancel( if super().cancel():
cygrpc.AioRpcStatus(cygrpc.StatusCode.cancelled, self._initializer.cancel()
_LOCAL_CANCELLATION_DETAILS, None, None)) if self._async_request_poller is not None:
self._async_request_poller.cancel()
return True
else:
return False
def _metadata_sent_observer(self):
self._metadata_sent.set()
async def _prepare_rpc(self):
"""This method prepares the RPC for receiving/sending messages.
All other operations around the stream should only happen after the
completion of this method.
"""
try:
await self._cython_call.initiate_stream_stream(
self._metadata,
self._metadata_sent_observer,
self._set_initial_metadata,
self._set_status,
)
except asyncio.CancelledError:
if not self.cancelled():
self.cancel()
# No need to raise RpcError here, because no one will `await` this task.
async def _consume_request_iterator(
self, request_async_iterator: Optional[AsyncIterable[RequestType]]
) -> None:
async for request in request_async_iterator:
await self.write(request)
await self.done_writing()
async def write(self, request: RequestType) -> None:
if self._status.done():
raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
if self._done_writing:
raise asyncio.InvalidStateError(_RPC_HALF_CLOSED_DETAILS)
if not self._metadata_sent.is_set():
await self._metadata_sent.wait()
serialized_request = _common.serialize(request,
self._request_serializer)
try:
await self._cython_call.send_serialized_message(serialized_request)
except asyncio.CancelledError:
if not self.cancelled():
self.cancel()
await self._raise_for_status()
async def done_writing(self) -> None:
"""Implementation of done_writing is idempotent."""
if self._status.done():
# If the RPC is finished, do nothing.
return
if not self._done_writing:
# If the done writing is not sent before, try to send it.
self._done_writing = True
try:
await self._cython_call.send_receive_close()
except asyncio.CancelledError:
if not self.cancelled():
self.cancel()
await self._raise_for_status()
async def _fetch_stream_responses(self) -> ResponseType:
"""The async generator that yields responses from peer."""
message = await self._read()
while message is not cygrpc.EOF:
yield message
message = await self._read()
def __aiter__(self) -> AsyncIterable[ResponseType]: def __aiter__(self) -> AsyncIterable[ResponseType]:
if self._message_aiter is None:
self._message_aiter = self._fetch_stream_responses()
return self._message_aiter return self._message_aiter
async def _read(self) -> ResponseType: async def _read(self) -> ResponseType:
# Wait for the request being sent # Wait for the setup
await self._send_unary_request_task await self._initializer
# Reads response message from Core # Reads response message from Core
try: try:
raw_response = await self._cython_call.receive_serialized_message() raw_response = await self._cython_call.receive_serialized_message()
except asyncio.CancelledError: except asyncio.CancelledError:
if self._code != grpc.StatusCode.CANCELLED: if not self.cancelled():
self.cancel() self.cancel()
raise await self._raise_for_status()
if raw_response is None: if raw_response is cygrpc.EOF:
return None return cygrpc.EOF
else: else:
return _common.deserialize(raw_response, return _common.deserialize(raw_response,
self._response_deserializer) self._response_deserializer)
@ -454,14 +690,11 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall):
async def read(self) -> ResponseType: async def read(self) -> ResponseType:
if self._status.done(): if self._status.done():
await self._raise_for_status() await self._raise_for_status()
raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS) return cygrpc.EOF
response_message = await self._read() response_message = await self._read()
if response_message is None: if response_message is cygrpc.EOF:
# If the read operation failed, Core should explain why. # If the read operation failed, Core should explain why.
await self._raise_for_status() await self._raise_for_status()
# If no exception raised, there is something wrong internally. return response_message
assert False, 'Read operation failed with StatusCode.OK'
else:
return response_message

@ -13,14 +13,15 @@
# limitations under the License. # limitations under the License.
"""Invocation-side implementation of gRPC Asyncio Python.""" """Invocation-side implementation of gRPC Asyncio Python."""
import asyncio import asyncio
from typing import Any, Optional, Sequence, Text from typing import Any, AsyncIterable, Optional, Sequence, Text
import grpc import grpc
from grpc import _common from grpc import _common
from grpc._cython import cygrpc from grpc._cython import cygrpc
from . import _base_call from . import _base_call
from ._call import UnaryStreamCall, UnaryUnaryCall from ._call import (StreamStreamCall, StreamUnaryCall, UnaryStreamCall,
UnaryUnaryCall)
from ._interceptor import (InterceptedUnaryUnaryCall, from ._interceptor import (InterceptedUnaryUnaryCall,
UnaryUnaryClientInterceptor) UnaryUnaryClientInterceptor)
from ._typing import (ChannelArgumentType, DeserializingFunction, MetadataType, from ._typing import (ChannelArgumentType, DeserializingFunction, MetadataType,
@ -28,8 +29,16 @@ from ._typing import (ChannelArgumentType, DeserializingFunction, MetadataType,
from ._utils import _timeout_to_deadline from ._utils import _timeout_to_deadline
class UnaryUnaryMultiCallable: class _BaseMultiCallable:
"""Factory an asynchronous unary-unary RPC stub call from client-side.""" """Base class of all multi callable objects.
Handles the initialization logic and stores common attributes.
"""
_loop: asyncio.AbstractEventLoop
_channel: cygrpc.AioChannel
_method: bytes
_request_serializer: SerializingFunction
_response_deserializer: DeserializingFunction
_channel: cygrpc.AioChannel _channel: cygrpc.AioChannel
_method: bytes _method: bytes
@ -50,6 +59,10 @@ class UnaryUnaryMultiCallable:
self._response_deserializer = response_deserializer self._response_deserializer = response_deserializer
self._interceptors = interceptors self._interceptors = interceptors
class UnaryUnaryMultiCallable(_BaseMultiCallable):
"""Factory an asynchronous unary-unary RPC stub call from client-side."""
def __call__(self, def __call__(self,
request: Any, request: Any,
*, *,
@ -114,17 +127,8 @@ class UnaryUnaryMultiCallable:
) )
class UnaryStreamMultiCallable: class UnaryStreamMultiCallable(_BaseMultiCallable):
"""Afford invoking a unary-stream RPC from client-side in an asynchronous way.""" """Affords invoking a unary-stream RPC from client-side in an asynchronous way."""
def __init__(self, channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction) -> None:
self._channel = channel
self._method = method
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
self._loop = asyncio.get_event_loop()
def __call__(self, def __call__(self,
request: Any, request: Any,
@ -176,6 +180,122 @@ class UnaryStreamMultiCallable:
) )
class StreamUnaryMultiCallable(_BaseMultiCallable):
"""Affords invoking a stream-unary RPC from client-side in an asynchronous way."""
def __call__(self,
request_async_iterator: Optional[AsyncIterable[Any]] = None,
timeout: Optional[float] = None,
metadata: Optional[MetadataType] = None,
credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None
) -> _base_call.StreamUnaryCall:
"""Asynchronously invokes the underlying RPC.
Args:
request: The request value for the RPC.
timeout: An optional duration of time in seconds to allow
for the RPC.
metadata: Optional :term:`metadata` to be transmitted to the
service-side of the RPC.
credentials: An optional CallCredentials for the RPC. Only valid for
secure Channel.
wait_for_ready: This is an EXPERIMENTAL argument. An optional
flag to enable wait for ready mechanism
compression: An element of grpc.compression, e.g.
grpc.compression.Gzip. This is an EXPERIMENTAL option.
Returns:
A Call object instance which is an awaitable object.
Raises:
RpcError: Indicating that the RPC terminated with non-OK status. The
raised RpcError will also be a Call for the RPC affording the RPC's
metadata, status code, and details.
"""
if metadata:
raise NotImplementedError("TODO: metadata not implemented yet")
if wait_for_ready:
raise NotImplementedError(
"TODO: wait_for_ready not implemented yet")
if compression:
raise NotImplementedError("TODO: compression not implemented yet")
deadline = _timeout_to_deadline(timeout)
return StreamUnaryCall(
request_async_iterator,
deadline,
credentials,
self._channel,
self._method,
self._request_serializer,
self._response_deserializer,
)
class StreamStreamMultiCallable(_BaseMultiCallable):
"""Affords invoking a stream-stream RPC from client-side in an asynchronous way."""
def __call__(self,
request_async_iterator: Optional[AsyncIterable[Any]] = None,
timeout: Optional[float] = None,
metadata: Optional[MetadataType] = None,
credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None
) -> _base_call.StreamStreamCall:
"""Asynchronously invokes the underlying RPC.
Args:
request: The request value for the RPC.
timeout: An optional duration of time in seconds to allow
for the RPC.
metadata: Optional :term:`metadata` to be transmitted to the
service-side of the RPC.
credentials: An optional CallCredentials for the RPC. Only valid for
secure Channel.
wait_for_ready: This is an EXPERIMENTAL argument. An optional
flag to enable wait for ready mechanism
compression: An element of grpc.compression, e.g.
grpc.compression.Gzip. This is an EXPERIMENTAL option.
Returns:
A Call object instance which is an awaitable object.
Raises:
RpcError: Indicating that the RPC terminated with non-OK status. The
raised RpcError will also be a Call for the RPC affording the RPC's
metadata, status code, and details.
"""
if metadata:
raise NotImplementedError("TODO: metadata not implemented yet")
if wait_for_ready:
raise NotImplementedError(
"TODO: wait_for_ready not implemented yet")
if compression:
raise NotImplementedError("TODO: compression not implemented yet")
deadline = _timeout_to_deadline(timeout)
return StreamStreamCall(
request_async_iterator,
deadline,
credentials,
self._channel,
self._method,
self._request_serializer,
self._response_deserializer,
)
class Channel: class Channel:
"""Asynchronous Channel implementation. """Asynchronous Channel implementation.
@ -301,21 +421,27 @@ class Channel:
) -> UnaryStreamMultiCallable: ) -> UnaryStreamMultiCallable:
return UnaryStreamMultiCallable(self._channel, _common.encode(method), return UnaryStreamMultiCallable(self._channel, _common.encode(method),
request_serializer, request_serializer,
response_deserializer) response_deserializer, None)
def stream_unary( def stream_unary(
self, self,
method: Text, method: Text,
request_serializer: Optional[SerializingFunction] = None, request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None): response_deserializer: Optional[DeserializingFunction] = None
"""Placeholder method for stream-unary calls.""" ) -> StreamUnaryMultiCallable:
return StreamUnaryMultiCallable(self._channel, _common.encode(method),
request_serializer,
response_deserializer, None)
def stream_stream( def stream_stream(
self, self,
method: Text, method: Text,
request_serializer: Optional[SerializingFunction] = None, request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None): response_deserializer: Optional[DeserializingFunction] = None
"""Placeholder method for stream-stream calls.""" ) -> StreamStreamMultiCallable:
return StreamStreamMultiCallable(self._channel, _common.encode(method),
request_serializer,
response_deserializer, None)
async def _close(self): async def _close(self):
# TODO: Send cancellation status # TODO: Send cancellation status

@ -14,6 +14,7 @@
"""Common types for gRPC Async API""" """Common types for gRPC Async API"""
from typing import Any, AnyStr, Callable, Sequence, Text, Tuple, TypeVar from typing import Any, AnyStr, Callable, Sequence, Text, Tuple, TypeVar
from grpc._cython.cygrpc import EOF
RequestType = TypeVar('RequestType') RequestType = TypeVar('RequestType')
ResponseType = TypeVar('ResponseType') ResponseType = TypeVar('ResponseType')
@ -21,3 +22,4 @@ SerializingFunction = Callable[[Any], bytes]
DeserializingFunction = Callable[[bytes], Any] DeserializingFunction = Callable[[bytes], Any]
MetadataType = Sequence[Tuple[Text, AnyStr]] MetadataType = Sequence[Tuple[Text, AnyStr]]
ChannelArgumentType = Sequence[Tuple[Text, Any]] ChannelArgumentType = Sequence[Tuple[Text, Any]]
EOFType = type(EOF)

@ -2,6 +2,8 @@
"_sanity._sanity_test.AioSanityTest", "_sanity._sanity_test.AioSanityTest",
"unit.abort_test.TestAbort", "unit.abort_test.TestAbort",
"unit.aio_rpc_error_test.TestAioRpcError", "unit.aio_rpc_error_test.TestAioRpcError",
"unit.call_test.TestStreamStreamCall",
"unit.call_test.TestStreamUnaryCall",
"unit.call_test.TestUnaryStreamCall", "unit.call_test.TestUnaryStreamCall",
"unit.call_test.TestUnaryUnaryCall", "unit.call_test.TestUnaryUnaryCall",
"unit.channel_argument_test.TestChannelArgument", "unit.channel_argument_test.TestChannelArgument",

@ -26,11 +26,12 @@ from tests_aio.unit._constants import UNARY_CALL_WITH_SLEEP_VALUE
class _TestServiceServicer(test_pb2_grpc.TestServiceServicer): class _TestServiceServicer(test_pb2_grpc.TestServiceServicer):
async def UnaryCall(self, request, context): async def UnaryCall(self, unused_request, unused_context):
return messages_pb2.SimpleResponse() return messages_pb2.SimpleResponse()
async def StreamingOutputCall( async def StreamingOutputCall(
self, request: messages_pb2.StreamingOutputCallRequest, context): self, request: messages_pb2.StreamingOutputCallRequest,
unused_context):
for response_parameters in request.response_parameters: for response_parameters in request.response_parameters:
if response_parameters.interval_us != 0: if response_parameters.interval_us != 0:
await asyncio.sleep( await asyncio.sleep(
@ -44,11 +45,30 @@ class _TestServiceServicer(test_pb2_grpc.TestServiceServicer):
# Next methods are extra ones that are registred programatically # Next methods are extra ones that are registred programatically
# when the sever is instantiated. They are not being provided by # when the sever is instantiated. They are not being provided by
# the proto file. # the proto file.
async def UnaryCallWithSleep(self, request, context): async def UnaryCallWithSleep(self, request, context):
await asyncio.sleep(UNARY_CALL_WITH_SLEEP_VALUE) await asyncio.sleep(UNARY_CALL_WITH_SLEEP_VALUE)
return messages_pb2.SimpleResponse() return messages_pb2.SimpleResponse()
async def StreamingInputCall(self, request_async_iterator, unused_context):
aggregate_size = 0
async for request in request_async_iterator:
if request.payload is not None and request.payload.body:
aggregate_size += len(request.payload.body)
return messages_pb2.StreamingInputCallResponse(
aggregated_payload_size=aggregate_size)
async def FullDuplexCall(self, request_async_iterator, unused_context):
async for request in request_async_iterator:
for response_parameters in request.response_parameters:
if response_parameters.interval_us != 0:
await asyncio.sleep(
datetime.timedelta(microseconds=response_parameters.
interval_us).total_seconds())
yield messages_pb2.StreamingOutputCallResponse(
payload=messages_pb2.Payload(type=request.payload.type,
body=b'\x00' *
response_parameters.size))
async def start_test_server(secure=False): async def start_test_server(secure=False):
server = aio.server(options=(('grpc.so_reuseport', 0),)) server = aio.server(options=(('grpc.so_reuseport', 0),))

@ -30,10 +30,10 @@ from src.proto.grpc.testing import messages_pb2
_NUM_STREAM_RESPONSES = 5 _NUM_STREAM_RESPONSES = 5
_RESPONSE_PAYLOAD_SIZE = 42 _RESPONSE_PAYLOAD_SIZE = 42
_REQUEST_PAYLOAD_SIZE = 7
_LOCAL_CANCEL_DETAILS_EXPECTATION = 'Locally cancelled by application!' _LOCAL_CANCEL_DETAILS_EXPECTATION = 'Locally cancelled by application!'
_RESPONSE_INTERVAL_US = test_constants.SHORT_TIMEOUT * 1000 * 1000 _RESPONSE_INTERVAL_US = test_constants.SHORT_TIMEOUT * 1000 * 1000
_UNREACHABLE_TARGET = '0.1:1111' _UNREACHABLE_TARGET = '0.1:1111'
_INFINITE_INTERVAL_US = 2**31 - 1 _INFINITE_INTERVAL_US = 2**31 - 1
@ -286,7 +286,7 @@ class TestUnaryStreamCall(AioTestBase):
[grpc.StatusCode.OK, grpc.StatusCode.CANCELLED]) [grpc.StatusCode.OK, grpc.StatusCode.CANCELLED])
async def test_too_many_reads_unary_stream(self): async def test_too_many_reads_unary_stream(self):
"""Test cancellation after received all messages.""" """Test calling read after received all messages fails."""
async with aio.insecure_channel(self._server_target) as channel: async with aio.insecure_channel(self._server_target) as channel:
stub = test_pb2_grpc.TestServiceStub(channel) stub = test_pb2_grpc.TestServiceStub(channel)
@ -306,13 +306,14 @@ class TestUnaryStreamCall(AioTestBase):
messages_pb2.StreamingOutputCallResponse) messages_pb2.StreamingOutputCallResponse)
self.assertEqual(_RESPONSE_PAYLOAD_SIZE, self.assertEqual(_RESPONSE_PAYLOAD_SIZE,
len(response.payload.body)) len(response.payload.body))
self.assertIs(await call.read(), aio.EOF)
# After the RPC is finished, further reads will lead to exception. # After the RPC is finished, further reads will lead to exception.
self.assertEqual(await call.code(), grpc.StatusCode.OK) self.assertEqual(await call.code(), grpc.StatusCode.OK)
with self.assertRaises(asyncio.InvalidStateError): self.assertIs(await call.read(), aio.EOF)
await call.read()
async def test_unary_stream_async_generator(self): async def test_unary_stream_async_generator(self):
"""Sunny day test case for unary_stream."""
async with aio.insecure_channel(self._server_target) as channel: async with aio.insecure_channel(self._server_target) as channel:
stub = test_pb2_grpc.TestServiceStub(channel) stub = test_pb2_grpc.TestServiceStub(channel)
@ -426,6 +427,307 @@ class TestUnaryStreamCall(AioTestBase):
self.loop.run_until_complete(coro()) self.loop.run_until_complete(coro())
class TestStreamUnaryCall(AioTestBase):
async def setUp(self):
self._server_target, self._server = await start_test_server()
self._channel = aio.insecure_channel(self._server_target)
self._stub = test_pb2_grpc.TestServiceStub(self._channel)
async def tearDown(self):
await self._channel.close()
await self._server.stop(None)
async def test_cancel_stream_unary(self):
call = self._stub.StreamingInputCall()
# Prepares the request
payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE)
request = messages_pb2.StreamingInputCallRequest(payload=payload)
# Sends out requests
for _ in range(_NUM_STREAM_RESPONSES):
await call.write(request)
# Cancels the RPC
self.assertFalse(call.done())
self.assertFalse(call.cancelled())
self.assertTrue(call.cancel())
self.assertTrue(call.cancelled())
await call.done_writing()
with self.assertRaises(asyncio.CancelledError):
await call
async def test_early_cancel_stream_unary(self):
call = self._stub.StreamingInputCall()
# Cancels the RPC
self.assertFalse(call.done())
self.assertFalse(call.cancelled())
self.assertTrue(call.cancel())
self.assertTrue(call.cancelled())
with self.assertRaises(asyncio.InvalidStateError):
await call.write(messages_pb2.StreamingInputCallRequest())
# Should be no-op
await call.done_writing()
with self.assertRaises(asyncio.CancelledError):
await call
async def test_write_after_done_writing(self):
call = self._stub.StreamingInputCall()
# Prepares the request
payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE)
request = messages_pb2.StreamingInputCallRequest(payload=payload)
# Sends out requests
for _ in range(_NUM_STREAM_RESPONSES):
await call.write(request)
# Should be no-op
await call.done_writing()
with self.assertRaises(asyncio.InvalidStateError):
await call.write(messages_pb2.StreamingInputCallRequest())
response = await call
self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse)
self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE,
response.aggregated_payload_size)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
async def test_error_in_async_generator(self):
# Server will pause between responses
request = messages_pb2.StreamingOutputCallRequest()
request.response_parameters.append(
messages_pb2.ResponseParameters(
size=_RESPONSE_PAYLOAD_SIZE,
interval_us=_RESPONSE_INTERVAL_US,
))
# We expect the request iterator to receive the exception
request_iterator_received_the_exception = asyncio.Event()
async def request_iterator():
with self.assertRaises(asyncio.CancelledError):
for _ in range(_NUM_STREAM_RESPONSES):
yield request
await asyncio.sleep(test_constants.SHORT_TIMEOUT)
request_iterator_received_the_exception.set()
call = self._stub.StreamingInputCall(request_iterator())
# Cancel the RPC after at least one response
async def cancel_later():
await asyncio.sleep(test_constants.SHORT_TIMEOUT * 2)
call.cancel()
cancel_later_task = self.loop.create_task(cancel_later())
# No exceptions here
with self.assertRaises(asyncio.CancelledError):
await call
await request_iterator_received_the_exception.wait()
# No failures in the cancel later task!
await cancel_later_task
# Prepares the request that stream in a ping-pong manner.
_STREAM_OUTPUT_REQUEST_ONE_RESPONSE = messages_pb2.StreamingOutputCallRequest()
_STREAM_OUTPUT_REQUEST_ONE_RESPONSE.response_parameters.append(
messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE))
class TestStreamStreamCall(AioTestBase):
async def setUp(self):
self._server_target, self._server = await start_test_server()
self._channel = aio.insecure_channel(self._server_target)
self._stub = test_pb2_grpc.TestServiceStub(self._channel)
async def tearDown(self):
await self._channel.close()
await self._server.stop(None)
async def test_cancel(self):
# Invokes the actual RPC
call = self._stub.FullDuplexCall()
for _ in range(_NUM_STREAM_RESPONSES):
await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE)
response = await call.read()
self.assertIsInstance(response,
messages_pb2.StreamingOutputCallResponse)
self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
# Cancels the RPC
self.assertFalse(call.done())
self.assertFalse(call.cancelled())
self.assertTrue(call.cancel())
self.assertTrue(call.cancelled())
self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
async def test_cancel_with_pending_read(self):
call = self._stub.FullDuplexCall()
await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE)
# Cancels the RPC
self.assertFalse(call.done())
self.assertFalse(call.cancelled())
self.assertTrue(call.cancel())
self.assertTrue(call.cancelled())
self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
async def test_cancel_with_ongoing_read(self):
call = self._stub.FullDuplexCall()
coro_started = asyncio.Event()
async def read_coro():
coro_started.set()
await call.read()
read_task = self.loop.create_task(read_coro())
await coro_started.wait()
self.assertFalse(read_task.done())
# Cancels the RPC
self.assertFalse(call.done())
self.assertFalse(call.cancelled())
self.assertTrue(call.cancel())
self.assertTrue(call.cancelled())
self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
async def test_early_cancel(self):
call = self._stub.FullDuplexCall()
# Cancels the RPC
self.assertFalse(call.done())
self.assertFalse(call.cancelled())
self.assertTrue(call.cancel())
self.assertTrue(call.cancelled())
self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
async def test_cancel_after_done_writing(self):
call = self._stub.FullDuplexCall()
await call.done_writing()
# Cancels the RPC
self.assertFalse(call.done())
self.assertFalse(call.cancelled())
self.assertTrue(call.cancel())
self.assertTrue(call.cancelled())
self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
async def test_late_cancel(self):
call = self._stub.FullDuplexCall()
await call.done_writing()
self.assertEqual(grpc.StatusCode.OK, await call.code())
# Cancels the RPC
self.assertTrue(call.done())
self.assertFalse(call.cancelled())
self.assertFalse(call.cancel())
self.assertFalse(call.cancelled())
# Status is still OK
self.assertEqual(grpc.StatusCode.OK, await call.code())
async def test_async_generator(self):
async def request_generator():
yield _STREAM_OUTPUT_REQUEST_ONE_RESPONSE
yield _STREAM_OUTPUT_REQUEST_ONE_RESPONSE
call = self._stub.FullDuplexCall(request_generator())
async for response in call:
self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
self.assertEqual(await call.code(), grpc.StatusCode.OK)
async def test_too_many_reads(self):
async def request_generator():
for _ in range(_NUM_STREAM_RESPONSES):
yield _STREAM_OUTPUT_REQUEST_ONE_RESPONSE
call = self._stub.FullDuplexCall(request_generator())
for _ in range(_NUM_STREAM_RESPONSES):
response = await call.read()
self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
self.assertIs(await call.read(), aio.EOF)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
# After the RPC finished, the read should also produce EOF
self.assertIs(await call.read(), aio.EOF)
async def test_read_write_after_done_writing(self):
call = self._stub.FullDuplexCall()
# Writes two requests, and pending two requests
await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE)
await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE)
await call.done_writing()
# Further write should fail
with self.assertRaises(asyncio.InvalidStateError):
await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE)
# But read should be unaffected
response = await call.read()
self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
response = await call.read()
self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
self.assertEqual(await call.code(), grpc.StatusCode.OK)
async def test_error_in_async_generator(self):
# Server will pause between responses
request = messages_pb2.StreamingOutputCallRequest()
request.response_parameters.append(
messages_pb2.ResponseParameters(
size=_RESPONSE_PAYLOAD_SIZE,
interval_us=_RESPONSE_INTERVAL_US,
))
# We expect the request iterator to receive the exception
request_iterator_received_the_exception = asyncio.Event()
async def request_iterator():
with self.assertRaises(asyncio.CancelledError):
for _ in range(_NUM_STREAM_RESPONSES):
yield request
await asyncio.sleep(test_constants.SHORT_TIMEOUT)
request_iterator_received_the_exception.set()
call = self._stub.FullDuplexCall(request_iterator())
# Cancel the RPC after at least one response
async def cancel_later():
await asyncio.sleep(test_constants.SHORT_TIMEOUT * 2)
call.cancel()
cancel_later_task = self.loop.create_task(cancel_later())
# No exceptions here
async for response in call:
self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
await request_iterator_received_the_exception.wait()
self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
# No failures in the cancel later task!
await cancel_later_task
if __name__ == '__main__': if __name__ == '__main__':
logging.basicConfig() logging.basicConfig()
unittest.main(verbosity=2) unittest.main(verbosity=2)

@ -32,6 +32,7 @@ _UNARY_CALL_METHOD = '/grpc.testing.TestService/UnaryCall'
_UNARY_CALL_METHOD_WITH_SLEEP = '/grpc.testing.TestService/UnaryCallWithSleep' _UNARY_CALL_METHOD_WITH_SLEEP = '/grpc.testing.TestService/UnaryCallWithSleep'
_STREAMING_OUTPUT_CALL_METHOD = '/grpc.testing.TestService/StreamingOutputCall' _STREAMING_OUTPUT_CALL_METHOD = '/grpc.testing.TestService/StreamingOutputCall'
_NUM_STREAM_RESPONSES = 5 _NUM_STREAM_RESPONSES = 5
_REQUEST_PAYLOAD_SIZE = 7
_RESPONSE_PAYLOAD_SIZE = 42 _RESPONSE_PAYLOAD_SIZE = 42
@ -121,7 +122,104 @@ class TestChannel(AioTestBase):
self.assertEqual(await call.code(), grpc.StatusCode.OK) self.assertEqual(await call.code(), grpc.StatusCode.OK)
await channel.close() await channel.close()
async def test_stream_unary_using_write(self):
channel = aio.insecure_channel(self._server_target)
stub = test_pb2_grpc.TestServiceStub(channel)
# Invokes the actual RPC
call = stub.StreamingInputCall()
# Prepares the request
payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE)
request = messages_pb2.StreamingInputCallRequest(payload=payload)
# Sends out requests
for _ in range(_NUM_STREAM_RESPONSES):
await call.write(request)
await call.done_writing()
# Validates the responses
response = await call
self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse)
self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE,
response.aggregated_payload_size)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
await channel.close()
async def test_stream_unary_using_async_gen(self):
channel = aio.insecure_channel(self._server_target)
stub = test_pb2_grpc.TestServiceStub(channel)
# Prepares the request
payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE)
request = messages_pb2.StreamingInputCallRequest(payload=payload)
async def gen():
for _ in range(_NUM_STREAM_RESPONSES):
yield request
# Invokes the actual RPC
call = stub.StreamingInputCall(gen())
# Validates the responses
response = await call
self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse)
self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE,
response.aggregated_payload_size)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
await channel.close()
async def test_stream_stream_using_read_write(self):
channel = aio.insecure_channel(self._server_target)
stub = test_pb2_grpc.TestServiceStub(channel)
# Invokes the actual RPC
call = stub.FullDuplexCall()
# Prepares the request
request = messages_pb2.StreamingOutputCallRequest()
request.response_parameters.append(
messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE))
for _ in range(_NUM_STREAM_RESPONSES):
await call.write(request)
response = await call.read()
self.assertIsInstance(response,
messages_pb2.StreamingOutputCallResponse)
self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
await call.done_writing()
self.assertEqual(grpc.StatusCode.OK, await call.code())
await channel.close()
async def test_stream_stream_using_async_gen(self):
channel = aio.insecure_channel(self._server_target)
stub = test_pb2_grpc.TestServiceStub(channel)
# Prepares the request
request = messages_pb2.StreamingOutputCallRequest()
request.response_parameters.append(
messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE))
async def gen():
for _ in range(_NUM_STREAM_RESPONSES):
yield request
# Invokes the actual RPC
call = stub.FullDuplexCall(gen())
async for response in call:
self.assertIsInstance(response,
messages_pb2.StreamingOutputCallResponse)
self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
self.assertEqual(grpc.StatusCode.OK, await call.code())
await channel.close()
if __name__ == '__main__': if __name__ == '__main__':
logging.basicConfig(level=logging.WARN) logging.basicConfig(level=logging.DEBUG)
unittest.main(verbosity=2) unittest.main(verbosity=2)

@ -13,15 +13,16 @@
# limitations under the License. # limitations under the License.
import asyncio import asyncio
import gc
import logging import logging
import unittest
import time import time
import gc import unittest
import grpc import grpc
from grpc.experimental import aio from grpc.experimental import aio
from tests_aio.unit._test_base import AioTestBase
from tests.unit.framework.common import test_constants from tests.unit.framework.common import test_constants
from tests_aio.unit._test_base import AioTestBase
_SIMPLE_UNARY_UNARY = '/test/SimpleUnaryUnary' _SIMPLE_UNARY_UNARY = '/test/SimpleUnaryUnary'
_BLOCK_FOREVER = '/test/BlockForever' _BLOCK_FOREVER = '/test/BlockForever'
@ -29,9 +30,16 @@ _BLOCK_BRIEFLY = '/test/BlockBriefly'
_UNARY_STREAM_ASYNC_GEN = '/test/UnaryStreamAsyncGen' _UNARY_STREAM_ASYNC_GEN = '/test/UnaryStreamAsyncGen'
_UNARY_STREAM_READER_WRITER = '/test/UnaryStreamReaderWriter' _UNARY_STREAM_READER_WRITER = '/test/UnaryStreamReaderWriter'
_UNARY_STREAM_EVILLY_MIXED = '/test/UnaryStreamEvillyMixed' _UNARY_STREAM_EVILLY_MIXED = '/test/UnaryStreamEvillyMixed'
_STREAM_UNARY_ASYNC_GEN = '/test/StreamUnaryAsyncGen'
_STREAM_UNARY_READER_WRITER = '/test/StreamUnaryReaderWriter'
_STREAM_UNARY_EVILLY_MIXED = '/test/StreamUnaryEvillyMixed'
_STREAM_STREAM_ASYNC_GEN = '/test/StreamStreamAsyncGen'
_STREAM_STREAM_READER_WRITER = '/test/StreamStreamReaderWriter'
_STREAM_STREAM_EVILLY_MIXED = '/test/StreamStreamEvillyMixed'
_REQUEST = b'\x00\x00\x00' _REQUEST = b'\x00\x00\x00'
_RESPONSE = b'\x01\x01\x01' _RESPONSE = b'\x01\x01\x01'
_NUM_STREAM_REQUESTS = 3
_NUM_STREAM_RESPONSES = 5 _NUM_STREAM_RESPONSES = 5
@ -39,6 +47,41 @@ class _GenericHandler(grpc.GenericRpcHandler):
def __init__(self): def __init__(self):
self._called = asyncio.get_event_loop().create_future() self._called = asyncio.get_event_loop().create_future()
self._routing_table = {
_SIMPLE_UNARY_UNARY:
grpc.unary_unary_rpc_method_handler(self._unary_unary),
_BLOCK_FOREVER:
grpc.unary_unary_rpc_method_handler(self._block_forever),
_BLOCK_BRIEFLY:
grpc.unary_unary_rpc_method_handler(self._block_briefly),
_UNARY_STREAM_ASYNC_GEN:
grpc.unary_stream_rpc_method_handler(
self._unary_stream_async_gen),
_UNARY_STREAM_READER_WRITER:
grpc.unary_stream_rpc_method_handler(
self._unary_stream_reader_writer),
_UNARY_STREAM_EVILLY_MIXED:
grpc.unary_stream_rpc_method_handler(
self._unary_stream_evilly_mixed),
_STREAM_UNARY_ASYNC_GEN:
grpc.stream_unary_rpc_method_handler(
self._stream_unary_async_gen),
_STREAM_UNARY_READER_WRITER:
grpc.stream_unary_rpc_method_handler(
self._stream_unary_reader_writer),
_STREAM_UNARY_EVILLY_MIXED:
grpc.stream_unary_rpc_method_handler(
self._stream_unary_evilly_mixed),
_STREAM_STREAM_ASYNC_GEN:
grpc.stream_stream_rpc_method_handler(
self._stream_stream_async_gen),
_STREAM_STREAM_READER_WRITER:
grpc.stream_stream_rpc_method_handler(
self._stream_stream_reader_writer),
_STREAM_STREAM_EVILLY_MIXED:
grpc.stream_stream_rpc_method_handler(
self._stream_stream_evilly_mixed),
}
@staticmethod @staticmethod
async def _unary_unary(unused_request, unused_context): async def _unary_unary(unused_request, unused_context):
@ -64,23 +107,59 @@ class _GenericHandler(grpc.GenericRpcHandler):
for _ in range(_NUM_STREAM_RESPONSES - 1): for _ in range(_NUM_STREAM_RESPONSES - 1):
await context.write(_RESPONSE) await context.write(_RESPONSE)
async def _stream_unary_async_gen(self, request_iterator, unused_context):
request_count = 0
async for request in request_iterator:
assert _REQUEST == request
request_count += 1
assert _NUM_STREAM_REQUESTS == request_count
return _RESPONSE
async def _stream_unary_reader_writer(self, unused_request, context):
for _ in range(_NUM_STREAM_REQUESTS):
assert _REQUEST == await context.read()
return _RESPONSE
async def _stream_unary_evilly_mixed(self, request_iterator, context):
assert _REQUEST == await context.read()
request_count = 0
async for request in request_iterator:
assert _REQUEST == request
request_count += 1
assert _NUM_STREAM_REQUESTS - 1 == request_count
return _RESPONSE
async def _stream_stream_async_gen(self, request_iterator, unused_context):
request_count = 0
async for request in request_iterator:
assert _REQUEST == request
request_count += 1
assert _NUM_STREAM_REQUESTS == request_count
for _ in range(_NUM_STREAM_RESPONSES):
yield _RESPONSE
async def _stream_stream_reader_writer(self, unused_request, context):
for _ in range(_NUM_STREAM_REQUESTS):
assert _REQUEST == await context.read()
for _ in range(_NUM_STREAM_RESPONSES):
await context.write(_RESPONSE)
async def _stream_stream_evilly_mixed(self, request_iterator, context):
assert _REQUEST == await context.read()
request_count = 0
async for request in request_iterator:
assert _REQUEST == request
request_count += 1
assert _NUM_STREAM_REQUESTS - 1 == request_count
yield _RESPONSE
for _ in range(_NUM_STREAM_RESPONSES - 1):
await context.write(_RESPONSE)
def service(self, handler_details): def service(self, handler_details):
self._called.set_result(None) self._called.set_result(None)
if handler_details.method == _SIMPLE_UNARY_UNARY: return self._routing_table[handler_details.method]
return grpc.unary_unary_rpc_method_handler(self._unary_unary)
if handler_details.method == _BLOCK_FOREVER:
return grpc.unary_unary_rpc_method_handler(self._block_forever)
if handler_details.method == _BLOCK_BRIEFLY:
return grpc.unary_unary_rpc_method_handler(self._block_briefly)
if handler_details.method == _UNARY_STREAM_ASYNC_GEN:
return grpc.unary_stream_rpc_method_handler(
self._unary_stream_async_gen)
if handler_details.method == _UNARY_STREAM_READER_WRITER:
return grpc.unary_stream_rpc_method_handler(
self._unary_stream_reader_writer)
if handler_details.method == _UNARY_STREAM_EVILLY_MIXED:
return grpc.unary_stream_rpc_method_handler(
self._unary_stream_evilly_mixed)
async def wait_for_call(self): async def wait_for_call(self):
await self._called await self._called
@ -98,89 +177,152 @@ async def _start_test_server():
class TestServer(AioTestBase): class TestServer(AioTestBase):
async def setUp(self): async def setUp(self):
self._server_target, self._server, self._generic_handler = await _start_test_server( addr, self._server, self._generic_handler = await _start_test_server()
) self._channel = aio.insecure_channel(addr)
async def tearDown(self): async def tearDown(self):
await self._channel.close()
await self._server.stop(None) await self._server.stop(None)
async def test_unary_unary(self): async def test_unary_unary(self):
async with aio.insecure_channel(self._server_target) as channel: unary_unary_call = self._channel.unary_unary(_SIMPLE_UNARY_UNARY)
unary_unary_call = channel.unary_unary(_SIMPLE_UNARY_UNARY) response = await unary_unary_call(_REQUEST)
response = await unary_unary_call(_REQUEST) self.assertEqual(response, _RESPONSE)
self.assertEqual(response, _RESPONSE)
async def test_unary_stream_async_generator(self): async def test_unary_stream_async_generator(self):
async with aio.insecure_channel(self._server_target) as channel: unary_stream_call = self._channel.unary_stream(_UNARY_STREAM_ASYNC_GEN)
unary_stream_call = channel.unary_stream(_UNARY_STREAM_ASYNC_GEN) call = unary_stream_call(_REQUEST)
call = unary_stream_call(_REQUEST)
# Expecting the request message to reach server before retriving response_cnt = 0
# any responses. async for response in call:
await asyncio.wait_for(self._generic_handler.wait_for_call(), response_cnt += 1
test_constants.SHORT_TIMEOUT) self.assertEqual(_RESPONSE, response)
response_cnt = 0 self.assertEqual(_NUM_STREAM_RESPONSES, response_cnt)
async for response in call: self.assertEqual(await call.code(), grpc.StatusCode.OK)
response_cnt += 1
self.assertEqual(_RESPONSE, response)
self.assertEqual(_NUM_STREAM_RESPONSES, response_cnt)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
async def test_unary_stream_reader_writer(self): async def test_unary_stream_reader_writer(self):
async with aio.insecure_channel(self._server_target) as channel: unary_stream_call = self._channel.unary_stream(
unary_stream_call = channel.unary_stream( _UNARY_STREAM_READER_WRITER)
_UNARY_STREAM_READER_WRITER) call = unary_stream_call(_REQUEST)
call = unary_stream_call(_REQUEST)
# Expecting the request message to reach server before retriving
# any responses.
await asyncio.wait_for(self._generic_handler.wait_for_call(),
test_constants.SHORT_TIMEOUT)
for _ in range(_NUM_STREAM_RESPONSES): for _ in range(_NUM_STREAM_RESPONSES):
response = await call.read() response = await call.read()
self.assertEqual(_RESPONSE, response) self.assertEqual(_RESPONSE, response)
self.assertEqual(await call.code(), grpc.StatusCode.OK) self.assertEqual(await call.code(), grpc.StatusCode.OK)
async def test_unary_stream_evilly_mixed(self): async def test_unary_stream_evilly_mixed(self):
async with aio.insecure_channel(self._server_target) as channel: unary_stream_call = self._channel.unary_stream(
unary_stream_call = channel.unary_stream(_UNARY_STREAM_EVILLY_MIXED) _UNARY_STREAM_EVILLY_MIXED)
call = unary_stream_call(_REQUEST) call = unary_stream_call(_REQUEST)
# Uses reader API
self.assertEqual(_RESPONSE, await call.read())
# Uses async generator API
response_cnt = 0
async for response in call:
response_cnt += 1
self.assertEqual(_RESPONSE, response)
self.assertEqual(_NUM_STREAM_RESPONSES - 1, response_cnt)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
async def test_stream_unary_async_generator(self):
stream_unary_call = self._channel.stream_unary(_STREAM_UNARY_ASYNC_GEN)
call = stream_unary_call()
for _ in range(_NUM_STREAM_REQUESTS):
await call.write(_REQUEST)
await call.done_writing()
response = await call
self.assertEqual(_RESPONSE, response)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
async def test_stream_unary_reader_writer(self):
stream_unary_call = self._channel.stream_unary(
_STREAM_UNARY_READER_WRITER)
call = stream_unary_call()
for _ in range(_NUM_STREAM_REQUESTS):
await call.write(_REQUEST)
await call.done_writing()
response = await call
self.assertEqual(_RESPONSE, response)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
async def test_stream_unary_evilly_mixed(self):
stream_unary_call = self._channel.stream_unary(
_STREAM_UNARY_EVILLY_MIXED)
call = stream_unary_call()
# Expecting the request message to reach server before retriving for _ in range(_NUM_STREAM_REQUESTS):
# any responses. await call.write(_REQUEST)
await asyncio.wait_for(self._generic_handler.wait_for_call(), await call.done_writing()
test_constants.SHORT_TIMEOUT)
# Uses reader API response = await call
self.assertEqual(_RESPONSE, await call.read()) self.assertEqual(_RESPONSE, response)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
# Uses async generator API async def test_stream_stream_async_generator(self):
response_cnt = 0 stream_stream_call = self._channel.stream_stream(
async for response in call: _STREAM_STREAM_ASYNC_GEN)
response_cnt += 1 call = stream_stream_call()
self.assertEqual(_RESPONSE, response)
self.assertEqual(_NUM_STREAM_RESPONSES - 1, response_cnt) for _ in range(_NUM_STREAM_REQUESTS):
await call.write(_REQUEST)
await call.done_writing()
self.assertEqual(await call.code(), grpc.StatusCode.OK) for _ in range(_NUM_STREAM_RESPONSES):
response = await call.read()
self.assertEqual(_RESPONSE, response)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
async def test_stream_stream_reader_writer(self):
stream_stream_call = self._channel.stream_stream(
_STREAM_STREAM_READER_WRITER)
call = stream_stream_call()
for _ in range(_NUM_STREAM_REQUESTS):
await call.write(_REQUEST)
await call.done_writing()
for _ in range(_NUM_STREAM_RESPONSES):
response = await call.read()
self.assertEqual(_RESPONSE, response)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
async def test_stream_stream_evilly_mixed(self):
stream_stream_call = self._channel.stream_stream(
_STREAM_STREAM_EVILLY_MIXED)
call = stream_stream_call()
for _ in range(_NUM_STREAM_REQUESTS):
await call.write(_REQUEST)
await call.done_writing()
for _ in range(_NUM_STREAM_RESPONSES):
response = await call.read()
self.assertEqual(_RESPONSE, response)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
async def test_shutdown(self): async def test_shutdown(self):
await self._server.stop(None) await self._server.stop(None)
# Ensures no SIGSEGV triggered, and ends within timeout. # Ensures no SIGSEGV triggered, and ends within timeout.
async def test_shutdown_after_call(self): async def test_shutdown_after_call(self):
async with aio.insecure_channel(self._server_target) as channel: await self._channel.unary_unary(_SIMPLE_UNARY_UNARY)(_REQUEST)
await channel.unary_unary(_SIMPLE_UNARY_UNARY)(_REQUEST)
await self._server.stop(None) await self._server.stop(None)
async def test_graceful_shutdown_success(self): async def test_graceful_shutdown_success(self):
channel = aio.insecure_channel(self._server_target) call = self._channel.unary_unary(_BLOCK_BRIEFLY)(_REQUEST)
call = channel.unary_unary(_BLOCK_BRIEFLY)(_REQUEST)
await self._generic_handler.wait_for_call() await self._generic_handler.wait_for_call()
shutdown_start_time = time.time() shutdown_start_time = time.time()
@ -190,13 +332,11 @@ class TestServer(AioTestBase):
test_constants.SHORT_TIMEOUT / 3) test_constants.SHORT_TIMEOUT / 3)
# Validates the states. # Validates the states.
await channel.close()
self.assertEqual(_RESPONSE, await call) self.assertEqual(_RESPONSE, await call)
self.assertTrue(call.done()) self.assertTrue(call.done())
async def test_graceful_shutdown_failed(self): async def test_graceful_shutdown_failed(self):
channel = aio.insecure_channel(self._server_target) call = self._channel.unary_unary(_BLOCK_FOREVER)(_REQUEST)
call = channel.unary_unary(_BLOCK_FOREVER)(_REQUEST)
await self._generic_handler.wait_for_call() await self._generic_handler.wait_for_call()
await self._server.stop(test_constants.SHORT_TIMEOUT) await self._server.stop(test_constants.SHORT_TIMEOUT)
@ -206,11 +346,9 @@ class TestServer(AioTestBase):
self.assertEqual(grpc.StatusCode.UNAVAILABLE, self.assertEqual(grpc.StatusCode.UNAVAILABLE,
exception_context.exception.code()) exception_context.exception.code())
self.assertIn('GOAWAY', exception_context.exception.details()) self.assertIn('GOAWAY', exception_context.exception.details())
await channel.close()
async def test_concurrent_graceful_shutdown(self): async def test_concurrent_graceful_shutdown(self):
channel = aio.insecure_channel(self._server_target) call = self._channel.unary_unary(_BLOCK_BRIEFLY)(_REQUEST)
call = channel.unary_unary(_BLOCK_BRIEFLY)(_REQUEST)
await self._generic_handler.wait_for_call() await self._generic_handler.wait_for_call()
# Expects the shortest grace period to be effective. # Expects the shortest grace period to be effective.
@ -224,13 +362,11 @@ class TestServer(AioTestBase):
self.assertGreater(grace_period_length, self.assertGreater(grace_period_length,
test_constants.SHORT_TIMEOUT / 3) test_constants.SHORT_TIMEOUT / 3)
await channel.close()
self.assertEqual(_RESPONSE, await call) self.assertEqual(_RESPONSE, await call)
self.assertTrue(call.done()) self.assertTrue(call.done())
async def test_concurrent_graceful_shutdown_immediate(self): async def test_concurrent_graceful_shutdown_immediate(self):
channel = aio.insecure_channel(self._server_target) call = self._channel.unary_unary(_BLOCK_FOREVER)(_REQUEST)
call = channel.unary_unary(_BLOCK_FOREVER)(_REQUEST)
await self._generic_handler.wait_for_call() await self._generic_handler.wait_for_call()
# Expects no grace period, due to the "server.stop(None)". # Expects no grace period, due to the "server.stop(None)".
@ -246,7 +382,6 @@ class TestServer(AioTestBase):
self.assertEqual(grpc.StatusCode.UNAVAILABLE, self.assertEqual(grpc.StatusCode.UNAVAILABLE,
exception_context.exception.code()) exception_context.exception.code())
self.assertIn('GOAWAY', exception_context.exception.details()) self.assertIn('GOAWAY', exception_context.exception.details())
await channel.close()
@unittest.skip('https://github.com/grpc/grpc/issues/20818') @unittest.skip('https://github.com/grpc/grpc/issues/20818')
async def test_shutdown_before_call(self): async def test_shutdown_before_call(self):

Loading…
Cancel
Save