diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pxd.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pxd.pxi index 9d1f81e72f5..bdf5996bd37 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pxd.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pxd.pxi @@ -18,6 +18,15 @@ cdef class _AioCall: AioChannel _channel list _references GrpcCallWrapper _grpc_call_wrapper + object _loop + + # Streaming call only attributes: + # + # A asyncio.Event that indicates if the status is received on the client side. + object _status_received + # A tuple of key value pairs representing the initial metadata sent by peer. + tuple _initial_metadata cdef grpc_call* _create_grpc_call(self, object timeout, bytes method) except * cdef void _destroy_grpc_call(self) + cdef AioRpcStatus _cancel_and_create_status(self, object cancellation_future) diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi index fba1244a813..960f5317446 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi @@ -19,6 +19,8 @@ _EMPTY_FLAGS = 0 _EMPTY_MASK = 0 _EMPTY_METADATA = None +_UNKNOWN_CANCELLATION_DETAILS = 'RPC cancelled due to unknown reason.' + cdef class _AioCall: @@ -26,6 +28,9 @@ cdef class _AioCall: self._channel = channel self._references = [] self._grpc_call_wrapper = GrpcCallWrapper() + self._loop = asyncio.get_event_loop() + + self._status_received = asyncio.Event(loop=self._loop) def __repr__(self): class_name = self.__class__.__name__ @@ -33,7 +38,7 @@ cdef class _AioCall: return f"<{class_name} {id_}>" cdef grpc_call* _create_grpc_call(self, - object timeout, + object deadline, bytes method) except *: """Creates the corresponding Core object for this RPC. @@ -44,7 +49,7 @@ cdef class _AioCall: nature in Core. """ cdef grpc_slice method_slice - cdef gpr_timespec deadline = _timespec_from_time(timeout) + cdef gpr_timespec c_deadline = _timespec_from_time(deadline) method_slice = grpc_slice_from_copied_buffer( method, @@ -57,93 +62,204 @@ cdef class _AioCall: self._channel.cq.c_ptr(), method_slice, NULL, - deadline, + c_deadline, NULL ) grpc_slice_unref(method_slice) cdef void _destroy_grpc_call(self): """Destroys the corresponding Core object for this RPC.""" - grpc_call_unref(self._grpc_call_wrapper.call) + if self._grpc_call_wrapper.call != NULL: + grpc_call_unref(self._grpc_call_wrapper.call) - async def unary_unary(self, bytes method, bytes request, object timeout, AioCancelStatus cancel_status): - cdef object loop = asyncio.get_event_loop() + cdef AioRpcStatus _cancel_and_create_status(self, object cancellation_future): + """Cancels the RPC in C-Core, and return the final RPC status.""" + cdef AioRpcStatus status + cdef object details + cdef char *c_details + # Try to fetch application layer cancellation details in the future. + # * If calcellation details present, cancel with status; + # * If details not present, cancel with unknown reason. + if cancellation_future.done(): + status = cancellation_future.result() + details = str_to_bytes(status.details()) + self._references.append(details) + c_details = details + # By implementation, grpc_call_cancel_with_status always return OK + grpc_call_cancel_with_status( + self._grpc_call_wrapper.call, + status.c_code(), + c_details, + NULL, + ) + return status + else: + # By implementation, grpc_call_cancel always return OK + grpc_call_cancel(self._grpc_call_wrapper.call, NULL) + return AioRpcStatus( + StatusCode.cancelled, + _UNKNOWN_CANCELLATION_DETAILS, + None, + None, + ) - cdef tuple operations - cdef Operation initial_metadata_operation - cdef Operation send_message_operation - cdef Operation send_close_from_client_operation - cdef Operation receive_initial_metadata_operation - cdef Operation receive_message_operation - cdef Operation receive_status_on_client_operation + async def unary_unary(self, + bytes method, + bytes request, + object deadline, + object cancellation_future, + object initial_metadata_observer, + object status_observer): + """Performs a unary unary RPC. + + Args: + method: name of the calling method in bytes. + request: the serialized requests in bytes. + deadline: optional deadline of the RPC in float. + cancellation_future: the future that meant to transport the + cancellation reason from the application layer. + initial_metadata_observer: a callback for received initial metadata. + status_observer: a callback for received final status. + """ + cdef tuple ops - cdef char *c_details = NULL + cdef SendInitialMetadataOperation initial_metadata_op = SendInitialMetadataOperation( + _EMPTY_METADATA, + GRPC_INITIAL_METADATA_USED_MASK) + cdef SendMessageOperation send_message_op = SendMessageOperation(request, _EMPTY_FLAGS) + cdef SendCloseFromClientOperation send_close_op = SendCloseFromClientOperation(_EMPTY_FLAGS) + cdef ReceiveInitialMetadataOperation receive_initial_metadata_op = ReceiveInitialMetadataOperation(_EMPTY_FLAGS) + cdef ReceiveMessageOperation receive_message_op = ReceiveMessageOperation(_EMPTY_FLAGS) + cdef ReceiveStatusOnClientOperation receive_status_on_client_op = ReceiveStatusOnClientOperation(_EMPTY_FLAGS) - initial_metadata_operation = SendInitialMetadataOperation(_EMPTY_METADATA, GRPC_INITIAL_METADATA_USED_MASK) - initial_metadata_operation.c() + ops = (initial_metadata_op, send_message_op, send_close_op, + receive_initial_metadata_op, receive_message_op, + receive_status_on_client_op) - send_message_operation = SendMessageOperation(request, _EMPTY_FLAGS) - send_message_operation.c() + try: + self._create_grpc_call(deadline, method) + try: + await callback_start_batch(self._grpc_call_wrapper, + ops, + self._loop) + except asyncio.CancelledError: + status = self._cancel_and_create_status(cancellation_future) + status_observer(status) + raise + finally: + # If the RPC failed, this method will return None instead of crash. + initial_metadata_observer( + receive_initial_metadata_op.initial_metadata() + ) + self._destroy_grpc_call() - send_close_from_client_operation = SendCloseFromClientOperation(_EMPTY_FLAGS) - send_close_from_client_operation.c() + status = AioRpcStatus( + receive_status_on_client_op.code(), + receive_status_on_client_op.details(), + receive_status_on_client_op.trailing_metadata(), + receive_status_on_client_op.error_string(), + ) + # Reports the final status of the RPC to Python layer. The observer + # pattern is used here to unify unary and streaming code path. + status_observer(status) - receive_initial_metadata_operation = ReceiveInitialMetadataOperation(_EMPTY_FLAGS) - receive_initial_metadata_operation.c() + if status.code() == StatusCode.ok: + return receive_message_op.message() + else: + return None - receive_message_operation = ReceiveMessageOperation(_EMPTY_FLAGS) - receive_message_operation.c() + async def _handle_status_once_received(self, object status_observer): + """Handles the status sent by peer once received.""" + cdef ReceiveStatusOnClientOperation op = ReceiveStatusOnClientOperation(_EMPTY_FLAGS) + cdef tuple ops = (op,) + await callback_start_batch(self._grpc_call_wrapper, ops, self._loop) + cdef AioRpcStatus status = AioRpcStatus( + op.code(), + op.details(), + op.trailing_metadata(), + op.error_string(), + ) + status_observer(status) + self._status_received.set() - receive_status_on_client_operation = ReceiveStatusOnClientOperation(_EMPTY_FLAGS) - receive_status_on_client_operation.c() + def _handle_cancellation_from_application(self, + object cancellation_future, + object status_observer): + def _cancellation_action(finished_future): + status = self._cancel_and_create_status(finished_future) + status_observer(status) - operations = ( - initial_metadata_operation, - send_message_operation, - send_close_from_client_operation, - receive_initial_metadata_operation, - receive_message_operation, - receive_status_on_client_operation, + cancellation_future.add_done_callback(_cancellation_action) + + async def unary_stream(self, + bytes method, + bytes request, + object deadline, + object cancellation_future, + object initial_metadata_observer, + object status_observer): + """Actual implementation of the complete unary-stream call. + + Needs to pay extra attention to the raise mechanism. If we want to + propagate the final status exception, then we have to raise it. + Othersize, it would end normally and raise `StopAsyncIteration()`. + """ + cdef bytes received_message + cdef tuple outbound_ops + cdef Operation initial_metadata_op = SendInitialMetadataOperation( + _EMPTY_METADATA, + GRPC_INITIAL_METADATA_USED_MASK) + cdef Operation send_message_op = SendMessageOperation( + request, + _EMPTY_FLAGS) + cdef Operation send_close_op = SendCloseFromClientOperation( + _EMPTY_FLAGS) + + outbound_ops = ( + initial_metadata_op, + send_message_op, + send_close_op, ) + # NOTE(lidiz) Not catching CancelledError here, because async + # generators do not have "cancel" method. try: - self._create_grpc_call( - timeout, - method, + self._create_grpc_call(deadline, method) + + await callback_start_batch( + self._grpc_call_wrapper, + outbound_ops, + self._loop) + + # Peer may prematurely end this RPC at any point. We need a mechanism + # that handles both the normal case and the error case. + self._loop.create_task(self._handle_status_once_received(status_observer)) + self._handle_cancellation_from_application(cancellation_future, + status_observer) + + # Receives initial metadata. + initial_metadata_observer( + await _receive_initial_metadata(self._grpc_call_wrapper, + self._loop), ) - try: - await callback_start_batch( + # Infinitely receiving messages, until: + # * EOF, no more messages to read; + # * The client application cancells; + # * The server sends final status. + while True: + if self._status_received.is_set(): + return + + received_message = await _receive_message( self._grpc_call_wrapper, - operations, - loop + self._loop ) - except asyncio.CancelledError: - if cancel_status: - details = str_to_bytes(cancel_status.details()) - self._references.append(details) - c_details = details - call_status = grpc_call_cancel_with_status( - self._grpc_call_wrapper.call, - cancel_status.code(), - c_details, - NULL, - ) + if received_message is None: + # The read operation failed, wait for status from C-Core. + await self._status_received.wait() + return else: - call_status = grpc_call_cancel( - self._grpc_call_wrapper.call, NULL) - if call_status != GRPC_CALL_OK: - raise Exception("RPC call couldn't be cancelled. Error {}".format(call_status)) - raise + yield received_message finally: self._destroy_grpc_call() - - if receive_status_on_client_operation.code() == StatusCode.ok: - return receive_message_operation.message() - - raise AioRpcError( - receive_initial_metadata_operation.initial_metadata(), - receive_status_on_client_operation.code(), - receive_status_on_client_operation.details(), - receive_status_on_client_operation.trailing_metadata(), - ) diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/callback_common.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/callback_common.pyx.pxi index d7c550cefa3..796a6ab9bf5 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/callback_common.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/callback_common.pyx.pxi @@ -46,11 +46,13 @@ cdef class CallbackWrapper: grpc_experimental_completion_queue_functor* functor, int success): cdef CallbackContext *context = functor + cdef object waiter = context.waiter + if waiter.cancelled(): + return if success == 0: - (context.failure_handler).handle( - context.waiter) + (context.failure_handler).handle(waiter) else: - (context.waiter).set_result(None) + waiter.set_result(None) cdef grpc_experimental_completion_queue_functor *c_functor(self): return &self.context.functor @@ -83,6 +85,9 @@ cdef class CallbackCompletionQueue: grpc_completion_queue_destroy(self._cq) +class CallbackStartBatchError(Exception): pass + + async def callback_start_batch(GrpcCallWrapper grpc_call_wrapper, tuple operations, object loop): @@ -93,7 +98,7 @@ async def callback_start_batch(GrpcCallWrapper grpc_call_wrapper, cdef object future = loop.create_future() cdef CallbackWrapper wrapper = CallbackWrapper( future, - CallbackFailureHandler('callback_start_batch', operations, RuntimeError)) + CallbackFailureHandler('callback_start_batch', operations, CallbackStartBatchError)) # NOTE(lidiz) Without Py_INCREF, the wrapper object will be destructed # when calling "await". This is an over-optimization by Cython. cpython.Py_INCREF(wrapper) @@ -104,10 +109,67 @@ async def callback_start_batch(GrpcCallWrapper grpc_call_wrapper, wrapper.c_functor(), NULL) if error != GRPC_CALL_OK: - raise RuntimeError("Failed grpc_call_start_batch: {}".format(error)) + raise CallbackStartBatchError("Failed grpc_call_start_batch: {}".format(error)) await future cpython.Py_DECREF(wrapper) cdef grpc_event c_event # Tag.event must be called, otherwise messages won't be parsed from C batch_operation_tag.event(c_event) + + +async def _receive_message(GrpcCallWrapper grpc_call_wrapper, + object loop): + """Retrives parsed messages from C-Core. + + The messages maybe already in C-Core's buffer, so there isn't a 1-to-1 + mapping between this and the underlying "socket.read()". Also, eventually, + this function will end with an EOF, which reads empty message. + """ + cdef ReceiveMessageOperation receive_op = ReceiveMessageOperation(_EMPTY_FLAG) + cdef tuple ops = (receive_op,) + try: + await callback_start_batch(grpc_call_wrapper, ops, loop) + except CallbackStartBatchError as e: + # NOTE(lidiz) The receive message operation has two ways to indicate + # finish state : 1) returns empty message due to EOF; 2) fails inside + # the callback (e.g. cancelled). + # + # Since they all indicates finish, they are better be merged. + _LOGGER.exception(e) + return receive_op.message() + + +async def _send_message(GrpcCallWrapper grpc_call_wrapper, + bytes message, + bint metadata_sent, + object loop): + cdef SendMessageOperation op = SendMessageOperation(message, _EMPTY_FLAG) + cdef tuple ops + if metadata_sent: + ops = (op,) + else: + ops = ( + # Initial metadata must be sent before first outbound message. + SendInitialMetadataOperation(None, _EMPTY_FLAG), + op, + ) + await callback_start_batch(grpc_call_wrapper, ops, loop) + + +async def _send_initial_metadata(GrpcCallWrapper grpc_call_wrapper, + tuple metadata, + object loop): + cdef SendInitialMetadataOperation op = SendInitialMetadataOperation( + metadata, + _EMPTY_FLAG) + cdef tuple ops = (op,) + await callback_start_batch(grpc_call_wrapper, ops, loop) + + +async def _receive_initial_metadata(GrpcCallWrapper grpc_call_wrapper, + object loop): + cdef ReceiveInitialMetadataOperation op = ReceiveInitialMetadataOperation(_EMPTY_FLAGS) + cdef tuple ops = (op,) + await callback_start_batch(grpc_call_wrapper, ops, loop) + return op.initial_metadata() diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi index d31df71a316..0f44135301f 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi @@ -26,6 +26,42 @@ cdef class AioChannel: def close(self): grpc_channel_destroy(self.channel) - async def unary_unary(self, method, request, timeout, cancel_status): - call = _AioCall(self) - return await call.unary_unary(method, request, timeout, cancel_status) + async def unary_unary(self, + bytes method, + bytes request, + object deadline, + object cancellation_future, + object initial_metadata_observer, + object status_observer): + """Assembles a unary-unary RPC. + + Returns: + The response message in bytes. + """ + cdef _AioCall call = _AioCall(self) + return await call.unary_unary(method, + request, + deadline, + cancellation_future, + initial_metadata_observer, + status_observer) + + def unary_stream(self, + bytes method, + bytes request, + object deadline, + object cancellation_future, + object initial_metadata_observer, + object status_observer): + """Assembles a unary-stream RPC. + + Returns: + An async generator that yields raw responses. + """ + cdef _AioCall call = _AioCall(self) + return call.unary_stream(method, + request, + deadline, + cancellation_future, + initial_metadata_observer, + status_observer) diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/cancel_status.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/common.pyx.pxi similarity index 50% rename from src/python/grpcio/grpc/_cython/_cygrpc/aio/cancel_status.pyx.pxi rename to src/python/grpcio/grpc/_cython/_cygrpc/aio/common.pyx.pxi index e2026458e3c..fbb65983f19 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/cancel_status.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/common.pyx.pxi @@ -1,4 +1,4 @@ -# Copyright 2019 gRPC authors. +# Copyright 2019 The gRPC Authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,26 +11,25 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Desired cancellation status for canceling an ongoing RPC call.""" -cdef class AioCancelStatus: +cdef object deserialize(object deserializer, bytes raw_message): + """Perform deserialization on raw bytes. - def __cinit__(self): - self._code = None - self._details = None + Failure to deserialize is a fatal error. + """ + if deserializer: + return deserializer(raw_message) + else: + return raw_message - def __len__(self): - if self._code is None: - return 0 - return 1 - def cancel(self, grpc_status_code code, str details=None): - self._code = code - self._details = details +cdef bytes serialize(object serializer, object message): + """Perform serialization on a message. - cpdef object code(self): - return self._code - - cpdef str details(self): - return self._details + Failure to serialize is a fatal error. + """ + if serializer: + return serializer(message) + else: + return message diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/iomgr/socket.pxd.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/iomgr/socket.pxd.pxi index 285fbdcea09..363d9b6eea5 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/iomgr/socket.pxd.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/iomgr/socket.pxd.pxi @@ -23,12 +23,14 @@ cdef class _AsyncioSocket: object _task_read object _task_connect char * _read_buffer + object _loop # Client-side attributes grpc_custom_connect_callback _grpc_connect_cb # Server-side attributes grpc_custom_accept_callback _grpc_accept_cb + grpc_custom_write_callback _grpc_write_cb grpc_custom_socket * _grpc_client_socket object _server object _py_socket diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/iomgr/socket.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/iomgr/socket.pyx.pxi index 4ef755dfaa0..4b14166be50 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/iomgr/socket.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/iomgr/socket.pyx.pxi @@ -16,11 +16,14 @@ import socket as native_socket from libc cimport string + +# TODO(https://github.com/grpc/grpc/issues/21348) Better flow control needed. cdef class _AsyncioSocket: def __cinit__(self): self._grpc_socket = NULL self._grpc_connect_cb = NULL self._grpc_read_cb = NULL + self._grpc_write_cb = NULL self._reader = None self._writer = None self._task_connect = None @@ -29,6 +32,7 @@ cdef class _AsyncioSocket: self._server = None self._py_socket = None self._peername = None + self._loop = asyncio.get_event_loop() @staticmethod cdef _AsyncioSocket create(grpc_custom_socket * grpc_socket, @@ -56,30 +60,25 @@ cdef class _AsyncioSocket: return f"<{class_name} {id_} connected={connected}>" def _connect_cb(self, future): - error = False try: self._reader, self._writer = future.result() except Exception as e: - error = True - error_msg = str(e) + self._grpc_connect_cb( + self._grpc_socket, + grpc_socket_error("Socket connect failed: {}".format(e).encode()) + ) finally: self._task_connect = None - if not error: - # gRPC default posix implementation disables nagle - # algorithm. - sock = self._writer.transport.get_extra_info('socket') - sock.setsockopt(native_socket.IPPROTO_TCP, native_socket.TCP_NODELAY, True) + # gRPC default posix implementation disables nagle + # algorithm. + sock = self._writer.transport.get_extra_info('socket') + sock.setsockopt(native_socket.IPPROTO_TCP, native_socket.TCP_NODELAY, True) - self._grpc_connect_cb( - self._grpc_socket, - 0 - ) - else: - self._grpc_connect_cb( - self._grpc_socket, - grpc_socket_error("connect {}".format(error_msg).encode()) - ) + self._grpc_connect_cb( + self._grpc_socket, + 0 + ) def _read_cb(self, future): error = False @@ -87,7 +86,8 @@ cdef class _AsyncioSocket: buffer_ = future.result() except Exception as e: error = True - error_msg = str(e) + error_msg = "%s: %s" % (type(e), str(e)) + _LOGGER.exception(e) finally: self._task_read = None @@ -106,7 +106,7 @@ cdef class _AsyncioSocket: self._grpc_read_cb( self._grpc_socket, -1, - grpc_socket_error("read {}".format(error_msg).encode()) + grpc_socket_error("Read failed: {}".format(error_msg).encode()) ) cdef void connect(self, @@ -125,23 +125,34 @@ cdef class _AsyncioSocket: cdef void read(self, char * buffer_, size_t length, grpc_custom_read_callback grpc_read_cb): assert not self._task_read - self._task_read = asyncio.ensure_future( + self._task_read = self._loop.create_task( self._reader.read(n=length) ) self._grpc_read_cb = grpc_read_cb self._task_read.add_done_callback(self._read_cb) self._read_buffer = buffer_ + + async def _async_write(self, bytearray buffer): + self._writer.write(buffer) + await self._writer.drain() + + self._grpc_write_cb( + self._grpc_socket, + 0 + ) cdef void write(self, grpc_slice_buffer * g_slice_buffer, grpc_custom_write_callback grpc_write_cb): + # For each socket, C-Core guarantees there'll be only one ongoing write + self._grpc_write_cb = grpc_write_cb + cdef char* start - buffer_ = bytearray() + buffer = bytearray() for i in range(g_slice_buffer.count): start = grpc_slice_buffer_start(g_slice_buffer, i) length = grpc_slice_buffer_length(g_slice_buffer, i) - buffer_.extend(start[:length]) - - self._writer.write(buffer_) + buffer.extend(start[:length]) + self._writer.write(buffer) grpc_write_cb( self._grpc_socket, 0 @@ -171,9 +182,9 @@ cdef class _AsyncioSocket: self._grpc_client_socket.impl = client_socket cpython.Py_INCREF(client_socket) # Py_DECREF in asyncio_socket_destroy # Accept callback expects to be called with: - # grpc_custom_socket: A grpc custom socket for server - # grpc_custom_socket: A grpc custom socket for client (with new Socket instance) - # grpc_error: An error object + # * grpc_custom_socket: A grpc custom socket for server + # * grpc_custom_socket: A grpc custom socket for client (with new Socket instance) + # * grpc_error: An error object self._grpc_accept_cb(self._grpc_socket, self._grpc_client_socket, grpc_error_none()) cdef listen(self): @@ -183,7 +194,7 @@ cdef class _AsyncioSocket: sock=self._py_socket, ) - asyncio.get_event_loop().create_task(create_asyncio_server()) + self._loop.create_task(create_asyncio_server()) cdef accept(self, grpc_custom_socket* grpc_socket_client, diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/rpc_error.pxd.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/rpc_status.pxd.pxi similarity index 79% rename from src/python/grpcio/grpc/_cython/_cygrpc/aio/rpc_error.pxd.pxi rename to src/python/grpcio/grpc/_cython/_cygrpc/aio/rpc_status.pxd.pxi index 5772751a3b6..62add7f33d1 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/rpc_error.pxd.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/rpc_status.pxd.pxi @@ -14,14 +14,16 @@ """Exceptions for the aio version of the RPC calls.""" -cdef class _AioRpcError(Exception): +cdef class AioRpcStatus(Exception): cdef readonly: - tuple _initial_metadata int _code str _details + # On spec, only client-side status has trailing metadata. tuple _trailing_metadata + str _debug_error_string - cpdef tuple initial_metadata(self) cpdef int code(self) cpdef str details(self) cpdef tuple trailing_metadata(self) + cpdef str debug_error_string(self) + cdef grpc_status_code c_code(self) diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/rpc_error.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/rpc_status.pyx.pxi similarity index 62% rename from src/python/grpcio/grpc/_cython/_cygrpc/aio/rpc_error.pyx.pxi rename to src/python/grpcio/grpc/_cython/_cygrpc/aio/rpc_status.pyx.pxi index ca8a584d7a7..9784db19a1f 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/rpc_error.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/rpc_status.pyx.pxi @@ -14,16 +14,19 @@ """Exceptions for the aio version of the RPC calls.""" -cdef class AioRpcError(Exception): +cdef class AioRpcStatus(Exception): - def __cinit__(self, tuple initial_metadata, int code, str details, tuple trailing_metadata): - self._initial_metadata = initial_metadata + # The final status of gRPC is represented by three trailing metadata: + # `grpc-status`, `grpc-status-message`, abd `grpc-status-details`. + def __cinit__(self, + int code, + str details, + tuple trailing_metadata, + str debug_error_string): self._code = code self._details = details self._trailing_metadata = trailing_metadata - - cpdef tuple initial_metadata(self): - return self._initial_metadata + self._debug_error_string = debug_error_string cpdef int code(self): return self._code @@ -33,3 +36,9 @@ cdef class AioRpcError(Exception): cpdef tuple trailing_metadata(self): return self._trailing_metadata + + cpdef str debug_error_string(self): + return self._debug_error_string + + cdef grpc_status_code c_code(self): + return self._code diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pxd.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pxd.pxi index 4d85bafc338..f41e5f395d3 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pxd.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pxd.pxi @@ -43,3 +43,4 @@ cdef class AioServer: cdef object _shutdown_completed # asyncio.Future cdef CallbackWrapper _shutdown_callback_wrapper cdef object _crash_exception # Exception + cdef set _ongoing_rpc_tasks diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi index d332ecbd384..8330f64e4cb 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi @@ -12,6 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. + +import inspect + + # TODO(https://github.com/grpc/grpc/issues/20850) refactor this. _LOGGER = logging.getLogger(__name__) cdef int _EMPTY_FLAG = 0 @@ -23,9 +27,6 @@ cdef class _HandlerCallDetails: self.invocation_metadata = invocation_metadata -class _ServicerContextPlaceHolder(object): pass - - cdef class RPCState: def __cinit__(self): @@ -43,12 +44,49 @@ cdef class RPCState: grpc_call_unref(self.call) +cdef class _ServicerContext: + cdef RPCState _rpc_state + cdef object _loop + cdef bint _metadata_sent + cdef object _request_deserializer + cdef object _response_serializer + + def __cinit__(self, + RPCState rpc_state, + object request_deserializer, + object response_serializer, + object loop): + self._rpc_state = rpc_state + self._request_deserializer = request_deserializer + self._response_serializer = response_serializer + self._loop = loop + self._metadata_sent = False + + async def read(self): + cdef bytes raw_message = await _receive_message(self._rpc_state, self._loop) + return deserialize(self._request_deserializer, + raw_message) + + async def write(self, object message): + await _send_message(self._rpc_state, + serialize(self._response_serializer, message), + self._metadata_sent, + self._loop) + if not self._metadata_sent: + self._metadata_sent = True + + async def send_initial_metadata(self, tuple metadata): + if self._metadata_sent: + raise ValueError('Send initial metadata failed: already sent') + else: + _send_initial_metadata(self._rpc_state, self._loop) + self._metadata_sent = True + + cdef _find_method_handler(str method, list generic_handlers): # TODO(lidiz) connects Metadata to call details - cdef _HandlerCallDetails handler_call_details = _HandlerCallDetails( - method, - tuple() - ) + cdef _HandlerCallDetails handler_call_details = _HandlerCallDetails(method, + None) for generic_handler in generic_handlers: method_handler = generic_handler.service(handler_call_details) @@ -61,64 +99,130 @@ async def _handle_unary_unary_rpc(object method_handler, RPCState rpc_state, object loop): # Receives request message - cdef tuple receive_ops = ( - ReceiveMessageOperation(_EMPTY_FLAGS), - ) - await callback_start_batch(rpc_state, receive_ops, loop) + cdef bytes request_raw = await _receive_message(rpc_state, loop) # Deserializes the request message - cdef bytes request_raw = receive_ops[0].message() - cdef object request_message - if method_handler.request_deserializer: - request_message = method_handler.request_deserializer(request_raw) - else: - request_message = request_raw + cdef object request_message = deserialize( + method_handler.request_deserializer, + request_raw, + ) # Executes application logic - cdef object response_message = await method_handler.unary_unary(request_message, _ServicerContextPlaceHolder()) + cdef object response_message = await method_handler.unary_unary( + request_message, + _ServicerContext( + rpc_state, + None, + None, + loop, + ), + ) # Serializes the response message - cdef bytes response_raw - if method_handler.response_serializer: - response_raw = method_handler.response_serializer(response_message) - else: - response_raw = response_message + cdef bytes response_raw = serialize( + method_handler.response_serializer, + response_message, + ) # Sends response message cdef tuple send_ops = ( SendStatusFromServerOperation( - tuple(), StatusCode.ok, b'', _EMPTY_FLAGS), - SendInitialMetadataOperation(tuple(), _EMPTY_FLAGS), + tuple(), + StatusCode.ok, + b'', + _EMPTY_FLAGS, + ), + SendInitialMetadataOperation(None, _EMPTY_FLAGS), SendMessageOperation(response_raw, _EMPTY_FLAGS), ) await callback_start_batch(rpc_state, send_ops, loop) +async def _handle_unary_stream_rpc(object method_handler, + RPCState rpc_state, + object loop): + # Receives request message + cdef bytes request_raw = await _receive_message(rpc_state, loop) + + # Deserializes the request message + cdef object request_message = deserialize( + method_handler.request_deserializer, + request_raw, + ) + + cdef _ServicerContext servicer_context = _ServicerContext( + rpc_state, + method_handler.request_deserializer, + method_handler.response_serializer, + loop, + ) + + if inspect.iscoroutinefunction(method_handler.unary_stream): + # The handler uses reader / writer API, returns None. + await method_handler.unary_stream( + request_message, + servicer_context, + ) + return + + # The handler uses async generator API + cdef object async_response_generator = method_handler.unary_stream( + request_message, + servicer_context, + ) + + # Consumes messages from the generator + cdef object response_message + async for response_message in async_response_generator: + await servicer_context.write(response_message) + + cdef SendStatusFromServerOperation op = SendStatusFromServerOperation( + None, + StatusCode.ok, + b'', + _EMPTY_FLAGS, + ) + cdef tuple ops = (op,) + await callback_start_batch(rpc_state, ops, loop) + + +async def _handle_cancellation_from_core(object rpc_task, + RPCState rpc_state, + object loop): + cdef ReceiveCloseOnServerOperation op = ReceiveCloseOnServerOperation(_EMPTY_FLAG) + cdef tuple ops = (op,) + await callback_start_batch(rpc_state, ops, loop) + if op.cancelled() and not rpc_task.done(): + rpc_task.cancel() + + async def _handle_rpc(list generic_handlers, RPCState rpc_state, object loop): # Finds the method handler (application logic) cdef object method_handler = _find_method_handler( rpc_state.method().decode(), - generic_handlers + generic_handlers, ) if method_handler is None: # TODO(lidiz) return unimplemented error to client side raise NotImplementedError() # TODO(lidiz) extend to all 4 types of RPC - if method_handler.request_streaming or method_handler.response_streaming: - raise NotImplementedError() + if not method_handler.request_streaming and method_handler.response_streaming: + await _handle_unary_stream_rpc(method_handler, + rpc_state, + loop) + elif not method_handler.request_streaming and not method_handler.response_streaming: + await _handle_unary_unary_rpc(method_handler, + rpc_state, + loop) else: - await _handle_unary_unary_rpc( - method_handler, - rpc_state, - loop - ) + raise NotImplementedError() class _RequestCallError(Exception): pass cdef CallbackFailureHandler REQUEST_CALL_FAILURE_HANDLER = CallbackFailureHandler( - 'grpc_server_request_call', 'server shutdown', _RequestCallError) + 'grpc_server_request_call', None, _RequestCallError) async def _server_call_request_call(Server server, @@ -147,19 +251,9 @@ async def _server_call_request_call(Server server, return rpc_state -async def _handle_cancellation_from_core(object rpc_task, - RPCState rpc_state, - object loop): - cdef ReceiveCloseOnServerOperation op = ReceiveCloseOnServerOperation(_EMPTY_FLAG) - cdef tuple ops = (op,) - await callback_start_batch(rpc_state, ops, loop) - if op.cancelled() and not rpc_task.done(): - rpc_task.cancel() - - cdef CallbackFailureHandler SERVER_SHUTDOWN_FAILURE_HANDLER = CallbackFailureHandler( 'grpc_server_shutdown_and_notify', - 'Unknown', + None, RuntimeError) @@ -182,6 +276,7 @@ cdef class AioServer: self._generic_handlers = [] self.add_generic_rpc_handlers(generic_handlers) self._serving_task = None + self._ongoing_rpc_tasks = set() self._shutdown_lock = asyncio.Lock(loop=self._loop) self._shutdown_completed = self._loop.create_future() @@ -221,11 +316,13 @@ cdef class AioServer: if self._status != AIO_SERVER_STATUS_RUNNING: break + # Accepts new request from C-Core rpc_state = await _server_call_request_call( self._server, self._cq, self._loop) + # Schedules the RPC as a separate coroutine rpc_task = self._loop.create_task( _handle_rpc( self._generic_handlers, @@ -233,6 +330,8 @@ cdef class AioServer: self._loop ) ) + + # Fires off a task that listening on the cancellation from client. self._loop.create_task( _handle_cancellation_from_core( rpc_task, @@ -241,6 +340,10 @@ cdef class AioServer: ) ) + # Keeps track of created coroutines, so we can clean them up properly. + self._ongoing_rpc_tasks.add(rpc_task) + rpc_task.add_done_callback(lambda _: self._ongoing_rpc_tasks.remove(rpc_task)) + def _serving_task_crash_handler(self, object task): """Shutdown the server immediately if unexpectedly exited.""" if task.exception() is None: @@ -318,6 +421,10 @@ cdef class AioServer: grpc_server_cancel_all_calls(self._server.c_server) await self._shutdown_completed + # Cancels all Python layer tasks + for rpc_task in self._ongoing_rpc_tasks: + rpc_task.cancel() + async with self._shutdown_lock: if self._status == AIO_SERVER_STATUS_STOPPING: grpc_server_destroy(self._server.c_server) @@ -328,7 +435,7 @@ cdef class AioServer: # Shuts down the completion queue await self._cq.shutdown() - async def wait_for_termination(self, float timeout): + async def wait_for_termination(self, object timeout): if timeout is None: await self._shutdown_completed else: diff --git a/src/python/grpcio/grpc/_cython/cygrpc.pxd b/src/python/grpcio/grpc/_cython/cygrpc.pxd index b9a85e3619b..0ffd586104e 100644 --- a/src/python/grpcio/grpc/_cython/cygrpc.pxd +++ b/src/python/grpcio/grpc/_cython/cygrpc.pxd @@ -43,9 +43,9 @@ IF UNAME_SYSNAME != "Windows": include "_cygrpc/aio/iomgr/socket.pxd.pxi" include "_cygrpc/aio/iomgr/timer.pxd.pxi" include "_cygrpc/aio/iomgr/resolver.pxd.pxi" +include "_cygrpc/aio/rpc_status.pxd.pxi" include "_cygrpc/aio/grpc_aio.pxd.pxi" include "_cygrpc/aio/callback_common.pxd.pxi" include "_cygrpc/aio/call.pxd.pxi" -include "_cygrpc/aio/cancel_status.pxd.pxi" include "_cygrpc/aio/channel.pxd.pxi" include "_cygrpc/aio/server.pxd.pxi" diff --git a/src/python/grpcio/grpc/_cython/cygrpc.pyx b/src/python/grpcio/grpc/_cython/cygrpc.pyx index b123a9b7bd6..09f2bfef42c 100644 --- a/src/python/grpcio/grpc/_cython/cygrpc.pyx +++ b/src/python/grpcio/grpc/_cython/cygrpc.pyx @@ -60,12 +60,12 @@ include "_cygrpc/aio/iomgr/iomgr.pyx.pxi" include "_cygrpc/aio/iomgr/socket.pyx.pxi" include "_cygrpc/aio/iomgr/timer.pyx.pxi" include "_cygrpc/aio/iomgr/resolver.pyx.pxi" +include "_cygrpc/aio/common.pyx.pxi" +include "_cygrpc/aio/rpc_status.pyx.pxi" +include "_cygrpc/aio/callback_common.pyx.pxi" include "_cygrpc/aio/grpc_aio.pyx.pxi" include "_cygrpc/aio/call.pyx.pxi" -include "_cygrpc/aio/callback_common.pyx.pxi" -include "_cygrpc/aio/cancel_status.pyx.pxi" include "_cygrpc/aio/channel.pyx.pxi" -include "_cygrpc/aio/rpc_error.pyx.pxi" include "_cygrpc/aio/server.pyx.pxi" diff --git a/src/python/grpcio/grpc/experimental/BUILD.bazel b/src/python/grpcio/grpc/experimental/BUILD.bazel index 36340079e56..ad54eda480e 100644 --- a/src/python/grpcio/grpc/experimental/BUILD.bazel +++ b/src/python/grpcio/grpc/experimental/BUILD.bazel @@ -7,6 +7,8 @@ py_library( "aio/_call.py", "aio/_channel.py", "aio/_server.py", + "aio/_typing.py", + "aio/_base_call.py", ], deps = [ "//src/python/grpcio/grpc/_cython:cygrpc", diff --git a/src/python/grpcio/grpc/experimental/aio/__init__.py b/src/python/grpcio/grpc/experimental/aio/__init__.py index 3f6b96eaa54..3633ad2b598 100644 --- a/src/python/grpcio/grpc/experimental/aio/__init__.py +++ b/src/python/grpcio/grpc/experimental/aio/__init__.py @@ -17,12 +17,9 @@ import abc import six import grpc -from grpc import _common -from grpc._cython import cygrpc from grpc._cython.cygrpc import init_grpc_aio -from ._call import AioRpcError -from ._call import Call +from ._base_call import Call, UnaryUnaryCall, UnaryStreamCall from ._channel import Channel from ._channel import UnaryUnaryMultiCallable from ._server import server @@ -47,5 +44,5 @@ def insecure_channel(target, options=None, compression=None): ################################### __all__ ################################# -__all__ = ('AioRpcError', 'Call', 'init_grpc_aio', 'Channel', - 'UnaryUnaryMultiCallable', 'insecure_channel', 'server') +__all__ = ('Call', 'UnaryUnaryCall', 'UnaryStreamCall', 'init_grpc_aio', + 'Channel', 'UnaryUnaryMultiCallable', 'insecure_channel', 'server') diff --git a/src/python/grpcio/grpc/experimental/aio/_base_call.py b/src/python/grpcio/grpc/experimental/aio/_base_call.py new file mode 100644 index 00000000000..95692a0718d --- /dev/null +++ b/src/python/grpcio/grpc/experimental/aio/_base_call.py @@ -0,0 +1,127 @@ +# Copyright 2019 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Abstract base classes for client-side Call objects. + +Call objects represents the RPC itself, and offer methods to access / modify +its information. They also offer methods to manipulate the life-cycle of the +RPC, e.g. cancellation. +""" + +from abc import ABCMeta, abstractmethod +from typing import AsyncIterable, Awaitable, Generic, Text + +import grpc + +from ._typing import MetadataType, RequestType, ResponseType + +__all__ = 'Call', 'UnaryUnaryCall', 'UnaryStreamCall' + + +class Call(grpc.RpcContext, metaclass=ABCMeta): + """The abstract base class of an RPC on the client-side.""" + + @abstractmethod + def cancelled(self) -> bool: + """Return True if the RPC is cancelled. + + The RPC is cancelled when the cancellation was requested with cancel(). + + Returns: + A bool indicates whether the RPC is cancelled or not. + """ + + @abstractmethod + def done(self) -> bool: + """Return True if the RPC is done. + + An RPC is done if the RPC is completed, cancelled or aborted. + + Returns: + A bool indicates if the RPC is done. + """ + + @abstractmethod + async def initial_metadata(self) -> MetadataType: + """Accesses the initial metadata sent by the server. + + Coroutine continues once the value is available. + + Returns: + The initial :term:`metadata`. + """ + + @abstractmethod + async def trailing_metadata(self) -> MetadataType: + """Accesses the trailing metadata sent by the server. + + Coroutine continues once the value is available. + + Returns: + The trailing :term:`metadata`. + """ + + @abstractmethod + async def code(self) -> grpc.StatusCode: + """Accesses the status code sent by the server. + + Coroutine continues once the value is available. + + Returns: + The StatusCode value for the RPC. + """ + + @abstractmethod + async def details(self) -> Text: + """Accesses the details sent by the server. + + Coroutine continues once the value is available. + + Returns: + The details string of the RPC. + """ + + +class UnaryUnaryCall(Generic[RequestType, ResponseType], Call, metaclass=ABCMeta): + """The abstract base class of an unary-unary RPC on the client-side.""" + + @abstractmethod + def __await__(self) -> Awaitable[ResponseType]: + """Await the response message to be ready. + + Returns: + The response message of the RPC. + """ + + +class UnaryStreamCall(Generic[RequestType, ResponseType], Call, metaclass=ABCMeta): + + @abstractmethod + def __aiter__(self) -> AsyncIterable[ResponseType]: + """Returns the async iterable representation that yields messages. + + Under the hood, it is calling the "read" method. + + Returns: + An async iterable object that yields messages. + """ + + @abstractmethod + async def read(self) -> ResponseType: + """Reads one message from the RPC. + + Parallel read operations is not allowed. + + Returns: + A response message of the RPC. + """ diff --git a/src/python/grpcio/grpc/experimental/aio/_call.py b/src/python/grpcio/grpc/experimental/aio/_call.py index 70ac3628971..0e7334501e7 100644 --- a/src/python/grpcio/grpc/experimental/aio/_call.py +++ b/src/python/grpcio/grpc/experimental/aio/_call.py @@ -12,19 +12,42 @@ # See the License for the specific language governing permissions and # limitations under the License. """Invocation-side implementation of gRPC Asyncio Python.""" + import asyncio -import enum -from typing import Callable, Dict, Optional, ClassVar +from typing import AsyncIterable, Awaitable, Dict, Optional import grpc from grpc import _common from grpc._cython import cygrpc -DeserializingFunction = Callable[[bytes], str] +from . import _base_call +from ._typing import (DeserializingFunction, MetadataType, RequestType, + ResponseType, SerializingFunction) + +__all__ = 'AioRpcError', 'Call', 'UnaryUnaryCall', 'UnaryStreamCall' + +_LOCAL_CANCELLATION_DETAILS = 'Locally cancelled by application!' +_GC_CANCELLATION_DETAILS = 'Cancelled upon garbage collection!' +_RPC_ALREADY_FINISHED_DETAILS = 'RPC already finished.' + +_OK_CALL_REPRESENTATION = ('<{} of RPC that terminated with:\n' + '\tstatus = {}\n' + '\tdetails = "{}"\n' + '>') + +_NON_OK_CALL_REPRESENTATION = ('<{} of RPC that terminated with:\n' + '\tstatus = {}\n' + '\tdetails = "{}"\n' + '\tdebug_error_string = "{}"\n' + '>') class AioRpcError(grpc.RpcError): - """An RpcError to be used by the asynchronous API.""" + """An implementation of RpcError to be used by the asynchronous API. + + Raised RpcError is a snapshot of the final status of the RPC, values are + determined. Hence, its methods no longer needs to be coroutines. + """ # TODO(https://github.com/grpc/grpc/issues/20144) Metadata # type returned by `initial_metadata` and `trailing_metadata` @@ -33,14 +56,16 @@ class AioRpcError(grpc.RpcError): _code: grpc.StatusCode _details: Optional[str] - _initial_metadata: Optional[Dict] - _trailing_metadata: Optional[Dict] + _initial_metadata: Optional[MetadataType] + _trailing_metadata: Optional[MetadataType] + _debug_error_string: Optional[str] def __init__(self, code: grpc.StatusCode, details: Optional[str] = None, - initial_metadata: Optional[Dict] = None, - trailing_metadata: Optional[Dict] = None): + initial_metadata: Optional[MetadataType] = None, + trailing_metadata: Optional[MetadataType] = None, + debug_error_string: Optional[str] = None) -> None: """Constructor. Args: @@ -56,207 +81,335 @@ class AioRpcError(grpc.RpcError): self._details = details self._initial_metadata = initial_metadata self._trailing_metadata = trailing_metadata + self._debug_error_string = debug_error_string def code(self) -> grpc.StatusCode: - """ + """Accesses the status code sent by the server. + Returns: The `grpc.StatusCode` status code. """ return self._code def details(self) -> Optional[str]: - """ + """Accesses the details sent by the server. + Returns: The description of the error. """ return self._details def initial_metadata(self) -> Optional[Dict]: - """ + """Accesses the initial metadata sent by the server. + Returns: - The inital metadata received. + The initial metadata received. """ return self._initial_metadata def trailing_metadata(self) -> Optional[Dict]: - """ + """Accesses the trailing metadata sent by the server. + Returns: The trailing metadata received. """ return self._trailing_metadata + def debug_error_string(self) -> str: + """Accesses the debug error string sent by the server. -@enum.unique -class _RpcState(enum.Enum): - """Identifies the state of the RPC.""" - ONGOING = 1 - CANCELLED = 2 - FINISHED = 3 - ABORT = 4 + Returns: + The debug error string received. + """ + return self._debug_error_string + def _repr(self) -> str: + """Assembles the error string for the RPC error.""" + return _NON_OK_CALL_REPRESENTATION.format(self.__class__.__name__, + self._code, self._details, + self._debug_error_string) -class Call: - """Object for managing RPC calls, - returned when an instance of `UnaryUnaryMultiCallable` object is called. - """ + def __repr__(self) -> str: + return self._repr() - _cancellation_details: ClassVar[str] = 'Locally cancelled by application!' + def __str__(self) -> str: + return self._repr() - _state: _RpcState - _exception: Optional[Exception] - _response: Optional[bytes] - _code: grpc.StatusCode - _details: Optional[str] - _initial_metadata: Optional[Dict] - _trailing_metadata: Optional[Dict] - _call: asyncio.Task - _call_cancel_status: cygrpc.AioCancelStatus - _response_deserializer: DeserializingFunction - def __init__(self, call: asyncio.Task, - response_deserializer: DeserializingFunction, - call_cancel_status: cygrpc.AioCancelStatus) -> None: - """Constructor. +def _create_rpc_error(initial_metadata: Optional[MetadataType], + status: cygrpc.AioRpcStatus) -> AioRpcError: + return AioRpcError(_common.CYGRPC_STATUS_CODE_TO_STATUS_CODE[status.code()], + status.details(), initial_metadata, + status.trailing_metadata()) - Args: - call: Asyncio Task that holds the RPC execution. - response_deserializer: Deserializer used for parsing the reponse. - call_cancel_status: A cygrpc.AioCancelStatus used for giving a - specific error when the RPC is canceled. - """ - self._state = _RpcState.ONGOING - self._exception = None - self._response = None - self._code = grpc.StatusCode.UNKNOWN - self._details = None - self._initial_metadata = None - self._trailing_metadata = None - self._call = call - self._call_cancel_status = call_cancel_status - self._response_deserializer = response_deserializer +class Call(_base_call.Call): + _loop: asyncio.AbstractEventLoop + _code: grpc.StatusCode + _status: Awaitable[cygrpc.AioRpcStatus] + _initial_metadata: Awaitable[MetadataType] + _cancellation_future: asyncio.Future - def __del__(self): - self.cancel() + def __init__(self) -> None: + self._loop = asyncio.get_event_loop() + self._code = None + self._status = self._loop.create_future() + self._initial_metadata = self._loop.create_future() + self._cancellation_future = self._loop.create_future() def cancel(self) -> bool: - """Cancels the ongoing RPC request. + """Virtual cancellation method. - Returns: - True if the RPC can be canceled, False if was already cancelled or terminated. + The implementation of this method needs to pass the cancellation reason + into self._cancellation_future, using `set_result` instead of + `set_exception`. """ - if self.cancelled() or self.done(): - return False - - code = grpc.StatusCode.CANCELLED - self._call_cancel_status.cancel( - _common.STATUS_CODE_TO_CYGRPC_STATUS_CODE[code], - details=Call._cancellation_details) - self._call.cancel() - self._details = Call._cancellation_details - self._code = code - self._state = _RpcState.CANCELLED - return True + raise NotImplementedError() def cancelled(self) -> bool: - """Returns if the RPC was cancelled. - - Returns: - True if the requests was cancelled, False if not. - """ - return self._state is _RpcState.CANCELLED - - def running(self) -> bool: - """Returns if the RPC is running. - - Returns: - True if the requests is running, False if it already terminated. - """ - return not self.done() + return self._cancellation_future.done( + ) or self._code == grpc.StatusCode.CANCELLED def done(self) -> bool: - """Returns if the RPC has finished. + return self._status.done() - Returns: - True if the requests has finished, False is if still ongoing. - """ - return self._state is not _RpcState.ONGOING + def add_callback(self, unused_callback) -> None: + pass - async def initial_metadata(self): - raise NotImplementedError() + def is_active(self) -> bool: + return self.done() - async def trailing_metadata(self): - raise NotImplementedError() + def time_remaining(self) -> float: + pass - async def code(self) -> grpc.StatusCode: - """Returns the `grpc.StatusCode` if the RPC is finished, - otherwise first waits until the RPC finishes. + async def initial_metadata(self) -> MetadataType: + return await self._initial_metadata - Returns: - The `grpc.StatusCode` status code. - """ - if not self.done(): - try: - await self - except (asyncio.CancelledError, AioRpcError): - pass + async def trailing_metadata(self) -> MetadataType: + return (await self._status).trailing_metadata() + async def code(self) -> grpc.StatusCode: + await self._status return self._code async def details(self) -> str: - """Returns the details if the RPC is finished, otherwise first waits till the - RPC finishes. + return (await self._status).details() - Returns: - The details. + async def debug_error_string(self) -> str: + return (await self._status).debug_error_string() + + def _set_initial_metadata(self, metadata: MetadataType) -> None: + self._initial_metadata.set_result(metadata) + + def _set_status(self, status: cygrpc.AioRpcStatus) -> None: + """Private method to set final status of the RPC. + + This method may be called multiple time due to data race between local + cancellation (by application) and C-Core receiving status from peer. We + make no promise here which one will win. """ - if not self.done(): - try: - await self - except (asyncio.CancelledError, AioRpcError): - pass + if self._status.done(): + return + else: + self._status.set_result(status) + self._code = _common.CYGRPC_STATUS_CODE_TO_STATUS_CODE[ + status.code()] + + async def _raise_rpc_error_if_not_ok(self) -> None: + if self._code != grpc.StatusCode.OK: + raise _create_rpc_error(await self.initial_metadata(), + self._status.result()) + + def _repr(self) -> str: + """Assembles the RPC representation string.""" + if not self._status.done(): + return '<{} object>'.format(self.__class__.__name__) + if self._code is grpc.StatusCode.OK: + return _OK_CALL_REPRESENTATION.format( + self.__class__.__name__, self._code, + self._status.result().self._status.result().details()) + else: + return _NON_OK_CALL_REPRESENTATION.format( + self.__class__.__name__, self._code, + self._status.result().details(), + self._status.result().debug_error_string()) + + def __repr__(self) -> str: + return self._repr() + + def __str__(self) -> str: + return self._repr() + + +class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall): + """Object for managing unary-unary RPC calls. + + Returned when an instance of `UnaryUnaryMultiCallable` object is called. + """ + _loop: asyncio.AbstractEventLoop + _request: RequestType + _deadline: Optional[float] + _channel: cygrpc.AioChannel + _method: bytes + _request_serializer: SerializingFunction + _response_deserializer: DeserializingFunction + _call: asyncio.Task - return self._details + def __init__(self, request: RequestType, deadline: Optional[float], + channel: cygrpc.AioChannel, method: bytes, + request_serializer: SerializingFunction, + response_deserializer: DeserializingFunction) -> None: + super().__init__() + self._loop = asyncio.get_event_loop() + self._request = request + self._deadline = deadline + self._channel = channel + self._method = method + self._request_serializer = request_serializer + self._response_deserializer = response_deserializer + self._call = self._loop.create_task(self._invoke()) + + def __del__(self) -> None: + if not self._call.done(): + self._cancel( + cygrpc.AioRpcStatus(cygrpc.StatusCode.cancelled, + _GC_CANCELLATION_DETAILS, None, None)) + + async def _invoke(self) -> ResponseType: + serialized_request = _common.serialize(self._request, + self._request_serializer) + + # NOTE(lidiz) asyncio.CancelledError is not a good transport for + # status, since the Task class do not cache the exact + # asyncio.CancelledError object. So, the solution is catching the error + # in Cython layer, then cancel the RPC and update the status, finally + # re-raise the CancelledError. + serialized_response = await self._channel.unary_unary( + self._method, + serialized_request, + self._deadline, + self._cancellation_future, + self._set_initial_metadata, + self._set_status, + ) + await self._raise_rpc_error_if_not_ok() + + return _common.deserialize(serialized_response, + self._response_deserializer) + + def _cancel(self, status: cygrpc.AioRpcStatus) -> bool: + """Forwards the application cancellation reasoning.""" + if not self._status.done() and not self._cancellation_future.done(): + self._cancellation_future.set_result(status) + self._call.cancel() + return True + else: + return False - def __await__(self): + def cancel(self) -> bool: + return self._cancel( + cygrpc.AioRpcStatus(cygrpc.StatusCode.cancelled, + _LOCAL_CANCELLATION_DETAILS, None, None)) + + def __await__(self) -> ResponseType: """Wait till the ongoing RPC request finishes. Returns: Response of the RPC call. Raises: - AioRpcError: Indicating that the RPC terminated with non-OK status. + RpcError: Indicating that the RPC terminated with non-OK status. asyncio.CancelledError: Indicating that the RPC was canceled. """ - # We can not relay on the `done()` method since some exceptions - # might be pending to be catched, like `asyncio.CancelledError`. - if self._response: - return self._response - elif self._exception: - raise self._exception - - try: - buffer_ = yield from self._call.__await__() - except cygrpc.AioRpcError as aio_rpc_error: - self._state = _RpcState.ABORT - self._code = _common.CYGRPC_STATUS_CODE_TO_STATUS_CODE[ - aio_rpc_error.code()] - self._details = aio_rpc_error.details() - self._initial_metadata = aio_rpc_error.initial_metadata() - self._trailing_metadata = aio_rpc_error.trailing_metadata() - - # Propagates the pure Python class - self._exception = AioRpcError(self._code, self._details, - self._initial_metadata, - self._trailing_metadata) - raise self._exception from aio_rpc_error - except asyncio.CancelledError as cancel_error: - # _state, _code, _details are managed in the `cancel` method - self._exception = cancel_error - raise - - self._response = _common.deserialize(buffer_, - self._response_deserializer) - self._code = grpc.StatusCode.OK - self._state = _RpcState.FINISHED - return self._response + response = yield from self._call + return response + + +class UnaryStreamCall(Call, _base_call.UnaryStreamCall): + """Object for managing unary-stream RPC calls. + + Returned when an instance of `UnaryStreamMultiCallable` object is called. + """ + _loop: asyncio.AbstractEventLoop + _request: RequestType + _deadline: Optional[float] + _channel: cygrpc.AioChannel + _method: bytes + _request_serializer: SerializingFunction + _response_deserializer: DeserializingFunction + _call: AsyncIterable[ResponseType] + + def __init__(self, request: RequestType, deadline: Optional[float], + channel: cygrpc.AioChannel, method: bytes, + request_serializer: SerializingFunction, + response_deserializer: DeserializingFunction) -> None: + super().__init__() + self._loop = asyncio.get_event_loop() + self._request = request + self._deadline = deadline + self._channel = channel + self._method = method + self._request_serializer = request_serializer + self._response_deserializer = response_deserializer + self._call = self._invoke() + + def __del__(self) -> None: + if not self._status.done(): + self._cancel( + cygrpc.AioRpcStatus(cygrpc.StatusCode.cancelled, + _GC_CANCELLATION_DETAILS, None, None)) + + async def _invoke(self) -> ResponseType: + serialized_request = _common.serialize(self._request, + self._request_serializer) + + async_gen = self._channel.unary_stream( + self._method, + serialized_request, + self._deadline, + self._cancellation_future, + self._set_initial_metadata, + self._set_status, + ) + async for serialized_response in async_gen: + if self._cancellation_future.done(): + await self._status + if self._status.done(): + # Raises pre-maturely if final status received here. Generates + # more helpful stack trace for end users. + await self._raise_rpc_error_if_not_ok() + yield _common.deserialize(serialized_response, + self._response_deserializer) + + await self._raise_rpc_error_if_not_ok() + + def _cancel(self, status: cygrpc.AioRpcStatus) -> bool: + """Forwards the application cancellation reasoning. + + Async generator will receive an exception. The cancellation will go + deep down into C-Core, and then propagates backup as the + `cygrpc.AioRpcStatus` exception. + + So, under race condition, e.g. the server sent out final state headers + and the client calling "cancel" at the same time, this method respects + the winner in C-Core. + """ + if not self._status.done() and not self._cancellation_future.done(): + self._cancellation_future.set_result(status) + return True + else: + return False + + def cancel(self) -> bool: + return self._cancel( + cygrpc.AioRpcStatus(cygrpc.StatusCode.cancelled, + _LOCAL_CANCELLATION_DETAILS, None, None)) + + def __aiter__(self) -> AsyncIterable[ResponseType]: + return self._call + + async def read(self) -> ResponseType: + if self._status.done(): + await self._raise_rpc_error_if_not_ok() + raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS) + return await self._call.__anext__() diff --git a/src/python/grpcio/grpc/experimental/aio/_channel.py b/src/python/grpcio/grpc/experimental/aio/_channel.py index 389b952b0b2..6227f8bc11c 100644 --- a/src/python/grpcio/grpc/experimental/aio/_channel.py +++ b/src/python/grpcio/grpc/experimental/aio/_channel.py @@ -13,42 +13,111 @@ # limitations under the License. """Invocation-side implementation of gRPC Asyncio Python.""" import asyncio -from typing import Callable, Optional +from typing import Any, Optional, Sequence, Text, Tuple +import grpc from grpc import _common from grpc._cython import cygrpc +from ._call import Call, UnaryUnaryCall, UnaryStreamCall +from ._typing import (DeserializingFunction, MetadataType, SerializingFunction) -from ._call import Call -SerializingFunction = Callable[[str], bytes] -DeserializingFunction = Callable[[bytes], str] +def _timeout_to_deadline(loop: asyncio.AbstractEventLoop, + timeout: int) -> Optional[int]: + if timeout is None: + return None + return loop.time() + timeout class UnaryUnaryMultiCallable: - """Afford invoking a unary-unary RPC from client-side in an asynchronous way.""" + """Factory an asynchronous unary-unary RPC stub call from client-side.""" def __init__(self, channel: cygrpc.AioChannel, method: bytes, request_serializer: SerializingFunction, response_deserializer: DeserializingFunction) -> None: + self._loop = asyncio.get_event_loop() self._channel = channel self._method = method self._request_serializer = request_serializer self._response_deserializer = response_deserializer - self._loop = asyncio.get_event_loop() - def _timeout_to_deadline(self, timeout: int) -> Optional[int]: - if timeout is None: - return None - return self._loop.time() + timeout + def __call__(self, + request: Any, + *, + timeout: Optional[float] = None, + metadata: Optional[MetadataType] = None, + credentials: Optional[grpc.CallCredentials] = None, + wait_for_ready: Optional[bool] = None, + compression: Optional[grpc.Compression] = None) -> Call: + """Asynchronously invokes the underlying RPC. + + Args: + request: The request value for the RPC. + timeout: An optional duration of time in seconds to allow + for the RPC. + metadata: Optional :term:`metadata` to be transmitted to the + service-side of the RPC. + credentials: An optional CallCredentials for the RPC. Only valid for + secure Channel. + wait_for_ready: This is an EXPERIMENTAL argument. An optional + flag to enable wait for ready mechanism + compression: An element of grpc.compression, e.g. + grpc.compression.Gzip. This is an EXPERIMENTAL option. + + Returns: + A Call object instance which is an awaitable object. + + Raises: + RpcError: Indicating that the RPC terminated with non-OK status. The + raised RpcError will also be a Call for the RPC affording the RPC's + metadata, status code, and details. + """ + + if metadata: + raise NotImplementedError("TODO: metadata not implemented yet") + + if credentials: + raise NotImplementedError("TODO: credentials not implemented yet") + + if wait_for_ready: + raise NotImplementedError( + "TODO: wait_for_ready not implemented yet") + + if compression: + raise NotImplementedError("TODO: compression not implemented yet") + + deadline = _timeout_to_deadline(self._loop, timeout) + + return UnaryUnaryCall( + request, + deadline, + self._channel, + self._method, + self._request_serializer, + self._response_deserializer, + ) + + +class UnaryStreamMultiCallable: + """Afford invoking a unary-stream RPC from client-side in an asynchronous way.""" + + def __init__(self, channel: cygrpc.AioChannel, method: bytes, + request_serializer: SerializingFunction, + response_deserializer: DeserializingFunction) -> None: + self._channel = channel + self._method = method + self._request_serializer = request_serializer + self._response_deserializer = response_deserializer + self._loop = asyncio.get_event_loop() def __call__(self, - request, + request: Any, *, - timeout=None, - metadata=None, - credentials=None, - wait_for_ready=None, - compression=None) -> Call: + timeout: Optional[float] = None, + metadata: Optional[MetadataType] = None, + credentials: Optional[grpc.CallCredentials] = None, + wait_for_ready: Optional[bool] = None, + compression: Optional[grpc.Compression] = None) -> Call: """Asynchronously invokes the underlying RPC. Args: @@ -86,15 +155,16 @@ class UnaryUnaryMultiCallable: if compression: raise NotImplementedError("TODO: compression not implemented yet") - serialized_request = _common.serialize(request, - self._request_serializer) - timeout = self._timeout_to_deadline(timeout) - aio_cancel_status = cygrpc.AioCancelStatus() - aio_call = asyncio.ensure_future( - self._channel.unary_unary(self._method, serialized_request, timeout, - aio_cancel_status), - loop=self._loop) - return Call(aio_call, self._response_deserializer, aio_cancel_status) + deadline = _timeout_to_deadline(self._loop, timeout) + + return UnaryStreamCall( + request, + deadline, + self._channel, + self._method, + self._request_serializer, + self._response_deserializer, + ) class Channel: @@ -103,7 +173,10 @@ class Channel: A cygrpc.AioChannel-backed implementation. """ - def __init__(self, target, options, credentials, compression): + def __init__(self, target: Text, + options: Optional[Sequence[Tuple[Text, Any]]], + credentials: Optional[grpc.ChannelCredentials], + compression: Optional[grpc.Compression]): """Constructor. Args: @@ -125,10 +198,11 @@ class Channel: self._channel = cygrpc.AioChannel(_common.encode(target)) - def unary_unary(self, - method, - request_serializer=None, - response_deserializer=None): + def unary_unary( + self, + method: Text, + request_serializer: Optional[SerializingFunction] = None, + response_deserializer: Optional[DeserializingFunction] = None): """Creates a UnaryUnaryMultiCallable for a unary-unary method. Args: @@ -146,6 +220,29 @@ class Channel: request_serializer, response_deserializer) + def unary_stream( + self, + method: Text, + request_serializer: Optional[SerializingFunction] = None, + response_deserializer: Optional[DeserializingFunction] = None): + return UnaryStreamMultiCallable(self._channel, _common.encode(method), + request_serializer, + response_deserializer) + + def stream_unary( + self, + method: Text, + request_serializer: Optional[SerializingFunction] = None, + response_deserializer: Optional[DeserializingFunction] = None): + """Placeholder method for stream-unary calls.""" + + def stream_stream( + self, + method: Text, + request_serializer: Optional[SerializingFunction] = None, + response_deserializer: Optional[DeserializingFunction] = None): + """Placeholder method for stream-stream calls.""" + async def _close(self): # TODO: Send cancellation status self._channel.close() diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/cancel_status.pxd.pxi b/src/python/grpcio/grpc/experimental/aio/_typing.py similarity index 60% rename from src/python/grpcio/grpc/_cython/_cygrpc/aio/cancel_status.pxd.pxi rename to src/python/grpcio/grpc/experimental/aio/_typing.py index 47670e5deb1..818b89d8dbc 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/cancel_status.pxd.pxi +++ b/src/python/grpcio/grpc/experimental/aio/_typing.py @@ -1,4 +1,4 @@ -# Copyright 2019 gRPC authors. +# Copyright 2019 The gRPC Authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,13 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Desired cancellation status for canceling an ongoing RPC calls.""" +"""Common types for gRPC Async API""" +from typing import Any, AnyStr, Callable, Sequence, Text, Tuple, TypeVar -cdef class AioCancelStatus: - cdef readonly: - object _code - str _details - - cpdef object code(self) - cpdef str details(self) +RequestType = TypeVar('RequestType') +ResponseType = TypeVar('ResponseType') +SerializingFunction = Callable[[Any], bytes] +DeserializingFunction = Callable[[bytes], Any] +MetadataType = Sequence[Tuple[Text, AnyStr]] diff --git a/src/python/grpcio_tests/tests_aio/benchmark/BUILD.bazel b/src/python/grpcio_tests/tests_aio/benchmark/BUILD.bazel new file mode 100644 index 00000000000..a3c84ec033a --- /dev/null +++ b/src/python/grpcio_tests/tests_aio/benchmark/BUILD.bazel @@ -0,0 +1,32 @@ +# Copyright 2019 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package( + default_testonly = 1, + default_visibility = ["//visibility:public"], +) + +py_binary( + name = "server", + srcs = ["server.py"], + python_version = "PY3", + deps = [ + "//external:six", + "//src/proto/grpc/testing:benchmark_service_py_pb2", + "//src/proto/grpc/testing:benchmark_service_py_pb2_grpc", + "//src/proto/grpc/testing:py_messages_proto", + "//src/python/grpcio/grpc:grpcio", + "//src/python/grpcio_tests/tests/unit/framework/common", + ], +) diff --git a/src/python/grpcio_tests/tests_aio/benchmark/server.py b/src/python/grpcio_tests/tests_aio/benchmark/server.py index ef0a3f7ff2c..223de2d2f44 100644 --- a/src/python/grpcio_tests/tests_aio/benchmark/server.py +++ b/src/python/grpcio_tests/tests_aio/benchmark/server.py @@ -27,6 +27,12 @@ class BenchmarkServer(benchmark_service_pb2_grpc.BenchmarkServiceServicer): payload = messages_pb2.Payload(body=b'\0' * request.response_size) return messages_pb2.SimpleResponse(payload=payload) + async def StreamingFromServer(self, request, context): + payload = messages_pb2.Payload(body=b'\0' * request.response_size) + # Sends response at full capacity! + while True: + yield messages_pb2.SimpleResponse(payload=payload) + async def _start_async_server(): server = aio.server() @@ -37,6 +43,7 @@ async def _start_async_server(): servicer, server) await server.start() + logging.info('Benchmark server started at :%d' % port) await server.wait_for_termination() @@ -48,5 +55,5 @@ def main(): if __name__ == '__main__': - logging.basicConfig() + logging.basicConfig(level=logging.DEBUG) main() diff --git a/src/python/grpcio_tests/tests_aio/unit/_test_base.py b/src/python/grpcio_tests/tests_aio/unit/_test_base.py index 61602259043..17406f4a5d1 100644 --- a/src/python/grpcio_tests/tests_aio/unit/_test_base.py +++ b/src/python/grpcio_tests/tests_aio/unit/_test_base.py @@ -12,18 +12,56 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging +import functools import asyncio +from typing import Callable import unittest from grpc.experimental import aio +__all__ = 'AioTestBase' + +_COROUTINE_FUNCTION_ALLOWLIST = ['setUp', 'tearDown'] + + +def _async_to_sync_decorator(f: Callable, loop: asyncio.AbstractEventLoop): + + @functools.wraps(f) + def wrapper(*args, **kwargs): + return loop.run_until_complete(f(*args, **kwargs)) + + return wrapper + + +def _get_default_loop(debug=True): + try: + loop = asyncio.get_event_loop() + except: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + finally: + loop.set_debug(debug) + return loop + class AioTestBase(unittest.TestCase): - def setUp(self): - self._loop = asyncio.new_event_loop() - asyncio.set_event_loop(self._loop) + @classmethod + def setUpClass(cls): + cls._loop = _get_default_loop() aio.init_grpc_aio() @property def loop(self): return self._loop + + def __getattribute__(self, name): + """Overrides the loading logic to support coroutine functions.""" + attr = super().__getattribute__(name) + + # If possible, converts the coroutine into a sync function. + if name.startswith('test_') or name in _COROUTINE_FUNCTION_ALLOWLIST: + if asyncio.iscoroutinefunction(attr): + return _async_to_sync_decorator(attr, _get_default_loop()) + # For other attributes, let them pass. + return attr diff --git a/src/python/grpcio_tests/tests_aio/unit/_test_server.py b/src/python/grpcio_tests/tests_aio/unit/_test_server.py index 4b6ceebc816..014a14aa7f3 100644 --- a/src/python/grpcio_tests/tests_aio/unit/_test_server.py +++ b/src/python/grpcio_tests/tests_aio/unit/_test_server.py @@ -12,22 +12,39 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio from time import sleep +import logging from grpc.experimental import aio from src.proto.grpc.testing import messages_pb2 from src.proto.grpc.testing import test_pb2_grpc from tests.unit.framework.common import test_constants +_US_IN_A_SECOND = 1000 * 1000 + class _TestServiceServicer(test_pb2_grpc.TestServiceServicer): async def UnaryCall(self, request, context): return messages_pb2.SimpleResponse() + # TODO(lidizheng) The semantic of this call is not matching its description + # See src/proto/grpc/testing/test.proto async def EmptyCall(self, request, context): while True: - sleep(test_constants.LONG_TIMEOUT) + await asyncio.sleep(test_constants.LONG_TIMEOUT) + + async def StreamingOutputCall( + self, request: messages_pb2.StreamingOutputCallRequest, context): + for response_parameters in request.response_parameters: + if response_parameters.interval_us != 0: + await asyncio.sleep( + response_parameters.interval_us / _US_IN_A_SECOND) + yield messages_pb2.StreamingOutputCallResponse( + payload=messages_pb2.Payload( + type=request.response_type, + body=b'\x00' * response_parameters.size)) async def start_test_server(): diff --git a/src/python/grpcio_tests/tests_aio/unit/aio_rpc_error_test.py b/src/python/grpcio_tests/tests_aio/unit/aio_rpc_error_test.py new file mode 100644 index 00000000000..60bc6ead06f --- /dev/null +++ b/src/python/grpcio_tests/tests_aio/unit/aio_rpc_error_test.py @@ -0,0 +1,45 @@ +# Copyright 2019 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests AioRpcError class.""" + +import logging +import unittest + +import grpc + +from grpc.experimental.aio._call import AioRpcError +from tests_aio.unit._test_base import AioTestBase + + +class TestAioRpcError(unittest.TestCase): + _TEST_INITIAL_METADATA = ("initial metadata",) + _TEST_TRAILING_METADATA = ("trailing metadata",) + + def test_attributes(self): + aio_rpc_error = AioRpcError( + grpc.StatusCode.CANCELLED, + "details", + initial_metadata=self._TEST_INITIAL_METADATA, + trailing_metadata=self._TEST_TRAILING_METADATA) + self.assertEqual(aio_rpc_error.code(), grpc.StatusCode.CANCELLED) + self.assertEqual(aio_rpc_error.details(), "details") + self.assertEqual(aio_rpc_error.initial_metadata(), + self._TEST_INITIAL_METADATA) + self.assertEqual(aio_rpc_error.trailing_metadata(), + self._TEST_TRAILING_METADATA) + + +if __name__ == '__main__': + logging.basicConfig() + unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests_aio/unit/call_test.py b/src/python/grpcio_tests/tests_aio/unit/call_test.py index cadd7e416bb..81480b62180 100644 --- a/src/python/grpcio_tests/tests_aio/unit/call_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/call_test.py @@ -11,186 +11,288 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""Tests behavior of the grpc.aio.UnaryUnaryCall class.""" + import asyncio import logging import unittest +import datetime import grpc from grpc.experimental import aio from src.proto.grpc.testing import messages_pb2 +from src.proto.grpc.testing import test_pb2_grpc from tests.unit.framework.common import test_constants from tests_aio.unit._test_server import start_test_server from tests_aio.unit._test_base import AioTestBase - -class TestAioRpcError(unittest.TestCase): - _TEST_INITIAL_METADATA = ("initial metadata",) - _TEST_TRAILING_METADATA = ("trailing metadata",) - - def test_attributes(self): - aio_rpc_error = aio.AioRpcError( - grpc.StatusCode.CANCELLED, - "details", - initial_metadata=self._TEST_INITIAL_METADATA, - trailing_metadata=self._TEST_TRAILING_METADATA) - self.assertEqual(aio_rpc_error.code(), grpc.StatusCode.CANCELLED) - self.assertEqual(aio_rpc_error.details(), "details") - self.assertEqual(aio_rpc_error.initial_metadata(), - self._TEST_INITIAL_METADATA) - self.assertEqual(aio_rpc_error.trailing_metadata(), - self._TEST_TRAILING_METADATA) +_NUM_STREAM_RESPONSES = 5 +_RESPONSE_PAYLOAD_SIZE = 42 +_LOCAL_CANCEL_DETAILS_EXPECTATION = 'Locally cancelled by application!' +# _RESPONSE_INTERVAL_US = test_constants.SHORT_TIMEOUT * 1000 * 1000 +_RESPONSE_INTERVAL_US = 200 * 1000 class TestCall(AioTestBase): - def test_call_ok(self): - - async def coro(): - server_target, _ = await start_test_server() # pylint: disable=unused-variable - - async with aio.insecure_channel(server_target) as channel: - hi = channel.unary_unary( - '/grpc.testing.TestService/UnaryCall', - request_serializer=messages_pb2.SimpleRequest. - SerializeToString, - response_deserializer=messages_pb2.SimpleResponse.FromString - ) - call = hi(messages_pb2.SimpleRequest()) - - self.assertFalse(call.done()) - - response = await call - - self.assertTrue(call.done()) - self.assertEqual(type(response), messages_pb2.SimpleResponse) - self.assertEqual(await call.code(), grpc.StatusCode.OK) - - # Response is cached at call object level, reentrance - # returns again the same response - response_retry = await call - self.assertIs(response, response_retry) - - self.loop.run_until_complete(coro()) - - def test_call_rpc_error(self): - - async def coro(): - server_target, _ = await start_test_server() # pylint: disable=unused-variable - - async with aio.insecure_channel(server_target) as channel: - empty_call_with_sleep = channel.unary_unary( - "/grpc.testing.TestService/EmptyCall", - request_serializer=messages_pb2.SimpleRequest. - SerializeToString, - response_deserializer=messages_pb2.SimpleResponse. - FromString, - ) - timeout = test_constants.SHORT_TIMEOUT / 2 - # TODO(https://github.com/grpc/grpc/issues/20869 - # Update once the async server is ready, change the - # synchronization mechanism by removing the sleep() - # as both components (client & server) will be on the same - # process. - call = empty_call_with_sleep( - messages_pb2.SimpleRequest(), timeout=timeout) - - with self.assertRaises(grpc.RpcError) as exception_context: - await call - - self.assertTrue(call.done()) - self.assertEqual(await call.code(), - grpc.StatusCode.DEADLINE_EXCEEDED) - - # Exception is cached at call object level, reentrance - # returns again the same exception - with self.assertRaises( - grpc.RpcError) as exception_context_retry: - await call - - self.assertIs(exception_context.exception, - exception_context_retry.exception) - - self.loop.run_until_complete(coro()) - - def test_call_code_awaitable(self): - - async def coro(): - server_target, _ = await start_test_server() # pylint: disable=unused-variable - - async with aio.insecure_channel(server_target) as channel: - hi = channel.unary_unary( - '/grpc.testing.TestService/UnaryCall', - request_serializer=messages_pb2.SimpleRequest. - SerializeToString, - response_deserializer=messages_pb2.SimpleResponse.FromString - ) - call = hi(messages_pb2.SimpleRequest()) - self.assertEqual(await call.code(), grpc.StatusCode.OK) - - self.loop.run_until_complete(coro()) - - def test_call_details_awaitable(self): - - async def coro(): - server_target, _ = await start_test_server() # pylint: disable=unused-variable - - async with aio.insecure_channel(server_target) as channel: - hi = channel.unary_unary( - '/grpc.testing.TestService/UnaryCall', - request_serializer=messages_pb2.SimpleRequest. - SerializeToString, - response_deserializer=messages_pb2.SimpleResponse.FromString - ) - call = hi(messages_pb2.SimpleRequest()) - self.assertEqual(await call.details(), None) - - self.loop.run_until_complete(coro()) - - def test_cancel(self): - - async def coro(): - server_target, _ = await start_test_server() # pylint: disable=unused-variable - - async with aio.insecure_channel(server_target) as channel: - hi = channel.unary_unary( - '/grpc.testing.TestService/UnaryCall', - request_serializer=messages_pb2.SimpleRequest. - SerializeToString, - response_deserializer=messages_pb2.SimpleResponse.FromString - ) - call = hi(messages_pb2.SimpleRequest()) - - self.assertFalse(call.cancelled()) - - # TODO(https://github.com/grpc/grpc/issues/20869) remove sleep. - # Force the loop to execute the RPC task. - await asyncio.sleep(0) - - self.assertTrue(call.cancel()) - self.assertTrue(call.cancelled()) - self.assertFalse(call.cancel()) - - with self.assertRaises( - asyncio.CancelledError) as exception_context: - await call - - self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED) - self.assertEqual(await call.details(), - 'Locally cancelled by application!') - - # Exception is cached at call object level, reentrance - # returns again the same exception - with self.assertRaises( - asyncio.CancelledError) as exception_context_retry: - await call - - self.assertIs(exception_context.exception, - exception_context_retry.exception) - - self.loop.run_until_complete(coro()) + async def setUp(self): + self._server_target, self._server = await start_test_server() + + async def tearDown(self): + await self._server.stop(None) + + async def test_call_ok(self): + async with aio.insecure_channel(self._server_target) as channel: + hi = channel.unary_unary( + '/grpc.testing.TestService/UnaryCall', + request_serializer=messages_pb2.SimpleRequest.SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString) + call = hi(messages_pb2.SimpleRequest()) + + self.assertFalse(call.done()) + + response = await call + + self.assertTrue(call.done()) + self.assertEqual(type(response), messages_pb2.SimpleResponse) + self.assertEqual(await call.code(), grpc.StatusCode.OK) + + # Response is cached at call object level, reentrance + # returns again the same response + response_retry = await call + self.assertIs(response, response_retry) + + async def test_call_rpc_error(self): + async with aio.insecure_channel(self._server_target) as channel: + empty_call_with_sleep = channel.unary_unary( + "/grpc.testing.TestService/EmptyCall", + request_serializer=messages_pb2.SimpleRequest.SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString, + ) + timeout = test_constants.SHORT_TIMEOUT / 2 + # TODO(https://github.com/grpc/grpc/issues/20869 + # Update once the async server is ready, change the + # synchronization mechanism by removing the sleep() + # as both components (client & server) will be on the same + # process. + call = empty_call_with_sleep( + messages_pb2.SimpleRequest(), timeout=timeout) + + with self.assertRaises(grpc.RpcError) as exception_context: + await call + + self.assertTrue(call.done()) + self.assertEqual(await call.code(), + grpc.StatusCode.DEADLINE_EXCEEDED) + + # Exception is cached at call object level, reentrance + # returns again the same exception + with self.assertRaises(grpc.RpcError) as exception_context_retry: + await call + + self.assertIs(exception_context.exception, + exception_context_retry.exception) + + async def test_call_code_awaitable(self): + async with aio.insecure_channel(self._server_target) as channel: + hi = channel.unary_unary( + '/grpc.testing.TestService/UnaryCall', + request_serializer=messages_pb2.SimpleRequest.SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString) + call = hi(messages_pb2.SimpleRequest()) + self.assertEqual(await call.code(), grpc.StatusCode.OK) + + async def test_call_details_awaitable(self): + async with aio.insecure_channel(self._server_target) as channel: + hi = channel.unary_unary( + '/grpc.testing.TestService/UnaryCall', + request_serializer=messages_pb2.SimpleRequest.SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString) + call = hi(messages_pb2.SimpleRequest()) + self.assertEqual('', await call.details()) + + async def test_cancel_unary_unary(self): + async with aio.insecure_channel(self._server_target) as channel: + hi = channel.unary_unary( + '/grpc.testing.TestService/UnaryCall', + request_serializer=messages_pb2.SimpleRequest.SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString) + call = hi(messages_pb2.SimpleRequest()) + + self.assertFalse(call.cancelled()) + + # TODO(https://github.com/grpc/grpc/issues/20869) remove sleep. + # Force the loop to execute the RPC task. + await asyncio.sleep(0) + + self.assertTrue(call.cancel()) + self.assertFalse(call.cancel()) + + with self.assertRaises(asyncio.CancelledError) as exception_context: + await call + + self.assertTrue(call.cancelled()) + self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED) + self.assertEqual(await call.details(), + 'Locally cancelled by application!') + + # NOTE(lidiz) The CancelledError is almost always re-created, + # so we might not want to use it to transmit data. + # https://github.com/python/cpython/blob/master/Lib/asyncio/tasks.py#L785 + + async def test_cancel_unary_stream(self): + async with aio.insecure_channel(self._server_target) as channel: + stub = test_pb2_grpc.TestServiceStub(channel) + + # Prepares the request + request = messages_pb2.StreamingOutputCallRequest() + for _ in range(_NUM_STREAM_RESPONSES): + request.response_parameters.append( + messages_pb2.ResponseParameters( + size=_RESPONSE_PAYLOAD_SIZE, + interval_us=_RESPONSE_INTERVAL_US, + )) + + # Invokes the actual RPC + call = stub.StreamingOutputCall(request) + self.assertFalse(call.cancelled()) + + response = await call.read() + self.assertIs( + type(response), messages_pb2.StreamingOutputCallResponse) + self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) + + self.assertTrue(call.cancel()) + self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) + self.assertEqual(_LOCAL_CANCEL_DETAILS_EXPECTATION, await + call.details()) + self.assertFalse(call.cancel()) + + with self.assertRaises(grpc.RpcError) as exception_context: + await call.read() + self.assertTrue(call.cancelled()) + + async def test_multiple_cancel_unary_stream(self): + async with aio.insecure_channel(self._server_target) as channel: + stub = test_pb2_grpc.TestServiceStub(channel) + + # Prepares the request + request = messages_pb2.StreamingOutputCallRequest() + for _ in range(_NUM_STREAM_RESPONSES): + request.response_parameters.append( + messages_pb2.ResponseParameters( + size=_RESPONSE_PAYLOAD_SIZE, + interval_us=_RESPONSE_INTERVAL_US, + )) + + # Invokes the actual RPC + call = stub.StreamingOutputCall(request) + self.assertFalse(call.cancelled()) + + response = await call.read() + self.assertIs( + type(response), messages_pb2.StreamingOutputCallResponse) + self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) + + self.assertTrue(call.cancel()) + self.assertFalse(call.cancel()) + self.assertFalse(call.cancel()) + self.assertFalse(call.cancel()) + + with self.assertRaises(grpc.RpcError) as exception_context: + await call.read() + + async def test_early_cancel_unary_stream(self): + """Test cancellation before receiving messages.""" + async with aio.insecure_channel(self._server_target) as channel: + stub = test_pb2_grpc.TestServiceStub(channel) + + # Prepares the request + request = messages_pb2.StreamingOutputCallRequest() + for _ in range(_NUM_STREAM_RESPONSES): + request.response_parameters.append( + messages_pb2.ResponseParameters( + size=_RESPONSE_PAYLOAD_SIZE, + interval_us=_RESPONSE_INTERVAL_US, + )) + + # Invokes the actual RPC + call = stub.StreamingOutputCall(request) + + self.assertFalse(call.cancelled()) + self.assertTrue(call.cancel()) + self.assertFalse(call.cancel()) + + with self.assertRaises(grpc.RpcError) as exception_context: + await call.read() + + self.assertTrue(call.cancelled()) + + self.assertEqual(grpc.StatusCode.CANCELLED, + exception_context.exception.code()) + self.assertEqual(_LOCAL_CANCEL_DETAILS_EXPECTATION, + exception_context.exception.details()) + + self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) + self.assertEqual(_LOCAL_CANCEL_DETAILS_EXPECTATION, await + call.details()) + + async def test_late_cancel_unary_stream(self): + """Test cancellation after received all messages.""" + async with aio.insecure_channel(self._server_target) as channel: + stub = test_pb2_grpc.TestServiceStub(channel) + + # Prepares the request + request = messages_pb2.StreamingOutputCallRequest() + for _ in range(_NUM_STREAM_RESPONSES): + request.response_parameters.append( + messages_pb2.ResponseParameters( + size=_RESPONSE_PAYLOAD_SIZE,)) + + # Invokes the actual RPC + call = stub.StreamingOutputCall(request) + + for _ in range(_NUM_STREAM_RESPONSES): + response = await call.read() + self.assertIs( + type(response), messages_pb2.StreamingOutputCallResponse) + self.assertEqual(_RESPONSE_PAYLOAD_SIZE, + len(response.payload.body)) + + # After all messages received, it is possible that the final state + # is received or on its way. It's basically a data race, so our + # expectation here is do not crash :) + call.cancel() + + async def test_too_many_reads_unary_stream(self): + """Test cancellation after received all messages.""" + async with aio.insecure_channel(self._server_target) as channel: + stub = test_pb2_grpc.TestServiceStub(channel) + + # Prepares the request + request = messages_pb2.StreamingOutputCallRequest() + for _ in range(_NUM_STREAM_RESPONSES): + request.response_parameters.append( + messages_pb2.ResponseParameters( + size=_RESPONSE_PAYLOAD_SIZE,)) + + # Invokes the actual RPC + call = stub.StreamingOutputCall(request) + + for _ in range(_NUM_STREAM_RESPONSES): + response = await call.read() + self.assertIs( + type(response), messages_pb2.StreamingOutputCallResponse) + self.assertEqual(_RESPONSE_PAYLOAD_SIZE, + len(response.payload.body)) + + # After the RPC is finished, further reads will lead to exception. + self.assertEqual(await call.code(), grpc.StatusCode.OK) + with self.assertRaises(asyncio.InvalidStateError): + await call.read() if __name__ == '__main__': - logging.basicConfig() + logging.basicConfig(level=logging.DEBUG) unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests_aio/unit/channel_test.py b/src/python/grpcio_tests/tests_aio/unit/channel_test.py index 076300786fb..0487ea050c3 100644 --- a/src/python/grpcio_tests/tests_aio/unit/channel_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/channel_test.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""Tests behavior of the grpc.aio.Channel class.""" + import logging import unittest @@ -18,38 +20,38 @@ import grpc from grpc.experimental import aio from src.proto.grpc.testing import messages_pb2 +from src.proto.grpc.testing import test_pb2_grpc from tests.unit.framework.common import test_constants from tests_aio.unit._test_server import start_test_server from tests_aio.unit._test_base import AioTestBase _UNARY_CALL_METHOD = '/grpc.testing.TestService/UnaryCall' _EMPTY_CALL_METHOD = '/grpc.testing.TestService/EmptyCall' +_STREAMING_OUTPUT_CALL_METHOD = '/grpc.testing.TestService/StreamingOutputCall' +_NUM_STREAM_RESPONSES = 5 +_RESPONSE_PAYLOAD_SIZE = 42 -class TestChannel(AioTestBase): - - def test_async_context(self): - async def coro(): - server_target, _ = await start_test_server() # pylint: disable=unused-variable - - async with aio.insecure_channel(server_target) as channel: - hi = channel.unary_unary( - _UNARY_CALL_METHOD, - request_serializer=messages_pb2.SimpleRequest. - SerializeToString, - response_deserializer=messages_pb2.SimpleResponse.FromString - ) - await hi(messages_pb2.SimpleRequest()) +class TestChannel(AioTestBase): - self.loop.run_until_complete(coro()) + async def setUp(self): + self._server_target, self._server = await start_test_server() - def test_unary_unary(self): + async def tearDown(self): + await self._server.stop(None) - async def coro(): - server_target, _ = await start_test_server() # pylint: disable=unused-variable + async def test_async_context(self): + async with aio.insecure_channel(self._server_target) as channel: + hi = channel.unary_unary( + _UNARY_CALL_METHOD, + request_serializer=messages_pb2.SimpleRequest.SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString) + await hi(messages_pb2.SimpleRequest()) - channel = aio.insecure_channel(server_target) + async def test_unary_unary(self): + async with aio.insecure_channel(self._server_target) as channel: + channel = aio.insecure_channel(self._server_target) hi = channel.unary_unary( _UNARY_CALL_METHOD, request_serializer=messages_pb2.SimpleRequest.SerializeToString, @@ -58,63 +60,70 @@ class TestChannel(AioTestBase): self.assertIs(type(response), messages_pb2.SimpleResponse) - await channel.close() - - self.loop.run_until_complete(coro()) - - def test_unary_call_times_out(self): - - async def coro(): - server_target, _ = await start_test_server() # pylint: disable=unused-variable - - async with aio.insecure_channel(server_target) as channel: - empty_call_with_sleep = channel.unary_unary( - _EMPTY_CALL_METHOD, - request_serializer=messages_pb2.SimpleRequest. - SerializeToString, - response_deserializer=messages_pb2.SimpleResponse. - FromString, - ) - timeout = test_constants.SHORT_TIMEOUT / 2 - # TODO(https://github.com/grpc/grpc/issues/20869) - # Update once the async server is ready, change the - # synchronization mechanism by removing the sleep() - # as both components (client & server) will be on the same - # process. - with self.assertRaises(grpc.RpcError) as exception_context: - await empty_call_with_sleep( - messages_pb2.SimpleRequest(), timeout=timeout) - - _, details = grpc.StatusCode.DEADLINE_EXCEEDED.value # pylint: disable=unused-variable - self.assertEqual(exception_context.exception.code(), - grpc.StatusCode.DEADLINE_EXCEEDED) - self.assertEqual(exception_context.exception.details(), - details.title()) - self.assertIsNotNone( - exception_context.exception.initial_metadata()) - self.assertIsNotNone( - exception_context.exception.trailing_metadata()) - - self.loop.run_until_complete(coro()) + async def test_unary_call_times_out(self): + async with aio.insecure_channel(self._server_target) as channel: + empty_call_with_sleep = channel.unary_unary( + _EMPTY_CALL_METHOD, + request_serializer=messages_pb2.SimpleRequest.SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString, + ) + timeout = test_constants.SHORT_TIMEOUT / 2 + # TODO(https://github.com/grpc/grpc/issues/20869) + # Update once the async server is ready, change the + # synchronization mechanism by removing the sleep() + # as both components (client & server) will be on the same + # process. + with self.assertRaises(grpc.RpcError) as exception_context: + await empty_call_with_sleep( + messages_pb2.SimpleRequest(), timeout=timeout) + + _, details = grpc.StatusCode.DEADLINE_EXCEEDED.value # pylint: disable=unused-variable + self.assertEqual(grpc.StatusCode.DEADLINE_EXCEEDED, + exception_context.exception.code()) + self.assertEqual(details.title(), + exception_context.exception.details()) + self.assertIsNotNone(exception_context.exception.initial_metadata()) + self.assertIsNotNone( + exception_context.exception.trailing_metadata()) @unittest.skip('https://github.com/grpc/grpc/issues/20818') - def test_call_to_the_void(self): + async def test_call_to_the_void(self): + channel = aio.insecure_channel('0.1.1.1:1111') + hi = channel.unary_unary( + _UNARY_CALL_METHOD, + request_serializer=messages_pb2.SimpleRequest.SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString) + response = await hi(messages_pb2.SimpleRequest()) - async def coro(): - channel = aio.insecure_channel('0.1.1.1:1111') - hi = channel.unary_unary( - _UNARY_CALL_METHOD, - request_serializer=messages_pb2.SimpleRequest.SerializeToString, - response_deserializer=messages_pb2.SimpleResponse.FromString) - response = await hi(messages_pb2.SimpleRequest()) + self.assertIs(type(response), messages_pb2.SimpleResponse) - self.assertIs(type(response), messages_pb2.SimpleResponse) + await channel.close() + + async def test_unary_stream(self): + channel = aio.insecure_channel(self._server_target) + stub = test_pb2_grpc.TestServiceStub(channel) + + # Prepares the request + request = messages_pb2.StreamingOutputCallRequest() + for _ in range(_NUM_STREAM_RESPONSES): + request.response_parameters.append( + messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE)) + + # Invokes the actual RPC + call = stub.StreamingOutputCall(request) - await channel.close() + # Validates the responses + response_cnt = 0 + async for response in call: + response_cnt += 1 + self.assertIs( + type(response), messages_pb2.StreamingOutputCallResponse) + self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) - self.loop.run_until_complete(coro()) + self.assertEqual(_NUM_STREAM_RESPONSES, response_cnt) + await channel.close() if __name__ == '__main__': - logging.basicConfig() + logging.basicConfig(level=logging.DEBUG) unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests_aio/unit/init_test.py b/src/python/grpcio_tests/tests_aio/unit/init_test.py index 9f5d8bb0d85..e619428b38e 100644 --- a/src/python/grpcio_tests/tests_aio/unit/init_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/init_test.py @@ -21,15 +21,11 @@ from tests_aio.unit._test_base import AioTestBase class TestInsecureChannel(AioTestBase): - def test_insecure_channel(self): + async def test_insecure_channel(self): + server_target, _ = await start_test_server() # pylint: disable=unused-variable - async def coro(): - server_target, _ = await start_test_server() # pylint: disable=unused-variable - - channel = aio.insecure_channel(server_target) - self.assertIsInstance(channel, aio.Channel) - - self.loop.run_until_complete(coro()) + channel = aio.insecure_channel(server_target) + self.assertIsInstance(channel, aio.Channel) if __name__ == '__main__': diff --git a/src/python/grpcio_tests/tests_aio/unit/server_test.py b/src/python/grpcio_tests/tests_aio/unit/server_test.py index 1e86de65404..962ab520ca9 100644 --- a/src/python/grpcio_tests/tests_aio/unit/server_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/server_test.py @@ -26,9 +26,12 @@ from tests.unit.framework.common import test_constants _SIMPLE_UNARY_UNARY = '/test/SimpleUnaryUnary' _BLOCK_FOREVER = '/test/BlockForever' _BLOCK_BRIEFLY = '/test/BlockBriefly' +_UNARY_STREAM_ASYNC_GEN = '/test/UnaryStreamAsyncGen' +_UNARY_STREAM_READER_WRITER = '/test/UnaryStreamReaderWriter' _REQUEST = b'\x00\x00\x00' _RESPONSE = b'\x01\x01\x01' +_NUM_STREAM_RESPONSES = 5 class _GenericHandler(grpc.GenericRpcHandler): @@ -43,10 +46,18 @@ class _GenericHandler(grpc.GenericRpcHandler): async def _block_forever(self, unused_request, unused_context): await asyncio.get_event_loop().create_future() - async def _BLOCK_BRIEFLY(self, unused_request, unused_context): + async def _block_briefly(self, unused_request, unused_context): await asyncio.sleep(test_constants.SHORT_TIMEOUT / 2) return _RESPONSE + async def _unary_stream_async_gen(self, unused_request, unused_context): + for _ in range(_NUM_STREAM_RESPONSES): + yield _RESPONSE + + async def _unary_stream_reader_writer(self, unused_request, context): + for _ in range(_NUM_STREAM_RESPONSES): + context.write(_RESPONSE) + def service(self, handler_details): self._called.set_result(None) if handler_details.method == _SIMPLE_UNARY_UNARY: @@ -54,7 +65,13 @@ class _GenericHandler(grpc.GenericRpcHandler): if handler_details.method == _BLOCK_FOREVER: return grpc.unary_unary_rpc_method_handler(self._block_forever) if handler_details.method == _BLOCK_BRIEFLY: - return grpc.unary_unary_rpc_method_handler(self._BLOCK_BRIEFLY) + return grpc.unary_unary_rpc_method_handler(self._block_briefly) + if handler_details.method == _UNARY_STREAM_ASYNC_GEN: + return grpc.unary_stream_rpc_method_handler( + self._unary_stream_async_gen) + if handler_details.method == _UNARY_STREAM_READER_WRITER: + return grpc.unary_stream_rpc_method_handler( + self._unary_stream_reader_writer) async def wait_for_call(self): await self._called @@ -71,148 +88,132 @@ async def _start_test_server(): class TestServer(AioTestBase): - def test_unary_unary(self): - - async def test_unary_unary_body(): - result = await _start_test_server() - server_target = result[0] - - async with aio.insecure_channel(server_target) as channel: - unary_call = channel.unary_unary(_SIMPLE_UNARY_UNARY) - response = await unary_call(_REQUEST) - self.assertEqual(response, _RESPONSE) - - self.loop.run_until_complete(test_unary_unary_body()) - - def test_shutdown(self): - - async def test_shutdown_body(): - _, server, _ = await _start_test_server() - await server.stop(None) - - self.loop.run_until_complete(test_shutdown_body()) - # Ensures no SIGSEGV triggered, and ends within timeout. - - def test_shutdown_after_call(self): - - async def test_shutdown_body(): - server_target, server, _ = await _start_test_server() - - async with aio.insecure_channel(server_target) as channel: - await channel.unary_unary(_SIMPLE_UNARY_UNARY)(_REQUEST) - - await server.stop(None) - - self.loop.run_until_complete(test_shutdown_body()) - - def test_graceful_shutdown_success(self): - - async def test_graceful_shutdown_success_body(): - server_target, server, generic_handler = await _start_test_server() + async def setUp(self): + self._server_target, self._server, self._generic_handler = await _start_test_server( + ) - channel = aio.insecure_channel(server_target) - call = channel.unary_unary(_BLOCK_BRIEFLY)(_REQUEST) - await generic_handler.wait_for_call() + async def tearDown(self): + await self._server.stop(None) - shutdown_start_time = time.time() - await server.stop(test_constants.SHORT_TIMEOUT) - grace_period_length = time.time() - shutdown_start_time - self.assertGreater(grace_period_length, - test_constants.SHORT_TIMEOUT / 3) + async def test_unary_unary(self): + async with aio.insecure_channel(self._server_target) as channel: + unary_unary_call = channel.unary_unary(_SIMPLE_UNARY_UNARY) + response = await unary_unary_call(_REQUEST) + self.assertEqual(response, _RESPONSE) - # Validates the states. - await channel.close() - self.assertEqual(_RESPONSE, await call) - self.assertTrue(call.done()) + async def test_unary_stream_async_generator(self): + async with aio.insecure_channel(self._server_target) as channel: + unary_stream_call = channel.unary_stream(_UNARY_STREAM_ASYNC_GEN) + call = unary_stream_call(_REQUEST) - self.loop.run_until_complete(test_graceful_shutdown_success_body()) + response_cnt = 0 + async for response in call: + response_cnt += 1 + self.assertEqual(_RESPONSE, response) - def test_graceful_shutdown_failed(self): + self.assertEqual(_NUM_STREAM_RESPONSES, response_cnt) + self.assertEqual(await call.code(), grpc.StatusCode.OK) - async def test_graceful_shutdown_failed_body(): - server_target, server, generic_handler = await _start_test_server() + async def test_unary_stream_reader_writer(self): + async with aio.insecure_channel(self._server_target) as channel: + unary_stream_call = channel.unary_stream(_UNARY_STREAM_ASYNC_GEN) + call = unary_stream_call(_REQUEST) - channel = aio.insecure_channel(server_target) - call = channel.unary_unary(_BLOCK_FOREVER)(_REQUEST) - await generic_handler.wait_for_call() + for _ in range(_NUM_STREAM_RESPONSES): + response = await call.read() + self.assertEqual(_RESPONSE, response) - await server.stop(test_constants.SHORT_TIMEOUT) + self.assertEqual(await call.code(), grpc.StatusCode.OK) - with self.assertRaises(aio.AioRpcError) as exception_context: - await call - self.assertEqual(grpc.StatusCode.UNAVAILABLE, - exception_context.exception.code()) - self.assertIn('GOAWAY', exception_context.exception.details()) - await channel.close() - - self.loop.run_until_complete(test_graceful_shutdown_failed_body()) - - def test_concurrent_graceful_shutdown(self): - - async def test_concurrent_graceful_shutdown_body(): - server_target, server, generic_handler = await _start_test_server() - - channel = aio.insecure_channel(server_target) - call = channel.unary_unary(_BLOCK_BRIEFLY)(_REQUEST) - await generic_handler.wait_for_call() - - # Expects the shortest grace period to be effective. - shutdown_start_time = time.time() - await asyncio.gather( - server.stop(test_constants.LONG_TIMEOUT), - server.stop(test_constants.SHORT_TIMEOUT), - server.stop(test_constants.LONG_TIMEOUT), - ) - grace_period_length = time.time() - shutdown_start_time - self.assertGreater(grace_period_length, - test_constants.SHORT_TIMEOUT / 3) - - await channel.close() - self.assertEqual(_RESPONSE, await call) - self.assertTrue(call.done()) - - self.loop.run_until_complete(test_concurrent_graceful_shutdown_body()) - - def test_concurrent_graceful_shutdown_immediate(self): - - async def test_concurrent_graceful_shutdown_immediate_body(): - server_target, server, generic_handler = await _start_test_server() - - channel = aio.insecure_channel(server_target) - call = channel.unary_unary(_BLOCK_FOREVER)(_REQUEST) - await generic_handler.wait_for_call() - - # Expects no grace period, due to the "server.stop(None)". - await asyncio.gather( - server.stop(test_constants.LONG_TIMEOUT), - server.stop(None), - server.stop(test_constants.SHORT_TIMEOUT), - server.stop(test_constants.LONG_TIMEOUT), - ) - - with self.assertRaises(aio.AioRpcError) as exception_context: - await call - self.assertEqual(grpc.StatusCode.UNAVAILABLE, - exception_context.exception.code()) - self.assertIn('GOAWAY', exception_context.exception.details()) - await channel.close() + async def test_shutdown(self): + await self._server.stop(None) + # Ensures no SIGSEGV triggered, and ends within timeout. - self.loop.run_until_complete( - test_concurrent_graceful_shutdown_immediate_body()) + async def test_shutdown_after_call(self): + async with aio.insecure_channel(self._server_target) as channel: + await channel.unary_unary(_SIMPLE_UNARY_UNARY)(_REQUEST) + + await self._server.stop(None) + + async def test_graceful_shutdown_success(self): + channel = aio.insecure_channel(self._server_target) + call = channel.unary_unary(_BLOCK_BRIEFLY)(_REQUEST) + await self._generic_handler.wait_for_call() + + shutdown_start_time = time.time() + await self._server.stop(test_constants.SHORT_TIMEOUT) + grace_period_length = time.time() - shutdown_start_time + self.assertGreater(grace_period_length, + test_constants.SHORT_TIMEOUT / 3) + + # Validates the states. + await channel.close() + self.assertEqual(_RESPONSE, await call) + self.assertTrue(call.done()) + + async def test_graceful_shutdown_failed(self): + channel = aio.insecure_channel(self._server_target) + call = channel.unary_unary(_BLOCK_FOREVER)(_REQUEST) + await self._generic_handler.wait_for_call() + + await self._server.stop(test_constants.SHORT_TIMEOUT) + + with self.assertRaises(grpc.RpcError) as exception_context: + await call + self.assertEqual(grpc.StatusCode.UNAVAILABLE, + exception_context.exception.code()) + self.assertIn('GOAWAY', exception_context.exception.details()) + await channel.close() + + async def test_concurrent_graceful_shutdown(self): + channel = aio.insecure_channel(self._server_target) + call = channel.unary_unary(_BLOCK_BRIEFLY)(_REQUEST) + await self._generic_handler.wait_for_call() + + # Expects the shortest grace period to be effective. + shutdown_start_time = time.time() + await asyncio.gather( + self._server.stop(test_constants.LONG_TIMEOUT), + self._server.stop(test_constants.SHORT_TIMEOUT), + self._server.stop(test_constants.LONG_TIMEOUT), + ) + grace_period_length = time.time() - shutdown_start_time + self.assertGreater(grace_period_length, + test_constants.SHORT_TIMEOUT / 3) + + await channel.close() + self.assertEqual(_RESPONSE, await call) + self.assertTrue(call.done()) + + async def test_concurrent_graceful_shutdown_immediate(self): + channel = aio.insecure_channel(self._server_target) + call = channel.unary_unary(_BLOCK_FOREVER)(_REQUEST) + await self._generic_handler.wait_for_call() + + # Expects no grace period, due to the "server.stop(None)". + await asyncio.gather( + self._server.stop(test_constants.LONG_TIMEOUT), + self._server.stop(None), + self._server.stop(test_constants.SHORT_TIMEOUT), + self._server.stop(test_constants.LONG_TIMEOUT), + ) + + with self.assertRaises(grpc.RpcError) as exception_context: + await call + self.assertEqual(grpc.StatusCode.UNAVAILABLE, + exception_context.exception.code()) + self.assertIn('GOAWAY', exception_context.exception.details()) + await channel.close() @unittest.skip('https://github.com/grpc/grpc/issues/20818') - def test_shutdown_before_call(self): - - async def test_shutdown_body(): - server_target, server, _ = _start_test_server() - await server.stop(None) - - # Ensures the server is cleaned up at this point. - # Some proper exception should be raised. - async with aio.insecure_channel('localhost:%d' % port) as channel: - await channel.unary_unary(_SIMPLE_UNARY_UNARY)(_REQUEST) - - self.loop.run_until_complete(test_shutdown_body()) + async def test_shutdown_before_call(self): + server_target, server, _ = _start_test_server() + await server.stop(None) + + # Ensures the server is cleaned up at this point. + # Some proper exception should be raised. + async with aio.insecure_channel('localhost:%d' % port) as channel: + await channel.unary_unary(_SIMPLE_UNARY_UNARY)(_REQUEST) if __name__ == '__main__':