From 1c78ccd44eda86f48ec96d2c7735d298873a56d0 Mon Sep 17 00:00:00 2001 From: Lidi Zheng Date: Wed, 15 Jan 2020 18:07:45 -0800 Subject: [PATCH] Implement add_done_callbacks and time_remaining --- .../grpc/_cython/_cygrpc/aio/call.pxd.pxi | 2 + .../grpc/_cython/_cygrpc/aio/call.pyx.pxi | 7 + .../grpc/experimental/aio/_base_call.py | 6 +- .../grpcio/grpc/experimental/aio/_call.py | 22 ++- .../grpcio/grpc/experimental/aio/_typing.py | 1 + .../tests_aio/unit/done_callback_test.py | 160 ++++++++++++++++++ 6 files changed, 187 insertions(+), 11 deletions(-) create mode 100644 src/python/grpcio_tests/tests_aio/unit/done_callback_test.py diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pxd.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pxd.pxi index d95cf5c52c9..fffa61075ab 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pxd.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pxd.pxi @@ -28,4 +28,6 @@ cdef class _AioCall(GrpcCallWrapper): # because Core is holding a pointer for the callback handler. bint _is_locally_cancelled + object _deadline + cdef void _create_grpc_call(self, object timeout, bytes method, CallCredentials credentials) except * diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi index 279c416a923..6f7346dd630 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi @@ -36,6 +36,7 @@ cdef class _AioCall(GrpcCallWrapper): self._loop = asyncio.get_event_loop() self._create_grpc_call(deadline, method, call_credentials) self._is_locally_cancelled = False + self._deadline = deadline def __dealloc__(self): if self.call: @@ -84,6 +85,12 @@ cdef class _AioCall(GrpcCallWrapper): grpc_slice_unref(method_slice) + def time_remaining(self): + if self._deadline is None: + return None + else: + return max(0, self._deadline - time.time()) + def cancel(self, AioRpcStatus status): """Cancels the RPC in Core with given RPC status. diff --git a/src/python/grpcio/grpc/experimental/aio/_base_call.py b/src/python/grpcio/grpc/experimental/aio/_base_call.py index bdd6902d893..fb0108de2e1 100644 --- a/src/python/grpcio/grpc/experimental/aio/_base_call.py +++ b/src/python/grpcio/grpc/experimental/aio/_base_call.py @@ -24,7 +24,7 @@ from typing import (Any, AsyncIterable, Awaitable, Callable, Generic, Optional, import grpc -from ._typing import EOFType, MetadataType, RequestType, ResponseType +from ._typing import EOFType, MetadataType, RequestType, ResponseType, DoneCallbackType __all__ = 'RpcContext', 'Call', 'UnaryUnaryCall', 'UnaryStreamCall' @@ -73,11 +73,11 @@ class RpcContext(metaclass=ABCMeta): """ @abstractmethod - def add_done_callback(self, callback: Callable[[Any], None]) -> None: + def add_done_callback(self, callback: DoneCallbackType) -> None: """Registers a callback to be called on RPC termination. Args: - callback: A callable object will be called with the context object as + callback: A callable object will be called with the call object as its only argument. """ diff --git a/src/python/grpcio/grpc/experimental/aio/_call.py b/src/python/grpcio/grpc/experimental/aio/_call.py index c8bd10244b4..d443834accc 100644 --- a/src/python/grpcio/grpc/experimental/aio/_call.py +++ b/src/python/grpcio/grpc/experimental/aio/_call.py @@ -14,7 +14,7 @@ """Invocation-side implementation of gRPC Asyncio Python.""" import asyncio -from typing import AsyncIterable, Awaitable, Dict, Optional +from typing import AsyncIterable, Awaitable, List, Dict, Optional import grpc from grpc import _common @@ -22,7 +22,7 @@ from grpc._cython import cygrpc from . import _base_call from ._typing import (DeserializingFunction, MetadataType, RequestType, - ResponseType, SerializingFunction) + ResponseType, SerializingFunction, DoneCallbackType) __all__ = 'AioRpcError', 'Call', 'UnaryUnaryCall', 'UnaryStreamCall' @@ -157,6 +157,7 @@ class Call(_base_call.Call): _initial_metadata: Awaitable[MetadataType] _locally_cancelled: bool _cython_call: cygrpc._AioCall + _done_callbacks: List[DoneCallbackType] def __init__(self, cython_call: cygrpc._AioCall) -> None: self._loop = asyncio.get_event_loop() @@ -165,6 +166,7 @@ class Call(_base_call.Call): self._initial_metadata = self._loop.create_future() self._locally_cancelled = False self._cython_call = cython_call + self._done_callbacks = [] def __del__(self) -> None: if not self._status.done(): @@ -192,11 +194,14 @@ class Call(_base_call.Call): def done(self) -> bool: return self._status.done() - def add_done_callback(self, unused_callback) -> None: - raise NotImplementedError() + def add_done_callback(self, callback: DoneCallbackType) -> None: + if self.done(): + callback(self) + else: + self._done_callbacks.append(callback) def time_remaining(self) -> Optional[float]: - raise NotImplementedError() + return self._cython_call.time_remaining() async def initial_metadata(self) -> MetadataType: return await self._initial_metadata @@ -220,9 +225,7 @@ class Call(_base_call.Call): def _set_status(self, status: cygrpc.AioRpcStatus) -> None: """Private method to set final status of the RPC. - This method may be called multiple time due to data race between local - cancellation (by application) and Core receiving status from peer. We - make no promise here which one will win. + This method should only be invoked once. """ # In case of local cancellation, flip the flag. if status.details() is _LOCAL_CANCELLATION_DETAILS: @@ -236,6 +239,9 @@ class Call(_base_call.Call): self._status.set_result(status) self._code = _common.CYGRPC_STATUS_CODE_TO_STATUS_CODE[status.code()] + for callback in self._done_callbacks: + callback(self) + async def _raise_for_status(self) -> None: if self._locally_cancelled: raise asyncio.CancelledError() diff --git a/src/python/grpcio/grpc/experimental/aio/_typing.py b/src/python/grpcio/grpc/experimental/aio/_typing.py index c60eab85449..15583754a63 100644 --- a/src/python/grpcio/grpc/experimental/aio/_typing.py +++ b/src/python/grpcio/grpc/experimental/aio/_typing.py @@ -24,3 +24,4 @@ MetadatumType = Tuple[Text, AnyStr] MetadataType = Sequence[MetadatumType] ChannelArgumentType = Sequence[Tuple[Text, Any]] EOFType = type(EOF) +DoneCallbackType = Callable[[Any], None] diff --git a/src/python/grpcio_tests/tests_aio/unit/done_callback_test.py b/src/python/grpcio_tests/tests_aio/unit/done_callback_test.py new file mode 100644 index 00000000000..f5f6a75974a --- /dev/null +++ b/src/python/grpcio_tests/tests_aio/unit/done_callback_test.py @@ -0,0 +1,160 @@ +# Copyright 2020 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing the done callbacks mechanism.""" + + +# Copyright 2019 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import logging +import unittest +import time +import gc + +import grpc +from grpc.experimental import aio +from tests_aio.unit._test_base import AioTestBase +from tests.unit.framework.common import test_constants +from src.proto.grpc.testing import messages_pb2, test_pb2_grpc +from tests_aio.unit._test_server import start_test_server + + +_NUM_STREAM_RESPONSES = 5 +_REQUEST_PAYLOAD_SIZE = 7 +_RESPONSE_PAYLOAD_SIZE = 42 + + +def _inject_callbacks(call): + first_callback_ran = asyncio.Event() + def first_callback(unused_call): + first_callback_ran.set() + + second_callback_ran = asyncio.Event() + def second_callback(unused_call): + second_callback_ran.set() + + call.add_done_callback(first_callback) + call.add_done_callback(second_callback) + + async def validation(): + await asyncio.wait_for( + asyncio.gather(first_callback_ran.wait(), second_callback_ran.wait()), + test_constants.SHORT_TIMEOUT + ) + + return validation() + + +class TestDoneCallback(AioTestBase): + + async def setUp(self): + address, self._server = await start_test_server() + self._channel = aio.insecure_channel(address) + self._stub = test_pb2_grpc.TestServiceStub(self._channel) + + async def tearDown(self): + await self._channel.close() + await self._server.stop(None) + + async def test_add_after_done(self): + call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) + self.assertEqual(grpc.StatusCode.OK, await call.code()) + + validation = _inject_callbacks(call) + await validation + + async def test_unary_unary(self): + call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) + validation = _inject_callbacks(call) + + self.assertEqual(grpc.StatusCode.OK, await call.code()) + + await validation + + async def test_unary_stream(self): + request = messages_pb2.StreamingOutputCallRequest() + for _ in range(_NUM_STREAM_RESPONSES): + request.response_parameters.append( + messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE)) + + call = self._stub.StreamingOutputCall(request) + validation = _inject_callbacks(call) + + response_cnt = 0 + async for response in call: + response_cnt += 1 + self.assertIsInstance(response, + messages_pb2.StreamingOutputCallResponse) + self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) + + self.assertEqual(_NUM_STREAM_RESPONSES, response_cnt) + self.assertEqual(grpc.StatusCode.OK, await call.code()) + + await validation + + async def test_stream_unary(self): + payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) + request = messages_pb2.StreamingInputCallRequest(payload=payload) + + async def gen(): + for _ in range(_NUM_STREAM_RESPONSES): + yield request + + call = self._stub.StreamingInputCall(gen()) + validation = _inject_callbacks(call) + + response = await call + self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse) + self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE, + response.aggregated_payload_size) + self.assertEqual(grpc.StatusCode.OK, await call.code()) + + await validation + + async def test_stream_stream(self): + call = self._stub.FullDuplexCall() + validation = _inject_callbacks(call) + + request = messages_pb2.StreamingOutputCallRequest() + request.response_parameters.append( + messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE)) + + for _ in range(_NUM_STREAM_RESPONSES): + await call.write(request) + response = await call.read() + self.assertIsInstance(response, + messages_pb2.StreamingOutputCallResponse) + self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) + + await call.done_writing() + + self.assertEqual(grpc.StatusCode.OK, await call.code()) + await validation + + +if __name__ == '__main__': + logging.basicConfig(level=logging.DEBUG) + unittest.main(verbosity=2)