Merge pull request #21910 from lidizheng/aio-no-mix

[Aio] Prohibit mixing two styles of API on client side
pull/21946/head
Lidi Zheng 5 years ago committed by GitHub
commit 56aa2c1143
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      src/python/grpcio/grpc/_cython/_cygrpc/aio/callback_common.pyx.pxi
  2. 6
      src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi
  3. 24
      src/python/grpcio/grpc/_cython/_cygrpc/aio/common.pyx.pxi
  4. 25
      src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi
  5. 4
      src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi
  6. 7
      src/python/grpcio/grpc/experimental/aio/__init__.py
  7. 57
      src/python/grpcio/grpc/experimental/aio/_call.py
  8. 2
      src/python/grpcio_tests/tests_aio/unit/connectivity_test.py
  9. 8
      src/python/grpcio_tests/tests_aio/unit/server_test.py

@ -66,7 +66,7 @@ cdef class CallbackWrapper:
cdef CallbackFailureHandler CQ_SHUTDOWN_FAILURE_HANDLER = CallbackFailureHandler(
'grpc_completion_queue_shutdown',
'Unknown',
RuntimeError)
InternalError)
cdef class CallbackCompletionQueue:

@ -71,8 +71,7 @@ cdef class AioChannel:
other design of API if necessary.
"""
if self._status in (AIO_CHANNEL_STATUS_DESTROYED, AIO_CHANNEL_STATUS_CLOSING):
# TODO(lidiz) switch to UsageError
raise RuntimeError('Channel is closed.')
raise UsageError('Channel is closed.')
cdef gpr_timespec c_deadline = _timespec_from_time(deadline)
@ -115,8 +114,7 @@ cdef class AioChannel:
The _AioCall object.
"""
if self.closed():
# TODO(lidiz) switch to UsageError
raise RuntimeError('Channel is closed.')
raise UsageError('Channel is closed.')
cdef CallCredentials cython_call_credentials
if python_call_credentials is not None:

@ -73,3 +73,27 @@ _COMPRESSION_METADATA_STRING_MAPPING = {
CompressionAlgorithm.deflate: 'deflate',
CompressionAlgorithm.gzip: 'gzip',
}
class BaseError(Exception):
"""The base class for exceptions generated by gRPC AsyncIO stack."""
class UsageError(BaseError):
"""Raised when the usage of API by applications is inappropriate.
For example, trying to invoke RPC on a closed channel, mixing two styles
of streaming API on the client side. This exception should not be
suppressed.
"""
class AbortError(BaseError):
"""Raised when calling abort in servicer methods.
This exception should not be suppressed. Applications may catch it to
perform certain clean-up logic, and then re-raise it.
"""
class InternalError(BaseError):
"""Raised upon unexpected errors in native code."""

@ -37,7 +37,7 @@ cdef class _HandlerCallDetails:
self.invocation_metadata = invocation_metadata
class _ServerStoppedError(RuntimeError):
class _ServerStoppedError(BaseError):
"""Raised if the server is stopped."""
@ -77,7 +77,7 @@ cdef class RPCState:
if self.abort_exception is not None:
raise self.abort_exception
if self.status_sent:
raise RuntimeError(_RPC_FINISHED_DETAILS)
raise UsageError(_RPC_FINISHED_DETAILS)
if self.server._status == AIO_SERVER_STATUS_STOPPED:
raise _ServerStoppedError(_SERVER_STOPPED_DETAILS)
@ -107,11 +107,6 @@ cdef class RPCState:
grpc_call_unref(self.call)
# TODO(lidiz) inherit this from Python level `AioRpcStatus`, we need to improve
# current code structure to make it happen.
class AbortError(Exception): pass
cdef class _ServicerContext:
cdef RPCState _rpc_state
cdef object _loop
@ -155,7 +150,7 @@ cdef class _ServicerContext:
self._rpc_state.raise_for_termination()
if self._rpc_state.metadata_sent:
raise RuntimeError('Send initial metadata failed: already sent')
raise UsageError('Send initial metadata failed: already sent')
else:
await _send_initial_metadata(
self._rpc_state,
@ -170,7 +165,7 @@ cdef class _ServicerContext:
str details='',
tuple trailing_metadata=_IMMUTABLE_EMPTY_METADATA):
if self._rpc_state.abort_exception is not None:
raise RuntimeError('Abort already called!')
raise UsageError('Abort already called!')
else:
# Keeps track of the exception object. After abort happen, the RPC
# should stop execution. However, if users decided to suppress it, it
@ -579,7 +574,7 @@ cdef CallbackFailureHandler REQUEST_CALL_FAILURE_HANDLER = CallbackFailureHandle
cdef CallbackFailureHandler SERVER_SHUTDOWN_FAILURE_HANDLER = CallbackFailureHandler(
'grpc_server_shutdown_and_notify',
None,
RuntimeError)
InternalError)
cdef class AioServer:
@ -642,7 +637,7 @@ cdef class AioServer:
wrapper.c_functor()
)
if error != GRPC_CALL_OK:
raise RuntimeError("Error in grpc_server_request_call: %s" % error)
raise InternalError("Error in grpc_server_request_call: %s" % error)
await future
return rpc_state
@ -692,7 +687,7 @@ cdef class AioServer:
if self._status == AIO_SERVER_STATUS_RUNNING:
return
elif self._status != AIO_SERVER_STATUS_READY:
raise RuntimeError('Server not in ready state')
raise UsageError('Server not in ready state')
self._status = AIO_SERVER_STATUS_RUNNING
cdef object server_started = self._loop.create_future()
@ -788,11 +783,7 @@ cdef class AioServer:
return True
def __dealloc__(self):
"""Deallocation of Core objects are ensured by Python grpc.aio.Server.
If the Cython representation is deallocated without underlying objects
freed, raise an RuntimeError.
"""
"""Deallocation of Core objects are ensured by Python layer."""
# TODO(lidiz) if users create server, and then dealloc it immediately.
# There is a potential memory leak of created Core server.
if self._status != AIO_SERVER_STATUS_STOPPED:

