diff --git a/src/python/grpcio/grpc/experimental/aio/_base_call.py b/src/python/grpcio/grpc/experimental/aio/_base_call.py index d116982aa79..214e208c005 100644 --- a/src/python/grpcio/grpc/experimental/aio/_base_call.py +++ b/src/python/grpcio/grpc/experimental/aio/_base_call.py @@ -117,6 +117,19 @@ class Call(RpcContext, metaclass=ABCMeta): The details string of the RPC. """ + @abstractmethod + async def wait_for_connection(self) -> None: + """Waits until connected to peer and raises aio.AioRpcError if failed. + + This is an EXPERIMENTAL method. + + This method ensures the RPC has been successfully connected. Otherwise, + an AioRpcError will be raised to explain the reason of the connection + failure. + + This method is recommended for building retry mechanisms. + """ + class UnaryUnaryCall(Generic[RequestType, ResponseType], Call, diff --git a/src/python/grpcio/grpc/experimental/aio/_call.py b/src/python/grpcio/grpc/experimental/aio/_call.py index be4887d239c..3d1d19fd3fa 100644 --- a/src/python/grpcio/grpc/experimental/aio/_call.py +++ b/src/python/grpcio/grpc/experimental/aio/_call.py @@ -18,7 +18,7 @@ import enum import inspect import logging from functools import partial -from typing import AsyncIterable, Awaitable, Optional, Tuple +from typing import AsyncIterable, Optional, Tuple import grpc from grpc import _common @@ -250,9 +250,8 @@ class _APIStyle(enum.IntEnum): class _UnaryResponseMixin(Call): _call_response: asyncio.Task - def _init_unary_response_mixin(self, - response_coro: Awaitable[ResponseType]): - self._call_response = self._loop.create_task(response_coro) + def _init_unary_response_mixin(self, response_task: asyncio.Task): + self._call_response = response_task def cancel(self) -> bool: if super().cancel(): @@ -458,6 +457,11 @@ class _StreamRequestMixin(Call): self._raise_for_different_style(_APIStyle.READER_WRITER) await self._done_writing() + async def wait_for_connection(self) -> None: + await self._metadata_sent.wait() + if self.done(): + await self._raise_for_status() + class UnaryUnaryCall(_UnaryResponseMixin, Call, _base_call.UnaryUnaryCall): """Object for managing unary-unary RPC calls. @@ -465,6 +469,7 @@ class UnaryUnaryCall(_UnaryResponseMixin, Call, _base_call.UnaryUnaryCall): Returned when an instance of `UnaryUnaryMultiCallable` object is called. """ _request: RequestType + _invocation_task: asyncio.Task # pylint: disable=too-many-arguments def __init__(self, request: RequestType, deadline: Optional[float], @@ -478,7 +483,8 @@ class UnaryUnaryCall(_UnaryResponseMixin, Call, _base_call.UnaryUnaryCall): channel.call(method, deadline, credentials, wait_for_ready), metadata, request_serializer, response_deserializer, loop) self._request = request - self._init_unary_response_mixin(self._invoke()) + self._invocation_task = loop.create_task(self._invoke()) + self._init_unary_response_mixin(self._invocation_task) async def _invoke(self) -> ResponseType: serialized_request = _common.serialize(self._request, @@ -500,6 +506,11 @@ class UnaryUnaryCall(_UnaryResponseMixin, Call, _base_call.UnaryUnaryCall): else: return cygrpc.EOF + async def wait_for_connection(self) -> None: + await self._invocation_task + if self.done(): + await self._raise_for_status() + class UnaryStreamCall(_StreamResponseMixin, Call, _base_call.UnaryStreamCall): """Object for managing unary-stream RPC calls. @@ -536,6 +547,11 @@ class UnaryStreamCall(_StreamResponseMixin, Call, _base_call.UnaryStreamCall): self.cancel() raise + async def wait_for_connection(self) -> None: + await self._send_unary_request_task + if self.done(): + await self._raise_for_status() + class StreamUnaryCall(_StreamRequestMixin, _UnaryResponseMixin, Call, _base_call.StreamUnaryCall): @@ -557,7 +573,7 @@ class StreamUnaryCall(_StreamRequestMixin, _UnaryResponseMixin, Call, metadata, request_serializer, response_deserializer, loop) self._init_stream_request_mixin(request_iterator) - self._init_unary_response_mixin(self._conduct_rpc()) + self._init_unary_response_mixin(loop.create_task(self._conduct_rpc())) async def _conduct_rpc(self) -> ResponseType: try: diff --git a/src/python/grpcio/grpc/experimental/aio/_interceptor.py b/src/python/grpcio/grpc/experimental/aio/_interceptor.py index 9e99a1b125d..d4aca3ae0fc 100644 --- a/src/python/grpcio/grpc/experimental/aio/_interceptor.py +++ b/src/python/grpcio/grpc/experimental/aio/_interceptor.py @@ -330,6 +330,10 @@ class InterceptedUnaryUnaryCall(_base_call.UnaryUnaryCall): response = yield from call.__await__() return response + async def wait_for_connection(self) -> None: + call = await self._interceptors_task + return await call.wait_for_connection() + class UnaryUnaryCallResponse(_base_call.UnaryUnaryCall): """Final UnaryUnaryCall class finished with a response.""" @@ -374,3 +378,6 @@ class UnaryUnaryCallResponse(_base_call.UnaryUnaryCall): # for telling the interpreter that __await__ is a generator. yield None return self._response + + async def wait_for_connection(self) -> None: + pass diff --git a/src/python/grpcio_tests/tests_aio/tests.json b/src/python/grpcio_tests/tests_aio/tests.json index 84dbf02b937..71f8733f5f9 100644 --- a/src/python/grpcio_tests/tests_aio/tests.json +++ b/src/python/grpcio_tests/tests_aio/tests.json @@ -28,5 +28,6 @@ "unit.server_interceptor_test.TestServerInterceptor", "unit.server_test.TestServer", "unit.timeout_test.TestTimeout", + "unit.wait_for_connection_test.TestWaitForConnection", "unit.wait_for_ready_test.TestWaitForReady" ] diff --git a/src/python/grpcio_tests/tests_aio/unit/call_test.py b/src/python/grpcio_tests/tests_aio/unit/call_test.py index 5b52f0e1724..2548e777783 100644 --- a/src/python/grpcio_tests/tests_aio/unit/call_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/call_test.py @@ -16,23 +16,23 @@ import asyncio import logging import unittest +import datetime import grpc from grpc.experimental import aio from src.proto.grpc.testing import messages_pb2, test_pb2_grpc -from tests.unit.framework.common import test_constants from tests_aio.unit._test_base import AioTestBase -from tests.unit import resources - from tests_aio.unit._test_server import start_test_server +from tests_aio.unit._constants import UNREACHABLE_TARGET + +_SHORT_TIMEOUT_S = datetime.timedelta(seconds=1).total_seconds() _NUM_STREAM_RESPONSES = 5 _RESPONSE_PAYLOAD_SIZE = 42 _REQUEST_PAYLOAD_SIZE = 7 _LOCAL_CANCEL_DETAILS_EXPECTATION = 'Locally cancelled by application!' -_RESPONSE_INTERVAL_US = test_constants.SHORT_TIMEOUT * 1000 * 1000 -_UNREACHABLE_TARGET = '0.1:1111' +_RESPONSE_INTERVAL_US = int(_SHORT_TIMEOUT_S * 1000 * 1000) _INFINITE_INTERVAL_US = 2**31 - 1 @@ -78,7 +78,7 @@ class TestUnaryUnaryCall(_MulticallableTestMixin, AioTestBase): self.assertIs(response, response_retry) async def test_call_rpc_error(self): - async with aio.insecure_channel(_UNREACHABLE_TARGET) as channel: + async with aio.insecure_channel(UNREACHABLE_TARGET) as channel: stub = test_pb2_grpc.TestServiceStub(channel) call = stub.UnaryCall(messages_pb2.SimpleRequest()) @@ -434,24 +434,24 @@ class TestUnaryStreamCall(_MulticallableTestMixin, AioTestBase): interval_us=_RESPONSE_INTERVAL_US, )) - call = self._stub.StreamingOutputCall( - request, timeout=test_constants.SHORT_TIMEOUT * 2) + call = self._stub.StreamingOutputCall(request, + timeout=_SHORT_TIMEOUT_S * 2) response = await call.read() self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) # Should be around the same as the timeout remained_time = call.time_remaining() - self.assertGreater(remained_time, test_constants.SHORT_TIMEOUT * 3 / 2) - self.assertLess(remained_time, test_constants.SHORT_TIMEOUT * 5 / 2) + self.assertGreater(remained_time, _SHORT_TIMEOUT_S * 3 / 2) + self.assertLess(remained_time, _SHORT_TIMEOUT_S * 5 / 2) response = await call.read() self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) # Should be around the timeout minus a unit of wait time remained_time = call.time_remaining() - self.assertGreater(remained_time, test_constants.SHORT_TIMEOUT / 2) - self.assertLess(remained_time, test_constants.SHORT_TIMEOUT * 3 / 2) + self.assertGreater(remained_time, _SHORT_TIMEOUT_S / 2) + self.assertLess(remained_time, _SHORT_TIMEOUT_S * 3 / 2) self.assertEqual(grpc.StatusCode.OK, await call.code()) @@ -538,14 +538,14 @@ class TestStreamUnaryCall(_MulticallableTestMixin, AioTestBase): with self.assertRaises(asyncio.CancelledError): for _ in range(_NUM_STREAM_RESPONSES): yield request - await asyncio.sleep(test_constants.SHORT_TIMEOUT) + await asyncio.sleep(_SHORT_TIMEOUT_S) request_iterator_received_the_exception.set() call = self._stub.StreamingInputCall(request_iterator()) # Cancel the RPC after at least one response async def cancel_later(): - await asyncio.sleep(test_constants.SHORT_TIMEOUT * 2) + await asyncio.sleep(_SHORT_TIMEOUT_S * 2) call.cancel() cancel_later_task = self.loop.create_task(cancel_later()) @@ -576,6 +576,33 @@ class TestStreamUnaryCall(_MulticallableTestMixin, AioTestBase): self.assertEqual(await call.code(), grpc.StatusCode.OK) + async def test_call_rpc_error(self): + async with aio.insecure_channel(UNREACHABLE_TARGET) as channel: + stub = test_pb2_grpc.TestServiceStub(channel) + + # The error should be raised automatically without any traffic. + call = stub.StreamingInputCall() + with self.assertRaises(aio.AioRpcError) as exception_context: + await call + + self.assertEqual(grpc.StatusCode.UNAVAILABLE, + exception_context.exception.code()) + + self.assertTrue(call.done()) + self.assertEqual(grpc.StatusCode.UNAVAILABLE, await call.code()) + + async def test_timeout(self): + call = self._stub.StreamingInputCall(timeout=_SHORT_TIMEOUT_S) + + # The error should be raised automatically without any traffic. + with self.assertRaises(aio.AioRpcError) as exception_context: + await call + + rpc_error = exception_context.exception + self.assertEqual(grpc.StatusCode.DEADLINE_EXCEEDED, rpc_error.code()) + self.assertTrue(call.done()) + self.assertEqual(grpc.StatusCode.DEADLINE_EXCEEDED, await call.code()) + # Prepares the request that stream in a ping-pong manner. _STREAM_OUTPUT_REQUEST_ONE_RESPONSE = messages_pb2.StreamingOutputCallRequest() @@ -733,14 +760,14 @@ class TestStreamStreamCall(_MulticallableTestMixin, AioTestBase): with self.assertRaises(asyncio.CancelledError): for _ in range(_NUM_STREAM_RESPONSES): yield request - await asyncio.sleep(test_constants.SHORT_TIMEOUT) + await asyncio.sleep(_SHORT_TIMEOUT_S) request_iterator_received_the_exception.set() call = self._stub.FullDuplexCall(request_iterator()) # Cancel the RPC after at least one response async def cancel_later(): - await asyncio.sleep(test_constants.SHORT_TIMEOUT * 2) + await asyncio.sleep(_SHORT_TIMEOUT_S * 2) call.cancel() cancel_later_task = self.loop.create_task(cancel_later()) diff --git a/src/python/grpcio_tests/tests_aio/unit/wait_for_connection_test.py b/src/python/grpcio_tests/tests_aio/unit/wait_for_connection_test.py new file mode 100644 index 00000000000..cb6f7985290 --- /dev/null +++ b/src/python/grpcio_tests/tests_aio/unit/wait_for_connection_test.py @@ -0,0 +1,159 @@ +# Copyright 2020 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests behavior of the wait for connection API on client side.""" + +import asyncio +import logging +import unittest +import datetime +from typing import Callable, Tuple + +import grpc +from grpc.experimental import aio + +from tests_aio.unit._test_base import AioTestBase +from tests_aio.unit._test_server import start_test_server +from tests_aio.unit import _common +from src.proto.grpc.testing import messages_pb2, test_pb2_grpc +from tests_aio.unit._constants import UNREACHABLE_TARGET + +_REQUEST = b'\x01\x02\x03' +_TEST_METHOD = '/test/Test' + +_NUM_STREAM_RESPONSES = 5 +_REQUEST_PAYLOAD_SIZE = 7 +_RESPONSE_PAYLOAD_SIZE = 42 + + +class TestWaitForConnection(AioTestBase): + """Tests if wait_for_connection raises connectivity issue.""" + + async def setUp(self): + address, self._server = await start_test_server() + self._channel = aio.insecure_channel(address) + self._dummy_channel = aio.insecure_channel(UNREACHABLE_TARGET) + self._stub = test_pb2_grpc.TestServiceStub(self._channel) + + async def tearDown(self): + await self._dummy_channel.close() + await self._channel.close() + await self._server.stop(None) + + async def test_unary_unary_ok(self): + call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) + + # No exception raised and no message swallowed. + await call.wait_for_connection() + + response = await call + self.assertIsInstance(response, messages_pb2.SimpleResponse) + + async def test_unary_stream_ok(self): + request = messages_pb2.StreamingOutputCallRequest() + for _ in range(_NUM_STREAM_RESPONSES): + request.response_parameters.append( + messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE)) + + call = self._stub.StreamingOutputCall(request) + + # No exception raised and no message swallowed. + await call.wait_for_connection() + + response_cnt = 0 + async for response in call: + response_cnt += 1 + self.assertIs(type(response), + messages_pb2.StreamingOutputCallResponse) + self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) + + self.assertEqual(_NUM_STREAM_RESPONSES, response_cnt) + self.assertEqual(await call.code(), grpc.StatusCode.OK) + + async def test_stream_unary_ok(self): + call = self._stub.StreamingInputCall() + + # No exception raised and no message swallowed. + await call.wait_for_connection() + + payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) + request = messages_pb2.StreamingInputCallRequest(payload=payload) + + for _ in range(_NUM_STREAM_RESPONSES): + await call.write(request) + await call.done_writing() + + response = await call + self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse) + self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE, + response.aggregated_payload_size) + + self.assertEqual(await call.code(), grpc.StatusCode.OK) + + async def test_stream_stream_ok(self): + call = self._stub.FullDuplexCall() + + # No exception raised and no message swallowed. + await call.wait_for_connection() + + request = messages_pb2.StreamingOutputCallRequest() + request.response_parameters.append( + messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE)) + + for _ in range(_NUM_STREAM_RESPONSES): + await call.write(request) + response = await call.read() + self.assertIsInstance(response, + messages_pb2.StreamingOutputCallResponse) + self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) + + await call.done_writing() + + self.assertEqual(grpc.StatusCode.OK, await call.code()) + + async def test_unary_unary_error(self): + call = self._dummy_channel.unary_unary(_TEST_METHOD)(_REQUEST) + + with self.assertRaises(aio.AioRpcError) as exception_context: + await call.wait_for_connection() + rpc_error = exception_context.exception + self.assertEqual(grpc.StatusCode.UNAVAILABLE, rpc_error.code()) + + async def test_unary_stream_error(self): + call = self._dummy_channel.unary_stream(_TEST_METHOD)(_REQUEST) + + with self.assertRaises(aio.AioRpcError) as exception_context: + await call.wait_for_connection() + rpc_error = exception_context.exception + self.assertEqual(grpc.StatusCode.UNAVAILABLE, rpc_error.code()) + + async def test_stream_unary_error(self): + call = self._dummy_channel.stream_unary(_TEST_METHOD)() + + with self.assertRaises(aio.AioRpcError) as exception_context: + await call.wait_for_connection() + rpc_error = exception_context.exception + self.assertEqual(grpc.StatusCode.UNAVAILABLE, rpc_error.code()) + + async def test_stream_stream_error(self): + call = self._dummy_channel.stream_stream(_TEST_METHOD)() + + with self.assertRaises(aio.AioRpcError) as exception_context: + await call.wait_for_connection() + rpc_error = exception_context.exception + self.assertEqual(grpc.StatusCode.UNAVAILABLE, rpc_error.code()) + + +if __name__ == '__main__': + logging.basicConfig(level=logging.DEBUG) + unittest.main(verbosity=2)