First step of adding streaming API:

* Refactors how Task wrapper work on the client-side
* Refactors final status propagation and unify similar classes
* Adds unary-stream API for both-sides
* Refactors each abstraction layer multicallable / call / channel
* Revisits the design of cancellation on client-side
* Makes server methods interuptable
* Fixes a zombie coroutine issue in shutdown path
pull/21232/head
Lidi Zheng 5 years ago
parent f5666958f9
commit 5e5781c961
  1. 9
      src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pxd.pxi
  2. 244
      src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi
  3. 72
      src/python/grpcio/grpc/_cython/_cygrpc/aio/callback_common.pyx.pxi
  4. 42
      src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi
  5. 35
      src/python/grpcio/grpc/_cython/_cygrpc/aio/common.pyx.pxi
  6. 2
      src/python/grpcio/grpc/_cython/_cygrpc/aio/iomgr/socket.pxd.pxi
  7. 51
      src/python/grpcio/grpc/_cython/_cygrpc/aio/iomgr/socket.pyx.pxi
  8. 8
      src/python/grpcio/grpc/_cython/_cygrpc/aio/rpc_status.pxd.pxi
  9. 21
      src/python/grpcio/grpc/_cython/_cygrpc/aio/rpc_status.pyx.pxi
  10. 1
      src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pxd.pxi
  11. 199
      src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi
  12. 2
      src/python/grpcio/grpc/_cython/cygrpc.pxd
  13. 6
      src/python/grpcio/grpc/_cython/cygrpc.pyx
  14. 2
      src/python/grpcio/grpc/experimental/BUILD.bazel
  15. 9
      src/python/grpcio/grpc/experimental/aio/__init__.py
  16. 127
      src/python/grpcio/grpc/experimental/aio/_base_call.py
  17. 451
      src/python/grpcio/grpc/experimental/aio/_call.py
  18. 157
      src/python/grpcio/grpc/experimental/aio/_channel.py
  19. 17
      src/python/grpcio/grpc/experimental/aio/_typing.py
  20. 32
      src/python/grpcio_tests/tests_aio/benchmark/BUILD.bazel
  21. 9
      src/python/grpcio_tests/tests_aio/benchmark/server.py
  22. 44
      src/python/grpcio_tests/tests_aio/unit/_test_base.py
  23. 19
      src/python/grpcio_tests/tests_aio/unit/_test_server.py
  24. 45
      src/python/grpcio_tests/tests_aio/unit/aio_rpc_error_test.py
  25. 278
      src/python/grpcio_tests/tests_aio/unit/call_test.py
  26. 95
      src/python/grpcio_tests/tests_aio/unit/channel_test.py
  27. 6
      src/python/grpcio_tests/tests_aio/unit/init_test.py
  28. 153
      src/python/grpcio_tests/tests_aio/unit/server_test.py

@ -18,6 +18,15 @@ cdef class _AioCall:
AioChannel _channel
list _references
GrpcCallWrapper _grpc_call_wrapper
object _loop
# Streaming call only attributes:
#
# A asyncio.Event that indicates if the status is received on the client side.
object _status_received
# A tuple of key value pairs representing the initial metadata sent by peer.
tuple _initial_metadata
cdef grpc_call* _create_grpc_call(self, object timeout, bytes method) except *
cdef void _destroy_grpc_call(self)
cdef AioRpcStatus _cancel_and_create_status(self, object cancellation_future)

