Merge pull request #21232 from lidizheng/aio-streaming

[Aio] Streaming API - Server side streaming
pull/21442/head
Lidi Zheng 5 years ago committed by GitHub
commit 4955cda816
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 11
      src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pxd.pxi
  2. 263
      src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi
  3. 6
      src/python/grpcio/grpc/_cython/_cygrpc/aio/callback_common.pxd.pxi
  4. 74
      src/python/grpcio/grpc/_cython/_cygrpc/aio/callback_common.pyx.pxi
  5. 38
      src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi
  6. 35
      src/python/grpcio/grpc/_cython/_cygrpc/aio/common.pyx.pxi
  7. 1
      src/python/grpcio/grpc/_cython/_cygrpc/aio/grpc_aio.pxd.pxi
  8. 4
      src/python/grpcio/grpc/_cython/_cygrpc/aio/grpc_aio.pyx.pxi
  9. 3
      src/python/grpcio/grpc/_cython/_cygrpc/aio/iomgr/socket.pxd.pxi
  10. 60
      src/python/grpcio/grpc/_cython/_cygrpc/aio/iomgr/socket.pyx.pxi
  11. 12
      src/python/grpcio/grpc/_cython/_cygrpc/aio/rpc_status.pxd.pxi
  12. 23
      src/python/grpcio/grpc/_cython/_cygrpc/aio/rpc_status.pyx.pxi
  13. 1
      src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pxd.pxi
  14. 205
      src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi
  15. 1
      src/python/grpcio/grpc/_cython/_cygrpc/fork_posix.pyx.pxi
  16. 2
      src/python/grpcio/grpc/_cython/cygrpc.pxd
  17. 6
      src/python/grpcio/grpc/_cython/cygrpc.pyx
  18. 2
      src/python/grpcio/grpc/experimental/BUILD.bazel
  19. 16
      src/python/grpcio/grpc/experimental/aio/__init__.py
  20. 157
      src/python/grpcio/grpc/experimental/aio/_base_call.py
  21. 452
      src/python/grpcio/grpc/experimental/aio/_call.py
  22. 162
      src/python/grpcio/grpc/experimental/aio/_channel.py
  23. 17
      src/python/grpcio/grpc/experimental/aio/_typing.py
  24. 2
      src/python/grpcio_tests/commands.py
  25. 3
      src/python/grpcio_tests/tests/_runner.py
  26. 32
      src/python/grpcio_tests/tests_aio/benchmark/BUILD.bazel
  27. 9
      src/python/grpcio_tests/tests_aio/benchmark/server.py
  28. 5
      src/python/grpcio_tests/tests_aio/tests.json
  29. 49
      src/python/grpcio_tests/tests_aio/unit/_test_base.py
  30. 20
      src/python/grpcio_tests/tests_aio/unit/_test_server.py
  31. 50
      src/python/grpcio_tests/tests_aio/unit/aio_rpc_error_test.py
  32. 472
      src/python/grpcio_tests/tests_aio/unit/call_test.py
  33. 153
      src/python/grpcio_tests/tests_aio/unit/channel_test.py
  34. 12
      src/python/grpcio_tests/tests_aio/unit/init_test.py
  35. 292
      src/python/grpcio_tests/tests_aio/unit/server_test.py
  36. 1
      tools/run_tests/artifacts/build_artifact_python.bat
  37. 9
      tools/run_tests/run_tests.py

@ -18,6 +18,17 @@ cdef class _AioCall:
AioChannel _channel
list _references
GrpcCallWrapper _grpc_call_wrapper
# Caches the picked event loop, so we can avoid the 30ns overhead each
# time we need access to the event loop.
object _loop
# Streaming call only attributes:
#
# A asyncio.Event that indicates if the status is received on the client side.
object _status_received
# A tuple of key value pairs representing the initial metadata sent by peer.
tuple _initial_metadata
cdef grpc_call* _create_grpc_call(self, object timeout, bytes method) except *
cdef void _destroy_grpc_call(self)
cdef AioRpcStatus _cancel_and_create_status(self, object cancellation_future)

