Merge pull request #21910 from lidizheng/aio-no-mix

[Aio] Prohibit mixing two styles of API on client side
pull/21946/head
Lidi Zheng 5 years ago committed by GitHub
commit 56aa2c1143
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      src/python/grpcio/grpc/_cython/_cygrpc/aio/callback_common.pyx.pxi
  2. 6
      src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi
  3. 24
      src/python/grpcio/grpc/_cython/_cygrpc/aio/common.pyx.pxi
  4. 25
      src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi
  5. 4
      src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi
  6. 7
      src/python/grpcio/grpc/experimental/aio/__init__.py
  7. 57
      src/python/grpcio/grpc/experimental/aio/_call.py
  8. 2
      src/python/grpcio_tests/tests_aio/unit/connectivity_test.py
  9. 12
      src/python/grpcio_tests/tests_aio/unit/server_test.py

@ -66,7 +66,7 @@ cdef class CallbackWrapper:
cdef CallbackFailureHandler CQ_SHUTDOWN_FAILURE_HANDLER = CallbackFailureHandler( cdef CallbackFailureHandler CQ_SHUTDOWN_FAILURE_HANDLER = CallbackFailureHandler(
'grpc_completion_queue_shutdown', 'grpc_completion_queue_shutdown',
'Unknown', 'Unknown',
RuntimeError) InternalError)
cdef class CallbackCompletionQueue: cdef class CallbackCompletionQueue:

@ -71,8 +71,7 @@ cdef class AioChannel:
other design of API if necessary. other design of API if necessary.
""" """
if self._status in (AIO_CHANNEL_STATUS_DESTROYED, AIO_CHANNEL_STATUS_CLOSING): if self._status in (AIO_CHANNEL_STATUS_DESTROYED, AIO_CHANNEL_STATUS_CLOSING):
# TODO(lidiz) switch to UsageError raise UsageError('Channel is closed.')
raise RuntimeError('Channel is closed.')
cdef gpr_timespec c_deadline = _timespec_from_time(deadline) cdef gpr_timespec c_deadline = _timespec_from_time(deadline)
@ -115,8 +114,7 @@ cdef class AioChannel:
The _AioCall object. The _AioCall object.
""" """
if self.closed(): if self.closed():
# TODO(lidiz) switch to UsageError raise UsageError('Channel is closed.')
raise RuntimeError('Channel is closed.')
cdef CallCredentials cython_call_credentials cdef CallCredentials cython_call_credentials
if python_call_credentials is not None: if python_call_credentials is not None:

@ -73,3 +73,27 @@ _COMPRESSION_METADATA_STRING_MAPPING = {
CompressionAlgorithm.deflate: 'deflate', CompressionAlgorithm.deflate: 'deflate',
CompressionAlgorithm.gzip: 'gzip', CompressionAlgorithm.gzip: 'gzip',
} }
class BaseError(Exception):
"""The base class for exceptions generated by gRPC AsyncIO stack."""
class UsageError(BaseError):
"""Raised when the usage of API by applications is inappropriate.
For example, trying to invoke RPC on a closed channel, mixing two styles
of streaming API on the client side. This exception should not be
suppressed.
"""
class AbortError(BaseError):
"""Raised when calling abort in servicer methods.
This exception should not be suppressed. Applications may catch it to
perform certain clean-up logic, and then re-raise it.
"""
class InternalError(BaseError):
"""Raised upon unexpected errors in native code."""

