Implement add_done_callbacks and time_remaining

pull/21681/head
Lidi Zheng 5 years ago
parent adab340647
commit 1c78ccd44e
  1. 2
      src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pxd.pxi
  2. 7
      src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi
  3. 6
      src/python/grpcio/grpc/experimental/aio/_base_call.py
  4. 22
      src/python/grpcio/grpc/experimental/aio/_call.py
  5. 1
      src/python/grpcio/grpc/experimental/aio/_typing.py
  6. 160
      src/python/grpcio_tests/tests_aio/unit/done_callback_test.py

@ -28,4 +28,6 @@ cdef class _AioCall(GrpcCallWrapper):
# because Core is holding a pointer for the callback handler. # because Core is holding a pointer for the callback handler.
bint _is_locally_cancelled bint _is_locally_cancelled
object _deadline
cdef void _create_grpc_call(self, object timeout, bytes method, CallCredentials credentials) except * cdef void _create_grpc_call(self, object timeout, bytes method, CallCredentials credentials) except *

@ -36,6 +36,7 @@ cdef class _AioCall(GrpcCallWrapper):
self._loop = asyncio.get_event_loop() self._loop = asyncio.get_event_loop()
self._create_grpc_call(deadline, method, call_credentials) self._create_grpc_call(deadline, method, call_credentials)
self._is_locally_cancelled = False self._is_locally_cancelled = False
self._deadline = deadline
def __dealloc__(self): def __dealloc__(self):
if self.call: if self.call:
@ -84,6 +85,12 @@ cdef class _AioCall(GrpcCallWrapper):
grpc_slice_unref(method_slice) grpc_slice_unref(method_slice)
def time_remaining(self):
if self._deadline is None:
return None
else:
return max(0, self._deadline - time.time())
def cancel(self, AioRpcStatus status): def cancel(self, AioRpcStatus status):
"""Cancels the RPC in Core with given RPC status. """Cancels the RPC in Core with given RPC status.

