From fa4eb94ea22aa6370a6f572cd3e8a01a75032b27 Mon Sep 17 00:00:00 2001 From: Lidi Zheng Date: Thu, 5 Dec 2019 18:34:42 -0800 Subject: [PATCH] Remove the add_callback method & fix segfault --- .../grpc/_cython/_cygrpc/aio/call.pyx.pxi | 9 ++-- .../grpcio/grpc/experimental/aio/__init__.py | 7 ++-- .../grpc/experimental/aio/_base_call.py | 41 +++++++++++++++++-- .../grpcio/grpc/experimental/aio/_call.py | 17 ++++---- .../grpcio_tests/tests_aio/unit/call_test.py | 25 +++++++++++ 5 files changed, 79 insertions(+), 20 deletions(-) diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi index 14a7e6df197..defff7fe881 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi @@ -190,16 +190,15 @@ cdef class _AioCall: ) status_observer(status) self._status_received.set() - self._destroy_grpc_call() def _handle_cancellation_from_application(self, object cancellation_future, object status_observer): def _cancellation_action(finished_future): - status = self._cancel_and_create_status(finished_future) - status_observer(status) - self._status_received.set() - self._destroy_grpc_call() + if not self._status_received.set(): + status = self._cancel_and_create_status(finished_future) + status_observer(status) + self._status_received.set() cancellation_future.add_done_callback(_cancellation_action) diff --git a/src/python/grpcio/grpc/experimental/aio/__init__.py b/src/python/grpcio/grpc/experimental/aio/__init__.py index af45fd2c692..b84c96c93a6 100644 --- a/src/python/grpcio/grpc/experimental/aio/__init__.py +++ b/src/python/grpcio/grpc/experimental/aio/__init__.py @@ -23,7 +23,7 @@ import six import grpc from grpc._cython.cygrpc import init_grpc_aio -from ._base_call import Call, UnaryUnaryCall, UnaryStreamCall +from ._base_call import RpcContext, Call, UnaryUnaryCall, UnaryStreamCall from ._channel import Channel from ._channel import UnaryUnaryMultiCallable from ._server import server @@ -48,5 +48,6 @@ def insecure_channel(target, options=None, compression=None): ################################### __all__ ################################# -__all__ = ('Call', 'UnaryUnaryCall', 'UnaryStreamCall', 'init_grpc_aio', - 'Channel', 'UnaryUnaryMultiCallable', 'insecure_channel', 'server') +__all__ = ('RpcContext', 'Call', 'UnaryUnaryCall', 'UnaryStreamCall', + 'init_grpc_aio', 'Channel', 'UnaryUnaryMultiCallable', + 'insecure_channel', 'server') diff --git a/src/python/grpcio/grpc/experimental/aio/_base_call.py b/src/python/grpcio/grpc/experimental/aio/_base_call.py index d76096e8e1d..dd96c5e1107 100644 --- a/src/python/grpcio/grpc/experimental/aio/_base_call.py +++ b/src/python/grpcio/grpc/experimental/aio/_base_call.py @@ -19,17 +19,17 @@ RPC, e.g. cancellation. """ from abc import ABCMeta, abstractmethod -from typing import AsyncIterable, Awaitable, Generic, Text +from typing import Any, AsyncIterable, Awaitable, Callable, Generic, Text, Optional import grpc from ._typing import MetadataType, RequestType, ResponseType -__all__ = 'Call', 'UnaryUnaryCall', 'UnaryStreamCall' +__all__ = 'RpcContext', 'Call', 'UnaryUnaryCall', 'UnaryStreamCall' -class Call(grpc.RpcContext, metaclass=ABCMeta): - """The abstract base class of an RPC on the client-side.""" +class RpcContext(metaclass=ABCMeta): + """Provides RPC-related information and action.""" @abstractmethod def cancelled(self) -> bool: @@ -51,6 +51,39 @@ class Call(grpc.RpcContext, metaclass=ABCMeta): A bool indicates if the RPC is done. """ + @abstractmethod + def time_remaining(self) -> Optional[float]: + """Describes the length of allowed time remaining for the RPC. + + Returns: + A nonnegative float indicating the length of allowed time in seconds + remaining for the RPC to complete before it is considered to have + timed out, or None if no deadline was specified for the RPC. + """ + + @abstractmethod + def cancel(self) -> bool: + """Cancels the RPC. + + Idempotent and has no effect if the RPC has already terminated. + + Returns: + A bool indicates if the cancellation is performed or not. + """ + + @abstractmethod + def add_done_callback(self, callback: Callable[[Any], None]) -> None: + """Registers a callback to be called on RPC termination. + + Args: + callback: A callable object will be called with the context object as + its only argument. + """ + + +class Call(RpcContext, metaclass=ABCMeta): + """The abstract base class of an RPC on the client-side.""" + @abstractmethod async def initial_metadata(self) -> MetadataType: """Accesses the initial metadata sent by the server. diff --git a/src/python/grpcio/grpc/experimental/aio/_call.py b/src/python/grpcio/grpc/experimental/aio/_call.py index 395e57756a6..09257e397a6 100644 --- a/src/python/grpcio/grpc/experimental/aio/_call.py +++ b/src/python/grpcio/grpc/experimental/aio/_call.py @@ -173,8 +173,8 @@ class Call(_base_call.Call): def done(self) -> bool: return self._status.done() - def add_callback(self, unused_callback) -> None: - pass + def add_done_callback(self, unused_callback) -> None: + raise NotImplementedError() def is_active(self) -> bool: return self.done() @@ -335,7 +335,8 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall): _request_serializer: SerializingFunction _response_deserializer: DeserializingFunction _call: asyncio.Task - _aiter: AsyncIterable[ResponseType] + _bytes_aiter: AsyncIterable[bytes] + _message_aiter: AsyncIterable[ResponseType] def __init__(self, request: RequestType, deadline: Optional[float], channel: cygrpc.AioChannel, method: bytes, @@ -349,7 +350,7 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall): self._request_serializer = request_serializer self._response_deserializer = response_deserializer self._call = self._loop.create_task(self._invoke()) - self._aiter = self._process() + self._message_aiter = self._process() def __del__(self) -> None: if not self._status.done(): @@ -361,7 +362,7 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall): serialized_request = _common.serialize(self._request, self._request_serializer) - self._aiter = await self._channel.unary_stream( + self._bytes_aiter = await self._channel.unary_stream( self._method, serialized_request, self._deadline, @@ -372,7 +373,7 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall): async def _process(self) -> ResponseType: await self._call - async for serialized_response in self._aiter: + async for serialized_response in self._bytes_aiter: if self._cancellation.done(): await self._status if self._status.done(): @@ -407,10 +408,10 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall): _LOCAL_CANCELLATION_DETAILS, None, None)) def __aiter__(self) -> AsyncIterable[ResponseType]: - return self._aiter + return self._message_aiter async def read(self) -> ResponseType: if self._status.done(): await self._raise_rpc_error_if_not_ok() raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS) - return await self._aiter.__anext__() + return await self._message_aiter.__anext__() diff --git a/src/python/grpcio_tests/tests_aio/unit/call_test.py b/src/python/grpcio_tests/tests_aio/unit/call_test.py index 44477c36159..0ed296c6b7b 100644 --- a/src/python/grpcio_tests/tests_aio/unit/call_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/call_test.py @@ -300,6 +300,31 @@ class TestUnaryStreamCall(AioTestBase): with self.assertRaises(asyncio.InvalidStateError): await call.read() + async def test_unary_stream_async_generator(self): + async with aio.insecure_channel(self._server_target) as channel: + stub = test_pb2_grpc.TestServiceStub(channel) + + # Prepares the request + request = messages_pb2.StreamingOutputCallRequest() + for _ in range(_NUM_STREAM_RESPONSES): + request.response_parameters.append( + messages_pb2.ResponseParameters( + size=_RESPONSE_PAYLOAD_SIZE, + interval_us=_RESPONSE_INTERVAL_US, + )) + + # Invokes the actual RPC + call = stub.StreamingOutputCall(request) + self.assertFalse(call.cancelled()) + + async for response in call: + self.assertIs( + type(response), messages_pb2.StreamingOutputCallResponse) + self.assertEqual(_RESPONSE_PAYLOAD_SIZE, + len(response.payload.body)) + + self.assertEqual(await call.code(), grpc.StatusCode.OK) + if __name__ == '__main__': logging.basicConfig(level=logging.DEBUG)