@ -37,7 +37,7 @@ cdef class _HandlerCallDetails:
self.invocation_metadata = invocation_metadata self.invocation_metadata = invocation_metadata
class _ServerStoppedError(RuntimeError): class _ServerStoppedError(BaseError):
"""Raised if the server is stopped.""" """Raised if the server is stopped."""
@ -77,7 +77,7 @@ cdef class RPCState:
if self.abort_exception is not None: if self.abort_exception is not None:
raise self.abort_exception raise self.abort_exception
if self.status_sent: if self.status_sent:
raise RuntimeError(_RPC_FINISHED_DETAILS) raise UsageError(_RPC_FINISHED_DETAILS)
if self.server._status == AIO_SERVER_STATUS_STOPPED: if self.server._status == AIO_SERVER_STATUS_STOPPED:
raise _ServerStoppedError(_SERVER_STOPPED_DETAILS) raise _ServerStoppedError(_SERVER_STOPPED_DETAILS)
@ -107,11 +107,6 @@ cdef class RPCState:
grpc_call_unref(self.call) grpc_call_unref(self.call)
# TODO(lidiz) inherit this from Python level `AioRpcStatus`, we need to improve
# current code structure to make it happen.
class AbortError(Exception): pass
cdef class _ServicerContext: cdef class _ServicerContext:
cdef RPCState _rpc_state cdef RPCState _rpc_state
cdef object _loop cdef object _loop
@ -155,7 +150,7 @@ cdef class _ServicerContext:
self._rpc_state.raise_for_termination() self._rpc_state.raise_for_termination()
if self._rpc_state.metadata_sent: if self._rpc_state.metadata_sent:
raise RuntimeError('Send initial metadata failed: already sent') raise UsageError('Send initial metadata failed: already sent')
else: else:
await _send_initial_metadata( await _send_initial_metadata(
self._rpc_state, self._rpc_state,
@ -170,7 +165,7 @@ cdef class _ServicerContext:
str details='', str details='',
tuple trailing_metadata=_IMMUTABLE_EMPTY_METADATA): tuple trailing_metadata=_IMMUTABLE_EMPTY_METADATA):
if self._rpc_state.abort_exception is not None: if self._rpc_state.abort_exception is not None:
raise RuntimeError('Abort already called!') raise UsageError('Abort already called!')
else: else:
# Keeps track of the exception object. After abort happen, the RPC # Keeps track of the exception object. After abort happen, the RPC
# should stop execution. However, if users decided to suppress it, it # should stop execution. However, if users decided to suppress it, it
@ -579,7 +574,7 @@ cdef CallbackFailureHandler REQUEST_CALL_FAILURE_HANDLER = CallbackFailureHandle
cdef CallbackFailureHandler SERVER_SHUTDOWN_FAILURE_HANDLER = CallbackFailureHandler( cdef CallbackFailureHandler SERVER_SHUTDOWN_FAILURE_HANDLER = CallbackFailureHandler(
'grpc_server_shutdown_and_notify', 'grpc_server_shutdown_and_notify',
None, None,
RuntimeError) InternalError)
cdef class AioServer: cdef class AioServer:
@ -642,7 +637,7 @@ cdef class AioServer:
wrapper.c_functor() wrapper.c_functor()
) )
if error != GRPC_CALL_OK: if error != GRPC_CALL_OK:
raise RuntimeError("Error in grpc_server_request_call: %s" % error) raise InternalError("Error in grpc_server_request_call: %s" % error)
await future await future
return rpc_state return rpc_state
@ -692,7 +687,7 @@ cdef class AioServer:
if self._status == AIO_SERVER_STATUS_RUNNING: if self._status == AIO_SERVER_STATUS_RUNNING:
return return
elif self._status != AIO_SERVER_STATUS_READY: elif self._status != AIO_SERVER_STATUS_READY:
raise RuntimeError('Server not in ready state') raise UsageError('Server not in ready state')
self._status = AIO_SERVER_STATUS_RUNNING self._status = AIO_SERVER_STATUS_RUNNING
cdef object server_started = self._loop.create_future() cdef object server_started = self._loop.create_future()
@ -788,11 +783,7 @@ cdef class AioServer:
return True return True
def __dealloc__(self): def __dealloc__(self):
"""Deallocation of Core objects are ensured by Python grpc.aio.Server. """Deallocation of Core objects are ensured by Python layer."""
If the Cython representation is deallocated without underlying objects
freed, raise an RuntimeError.
"""
# TODO(lidiz) if users create server, and then dealloc it immediately. # TODO(lidiz) if users create server, and then dealloc it immediately.
# There is a potential memory leak of created Core server. # There is a potential memory leak of created Core server.
if self._status != AIO_SERVER_STATUS_STOPPED: if self._status != AIO_SERVER_STATUS_STOPPED:

@ -118,7 +118,7 @@ cdef class Server:
def cancel_all_calls(self): def cancel_all_calls(self):
if not self.is_shutting_down: if not self.is_shutting_down:
raise RuntimeError("the server must be shutting down to cancel all calls") raise UsageError("the server must be shutting down to cancel all calls")
elif self.is_shutdown: elif self.is_shutdown:
return return
else: else:
@ -136,7 +136,7 @@ cdef class Server:
pass pass
elif not self.is_shutting_down: elif not self.is_shutting_down:
if self.backup_shutdown_queue is None: if self.backup_shutdown_queue is None:
raise RuntimeError('Server shutdown failed: no completion queue.') raise InternalError('Server shutdown failed: no completion queue.')
else: else:
# the user didn't call shutdown - use our backup queue # the user didn't call shutdown - use our backup queue
self._c_shutdown(self.backup_shutdown_queue, None) self._c_shutdown(self.backup_shutdown_queue, None)