@ -24,7 +24,7 @@ from typing import (Any, AsyncIterable, Awaitable, Callable, Generic, Optional,
import grpc import grpc
from ._typing import EOFType, MetadataType, RequestType, ResponseType from ._typing import EOFType, MetadataType, RequestType, ResponseType, DoneCallbackType
__all__ = 'RpcContext', 'Call', 'UnaryUnaryCall', 'UnaryStreamCall' __all__ = 'RpcContext', 'Call', 'UnaryUnaryCall', 'UnaryStreamCall'
@ -73,11 +73,11 @@ class RpcContext(metaclass=ABCMeta):
""" """
@abstractmethod @abstractmethod
def add_done_callback(self, callback: Callable[[Any], None]) -> None: def add_done_callback(self, callback: DoneCallbackType) -> None:
"""Registers a callback to be called on RPC termination. """Registers a callback to be called on RPC termination.
Args: Args:
callback: A callable object will be called with the context object as callback: A callable object will be called with the call object as
its only argument. its only argument.
""" """

@ -14,7 +14,7 @@
"""Invocation-side implementation of gRPC Asyncio Python.""" """Invocation-side implementation of gRPC Asyncio Python."""
import asyncio import asyncio
from typing import AsyncIterable, Awaitable, Dict, Optional from typing import AsyncIterable, Awaitable, List, Dict, Optional
import grpc import grpc
from grpc import _common from grpc import _common
@ -22,7 +22,7 @@ from grpc._cython import cygrpc
from . import _base_call from . import _base_call
from ._typing import (DeserializingFunction, MetadataType, RequestType, from ._typing import (DeserializingFunction, MetadataType, RequestType,
ResponseType, SerializingFunction) ResponseType, SerializingFunction, DoneCallbackType)
__all__ = 'AioRpcError', 'Call', 'UnaryUnaryCall', 'UnaryStreamCall' __all__ = 'AioRpcError', 'Call', 'UnaryUnaryCall', 'UnaryStreamCall'
@ -157,6 +157,7 @@ class Call(_base_call.Call):
_initial_metadata: Awaitable[MetadataType] _initial_metadata: Awaitable[MetadataType]
_locally_cancelled: bool _locally_cancelled: bool
_cython_call: cygrpc._AioCall _cython_call: cygrpc._AioCall
_done_callbacks: List[DoneCallbackType]
def __init__(self, cython_call: cygrpc._AioCall) -> None: def __init__(self, cython_call: cygrpc._AioCall) -> None:
self._loop = asyncio.get_event_loop() self._loop = asyncio.get_event_loop()
@ -165,6 +166,7 @@ class Call(_base_call.Call):
self._initial_metadata = self._loop.create_future() self._initial_metadata = self._loop.create_future()
self._locally_cancelled = False self._locally_cancelled = False
self._cython_call = cython_call self._cython_call = cython_call
self._done_callbacks = []
def __del__(self) -> None: def __del__(self) -> None:
if not self._status.done(): if not self._status.done():
@ -192,11 +194,14 @@ class Call(_base_call.Call):
def done(self) -> bool: def done(self) -> bool:
return self._status.done() return self._status.done()
def add_done_callback(self, unused_callback) -> None: def add_done_callback(self, callback: DoneCallbackType) -> None:
raise NotImplementedError() if self.done():
callback(self)
else:
self._done_callbacks.append(callback)
def time_remaining(self) -> Optional[float]: def time_remaining(self) -> Optional[float]:
raise NotImplementedError() return self._cython_call.time_remaining()
async def initial_metadata(self) -> MetadataType: async def initial_metadata(self) -> MetadataType:
return await self._initial_metadata return await self._initial_metadata
@ -220,9 +225,7 @@ class Call(_base_call.Call):
def _set_status(self, status: cygrpc.AioRpcStatus) -> None: def _set_status(self, status: cygrpc.AioRpcStatus) -> None:
"""Private method to set final status of the RPC. """Private method to set final status of the RPC.
This method may be called multiple time due to data race between local This method should only be invoked once.
cancellation (by application) and Core receiving status from peer. We
make no promise here which one will win.
""" """
# In case of local cancellation, flip the flag. # In case of local cancellation, flip the flag.
if status.details() is _LOCAL_CANCELLATION_DETAILS: if status.details() is _LOCAL_CANCELLATION_DETAILS:
@ -236,6 +239,9 @@ class Call(_base_call.Call):
self._status.set_result(status) self._status.set_result(status)
self._code = _common.CYGRPC_STATUS_CODE_TO_STATUS_CODE[status.code()] self._code = _common.CYGRPC_STATUS_CODE_TO_STATUS_CODE[status.code()]
for callback in self._done_callbacks:
callback(self)
async def _raise_for_status(self) -> None: async def _raise_for_status(self) -> None:
if self._locally_cancelled: if self._locally_cancelled:
raise asyncio.CancelledError() raise asyncio.CancelledError()

@ -24,3 +24,4 @@ MetadatumType = Tuple[Text, AnyStr]
MetadataType = Sequence[MetadatumType] MetadataType = Sequence[MetadatumType]
ChannelArgumentType = Sequence[Tuple[Text, Any]] ChannelArgumentType = Sequence[Tuple[Text, Any]]
EOFType = type(EOF) EOFType = type(EOF)
DoneCallbackType = Callable[[Any], None]

@ -0,0 +1,160 @@
# Copyright 2020 The gRPC Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing the done callbacks mechanism."""
# Copyright 2019 The gRPC Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
import unittest
import time
import gc
import grpc
from grpc.experimental import aio
from tests_aio.unit._test_base import AioTestBase
from tests.unit.framework.common import test_constants
from src.proto.grpc.testing import messages_pb2, test_pb2_grpc
from tests_aio.unit._test_server import start_test_server
_NUM_STREAM_RESPONSES = 5
_REQUEST_PAYLOAD_SIZE = 7
_RESPONSE_PAYLOAD_SIZE = 42
def _inject_callbacks(call):
first_callback_ran = asyncio.Event()
def first_callback(unused_call):
first_callback_ran.set()
second_callback_ran = asyncio.Event()
def second_callback(unused_call):
second_callback_ran.set()
call.add_done_callback(first_callback)
call.add_done_callback(second_callback)
async def validation():
await asyncio.wait_for(
asyncio.gather(first_callback_ran.wait(), second_callback_ran.wait()),
test_constants.SHORT_TIMEOUT
)
return validation()
class TestDoneCallback(AioTestBase):
async def setUp(self):
address, self._server = await start_test_server()
self._channel = aio.insecure_channel(address)
self._stub = test_pb2_grpc.TestServiceStub(self._channel)
async def tearDown(self):
await self._channel.close()
await self._server.stop(None)
async def test_add_after_done(self):
call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
self.assertEqual(grpc.StatusCode.OK, await call.code())
validation = _inject_callbacks(call)
await validation
async def test_unary_unary(self):
call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
validation = _inject_callbacks(call)
self.assertEqual(grpc.StatusCode.OK, await call.code())
await validation
async def test_unary_stream(self):
request = messages_pb2.StreamingOutputCallRequest()
for _ in range(_NUM_STREAM_RESPONSES):
request.response_parameters.append(
messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE))
call = self._stub.StreamingOutputCall(request)
validation = _inject_callbacks(call)
response_cnt = 0
async for response in call:
response_cnt += 1
self.assertIsInstance(response,
messages_pb2.StreamingOutputCallResponse)
self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
self.assertEqual(_NUM_STREAM_RESPONSES, response_cnt)
self.assertEqual(grpc.StatusCode.OK, await call.code())
await validation
async def test_stream_unary(self):
payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE)
request = messages_pb2.StreamingInputCallRequest(payload=payload)
async def gen():
for _ in range(_NUM_STREAM_RESPONSES):
yield request
call = self._stub.StreamingInputCall(gen())
validation = _inject_callbacks(call)
response = await call
self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse)
self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE,
response.aggregated_payload_size)
self.assertEqual(grpc.StatusCode.OK, await call.code())
await validation
async def test_stream_stream(self):
call = self._stub.FullDuplexCall()
validation = _inject_callbacks(call)
request = messages_pb2.StreamingOutputCallRequest()
request.response_parameters.append(
messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE))
for _ in range(_NUM_STREAM_RESPONSES):
await call.write(request)
response = await call.read()
self.assertIsInstance(response,
messages_pb2.StreamingOutputCallResponse)
self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
await call.done_writing()
self.assertEqual(grpc.StatusCode.OK, await call.code())
await validation
if __name__ == '__main__':
logging.basicConfig(level=logging.DEBUG)
unittest.main(verbosity=2)
Loading…
Cancel
Save