diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/callback_common.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/callback_common.pyx.pxi index 756ba6e3d1f..67848cadaf8 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/callback_common.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/callback_common.pyx.pxi @@ -66,7 +66,7 @@ cdef class CallbackWrapper: cdef CallbackFailureHandler CQ_SHUTDOWN_FAILURE_HANDLER = CallbackFailureHandler( 'grpc_completion_queue_shutdown', 'Unknown', - RuntimeError) + InternalError) cdef class CallbackCompletionQueue: diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi index 38cb8887350..bfa9477b6d1 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi @@ -71,8 +71,7 @@ cdef class AioChannel: other design of API if necessary. """ if self._status in (AIO_CHANNEL_STATUS_DESTROYED, AIO_CHANNEL_STATUS_CLOSING): - # TODO(lidiz) switch to UsageError - raise RuntimeError('Channel is closed.') + raise UsageError('Channel is closed.') cdef gpr_timespec c_deadline = _timespec_from_time(deadline) @@ -115,8 +114,7 @@ cdef class AioChannel: The _AioCall object. """ if self.closed(): - # TODO(lidiz) switch to UsageError - raise RuntimeError('Channel is closed.') + raise UsageError('Channel is closed.') cdef CallCredentials cython_call_credentials if python_call_credentials is not None: diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/common.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/common.pyx.pxi index 9997de195e4..cf9269364ce 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/common.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/common.pyx.pxi @@ -73,3 +73,27 @@ _COMPRESSION_METADATA_STRING_MAPPING = { CompressionAlgorithm.deflate: 'deflate', CompressionAlgorithm.gzip: 'gzip', } + +class BaseError(Exception): + """The base class for exceptions generated by gRPC AsyncIO stack.""" + + +class UsageError(BaseError): + """Raised when the usage of API by applications is inappropriate. + + For example, trying to invoke RPC on a closed channel, mixing two styles + of streaming API on the client side. This exception should not be + suppressed. + """ + + +class AbortError(BaseError): + """Raised when calling abort in servicer methods. + + This exception should not be suppressed. Applications may catch it to + perform certain clean-up logic, and then re-raise it. + """ + + +class InternalError(BaseError): + """Raised upon unexpected errors in native code.""" diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi index f0b9670666a..903c20796f7 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi @@ -37,7 +37,7 @@ cdef class _HandlerCallDetails: self.invocation_metadata = invocation_metadata -class _ServerStoppedError(RuntimeError): +class _ServerStoppedError(BaseError): """Raised if the server is stopped.""" @@ -77,7 +77,7 @@ cdef class RPCState: if self.abort_exception is not None: raise self.abort_exception if self.status_sent: - raise RuntimeError(_RPC_FINISHED_DETAILS) + raise UsageError(_RPC_FINISHED_DETAILS) if self.server._status == AIO_SERVER_STATUS_STOPPED: raise _ServerStoppedError(_SERVER_STOPPED_DETAILS) @@ -107,11 +107,6 @@ cdef class RPCState: grpc_call_unref(self.call) -# TODO(lidiz) inherit this from Python level `AioRpcStatus`, we need to improve -# current code structure to make it happen. -class AbortError(Exception): pass - - cdef class _ServicerContext: cdef RPCState _rpc_state cdef object _loop @@ -155,7 +150,7 @@ cdef class _ServicerContext: self._rpc_state.raise_for_termination() if self._rpc_state.metadata_sent: - raise RuntimeError('Send initial metadata failed: already sent') + raise UsageError('Send initial metadata failed: already sent') else: await _send_initial_metadata( self._rpc_state, @@ -170,7 +165,7 @@ cdef class _ServicerContext: str details='', tuple trailing_metadata=_IMMUTABLE_EMPTY_METADATA): if self._rpc_state.abort_exception is not None: - raise RuntimeError('Abort already called!') + raise UsageError('Abort already called!') else: # Keeps track of the exception object. After abort happen, the RPC # should stop execution. However, if users decided to suppress it, it @@ -579,7 +574,7 @@ cdef CallbackFailureHandler REQUEST_CALL_FAILURE_HANDLER = CallbackFailureHandle cdef CallbackFailureHandler SERVER_SHUTDOWN_FAILURE_HANDLER = CallbackFailureHandler( 'grpc_server_shutdown_and_notify', None, - RuntimeError) + InternalError) cdef class AioServer: @@ -642,7 +637,7 @@ cdef class AioServer: wrapper.c_functor() ) if error != GRPC_CALL_OK: - raise RuntimeError("Error in grpc_server_request_call: %s" % error) + raise InternalError("Error in grpc_server_request_call: %s" % error) await future return rpc_state @@ -692,7 +687,7 @@ cdef class AioServer: if self._status == AIO_SERVER_STATUS_RUNNING: return elif self._status != AIO_SERVER_STATUS_READY: - raise RuntimeError('Server not in ready state') + raise UsageError('Server not in ready state') self._status = AIO_SERVER_STATUS_RUNNING cdef object server_started = self._loop.create_future() @@ -788,11 +783,7 @@ cdef class AioServer: return True def __dealloc__(self): - """Deallocation of Core objects are ensured by Python grpc.aio.Server. - - If the Cython representation is deallocated without underlying objects - freed, raise an RuntimeError. - """ + """Deallocation of Core objects are ensured by Python layer.""" # TODO(lidiz) if users create server, and then dealloc it immediately. # There is a potential memory leak of created Core server. if self._status != AIO_SERVER_STATUS_STOPPED: diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi index a3be6bae479..00311b5ea2a 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi @@ -118,7 +118,7 @@ cdef class Server: def cancel_all_calls(self): if not self.is_shutting_down: - raise RuntimeError("the server must be shutting down to cancel all calls") + raise UsageError("the server must be shutting down to cancel all calls") elif self.is_shutdown: return else: @@ -136,7 +136,7 @@ cdef class Server: pass elif not self.is_shutting_down: if self.backup_shutdown_queue is None: - raise RuntimeError('Server shutdown failed: no completion queue.') + raise InternalError('Server shutdown failed: no completion queue.') else: # the user didn't call shutdown - use our backup queue self._c_shutdown(self.backup_shutdown_queue, None) diff --git a/src/python/grpcio/grpc/experimental/aio/__init__.py b/src/python/grpcio/grpc/experimental/aio/__init__.py index 0839c79010d..db2feb5b2d4 100644 --- a/src/python/grpcio/grpc/experimental/aio/__init__.py +++ b/src/python/grpcio/grpc/experimental/aio/__init__.py @@ -17,12 +17,11 @@ gRPC Async API objects may only be used on the thread on which they were created. AsyncIO doesn't provide thread safety for most of its APIs. """ -import abc from typing import Any, Optional, Sequence, Text, Tuple -import six import grpc -from grpc._cython.cygrpc import EOF, AbortError, init_grpc_aio +from grpc._cython.cygrpc import (EOF, AbortError, BaseError, UsageError, + init_grpc_aio) from ._base_call import Call, RpcContext, UnaryStreamCall, UnaryUnaryCall from ._call import AioRpcError @@ -88,4 +87,4 @@ __all__ = ('AioRpcError', 'RpcContext', 'Call', 'UnaryUnaryCall', 'UnaryUnaryMultiCallable', 'ClientCallDetails', 'UnaryUnaryClientInterceptor', 'InterceptedUnaryUnaryCall', 'insecure_channel', 'server', 'Server', 'EOF', 'secure_channel', - 'AbortError') + 'AbortError', 'BaseError', 'UsageError') diff --git a/src/python/grpcio/grpc/experimental/aio/_call.py b/src/python/grpcio/grpc/experimental/aio/_call.py index 1a43c44f5c5..d06cc18d872 100644 --- a/src/python/grpcio/grpc/experimental/aio/_call.py +++ b/src/python/grpcio/grpc/experimental/aio/_call.py @@ -16,6 +16,7 @@ import asyncio from functools import partial import logging +import enum from typing import AsyncIterable, Awaitable, Dict, Optional import grpc @@ -238,6 +239,12 @@ class Call: return self._repr() +class _APIStyle(enum.IntEnum): + UNKNOWN = 0 + ASYNC_GENERATOR = 1 + READER_WRITER = 2 + + class _UnaryResponseMixin(Call): _call_response: asyncio.Task @@ -283,10 +290,19 @@ class _UnaryResponseMixin(Call): class _StreamResponseMixin(Call): _message_aiter: AsyncIterable[ResponseType] _preparation: asyncio.Task + _response_style: _APIStyle def _init_stream_response_mixin(self, preparation: asyncio.Task): self._message_aiter = None self._preparation = preparation + self._response_style = _APIStyle.UNKNOWN + + def _update_response_style(self, style: _APIStyle): + if self._response_style is _APIStyle.UNKNOWN: + self._response_style = style + elif self._response_style is not style: + raise cygrpc.UsageError( + 'Please don\'t mix two styles of API for streaming responses') def cancel(self) -> bool: if super().cancel(): @@ -302,6 +318,7 @@ class _StreamResponseMixin(Call): message = await self._read() def __aiter__(self) -> AsyncIterable[ResponseType]: + self._update_response_style(_APIStyle.ASYNC_GENERATOR) if self._message_aiter is None: self._message_aiter = self._fetch_stream_responses() return self._message_aiter @@ -328,6 +345,7 @@ class _StreamResponseMixin(Call): if self.done(): await self._raise_for_status() return cygrpc.EOF + self._update_response_style(_APIStyle.READER_WRITER) response_message = await self._read() @@ -339,20 +357,28 @@ class _StreamResponseMixin(Call): class _StreamRequestMixin(Call): _metadata_sent: asyncio.Event - _done_writing: bool + _done_writing_flag: bool _async_request_poller: Optional[asyncio.Task] + _request_style: _APIStyle def _init_stream_request_mixin( self, request_async_iterator: Optional[AsyncIterable[RequestType]]): self._metadata_sent = asyncio.Event(loop=self._loop) - self._done_writing = False + self._done_writing_flag = False # If user passes in an async iterator, create a consumer Task. if request_async_iterator is not None: self._async_request_poller = self._loop.create_task( self._consume_request_iterator(request_async_iterator)) + self._request_style = _APIStyle.ASYNC_GENERATOR else: self._async_request_poller = None + self._request_style = _APIStyle.READER_WRITER + + def _raise_for_different_style(self, style: _APIStyle): + if self._request_style is not style: + raise cygrpc.UsageError( + 'Please don\'t mix two styles of API for streaming requests') def cancel(self) -> bool: if super().cancel(): @@ -369,8 +395,8 @@ class _StreamRequestMixin(Call): self, request_async_iterator: AsyncIterable[RequestType]) -> None: try: async for request in request_async_iterator: - await self.write(request) - await self.done_writing() + await self._write(request) + await self._done_writing() except AioRpcError as rpc_error: # Rpc status should be exposed through other API. Exceptions raised # within this Task won't be retrieved by another coroutine. It's @@ -378,10 +404,10 @@ class _StreamRequestMixin(Call): _LOGGER.debug('Exception while consuming the request_iterator: %s', rpc_error) - async def write(self, request: RequestType) -> None: + async def _write(self, request: RequestType) -> None: if self.done(): raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS) - if self._done_writing: + if self._done_writing_flag: raise asyncio.InvalidStateError(_RPC_HALF_CLOSED_DETAILS) if not self._metadata_sent.is_set(): await self._metadata_sent.wait() @@ -398,14 +424,13 @@ class _StreamRequestMixin(Call): self.cancel() await self._raise_for_status() - async def done_writing(self) -> None: - """Implementation of done_writing is idempotent.""" + async def _done_writing(self) -> None: if self.done(): # If the RPC is finished, do nothing. return - if not self._done_writing: + if not self._done_writing_flag: # If the done writing is not sent before, try to send it. - self._done_writing = True + self._done_writing_flag = True try: await self._cython_call.send_receive_close() except asyncio.CancelledError: @@ -413,6 +438,18 @@ class _StreamRequestMixin(Call): self.cancel() await self._raise_for_status() + async def write(self, request: RequestType) -> None: + self._raise_for_different_style(_APIStyle.READER_WRITER) + await self._write(request) + + async def done_writing(self) -> None: + """Signal peer that client is done writing. + + This method is idempotent. + """ + self._raise_for_different_style(_APIStyle.READER_WRITER) + await self._done_writing() + class UnaryUnaryCall(_UnaryResponseMixin, Call, _base_call.UnaryUnaryCall): """Object for managing unary-unary RPC calls. diff --git a/src/python/grpcio_tests/tests_aio/unit/connectivity_test.py b/src/python/grpcio_tests/tests_aio/unit/connectivity_test.py index 7acf53b95c7..7f98329070b 100644 --- a/src/python/grpcio_tests/tests_aio/unit/connectivity_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/connectivity_test.py @@ -102,7 +102,7 @@ class TestConnectivityState(AioTestBase): # It can raise exceptions since it is an usage error, but it should not # segfault or abort. - with self.assertRaises(RuntimeError): + with self.assertRaises(aio.UsageError): await channel.wait_for_state_change( grpc.ChannelConnectivity.SHUTDOWN) diff --git a/src/python/grpcio_tests/tests_aio/unit/server_test.py b/src/python/grpcio_tests/tests_aio/unit/server_test.py index 39288d90777..70240fefee1 100644 --- a/src/python/grpcio_tests/tests_aio/unit/server_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/server_test.py @@ -231,14 +231,10 @@ class TestServer(AioTestBase): # Uses reader API self.assertEqual(_RESPONSE, await call.read()) - # Uses async generator API - response_cnt = 0 - async for response in call: - response_cnt += 1 - self.assertEqual(_RESPONSE, response) - - self.assertEqual(_NUM_STREAM_RESPONSES - 1, response_cnt) - self.assertEqual(await call.code(), grpc.StatusCode.OK) + # Uses async generator API, mixed! + with self.assertRaises(aio.UsageError): + async for response in call: + self.assertEqual(_RESPONSE, response) async def test_stream_unary_async_generator(self): stream_unary_call = self._channel.stream_unary(_STREAM_UNARY_ASYNC_GEN)