@ -19,13 +19,25 @@ _EMPTY_FLAGS = 0
_EMPTY_MASK = 0
_EMPTY_METADATA = None
_UNKNOWN_CANCELLATION_DETAILS = 'RPC cancelled for unknown reason.'
cdef class _AioCall:
def __cinit__(self, AioChannel channel):
def __cinit__(self,
AioChannel channel,
object deadline,
bytes method):
self._channel = channel
self._references = []
self._grpc_call_wrapper = GrpcCallWrapper()
self._loop = asyncio.get_event_loop()
self._create_grpc_call(deadline, method)
self._status_received = asyncio.Event(loop=self._loop)
def __dealloc__(self):
self._destroy_grpc_call()
def __repr__(self):
class_name = self.__class__.__name__
@ -33,7 +45,7 @@ cdef class _AioCall:
return f"<{class_name} {id_}>"
cdef grpc_call* _create_grpc_call(self,
object timeout,
object deadline,
bytes method) except *:
"""Creates the corresponding Core object for this RPC.
@ -44,7 +56,7 @@ cdef class _AioCall:
nature in Core.
"""
cdef grpc_slice method_slice
cdef gpr_timespec deadline = _timespec_from_time(timeout)
cdef gpr_timespec c_deadline = _timespec_from_time(deadline)
method_slice = grpc_slice_from_copied_buffer(
<const char *> method,
@ -57,7 +69,7 @@ cdef class _AioCall:
self._channel.cq.c_ptr(),
method_slice,
NULL,
deadline,
c_deadline,
NULL
)
grpc_slice_unref(method_slice)
@ -66,84 +78,191 @@ cdef class _AioCall:
"""Destroys the corresponding Core object for this RPC."""
grpc_call_unref(self._grpc_call_wrapper.call)
async def unary_unary(self, bytes method, bytes request, object timeout, AioCancelStatus cancel_status):
cdef object loop = asyncio.get_event_loop()
cdef AioRpcStatus _cancel_and_create_status(self, object cancellation_future):
"""Cancels the RPC in Core, and return the final RPC status."""
cdef AioRpcStatus status
cdef object details
cdef char *c_details
cdef grpc_call_error error
# Try to fetch application layer cancellation details in the future.
# * If cancellation details present, cancel with status;
# * If details not present, cancel with unknown reason.
if cancellation_future.done():
status = cancellation_future.result()
details = str_to_bytes(status.details())
self._references.append(details)
c_details = <char *>details
# By implementation, grpc_call_cancel_with_status always return OK
error = grpc_call_cancel_with_status(
self._grpc_call_wrapper.call,
status.c_code(),
c_details,
NULL,
)
assert error == GRPC_CALL_OK
return status
else:
# By implementation, grpc_call_cancel always return OK
error = grpc_call_cancel(self._grpc_call_wrapper.call, NULL)
assert error == GRPC_CALL_OK
status = AioRpcStatus(
StatusCode.cancelled,
_UNKNOWN_CANCELLATION_DETAILS,
None,
None,
)
cancellation_future.set_result(status)
return status
cdef tuple operations
cdef Operation initial_metadata_operation
cdef Operation send_message_operation
cdef Operation send_close_from_client_operation
cdef Operation receive_initial_metadata_operation
cdef Operation receive_message_operation
cdef Operation receive_status_on_client_operation
async def unary_unary(self,
bytes request,
object cancellation_future,
object initial_metadata_observer,
object status_observer):
"""Performs a unary unary RPC.
Args:
method: name of the calling method in bytes.
request: the serialized requests in bytes.
deadline: optional deadline of the RPC in float.
cancellation_future: the future that meant to transport the
cancellation reason from the application layer.
initial_metadata_observer: a callback for received initial metadata.
status_observer: a callback for received final status.
"""
cdef tuple ops
cdef char *c_details = NULL
cdef SendInitialMetadataOperation initial_metadata_op = SendInitialMetadataOperation(
_EMPTY_METADATA,
GRPC_INITIAL_METADATA_USED_MASK)
cdef SendMessageOperation send_message_op = SendMessageOperation(request, _EMPTY_FLAGS)
cdef SendCloseFromClientOperation send_close_op = SendCloseFromClientOperation(_EMPTY_FLAGS)
cdef ReceiveInitialMetadataOperation receive_initial_metadata_op = ReceiveInitialMetadataOperation(_EMPTY_FLAGS)
cdef ReceiveMessageOperation receive_message_op = ReceiveMessageOperation(_EMPTY_FLAGS)
cdef ReceiveStatusOnClientOperation receive_status_on_client_op = ReceiveStatusOnClientOperation(_EMPTY_FLAGS)
initial_metadata_operation = SendInitialMetadataOperation(_EMPTY_METADATA, GRPC_INITIAL_METADATA_USED_MASK)
initial_metadata_operation.c()
ops = (initial_metadata_op, send_message_op, send_close_op,
receive_initial_metadata_op, receive_message_op,
receive_status_on_client_op)
send_message_operation = SendMessageOperation(request, _EMPTY_FLAGS)
send_message_operation.c()
try:
await execute_batch(self._grpc_call_wrapper,
ops,
self._loop)
except asyncio.CancelledError:
status = self._cancel_and_create_status(cancellation_future)
initial_metadata_observer(None)
status_observer(status)
raise
else:
initial_metadata_observer(
receive_initial_metadata_op.initial_metadata()
)
send_close_from_client_operation = SendCloseFromClientOperation(_EMPTY_FLAGS)
send_close_from_client_operation.c()
status = AioRpcStatus(
receive_status_on_client_op.code(),
receive_status_on_client_op.details(),
receive_status_on_client_op.trailing_metadata(),
receive_status_on_client_op.error_string(),
)
# Reports the final status of the RPC to Python layer. The observer
# pattern is used here to unify unary and streaming code path.
status_observer(status)
receive_initial_metadata_operation = ReceiveInitialMetadataOperation(_EMPTY_FLAGS)
receive_initial_metadata_operation.c()
if status.code() == StatusCode.ok:
return receive_message_op.message()
else:
return None
receive_message_operation = ReceiveMessageOperation(_EMPTY_FLAGS)
receive_message_operation.c()
async def _handle_status_once_received(self, object status_observer):
"""Handles the status sent by peer once received."""
cdef ReceiveStatusOnClientOperation op = ReceiveStatusOnClientOperation(_EMPTY_FLAGS)
cdef tuple ops = (op,)
await execute_batch(self._grpc_call_wrapper, ops, self._loop)
cdef AioRpcStatus status = AioRpcStatus(
op.code(),
op.details(),
op.trailing_metadata(),
op.error_string(),
)
status_observer(status)
self._status_received.set()
receive_status_on_client_operation = ReceiveStatusOnClientOperation(_EMPTY_FLAGS)
receive_status_on_client_operation.c()
def _handle_cancellation_from_application(self,
object cancellation_future,
object status_observer):
def _cancellation_action(finished_future):
if not self._status_received.set():
status = self._cancel_and_create_status(finished_future)
status_observer(status)
self._status_received.set()
operations = (
initial_metadata_operation,
send_message_operation,
send_close_from_client_operation,
receive_initial_metadata_operation,
receive_message_operation,
receive_status_on_client_operation,
)
cancellation_future.add_done_callback(_cancellation_action)
try:
self._create_grpc_call(
timeout,
method,
async def _message_async_generator(self):
cdef bytes received_message
# Infinitely receiving messages, until:
# * EOF, no more messages to read;
# * The client application cancells;
# * The server sends final status.
while True:
if self._status_received.is_set():
return
received_message = await _receive_message(
self._grpc_call_wrapper,
self._loop
)
if received_message is None:
# The read operation failed, Core should explain why it fails
await self._status_received.wait()
return
else:
yield received_message
async def unary_stream(self,
bytes request,
object cancellation_future,
object initial_metadata_observer,
object status_observer):
"""Actual implementation of the complete unary-stream call.
Needs to pay extra attention to the raise mechanism. If we want to
propagate the final status exception, then we have to raise it.
Othersize, it would end normally and raise `StopAsyncIteration()`.
"""
cdef tuple outbound_ops
cdef Operation initial_metadata_op = SendInitialMetadataOperation(
_EMPTY_METADATA,
GRPC_INITIAL_METADATA_USED_MASK)
cdef Operation send_message_op = SendMessageOperation(
request,
_EMPTY_FLAGS)
cdef Operation send_close_op = SendCloseFromClientOperation(
_EMPTY_FLAGS)
try:
await callback_start_batch(
self._grpc_call_wrapper,
operations,
loop
)
except asyncio.CancelledError:
if cancel_status:
details = str_to_bytes(cancel_status.details())
self._references.append(details)
c_details = <char *>details
call_status = grpc_call_cancel_with_status(
self._grpc_call_wrapper.call,
cancel_status.code(),
c_details,
NULL,
)
else:
call_status = grpc_call_cancel(
self._grpc_call_wrapper.call, NULL)
if call_status != GRPC_CALL_OK:
raise Exception("RPC call couldn't be cancelled. Error {}".format(call_status))
raise
finally:
self._destroy_grpc_call()
if receive_status_on_client_operation.code() == StatusCode.ok:
return receive_message_operation.message()
raise AioRpcError(
receive_initial_metadata_operation.initial_metadata(),
receive_status_on_client_operation.code(),
receive_status_on_client_operation.details(),
receive_status_on_client_operation.trailing_metadata(),
outbound_ops = (
initial_metadata_op,
send_message_op,
send_close_op,
)
# Actually sends out the request message.
await execute_batch(self._grpc_call_wrapper,
outbound_ops,
self._loop)
# Peer may prematurely end this RPC at any point. We need a mechanism
# that handles both the normal case and the error case.
self._loop.create_task(self._handle_status_once_received(status_observer))
self._handle_cancellation_from_application(cancellation_future,
status_observer)
# Receives initial metadata.
initial_metadata_observer(
await _receive_initial_metadata(self._grpc_call_wrapper,
self._loop),
)
return self._message_async_generator()

@ -28,10 +28,10 @@ cdef struct CallbackContext:
#
# Attributes:
# functor: A grpc_experimental_completion_queue_functor represents the
# callback function in the only way C-Core understands.
# callback function in the only way Core understands.
# waiter: An asyncio.Future object that fulfills when the callback is
# invoked by C-Core.
# failure_handler: A CallbackFailureHandler object that called when C-Core
# invoked by Core.
# failure_handler: A CallbackFailureHandler object that called when Core
# returns 'success == 0' state.
grpc_experimental_completion_queue_functor functor
cpython.PyObject *waiter

@ -46,11 +46,13 @@ cdef class CallbackWrapper:
grpc_experimental_completion_queue_functor* functor,
int success):
cdef CallbackContext *context = <CallbackContext *>functor
cdef object waiter = <object>context.waiter
if waiter.cancelled():
return
if success == 0:
(<CallbackFailureHandler>context.failure_handler).handle(
<object>context.waiter)
(<CallbackFailureHandler>context.failure_handler).handle(waiter)
else:
(<object>context.waiter).set_result(None)
waiter.set_result(None)
cdef grpc_experimental_completion_queue_functor *c_functor(self):
return &self.context.functor
@ -83,7 +85,10 @@ cdef class CallbackCompletionQueue:
grpc_completion_queue_destroy(self._cq)
async def callback_start_batch(GrpcCallWrapper grpc_call_wrapper,
class ExecuteBatchError(Exception): pass
async def execute_batch(GrpcCallWrapper grpc_call_wrapper,
tuple operations,
object loop):
"""The callback version of start batch operations."""
@ -93,7 +98,7 @@ async def callback_start_batch(GrpcCallWrapper grpc_call_wrapper,
cdef object future = loop.create_future()
cdef CallbackWrapper wrapper = CallbackWrapper(
future,
CallbackFailureHandler('callback_start_batch', operations, RuntimeError))
CallbackFailureHandler('execute_batch', operations, ExecuteBatchError))
# NOTE(lidiz) Without Py_INCREF, the wrapper object will be destructed
# when calling "await". This is an over-optimization by Cython.
cpython.Py_INCREF(wrapper)
@ -104,10 +109,67 @@ async def callback_start_batch(GrpcCallWrapper grpc_call_wrapper,
wrapper.c_functor(), NULL)
if error != GRPC_CALL_OK:
raise RuntimeError("Failed grpc_call_start_batch: {}".format(error))
raise ExecuteBatchError("Failed grpc_call_start_batch: {}".format(error))
await future
cpython.Py_DECREF(wrapper)
cdef grpc_event c_event
# Tag.event must be called, otherwise messages won't be parsed from C
batch_operation_tag.event(c_event)
async def _receive_message(GrpcCallWrapper grpc_call_wrapper,
object loop):
"""Retrives parsed messages from Core.
The messages maybe already in Core's buffer, so there isn't a 1-to-1
mapping between this and the underlying "socket.read()". Also, eventually,
this function will end with an EOF, which reads empty message.
"""
cdef ReceiveMessageOperation receive_op = ReceiveMessageOperation(_EMPTY_FLAG)
cdef tuple ops = (receive_op,)
try:
await execute_batch(grpc_call_wrapper, ops, loop)
except ExecuteBatchError as e:
# NOTE(lidiz) The receive message operation has two ways to indicate
# finish state : 1) returns empty message due to EOF; 2) fails inside
# the callback (e.g. cancelled).
#
# Since they all indicates finish, they are better be merged.
_LOGGER.debug(e)
return receive_op.message()
async def _send_message(GrpcCallWrapper grpc_call_wrapper,
bytes message,
bint metadata_sent,
object loop):
cdef SendMessageOperation op = SendMessageOperation(message, _EMPTY_FLAG)
cdef tuple ops
if metadata_sent:
ops = (op,)
else:
ops = (
# Initial metadata must be sent before first outbound message.
SendInitialMetadataOperation(None, _EMPTY_FLAG),
op,
)
await execute_batch(grpc_call_wrapper, ops, loop)
async def _send_initial_metadata(GrpcCallWrapper grpc_call_wrapper,
tuple metadata,
object loop):
cdef SendInitialMetadataOperation op = SendInitialMetadataOperation(
metadata,
_EMPTY_FLAG)
cdef tuple ops = (op,)
await execute_batch(grpc_call_wrapper, ops, loop)
async def _receive_initial_metadata(GrpcCallWrapper grpc_call_wrapper,
object loop):
cdef ReceiveInitialMetadataOperation op = ReceiveInitialMetadataOperation(_EMPTY_FLAGS)
cdef tuple ops = (op,)
await execute_batch(grpc_call_wrapper, ops, loop)
return op.initial_metadata()

@ -26,6 +26,38 @@ cdef class AioChannel:
def close(self):
grpc_channel_destroy(self.channel)
async def unary_unary(self, method, request, timeout, cancel_status):
call = _AioCall(self)
return await call.unary_unary(method, request, timeout, cancel_status)
async def unary_unary(self,
bytes method,
bytes request,
object deadline,
object cancellation_future,
object initial_metadata_observer,
object status_observer):
"""Assembles a unary-unary RPC.
Returns:
The response message in bytes.
"""
cdef _AioCall call = _AioCall(self, deadline, method)
return await call.unary_unary(request,
cancellation_future,
initial_metadata_observer,
status_observer)
def unary_stream(self,
bytes method,
bytes request,
object deadline,
object cancellation_future,
object initial_metadata_observer,
object status_observer):
"""Assembles a unary-stream RPC.
Returns:
An async generator that yields raw responses.
"""
cdef _AioCall call = _AioCall(self, deadline, method)
return call.unary_stream(request,
cancellation_future,
initial_metadata_observer,
status_observer)

@ -1,4 +1,4 @@
# Copyright 2019 gRPC authors.
# Copyright 2019 The gRPC Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -11,26 +11,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Desired cancellation status for canceling an ongoing RPC call."""
cdef class AioCancelStatus:
cdef object deserialize(object deserializer, bytes raw_message):
"""Perform deserialization on raw bytes.
def __cinit__(self):
self._code = None
self._details = None
Failure to deserialize is a fatal error.
"""
if deserializer:
return deserializer(raw_message)
else:
return raw_message
def __len__(self):
if self._code is None:
return 0
return 1
def cancel(self, grpc_status_code code, str details=None):
self._code = code
self._details = details
cdef bytes serialize(object serializer, object message):
"""Perform serialization on a message.
cpdef object code(self):
return self._code
cpdef str details(self):
return self._details
Failure to serialize is a fatal error.
"""
if serializer:
return serializer(message)
else:
return message

@ -13,6 +13,7 @@
# limitations under the License.
# distutils: language=c++
cdef extern from "src/core/lib/iomgr/timer_manager.h":
void grpc_timer_manager_set_threading(bint enabled);

@ -28,10 +28,10 @@ def init_grpc_aio():
# Timers are triggered by the Asyncio loop. We disable
# the background thread that is being used by the native
# gRPC iomgr.
grpc_timer_manager_set_threading(0)
grpc_timer_manager_set_threading(False)
# gRPC callbaks are executed within the same thread used by the Asyncio
# event loop, as it is being done by the other Asyncio callbacks.
Executor.SetThreadingAll(0)
Executor.SetThreadingAll(False)
_grpc_aio_initialized = 1

@ -23,6 +23,9 @@ cdef class _AsyncioSocket:
object _task_read
object _task_connect
char * _read_buffer
# Caches the picked event loop, so we can avoid the 30ns overhead each
# time we need access to the event loop.
object _loop
# Client-side attributes
grpc_custom_connect_callback _grpc_connect_cb

@ -16,6 +16,8 @@ import socket as native_socket
from libc cimport string
# TODO(https://github.com/grpc/grpc/issues/21348) Better flow control needed.
cdef class _AsyncioSocket:
def __cinit__(self):
self._grpc_socket = NULL
@ -29,6 +31,7 @@ cdef class _AsyncioSocket:
self._server = None
self._py_socket = None
self._peername = None
self._loop = asyncio.get_event_loop()
@staticmethod
cdef _AsyncioSocket create(grpc_custom_socket * grpc_socket,
@ -56,30 +59,25 @@ cdef class _AsyncioSocket:
return f"<{class_name} {id_} connected={connected}>"
def _connect_cb(self, future):
error = False
try:
self._reader, self._writer = future.result()
except Exception as e:
error = True
error_msg = str(e)
self._grpc_connect_cb(
<grpc_custom_socket*>self._grpc_socket,
grpc_socket_error("Socket connect failed: {}".format(e).encode())
)
finally:
self._task_connect = None
if not error:
# gRPC default posix implementation disables nagle
# algorithm.
sock = self._writer.transport.get_extra_info('socket')
sock.setsockopt(native_socket.IPPROTO_TCP, native_socket.TCP_NODELAY, True)
# gRPC default posix implementation disables nagle
# algorithm.
sock = self._writer.transport.get_extra_info('socket')
sock.setsockopt(native_socket.IPPROTO_TCP, native_socket.TCP_NODELAY, True)
self._grpc_connect_cb(
<grpc_custom_socket*>self._grpc_socket,
<grpc_error*>0
)
else:
self._grpc_connect_cb(
<grpc_custom_socket*>self._grpc_socket,
grpc_socket_error("connect {}".format(error_msg).encode())
)
self._grpc_connect_cb(
<grpc_custom_socket*>self._grpc_socket,
<grpc_error*>0
)
def _read_cb(self, future):
error = False
@ -87,7 +85,8 @@ cdef class _AsyncioSocket:
buffer_ = future.result()
except Exception as e:
error = True
error_msg = str(e)
error_msg = "%s: %s" % (type(e), str(e))
_LOGGER.exception(e)
finally:
self._task_read = None
@ -106,7 +105,7 @@ cdef class _AsyncioSocket:
self._grpc_read_cb(
<grpc_custom_socket*>self._grpc_socket,
-1,
grpc_socket_error("read {}".format(error_msg).encode())
grpc_socket_error("Read failed: {}".format(error_msg).encode())
)
cdef void connect(self,
@ -125,7 +124,7 @@ cdef class _AsyncioSocket:
cdef void read(self, char * buffer_, size_t length, grpc_custom_read_callback grpc_read_cb):
assert not self._task_read
self._task_read = asyncio.ensure_future(
self._task_read = self._loop.create_task(
self._reader.read(n=length)
)
self._grpc_read_cb = grpc_read_cb
@ -133,15 +132,20 @@ cdef class _AsyncioSocket:
self._read_buffer = buffer_
cdef void write(self, grpc_slice_buffer * g_slice_buffer, grpc_custom_write_callback grpc_write_cb):
"""Performs write to network socket in AsyncIO.
For each socket, Core guarantees there'll be only one ongoing write.
When the write is finished, we need to call grpc_write_cb to notify
Core that the work is done.
"""
cdef char* start
buffer_ = bytearray()
cdef bytearray outbound_buffer = bytearray()
for i in range(g_slice_buffer.count):
start = grpc_slice_buffer_start(g_slice_buffer, i)
length = grpc_slice_buffer_length(g_slice_buffer, i)
buffer_.extend(<bytes>start[:length])
self._writer.write(buffer_)
outbound_buffer.extend(<bytes>start[:length])
self._writer.write(outbound_buffer)
grpc_write_cb(
<grpc_custom_socket*>self._grpc_socket,
<grpc_error*>0
@ -171,9 +175,9 @@ cdef class _AsyncioSocket:
self._grpc_client_socket.impl = <void*>client_socket
cpython.Py_INCREF(client_socket) # Py_DECREF in asyncio_socket_destroy
# Accept callback expects to be called with:
# grpc_custom_socket: A grpc custom socket for server
# grpc_custom_socket: A grpc custom socket for client (with new Socket instance)
# grpc_error: An error object
# * grpc_custom_socket: A grpc custom socket for server
# * grpc_custom_socket: A grpc custom socket for client (with new Socket instance)
# * grpc_error: An error object
self._grpc_accept_cb(self._grpc_socket, self._grpc_client_socket, grpc_error_none())
cdef listen(self):
@ -183,7 +187,7 @@ cdef class _AsyncioSocket:
sock=self._py_socket,
)
asyncio.get_event_loop().create_task(create_asyncio_server())
self._loop.create_task(create_asyncio_server())
cdef accept(self,
grpc_custom_socket* grpc_socket_client,

@ -14,14 +14,16 @@
"""Exceptions for the aio version of the RPC calls."""
cdef class _AioRpcError(Exception):
cdef class AioRpcStatus(Exception):
cdef readonly:
tuple _initial_metadata
int _code
grpc_status_code _code
str _details
# Per the spec, only client-side status has trailing metadata.
tuple _trailing_metadata
str _debug_error_string
cpdef tuple initial_metadata(self)
cpdef int code(self)
cpdef grpc_status_code code(self)
cpdef str details(self)
cpdef tuple trailing_metadata(self)
cpdef str debug_error_string(self)
cdef grpc_status_code c_code(self)

@ -14,18 +14,21 @@
"""Exceptions for the aio version of the RPC calls."""
cdef class AioRpcError(Exception):
cdef class AioRpcStatus(Exception):
def __cinit__(self, tuple initial_metadata, int code, str details, tuple trailing_metadata):
self._initial_metadata = initial_metadata
# The final status of gRPC is represented by three trailing metadata:
# `grpc-status`, `grpc-status-message`, abd `grpc-status-details`.
def __cinit__(self,
grpc_status_code code,
str details,
tuple trailing_metadata,
str debug_error_string):
self._code = code
self._details = details
self._trailing_metadata = trailing_metadata
self._debug_error_string = debug_error_string
cpdef tuple initial_metadata(self):
return self._initial_metadata
cpdef int code(self):
cpdef grpc_status_code code(self):
return self._code
cpdef str details(self):
@ -33,3 +36,9 @@ cdef class AioRpcError(Exception):
cpdef tuple trailing_metadata(self):
return self._trailing_metadata
cpdef str debug_error_string(self):
return self._debug_error_string
cdef grpc_status_code c_code(self):
return <grpc_status_code>self._code

@ -43,3 +43,4 @@ cdef class AioServer:
cdef object _shutdown_completed # asyncio.Future
cdef CallbackWrapper _shutdown_callback_wrapper
cdef object _crash_exception # Exception
cdef set _ongoing_rpc_tasks

@ -12,6 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
# TODO(https://github.com/grpc/grpc/issues/20850) refactor this.
_LOGGER = logging.getLogger(__name__)
cdef int _EMPTY_FLAG = 0
@ -23,9 +27,6 @@ cdef class _HandlerCallDetails:
self.invocation_metadata = invocation_metadata
class _ServicerContextPlaceHolder(object): pass
cdef class RPCState:
def __cinit__(self):
@ -43,12 +44,49 @@ cdef class RPCState:
grpc_call_unref(self.call)
cdef class _ServicerContext:
cdef RPCState _rpc_state
cdef object _loop
cdef bint _metadata_sent
cdef object _request_deserializer
cdef object _response_serializer
def __cinit__(self,
RPCState rpc_state,
object request_deserializer,
object response_serializer,
object loop):
self._rpc_state = rpc_state
self._request_deserializer = request_deserializer
self._response_serializer = response_serializer
self._loop = loop
self._metadata_sent = False
async def read(self):
cdef bytes raw_message = await _receive_message(self._rpc_state, self._loop)
return deserialize(self._request_deserializer,
raw_message)
async def write(self, object message):
await _send_message(self._rpc_state,
serialize(self._response_serializer, message),
self._metadata_sent,
self._loop)
if not self._metadata_sent:
self._metadata_sent = True
async def send_initial_metadata(self, tuple metadata):
if self._metadata_sent:
raise RuntimeError('Send initial metadata failed: already sent')
else:
_send_initial_metadata(self._rpc_state, self._loop)
self._metadata_sent = True
cdef _find_method_handler(str method, list generic_handlers):
# TODO(lidiz) connects Metadata to call details
cdef _HandlerCallDetails handler_call_details = _HandlerCallDetails(
method,
tuple()
)
cdef _HandlerCallDetails handler_call_details = _HandlerCallDetails(method,
None)
for generic_handler in generic_handlers:
method_handler = generic_handler.service(handler_call_details)
@ -61,64 +99,132 @@ async def _handle_unary_unary_rpc(object method_handler,
RPCState rpc_state,
object loop):
# Receives request message
cdef tuple receive_ops = (
ReceiveMessageOperation(_EMPTY_FLAGS),
)
await callback_start_batch(rpc_state, receive_ops, loop)
cdef bytes request_raw = await _receive_message(rpc_state, loop)
# Deserializes the request message
cdef bytes request_raw = receive_ops[0].message()
cdef object request_message
if method_handler.request_deserializer:
request_message = method_handler.request_deserializer(request_raw)
else:
request_message = request_raw
cdef object request_message = deserialize(
method_handler.request_deserializer,
request_raw,
)
# Executes application logic
cdef object response_message = await method_handler.unary_unary(request_message, _ServicerContextPlaceHolder())
cdef object response_message = await method_handler.unary_unary(
request_message,
_ServicerContext(
rpc_state,
None,
None,
loop,
),
)
# Serializes the response message
cdef bytes response_raw
if method_handler.response_serializer:
response_raw = method_handler.response_serializer(response_message)
else:
response_raw = response_message
cdef bytes response_raw = serialize(
method_handler.response_serializer,
response_message,
)
# Sends response message
cdef tuple send_ops = (
SendStatusFromServerOperation(
tuple(), StatusCode.ok, b'', _EMPTY_FLAGS),
SendInitialMetadataOperation(tuple(), _EMPTY_FLAGS),
tuple(),
StatusCode.ok,
b'',
_EMPTY_FLAGS,
),
SendInitialMetadataOperation(None, _EMPTY_FLAGS),
SendMessageOperation(response_raw, _EMPTY_FLAGS),
)
await callback_start_batch(rpc_state, send_ops, loop)
await execute_batch(rpc_state, send_ops, loop)
async def _handle_unary_stream_rpc(object method_handler,
RPCState rpc_state,
object loop):
# Receives request message
cdef bytes request_raw = await _receive_message(rpc_state, loop)
# Deserializes the request message
cdef object request_message = deserialize(
method_handler.request_deserializer,
request_raw,
)
cdef _ServicerContext servicer_context = _ServicerContext(
rpc_state,
method_handler.request_deserializer,
method_handler.response_serializer,
loop,
)
cdef object async_response_generator
cdef object response_message
if inspect.iscoroutinefunction(method_handler.unary_stream):
# The handler uses reader / writer API, returns None.
await method_handler.unary_stream(
request_message,
servicer_context,
)
else:
# The handler uses async generator API
async_response_generator = method_handler.unary_stream(
request_message,
servicer_context,
)
# Consumes messages from the generator
async for response_message in async_response_generator:
await servicer_context.write(response_message)
# Sends the final status of this RPC
cdef SendStatusFromServerOperation op = SendStatusFromServerOperation(
None,
StatusCode.ok,
b'',
_EMPTY_FLAGS,
)
cdef tuple ops = (op,)
await execute_batch(rpc_state, ops, loop)
async def _handle_cancellation_from_core(object rpc_task,
RPCState rpc_state,
object loop):
cdef ReceiveCloseOnServerOperation op = ReceiveCloseOnServerOperation(_EMPTY_FLAG)
cdef tuple ops = (op,)
await execute_batch(rpc_state, ops, loop)
if op.cancelled() and not rpc_task.done():
rpc_task.cancel()
async def _handle_rpc(list generic_handlers, RPCState rpc_state, object loop):
# Finds the method handler (application logic)
cdef object method_handler = _find_method_handler(
rpc_state.method().decode(),
generic_handlers
generic_handlers,
)
if method_handler is None:
# TODO(lidiz) return unimplemented error to client side
raise NotImplementedError()
# TODO(lidiz) extend to all 4 types of RPC
if method_handler.request_streaming or method_handler.response_streaming:
raise NotImplementedError()
if not method_handler.request_streaming and method_handler.response_streaming:
await _handle_unary_stream_rpc(method_handler,
rpc_state,
loop)
elif not method_handler.request_streaming and not method_handler.response_streaming:
await _handle_unary_unary_rpc(method_handler,
rpc_state,
loop)
else:
await _handle_unary_unary_rpc(
method_handler,
rpc_state,
loop
)
raise NotImplementedError()
class _RequestCallError(Exception): pass
cdef CallbackFailureHandler REQUEST_CALL_FAILURE_HANDLER = CallbackFailureHandler(
'grpc_server_request_call', 'server shutdown', _RequestCallError)
'grpc_server_request_call', None, _RequestCallError)
async def _server_call_request_call(Server server,
@ -147,19 +253,9 @@ async def _server_call_request_call(Server server,
return rpc_state
async def _handle_cancellation_from_core(object rpc_task,
RPCState rpc_state,
object loop):
cdef ReceiveCloseOnServerOperation op = ReceiveCloseOnServerOperation(_EMPTY_FLAG)
cdef tuple ops = (op,)
await callback_start_batch(rpc_state, ops, loop)
if op.cancelled() and not rpc_task.done():
rpc_task.cancel()
cdef CallbackFailureHandler SERVER_SHUTDOWN_FAILURE_HANDLER = CallbackFailureHandler(
'grpc_server_shutdown_and_notify',
'Unknown',
None,
RuntimeError)
@ -182,6 +278,7 @@ cdef class AioServer:
self._generic_handlers = []
self.add_generic_rpc_handlers(generic_handlers)
self._serving_task = None
self._ongoing_rpc_tasks = set()
self._shutdown_lock = asyncio.Lock(loop=self._loop)
self._shutdown_completed = self._loop.create_future()
@ -221,11 +318,13 @@ cdef class AioServer:
if self._status != AIO_SERVER_STATUS_RUNNING:
break
# Accepts new request from Core
rpc_state = await _server_call_request_call(
self._server,
self._cq,
self._loop)
# Schedules the RPC as a separate coroutine
rpc_task = self._loop.create_task(
_handle_rpc(
self._generic_handlers,
@ -233,6 +332,8 @@ cdef class AioServer:
self._loop
)
)
# Fires off a task that listens on the cancellation from client.
self._loop.create_task(
_handle_cancellation_from_core(
rpc_task,
@ -241,6 +342,10 @@ cdef class AioServer:
)
)
# Keeps track of created coroutines, so we can clean them up properly.
self._ongoing_rpc_tasks.add(rpc_task)
rpc_task.add_done_callback(lambda _: self._ongoing_rpc_tasks.remove(rpc_task))
def _serving_task_crash_handler(self, object task):
"""Shutdown the server immediately if unexpectedly exited."""
if task.exception() is None:
@ -282,7 +387,7 @@ cdef class AioServer:
pass
async def shutdown(self, grace):
"""Gracefully shutdown the C-Core server.
"""Gracefully shutdown the Core server.
Application should only call shutdown once.
@ -318,6 +423,10 @@ cdef class AioServer:
grpc_server_cancel_all_calls(self._server.c_server)
await self._shutdown_completed
# Cancels all Python layer tasks
for rpc_task in self._ongoing_rpc_tasks:
rpc_task.cancel()
async with self._shutdown_lock:
if self._status == AIO_SERVER_STATUS_STOPPING:
grpc_server_destroy(self._server.c_server)
@ -328,7 +437,7 @@ cdef class AioServer:
# Shuts down the completion queue
await self._cq.shutdown()
async def wait_for_termination(self, float timeout):
async def wait_for_termination(self, object timeout):
if timeout is None:
await self._shutdown_completed
else:

@ -32,6 +32,7 @@ _TRUE_VALUES = ['yes', 'Yes', 'YES', 'true', 'True', 'TRUE', '1']
# must not block and should execute quickly.
#
# This flag is not supported on Windows.
# This flag is also not supported for non-native IO manager.
_GRPC_ENABLE_FORK_SUPPORT = (
os.environ.get('GRPC_ENABLE_FORK_SUPPORT', '0')
.lower() in _TRUE_VALUES)

@ -43,9 +43,9 @@ IF UNAME_SYSNAME != "Windows":
include "_cygrpc/aio/iomgr/socket.pxd.pxi"
include "_cygrpc/aio/iomgr/timer.pxd.pxi"
include "_cygrpc/aio/iomgr/resolver.pxd.pxi"
include "_cygrpc/aio/rpc_status.pxd.pxi"
include "_cygrpc/aio/grpc_aio.pxd.pxi"
include "_cygrpc/aio/callback_common.pxd.pxi"
include "_cygrpc/aio/call.pxd.pxi"
include "_cygrpc/aio/cancel_status.pxd.pxi"
include "_cygrpc/aio/channel.pxd.pxi"
include "_cygrpc/aio/server.pxd.pxi"

@ -60,12 +60,12 @@ include "_cygrpc/aio/iomgr/iomgr.pyx.pxi"
include "_cygrpc/aio/iomgr/socket.pyx.pxi"
include "_cygrpc/aio/iomgr/timer.pyx.pxi"
include "_cygrpc/aio/iomgr/resolver.pyx.pxi"
include "_cygrpc/aio/common.pyx.pxi"
include "_cygrpc/aio/rpc_status.pyx.pxi"
include "_cygrpc/aio/callback_common.pyx.pxi"
include "_cygrpc/aio/grpc_aio.pyx.pxi"
include "_cygrpc/aio/call.pyx.pxi"
include "_cygrpc/aio/callback_common.pyx.pxi"
include "_cygrpc/aio/cancel_status.pyx.pxi"
include "_cygrpc/aio/channel.pyx.pxi"
include "_cygrpc/aio/rpc_error.pyx.pxi"
include "_cygrpc/aio/server.pyx.pxi"

@ -4,9 +4,11 @@ py_library(
name = "aio",
srcs = [
"aio/__init__.py",
"aio/_base_call.py",
"aio/_call.py",
"aio/_channel.py",
"aio/_server.py",
"aio/_typing.py",
],
deps = [
"//src/python/grpcio/grpc/_cython:cygrpc",

@ -11,18 +11,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""gRPC's Asynchronous Python API."""
"""gRPC's Asynchronous Python API.
gRPC Async API objects may only be used on the thread on which they were
created. AsyncIO doesn't provide thread safety for most of its APIs.
"""
import abc
import six
import grpc
from grpc import _common
from grpc._cython import cygrpc
from grpc._cython.cygrpc import init_grpc_aio
from ._call import AioRpcError
from ._call import Call
from ._base_call import RpcContext, Call, UnaryUnaryCall, UnaryStreamCall
from ._channel import Channel
from ._channel import UnaryUnaryMultiCallable
from ._server import server
@ -47,5 +48,6 @@ def insecure_channel(target, options=None, compression=None):
################################### __all__ #################################
__all__ = ('AioRpcError', 'Call', 'init_grpc_aio', 'Channel',
'UnaryUnaryMultiCallable', 'insecure_channel', 'server')
__all__ = ('RpcContext', 'Call', 'UnaryUnaryCall', 'UnaryStreamCall',
'init_grpc_aio', 'Channel', 'UnaryUnaryMultiCallable',
'insecure_channel', 'server')

@ -0,0 +1,157 @@
# Copyright 2019 The gRPC Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Abstract base classes for client-side Call objects.
Call objects represents the RPC itself, and offer methods to access / modify
its information. They also offer methods to manipulate the life-cycle of the
RPC, e.g. cancellation.
"""
from abc import ABCMeta, abstractmethod
from typing import Any, AsyncIterable, Awaitable, Callable, Generic, Text, Optional
import grpc
from ._typing import MetadataType, RequestType, ResponseType
__all__ = 'RpcContext', 'Call', 'UnaryUnaryCall', 'UnaryStreamCall'
class RpcContext(metaclass=ABCMeta):
"""Provides RPC-related information and action."""
@abstractmethod
def cancelled(self) -> bool:
"""Return True if the RPC is cancelled.
The RPC is cancelled when the cancellation was requested with cancel().
Returns:
A bool indicates whether the RPC is cancelled or not.
"""
@abstractmethod
def done(self) -> bool:
"""Return True if the RPC is done.
An RPC is done if the RPC is completed, cancelled or aborted.
Returns:
A bool indicates if the RPC is done.
"""
@abstractmethod
def time_remaining(self) -> Optional[float]:
"""Describes the length of allowed time remaining for the RPC.
Returns:
A nonnegative float indicating the length of allowed time in seconds
remaining for the RPC to complete before it is considered to have
timed out, or None if no deadline was specified for the RPC.
"""
@abstractmethod
def cancel(self) -> bool:
"""Cancels the RPC.
Idempotent and has no effect if the RPC has already terminated.
Returns:
A bool indicates if the cancellation is performed or not.
"""
@abstractmethod
def add_done_callback(self, callback: Callable[[Any], None]) -> None:
"""Registers a callback to be called on RPC termination.
Args:
callback: A callable object will be called with the context object as
its only argument.
"""
class Call(RpcContext, metaclass=ABCMeta):
"""The abstract base class of an RPC on the client-side."""
@abstractmethod
async def initial_metadata(self) -> MetadataType:
"""Accesses the initial metadata sent by the server.
Returns:
The initial :term:`metadata`.
"""
@abstractmethod
async def trailing_metadata(self) -> MetadataType:
"""Accesses the trailing metadata sent by the server.
Returns:
The trailing :term:`metadata`.
"""
@abstractmethod
async def code(self) -> grpc.StatusCode:
"""Accesses the status code sent by the server.
Returns:
The StatusCode value for the RPC.
"""
@abstractmethod
async def details(self) -> Text:
"""Accesses the details sent by the server.
Returns:
The details string of the RPC.
"""
class UnaryUnaryCall(
Generic[RequestType, ResponseType], Call, metaclass=ABCMeta):
"""The abstract base class of an unary-unary RPC on the client-side."""
@abstractmethod
def __await__(self) -> Awaitable[ResponseType]:
"""Await the response message to be ready.
Returns:
The response message of the RPC.
"""
class UnaryStreamCall(
Generic[RequestType, ResponseType], Call, metaclass=ABCMeta):
@abstractmethod
def __aiter__(self) -> AsyncIterable[ResponseType]:
"""Returns the async iterable representation that yields messages.
Under the hood, it is calling the "read" method.
Returns:
An async iterable object that yields messages.
"""
@abstractmethod
async def read(self) -> ResponseType:
"""Reads one message from the RPC.
For each streaming RPC, concurrent reads in multiple coroutines are not
allowed. If you want to perform read in multiple coroutines, you needs
synchronization. So, you can start another read after current read is
finished.
Returns:
A response message of the RPC.
"""

@ -12,19 +12,42 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Invocation-side implementation of gRPC Asyncio Python."""
import asyncio
import enum
from typing import Callable, Dict, Optional, ClassVar
from typing import AsyncIterable, Awaitable, Dict, Optional
import grpc
from grpc import _common
from grpc._cython import cygrpc
DeserializingFunction = Callable[[bytes], str]
from . import _base_call
from ._typing import (DeserializingFunction, MetadataType, RequestType,
ResponseType, SerializingFunction)
__all__ = 'AioRpcError', 'Call', 'UnaryUnaryCall', 'UnaryStreamCall'
_LOCAL_CANCELLATION_DETAILS = 'Locally cancelled by application!'
_GC_CANCELLATION_DETAILS = 'Cancelled upon garbage collection!'
_RPC_ALREADY_FINISHED_DETAILS = 'RPC already finished.'
_OK_CALL_REPRESENTATION = ('<{} of RPC that terminated with:\n'
'\tstatus = {}\n'
'\tdetails = "{}"\n'
'>')
_NON_OK_CALL_REPRESENTATION = ('<{} of RPC that terminated with:\n'
'\tstatus = {}\n'
'\tdetails = "{}"\n'
'\tdebug_error_string = "{}"\n'
'>')
class AioRpcError(grpc.RpcError):
"""An RpcError to be used by the asynchronous API."""
"""An implementation of RpcError to be used by the asynchronous API.
Raised RpcError is a snapshot of the final status of the RPC, values are
determined. Hence, its methods no longer needs to be coroutines.
"""
# TODO(https://github.com/grpc/grpc/issues/20144) Metadata
# type returned by `initial_metadata` and `trailing_metadata`
@ -33,14 +56,16 @@ class AioRpcError(grpc.RpcError):
_code: grpc.StatusCode
_details: Optional[str]
_initial_metadata: Optional[Dict]
_trailing_metadata: Optional[Dict]
_initial_metadata: Optional[MetadataType]
_trailing_metadata: Optional[MetadataType]
_debug_error_string: Optional[str]
def __init__(self,
code: grpc.StatusCode,
details: Optional[str] = None,
initial_metadata: Optional[Dict] = None,
trailing_metadata: Optional[Dict] = None):
initial_metadata: Optional[MetadataType] = None,
trailing_metadata: Optional[MetadataType] = None,
debug_error_string: Optional[str] = None) -> None:
"""Constructor.
Args:
@ -56,207 +81,336 @@ class AioRpcError(grpc.RpcError):
self._details = details
self._initial_metadata = initial_metadata
self._trailing_metadata = trailing_metadata
self._debug_error_string = debug_error_string
def code(self) -> grpc.StatusCode:
"""
"""Accesses the status code sent by the server.
Returns:
The `grpc.StatusCode` status code.
"""
return self._code
def details(self) -> Optional[str]:
"""
"""Accesses the details sent by the server.
Returns:
The description of the error.
"""
return self._details
def initial_metadata(self) -> Optional[Dict]:
"""
"""Accesses the initial metadata sent by the server.
Returns:
The inital metadata received.
The initial metadata received.
"""
return self._initial_metadata
def trailing_metadata(self) -> Optional[Dict]:
"""
"""Accesses the trailing metadata sent by the server.
Returns:
The trailing metadata received.
"""
return self._trailing_metadata
def debug_error_string(self) -> str:
"""Accesses the debug error string sent by the server.
@enum.unique
class _RpcState(enum.Enum):
"""Identifies the state of the RPC."""
ONGOING = 1
CANCELLED = 2
FINISHED = 3
ABORT = 4
Returns:
The debug error string received.
"""
return self._debug_error_string
def _repr(self) -> str:
"""Assembles the error string for the RPC error."""
return _NON_OK_CALL_REPRESENTATION.format(self.__class__.__name__,
self._code, self._details,
self._debug_error_string)
class Call:
"""Object for managing RPC calls,
returned when an instance of `UnaryUnaryMultiCallable` object is called.
"""
def __repr__(self) -> str:
return self._repr()
_cancellation_details: ClassVar[str] = 'Locally cancelled by application!'
def __str__(self) -> str:
return self._repr()
_state: _RpcState
_exception: Optional[Exception]
_response: Optional[bytes]
_code: grpc.StatusCode
_details: Optional[str]
_initial_metadata: Optional[Dict]
_trailing_metadata: Optional[Dict]
_call: asyncio.Task
_call_cancel_status: cygrpc.AioCancelStatus
_response_deserializer: DeserializingFunction
def __init__(self, call: asyncio.Task,
response_deserializer: DeserializingFunction,
call_cancel_status: cygrpc.AioCancelStatus) -> None:
"""Constructor.
def _create_rpc_error(initial_metadata: Optional[MetadataType],
status: cygrpc.AioRpcStatus) -> AioRpcError:
return AioRpcError(_common.CYGRPC_STATUS_CODE_TO_STATUS_CODE[status.code()],
status.details(), initial_metadata,
status.trailing_metadata())
Args:
call: Asyncio Task that holds the RPC execution.
response_deserializer: Deserializer used for parsing the reponse.
call_cancel_status: A cygrpc.AioCancelStatus used for giving a
specific error when the RPC is canceled.
"""
self._state = _RpcState.ONGOING
self._exception = None
self._response = None
self._code = grpc.StatusCode.UNKNOWN
self._details = None
self._initial_metadata = None
self._trailing_metadata = None
self._call = call
self._call_cancel_status = call_cancel_status
self._response_deserializer = response_deserializer
class Call(_base_call.Call):
_loop: asyncio.AbstractEventLoop
_code: grpc.StatusCode
_status: Awaitable[cygrpc.AioRpcStatus]
_initial_metadata: Awaitable[MetadataType]
_cancellation: asyncio.Future
def __del__(self):
self.cancel()
def __init__(self) -> None:
self._loop = asyncio.get_event_loop()
self._code = None
self._status = self._loop.create_future()
self._initial_metadata = self._loop.create_future()
self._cancellation = self._loop.create_future()
def cancel(self) -> bool:
"""Cancels the ongoing RPC request.
"""Placeholder cancellation method.
Returns:
True if the RPC can be canceled, False if was already cancelled or terminated.
The implementation of this method needs to pass the cancellation reason
into self._cancellation, using `set_result` instead of
`set_exception`.
"""
if self.cancelled() or self.done():
return False
code = grpc.StatusCode.CANCELLED
self._call_cancel_status.cancel(
_common.STATUS_CODE_TO_CYGRPC_STATUS_CODE[code],
details=Call._cancellation_details)
self._call.cancel()
self._details = Call._cancellation_details
self._code = code
self._state = _RpcState.CANCELLED
return True
raise NotImplementedError()
def cancelled(self) -> bool:
"""Returns if the RPC was cancelled.
Returns:
True if the requests was cancelled, False if not.
"""
return self._state is _RpcState.CANCELLED
def running(self) -> bool:
"""Returns if the RPC is running.
Returns:
True if the requests is running, False if it already terminated.
"""
return not self.done()
return self._cancellation.done(
) or self._code == grpc.StatusCode.CANCELLED
def done(self) -> bool:
"""Returns if the RPC has finished.
return self._status.done()
Returns:
True if the requests has finished, False is if still ongoing.
"""
return self._state is not _RpcState.ONGOING
async def initial_metadata(self):
def add_done_callback(self, unused_callback) -> None:
raise NotImplementedError()
async def trailing_metadata(self):
def time_remaining(self) -> Optional[float]:
raise NotImplementedError()
async def code(self) -> grpc.StatusCode:
"""Returns the `grpc.StatusCode` if the RPC is finished,
otherwise first waits until the RPC finishes.
async def initial_metadata(self) -> MetadataType:
return await self._initial_metadata
Returns:
The `grpc.StatusCode` status code.
"""
if not self.done():
try:
await self
except (asyncio.CancelledError, AioRpcError):
pass
async def trailing_metadata(self) -> MetadataType:
return (await self._status).trailing_metadata()
async def code(self) -> grpc.StatusCode:
await self._status
return self._code
async def details(self) -> str:
"""Returns the details if the RPC is finished, otherwise first waits till the
RPC finishes.
return (await self._status).details()
Returns:
The details.
async def debug_error_string(self) -> str:
return (await self._status).debug_error_string()
def _set_initial_metadata(self, metadata: MetadataType) -> None:
self._initial_metadata.set_result(metadata)
def _set_status(self, status: cygrpc.AioRpcStatus) -> None:
"""Private method to set final status of the RPC.
This method may be called multiple time due to data race between local
cancellation (by application) and Core receiving status from peer. We
make no promise here which one will win.
"""
if not self.done():
try:
await self
except (asyncio.CancelledError, AioRpcError):
pass
if self._status.done():
return
else:
self._status.set_result(status)
self._code = _common.CYGRPC_STATUS_CODE_TO_STATUS_CODE[
status.code()]
async def _raise_rpc_error_if_not_ok(self) -> None:
if self._code != grpc.StatusCode.OK:
raise _create_rpc_error(await self.initial_metadata(),
self._status.result())
def _repr(self) -> str:
"""Assembles the RPC representation string."""
if not self._status.done():
return '<{} object>'.format(self.__class__.__name__)
if self._code is grpc.StatusCode.OK:
return _OK_CALL_REPRESENTATION.format(
self.__class__.__name__, self._code,
self._status.result().self._status.result().details())
else:
return _NON_OK_CALL_REPRESENTATION.format(
self.__class__.__name__, self._code,
self._status.result().details(),
self._status.result().debug_error_string())
def __repr__(self) -> str:
return self._repr()
def __str__(self) -> str:
return self._repr()
# pylint: disable=abstract-method
class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall):
"""Object for managing unary-unary RPC calls.
Returned when an instance of `UnaryUnaryMultiCallable` object is called.
"""
_request: RequestType
_deadline: Optional[float]
_channel: cygrpc.AioChannel
_method: bytes
_request_serializer: SerializingFunction
_response_deserializer: DeserializingFunction
_call: asyncio.Task
return self._details
def __init__(self, request: RequestType, deadline: Optional[float],
channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction) -> None:
super().__init__()
self._request = request
self._deadline = deadline
self._channel = channel
self._method = method
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
self._call = self._loop.create_task(self._invoke())
def __del__(self) -> None:
if not self._call.done():
self._cancel(
cygrpc.AioRpcStatus(cygrpc.StatusCode.cancelled,
_GC_CANCELLATION_DETAILS, None, None))
async def _invoke(self) -> ResponseType:
serialized_request = _common.serialize(self._request,
self._request_serializer)
# NOTE(lidiz) asyncio.CancelledError is not a good transport for
# status, since the Task class do not cache the exact
# asyncio.CancelledError object. So, the solution is catching the error
# in Cython layer, then cancel the RPC and update the status, finally
# re-raise the CancelledError.
serialized_response = await self._channel.unary_unary(
self._method,
serialized_request,
self._deadline,
self._cancellation,
self._set_initial_metadata,
self._set_status,
)
await self._raise_rpc_error_if_not_ok()
return _common.deserialize(serialized_response,
self._response_deserializer)
def _cancel(self, status: cygrpc.AioRpcStatus) -> bool:
"""Forwards the application cancellation reasoning."""
if not self._status.done() and not self._cancellation.done():
self._cancellation.set_result(status)
self._call.cancel()
return True
else:
return False
def __await__(self):
def cancel(self) -> bool:
return self._cancel(
cygrpc.AioRpcStatus(cygrpc.StatusCode.cancelled,
_LOCAL_CANCELLATION_DETAILS, None, None))
def __await__(self) -> ResponseType:
"""Wait till the ongoing RPC request finishes.
Returns:
Response of the RPC call.
Raises:
AioRpcError: Indicating that the RPC terminated with non-OK status.
RpcError: Indicating that the RPC terminated with non-OK status.
asyncio.CancelledError: Indicating that the RPC was canceled.
"""
# We can not relay on the `done()` method since some exceptions
# might be pending to be catched, like `asyncio.CancelledError`.
if self._response:
return self._response
elif self._exception:
raise self._exception
try:
buffer_ = yield from self._call.__await__()
except cygrpc.AioRpcError as aio_rpc_error:
self._state = _RpcState.ABORT
self._code = _common.CYGRPC_STATUS_CODE_TO_STATUS_CODE[
aio_rpc_error.code()]
self._details = aio_rpc_error.details()
self._initial_metadata = aio_rpc_error.initial_metadata()
self._trailing_metadata = aio_rpc_error.trailing_metadata()
# Propagates the pure Python class
self._exception = AioRpcError(self._code, self._details,
self._initial_metadata,
self._trailing_metadata)
raise self._exception from aio_rpc_error
except asyncio.CancelledError as cancel_error:
# _state, _code, _details are managed in the `cancel` method
self._exception = cancel_error
raise
self._response = _common.deserialize(buffer_,
self._response_deserializer)
self._code = grpc.StatusCode.OK
self._state = _RpcState.FINISHED
return self._response
response = yield from self._call
return response
# pylint: disable=abstract-method
class UnaryStreamCall(Call, _base_call.UnaryStreamCall):
"""Object for managing unary-stream RPC calls.
Returned when an instance of `UnaryStreamMultiCallable` object is called.
"""
_request: RequestType
_deadline: Optional[float]
_channel: cygrpc.AioChannel
_method: bytes
_request_serializer: SerializingFunction
_response_deserializer: DeserializingFunction
_call: asyncio.Task
_bytes_aiter: AsyncIterable[bytes]
_message_aiter: AsyncIterable[ResponseType]
def __init__(self, request: RequestType, deadline: Optional[float],
channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction) -> None:
super().__init__()
self._request = request
self._deadline = deadline
self._channel = channel
self._method = method
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
self._call = self._loop.create_task(self._invoke())
self._message_aiter = self._process()
def __del__(self) -> None:
if not self._status.done():
self._cancel(
cygrpc.AioRpcStatus(cygrpc.StatusCode.cancelled,
_GC_CANCELLATION_DETAILS, None, None))
async def _invoke(self) -> ResponseType:
serialized_request = _common.serialize(self._request,
self._request_serializer)
self._bytes_aiter = await self._channel.unary_stream(
self._method,
serialized_request,
self._deadline,
self._cancellation,
self._set_initial_metadata,
self._set_status,
)
async def _process(self) -> ResponseType:
await self._call
async for serialized_response in self._bytes_aiter:
if self._cancellation.done():
await self._status
if self._status.done():
# Raises pre-maturely if final status received here. Generates
# more helpful stack trace for end users.
await self._raise_rpc_error_if_not_ok()
yield _common.deserialize(serialized_response,
self._response_deserializer)
await self._raise_rpc_error_if_not_ok()
def _cancel(self, status: cygrpc.AioRpcStatus) -> bool:
"""Forwards the application cancellation reasoning.
Async generator will receive an exception. The cancellation will go
deep down into Core, and then propagates backup as the
`cygrpc.AioRpcStatus` exception.
So, under race condition, e.g. the server sent out final state headers
and the client calling "cancel" at the same time, this method respects
the winner in Core.
"""
if not self._status.done() and not self._cancellation.done():
self._cancellation.set_result(status)
return True
else:
return False
def cancel(self) -> bool:
return self._cancel(
cygrpc.AioRpcStatus(cygrpc.StatusCode.cancelled,
_LOCAL_CANCELLATION_DETAILS, None, None))
def __aiter__(self) -> AsyncIterable[ResponseType]:
return self._message_aiter
async def read(self) -> ResponseType:
if self._status.done():
await self._raise_rpc_error_if_not_ok()
raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
return await self._message_aiter.__anext__()

@ -13,42 +13,114 @@
# limitations under the License.
"""Invocation-side implementation of gRPC Asyncio Python."""
import asyncio
from typing import Callable, Optional
from typing import Any, Optional, Sequence, Text, Tuple
import grpc
from grpc import _common
from grpc._cython import cygrpc
from . import _base_call
from ._call import UnaryUnaryCall, UnaryStreamCall
from ._typing import (DeserializingFunction, MetadataType, SerializingFunction)
from ._call import Call
SerializingFunction = Callable[[str], bytes]
DeserializingFunction = Callable[[bytes], str]
def _timeout_to_deadline(loop: asyncio.AbstractEventLoop,
timeout: Optional[float]) -> Optional[float]:
if timeout is None:
return None
return loop.time() + timeout
class UnaryUnaryMultiCallable:
"""Afford invoking a unary-unary RPC from client-side in an asynchronous way."""
"""Factory an asynchronous unary-unary RPC stub call from client-side."""
def __init__(self, channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction) -> None:
self._loop = asyncio.get_event_loop()
self._channel = channel
self._method = method
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
self._loop = asyncio.get_event_loop()
def _timeout_to_deadline(self, timeout: int) -> Optional[int]:
if timeout is None:
return None
return self._loop.time() + timeout
def __call__(self,
request: Any,
*,
timeout: Optional[float] = None,
metadata: Optional[MetadataType] = None,
credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None
) -> _base_call.UnaryUnaryCall:
"""Asynchronously invokes the underlying RPC.
Args:
request: The request value for the RPC.
timeout: An optional duration of time in seconds to allow
for the RPC.
metadata: Optional :term:`metadata` to be transmitted to the
service-side of the RPC.
credentials: An optional CallCredentials for the RPC. Only valid for
secure Channel.
wait_for_ready: This is an EXPERIMENTAL argument. An optional
flag to enable wait for ready mechanism
compression: An element of grpc.compression, e.g.
grpc.compression.Gzip. This is an EXPERIMENTAL option.
Returns:
A Call object instance which is an awaitable object.
Raises:
RpcError: Indicating that the RPC terminated with non-OK status. The
raised RpcError will also be a Call for the RPC affording the RPC's
metadata, status code, and details.
"""
if metadata:
raise NotImplementedError("TODO: metadata not implemented yet")
if credentials:
raise NotImplementedError("TODO: credentials not implemented yet")
if wait_for_ready:
raise NotImplementedError(
"TODO: wait_for_ready not implemented yet")
if compression:
raise NotImplementedError("TODO: compression not implemented yet")
deadline = _timeout_to_deadline(self._loop, timeout)
return UnaryUnaryCall(
request,
deadline,
self._channel,
self._method,
self._request_serializer,
self._response_deserializer,
)
class UnaryStreamMultiCallable:
"""Afford invoking a unary-stream RPC from client-side in an asynchronous way."""
def __init__(self, channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction) -> None:
self._channel = channel
self._method = method
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
self._loop = asyncio.get_event_loop()
def __call__(self,
request,
request: Any,
*,
timeout=None,
metadata=None,
credentials=None,
wait_for_ready=None,
compression=None) -> Call:
timeout: Optional[float] = None,
metadata: Optional[MetadataType] = None,
credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None
) -> _base_call.UnaryStreamCall:
"""Asynchronously invokes the underlying RPC.
Args:
@ -86,15 +158,16 @@ class UnaryUnaryMultiCallable:
if compression:
raise NotImplementedError("TODO: compression not implemented yet")
serialized_request = _common.serialize(request,
self._request_serializer)
timeout = self._timeout_to_deadline(timeout)
aio_cancel_status = cygrpc.AioCancelStatus()
aio_call = asyncio.ensure_future(
self._channel.unary_unary(self._method, serialized_request, timeout,
aio_cancel_status),
loop=self._loop)
return Call(aio_call, self._response_deserializer, aio_cancel_status)
deadline = _timeout_to_deadline(self._loop, timeout)
return UnaryStreamCall(
request,
deadline,
self._channel,
self._method,
self._request_serializer,
self._response_deserializer,
)
class Channel:
@ -103,7 +176,10 @@ class Channel:
A cygrpc.AioChannel-backed implementation.
"""
def __init__(self, target, options, credentials, compression):
def __init__(self, target: Text,
options: Optional[Sequence[Tuple[Text, Any]]],
credentials: Optional[grpc.ChannelCredentials],
compression: Optional[grpc.Compression]):
"""Constructor.
Args:
@ -125,10 +201,12 @@ class Channel:
self._channel = cygrpc.AioChannel(_common.encode(target))
def unary_unary(self,
method,
request_serializer=None,
response_deserializer=None):
def unary_unary(
self,
method: Text,
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None
) -> UnaryUnaryMultiCallable:
"""Creates a UnaryUnaryMultiCallable for a unary-unary method.
Args:
@ -146,6 +224,30 @@ class Channel:
request_serializer,
response_deserializer)
def unary_stream(
self,
method: Text,
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None
) -> UnaryStreamMultiCallable:
return UnaryStreamMultiCallable(self._channel, _common.encode(method),
request_serializer,
response_deserializer)
def stream_unary(
self,
method: Text,
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None):
"""Placeholder method for stream-unary calls."""
def stream_stream(
self,
method: Text,
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None):
"""Placeholder method for stream-stream calls."""
async def _close(self):
# TODO: Send cancellation status
self._channel.close()

@ -1,4 +1,4 @@
# Copyright 2019 gRPC authors.
# Copyright 2019 The gRPC Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -11,13 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Desired cancellation status for canceling an ongoing RPC calls."""
"""Common types for gRPC Async API"""
from typing import Any, AnyStr, Callable, Sequence, Text, Tuple, TypeVar
cdef class AioCancelStatus:
cdef readonly:
object _code
str _details
cpdef object code(self)
cpdef str details(self)
RequestType = TypeVar('RequestType')
ResponseType = TypeVar('ResponseType')
SerializingFunction = Callable[[Any], bytes]
DeserializingFunction = Callable[[bytes], Any]
MetadataType = Sequence[Tuple[Text, AnyStr]]

@ -120,6 +120,8 @@ class TestAio(setuptools.Command):
def run(self):
self._add_eggs_to_path()
from grpc.experimental.aio import init_grpc_aio
init_grpc_aio()
import tests
loader = tests.Loader()

@ -15,7 +15,6 @@
from __future__ import absolute_import
import collections
import multiprocessing
import os
import select
import signal
@ -115,6 +114,8 @@ class AugmentedCase(collections.namedtuple('AugmentedCase', ['case', 'id'])):
return super(cls, AugmentedCase).__new__(cls, case, id)
# NOTE(lidiz) This complex wrapper is not triggering setUpClass nor
# tearDownClass. Do not use those methods, or fix this wrapper!
class Runner(object):
def __init__(self, dedicated_threads=False):

@ -0,0 +1,32 @@
# Copyright 2019 The gRPC Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(
default_testonly = 1,
default_visibility = ["//visibility:public"],
)
py_binary(
name = "server",
srcs = ["server.py"],
python_version = "PY3",
deps = [
"//external:six",
"//src/proto/grpc/testing:benchmark_service_py_pb2",
"//src/proto/grpc/testing:benchmark_service_py_pb2_grpc",
"//src/proto/grpc/testing:py_messages_proto",
"//src/python/grpcio/grpc:grpcio",
"//src/python/grpcio_tests/tests/unit/framework/common",
],
)

@ -27,6 +27,12 @@ class BenchmarkServer(benchmark_service_pb2_grpc.BenchmarkServiceServicer):
payload = messages_pb2.Payload(body=b'\0' * request.response_size)
return messages_pb2.SimpleResponse(payload=payload)
async def StreamingFromServer(self, request, context):
payload = messages_pb2.Payload(body=b'\0' * request.response_size)
# Sends response at full capacity!
while True:
yield messages_pb2.SimpleResponse(payload=payload)
async def _start_async_server():
server = aio.server()
@ -37,6 +43,7 @@ async def _start_async_server():
servicer, server)
await server.start()
logging.info('Benchmark server started at :%d' % port)
await server.wait_for_termination()
@ -48,5 +55,5 @@ def main():
if __name__ == '__main__':
logging.basicConfig()
logging.basicConfig(level=logging.DEBUG)
main()

@ -1,7 +1,8 @@
[
"_sanity._sanity_test.AioSanityTest",
"unit.call_test.TestAioRpcError",
"unit.call_test.TestCall",
"unit.aio_rpc_error_test.TestAioRpcError",
"unit.call_test.TestUnaryStreamCall",
"unit.call_test.TestUnaryUnaryCall",
"unit.channel_test.TestChannel",
"unit.init_test.TestInsecureChannel",
"unit.server_test.TestServer"

@ -12,18 +12,55 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import functools
import asyncio
from typing import Callable
import unittest
from grpc.experimental import aio
__all__ = 'AioTestBase'
class AioTestBase(unittest.TestCase):
_COROUTINE_FUNCTION_ALLOWLIST = ['setUp', 'tearDown']
def _async_to_sync_decorator(f: Callable, loop: asyncio.AbstractEventLoop):
@functools.wraps(f)
def wrapper(*args, **kwargs):
return loop.run_until_complete(f(*args, **kwargs))
return wrapper
def _get_default_loop(debug=True):
try:
loop = asyncio.get_event_loop()
except:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
finally:
loop.set_debug(debug)
return loop
def setUp(self):
self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop)
aio.init_grpc_aio()
# NOTE(gnossen) this test class can also be implemented with metaclass.
class AioTestBase(unittest.TestCase):
@property
def loop(self):
return self._loop
return _get_default_loop()
def __getattribute__(self, name):
"""Overrides the loading logic to support coroutine functions."""
attr = super().__getattribute__(name)
# If possible, converts the coroutine into a sync function.
if name.startswith('test_') or name in _COROUTINE_FUNCTION_ALLOWLIST:
if asyncio.iscoroutinefunction(attr):
return _async_to_sync_decorator(attr, _get_default_loop())
# For other attributes, let them pass.
return attr
aio.init_grpc_aio()

@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from time import sleep
import asyncio
import logging
import datetime
from grpc.experimental import aio
from src.proto.grpc.testing import messages_pb2
@ -25,9 +27,23 @@ class _TestServiceServicer(test_pb2_grpc.TestServiceServicer):
async def UnaryCall(self, request, context):
return messages_pb2.SimpleResponse()
# TODO(lidizheng) The semantic of this call is not matching its description
# See src/proto/grpc/testing/test.proto
async def EmptyCall(self, request, context):
while True:
sleep(test_constants.LONG_TIMEOUT)
await asyncio.sleep(test_constants.LONG_TIMEOUT)
async def StreamingOutputCall(
self, request: messages_pb2.StreamingOutputCallRequest, context):
for response_parameters in request.response_parameters:
if response_parameters.interval_us != 0:
await asyncio.sleep(
datetime.timedelta(microseconds=response_parameters.
interval_us).total_seconds())
yield messages_pb2.StreamingOutputCallResponse(
payload=messages_pb2.Payload(
type=request.response_type,
body=b'\x00' * response_parameters.size))
async def start_test_server():

@ -0,0 +1,50 @@
# Copyright 2019 The gRPC Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests AioRpcError class."""
import logging
import unittest
import grpc
from grpc.experimental.aio._call import AioRpcError
from tests_aio.unit._test_base import AioTestBase
_TEST_INITIAL_METADATA = ('initial metadata',)
_TEST_TRAILING_METADATA = ('trailing metadata',)
_TEST_DEBUG_ERROR_STRING = '{This is a debug string}'
class TestAioRpcError(unittest.TestCase):
def test_attributes(self):
aio_rpc_error = AioRpcError(
grpc.StatusCode.CANCELLED,
'details',
initial_metadata=_TEST_INITIAL_METADATA,
trailing_metadata=_TEST_TRAILING_METADATA,
debug_error_string=_TEST_DEBUG_ERROR_STRING)
self.assertEqual(aio_rpc_error.code(), grpc.StatusCode.CANCELLED)
self.assertEqual(aio_rpc_error.details(), 'details')
self.assertEqual(aio_rpc_error.initial_metadata(),
_TEST_INITIAL_METADATA)
self.assertEqual(aio_rpc_error.trailing_metadata(),
_TEST_TRAILING_METADATA)
self.assertEqual(aio_rpc_error.debug_error_string(),
_TEST_DEBUG_ERROR_STRING)
if __name__ == '__main__':
logging.basicConfig()
unittest.main(verbosity=2)

@ -11,186 +11,324 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests behavior of the grpc.aio.UnaryUnaryCall class."""
import asyncio
import logging
import unittest
import datetime
import grpc
from grpc.experimental import aio
from src.proto.grpc.testing import messages_pb2
from src.proto.grpc.testing import test_pb2_grpc
from tests.unit.framework.common import test_constants
from tests_aio.unit._test_server import start_test_server
from tests_aio.unit._test_base import AioTestBase
class TestAioRpcError(unittest.TestCase):
_TEST_INITIAL_METADATA = ("initial metadata",)
_TEST_TRAILING_METADATA = ("trailing metadata",)
def test_attributes(self):
aio_rpc_error = aio.AioRpcError(
grpc.StatusCode.CANCELLED,
"details",
initial_metadata=self._TEST_INITIAL_METADATA,
trailing_metadata=self._TEST_TRAILING_METADATA)
self.assertEqual(aio_rpc_error.code(), grpc.StatusCode.CANCELLED)
self.assertEqual(aio_rpc_error.details(), "details")
self.assertEqual(aio_rpc_error.initial_metadata(),
self._TEST_INITIAL_METADATA)
self.assertEqual(aio_rpc_error.trailing_metadata(),
self._TEST_TRAILING_METADATA)
class TestCall(AioTestBase):
def test_call_ok(self):
async def coro():
server_target, _ = await start_test_server() # pylint: disable=unused-variable
async with aio.insecure_channel(server_target) as channel:
hi = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.SimpleRequest.
SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString
)
call = hi(messages_pb2.SimpleRequest())
self.assertFalse(call.done())
response = await call
self.assertTrue(call.done())
self.assertEqual(type(response), messages_pb2.SimpleResponse)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
# Response is cached at call object level, reentrance
# returns again the same response
response_retry = await call
self.assertIs(response, response_retry)
self.loop.run_until_complete(coro())
def test_call_rpc_error(self):
async def coro():
server_target, _ = await start_test_server() # pylint: disable=unused-variable
async with aio.insecure_channel(server_target) as channel:
empty_call_with_sleep = channel.unary_unary(
"/grpc.testing.TestService/EmptyCall",
request_serializer=messages_pb2.SimpleRequest.
SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.
FromString,
)
timeout = test_constants.SHORT_TIMEOUT / 2
# TODO(https://github.com/grpc/grpc/issues/20869
# Update once the async server is ready, change the
# synchronization mechanism by removing the sleep(<timeout>)
# as both components (client & server) will be on the same
# process.
call = empty_call_with_sleep(
messages_pb2.SimpleRequest(), timeout=timeout)
with self.assertRaises(grpc.RpcError) as exception_context:
await call
self.assertTrue(call.done())
self.assertEqual(await call.code(),
grpc.StatusCode.DEADLINE_EXCEEDED)
# Exception is cached at call object level, reentrance
# returns again the same exception
with self.assertRaises(
grpc.RpcError) as exception_context_retry:
await call
self.assertIs(exception_context.exception,
exception_context_retry.exception)
self.loop.run_until_complete(coro())
def test_call_code_awaitable(self):
async def coro():
server_target, _ = await start_test_server() # pylint: disable=unused-variable
async with aio.insecure_channel(server_target) as channel:
hi = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.SimpleRequest.
SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString
)
call = hi(messages_pb2.SimpleRequest())
self.assertEqual(await call.code(), grpc.StatusCode.OK)
self.loop.run_until_complete(coro())
def test_call_details_awaitable(self):
async def coro():
server_target, _ = await start_test_server() # pylint: disable=unused-variable
async with aio.insecure_channel(server_target) as channel:
hi = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.SimpleRequest.
SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString
)
call = hi(messages_pb2.SimpleRequest())
self.assertEqual(await call.details(), None)
self.loop.run_until_complete(coro())
def test_cancel(self):
async def coro():
server_target, _ = await start_test_server() # pylint: disable=unused-variable
async with aio.insecure_channel(server_target) as channel:
hi = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.SimpleRequest.
SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString
)
call = hi(messages_pb2.SimpleRequest())
self.assertFalse(call.cancelled())
# TODO(https://github.com/grpc/grpc/issues/20869) remove sleep.
# Force the loop to execute the RPC task.
await asyncio.sleep(0)
self.assertTrue(call.cancel())
self.assertTrue(call.cancelled())
self.assertFalse(call.cancel())
with self.assertRaises(
asyncio.CancelledError) as exception_context:
await call
self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED)
self.assertEqual(await call.details(),
'Locally cancelled by application!')
# Exception is cached at call object level, reentrance
# returns again the same exception
with self.assertRaises(
asyncio.CancelledError) as exception_context_retry:
await call
self.assertIs(exception_context.exception,
exception_context_retry.exception)
self.loop.run_until_complete(coro())
_NUM_STREAM_RESPONSES = 5
_RESPONSE_PAYLOAD_SIZE = 42
_LOCAL_CANCEL_DETAILS_EXPECTATION = 'Locally cancelled by application!'
_RESPONSE_INTERVAL_US = test_constants.SHORT_TIMEOUT * 1000 * 1000
class TestUnaryUnaryCall(AioTestBase):
async def setUp(self):
self._server_target, self._server = await start_test_server()
async def tearDown(self):
await self._server.stop(None)
async def test_call_ok(self):
async with aio.insecure_channel(self._server_target) as channel:
hi = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
call = hi(messages_pb2.SimpleRequest())
self.assertFalse(call.done())
response = await call
self.assertTrue(call.done())
self.assertIsInstance(response, messages_pb2.SimpleResponse)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
# Response is cached at call object level, reentrance
# returns again the same response
response_retry = await call
self.assertIs(response, response_retry)
async def test_call_rpc_error(self):
async with aio.insecure_channel(self._server_target) as channel:
empty_call_with_sleep = channel.unary_unary(
"/grpc.testing.TestService/EmptyCall",
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString,
)
timeout = test_constants.SHORT_TIMEOUT / 2
# TODO(https://github.com/grpc/grpc/issues/20869
# Update once the async server is ready, change the
# synchronization mechanism by removing the sleep(<timeout>)
# as both components (client & server) will be on the same
# process.
call = empty_call_with_sleep(
messages_pb2.SimpleRequest(), timeout=timeout)
with self.assertRaises(grpc.RpcError) as exception_context:
await call
self.assertEqual(grpc.StatusCode.DEADLINE_EXCEEDED,
exception_context.exception.code())
self.assertTrue(call.done())
self.assertEqual(grpc.StatusCode.DEADLINE_EXCEEDED, await
call.code())
# Exception is cached at call object level, reentrance
# returns again the same exception
with self.assertRaises(grpc.RpcError) as exception_context_retry:
await call
self.assertIs(exception_context.exception,
exception_context_retry.exception)
async def test_call_code_awaitable(self):
async with aio.insecure_channel(self._server_target) as channel:
hi = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
call = hi(messages_pb2.SimpleRequest())
self.assertEqual(await call.code(), grpc.StatusCode.OK)
async def test_call_details_awaitable(self):
async with aio.insecure_channel(self._server_target) as channel:
hi = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
call = hi(messages_pb2.SimpleRequest())
self.assertEqual('', await call.details())
async def test_cancel_unary_unary(self):
async with aio.insecure_channel(self._server_target) as channel:
hi = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
call = hi(messages_pb2.SimpleRequest())
self.assertFalse(call.cancelled())
# TODO(https://github.com/grpc/grpc/issues/20869) remove sleep.
# Force the loop to execute the RPC task.
await asyncio.sleep(0)
self.assertTrue(call.cancel())
self.assertFalse(call.cancel())
with self.assertRaises(asyncio.CancelledError) as exception_context:
await call
self.assertTrue(call.cancelled())
self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED)
self.assertEqual(await call.details(),
'Locally cancelled by application!')
# NOTE(lidiz) The CancelledError is almost always re-created,
# so we might not want to use it to transmit data.
# https://github.com/python/cpython/blob/edad4d89e357c92f70c0324b937845d652b20afd/Lib/asyncio/tasks.py#L785
class TestUnaryStreamCall(AioTestBase):
async def setUp(self):
self._server_target, self._server = await start_test_server()
async def tearDown(self):
await self._server.stop(None)
async def test_cancel_unary_stream(self):
async with aio.insecure_channel(self._server_target) as channel:
stub = test_pb2_grpc.TestServiceStub(channel)
# Prepares the request
request = messages_pb2.StreamingOutputCallRequest()
for _ in range(_NUM_STREAM_RESPONSES):
request.response_parameters.append(
messages_pb2.ResponseParameters(
size=_RESPONSE_PAYLOAD_SIZE,
interval_us=_RESPONSE_INTERVAL_US,
))
# Invokes the actual RPC
call = stub.StreamingOutputCall(request)
self.assertFalse(call.cancelled())
response = await call.read()
self.assertIs(
type(response), messages_pb2.StreamingOutputCallResponse)
self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
self.assertTrue(call.cancel())
self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
self.assertEqual(_LOCAL_CANCEL_DETAILS_EXPECTATION, await
call.details())
self.assertFalse(call.cancel())
with self.assertRaises(grpc.RpcError) as exception_context:
await call.read()
self.assertTrue(call.cancelled())
async def test_multiple_cancel_unary_stream(self):
async with aio.insecure_channel(self._server_target) as channel:
stub = test_pb2_grpc.TestServiceStub(channel)
# Prepares the request
request = messages_pb2.StreamingOutputCallRequest()
for _ in range(_NUM_STREAM_RESPONSES):
request.response_parameters.append(
messages_pb2.ResponseParameters(
size=_RESPONSE_PAYLOAD_SIZE,
interval_us=_RESPONSE_INTERVAL_US,
))
# Invokes the actual RPC
call = stub.StreamingOutputCall(request)
self.assertFalse(call.cancelled())
response = await call.read()
self.assertIs(
type(response), messages_pb2.StreamingOutputCallResponse)
self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
self.assertTrue(call.cancel())
self.assertFalse(call.cancel())
self.assertFalse(call.cancel())
self.assertFalse(call.cancel())
with self.assertRaises(grpc.RpcError) as exception_context:
await call.read()
async def test_early_cancel_unary_stream(self):
"""Test cancellation before receiving messages."""
async with aio.insecure_channel(self._server_target) as channel:
stub = test_pb2_grpc.TestServiceStub(channel)
# Prepares the request
request = messages_pb2.StreamingOutputCallRequest()
for _ in range(_NUM_STREAM_RESPONSES):
request.response_parameters.append(
messages_pb2.ResponseParameters(
size=_RESPONSE_PAYLOAD_SIZE,
interval_us=_RESPONSE_INTERVAL_US,
))
# Invokes the actual RPC
call = stub.StreamingOutputCall(request)
self.assertFalse(call.cancelled())
self.assertTrue(call.cancel())
self.assertFalse(call.cancel())
with self.assertRaises(grpc.RpcError) as exception_context:
await call.read()
self.assertTrue(call.cancelled())
self.assertEqual(grpc.StatusCode.CANCELLED,
exception_context.exception.code())
self.assertEqual(_LOCAL_CANCEL_DETAILS_EXPECTATION,
exception_context.exception.details())
self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
self.assertEqual(_LOCAL_CANCEL_DETAILS_EXPECTATION, await
call.details())
async def test_late_cancel_unary_stream(self):
"""Test cancellation after received all messages."""
async with aio.insecure_channel(self._server_target) as channel:
stub = test_pb2_grpc.TestServiceStub(channel)
# Prepares the request
request = messages_pb2.StreamingOutputCallRequest()
for _ in range(_NUM_STREAM_RESPONSES):
request.response_parameters.append(
messages_pb2.ResponseParameters(
size=_RESPONSE_PAYLOAD_SIZE,))
# Invokes the actual RPC
call = stub.StreamingOutputCall(request)
for _ in range(_NUM_STREAM_RESPONSES):
response = await call.read()
self.assertIs(
type(response), messages_pb2.StreamingOutputCallResponse)
self.assertEqual(_RESPONSE_PAYLOAD_SIZE,
len(response.payload.body))
# After all messages received, it is possible that the final state
# is received or on its way. It's basically a data race, so our
# expectation here is do not crash :)
call.cancel()
self.assertIn(await call.code(),
[grpc.StatusCode.OK, grpc.StatusCode.CANCELLED])
async def test_too_many_reads_unary_stream(self):
"""Test cancellation after received all messages."""
async with aio.insecure_channel(self._server_target) as channel:
stub = test_pb2_grpc.TestServiceStub(channel)
# Prepares the request
request = messages_pb2.StreamingOutputCallRequest()
for _ in range(_NUM_STREAM_RESPONSES):
request.response_parameters.append(
messages_pb2.ResponseParameters(
size=_RESPONSE_PAYLOAD_SIZE,))
# Invokes the actual RPC
call = stub.StreamingOutputCall(request)
for _ in range(_NUM_STREAM_RESPONSES):
response = await call.read()
self.assertIs(
type(response), messages_pb2.StreamingOutputCallResponse)
self.assertEqual(_RESPONSE_PAYLOAD_SIZE,
len(response.payload.body))
# After the RPC is finished, further reads will lead to exception.
self.assertEqual(await call.code(), grpc.StatusCode.OK)
with self.assertRaises(asyncio.InvalidStateError):
await call.read()
async def test_unary_stream_async_generator(self):
async with aio.insecure_channel(self._server_target) as channel:
stub = test_pb2_grpc.TestServiceStub(channel)
# Prepares the request
request = messages_pb2.StreamingOutputCallRequest()
for _ in range(_NUM_STREAM_RESPONSES):
request.response_parameters.append(
messages_pb2.ResponseParameters(
size=_RESPONSE_PAYLOAD_SIZE,))
# Invokes the actual RPC
call = stub.StreamingOutputCall(request)
self.assertFalse(call.cancelled())
async for response in call:
self.assertIs(
type(response), messages_pb2.StreamingOutputCallResponse)
self.assertEqual(_RESPONSE_PAYLOAD_SIZE,
len(response.payload.body))
self.assertEqual(await call.code(), grpc.StatusCode.OK)
if __name__ == '__main__':
logging.basicConfig()
logging.basicConfig(level=logging.DEBUG)
unittest.main(verbosity=2)

@ -11,110 +11,121 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests behavior of the grpc.aio.Channel class."""
import logging
import threading
import unittest
import grpc
from grpc.experimental import aio
from src.proto.grpc.testing import messages_pb2
from src.proto.grpc.testing import test_pb2_grpc
from tests.unit.framework.common import test_constants
from tests_aio.unit._test_server import start_test_server
from tests_aio.unit._test_base import AioTestBase
_UNARY_CALL_METHOD = '/grpc.testing.TestService/UnaryCall'
_EMPTY_CALL_METHOD = '/grpc.testing.TestService/EmptyCall'
_STREAMING_OUTPUT_CALL_METHOD = '/grpc.testing.TestService/StreamingOutputCall'
_NUM_STREAM_RESPONSES = 5
_RESPONSE_PAYLOAD_SIZE = 42
class TestChannel(AioTestBase):
def test_async_context(self):
async def coro():
server_target, _ = await start_test_server() # pylint: disable=unused-variable
async with aio.insecure_channel(server_target) as channel:
hi = channel.unary_unary(
_UNARY_CALL_METHOD,
request_serializer=messages_pb2.SimpleRequest.
SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString
)
await hi(messages_pb2.SimpleRequest())
self.loop.run_until_complete(coro())
class TestChannel(AioTestBase):
def test_unary_unary(self):
async def setUp(self):
self._server_target, self._server = await start_test_server()
async def coro():
server_target, _ = await start_test_server() # pylint: disable=unused-variable
async def tearDown(self):
await self._server.stop(None)
channel = aio.insecure_channel(server_target)
async def test_async_context(self):
async with aio.insecure_channel(self._server_target) as channel:
hi = channel.unary_unary(
_UNARY_CALL_METHOD,
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
response = await hi(messages_pb2.SimpleRequest())
self.assertIs(type(response), messages_pb2.SimpleResponse)
await channel.close()
self.loop.run_until_complete(coro())
def test_unary_call_times_out(self):
async def coro():
server_target, _ = await start_test_server() # pylint: disable=unused-variable
async with aio.insecure_channel(server_target) as channel:
empty_call_with_sleep = channel.unary_unary(
_EMPTY_CALL_METHOD,
request_serializer=messages_pb2.SimpleRequest.
SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.
FromString,
)
timeout = test_constants.SHORT_TIMEOUT / 2
# TODO(https://github.com/grpc/grpc/issues/20869)
# Update once the async server is ready, change the
# synchronization mechanism by removing the sleep(<timeout>)
# as both components (client & server) will be on the same
# process.
with self.assertRaises(grpc.RpcError) as exception_context:
await empty_call_with_sleep(
messages_pb2.SimpleRequest(), timeout=timeout)
_, details = grpc.StatusCode.DEADLINE_EXCEEDED.value # pylint: disable=unused-variable
self.assertEqual(exception_context.exception.code(),
grpc.StatusCode.DEADLINE_EXCEEDED)
self.assertEqual(exception_context.exception.details(),
details.title())
self.assertIsNotNone(
exception_context.exception.initial_metadata())
self.assertIsNotNone(
exception_context.exception.trailing_metadata())
self.loop.run_until_complete(coro())
@unittest.skip('https://github.com/grpc/grpc/issues/20818')
def test_call_to_the_void(self):
await hi(messages_pb2.SimpleRequest())
async def coro():
channel = aio.insecure_channel('0.1.1.1:1111')
async def test_unary_unary(self):
async with aio.insecure_channel(self._server_target) as channel:
channel = aio.insecure_channel(self._server_target)
hi = channel.unary_unary(
_UNARY_CALL_METHOD,
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
response = await hi(messages_pb2.SimpleRequest())
self.assertIs(type(response), messages_pb2.SimpleResponse)
self.assertIsInstance(response, messages_pb2.SimpleResponse)
async def test_unary_call_times_out(self):
async with aio.insecure_channel(self._server_target) as channel:
empty_call_with_sleep = channel.unary_unary(
_EMPTY_CALL_METHOD,
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString,
)
timeout = test_constants.SHORT_TIMEOUT / 2
# TODO(https://github.com/grpc/grpc/issues/20869)
# Update once the async server is ready, change the
# synchronization mechanism by removing the sleep(<timeout>)
# as both components (client & server) will be on the same
# process.
with self.assertRaises(grpc.RpcError) as exception_context:
await empty_call_with_sleep(
messages_pb2.SimpleRequest(), timeout=timeout)
_, details = grpc.StatusCode.DEADLINE_EXCEEDED.value # pylint: disable=unused-variable
self.assertEqual(grpc.StatusCode.DEADLINE_EXCEEDED,
exception_context.exception.code())
self.assertEqual(details.title(),
exception_context.exception.details())
self.assertIsNotNone(exception_context.exception.initial_metadata())
self.assertIsNotNone(
exception_context.exception.trailing_metadata())
@unittest.skip('https://github.com/grpc/grpc/issues/20818')
async def test_call_to_the_void(self):
channel = aio.insecure_channel('0.1.1.1:1111')
hi = channel.unary_unary(
_UNARY_CALL_METHOD,
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
response = await hi(messages_pb2.SimpleRequest())
self.assertIsInstance(response, messages_pb2.SimpleResponse)
await channel.close()
async def test_unary_stream(self):
channel = aio.insecure_channel(self._server_target)
stub = test_pb2_grpc.TestServiceStub(channel)
# Prepares the request
request = messages_pb2.StreamingOutputCallRequest()
for _ in range(_NUM_STREAM_RESPONSES):
request.response_parameters.append(
messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE))
# Invokes the actual RPC
call = stub.StreamingOutputCall(request)
await channel.close()
# Validates the responses
response_cnt = 0
async for response in call:
response_cnt += 1
self.assertIs(
type(response), messages_pb2.StreamingOutputCallResponse)
self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
self.loop.run_until_complete(coro())
self.assertEqual(_NUM_STREAM_RESPONSES, response_cnt)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
await channel.close()
if __name__ == '__main__':
logging.basicConfig()
logging.basicConfig(level=logging.DEBUG)
unittest.main(verbosity=2)

@ -21,15 +21,11 @@ from tests_aio.unit._test_base import AioTestBase
class TestInsecureChannel(AioTestBase):
def test_insecure_channel(self):
async def test_insecure_channel(self):
server_target, _ = await start_test_server() # pylint: disable=unused-variable
async def coro():
server_target, _ = await start_test_server() # pylint: disable=unused-variable
channel = aio.insecure_channel(server_target)
self.assertIsInstance(channel, aio.Channel)
self.loop.run_until_complete(coro())
channel = aio.insecure_channel(server_target)
self.assertIsInstance(channel, aio.Channel)
if __name__ == '__main__':

@ -26,9 +26,13 @@ from tests.unit.framework.common import test_constants
_SIMPLE_UNARY_UNARY = '/test/SimpleUnaryUnary'
_BLOCK_FOREVER = '/test/BlockForever'
_BLOCK_BRIEFLY = '/test/BlockBriefly'
_UNARY_STREAM_ASYNC_GEN = '/test/UnaryStreamAsyncGen'
_UNARY_STREAM_READER_WRITER = '/test/UnaryStreamReaderWriter'
_UNARY_STREAM_EVILLY_MIXED = '/test/UnaryStreamEvillyMixed'
_REQUEST = b'\x00\x00\x00'
_RESPONSE = b'\x01\x01\x01'
_NUM_STREAM_RESPONSES = 5
class _GenericHandler(grpc.GenericRpcHandler):
@ -43,10 +47,23 @@ class _GenericHandler(grpc.GenericRpcHandler):
async def _block_forever(self, unused_request, unused_context):
await asyncio.get_event_loop().create_future()
async def _BLOCK_BRIEFLY(self, unused_request, unused_context):
async def _block_briefly(self, unused_request, unused_context):
await asyncio.sleep(test_constants.SHORT_TIMEOUT / 2)
return _RESPONSE
async def _unary_stream_async_gen(self, unused_request, unused_context):
for _ in range(_NUM_STREAM_RESPONSES):
yield _RESPONSE
async def _unary_stream_reader_writer(self, unused_request, context):
for _ in range(_NUM_STREAM_RESPONSES):
await context.write(_RESPONSE)
async def _unary_stream_evilly_mixed(self, unused_request, context):
yield _RESPONSE
for _ in range(_NUM_STREAM_RESPONSES - 1):
await context.write(_RESPONSE)
def service(self, handler_details):
self._called.set_result(None)
if handler_details.method == _SIMPLE_UNARY_UNARY:
@ -54,7 +71,16 @@ class _GenericHandler(grpc.GenericRpcHandler):
if handler_details.method == _BLOCK_FOREVER:
return grpc.unary_unary_rpc_method_handler(self._block_forever)
if handler_details.method == _BLOCK_BRIEFLY:
return grpc.unary_unary_rpc_method_handler(self._BLOCK_BRIEFLY)
return grpc.unary_unary_rpc_method_handler(self._block_briefly)
if handler_details.method == _UNARY_STREAM_ASYNC_GEN:
return grpc.unary_stream_rpc_method_handler(
self._unary_stream_async_gen)
if handler_details.method == _UNARY_STREAM_READER_WRITER:
return grpc.unary_stream_rpc_method_handler(
self._unary_stream_reader_writer)
if handler_details.method == _UNARY_STREAM_EVILLY_MIXED:
return grpc.unary_stream_rpc_method_handler(
self._unary_stream_evilly_mixed)
async def wait_for_call(self):
await self._called
@ -71,150 +97,168 @@ async def _start_test_server():
class TestServer(AioTestBase):
def test_unary_unary(self):
async def test_unary_unary_body():
result = await _start_test_server()
server_target = result[0]
async with aio.insecure_channel(server_target) as channel:
unary_call = channel.unary_unary(_SIMPLE_UNARY_UNARY)
response = await unary_call(_REQUEST)
self.assertEqual(response, _RESPONSE)
self.loop.run_until_complete(test_unary_unary_body())
def test_shutdown(self):
async def test_shutdown_body():
_, server, _ = await _start_test_server()
await server.stop(None)
self.loop.run_until_complete(test_shutdown_body())
# Ensures no SIGSEGV triggered, and ends within timeout.
def test_shutdown_after_call(self):
async def test_shutdown_body():
server_target, server, _ = await _start_test_server()
async def setUp(self):
self._server_target, self._server, self._generic_handler = await _start_test_server(
)
async with aio.insecure_channel(server_target) as channel:
await channel.unary_unary(_SIMPLE_UNARY_UNARY)(_REQUEST)
async def tearDown(self):
await self._server.stop(None)
await server.stop(None)
async def test_unary_unary(self):
async with aio.insecure_channel(self._server_target) as channel:
unary_unary_call = channel.unary_unary(_SIMPLE_UNARY_UNARY)
response = await unary_unary_call(_REQUEST)
self.assertEqual(response, _RESPONSE)
self.loop.run_until_complete(test_shutdown_body())
async def test_unary_stream_async_generator(self):
async with aio.insecure_channel(self._server_target) as channel:
unary_stream_call = channel.unary_stream(_UNARY_STREAM_ASYNC_GEN)
call = unary_stream_call(_REQUEST)
def test_graceful_shutdown_success(self):
# Expecting the request message to reach server before retriving
# any responses.
await asyncio.wait_for(self._generic_handler.wait_for_call(),
test_constants.SHORT_TIMEOUT)
async def test_graceful_shutdown_success_body():
server_target, server, generic_handler = await _start_test_server()
response_cnt = 0
async for response in call:
response_cnt += 1
self.assertEqual(_RESPONSE, response)
channel = aio.insecure_channel(server_target)
call = channel.unary_unary(_BLOCK_BRIEFLY)(_REQUEST)
await generic_handler.wait_for_call()
self.assertEqual(_NUM_STREAM_RESPONSES, response_cnt)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
shutdown_start_time = time.time()
await server.stop(test_constants.SHORT_TIMEOUT)
grace_period_length = time.time() - shutdown_start_time
self.assertGreater(grace_period_length,
test_constants.SHORT_TIMEOUT / 3)
async def test_unary_stream_reader_writer(self):
async with aio.insecure_channel(self._server_target) as channel:
unary_stream_call = channel.unary_stream(
_UNARY_STREAM_READER_WRITER)
call = unary_stream_call(_REQUEST)
# Validates the states.
await channel.close()
self.assertEqual(_RESPONSE, await call)
self.assertTrue(call.done())
# Expecting the request message to reach server before retriving
# any responses.
await asyncio.wait_for(self._generic_handler.wait_for_call(),
test_constants.SHORT_TIMEOUT)
self.loop.run_until_complete(test_graceful_shutdown_success_body())
for _ in range(_NUM_STREAM_RESPONSES):
response = await call.read()
self.assertEqual(_RESPONSE, response)
def test_graceful_shutdown_failed(self):
self.assertEqual(await call.code(), grpc.StatusCode.OK)
async def test_graceful_shutdown_failed_body():
server_target, server, generic_handler = await _start_test_server()
async def test_unary_stream_evilly_mixed(self):
async with aio.insecure_channel(self._server_target) as channel:
unary_stream_call = channel.unary_stream(_UNARY_STREAM_EVILLY_MIXED)
call = unary_stream_call(_REQUEST)
channel = aio.insecure_channel(server_target)
call = channel.unary_unary(_BLOCK_FOREVER)(_REQUEST)
await generic_handler.wait_for_call()
# Expecting the request message to reach server before retriving
# any responses.
await asyncio.wait_for(self._generic_handler.wait_for_call(),
test_constants.SHORT_TIMEOUT)
await server.stop(test_constants.SHORT_TIMEOUT)
# Uses reader API
self.assertEqual(_RESPONSE, await call.read())
with self.assertRaises(aio.AioRpcError) as exception_context:
await call
self.assertEqual(grpc.StatusCode.UNAVAILABLE,
exception_context.exception.code())
self.assertIn('GOAWAY', exception_context.exception.details())
await channel.close()
# Uses async generator API
response_cnt = 0
async for response in call:
response_cnt += 1
self.assertEqual(_RESPONSE, response)
self.loop.run_until_complete(test_graceful_shutdown_failed_body())
self.assertEqual(_NUM_STREAM_RESPONSES - 1, response_cnt)
def test_concurrent_graceful_shutdown(self):
self.assertEqual(await call.code(), grpc.StatusCode.OK)
async def test_concurrent_graceful_shutdown_body():
server_target, server, generic_handler = await _start_test_server()
channel = aio.insecure_channel(server_target)
call = channel.unary_unary(_BLOCK_BRIEFLY)(_REQUEST)
await generic_handler.wait_for_call()
# Expects the shortest grace period to be effective.
shutdown_start_time = time.time()
await asyncio.gather(
server.stop(test_constants.LONG_TIMEOUT),
server.stop(test_constants.SHORT_TIMEOUT),
server.stop(test_constants.LONG_TIMEOUT),
)
grace_period_length = time.time() - shutdown_start_time
self.assertGreater(grace_period_length,
test_constants.SHORT_TIMEOUT / 3)
await channel.close()
self.assertEqual(_RESPONSE, await call)
self.assertTrue(call.done())
self.loop.run_until_complete(test_concurrent_graceful_shutdown_body())
def test_concurrent_graceful_shutdown_immediate(self):
async def test_concurrent_graceful_shutdown_immediate_body():
server_target, server, generic_handler = await _start_test_server()
channel = aio.insecure_channel(server_target)
call = channel.unary_unary(_BLOCK_FOREVER)(_REQUEST)
await generic_handler.wait_for_call()
# Expects no grace period, due to the "server.stop(None)".
await asyncio.gather(
server.stop(test_constants.LONG_TIMEOUT),
server.stop(None),
server.stop(test_constants.SHORT_TIMEOUT),
server.stop(test_constants.LONG_TIMEOUT),
)
with self.assertRaises(aio.AioRpcError) as exception_context:
await call
self.assertEqual(grpc.StatusCode.UNAVAILABLE,
exception_context.exception.code())
self.assertIn('GOAWAY', exception_context.exception.details())
await channel.close()
async def test_shutdown(self):
await self._server.stop(None)
# Ensures no SIGSEGV triggered, and ends within timeout.
self.loop.run_until_complete(
test_concurrent_graceful_shutdown_immediate_body())
async def test_shutdown_after_call(self):
async with aio.insecure_channel(self._server_target) as channel:
await channel.unary_unary(_SIMPLE_UNARY_UNARY)(_REQUEST)
await self._server.stop(None)
async def test_graceful_shutdown_success(self):
channel = aio.insecure_channel(self._server_target)
call = channel.unary_unary(_BLOCK_BRIEFLY)(_REQUEST)
await self._generic_handler.wait_for_call()
shutdown_start_time = time.time()
await self._server.stop(test_constants.SHORT_TIMEOUT)
grace_period_length = time.time() - shutdown_start_time
self.assertGreater(grace_period_length,
test_constants.SHORT_TIMEOUT / 3)
# Validates the states.
await channel.close()
self.assertEqual(_RESPONSE, await call)
self.assertTrue(call.done())
async def test_graceful_shutdown_failed(self):
channel = aio.insecure_channel(self._server_target)
call = channel.unary_unary(_BLOCK_FOREVER)(_REQUEST)
await self._generic_handler.wait_for_call()
await self._server.stop(test_constants.SHORT_TIMEOUT)
with self.assertRaises(grpc.RpcError) as exception_context:
await call
self.assertEqual(grpc.StatusCode.UNAVAILABLE,
exception_context.exception.code())
self.assertIn('GOAWAY', exception_context.exception.details())
await channel.close()
async def test_concurrent_graceful_shutdown(self):
channel = aio.insecure_channel(self._server_target)
call = channel.unary_unary(_BLOCK_BRIEFLY)(_REQUEST)
await self._generic_handler.wait_for_call()
# Expects the shortest grace period to be effective.
shutdown_start_time = time.time()
await asyncio.gather(
self._server.stop(test_constants.LONG_TIMEOUT),
self._server.stop(test_constants.SHORT_TIMEOUT),
self._server.stop(test_constants.LONG_TIMEOUT),
)
grace_period_length = time.time() - shutdown_start_time
self.assertGreater(grace_period_length,
test_constants.SHORT_TIMEOUT / 3)
await channel.close()
self.assertEqual(_RESPONSE, await call)
self.assertTrue(call.done())
async def test_concurrent_graceful_shutdown_immediate(self):
channel = aio.insecure_channel(self._server_target)
call = channel.unary_unary(_BLOCK_FOREVER)(_REQUEST)
await self._generic_handler.wait_for_call()
# Expects no grace period, due to the "server.stop(None)".
await asyncio.gather(
self._server.stop(test_constants.LONG_TIMEOUT),
self._server.stop(None),
self._server.stop(test_constants.SHORT_TIMEOUT),
self._server.stop(test_constants.LONG_TIMEOUT),
)
with self.assertRaises(grpc.RpcError) as exception_context:
await call
self.assertEqual(grpc.StatusCode.UNAVAILABLE,
exception_context.exception.code())
self.assertIn('GOAWAY', exception_context.exception.details())
await channel.close()
@unittest.skip('https://github.com/grpc/grpc/issues/20818')
def test_shutdown_before_call(self):
async def test_shutdown_body():
server_target, server, _ = _start_test_server()
await server.stop(None)
# Ensures the server is cleaned up at this point.
# Some proper exception should be raised.
async with aio.insecure_channel('localhost:%d' % port) as channel:
await channel.unary_unary(_SIMPLE_UNARY_UNARY)(_REQUEST)
async def test_shutdown_before_call(self):
server_target, server, _ = _start_test_server()
await server.stop(None)
self.loop.run_until_complete(test_shutdown_body())
# Ensures the server is cleaned up at this point.
# Some proper exception should be raised.
async with aio.insecure_channel('localhost:%d' % port) as channel:
await channel.unary_unary(_SIMPLE_UNARY_UNARY)(_REQUEST)
if __name__ == '__main__':
logging.basicConfig()
logging.basicConfig(level=logging.DEBUG)
unittest.main(verbosity=2)

@ -18,6 +18,7 @@ set PATH=C:\%1;C:\%1\scripts;C:\msys64\mingw%2\bin;C:\tools\msys64\mingw%2\bin;%
python -m pip install --upgrade six
@rem some artifacts are broken for setuptools 38.5.0. See https://github.com/grpc/grpc/issues/14317
python -m pip install --upgrade setuptools==38.2.4
python -m pip install --upgrade cython
python -m pip install -rrequirements.txt
set GRPC_PYTHON_BUILD_WITH_CYTHON=1

@ -727,13 +727,18 @@ class PythonLanguage(object):
self.args.iomgr_platform]) as tests_json_file:
tests_json = json.load(tests_json_file)
environment = dict(_FORCE_ENVIRON_FOR_WRAPPERS)
# TODO(https://github.com/grpc/grpc/issues/21401) Fork handlers is not
# designed for non-native IO manager. It has a side-effect that
# overrides threading settings in C-Core.
if args.iomgr_platform != 'native':
environment['GRPC_ENABLE_FORK_SUPPORT'] = '0'
return [
self.config.job_spec(
config.run,
timeout_seconds=5 * 60,
environ=dict(
list(environment.items()) + [(
'GRPC_PYTHON_TESTRUNNER_FILTER', str(suite_name))]),
GRPC_PYTHON_TESTRUNNER_FILTER=str(suite_name),
**environment),
shortname='%s.%s.%s' %
(config.name, self._TEST_FOLDER[self.args.iomgr_platform],
suite_name),

Loading…
Cancel
Save