Remove the add_callback method & fix segfault

pull/21232/head
Lidi Zheng 5 years ago
parent 46e963f8bc
commit fa4eb94ea2
  1. 3
      src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi
  2. 7
      src/python/grpcio/grpc/experimental/aio/__init__.py
  3. 41
      src/python/grpcio/grpc/experimental/aio/_base_call.py
  4. 17
      src/python/grpcio/grpc/experimental/aio/_call.py
  5. 25
      src/python/grpcio_tests/tests_aio/unit/call_test.py

@ -190,16 +190,15 @@ cdef class _AioCall:
)
status_observer(status)
self._status_received.set()
self._destroy_grpc_call()
def _handle_cancellation_from_application(self,
object cancellation_future,
object status_observer):
def _cancellation_action(finished_future):
if not self._status_received.set():
status = self._cancel_and_create_status(finished_future)
status_observer(status)
self._status_received.set()
self._destroy_grpc_call()
cancellation_future.add_done_callback(_cancellation_action)

@ -23,7 +23,7 @@ import six
import grpc
from grpc._cython.cygrpc import init_grpc_aio
from ._base_call import Call, UnaryUnaryCall, UnaryStreamCall
from ._base_call import RpcContext, Call, UnaryUnaryCall, UnaryStreamCall
from ._channel import Channel
from ._channel import UnaryUnaryMultiCallable
from ._server import server
@ -48,5 +48,6 @@ def insecure_channel(target, options=None, compression=None):
################################### __all__ #################################
__all__ = ('Call', 'UnaryUnaryCall', 'UnaryStreamCall', 'init_grpc_aio',
'Channel', 'UnaryUnaryMultiCallable', 'insecure_channel', 'server')
__all__ = ('RpcContext', 'Call', 'UnaryUnaryCall', 'UnaryStreamCall',
'init_grpc_aio', 'Channel', 'UnaryUnaryMultiCallable',
'insecure_channel', 'server')

@ -19,17 +19,17 @@ RPC, e.g. cancellation.
"""
from abc import ABCMeta, abstractmethod
from typing import AsyncIterable, Awaitable, Generic, Text
from typing import Any, AsyncIterable, Awaitable, Callable, Generic, Text, Optional
import grpc
from ._typing import MetadataType, RequestType, ResponseType
__all__ = 'Call', 'UnaryUnaryCall', 'UnaryStreamCall'
__all__ = 'RpcContext', 'Call', 'UnaryUnaryCall', 'UnaryStreamCall'
class Call(grpc.RpcContext, metaclass=ABCMeta):
"""The abstract base class of an RPC on the client-side."""
class RpcContext(metaclass=ABCMeta):
"""Provides RPC-related information and action."""
@abstractmethod
def cancelled(self) -> bool:
@ -51,6 +51,39 @@ class Call(grpc.RpcContext, metaclass=ABCMeta):
A bool indicates if the RPC is done.
"""
@abstractmethod
def time_remaining(self) -> Optional[float]:
"""Describes the length of allowed time remaining for the RPC.
Returns:
A nonnegative float indicating the length of allowed time in seconds
remaining for the RPC to complete before it is considered to have
timed out, or None if no deadline was specified for the RPC.
"""
@abstractmethod
def cancel(self) -> bool:
"""Cancels the RPC.
Idempotent and has no effect if the RPC has already terminated.
Returns:
A bool indicates if the cancellation is performed or not.
"""
@abstractmethod
def add_done_callback(self, callback: Callable[[Any], None]) -> None:
"""Registers a callback to be called on RPC termination.
Args:
callback: A callable object will be called with the context object as
its only argument.
"""
class Call(RpcContext, metaclass=ABCMeta):
"""The abstract base class of an RPC on the client-side."""
@abstractmethod
async def initial_metadata(self) -> MetadataType:
"""Accesses the initial metadata sent by the server.

@ -173,8 +173,8 @@ class Call(_base_call.Call):
def done(self) -> bool:
return self._status.done()
def add_callback(self, unused_callback) -> None:
pass
def add_done_callback(self, unused_callback) -> None:
raise NotImplementedError()
def is_active(self) -> bool:
return self.done()
@ -335,7 +335,8 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall):
_request_serializer: SerializingFunction
_response_deserializer: DeserializingFunction
_call: asyncio.Task
_aiter: AsyncIterable[ResponseType]
_bytes_aiter: AsyncIterable[bytes]
_message_aiter: AsyncIterable[ResponseType]
def __init__(self, request: RequestType, deadline: Optional[float],
channel: cygrpc.AioChannel, method: bytes,
@ -349,7 +350,7 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall):
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
self._call = self._loop.create_task(self._invoke())
self._aiter = self._process()
self._message_aiter = self._process()
def __del__(self) -> None:
if not self._status.done():
@ -361,7 +362,7 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall):
serialized_request = _common.serialize(self._request,
self._request_serializer)
self._aiter = await self._channel.unary_stream(
self._bytes_aiter = await self._channel.unary_stream(
self._method,
serialized_request,
self._deadline,
@ -372,7 +373,7 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall):
async def _process(self) -> ResponseType:
await self._call
async for serialized_response in self._aiter:
async for serialized_response in self._bytes_aiter:
if self._cancellation.done():
await self._status
if self._status.done():
@ -407,10 +408,10 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall):
_LOCAL_CANCELLATION_DETAILS, None, None))
def __aiter__(self) -> AsyncIterable[ResponseType]:
return self._aiter
return self._message_aiter
async def read(self) -> ResponseType:
if self._status.done():
await self._raise_rpc_error_if_not_ok()
raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
return await self._aiter.__anext__()
return await self._message_aiter.__anext__()

@ -300,6 +300,31 @@ class TestUnaryStreamCall(AioTestBase):
with self.assertRaises(asyncio.InvalidStateError):
await call.read()
async def test_unary_stream_async_generator(self):
async with aio.insecure_channel(self._server_target) as channel:
stub = test_pb2_grpc.TestServiceStub(channel)
# Prepares the request
request = messages_pb2.StreamingOutputCallRequest()
for _ in range(_NUM_STREAM_RESPONSES):
request.response_parameters.append(
messages_pb2.ResponseParameters(
size=_RESPONSE_PAYLOAD_SIZE,
interval_us=_RESPONSE_INTERVAL_US,
))
# Invokes the actual RPC
call = stub.StreamingOutputCall(request)
self.assertFalse(call.cancelled())
async for response in call:
self.assertIs(
type(response), messages_pb2.StreamingOutputCallResponse)
self.assertEqual(_RESPONSE_PAYLOAD_SIZE,
len(response.payload.body))
self.assertEqual(await call.code(), grpc.StatusCode.OK)
if __name__ == '__main__':
logging.basicConfig(level=logging.DEBUG)

Loading…
Cancel
Save