@ -19,6 +19,8 @@ _EMPTY_FLAGS = 0
_EMPTY_MASK = 0
_EMPTY_METADATA = None
_UNKNOWN_CANCELLATION_DETAILS = 'RPC cancelled due to unknown reason.'
cdef class _AioCall:
@ -26,6 +28,9 @@ cdef class _AioCall:
self._channel = channel
self._references = []
self._grpc_call_wrapper = GrpcCallWrapper()
self._loop = asyncio.get_event_loop()
self._status_received = asyncio.Event(loop=self._loop)
def __repr__(self):
class_name = self.__class__.__name__
@ -33,7 +38,7 @@ cdef class _AioCall:
return f"<{class_name} {id_}>"
cdef grpc_call* _create_grpc_call(self,
object timeout,
object deadline,
bytes method) except *:
"""Creates the corresponding Core object for this RPC.
@ -44,7 +49,7 @@ cdef class _AioCall:
nature in Core.
"""
cdef grpc_slice method_slice
cdef gpr_timespec deadline = _timespec_from_time(timeout)
cdef gpr_timespec c_deadline = _timespec_from_time(deadline)
method_slice = grpc_slice_from_copied_buffer(
<const char *> method,
@ -57,93 +62,204 @@ cdef class _AioCall:
self._channel.cq.c_ptr(),
method_slice,
NULL,
deadline,
c_deadline,
NULL
)
grpc_slice_unref(method_slice)
cdef void _destroy_grpc_call(self):
"""Destroys the corresponding Core object for this RPC."""
if self._grpc_call_wrapper.call != NULL:
grpc_call_unref(self._grpc_call_wrapper.call)
async def unary_unary(self, bytes method, bytes request, object timeout, AioCancelStatus cancel_status):
cdef object loop = asyncio.get_event_loop()
cdef tuple operations
cdef Operation initial_metadata_operation
cdef Operation send_message_operation
cdef Operation send_close_from_client_operation
cdef Operation receive_initial_metadata_operation
cdef Operation receive_message_operation
cdef Operation receive_status_on_client_operation
cdef AioRpcStatus _cancel_and_create_status(self, object cancellation_future):
"""Cancels the RPC in C-Core, and return the final RPC status."""
cdef AioRpcStatus status
cdef object details
cdef char *c_details
# Try to fetch application layer cancellation details in the future.
# * If calcellation details present, cancel with status;
# * If details not present, cancel with unknown reason.
if cancellation_future.done():
status = cancellation_future.result()
details = str_to_bytes(status.details())
self._references.append(details)
c_details = <char *>details
# By implementation, grpc_call_cancel_with_status always return OK
grpc_call_cancel_with_status(
self._grpc_call_wrapper.call,
status.c_code(),
c_details,
NULL,
)
return status
else:
# By implementation, grpc_call_cancel always return OK
grpc_call_cancel(self._grpc_call_wrapper.call, NULL)
return AioRpcStatus(
StatusCode.cancelled,
_UNKNOWN_CANCELLATION_DETAILS,
None,
None,
)
cdef char *c_details = NULL
async def unary_unary(self,
bytes method,
bytes request,
object deadline,
object cancellation_future,
object initial_metadata_observer,
object status_observer):
"""Performs a unary unary RPC.
initial_metadata_operation = SendInitialMetadataOperation(_EMPTY_METADATA, GRPC_INITIAL_METADATA_USED_MASK)
initial_metadata_operation.c()
Args:
method: name of the calling method in bytes.
request: the serialized requests in bytes.
deadline: optional deadline of the RPC in float.
cancellation_future: the future that meant to transport the
cancellation reason from the application layer.
initial_metadata_observer: a callback for received initial metadata.
status_observer: a callback for received final status.
"""
cdef tuple ops
send_message_operation = SendMessageOperation(request, _EMPTY_FLAGS)
send_message_operation.c()
cdef SendInitialMetadataOperation initial_metadata_op = SendInitialMetadataOperation(
_EMPTY_METADATA,
GRPC_INITIAL_METADATA_USED_MASK)
cdef SendMessageOperation send_message_op = SendMessageOperation(request, _EMPTY_FLAGS)
cdef SendCloseFromClientOperation send_close_op = SendCloseFromClientOperation(_EMPTY_FLAGS)
cdef ReceiveInitialMetadataOperation receive_initial_metadata_op = ReceiveInitialMetadataOperation(_EMPTY_FLAGS)
cdef ReceiveMessageOperation receive_message_op = ReceiveMessageOperation(_EMPTY_FLAGS)
cdef ReceiveStatusOnClientOperation receive_status_on_client_op = ReceiveStatusOnClientOperation(_EMPTY_FLAGS)
send_close_from_client_operation = SendCloseFromClientOperation(_EMPTY_FLAGS)
send_close_from_client_operation.c()
ops = (initial_metadata_op, send_message_op, send_close_op,
receive_initial_metadata_op, receive_message_op,
receive_status_on_client_op)
receive_initial_metadata_operation = ReceiveInitialMetadataOperation(_EMPTY_FLAGS)
receive_initial_metadata_operation.c()
try:
self._create_grpc_call(deadline, method)
try:
await callback_start_batch(self._grpc_call_wrapper,
ops,
self._loop)
except asyncio.CancelledError:
status = self._cancel_and_create_status(cancellation_future)
status_observer(status)
raise
finally:
# If the RPC failed, this method will return None instead of crash.
initial_metadata_observer(
receive_initial_metadata_op.initial_metadata()
)
self._destroy_grpc_call()
receive_message_operation = ReceiveMessageOperation(_EMPTY_FLAGS)
receive_message_operation.c()
status = AioRpcStatus(
receive_status_on_client_op.code(),
receive_status_on_client_op.details(),
receive_status_on_client_op.trailing_metadata(),
receive_status_on_client_op.error_string(),
)
# Reports the final status of the RPC to Python layer. The observer
# pattern is used here to unify unary and streaming code path.
status_observer(status)
receive_status_on_client_operation = ReceiveStatusOnClientOperation(_EMPTY_FLAGS)
receive_status_on_client_operation.c()
if status.code() == StatusCode.ok:
return receive_message_op.message()
else:
return None
operations = (
initial_metadata_operation,
send_message_operation,
send_close_from_client_operation,
receive_initial_metadata_operation,
receive_message_operation,
receive_status_on_client_operation,
async def _handle_status_once_received(self, object status_observer):
"""Handles the status sent by peer once received."""
cdef ReceiveStatusOnClientOperation op = ReceiveStatusOnClientOperation(_EMPTY_FLAGS)
cdef tuple ops = (op,)
await callback_start_batch(self._grpc_call_wrapper, ops, self._loop)
cdef AioRpcStatus status = AioRpcStatus(
op.code(),
op.details(),
op.trailing_metadata(),
op.error_string(),
)
status_observer(status)
self._status_received.set()
try:
self._create_grpc_call(
timeout,
method,
def _handle_cancellation_from_application(self,
object cancellation_future,
object status_observer):
def _cancellation_action(finished_future):
status = self._cancel_and_create_status(finished_future)
status_observer(status)
cancellation_future.add_done_callback(_cancellation_action)
async def unary_stream(self,
bytes method,
bytes request,
object deadline,
object cancellation_future,
object initial_metadata_observer,
object status_observer):
"""Actual implementation of the complete unary-stream call.
Needs to pay extra attention to the raise mechanism. If we want to
propagate the final status exception, then we have to raise it.
Othersize, it would end normally and raise `StopAsyncIteration()`.
"""
cdef bytes received_message
cdef tuple outbound_ops
cdef Operation initial_metadata_op = SendInitialMetadataOperation(
_EMPTY_METADATA,
GRPC_INITIAL_METADATA_USED_MASK)
cdef Operation send_message_op = SendMessageOperation(
request,
_EMPTY_FLAGS)
cdef Operation send_close_op = SendCloseFromClientOperation(
_EMPTY_FLAGS)
outbound_ops = (
initial_metadata_op,
send_message_op,
send_close_op,
)
# NOTE(lidiz) Not catching CancelledError here, because async
# generators do not have "cancel" method.
try:
self._create_grpc_call(deadline, method)
await callback_start_batch(
self._grpc_call_wrapper,
operations,
loop
outbound_ops,
self._loop)
# Peer may prematurely end this RPC at any point. We need a mechanism
# that handles both the normal case and the error case.
self._loop.create_task(self._handle_status_once_received(status_observer))
self._handle_cancellation_from_application(cancellation_future,
status_observer)
# Receives initial metadata.
initial_metadata_observer(
await _receive_initial_metadata(self._grpc_call_wrapper,
self._loop),
)
except asyncio.CancelledError:
if cancel_status:
details = str_to_bytes(cancel_status.details())
self._references.append(details)
c_details = <char *>details
call_status = grpc_call_cancel_with_status(
self._grpc_call_wrapper.call,
cancel_status.code(),
c_details,
NULL,
# Infinitely receiving messages, until:
# * EOF, no more messages to read;
# * The client application cancells;
# * The server sends final status.
while True:
if self._status_received.is_set():
return
received_message = await _receive_message(
self._grpc_call_wrapper,
self._loop
)
if received_message is None:
# The read operation failed, wait for status from C-Core.
await self._status_received.wait()
return
else:
call_status = grpc_call_cancel(
self._grpc_call_wrapper.call, NULL)
if call_status != GRPC_CALL_OK:
raise Exception("RPC call couldn't be cancelled. Error {}".format(call_status))
raise
yield received_message
finally:
self._destroy_grpc_call()
if receive_status_on_client_operation.code() == StatusCode.ok:
return receive_message_operation.message()
raise AioRpcError(
receive_initial_metadata_operation.initial_metadata(),
receive_status_on_client_operation.code(),
receive_status_on_client_operation.details(),
receive_status_on_client_operation.trailing_metadata(),
)

@ -46,11 +46,13 @@ cdef class CallbackWrapper:
grpc_experimental_completion_queue_functor* functor,
int success):
cdef CallbackContext *context = <CallbackContext *>functor
cdef object waiter = <object>context.waiter
if waiter.cancelled():
return
if success == 0:
(<CallbackFailureHandler>context.failure_handler).handle(
<object>context.waiter)
(<CallbackFailureHandler>context.failure_handler).handle(waiter)
else:
(<object>context.waiter).set_result(None)
waiter.set_result(None)
cdef grpc_experimental_completion_queue_functor *c_functor(self):
return &self.context.functor
@ -83,6 +85,9 @@ cdef class CallbackCompletionQueue:
grpc_completion_queue_destroy(self._cq)
class CallbackStartBatchError(Exception): pass
async def callback_start_batch(GrpcCallWrapper grpc_call_wrapper,
tuple operations,
object loop):
@ -93,7 +98,7 @@ async def callback_start_batch(GrpcCallWrapper grpc_call_wrapper,
cdef object future = loop.create_future()
cdef CallbackWrapper wrapper = CallbackWrapper(
future,
CallbackFailureHandler('callback_start_batch', operations, RuntimeError))
CallbackFailureHandler('callback_start_batch', operations, CallbackStartBatchError))
# NOTE(lidiz) Without Py_INCREF, the wrapper object will be destructed
# when calling "await". This is an over-optimization by Cython.
cpython.Py_INCREF(wrapper)
@ -104,10 +109,67 @@ async def callback_start_batch(GrpcCallWrapper grpc_call_wrapper,
wrapper.c_functor(), NULL)
if error != GRPC_CALL_OK:
raise RuntimeError("Failed grpc_call_start_batch: {}".format(error))
raise CallbackStartBatchError("Failed grpc_call_start_batch: {}".format(error))
await future
cpython.Py_DECREF(wrapper)
cdef grpc_event c_event
# Tag.event must be called, otherwise messages won't be parsed from C
batch_operation_tag.event(c_event)
async def _receive_message(GrpcCallWrapper grpc_call_wrapper,
object loop):
"""Retrives parsed messages from C-Core.
The messages maybe already in C-Core's buffer, so there isn't a 1-to-1
mapping between this and the underlying "socket.read()". Also, eventually,
this function will end with an EOF, which reads empty message.
"""
cdef ReceiveMessageOperation receive_op = ReceiveMessageOperation(_EMPTY_FLAG)
cdef tuple ops = (receive_op,)
try:
await callback_start_batch(grpc_call_wrapper, ops, loop)
except CallbackStartBatchError as e:
# NOTE(lidiz) The receive message operation has two ways to indicate
# finish state : 1) returns empty message due to EOF; 2) fails inside
# the callback (e.g. cancelled).
#
# Since they all indicates finish, they are better be merged.
_LOGGER.exception(e)
return receive_op.message()
async def _send_message(GrpcCallWrapper grpc_call_wrapper,
bytes message,
bint metadata_sent,
object loop):
cdef SendMessageOperation op = SendMessageOperation(message, _EMPTY_FLAG)
cdef tuple ops
if metadata_sent:
ops = (op,)
else:
ops = (
# Initial metadata must be sent before first outbound message.
SendInitialMetadataOperation(None, _EMPTY_FLAG),
op,
)
await callback_start_batch(grpc_call_wrapper, ops, loop)
async def _send_initial_metadata(GrpcCallWrapper grpc_call_wrapper,
tuple metadata,
object loop):
cdef SendInitialMetadataOperation op = SendInitialMetadataOperation(
metadata,
_EMPTY_FLAG)
cdef tuple ops = (op,)
await callback_start_batch(grpc_call_wrapper, ops, loop)
async def _receive_initial_metadata(GrpcCallWrapper grpc_call_wrapper,
object loop):
cdef ReceiveInitialMetadataOperation op = ReceiveInitialMetadataOperation(_EMPTY_FLAGS)
cdef tuple ops = (op,)
await callback_start_batch(grpc_call_wrapper, ops, loop)
return op.initial_metadata()

@ -26,6 +26,42 @@ cdef class AioChannel:
def close(self):
grpc_channel_destroy(self.channel)
async def unary_unary(self, method, request, timeout, cancel_status):
call = _AioCall(self)
return await call.unary_unary(method, request, timeout, cancel_status)
async def unary_unary(self,
bytes method,
bytes request,
object deadline,
object cancellation_future,
object initial_metadata_observer,
object status_observer):
"""Assembles a unary-unary RPC.
Returns:
The response message in bytes.
"""
cdef _AioCall call = _AioCall(self)
return await call.unary_unary(method,
request,
deadline,
cancellation_future,
initial_metadata_observer,
status_observer)
def unary_stream(self,
bytes method,
bytes request,
object deadline,
object cancellation_future,
object initial_metadata_observer,
object status_observer):
"""Assembles a unary-stream RPC.
Returns:
An async generator that yields raw responses.
"""
cdef _AioCall call = _AioCall(self)
return call.unary_stream(method,
request,
deadline,
cancellation_future,
initial_metadata_observer,
status_observer)

@ -1,4 +1,4 @@
# Copyright 2019 gRPC authors.
# Copyright 2019 The gRPC Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -11,26 +11,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Desired cancellation status for canceling an ongoing RPC call."""
cdef class AioCancelStatus:
cdef object deserialize(object deserializer, bytes raw_message):
"""Perform deserialization on raw bytes.
def __cinit__(self):
self._code = None
self._details = None
Failure to deserialize is a fatal error.
"""
if deserializer:
return deserializer(raw_message)
else:
return raw_message
def __len__(self):
if self._code is None:
return 0
return 1
def cancel(self, grpc_status_code code, str details=None):
self._code = code
self._details = details
cdef bytes serialize(object serializer, object message):
"""Perform serialization on a message.
cpdef object code(self):
return self._code
cpdef str details(self):
return self._details
Failure to serialize is a fatal error.
"""
if serializer:
return serializer(message)
else:
return message

@ -23,12 +23,14 @@ cdef class _AsyncioSocket:
object _task_read
object _task_connect
char * _read_buffer
object _loop
# Client-side attributes
grpc_custom_connect_callback _grpc_connect_cb
# Server-side attributes
grpc_custom_accept_callback _grpc_accept_cb
grpc_custom_write_callback _grpc_write_cb
grpc_custom_socket * _grpc_client_socket
object _server
object _py_socket

@ -16,11 +16,14 @@ import socket as native_socket
from libc cimport string
# TODO(https://github.com/grpc/grpc/issues/21348) Better flow control needed.
cdef class _AsyncioSocket:
def __cinit__(self):
self._grpc_socket = NULL
self._grpc_connect_cb = NULL
self._grpc_read_cb = NULL
self._grpc_write_cb = NULL
self._reader = None
self._writer = None
self._task_connect = None
@ -29,6 +32,7 @@ cdef class _AsyncioSocket:
self._server = None
self._py_socket = None
self._peername = None
self._loop = asyncio.get_event_loop()
@staticmethod
cdef _AsyncioSocket create(grpc_custom_socket * grpc_socket,
@ -56,16 +60,16 @@ cdef class _AsyncioSocket:
return f"<{class_name} {id_} connected={connected}>"
def _connect_cb(self, future):
error = False
try:
self._reader, self._writer = future.result()
except Exception as e:
error = True
error_msg = str(e)
self._grpc_connect_cb(
<grpc_custom_socket*>self._grpc_socket,
grpc_socket_error("Socket connect failed: {}".format(e).encode())
)
finally:
self._task_connect = None
if not error:
# gRPC default posix implementation disables nagle
# algorithm.
sock = self._writer.transport.get_extra_info('socket')
@ -75,11 +79,6 @@ cdef class _AsyncioSocket:
<grpc_custom_socket*>self._grpc_socket,
<grpc_error*>0
)
else:
self._grpc_connect_cb(
<grpc_custom_socket*>self._grpc_socket,
grpc_socket_error("connect {}".format(error_msg).encode())
)
def _read_cb(self, future):
error = False
@ -87,7 +86,8 @@ cdef class _AsyncioSocket:
buffer_ = future.result()
except Exception as e:
error = True
error_msg = str(e)
error_msg = "%s: %s" % (type(e), str(e))
_LOGGER.exception(e)
finally:
self._task_read = None
@ -106,7 +106,7 @@ cdef class _AsyncioSocket:
self._grpc_read_cb(
<grpc_custom_socket*>self._grpc_socket,
-1,
grpc_socket_error("read {}".format(error_msg).encode())
grpc_socket_error("Read failed: {}".format(error_msg).encode())
)
cdef void connect(self,
@ -125,23 +125,34 @@ cdef class _AsyncioSocket:
cdef void read(self, char * buffer_, size_t length, grpc_custom_read_callback grpc_read_cb):
assert not self._task_read
self._task_read = asyncio.ensure_future(
self._task_read = self._loop.create_task(
self._reader.read(n=length)
)
self._grpc_read_cb = grpc_read_cb
self._task_read.add_done_callback(self._read_cb)
self._read_buffer = buffer_
async def _async_write(self, bytearray buffer):
self._writer.write(buffer)
await self._writer.drain()
self._grpc_write_cb(
<grpc_custom_socket*>self._grpc_socket,
<grpc_error*>0
)
cdef void write(self, grpc_slice_buffer * g_slice_buffer, grpc_custom_write_callback grpc_write_cb):
# For each socket, C-Core guarantees there'll be only one ongoing write
self._grpc_write_cb = grpc_write_cb
cdef char* start
buffer_ = bytearray()
buffer = bytearray()
for i in range(g_slice_buffer.count):
start = grpc_slice_buffer_start(g_slice_buffer, i)
length = grpc_slice_buffer_length(g_slice_buffer, i)
buffer_.extend(<bytes>start[:length])
self._writer.write(buffer_)
buffer.extend(<bytes>start[:length])
self._writer.write(buffer)
grpc_write_cb(
<grpc_custom_socket*>self._grpc_socket,
<grpc_error*>0
@ -171,9 +182,9 @@ cdef class _AsyncioSocket:
self._grpc_client_socket.impl = <void*>client_socket
cpython.Py_INCREF(client_socket) # Py_DECREF in asyncio_socket_destroy
# Accept callback expects to be called with:
# grpc_custom_socket: A grpc custom socket for server
# grpc_custom_socket: A grpc custom socket for client (with new Socket instance)
# grpc_error: An error object
# * grpc_custom_socket: A grpc custom socket for server
# * grpc_custom_socket: A grpc custom socket for client (with new Socket instance)
# * grpc_error: An error object
self._grpc_accept_cb(self._grpc_socket, self._grpc_client_socket, grpc_error_none())
cdef listen(self):
@ -183,7 +194,7 @@ cdef class _AsyncioSocket:
sock=self._py_socket,
)
asyncio.get_event_loop().create_task(create_asyncio_server())
self._loop.create_task(create_asyncio_server())
cdef accept(self,
grpc_custom_socket* grpc_socket_client,

@ -14,14 +14,16 @@
"""Exceptions for the aio version of the RPC calls."""
cdef class _AioRpcError(Exception):
cdef class AioRpcStatus(Exception):
cdef readonly:
tuple _initial_metadata
int _code
str _details
# On spec, only client-side status has trailing metadata.
tuple _trailing_metadata
str _debug_error_string
cpdef tuple initial_metadata(self)
cpdef int code(self)
cpdef str details(self)
cpdef tuple trailing_metadata(self)
cpdef str debug_error_string(self)
cdef grpc_status_code c_code(self)

@ -14,16 +14,19 @@
"""Exceptions for the aio version of the RPC calls."""
cdef class AioRpcError(Exception):
cdef class AioRpcStatus(Exception):
def __cinit__(self, tuple initial_metadata, int code, str details, tuple trailing_metadata):
self._initial_metadata = initial_metadata
# The final status of gRPC is represented by three trailing metadata:
# `grpc-status`, `grpc-status-message`, abd `grpc-status-details`.
def __cinit__(self,
int code,
str details,
tuple trailing_metadata,
str debug_error_string):
self._code = code
self._details = details
self._trailing_metadata = trailing_metadata
cpdef tuple initial_metadata(self):
return self._initial_metadata
self._debug_error_string = debug_error_string
cpdef int code(self):
return self._code
@ -33,3 +36,9 @@ cdef class AioRpcError(Exception):
cpdef tuple trailing_metadata(self):
return self._trailing_metadata
cpdef str debug_error_string(self):
return self._debug_error_string
cdef grpc_status_code c_code(self):
return <grpc_status_code>self._code

@ -43,3 +43,4 @@ cdef class AioServer:
cdef object _shutdown_completed # asyncio.Future
cdef CallbackWrapper _shutdown_callback_wrapper
cdef object _crash_exception # Exception
cdef set _ongoing_rpc_tasks

@ -12,6 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
# TODO(https://github.com/grpc/grpc/issues/20850) refactor this.
_LOGGER = logging.getLogger(__name__)
cdef int _EMPTY_FLAG = 0
@ -23,9 +27,6 @@ cdef class _HandlerCallDetails:
self.invocation_metadata = invocation_metadata
class _ServicerContextPlaceHolder(object): pass
cdef class RPCState:
def __cinit__(self):
@ -43,12 +44,49 @@ cdef class RPCState:
grpc_call_unref(self.call)
cdef class _ServicerContext:
cdef RPCState _rpc_state
cdef object _loop
cdef bint _metadata_sent
cdef object _request_deserializer
cdef object _response_serializer
def __cinit__(self,
RPCState rpc_state,
object request_deserializer,
object response_serializer,
object loop):
self._rpc_state = rpc_state
self._request_deserializer = request_deserializer
self._response_serializer = response_serializer
self._loop = loop
self._metadata_sent = False
async def read(self):
cdef bytes raw_message = await _receive_message(self._rpc_state, self._loop)
return deserialize(self._request_deserializer,
raw_message)
async def write(self, object message):
await _send_message(self._rpc_state,
serialize(self._response_serializer, message),
self._metadata_sent,
self._loop)
if not self._metadata_sent:
self._metadata_sent = True
async def send_initial_metadata(self, tuple metadata):
if self._metadata_sent:
raise ValueError('Send initial metadata failed: already sent')
else:
_send_initial_metadata(self._rpc_state, self._loop)
self._metadata_sent = True
cdef _find_method_handler(str method, list generic_handlers):
# TODO(lidiz) connects Metadata to call details
cdef _HandlerCallDetails handler_call_details = _HandlerCallDetails(
method,
tuple()
)
cdef _HandlerCallDetails handler_call_details = _HandlerCallDetails(method,
None)
for generic_handler in generic_handlers:
method_handler = generic_handler.service(handler_call_details)
@ -61,64 +99,130 @@ async def _handle_unary_unary_rpc(object method_handler,
RPCState rpc_state,
object loop):
# Receives request message
cdef tuple receive_ops = (
ReceiveMessageOperation(_EMPTY_FLAGS),
)
await callback_start_batch(rpc_state, receive_ops, loop)
cdef bytes request_raw = await _receive_message(rpc_state, loop)
# Deserializes the request message
cdef bytes request_raw = receive_ops[0].message()
cdef object request_message
if method_handler.request_deserializer:
request_message = method_handler.request_deserializer(request_raw)
else:
request_message = request_raw
cdef object request_message = deserialize(
method_handler.request_deserializer,
request_raw,
)
# Executes application logic
cdef object response_message = await method_handler.unary_unary(request_message, _ServicerContextPlaceHolder())
cdef object response_message = await method_handler.unary_unary(
request_message,
_ServicerContext(
rpc_state,
None,
None,
loop,
),
)
# Serializes the response message
cdef bytes response_raw
if method_handler.response_serializer:
response_raw = method_handler.response_serializer(response_message)
else:
response_raw = response_message
cdef bytes response_raw = serialize(
method_handler.response_serializer,
response_message,
)
# Sends response message
cdef tuple send_ops = (
SendStatusFromServerOperation(
tuple(), StatusCode.ok, b'', _EMPTY_FLAGS),
SendInitialMetadataOperation(tuple(), _EMPTY_FLAGS),
tuple(),
StatusCode.ok,
b'',
_EMPTY_FLAGS,
),
SendInitialMetadataOperation(None, _EMPTY_FLAGS),
SendMessageOperation(response_raw, _EMPTY_FLAGS),
)
await callback_start_batch(rpc_state, send_ops, loop)
async def _handle_unary_stream_rpc(object method_handler,
RPCState rpc_state,
object loop):
# Receives request message
cdef bytes request_raw = await _receive_message(rpc_state, loop)
# Deserializes the request message
cdef object request_message = deserialize(
method_handler.request_deserializer,
request_raw,
)
cdef _ServicerContext servicer_context = _ServicerContext(
rpc_state,
method_handler.request_deserializer,
method_handler.response_serializer,
loop,
)
if inspect.iscoroutinefunction(method_handler.unary_stream):
# The handler uses reader / writer API, returns None.
await method_handler.unary_stream(
request_message,
servicer_context,
)
return
# The handler uses async generator API
cdef object async_response_generator = method_handler.unary_stream(
request_message,
servicer_context,
)
# Consumes messages from the generator
cdef object response_message
async for response_message in async_response_generator:
await servicer_context.write(response_message)
cdef SendStatusFromServerOperation op = SendStatusFromServerOperation(
None,
StatusCode.ok,
b'',
_EMPTY_FLAGS,
)
cdef tuple ops = (op,)
await callback_start_batch(rpc_state, ops, loop)
async def _handle_cancellation_from_core(object rpc_task,
RPCState rpc_state,
object loop):
cdef ReceiveCloseOnServerOperation op = ReceiveCloseOnServerOperation(_EMPTY_FLAG)
cdef tuple ops = (op,)
await callback_start_batch(rpc_state, ops, loop)
if op.cancelled() and not rpc_task.done():
rpc_task.cancel()
async def _handle_rpc(list generic_handlers, RPCState rpc_state, object loop):
# Finds the method handler (application logic)
cdef object method_handler = _find_method_handler(
rpc_state.method().decode(),
generic_handlers
generic_handlers,
)
if method_handler is None:
# TODO(lidiz) return unimplemented error to client side
raise NotImplementedError()
# TODO(lidiz) extend to all 4 types of RPC
if method_handler.request_streaming or method_handler.response_streaming:
raise NotImplementedError()
else:
await _handle_unary_unary_rpc(
method_handler,
if not method_handler.request_streaming and method_handler.response_streaming:
await _handle_unary_stream_rpc(method_handler,
rpc_state,
loop
)
loop)
elif not method_handler.request_streaming and not method_handler.response_streaming:
await _handle_unary_unary_rpc(method_handler,
rpc_state,
loop)
else:
raise NotImplementedError()
class _RequestCallError(Exception): pass
cdef CallbackFailureHandler REQUEST_CALL_FAILURE_HANDLER = CallbackFailureHandler(
'grpc_server_request_call', 'server shutdown', _RequestCallError)
'grpc_server_request_call', None, _RequestCallError)
async def _server_call_request_call(Server server,
@ -147,19 +251,9 @@ async def _server_call_request_call(Server server,
return rpc_state
async def _handle_cancellation_from_core(object rpc_task,
RPCState rpc_state,
object loop):
cdef ReceiveCloseOnServerOperation op = ReceiveCloseOnServerOperation(_EMPTY_FLAG)
cdef tuple ops = (op,)
await callback_start_batch(rpc_state, ops, loop)
if op.cancelled() and not rpc_task.done():
rpc_task.cancel()
cdef CallbackFailureHandler SERVER_SHUTDOWN_FAILURE_HANDLER = CallbackFailureHandler(
'grpc_server_shutdown_and_notify',
'Unknown',
None,
RuntimeError)
@ -182,6 +276,7 @@ cdef class AioServer:
self._generic_handlers = []
self.add_generic_rpc_handlers(generic_handlers)
self._serving_task = None
self._ongoing_rpc_tasks = set()
self._shutdown_lock = asyncio.Lock(loop=self._loop)
self._shutdown_completed = self._loop.create_future()
@ -221,11 +316,13 @@ cdef class AioServer:
if self._status != AIO_SERVER_STATUS_RUNNING:
break
# Accepts new request from C-Core
rpc_state = await _server_call_request_call(
self._server,
self._cq,
self._loop)
# Schedules the RPC as a separate coroutine
rpc_task = self._loop.create_task(
_handle_rpc(
self._generic_handlers,
@ -233,6 +330,8 @@ cdef class AioServer:
self._loop
)
)
# Fires off a task that listening on the cancellation from client.
self._loop.create_task(
_handle_cancellation_from_core(
rpc_task,
@ -241,6 +340,10 @@ cdef class AioServer:
)
)
# Keeps track of created coroutines, so we can clean them up properly.
self._ongoing_rpc_tasks.add(rpc_task)
rpc_task.add_done_callback(lambda _: self._ongoing_rpc_tasks.remove(rpc_task))
def _serving_task_crash_handler(self, object task):
"""Shutdown the server immediately if unexpectedly exited."""
if task.exception() is None:
@ -318,6 +421,10 @@ cdef class AioServer:
grpc_server_cancel_all_calls(self._server.c_server)
await self._shutdown_completed
# Cancels all Python layer tasks
for rpc_task in self._ongoing_rpc_tasks:
rpc_task.cancel()
async with self._shutdown_lock:
if self._status == AIO_SERVER_STATUS_STOPPING:
grpc_server_destroy(self._server.c_server)
@ -328,7 +435,7 @@ cdef class AioServer:
# Shuts down the completion queue
await self._cq.shutdown()
async def wait_for_termination(self, float timeout):
async def wait_for_termination(self, object timeout):
if timeout is None:
await self._shutdown_completed
else:

@ -43,9 +43,9 @@ IF UNAME_SYSNAME != "Windows":
include "_cygrpc/aio/iomgr/socket.pxd.pxi"
include "_cygrpc/aio/iomgr/timer.pxd.pxi"
include "_cygrpc/aio/iomgr/resolver.pxd.pxi"
include "_cygrpc/aio/rpc_status.pxd.pxi"
include "_cygrpc/aio/grpc_aio.pxd.pxi"
include "_cygrpc/aio/callback_common.pxd.pxi"
include "_cygrpc/aio/call.pxd.pxi"
include "_cygrpc/aio/cancel_status.pxd.pxi"
include "_cygrpc/aio/channel.pxd.pxi"
include "_cygrpc/aio/server.pxd.pxi"

@ -60,12 +60,12 @@ include "_cygrpc/aio/iomgr/iomgr.pyx.pxi"
include "_cygrpc/aio/iomgr/socket.pyx.pxi"
include "_cygrpc/aio/iomgr/timer.pyx.pxi"
include "_cygrpc/aio/iomgr/resolver.pyx.pxi"
include "_cygrpc/aio/common.pyx.pxi"
include "_cygrpc/aio/rpc_status.pyx.pxi"
include "_cygrpc/aio/callback_common.pyx.pxi"
include "_cygrpc/aio/grpc_aio.pyx.pxi"
include "_cygrpc/aio/call.pyx.pxi"
include "_cygrpc/aio/callback_common.pyx.pxi"
include "_cygrpc/aio/cancel_status.pyx.pxi"
include "_cygrpc/aio/channel.pyx.pxi"
include "_cygrpc/aio/rpc_error.pyx.pxi"
include "_cygrpc/aio/server.pyx.pxi"

@ -7,6 +7,8 @@ py_library(
"aio/_call.py",
"aio/_channel.py",
"aio/_server.py",
"aio/_typing.py",
"aio/_base_call.py",
],
deps = [
"//src/python/grpcio/grpc/_cython:cygrpc",

@ -17,12 +17,9 @@ import abc
import six
import grpc
from grpc import _common
from grpc._cython import cygrpc
from grpc._cython.cygrpc import init_grpc_aio
from ._call import AioRpcError
from ._call import Call
from ._base_call import Call, UnaryUnaryCall, UnaryStreamCall
from ._channel import Channel
from ._channel import UnaryUnaryMultiCallable
from ._server import server
@ -47,5 +44,5 @@ def insecure_channel(target, options=None, compression=None):
################################### __all__ #################################
__all__ = ('AioRpcError', 'Call', 'init_grpc_aio', 'Channel',
'UnaryUnaryMultiCallable', 'insecure_channel', 'server')
__all__ = ('Call', 'UnaryUnaryCall', 'UnaryStreamCall', 'init_grpc_aio',
'Channel', 'UnaryUnaryMultiCallable', 'insecure_channel', 'server')

@ -0,0 +1,127 @@
# Copyright 2019 The gRPC Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Abstract base classes for client-side Call objects.
Call objects represents the RPC itself, and offer methods to access / modify
its information. They also offer methods to manipulate the life-cycle of the
RPC, e.g. cancellation.
"""
from abc import ABCMeta, abstractmethod
from typing import AsyncIterable, Awaitable, Generic, Text
import grpc
from ._typing import MetadataType, RequestType, ResponseType
__all__ = 'Call', 'UnaryUnaryCall', 'UnaryStreamCall'
class Call(grpc.RpcContext, metaclass=ABCMeta):
"""The abstract base class of an RPC on the client-side."""
@abstractmethod
def cancelled(self) -> bool:
"""Return True if the RPC is cancelled.
The RPC is cancelled when the cancellation was requested with cancel().
Returns:
A bool indicates whether the RPC is cancelled or not.
"""
@abstractmethod
def done(self) -> bool:
"""Return True if the RPC is done.
An RPC is done if the RPC is completed, cancelled or aborted.
Returns:
A bool indicates if the RPC is done.
"""
@abstractmethod
async def initial_metadata(self) -> MetadataType:
"""Accesses the initial metadata sent by the server.
Coroutine continues once the value is available.
Returns:
The initial :term:`metadata`.
"""
@abstractmethod
async def trailing_metadata(self) -> MetadataType:
"""Accesses the trailing metadata sent by the server.
Coroutine continues once the value is available.
Returns:
The trailing :term:`metadata`.
"""
@abstractmethod
async def code(self) -> grpc.StatusCode:
"""Accesses the status code sent by the server.
Coroutine continues once the value is available.
Returns:
The StatusCode value for the RPC.
"""
@abstractmethod
async def details(self) -> Text:
"""Accesses the details sent by the server.
Coroutine continues once the value is available.
Returns:
The details string of the RPC.
"""
class UnaryUnaryCall(Generic[RequestType, ResponseType], Call, metaclass=ABCMeta):
"""The abstract base class of an unary-unary RPC on the client-side."""
@abstractmethod
def __await__(self) -> Awaitable[ResponseType]:
"""Await the response message to be ready.
Returns:
The response message of the RPC.
"""
class UnaryStreamCall(Generic[RequestType, ResponseType], Call, metaclass=ABCMeta):
@abstractmethod
def __aiter__(self) -> AsyncIterable[ResponseType]:
"""Returns the async iterable representation that yields messages.
Under the hood, it is calling the "read" method.
Returns:
An async iterable object that yields messages.
"""
@abstractmethod
async def read(self) -> ResponseType:
"""Reads one message from the RPC.
Parallel read operations is not allowed.
Returns:
A response message of the RPC.
"""

@ -12,19 +12,42 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Invocation-side implementation of gRPC Asyncio Python."""
import asyncio
import enum
from typing import Callable, Dict, Optional, ClassVar
from typing import AsyncIterable, Awaitable, Dict, Optional
import grpc
from grpc import _common
from grpc._cython import cygrpc
DeserializingFunction = Callable[[bytes], str]
from . import _base_call
from ._typing import (DeserializingFunction, MetadataType, RequestType,
ResponseType, SerializingFunction)
__all__ = 'AioRpcError', 'Call', 'UnaryUnaryCall', 'UnaryStreamCall'
_LOCAL_CANCELLATION_DETAILS = 'Locally cancelled by application!'
_GC_CANCELLATION_DETAILS = 'Cancelled upon garbage collection!'
_RPC_ALREADY_FINISHED_DETAILS = 'RPC already finished.'
_OK_CALL_REPRESENTATION = ('<{} of RPC that terminated with:\n'
'\tstatus = {}\n'
'\tdetails = "{}"\n'
'>')
_NON_OK_CALL_REPRESENTATION = ('<{} of RPC that terminated with:\n'
'\tstatus = {}\n'
'\tdetails = "{}"\n'
'\tdebug_error_string = "{}"\n'
'>')
class AioRpcError(grpc.RpcError):
"""An RpcError to be used by the asynchronous API."""
"""An implementation of RpcError to be used by the asynchronous API.
Raised RpcError is a snapshot of the final status of the RPC, values are
determined. Hence, its methods no longer needs to be coroutines.
"""
# TODO(https://github.com/grpc/grpc/issues/20144) Metadata
# type returned by `initial_metadata` and `trailing_metadata`
@ -33,14 +56,16 @@ class AioRpcError(grpc.RpcError):
_code: grpc.StatusCode
_details: Optional[str]
_initial_metadata: Optional[Dict]
_trailing_metadata: Optional[Dict]
_initial_metadata: Optional[MetadataType]
_trailing_metadata: Optional[MetadataType]
_debug_error_string: Optional[str]
def __init__(self,
code: grpc.StatusCode,
details: Optional[str] = None,
initial_metadata: Optional[Dict] = None,
trailing_metadata: Optional[Dict] = None):
initial_metadata: Optional[MetadataType] = None,
trailing_metadata: Optional[MetadataType] = None,
debug_error_string: Optional[str] = None) -> None:
"""Constructor.
Args:
@ -56,207 +81,335 @@ class AioRpcError(grpc.RpcError):
self._details = details
self._initial_metadata = initial_metadata
self._trailing_metadata = trailing_metadata
self._debug_error_string = debug_error_string
def code(self) -> grpc.StatusCode:
"""
"""Accesses the status code sent by the server.
Returns:
The `grpc.StatusCode` status code.
"""
return self._code
def details(self) -> Optional[str]:
"""
"""Accesses the details sent by the server.
Returns:
The description of the error.
"""
return self._details
def initial_metadata(self) -> Optional[Dict]:
"""
"""Accesses the initial metadata sent by the server.
Returns:
The inital metadata received.
The initial metadata received.
"""
return self._initial_metadata
def trailing_metadata(self) -> Optional[Dict]:
"""
"""Accesses the trailing metadata sent by the server.
Returns:
The trailing metadata received.
"""
return self._trailing_metadata
def debug_error_string(self) -> str:
"""Accesses the debug error string sent by the server.
@enum.unique
class _RpcState(enum.Enum):
"""Identifies the state of the RPC."""
ONGOING = 1
CANCELLED = 2
FINISHED = 3
ABORT = 4
Returns:
The debug error string received.
"""
return self._debug_error_string
def _repr(self) -> str:
"""Assembles the error string for the RPC error."""
return _NON_OK_CALL_REPRESENTATION.format(self.__class__.__name__,
self._code, self._details,
self._debug_error_string)
class Call:
"""Object for managing RPC calls,
returned when an instance of `UnaryUnaryMultiCallable` object is called.
"""
def __repr__(self) -> str:
return self._repr()
_cancellation_details: ClassVar[str] = 'Locally cancelled by application!'
def __str__(self) -> str:
return self._repr()
_state: _RpcState
_exception: Optional[Exception]
_response: Optional[bytes]
_code: grpc.StatusCode
_details: Optional[str]
_initial_metadata: Optional[Dict]
_trailing_metadata: Optional[Dict]
_call: asyncio.Task
_call_cancel_status: cygrpc.AioCancelStatus
_response_deserializer: DeserializingFunction
def __init__(self, call: asyncio.Task,
response_deserializer: DeserializingFunction,
call_cancel_status: cygrpc.AioCancelStatus) -> None:
"""Constructor.
def _create_rpc_error(initial_metadata: Optional[MetadataType],
status: cygrpc.AioRpcStatus) -> AioRpcError:
return AioRpcError(_common.CYGRPC_STATUS_CODE_TO_STATUS_CODE[status.code()],
status.details(), initial_metadata,
status.trailing_metadata())
Args:
call: Asyncio Task that holds the RPC execution.
response_deserializer: Deserializer used for parsing the reponse.
call_cancel_status: A cygrpc.AioCancelStatus used for giving a
specific error when the RPC is canceled.
"""
self._state = _RpcState.ONGOING
self._exception = None
self._response = None
self._code = grpc.StatusCode.UNKNOWN
self._details = None
self._initial_metadata = None
self._trailing_metadata = None
self._call = call
self._call_cancel_status = call_cancel_status
self._response_deserializer = response_deserializer
class Call(_base_call.Call):
_loop: asyncio.AbstractEventLoop
_code: grpc.StatusCode
_status: Awaitable[cygrpc.AioRpcStatus]
_initial_metadata: Awaitable[MetadataType]
_cancellation_future: asyncio.Future
def __del__(self):
self.cancel()
def __init__(self) -> None:
self._loop = asyncio.get_event_loop()
self._code = None
self._status = self._loop.create_future()
self._initial_metadata = self._loop.create_future()
self._cancellation_future = self._loop.create_future()
def cancel(self) -> bool:
"""Cancels the ongoing RPC request.
"""Virtual cancellation method.
Returns:
True if the RPC can be canceled, False if was already cancelled or terminated.
The implementation of this method needs to pass the cancellation reason
into self._cancellation_future, using `set_result` instead of
`set_exception`.
"""
if self.cancelled() or self.done():
return False
code = grpc.StatusCode.CANCELLED
self._call_cancel_status.cancel(
_common.STATUS_CODE_TO_CYGRPC_STATUS_CODE[code],
details=Call._cancellation_details)
self._call.cancel()
self._details = Call._cancellation_details
self._code = code
self._state = _RpcState.CANCELLED
return True
raise NotImplementedError()
def cancelled(self) -> bool:
"""Returns if the RPC was cancelled.
Returns:
True if the requests was cancelled, False if not.
"""
return self._state is _RpcState.CANCELLED
def running(self) -> bool:
"""Returns if the RPC is running.
Returns:
True if the requests is running, False if it already terminated.
"""
return not self.done()
return self._cancellation_future.done(
) or self._code == grpc.StatusCode.CANCELLED
def done(self) -> bool:
"""Returns if the RPC has finished.
return self._status.done()
Returns:
True if the requests has finished, False is if still ongoing.
"""
return self._state is not _RpcState.ONGOING
def add_callback(self, unused_callback) -> None:
pass
async def initial_metadata(self):
raise NotImplementedError()
def is_active(self) -> bool:
return self.done()
async def trailing_metadata(self):
raise NotImplementedError()
def time_remaining(self) -> float:
pass
async def code(self) -> grpc.StatusCode:
"""Returns the `grpc.StatusCode` if the RPC is finished,
otherwise first waits until the RPC finishes.
async def initial_metadata(self) -> MetadataType:
return await self._initial_metadata
Returns:
The `grpc.StatusCode` status code.
"""
if not self.done():
try:
await self
except (asyncio.CancelledError, AioRpcError):
pass
async def trailing_metadata(self) -> MetadataType:
return (await self._status).trailing_metadata()
async def code(self) -> grpc.StatusCode:
await self._status
return self._code
async def details(self) -> str:
"""Returns the details if the RPC is finished, otherwise first waits till the
RPC finishes.
return (await self._status).details()
Returns:
The details.
async def debug_error_string(self) -> str:
return (await self._status).debug_error_string()
def _set_initial_metadata(self, metadata: MetadataType) -> None:
self._initial_metadata.set_result(metadata)
def _set_status(self, status: cygrpc.AioRpcStatus) -> None:
"""Private method to set final status of the RPC.
This method may be called multiple time due to data race between local
cancellation (by application) and C-Core receiving status from peer. We
make no promise here which one will win.
"""
if not self.done():
try:
await self
except (asyncio.CancelledError, AioRpcError):
pass
if self._status.done():
return
else:
self._status.set_result(status)
self._code = _common.CYGRPC_STATUS_CODE_TO_STATUS_CODE[
status.code()]
async def _raise_rpc_error_if_not_ok(self) -> None:
if self._code != grpc.StatusCode.OK:
raise _create_rpc_error(await self.initial_metadata(),
self._status.result())
def _repr(self) -> str:
"""Assembles the RPC representation string."""
if not self._status.done():
return '<{} object>'.format(self.__class__.__name__)
if self._code is grpc.StatusCode.OK:
return _OK_CALL_REPRESENTATION.format(
self.__class__.__name__, self._code,
self._status.result().self._status.result().details())
else:
return _NON_OK_CALL_REPRESENTATION.format(
self.__class__.__name__, self._code,
self._status.result().details(),
self._status.result().debug_error_string())
def __repr__(self) -> str:
return self._repr()
def __str__(self) -> str:
return self._repr()
class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall):
"""Object for managing unary-unary RPC calls.
Returned when an instance of `UnaryUnaryMultiCallable` object is called.
"""
_loop: asyncio.AbstractEventLoop
_request: RequestType
_deadline: Optional[float]
_channel: cygrpc.AioChannel
_method: bytes
_request_serializer: SerializingFunction
_response_deserializer: DeserializingFunction
_call: asyncio.Task
return self._details
def __init__(self, request: RequestType, deadline: Optional[float],
channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction) -> None:
super().__init__()
self._loop = asyncio.get_event_loop()
self._request = request
self._deadline = deadline
self._channel = channel
self._method = method
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
self._call = self._loop.create_task(self._invoke())
def __del__(self) -> None:
if not self._call.done():
self._cancel(
cygrpc.AioRpcStatus(cygrpc.StatusCode.cancelled,
_GC_CANCELLATION_DETAILS, None, None))
async def _invoke(self) -> ResponseType:
serialized_request = _common.serialize(self._request,
self._request_serializer)
# NOTE(lidiz) asyncio.CancelledError is not a good transport for
# status, since the Task class do not cache the exact
# asyncio.CancelledError object. So, the solution is catching the error
# in Cython layer, then cancel the RPC and update the status, finally
# re-raise the CancelledError.
serialized_response = await self._channel.unary_unary(
self._method,
serialized_request,
self._deadline,
self._cancellation_future,
self._set_initial_metadata,
self._set_status,
)
await self._raise_rpc_error_if_not_ok()
return _common.deserialize(serialized_response,
self._response_deserializer)
def __await__(self):
def _cancel(self, status: cygrpc.AioRpcStatus) -> bool:
"""Forwards the application cancellation reasoning."""
if not self._status.done() and not self._cancellation_future.done():
self._cancellation_future.set_result(status)
self._call.cancel()
return True
else:
return False
def cancel(self) -> bool:
return self._cancel(
cygrpc.AioRpcStatus(cygrpc.StatusCode.cancelled,
_LOCAL_CANCELLATION_DETAILS, None, None))
def __await__(self) -> ResponseType:
"""Wait till the ongoing RPC request finishes.
Returns:
Response of the RPC call.
Raises:
AioRpcError: Indicating that the RPC terminated with non-OK status.
RpcError: Indicating that the RPC terminated with non-OK status.
asyncio.CancelledError: Indicating that the RPC was canceled.
"""
# We can not relay on the `done()` method since some exceptions
# might be pending to be catched, like `asyncio.CancelledError`.
if self._response:
return self._response
elif self._exception:
raise self._exception
try:
buffer_ = yield from self._call.__await__()
except cygrpc.AioRpcError as aio_rpc_error:
self._state = _RpcState.ABORT
self._code = _common.CYGRPC_STATUS_CODE_TO_STATUS_CODE[
aio_rpc_error.code()]
self._details = aio_rpc_error.details()
self._initial_metadata = aio_rpc_error.initial_metadata()
self._trailing_metadata = aio_rpc_error.trailing_metadata()
# Propagates the pure Python class
self._exception = AioRpcError(self._code, self._details,
self._initial_metadata,
self._trailing_metadata)
raise self._exception from aio_rpc_error
except asyncio.CancelledError as cancel_error:
# _state, _code, _details are managed in the `cancel` method
self._exception = cancel_error
raise
self._response = _common.deserialize(buffer_,
response = yield from self._call
return response
class UnaryStreamCall(Call, _base_call.UnaryStreamCall):
"""Object for managing unary-stream RPC calls.
Returned when an instance of `UnaryStreamMultiCallable` object is called.
"""
_loop: asyncio.AbstractEventLoop
_request: RequestType
_deadline: Optional[float]
_channel: cygrpc.AioChannel
_method: bytes
_request_serializer: SerializingFunction
_response_deserializer: DeserializingFunction
_call: AsyncIterable[ResponseType]
def __init__(self, request: RequestType, deadline: Optional[float],
channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction) -> None:
super().__init__()
self._loop = asyncio.get_event_loop()
self._request = request
self._deadline = deadline
self._channel = channel
self._method = method
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
self._call = self._invoke()
def __del__(self) -> None:
if not self._status.done():
self._cancel(
cygrpc.AioRpcStatus(cygrpc.StatusCode.cancelled,
_GC_CANCELLATION_DETAILS, None, None))
async def _invoke(self) -> ResponseType:
serialized_request = _common.serialize(self._request,
self._request_serializer)
async_gen = self._channel.unary_stream(
self._method,
serialized_request,
self._deadline,
self._cancellation_future,
self._set_initial_metadata,
self._set_status,
)
async for serialized_response in async_gen:
if self._cancellation_future.done():
await self._status
if self._status.done():
# Raises pre-maturely if final status received here. Generates
# more helpful stack trace for end users.
await self._raise_rpc_error_if_not_ok()
yield _common.deserialize(serialized_response,
self._response_deserializer)
self._code = grpc.StatusCode.OK
self._state = _RpcState.FINISHED
return self._response
await self._raise_rpc_error_if_not_ok()
def _cancel(self, status: cygrpc.AioRpcStatus) -> bool:
"""Forwards the application cancellation reasoning.
Async generator will receive an exception. The cancellation will go
deep down into C-Core, and then propagates backup as the
`cygrpc.AioRpcStatus` exception.
So, under race condition, e.g. the server sent out final state headers
and the client calling "cancel" at the same time, this method respects
the winner in C-Core.
"""
if not self._status.done() and not self._cancellation_future.done():
self._cancellation_future.set_result(status)
return True
else:
return False
def cancel(self) -> bool:
return self._cancel(
cygrpc.AioRpcStatus(cygrpc.StatusCode.cancelled,
_LOCAL_CANCELLATION_DETAILS, None, None))
def __aiter__(self) -> AsyncIterable[ResponseType]:
return self._call
async def read(self) -> ResponseType:
if self._status.done():
await self._raise_rpc_error_if_not_ok()
raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
return await self._call.__anext__()

@ -13,42 +13,111 @@
# limitations under the License.
"""Invocation-side implementation of gRPC Asyncio Python."""
import asyncio
from typing import Callable, Optional
from typing import Any, Optional, Sequence, Text, Tuple
import grpc
from grpc import _common
from grpc._cython import cygrpc
from ._call import Call, UnaryUnaryCall, UnaryStreamCall
from ._typing import (DeserializingFunction, MetadataType, SerializingFunction)
from ._call import Call
SerializingFunction = Callable[[str], bytes]
DeserializingFunction = Callable[[bytes], str]
def _timeout_to_deadline(loop: asyncio.AbstractEventLoop,
timeout: int) -> Optional[int]:
if timeout is None:
return None
return loop.time() + timeout
class UnaryUnaryMultiCallable:
"""Afford invoking a unary-unary RPC from client-side in an asynchronous way."""
"""Factory an asynchronous unary-unary RPC stub call from client-side."""
def __init__(self, channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction) -> None:
self._loop = asyncio.get_event_loop()
self._channel = channel
self._method = method
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
self._loop = asyncio.get_event_loop()
def _timeout_to_deadline(self, timeout: int) -> Optional[int]:
if timeout is None:
return None
return self._loop.time() + timeout
def __call__(self,
request: Any,
*,
timeout: Optional[float] = None,
metadata: Optional[MetadataType] = None,
credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None) -> Call:
"""Asynchronously invokes the underlying RPC.
Args:
request: The request value for the RPC.
timeout: An optional duration of time in seconds to allow
for the RPC.
metadata: Optional :term:`metadata` to be transmitted to the
service-side of the RPC.
credentials: An optional CallCredentials for the RPC. Only valid for
secure Channel.
wait_for_ready: This is an EXPERIMENTAL argument. An optional
flag to enable wait for ready mechanism
compression: An element of grpc.compression, e.g.
grpc.compression.Gzip. This is an EXPERIMENTAL option.
Returns:
A Call object instance which is an awaitable object.
Raises:
RpcError: Indicating that the RPC terminated with non-OK status. The
raised RpcError will also be a Call for the RPC affording the RPC's
metadata, status code, and details.
"""
if metadata:
raise NotImplementedError("TODO: metadata not implemented yet")
if credentials:
raise NotImplementedError("TODO: credentials not implemented yet")
if wait_for_ready:
raise NotImplementedError(
"TODO: wait_for_ready not implemented yet")
if compression:
raise NotImplementedError("TODO: compression not implemented yet")
deadline = _timeout_to_deadline(self._loop, timeout)
return UnaryUnaryCall(
request,
deadline,
self._channel,
self._method,
self._request_serializer,
self._response_deserializer,
)
class UnaryStreamMultiCallable:
"""Afford invoking a unary-stream RPC from client-side in an asynchronous way."""
def __init__(self, channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction) -> None:
self._channel = channel
self._method = method
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
self._loop = asyncio.get_event_loop()
def __call__(self,
request: Any,
*,
timeout=None,
metadata=None,
credentials=None,
wait_for_ready=None,
compression=None) -> Call:
timeout: Optional[float] = None,
metadata: Optional[MetadataType] = None,
credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None) -> Call:
"""Asynchronously invokes the underlying RPC.
Args:
@ -86,15 +155,16 @@ class UnaryUnaryMultiCallable:
if compression:
raise NotImplementedError("TODO: compression not implemented yet")
serialized_request = _common.serialize(request,
self._request_serializer)
timeout = self._timeout_to_deadline(timeout)
aio_cancel_status = cygrpc.AioCancelStatus()
aio_call = asyncio.ensure_future(
self._channel.unary_unary(self._method, serialized_request, timeout,
aio_cancel_status),
loop=self._loop)
return Call(aio_call, self._response_deserializer, aio_cancel_status)
deadline = _timeout_to_deadline(self._loop, timeout)
return UnaryStreamCall(
request,
deadline,
self._channel,
self._method,
self._request_serializer,
self._response_deserializer,
)
class Channel:
@ -103,7 +173,10 @@ class Channel:
A cygrpc.AioChannel-backed implementation.
"""
def __init__(self, target, options, credentials, compression):
def __init__(self, target: Text,
options: Optional[Sequence[Tuple[Text, Any]]],
credentials: Optional[grpc.ChannelCredentials],
compression: Optional[grpc.Compression]):
"""Constructor.
Args:
@ -125,10 +198,11 @@ class Channel:
self._channel = cygrpc.AioChannel(_common.encode(target))
def unary_unary(self,
method,
request_serializer=None,
response_deserializer=None):
def unary_unary(
self,
method: Text,
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None):
"""Creates a UnaryUnaryMultiCallable for a unary-unary method.
Args:
@ -146,6 +220,29 @@ class Channel:
request_serializer,
response_deserializer)
def unary_stream(
self,
method: Text,
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None):
return UnaryStreamMultiCallable(self._channel, _common.encode(method),
request_serializer,
response_deserializer)
def stream_unary(
self,
method: Text,
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None):
"""Placeholder method for stream-unary calls."""
def stream_stream(
self,
method: Text,
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None):
"""Placeholder method for stream-stream calls."""
async def _close(self):
# TODO: Send cancellation status
self._channel.close()

@ -1,4 +1,4 @@
# Copyright 2019 gRPC authors.
# Copyright 2019 The gRPC Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -11,13 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Desired cancellation status for canceling an ongoing RPC calls."""
"""Common types for gRPC Async API"""
from typing import Any, AnyStr, Callable, Sequence, Text, Tuple, TypeVar
cdef class AioCancelStatus:
cdef readonly:
object _code
str _details
cpdef object code(self)
cpdef str details(self)
RequestType = TypeVar('RequestType')
ResponseType = TypeVar('ResponseType')
SerializingFunction = Callable[[Any], bytes]
DeserializingFunction = Callable[[bytes], Any]
MetadataType = Sequence[Tuple[Text, AnyStr]]

@ -0,0 +1,32 @@
# Copyright 2019 The gRPC Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(
default_testonly = 1,
default_visibility = ["//visibility:public"],
)
py_binary(
name = "server",
srcs = ["server.py"],
python_version = "PY3",
deps = [
"//external:six",
"//src/proto/grpc/testing:benchmark_service_py_pb2",
"//src/proto/grpc/testing:benchmark_service_py_pb2_grpc",
"//src/proto/grpc/testing:py_messages_proto",
"//src/python/grpcio/grpc:grpcio",
"//src/python/grpcio_tests/tests/unit/framework/common",
],
)

@ -27,6 +27,12 @@ class BenchmarkServer(benchmark_service_pb2_grpc.BenchmarkServiceServicer):
payload = messages_pb2.Payload(body=b'\0' * request.response_size)
return messages_pb2.SimpleResponse(payload=payload)
async def StreamingFromServer(self, request, context):
payload = messages_pb2.Payload(body=b'\0' * request.response_size)
# Sends response at full capacity!
while True:
yield messages_pb2.SimpleResponse(payload=payload)
async def _start_async_server():
server = aio.server()
@ -37,6 +43,7 @@ async def _start_async_server():
servicer, server)
await server.start()
logging.info('Benchmark server started at :%d' % port)
await server.wait_for_termination()
@ -48,5 +55,5 @@ def main():
if __name__ == '__main__':
logging.basicConfig()
logging.basicConfig(level=logging.DEBUG)
main()

@ -12,18 +12,56 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import functools
import asyncio
from typing import Callable
import unittest
from grpc.experimental import aio
__all__ = 'AioTestBase'
_COROUTINE_FUNCTION_ALLOWLIST = ['setUp', 'tearDown']
def _async_to_sync_decorator(f: Callable, loop: asyncio.AbstractEventLoop):
@functools.wraps(f)
def wrapper(*args, **kwargs):
return loop.run_until_complete(f(*args, **kwargs))
return wrapper
def _get_default_loop(debug=True):
try:
loop = asyncio.get_event_loop()
except:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
finally:
loop.set_debug(debug)
return loop
class AioTestBase(unittest.TestCase):
def setUp(self):
self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop)
@classmethod
def setUpClass(cls):
cls._loop = _get_default_loop()
aio.init_grpc_aio()
@property
def loop(self):
return self._loop
def __getattribute__(self, name):
"""Overrides the loading logic to support coroutine functions."""
attr = super().__getattribute__(name)
# If possible, converts the coroutine into a sync function.
if name.startswith('test_') or name in _COROUTINE_FUNCTION_ALLOWLIST:
if asyncio.iscoroutinefunction(attr):
return _async_to_sync_decorator(attr, _get_default_loop())
# For other attributes, let them pass.
return attr

@ -12,22 +12,39 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
from time import sleep
import logging
from grpc.experimental import aio
from src.proto.grpc.testing import messages_pb2
from src.proto.grpc.testing import test_pb2_grpc
from tests.unit.framework.common import test_constants
_US_IN_A_SECOND = 1000 * 1000
class _TestServiceServicer(test_pb2_grpc.TestServiceServicer):
async def UnaryCall(self, request, context):
return messages_pb2.SimpleResponse()
# TODO(lidizheng) The semantic of this call is not matching its description
# See src/proto/grpc/testing/test.proto
async def EmptyCall(self, request, context):
while True:
sleep(test_constants.LONG_TIMEOUT)
await asyncio.sleep(test_constants.LONG_TIMEOUT)
async def StreamingOutputCall(
self, request: messages_pb2.StreamingOutputCallRequest, context):
for response_parameters in request.response_parameters:
if response_parameters.interval_us != 0:
await asyncio.sleep(
response_parameters.interval_us / _US_IN_A_SECOND)
yield messages_pb2.StreamingOutputCallResponse(
payload=messages_pb2.Payload(
type=request.response_type,
body=b'\x00' * response_parameters.size))
async def start_test_server():

@ -0,0 +1,45 @@
# Copyright 2019 The gRPC Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests AioRpcError class."""
import logging
import unittest
import grpc
from grpc.experimental.aio._call import AioRpcError
from tests_aio.unit._test_base import AioTestBase
class TestAioRpcError(unittest.TestCase):
_TEST_INITIAL_METADATA = ("initial metadata",)
_TEST_TRAILING_METADATA = ("trailing metadata",)
def test_attributes(self):
aio_rpc_error = AioRpcError(
grpc.StatusCode.CANCELLED,
"details",
initial_metadata=self._TEST_INITIAL_METADATA,
trailing_metadata=self._TEST_TRAILING_METADATA)
self.assertEqual(aio_rpc_error.code(), grpc.StatusCode.CANCELLED)
self.assertEqual(aio_rpc_error.details(), "details")
self.assertEqual(aio_rpc_error.initial_metadata(),
self._TEST_INITIAL_METADATA)
self.assertEqual(aio_rpc_error.trailing_metadata(),
self._TEST_TRAILING_METADATA)
if __name__ == '__main__':
logging.basicConfig()
unittest.main(verbosity=2)

@ -11,51 +11,43 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests behavior of the grpc.aio.UnaryUnaryCall class."""
import asyncio
import logging
import unittest
import datetime
import grpc
from grpc.experimental import aio
from src.proto.grpc.testing import messages_pb2
from src.proto.grpc.testing import test_pb2_grpc
from tests.unit.framework.common import test_constants
from tests_aio.unit._test_server import start_test_server
from tests_aio.unit._test_base import AioTestBase
class TestAioRpcError(unittest.TestCase):
_TEST_INITIAL_METADATA = ("initial metadata",)
_TEST_TRAILING_METADATA = ("trailing metadata",)
def test_attributes(self):
aio_rpc_error = aio.AioRpcError(
grpc.StatusCode.CANCELLED,
"details",
initial_metadata=self._TEST_INITIAL_METADATA,
trailing_metadata=self._TEST_TRAILING_METADATA)
self.assertEqual(aio_rpc_error.code(), grpc.StatusCode.CANCELLED)
self.assertEqual(aio_rpc_error.details(), "details")
self.assertEqual(aio_rpc_error.initial_metadata(),
self._TEST_INITIAL_METADATA)
self.assertEqual(aio_rpc_error.trailing_metadata(),
self._TEST_TRAILING_METADATA)
_NUM_STREAM_RESPONSES = 5
_RESPONSE_PAYLOAD_SIZE = 42
_LOCAL_CANCEL_DETAILS_EXPECTATION = 'Locally cancelled by application!'
# _RESPONSE_INTERVAL_US = test_constants.SHORT_TIMEOUT * 1000 * 1000
_RESPONSE_INTERVAL_US = 200 * 1000
class TestCall(AioTestBase):
def test_call_ok(self):
async def setUp(self):
self._server_target, self._server = await start_test_server()
async def coro():
server_target, _ = await start_test_server() # pylint: disable=unused-variable
async def tearDown(self):
await self._server.stop(None)
async with aio.insecure_channel(server_target) as channel:
async def test_call_ok(self):
async with aio.insecure_channel(self._server_target) as channel:
hi = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.SimpleRequest.
SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString
)
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
call = hi(messages_pb2.SimpleRequest())
self.assertFalse(call.done())
@ -71,20 +63,12 @@ class TestCall(AioTestBase):
response_retry = await call
self.assertIs(response, response_retry)
self.loop.run_until_complete(coro())
def test_call_rpc_error(self):
async def coro():
server_target, _ = await start_test_server() # pylint: disable=unused-variable
async with aio.insecure_channel(server_target) as channel:
async def test_call_rpc_error(self):
async with aio.insecure_channel(self._server_target) as channel:
empty_call_with_sleep = channel.unary_unary(
"/grpc.testing.TestService/EmptyCall",
request_serializer=messages_pb2.SimpleRequest.
SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.
FromString,
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString,
)
timeout = test_constants.SHORT_TIMEOUT / 2
# TODO(https://github.com/grpc/grpc/issues/20869
@ -104,61 +88,36 @@ class TestCall(AioTestBase):
# Exception is cached at call object level, reentrance
# returns again the same exception
with self.assertRaises(
grpc.RpcError) as exception_context_retry:
with self.assertRaises(grpc.RpcError) as exception_context_retry:
await call
self.assertIs(exception_context.exception,
exception_context_retry.exception)
self.loop.run_until_complete(coro())
def test_call_code_awaitable(self):
async def coro():
server_target, _ = await start_test_server() # pylint: disable=unused-variable
async with aio.insecure_channel(server_target) as channel:
async def test_call_code_awaitable(self):
async with aio.insecure_channel(self._server_target) as channel:
hi = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.SimpleRequest.
SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString
)
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
call = hi(messages_pb2.SimpleRequest())
self.assertEqual(await call.code(), grpc.StatusCode.OK)
self.loop.run_until_complete(coro())
def test_call_details_awaitable(self):
async def coro():
server_target, _ = await start_test_server() # pylint: disable=unused-variable
async with aio.insecure_channel(server_target) as channel:
async def test_call_details_awaitable(self):
async with aio.insecure_channel(self._server_target) as channel:
hi = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.SimpleRequest.
SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString
)
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
call = hi(messages_pb2.SimpleRequest())
self.assertEqual(await call.details(), None)
self.loop.run_until_complete(coro())
def test_cancel(self):
async def coro():
server_target, _ = await start_test_server() # pylint: disable=unused-variable
self.assertEqual('', await call.details())
async with aio.insecure_channel(server_target) as channel:
async def test_cancel_unary_unary(self):
async with aio.insecure_channel(self._server_target) as channel:
hi = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.SimpleRequest.
SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString
)
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
call = hi(messages_pb2.SimpleRequest())
self.assertFalse(call.cancelled())
@ -168,29 +127,172 @@ class TestCall(AioTestBase):
await asyncio.sleep(0)
self.assertTrue(call.cancel())
self.assertTrue(call.cancelled())
self.assertFalse(call.cancel())
with self.assertRaises(
asyncio.CancelledError) as exception_context:
with self.assertRaises(asyncio.CancelledError) as exception_context:
await call
self.assertTrue(call.cancelled())
self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED)
self.assertEqual(await call.details(),
'Locally cancelled by application!')
# Exception is cached at call object level, reentrance
# returns again the same exception
with self.assertRaises(
asyncio.CancelledError) as exception_context_retry:
await call
# NOTE(lidiz) The CancelledError is almost always re-created,
# so we might not want to use it to transmit data.
# https://github.com/python/cpython/blob/master/Lib/asyncio/tasks.py#L785
async def test_cancel_unary_stream(self):
async with aio.insecure_channel(self._server_target) as channel:
stub = test_pb2_grpc.TestServiceStub(channel)
# Prepares the request
request = messages_pb2.StreamingOutputCallRequest()
for _ in range(_NUM_STREAM_RESPONSES):
request.response_parameters.append(
messages_pb2.ResponseParameters(
size=_RESPONSE_PAYLOAD_SIZE,
interval_us=_RESPONSE_INTERVAL_US,
))
# Invokes the actual RPC
call = stub.StreamingOutputCall(request)
self.assertFalse(call.cancelled())
self.assertIs(exception_context.exception,
exception_context_retry.exception)
response = await call.read()
self.assertIs(
type(response), messages_pb2.StreamingOutputCallResponse)
self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
self.loop.run_until_complete(coro())
self.assertTrue(call.cancel())
self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
self.assertEqual(_LOCAL_CANCEL_DETAILS_EXPECTATION, await
call.details())
self.assertFalse(call.cancel())
with self.assertRaises(grpc.RpcError) as exception_context:
await call.read()
self.assertTrue(call.cancelled())
async def test_multiple_cancel_unary_stream(self):
async with aio.insecure_channel(self._server_target) as channel:
stub = test_pb2_grpc.TestServiceStub(channel)
# Prepares the request
request = messages_pb2.StreamingOutputCallRequest()
for _ in range(_NUM_STREAM_RESPONSES):
request.response_parameters.append(
messages_pb2.ResponseParameters(
size=_RESPONSE_PAYLOAD_SIZE,
interval_us=_RESPONSE_INTERVAL_US,
))
# Invokes the actual RPC
call = stub.StreamingOutputCall(request)
self.assertFalse(call.cancelled())
response = await call.read()
self.assertIs(
type(response), messages_pb2.StreamingOutputCallResponse)
self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
self.assertTrue(call.cancel())
self.assertFalse(call.cancel())
self.assertFalse(call.cancel())
self.assertFalse(call.cancel())
with self.assertRaises(grpc.RpcError) as exception_context:
await call.read()
async def test_early_cancel_unary_stream(self):
"""Test cancellation before receiving messages."""
async with aio.insecure_channel(self._server_target) as channel:
stub = test_pb2_grpc.TestServiceStub(channel)
# Prepares the request
request = messages_pb2.StreamingOutputCallRequest()
for _ in range(_NUM_STREAM_RESPONSES):
request.response_parameters.append(
messages_pb2.ResponseParameters(
size=_RESPONSE_PAYLOAD_SIZE,
interval_us=_RESPONSE_INTERVAL_US,
))
# Invokes the actual RPC
call = stub.StreamingOutputCall(request)
self.assertFalse(call.cancelled())
self.assertTrue(call.cancel())
self.assertFalse(call.cancel())
with self.assertRaises(grpc.RpcError) as exception_context:
await call.read()
self.assertTrue(call.cancelled())
self.assertEqual(grpc.StatusCode.CANCELLED,
exception_context.exception.code())
self.assertEqual(_LOCAL_CANCEL_DETAILS_EXPECTATION,
exception_context.exception.details())
self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
self.assertEqual(_LOCAL_CANCEL_DETAILS_EXPECTATION, await
call.details())
async def test_late_cancel_unary_stream(self):
"""Test cancellation after received all messages."""
async with aio.insecure_channel(self._server_target) as channel:
stub = test_pb2_grpc.TestServiceStub(channel)
# Prepares the request
request = messages_pb2.StreamingOutputCallRequest()
for _ in range(_NUM_STREAM_RESPONSES):
request.response_parameters.append(
messages_pb2.ResponseParameters(
size=_RESPONSE_PAYLOAD_SIZE,))
# Invokes the actual RPC
call = stub.StreamingOutputCall(request)
for _ in range(_NUM_STREAM_RESPONSES):
response = await call.read()
self.assertIs(
type(response), messages_pb2.StreamingOutputCallResponse)
self.assertEqual(_RESPONSE_PAYLOAD_SIZE,
len(response.payload.body))
# After all messages received, it is possible that the final state
# is received or on its way. It's basically a data race, so our
# expectation here is do not crash :)
call.cancel()
async def test_too_many_reads_unary_stream(self):
"""Test cancellation after received all messages."""
async with aio.insecure_channel(self._server_target) as channel:
stub = test_pb2_grpc.TestServiceStub(channel)
# Prepares the request
request = messages_pb2.StreamingOutputCallRequest()
for _ in range(_NUM_STREAM_RESPONSES):
request.response_parameters.append(
messages_pb2.ResponseParameters(
size=_RESPONSE_PAYLOAD_SIZE,))
# Invokes the actual RPC
call = stub.StreamingOutputCall(request)
for _ in range(_NUM_STREAM_RESPONSES):
response = await call.read()
self.assertIs(
type(response), messages_pb2.StreamingOutputCallResponse)
self.assertEqual(_RESPONSE_PAYLOAD_SIZE,
len(response.payload.body))
# After the RPC is finished, further reads will lead to exception.
self.assertEqual(await call.code(), grpc.StatusCode.OK)
with self.assertRaises(asyncio.InvalidStateError):
await call.read()
if __name__ == '__main__':
logging.basicConfig()
logging.basicConfig(level=logging.DEBUG)
unittest.main(verbosity=2)

@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests behavior of the grpc.aio.Channel class."""
import logging
import unittest
@ -18,38 +20,38 @@ import grpc
from grpc.experimental import aio
from src.proto.grpc.testing import messages_pb2
from src.proto.grpc.testing import test_pb2_grpc
from tests.unit.framework.common import test_constants
from tests_aio.unit._test_server import start_test_server
from tests_aio.unit._test_base import AioTestBase
_UNARY_CALL_METHOD = '/grpc.testing.TestService/UnaryCall'
_EMPTY_CALL_METHOD = '/grpc.testing.TestService/EmptyCall'
_STREAMING_OUTPUT_CALL_METHOD = '/grpc.testing.TestService/StreamingOutputCall'
_NUM_STREAM_RESPONSES = 5
_RESPONSE_PAYLOAD_SIZE = 42
class TestChannel(AioTestBase):
def test_async_context(self):
async def setUp(self):
self._server_target, self._server = await start_test_server()
async def coro():
server_target, _ = await start_test_server() # pylint: disable=unused-variable
async def tearDown(self):
await self._server.stop(None)
async with aio.insecure_channel(server_target) as channel:
async def test_async_context(self):
async with aio.insecure_channel(self._server_target) as channel:
hi = channel.unary_unary(
_UNARY_CALL_METHOD,
request_serializer=messages_pb2.SimpleRequest.
SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString
)
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
await hi(messages_pb2.SimpleRequest())
self.loop.run_until_complete(coro())
def test_unary_unary(self):
async def coro():
server_target, _ = await start_test_server() # pylint: disable=unused-variable
channel = aio.insecure_channel(server_target)
async def test_unary_unary(self):
async with aio.insecure_channel(self._server_target) as channel:
channel = aio.insecure_channel(self._server_target)
hi = channel.unary_unary(
_UNARY_CALL_METHOD,
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
@ -58,22 +60,12 @@ class TestChannel(AioTestBase):
self.assertIs(type(response), messages_pb2.SimpleResponse)
await channel.close()
self.loop.run_until_complete(coro())
def test_unary_call_times_out(self):
async def coro():
server_target, _ = await start_test_server() # pylint: disable=unused-variable
async with aio.insecure_channel(server_target) as channel:
async def test_unary_call_times_out(self):
async with aio.insecure_channel(self._server_target) as channel:
empty_call_with_sleep = channel.unary_unary(
_EMPTY_CALL_METHOD,
request_serializer=messages_pb2.SimpleRequest.
SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.
FromString,
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString,
)
timeout = test_constants.SHORT_TIMEOUT / 2
# TODO(https://github.com/grpc/grpc/issues/20869)
@ -86,21 +78,16 @@ class TestChannel(AioTestBase):
messages_pb2.SimpleRequest(), timeout=timeout)
_, details = grpc.StatusCode.DEADLINE_EXCEEDED.value # pylint: disable=unused-variable
self.assertEqual(exception_context.exception.code(),
grpc.StatusCode.DEADLINE_EXCEEDED)
self.assertEqual(exception_context.exception.details(),
details.title())
self.assertIsNotNone(
exception_context.exception.initial_metadata())
self.assertEqual(grpc.StatusCode.DEADLINE_EXCEEDED,
exception_context.exception.code())
self.assertEqual(details.title(),
exception_context.exception.details())
self.assertIsNotNone(exception_context.exception.initial_metadata())
self.assertIsNotNone(
exception_context.exception.trailing_metadata())
self.loop.run_until_complete(coro())
@unittest.skip('https://github.com/grpc/grpc/issues/20818')
def test_call_to_the_void(self):
async def coro():
async def test_call_to_the_void(self):
channel = aio.insecure_channel('0.1.1.1:1111')
hi = channel.unary_unary(
_UNARY_CALL_METHOD,
@ -112,9 +99,31 @@ class TestChannel(AioTestBase):
await channel.close()
self.loop.run_until_complete(coro())
async def test_unary_stream(self):
channel = aio.insecure_channel(self._server_target)
stub = test_pb2_grpc.TestServiceStub(channel)
# Prepares the request
request = messages_pb2.StreamingOutputCallRequest()
for _ in range(_NUM_STREAM_RESPONSES):
request.response_parameters.append(
messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE))
# Invokes the actual RPC
call = stub.StreamingOutputCall(request)
# Validates the responses
response_cnt = 0
async for response in call:
response_cnt += 1
self.assertIs(
type(response), messages_pb2.StreamingOutputCallResponse)
self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
self.assertEqual(_NUM_STREAM_RESPONSES, response_cnt)
await channel.close()
if __name__ == '__main__':
logging.basicConfig()
logging.basicConfig(level=logging.DEBUG)
unittest.main(verbosity=2)

@ -21,16 +21,12 @@ from tests_aio.unit._test_base import AioTestBase
class TestInsecureChannel(AioTestBase):
def test_insecure_channel(self):
async def coro():
async def test_insecure_channel(self):
server_target, _ = await start_test_server() # pylint: disable=unused-variable
channel = aio.insecure_channel(server_target)
self.assertIsInstance(channel, aio.Channel)
self.loop.run_until_complete(coro())
if __name__ == '__main__':
logging.basicConfig()

@ -26,9 +26,12 @@ from tests.unit.framework.common import test_constants
_SIMPLE_UNARY_UNARY = '/test/SimpleUnaryUnary'
_BLOCK_FOREVER = '/test/BlockForever'
_BLOCK_BRIEFLY = '/test/BlockBriefly'
_UNARY_STREAM_ASYNC_GEN = '/test/UnaryStreamAsyncGen'
_UNARY_STREAM_READER_WRITER = '/test/UnaryStreamReaderWriter'
_REQUEST = b'\x00\x00\x00'
_RESPONSE = b'\x01\x01\x01'
_NUM_STREAM_RESPONSES = 5
class _GenericHandler(grpc.GenericRpcHandler):
@ -43,10 +46,18 @@ class _GenericHandler(grpc.GenericRpcHandler):
async def _block_forever(self, unused_request, unused_context):
await asyncio.get_event_loop().create_future()
async def _BLOCK_BRIEFLY(self, unused_request, unused_context):
async def _block_briefly(self, unused_request, unused_context):
await asyncio.sleep(test_constants.SHORT_TIMEOUT / 2)
return _RESPONSE
async def _unary_stream_async_gen(self, unused_request, unused_context):
for _ in range(_NUM_STREAM_RESPONSES):
yield _RESPONSE
async def _unary_stream_reader_writer(self, unused_request, context):
for _ in range(_NUM_STREAM_RESPONSES):
context.write(_RESPONSE)
def service(self, handler_details):
self._called.set_result(None)
if handler_details.method == _SIMPLE_UNARY_UNARY:
@ -54,7 +65,13 @@ class _GenericHandler(grpc.GenericRpcHandler):
if handler_details.method == _BLOCK_FOREVER:
return grpc.unary_unary_rpc_method_handler(self._block_forever)
if handler_details.method == _BLOCK_BRIEFLY:
return grpc.unary_unary_rpc_method_handler(self._BLOCK_BRIEFLY)
return grpc.unary_unary_rpc_method_handler(self._block_briefly)
if handler_details.method == _UNARY_STREAM_ASYNC_GEN:
return grpc.unary_stream_rpc_method_handler(
self._unary_stream_async_gen)
if handler_details.method == _UNARY_STREAM_READER_WRITER:
return grpc.unary_stream_rpc_method_handler(
self._unary_stream_reader_writer)
async def wait_for_call(self):
await self._called
@ -71,51 +88,60 @@ async def _start_test_server():
class TestServer(AioTestBase):
def test_unary_unary(self):
async def setUp(self):
self._server_target, self._server, self._generic_handler = await _start_test_server(
)
async def test_unary_unary_body():
result = await _start_test_server()
server_target = result[0]
async def tearDown(self):
await self._server.stop(None)
async with aio.insecure_channel(server_target) as channel:
unary_call = channel.unary_unary(_SIMPLE_UNARY_UNARY)
response = await unary_call(_REQUEST)
async def test_unary_unary(self):
async with aio.insecure_channel(self._server_target) as channel:
unary_unary_call = channel.unary_unary(_SIMPLE_UNARY_UNARY)
response = await unary_unary_call(_REQUEST)
self.assertEqual(response, _RESPONSE)
self.loop.run_until_complete(test_unary_unary_body())
def test_shutdown(self):
async def test_shutdown_body():
_, server, _ = await _start_test_server()
await server.stop(None)
async def test_unary_stream_async_generator(self):
async with aio.insecure_channel(self._server_target) as channel:
unary_stream_call = channel.unary_stream(_UNARY_STREAM_ASYNC_GEN)
call = unary_stream_call(_REQUEST)
self.loop.run_until_complete(test_shutdown_body())
# Ensures no SIGSEGV triggered, and ends within timeout.
response_cnt = 0
async for response in call:
response_cnt += 1
self.assertEqual(_RESPONSE, response)
def test_shutdown_after_call(self):
self.assertEqual(_NUM_STREAM_RESPONSES, response_cnt)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
async def test_shutdown_body():
server_target, server, _ = await _start_test_server()
async def test_unary_stream_reader_writer(self):
async with aio.insecure_channel(self._server_target) as channel:
unary_stream_call = channel.unary_stream(_UNARY_STREAM_ASYNC_GEN)
call = unary_stream_call(_REQUEST)
async with aio.insecure_channel(server_target) as channel:
await channel.unary_unary(_SIMPLE_UNARY_UNARY)(_REQUEST)
for _ in range(_NUM_STREAM_RESPONSES):
response = await call.read()
self.assertEqual(_RESPONSE, response)
await server.stop(None)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
self.loop.run_until_complete(test_shutdown_body())
async def test_shutdown(self):
await self._server.stop(None)
# Ensures no SIGSEGV triggered, and ends within timeout.
def test_graceful_shutdown_success(self):
async def test_shutdown_after_call(self):
async with aio.insecure_channel(self._server_target) as channel:
await channel.unary_unary(_SIMPLE_UNARY_UNARY)(_REQUEST)
async def test_graceful_shutdown_success_body():
server_target, server, generic_handler = await _start_test_server()
await self._server.stop(None)
channel = aio.insecure_channel(server_target)
async def test_graceful_shutdown_success(self):
channel = aio.insecure_channel(self._server_target)
call = channel.unary_unary(_BLOCK_BRIEFLY)(_REQUEST)
await generic_handler.wait_for_call()
await self._generic_handler.wait_for_call()
shutdown_start_time = time.time()
await server.stop(test_constants.SHORT_TIMEOUT)
await self._server.stop(test_constants.SHORT_TIMEOUT)
grace_period_length = time.time() - shutdown_start_time
self.assertGreater(grace_period_length,
test_constants.SHORT_TIMEOUT / 3)
@ -125,43 +151,31 @@ class TestServer(AioTestBase):
self.assertEqual(_RESPONSE, await call)
self.assertTrue(call.done())
self.loop.run_until_complete(test_graceful_shutdown_success_body())
def test_graceful_shutdown_failed(self):
async def test_graceful_shutdown_failed_body():
server_target, server, generic_handler = await _start_test_server()
channel = aio.insecure_channel(server_target)
async def test_graceful_shutdown_failed(self):
channel = aio.insecure_channel(self._server_target)
call = channel.unary_unary(_BLOCK_FOREVER)(_REQUEST)
await generic_handler.wait_for_call()
await self._generic_handler.wait_for_call()
await server.stop(test_constants.SHORT_TIMEOUT)
await self._server.stop(test_constants.SHORT_TIMEOUT)
with self.assertRaises(aio.AioRpcError) as exception_context:
with self.assertRaises(grpc.RpcError) as exception_context:
await call
self.assertEqual(grpc.StatusCode.UNAVAILABLE,
exception_context.exception.code())
self.assertIn('GOAWAY', exception_context.exception.details())
await channel.close()
self.loop.run_until_complete(test_graceful_shutdown_failed_body())
def test_concurrent_graceful_shutdown(self):
async def test_concurrent_graceful_shutdown_body():
server_target, server, generic_handler = await _start_test_server()
channel = aio.insecure_channel(server_target)
async def test_concurrent_graceful_shutdown(self):
channel = aio.insecure_channel(self._server_target)
call = channel.unary_unary(_BLOCK_BRIEFLY)(_REQUEST)
await generic_handler.wait_for_call()
await self._generic_handler.wait_for_call()
# Expects the shortest grace period to be effective.
shutdown_start_time = time.time()
await asyncio.gather(
server.stop(test_constants.LONG_TIMEOUT),
server.stop(test_constants.SHORT_TIMEOUT),
server.stop(test_constants.LONG_TIMEOUT),
self._server.stop(test_constants.LONG_TIMEOUT),
self._server.stop(test_constants.SHORT_TIMEOUT),
self._server.stop(test_constants.LONG_TIMEOUT),
)
grace_period_length = time.time() - shutdown_start_time
self.assertGreater(grace_period_length,
@ -171,39 +185,28 @@ class TestServer(AioTestBase):
self.assertEqual(_RESPONSE, await call)
self.assertTrue(call.done())
self.loop.run_until_complete(test_concurrent_graceful_shutdown_body())
def test_concurrent_graceful_shutdown_immediate(self):
async def test_concurrent_graceful_shutdown_immediate_body():
server_target, server, generic_handler = await _start_test_server()
channel = aio.insecure_channel(server_target)
async def test_concurrent_graceful_shutdown_immediate(self):
channel = aio.insecure_channel(self._server_target)
call = channel.unary_unary(_BLOCK_FOREVER)(_REQUEST)
await generic_handler.wait_for_call()
await self._generic_handler.wait_for_call()
# Expects no grace period, due to the "server.stop(None)".
await asyncio.gather(
server.stop(test_constants.LONG_TIMEOUT),
server.stop(None),
server.stop(test_constants.SHORT_TIMEOUT),
server.stop(test_constants.LONG_TIMEOUT),
self._server.stop(test_constants.LONG_TIMEOUT),
self._server.stop(None),
self._server.stop(test_constants.SHORT_TIMEOUT),
self._server.stop(test_constants.LONG_TIMEOUT),
)
with self.assertRaises(aio.AioRpcError) as exception_context:
with self.assertRaises(grpc.RpcError) as exception_context:
await call
self.assertEqual(grpc.StatusCode.UNAVAILABLE,
exception_context.exception.code())
self.assertIn('GOAWAY', exception_context.exception.details())
await channel.close()
self.loop.run_until_complete(
test_concurrent_graceful_shutdown_immediate_body())
@unittest.skip('https://github.com/grpc/grpc/issues/20818')
def test_shutdown_before_call(self):
async def test_shutdown_body():
async def test_shutdown_before_call(self):
server_target, server, _ = _start_test_server()
await server.stop(None)
@ -212,8 +215,6 @@ class TestServer(AioTestBase):
async with aio.insecure_channel('localhost:%d' % port) as channel:
await channel.unary_unary(_SIMPLE_UNARY_UNARY)(_REQUEST)
self.loop.run_until_complete(test_shutdown_body())
if __name__ == '__main__':
logging.basicConfig()

Loading…
Cancel
Save