Remove the add_callback method & fix segfault

pull/21232/head
Lidi Zheng 5 years ago
parent 46e963f8bc
commit fa4eb94ea2
  1. 9
      src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi
  2. 7
      src/python/grpcio/grpc/experimental/aio/__init__.py
  3. 41
      src/python/grpcio/grpc/experimental/aio/_base_call.py
  4. 17
      src/python/grpcio/grpc/experimental/aio/_call.py
  5. 25
      src/python/grpcio_tests/tests_aio/unit/call_test.py

@ -190,16 +190,15 @@ cdef class _AioCall:
) )
status_observer(status) status_observer(status)
self._status_received.set() self._status_received.set()
self._destroy_grpc_call()
def _handle_cancellation_from_application(self, def _handle_cancellation_from_application(self,
object cancellation_future, object cancellation_future,
object status_observer): object status_observer):
def _cancellation_action(finished_future): def _cancellation_action(finished_future):
status = self._cancel_and_create_status(finished_future) if not self._status_received.set():
status_observer(status) status = self._cancel_and_create_status(finished_future)
self._status_received.set() status_observer(status)
self._destroy_grpc_call() self._status_received.set()
cancellation_future.add_done_callback(_cancellation_action) cancellation_future.add_done_callback(_cancellation_action)

@ -23,7 +23,7 @@ import six
import grpc import grpc
from grpc._cython.cygrpc import init_grpc_aio from grpc._cython.cygrpc import init_grpc_aio
from ._base_call import Call, UnaryUnaryCall, UnaryStreamCall from ._base_call import RpcContext, Call, UnaryUnaryCall, UnaryStreamCall
from ._channel import Channel from ._channel import Channel
from ._channel import UnaryUnaryMultiCallable from ._channel import UnaryUnaryMultiCallable
from ._server import server from ._server import server
@ -48,5 +48,6 @@ def insecure_channel(target, options=None, compression=None):
################################### __all__ ################################# ################################### __all__ #################################
__all__ = ('Call', 'UnaryUnaryCall', 'UnaryStreamCall', 'init_grpc_aio', __all__ = ('RpcContext', 'Call', 'UnaryUnaryCall', 'UnaryStreamCall',
'Channel', 'UnaryUnaryMultiCallable', 'insecure_channel', 'server') 'init_grpc_aio', 'Channel', 'UnaryUnaryMultiCallable',
'insecure_channel', 'server')

@ -19,17 +19,17 @@ RPC, e.g. cancellation.
""" """
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from typing import AsyncIterable, Awaitable, Generic, Text from typing import Any, AsyncIterable, Awaitable, Callable, Generic, Text, Optional
import grpc import grpc
from ._typing import MetadataType, RequestType, ResponseType from ._typing import MetadataType, RequestType, ResponseType
__all__ = 'Call', 'UnaryUnaryCall', 'UnaryStreamCall' __all__ = 'RpcContext', 'Call', 'UnaryUnaryCall', 'UnaryStreamCall'
class Call(grpc.RpcContext, metaclass=ABCMeta): class RpcContext(metaclass=ABCMeta):
"""The abstract base class of an RPC on the client-side.""" """Provides RPC-related information and action."""
@abstractmethod @abstractmethod
def cancelled(self) -> bool: def cancelled(self) -> bool:
@ -51,6 +51,39 @@ class Call(grpc.RpcContext, metaclass=ABCMeta):
A bool indicates if the RPC is done. A bool indicates if the RPC is done.
""" """
@abstractmethod
def time_remaining(self) -> Optional[float]:
"""Describes the length of allowed time remaining for the RPC.
Returns:
A nonnegative float indicating the length of allowed time in seconds
remaining for the RPC to complete before it is considered to have
timed out, or None if no deadline was specified for the RPC.
"""
@abstractmethod
def cancel(self) -> bool:
"""Cancels the RPC.
Idempotent and has no effect if the RPC has already terminated.
Returns:
A bool indicates if the cancellation is performed or not.
"""
@abstractmethod
def add_done_callback(self, callback: Callable[[Any], None]) -> None:
"""Registers a callback to be called on RPC termination.
Args:
callback: A callable object will be called with the context object as
its only argument.
"""
class Call(RpcContext, metaclass=ABCMeta):
"""The abstract base class of an RPC on the client-side."""
@abstractmethod @abstractmethod
async def initial_metadata(self) -> MetadataType: async def initial_metadata(self) -> MetadataType:
"""Accesses the initial metadata sent by the server. """Accesses the initial metadata sent by the server.