@ -17,12 +17,11 @@ gRPC Async API objects may only be used on the thread on which they were
created. AsyncIO doesn't provide thread safety for most of its APIs. created. AsyncIO doesn't provide thread safety for most of its APIs.
""" """
import abc
from typing import Any, Optional, Sequence, Text, Tuple from typing import Any, Optional, Sequence, Text, Tuple
import six
import grpc import grpc
from grpc._cython.cygrpc import EOF, AbortError, init_grpc_aio from grpc._cython.cygrpc import (EOF, AbortError, BaseError, UsageError,
init_grpc_aio)
from ._base_call import Call, RpcContext, UnaryStreamCall, UnaryUnaryCall from ._base_call import Call, RpcContext, UnaryStreamCall, UnaryUnaryCall
from ._call import AioRpcError from ._call import AioRpcError
@ -88,4 +87,4 @@ __all__ = ('AioRpcError', 'RpcContext', 'Call', 'UnaryUnaryCall',
'UnaryUnaryMultiCallable', 'ClientCallDetails', 'UnaryUnaryMultiCallable', 'ClientCallDetails',
'UnaryUnaryClientInterceptor', 'InterceptedUnaryUnaryCall', 'UnaryUnaryClientInterceptor', 'InterceptedUnaryUnaryCall',
'insecure_channel', 'server', 'Server', 'EOF', 'secure_channel', 'insecure_channel', 'server', 'Server', 'EOF', 'secure_channel',
'AbortError') 'AbortError', 'BaseError', 'UsageError')

@ -16,6 +16,7 @@
import asyncio import asyncio
from functools import partial from functools import partial
import logging import logging
import enum
from typing import AsyncIterable, Awaitable, Dict, Optional from typing import AsyncIterable, Awaitable, Dict, Optional
import grpc import grpc
@ -238,6 +239,12 @@ class Call:
return self._repr() return self._repr()
class _APIStyle(enum.IntEnum):
UNKNOWN = 0
ASYNC_GENERATOR = 1
READER_WRITER = 2
class _UnaryResponseMixin(Call): class _UnaryResponseMixin(Call):
_call_response: asyncio.Task _call_response: asyncio.Task
@ -283,10 +290,19 @@ class _UnaryResponseMixin(Call):
class _StreamResponseMixin(Call): class _StreamResponseMixin(Call):
_message_aiter: AsyncIterable[ResponseType] _message_aiter: AsyncIterable[ResponseType]
_preparation: asyncio.Task _preparation: asyncio.Task
_response_style: _APIStyle
def _init_stream_response_mixin(self, preparation: asyncio.Task): def _init_stream_response_mixin(self, preparation: asyncio.Task):
self._message_aiter = None self._message_aiter = None
self._preparation = preparation self._preparation = preparation
self._response_style = _APIStyle.UNKNOWN
def _update_response_style(self, style: _APIStyle):
if self._response_style is _APIStyle.UNKNOWN:
self._response_style = style
elif self._response_style is not style:
raise cygrpc.UsageError(
'Please don\'t mix two styles of API for streaming responses')
def cancel(self) -> bool: def cancel(self) -> bool:
if super().cancel(): if super().cancel():
@ -302,6 +318,7 @@ class _StreamResponseMixin(Call):
message = await self._read() message = await self._read()
def __aiter__(self) -> AsyncIterable[ResponseType]: def __aiter__(self) -> AsyncIterable[ResponseType]:
self._update_response_style(_APIStyle.ASYNC_GENERATOR)
if self._message_aiter is None: if self._message_aiter is None:
self._message_aiter = self._fetch_stream_responses() self._message_aiter = self._fetch_stream_responses()
return self._message_aiter return self._message_aiter
@ -328,6 +345,7 @@ class _StreamResponseMixin(Call):
if self.done(): if self.done():
await self._raise_for_status() await self._raise_for_status()
return cygrpc.EOF return cygrpc.EOF
self._update_response_style(_APIStyle.READER_WRITER)
response_message = await self._read() response_message = await self._read()
@ -339,20 +357,28 @@ class _StreamResponseMixin(Call):
class _StreamRequestMixin(Call): class _StreamRequestMixin(Call):
_metadata_sent: asyncio.Event _metadata_sent: asyncio.Event
_done_writing: bool _done_writing_flag: bool
_async_request_poller: Optional[asyncio.Task] _async_request_poller: Optional[asyncio.Task]
_request_style: _APIStyle
def _init_stream_request_mixin( def _init_stream_request_mixin(
self, request_async_iterator: Optional[AsyncIterable[RequestType]]): self, request_async_iterator: Optional[AsyncIterable[RequestType]]):
self._metadata_sent = asyncio.Event(loop=self._loop) self._metadata_sent = asyncio.Event(loop=self._loop)
self._done_writing = False self._done_writing_flag = False
# If user passes in an async iterator, create a consumer Task. # If user passes in an async iterator, create a consumer Task.
if request_async_iterator is not None: if request_async_iterator is not None:
self._async_request_poller = self._loop.create_task( self._async_request_poller = self._loop.create_task(
self._consume_request_iterator(request_async_iterator)) self._consume_request_iterator(request_async_iterator))
self._request_style = _APIStyle.ASYNC_GENERATOR
else: else:
self._async_request_poller = None self._async_request_poller = None
self._request_style = _APIStyle.READER_WRITER
def _raise_for_different_style(self, style: _APIStyle):
if self._request_style is not style:
raise cygrpc.UsageError(
'Please don\'t mix two styles of API for streaming requests')
def cancel(self) -> bool: def cancel(self) -> bool:
if super().cancel(): if super().cancel():
@ -369,8 +395,8 @@ class _StreamRequestMixin(Call):
self, request_async_iterator: AsyncIterable[RequestType]) -> None: self, request_async_iterator: AsyncIterable[RequestType]) -> None:
try: try:
async for request in request_async_iterator: async for request in request_async_iterator:
await self.write(request) await self._write(request)
await self.done_writing() await self._done_writing()
except AioRpcError as rpc_error: except AioRpcError as rpc_error:
# Rpc status should be exposed through other API. Exceptions raised # Rpc status should be exposed through other API. Exceptions raised
# within this Task won't be retrieved by another coroutine. It's # within this Task won't be retrieved by another coroutine. It's
@ -378,10 +404,10 @@ class _StreamRequestMixin(Call):
_LOGGER.debug('Exception while consuming the request_iterator: %s', _LOGGER.debug('Exception while consuming the request_iterator: %s',
rpc_error) rpc_error)
async def write(self, request: RequestType) -> None: async def _write(self, request: RequestType) -> None:
if self.done(): if self.done():
raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS) raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
if self._done_writing: if self._done_writing_flag:
raise asyncio.InvalidStateError(_RPC_HALF_CLOSED_DETAILS) raise asyncio.InvalidStateError(_RPC_HALF_CLOSED_DETAILS)
if not self._metadata_sent.is_set(): if not self._metadata_sent.is_set():
await self._metadata_sent.wait() await self._metadata_sent.wait()
@ -398,14 +424,13 @@ class _StreamRequestMixin(Call):
self.cancel() self.cancel()
await self._raise_for_status() await self._raise_for_status()
async def done_writing(self) -> None: async def _done_writing(self) -> None:
"""Implementation of done_writing is idempotent."""
if self.done(): if self.done():
# If the RPC is finished, do nothing. # If the RPC is finished, do nothing.
return return
if not self._done_writing: if not self._done_writing_flag:
# If the done writing is not sent before, try to send it. # If the done writing is not sent before, try to send it.
self._done_writing = True self._done_writing_flag = True
try: try:
await self._cython_call.send_receive_close() await self._cython_call.send_receive_close()
except asyncio.CancelledError: except asyncio.CancelledError:
@ -413,6 +438,18 @@ class _StreamRequestMixin(Call):
self.cancel() self.cancel()
await self._raise_for_status() await self._raise_for_status()
async def write(self, request: RequestType) -> None:
self._raise_for_different_style(_APIStyle.READER_WRITER)
await self._write(request)
async def done_writing(self) -> None:
"""Signal peer that client is done writing.
This method is idempotent.
"""
self._raise_for_different_style(_APIStyle.READER_WRITER)
await self._done_writing()
class UnaryUnaryCall(_UnaryResponseMixin, Call, _base_call.UnaryUnaryCall): class UnaryUnaryCall(_UnaryResponseMixin, Call, _base_call.UnaryUnaryCall):
"""Object for managing unary-unary RPC calls. """Object for managing unary-unary RPC calls.

