diff --git a/src/python/grpcio/grpc/experimental/aio/_base_call.py b/src/python/grpcio/grpc/experimental/aio/_base_call.py index e71076fe3d5..8ea8e90c8b1 100644 --- a/src/python/grpcio/grpc/experimental/aio/_base_call.py +++ b/src/python/grpcio/grpc/experimental/aio/_base_call.py @@ -117,6 +117,19 @@ class Call(RpcContext, metaclass=ABCMeta): The details string of the RPC. """ + @abstractmethod + async def wait_for_connection(self) -> None: + """Waits until connected to peer and raises aio.AioRpcError if failed. + + This is an EXPERIMENTAL method. + + This method makes ensure if the RPC has been successfully connected. + Otherwise, an AioRpcError will be raised to explain the reason of the + connection failure. + + This method is recommended for building retry mechanisms. + """ + class UnaryUnaryCall(Generic[RequestType, ResponseType], Call, @@ -158,23 +171,6 @@ class UnaryStreamCall(Generic[RequestType, ResponseType], stream. """ - @abstractmethod - async def try_connect(self) -> None: - """Tries to connect to peer and raise aio.AioRpcError if failed. - - This is an EXPERIMENTAL method. - - This method is available for streaming RPCs. This method enables the - application to ensure if the RPC has been successfully connected. - Otherwise, an AioRpcError will be raised to explain the reason of the - connection failure. - - For unary-unary RPCs, the connectivity issue will be raised once the - application awaits the call. - - This method is recommended for building retry mechanisms. - """ - class StreamUnaryCall(Generic[RequestType, ResponseType], Call, @@ -204,23 +200,6 @@ class StreamUnaryCall(Generic[RequestType, ResponseType], The response message of the stream. """ - @abstractmethod - async def try_connect(self) -> None: - """Tries to connect to peer and raise aio.AioRpcError if failed. - - This is an EXPERIMENTAL method. - - This method is available for streaming RPCs. This method enables the - application to ensure if the RPC has been successfully connected. - Otherwise, an AioRpcError will be raised to explain the reason of the - connection failure. - - For unary-unary RPCs, the connectivity issue will be raised once the - application awaits the call. - - This method is recommended for building retry mechanisms. - """ - class StreamStreamCall(Generic[RequestType, ResponseType], Call, @@ -263,20 +242,3 @@ class StreamStreamCall(Generic[RequestType, ResponseType], After done_writing is called, any additional invocation to the write function will fail. This function is idempotent. """ - - @abstractmethod - async def try_connect(self) -> None: - """Tries to connect to peer and raise aio.AioRpcError if failed. - - This is an EXPERIMENTAL method. - - This method is available for streaming RPCs. This method enables the - application to ensure if the RPC has been successfully connected. - Otherwise, an AioRpcError will be raised to explain the reason of the - connection failure. - - For unary-unary RPCs, the connectivity issue will be raised once the - application awaits the call. - - This method is recommended for building retry mechanisms. - """ diff --git a/src/python/grpcio/grpc/experimental/aio/_call.py b/src/python/grpcio/grpc/experimental/aio/_call.py index dc9dbfe481b..ab8056e7339 100644 --- a/src/python/grpcio/grpc/experimental/aio/_call.py +++ b/src/python/grpcio/grpc/experimental/aio/_call.py @@ -18,7 +18,7 @@ import enum import inspect import logging from functools import partial -from typing import AsyncIterable, Awaitable, Optional, Tuple +from typing import AsyncIterable, Optional, Tuple import grpc from grpc import _common @@ -250,9 +250,8 @@ class _APIStyle(enum.IntEnum): class _UnaryResponseMixin(Call): _call_response: asyncio.Task - def _init_unary_response_mixin(self, - response_coro: Awaitable[ResponseType]): - self._call_response = self._loop.create_task(response_coro) + def _init_unary_response_mixin(self, response_task: asyncio.Task): + self._call_response = response_task def cancel(self) -> bool: if super().cancel(): @@ -458,7 +457,7 @@ class _StreamRequestMixin(Call): self._raise_for_different_style(_APIStyle.READER_WRITER) await self._done_writing() - async def try_connect(self) -> None: + async def wait_for_connection(self) -> None: await self._metadata_sent.wait() if self.done(): await self._raise_for_status() @@ -470,6 +469,7 @@ class UnaryUnaryCall(_UnaryResponseMixin, Call, _base_call.UnaryUnaryCall): Returned when an instance of `UnaryUnaryMultiCallable` object is called. """ _request: RequestType + _invocation_task: asyncio.Task # pylint: disable=too-many-arguments def __init__(self, request: RequestType, deadline: Optional[float], @@ -483,7 +483,8 @@ class UnaryUnaryCall(_UnaryResponseMixin, Call, _base_call.UnaryUnaryCall): channel.call(method, deadline, credentials, wait_for_ready), metadata, request_serializer, response_deserializer, loop) self._request = request - self._init_unary_response_mixin(self._invoke()) + self._invocation_task = loop.create_task(self._invoke()) + self._init_unary_response_mixin(self._invocation_task) async def _invoke(self) -> ResponseType: serialized_request = _common.serialize(self._request, @@ -505,6 +506,11 @@ class UnaryUnaryCall(_UnaryResponseMixin, Call, _base_call.UnaryUnaryCall): else: return cygrpc.EOF + async def wait_for_connection(self) -> None: + await self._invocation_task + if self.done(): + await self._raise_for_status() + class UnaryStreamCall(_StreamResponseMixin, Call, _base_call.UnaryStreamCall): """Object for managing unary-stream RPC calls. @@ -541,7 +547,7 @@ class UnaryStreamCall(_StreamResponseMixin, Call, _base_call.UnaryStreamCall): self.cancel() raise - async def try_connect(self) -> None: + async def wait_for_connection(self) -> None: await self._send_unary_request_task if self.done(): await self._raise_for_status() @@ -566,8 +572,13 @@ class StreamUnaryCall(_StreamRequestMixin, _UnaryResponseMixin, Call, channel.call(method, deadline, credentials, wait_for_ready), metadata, request_serializer, response_deserializer, loop) +<<<<<<< HEAD self._init_stream_request_mixin(request_iterator) self._init_unary_response_mixin(self._conduct_rpc()) +======= + self._init_stream_request_mixin(request_async_iterator) + self._init_unary_response_mixin(loop.create_task(self._conduct_rpc())) +>>>>>>> Rename to wait_for_conneciton && Add to unary-unary RPC async def _conduct_rpc(self) -> ResponseType: try: diff --git a/src/python/grpcio/grpc/experimental/aio/_interceptor.py b/src/python/grpcio/grpc/experimental/aio/_interceptor.py index 9e99a1b125d..d4aca3ae0fc 100644 --- a/src/python/grpcio/grpc/experimental/aio/_interceptor.py +++ b/src/python/grpcio/grpc/experimental/aio/_interceptor.py @@ -330,6 +330,10 @@ class InterceptedUnaryUnaryCall(_base_call.UnaryUnaryCall): response = yield from call.__await__() return response + async def wait_for_connection(self) -> None: + call = await self._interceptors_task + return await call.wait_for_connection() + class UnaryUnaryCallResponse(_base_call.UnaryUnaryCall): """Final UnaryUnaryCall class finished with a response.""" @@ -374,3 +378,6 @@ class UnaryUnaryCallResponse(_base_call.UnaryUnaryCall): # for telling the interpreter that __await__ is a generator. yield None return self._response + + async def wait_for_connection(self) -> None: + pass diff --git a/src/python/grpcio_tests/tests_aio/tests.json b/src/python/grpcio_tests/tests_aio/tests.json index b76b5b893b4..71f8733f5f9 100644 --- a/src/python/grpcio_tests/tests_aio/tests.json +++ b/src/python/grpcio_tests/tests_aio/tests.json @@ -28,6 +28,6 @@ "unit.server_interceptor_test.TestServerInterceptor", "unit.server_test.TestServer", "unit.timeout_test.TestTimeout", - "unit.try_connect_test.TestTryConnect", + "unit.wait_for_connection_test.TestWaitForConnection", "unit.wait_for_ready_test.TestWaitForReady" ] diff --git a/src/python/grpcio_tests/tests_aio/unit/call_test.py b/src/python/grpcio_tests/tests_aio/unit/call_test.py index b0c126640a9..2548e777783 100644 --- a/src/python/grpcio_tests/tests_aio/unit/call_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/call_test.py @@ -24,6 +24,7 @@ from grpc.experimental import aio from src.proto.grpc.testing import messages_pb2, test_pb2_grpc from tests_aio.unit._test_base import AioTestBase from tests_aio.unit._test_server import start_test_server +from tests_aio.unit._constants import UNREACHABLE_TARGET _SHORT_TIMEOUT_S = datetime.timedelta(seconds=1).total_seconds() @@ -32,7 +33,6 @@ _RESPONSE_PAYLOAD_SIZE = 42 _REQUEST_PAYLOAD_SIZE = 7 _LOCAL_CANCEL_DETAILS_EXPECTATION = 'Locally cancelled by application!' _RESPONSE_INTERVAL_US = int(_SHORT_TIMEOUT_S * 1000 * 1000) -_UNREACHABLE_TARGET = '0.1:1111' _INFINITE_INTERVAL_US = 2**31 - 1 @@ -78,7 +78,7 @@ class TestUnaryUnaryCall(_MulticallableTestMixin, AioTestBase): self.assertIs(response, response_retry) async def test_call_rpc_error(self): - async with aio.insecure_channel(_UNREACHABLE_TARGET) as channel: + async with aio.insecure_channel(UNREACHABLE_TARGET) as channel: stub = test_pb2_grpc.TestServiceStub(channel) call = stub.UnaryCall(messages_pb2.SimpleRequest()) @@ -577,7 +577,7 @@ class TestStreamUnaryCall(_MulticallableTestMixin, AioTestBase): self.assertEqual(await call.code(), grpc.StatusCode.OK) async def test_call_rpc_error(self): - async with aio.insecure_channel(_UNREACHABLE_TARGET) as channel: + async with aio.insecure_channel(UNREACHABLE_TARGET) as channel: stub = test_pb2_grpc.TestServiceStub(channel) # The error should be raised automatically without any traffic. diff --git a/src/python/grpcio_tests/tests_aio/unit/try_connect_test.py b/src/python/grpcio_tests/tests_aio/unit/wait_for_connection_test.py similarity index 79% rename from src/python/grpcio_tests/tests_aio/unit/try_connect_test.py rename to src/python/grpcio_tests/tests_aio/unit/wait_for_connection_test.py index 7fc292ea3a1..cb6f7985290 100644 --- a/src/python/grpcio_tests/tests_aio/unit/try_connect_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/wait_for_connection_test.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests behavior of the try connect API on client side.""" +"""Tests behavior of the wait for connection API on client side.""" import asyncio import logging @@ -26,9 +26,9 @@ from tests_aio.unit._test_base import AioTestBase from tests_aio.unit._test_server import start_test_server from tests_aio.unit import _common from src.proto.grpc.testing import messages_pb2, test_pb2_grpc +from tests_aio.unit._constants import UNREACHABLE_TARGET _REQUEST = b'\x01\x02\x03' -_UNREACHABLE_TARGET = '0.1:1111' _TEST_METHOD = '/test/Test' _NUM_STREAM_RESPONSES = 5 @@ -36,13 +36,13 @@ _REQUEST_PAYLOAD_SIZE = 7 _RESPONSE_PAYLOAD_SIZE = 42 -class TestTryConnect(AioTestBase): - """Tests if try connect raises connectivity issue.""" +class TestWaitForConnection(AioTestBase): + """Tests if wait_for_connection raises connectivity issue.""" async def setUp(self): address, self._server = await start_test_server() self._channel = aio.insecure_channel(address) - self._dummy_channel = aio.insecure_channel(_UNREACHABLE_TARGET) + self._dummy_channel = aio.insecure_channel(UNREACHABLE_TARGET) self._stub = test_pb2_grpc.TestServiceStub(self._channel) async def tearDown(self): @@ -50,6 +50,15 @@ class TestTryConnect(AioTestBase): await self._channel.close() await self._server.stop(None) + async def test_unary_unary_ok(self): + call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) + + # No exception raised and no message swallowed. + await call.wait_for_connection() + + response = await call + self.assertIsInstance(response, messages_pb2.SimpleResponse) + async def test_unary_stream_ok(self): request = messages_pb2.StreamingOutputCallRequest() for _ in range(_NUM_STREAM_RESPONSES): @@ -59,7 +68,7 @@ class TestTryConnect(AioTestBase): call = self._stub.StreamingOutputCall(request) # No exception raised and no message swallowed. - await call.try_connect() + await call.wait_for_connection() response_cnt = 0 async for response in call: @@ -75,7 +84,7 @@ class TestTryConnect(AioTestBase): call = self._stub.StreamingInputCall() # No exception raised and no message swallowed. - await call.try_connect() + await call.wait_for_connection() payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) request = messages_pb2.StreamingInputCallRequest(payload=payload) @@ -95,7 +104,7 @@ class TestTryConnect(AioTestBase): call = self._stub.FullDuplexCall() # No exception raised and no message swallowed. - await call.try_connect() + await call.wait_for_connection() request = messages_pb2.StreamingOutputCallRequest() request.response_parameters.append( @@ -112,11 +121,19 @@ class TestTryConnect(AioTestBase): self.assertEqual(grpc.StatusCode.OK, await call.code()) + async def test_unary_unary_error(self): + call = self._dummy_channel.unary_unary(_TEST_METHOD)(_REQUEST) + + with self.assertRaises(aio.AioRpcError) as exception_context: + await call.wait_for_connection() + rpc_error = exception_context.exception + self.assertEqual(grpc.StatusCode.UNAVAILABLE, rpc_error.code()) + async def test_unary_stream_error(self): call = self._dummy_channel.unary_stream(_TEST_METHOD)(_REQUEST) with self.assertRaises(aio.AioRpcError) as exception_context: - await call.try_connect() + await call.wait_for_connection() rpc_error = exception_context.exception self.assertEqual(grpc.StatusCode.UNAVAILABLE, rpc_error.code()) @@ -124,7 +141,7 @@ class TestTryConnect(AioTestBase): call = self._dummy_channel.stream_unary(_TEST_METHOD)() with self.assertRaises(aio.AioRpcError) as exception_context: - await call.try_connect() + await call.wait_for_connection() rpc_error = exception_context.exception self.assertEqual(grpc.StatusCode.UNAVAILABLE, rpc_error.code()) @@ -132,7 +149,7 @@ class TestTryConnect(AioTestBase): call = self._dummy_channel.stream_stream(_TEST_METHOD)() with self.assertRaises(aio.AioRpcError) as exception_context: - await call.try_connect() + await call.wait_for_connection() rpc_error = exception_context.exception self.assertEqual(grpc.StatusCode.UNAVAILABLE, rpc_error.code())