@ -118,7 +118,7 @@ cdef class Server:
def cancel_all_calls(self):
if not self.is_shutting_down:
raise RuntimeError("the server must be shutting down to cancel all calls")
raise UsageError("the server must be shutting down to cancel all calls")
elif self.is_shutdown:
return
else:
@ -136,7 +136,7 @@ cdef class Server:
pass
elif not self.is_shutting_down:
if self.backup_shutdown_queue is None:
raise RuntimeError('Server shutdown failed: no completion queue.')
raise InternalError('Server shutdown failed: no completion queue.')
else:
# the user didn't call shutdown - use our backup queue
self._c_shutdown(self.backup_shutdown_queue, None)

@ -17,12 +17,11 @@ gRPC Async API objects may only be used on the thread on which they were
created. AsyncIO doesn't provide thread safety for most of its APIs.
"""
import abc
from typing import Any, Optional, Sequence, Text, Tuple
import six
import grpc
from grpc._cython.cygrpc import EOF, AbortError, init_grpc_aio
from grpc._cython.cygrpc import (EOF, AbortError, BaseError, UsageError,
init_grpc_aio)
from ._base_call import Call, RpcContext, UnaryStreamCall, UnaryUnaryCall
from ._call import AioRpcError
@ -88,4 +87,4 @@ __all__ = ('AioRpcError', 'RpcContext', 'Call', 'UnaryUnaryCall',
'UnaryUnaryMultiCallable', 'ClientCallDetails',
'UnaryUnaryClientInterceptor', 'InterceptedUnaryUnaryCall',
'insecure_channel', 'server', 'Server', 'EOF', 'secure_channel',
'AbortError')
'AbortError', 'BaseError', 'UsageError')

@ -16,6 +16,7 @@
import asyncio
from functools import partial
import logging
import enum
from typing import AsyncIterable, Awaitable, Dict, Optional
import grpc
@ -238,6 +239,12 @@ class Call:
return self._repr()
class _APIStyle(enum.IntEnum):
UNKNOWN = 0
ASYNC_GENERATOR = 1
READER_WRITER = 2
class _UnaryResponseMixin(Call):
_call_response: asyncio.Task
@ -283,10 +290,19 @@ class _UnaryResponseMixin(Call):
class _StreamResponseMixin(Call):
_message_aiter: AsyncIterable[ResponseType]
_preparation: asyncio.Task
_response_style: _APIStyle
def _init_stream_response_mixin(self, preparation: asyncio.Task):
self._message_aiter = None
self._preparation = preparation
self._response_style = _APIStyle.UNKNOWN
def _update_response_style(self, style: _APIStyle):
if self._response_style is _APIStyle.UNKNOWN:
self._response_style = style
elif self._response_style is not style:
raise cygrpc.UsageError(
'Please don\'t mix two styles of API for streaming responses')
def cancel(self) -> bool:
if super().cancel():
@ -302,6 +318,7 @@ class _StreamResponseMixin(Call):
message = await self._read()
def __aiter__(self) -> AsyncIterable[ResponseType]:
self._update_response_style(_APIStyle.ASYNC_GENERATOR)
if self._message_aiter is None:
self._message_aiter = self._fetch_stream_responses()
return self._message_aiter
@ -328,6 +345,7 @@ class _StreamResponseMixin(Call):
if self.done():
await self._raise_for_status()
return cygrpc.EOF
self._update_response_style(_APIStyle.READER_WRITER)
response_message = await self._read()
@ -339,20 +357,28 @@ class _StreamResponseMixin(Call):
class _StreamRequestMixin(Call):
_metadata_sent: asyncio.Event
_done_writing: bool
_done_writing_flag: bool
_async_request_poller: Optional[asyncio.Task]
_request_style: _APIStyle
def _init_stream_request_mixin(
self, request_async_iterator: Optional[AsyncIterable[RequestType]]):
self._metadata_sent = asyncio.Event(loop=self._loop)
self._done_writing = False
self._done_writing_flag = False
# If user passes in an async iterator, create a consumer Task.
if request_async_iterator is not None:
self._async_request_poller = self._loop.create_task(
self._consume_request_iterator(request_async_iterator))
self._request_style = _APIStyle.ASYNC_GENERATOR
else:
self._async_request_poller = None
self._request_style = _APIStyle.READER_WRITER
def _raise_for_different_style(self, style: _APIStyle):
if self._request_style is not style:
raise cygrpc.UsageError(
'Please don\'t mix two styles of API for streaming requests')
def cancel(self) -> bool:
if super().cancel():
@ -369,8 +395,8 @@ class _StreamRequestMixin(Call):
self, request_async_iterator: AsyncIterable[RequestType]) -> None:
try:
async for request in request_async_iterator:
await self.write(request)
await self.done_writing()
await self._write(request)
await self._done_writing()
except AioRpcError as rpc_error:
# Rpc status should be exposed through other API. Exceptions raised
# within this Task won't be retrieved by another coroutine. It's
@ -378,10 +404,10 @@ class _StreamRequestMixin(Call):
_LOGGER.debug('Exception while consuming the request_iterator: %s',
rpc_error)
async def write(self, request: RequestType) -> None:
async def _write(self, request: RequestType) -> None:
if self.done():
raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
if self._done_writing:
if self._done_writing_flag:
raise asyncio.InvalidStateError(_RPC_HALF_CLOSED_DETAILS)
if not self._metadata_sent.is_set():
await self._metadata_sent.wait()
@ -398,14 +424,13 @@ class _StreamRequestMixin(Call):
self.cancel()
await self._raise_for_status()
async def done_writing(self) -> None:
"""Implementation of done_writing is idempotent."""
async def _done_writing(self) -> None:
if self.done():
# If the RPC is finished, do nothing.
return
if not self._done_writing:
if not self._done_writing_flag:
# If the done writing is not sent before, try to send it.
self._done_writing = True
self._done_writing_flag = True
try:
await self._cython_call.send_receive_close()
except asyncio.CancelledError:
@ -413,6 +438,18 @@ class _StreamRequestMixin(Call):
self.cancel()
await self._raise_for_status()
async def write(self, request: RequestType) -> None:
self._raise_for_different_style(_APIStyle.READER_WRITER)
await self._write(request)
async def done_writing(self) -> None:
"""Signal peer that client is done writing.
This method is idempotent.
"""
self._raise_for_different_style(_APIStyle.READER_WRITER)
await self._done_writing()
class UnaryUnaryCall(_UnaryResponseMixin, Call, _base_call.UnaryUnaryCall):
"""Object for managing unary-unary RPC calls.

@ -102,7 +102,7 @@ class TestConnectivityState(AioTestBase):
# It can raise exceptions since it is an usage error, but it should not
# segfault or abort.
with self.assertRaises(RuntimeError):
with self.assertRaises(aio.UsageError):
await channel.wait_for_state_change(
grpc.ChannelConnectivity.SHUTDOWN)

@ -231,15 +231,11 @@ class TestServer(AioTestBase):
# Uses reader API
self.assertEqual(_RESPONSE, await call.read())
# Uses async generator API
response_cnt = 0
# Uses async generator API, mixed!
with self.assertRaises(aio.UsageError):
async for response in call:
response_cnt += 1
self.assertEqual(_RESPONSE, response)
self.assertEqual(_NUM_STREAM_RESPONSES - 1, response_cnt)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
async def test_stream_unary_async_generator(self):
stream_unary_call = self._channel.stream_unary(_STREAM_UNARY_ASYNC_GEN)
call = stream_unary_call()

Loading…
Cancel
Save