@ -173,8 +173,8 @@ class Call(_base_call.Call):
def done(self) -> bool: def done(self) -> bool:
return self._status.done() return self._status.done()
def add_callback(self, unused_callback) -> None: def add_done_callback(self, unused_callback) -> None:
pass raise NotImplementedError()
def is_active(self) -> bool: def is_active(self) -> bool:
return self.done() return self.done()
@ -335,7 +335,8 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall):
_request_serializer: SerializingFunction _request_serializer: SerializingFunction
_response_deserializer: DeserializingFunction _response_deserializer: DeserializingFunction
_call: asyncio.Task _call: asyncio.Task
_aiter: AsyncIterable[ResponseType] _bytes_aiter: AsyncIterable[bytes]
_message_aiter: AsyncIterable[ResponseType]
def __init__(self, request: RequestType, deadline: Optional[float], def __init__(self, request: RequestType, deadline: Optional[float],
channel: cygrpc.AioChannel, method: bytes, channel: cygrpc.AioChannel, method: bytes,
@ -349,7 +350,7 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall):
self._request_serializer = request_serializer self._request_serializer = request_serializer
self._response_deserializer = response_deserializer self._response_deserializer = response_deserializer
self._call = self._loop.create_task(self._invoke()) self._call = self._loop.create_task(self._invoke())
self._aiter = self._process() self._message_aiter = self._process()
def __del__(self) -> None: def __del__(self) -> None:
if not self._status.done(): if not self._status.done():
@ -361,7 +362,7 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall):
serialized_request = _common.serialize(self._request, serialized_request = _common.serialize(self._request,
self._request_serializer) self._request_serializer)
self._aiter = await self._channel.unary_stream( self._bytes_aiter = await self._channel.unary_stream(
self._method, self._method,
serialized_request, serialized_request,
self._deadline, self._deadline,
@ -372,7 +373,7 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall):
async def _process(self) -> ResponseType: async def _process(self) -> ResponseType:
await self._call await self._call
async for serialized_response in self._aiter: async for serialized_response in self._bytes_aiter:
if self._cancellation.done(): if self._cancellation.done():
await self._status await self._status
if self._status.done(): if self._status.done():
@ -407,10 +408,10 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall):
_LOCAL_CANCELLATION_DETAILS, None, None)) _LOCAL_CANCELLATION_DETAILS, None, None))
def __aiter__(self) -> AsyncIterable[ResponseType]: def __aiter__(self) -> AsyncIterable[ResponseType]:
return self._aiter return self._message_aiter
async def read(self) -> ResponseType: async def read(self) -> ResponseType:
if self._status.done(): if self._status.done():
await self._raise_rpc_error_if_not_ok() await self._raise_rpc_error_if_not_ok()
raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS) raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
return await self._aiter.__anext__() return await self._message_aiter.__anext__()

@ -300,6 +300,31 @@ class TestUnaryStreamCall(AioTestBase):
with self.assertRaises(asyncio.InvalidStateError): with self.assertRaises(asyncio.InvalidStateError):
await call.read() await call.read()
async def test_unary_stream_async_generator(self):
async with aio.insecure_channel(self._server_target) as channel:
stub = test_pb2_grpc.TestServiceStub(channel)
# Prepares the request
request = messages_pb2.StreamingOutputCallRequest()
for _ in range(_NUM_STREAM_RESPONSES):
request.response_parameters.append(
messages_pb2.ResponseParameters(
size=_RESPONSE_PAYLOAD_SIZE,
interval_us=_RESPONSE_INTERVAL_US,
))
# Invokes the actual RPC
call = stub.StreamingOutputCall(request)
self.assertFalse(call.cancelled())
async for response in call:
self.assertIs(
type(response), messages_pb2.StreamingOutputCallResponse)
self.assertEqual(_RESPONSE_PAYLOAD_SIZE,
len(response.payload.body))
self.assertEqual(await call.code(), grpc.StatusCode.OK)
if __name__ == '__main__': if __name__ == '__main__':
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)

Loading…
Cancel
Save