@ -102,7 +102,7 @@ class TestConnectivityState(AioTestBase):
# It can raise exceptions since it is an usage error, but it should not # It can raise exceptions since it is an usage error, but it should not
# segfault or abort. # segfault or abort.
with self.assertRaises(RuntimeError): with self.assertRaises(aio.UsageError):
await channel.wait_for_state_change( await channel.wait_for_state_change(
grpc.ChannelConnectivity.SHUTDOWN) grpc.ChannelConnectivity.SHUTDOWN)

@ -231,14 +231,10 @@ class TestServer(AioTestBase):
# Uses reader API # Uses reader API
self.assertEqual(_RESPONSE, await call.read()) self.assertEqual(_RESPONSE, await call.read())
# Uses async generator API # Uses async generator API, mixed!
response_cnt = 0 with self.assertRaises(aio.UsageError):
async for response in call: async for response in call:
response_cnt += 1 self.assertEqual(_RESPONSE, response)
self.assertEqual(_RESPONSE, response)
self.assertEqual(_NUM_STREAM_RESPONSES - 1, response_cnt)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
async def test_stream_unary_async_generator(self): async def test_stream_unary_async_generator(self):
stream_unary_call = self._channel.stream_unary(_STREAM_UNARY_ASYNC_GEN) stream_unary_call = self._channel.stream_unary(_STREAM_UNARY_ASYNC_GEN)

Loading…
Cancel
Save