Rename to wait_for_conneciton && Add to unary-unary RPC

pull/22565/head
Lidi Zheng 5 years ago
parent 2bbf0a79f6
commit 2b6037f113
  1. 64
      src/python/grpcio/grpc/experimental/aio/_base_call.py
  2. 25
      src/python/grpcio/grpc/experimental/aio/_call.py
  3. 7
      src/python/grpcio/grpc/experimental/aio/_interceptor.py
  4. 2
      src/python/grpcio_tests/tests_aio/tests.json
  5. 6
      src/python/grpcio_tests/tests_aio/unit/call_test.py
  6. 39
      src/python/grpcio_tests/tests_aio/unit/wait_for_connection_test.py

@ -117,6 +117,19 @@ class Call(RpcContext, metaclass=ABCMeta):
The details string of the RPC.
"""
@abstractmethod
async def wait_for_connection(self) -> None:
"""Waits until connected to peer and raises aio.AioRpcError if failed.
This is an EXPERIMENTAL method.
This method makes ensure if the RPC has been successfully connected.
Otherwise, an AioRpcError will be raised to explain the reason of the
connection failure.
This method is recommended for building retry mechanisms.
"""
class UnaryUnaryCall(Generic[RequestType, ResponseType],
Call,
@ -158,23 +171,6 @@ class UnaryStreamCall(Generic[RequestType, ResponseType],
stream.
"""
@abstractmethod
async def try_connect(self) -> None:
"""Tries to connect to peer and raise aio.AioRpcError if failed.
This is an EXPERIMENTAL method.
This method is available for streaming RPCs. This method enables the
application to ensure if the RPC has been successfully connected.
Otherwise, an AioRpcError will be raised to explain the reason of the
connection failure.
For unary-unary RPCs, the connectivity issue will be raised once the
application awaits the call.
This method is recommended for building retry mechanisms.
"""
class StreamUnaryCall(Generic[RequestType, ResponseType],
Call,
@ -204,23 +200,6 @@ class StreamUnaryCall(Generic[RequestType, ResponseType],
The response message of the stream.
"""
@abstractmethod
async def try_connect(self) -> None:
"""Tries to connect to peer and raise aio.AioRpcError if failed.
This is an EXPERIMENTAL method.
This method is available for streaming RPCs. This method enables the
application to ensure if the RPC has been successfully connected.
Otherwise, an AioRpcError will be raised to explain the reason of the
connection failure.
For unary-unary RPCs, the connectivity issue will be raised once the
application awaits the call.
This method is recommended for building retry mechanisms.
"""
class StreamStreamCall(Generic[RequestType, ResponseType],
Call,
@ -263,20 +242,3 @@ class StreamStreamCall(Generic[RequestType, ResponseType],
After done_writing is called, any additional invocation to the write
function will fail. This function is idempotent.
"""
@abstractmethod
async def try_connect(self) -> None:
"""Tries to connect to peer and raise aio.AioRpcError if failed.
This is an EXPERIMENTAL method.
This method is available for streaming RPCs. This method enables the
application to ensure if the RPC has been successfully connected.
Otherwise, an AioRpcError will be raised to explain the reason of the
connection failure.
For unary-unary RPCs, the connectivity issue will be raised once the
application awaits the call.
This method is recommended for building retry mechanisms.
"""

@ -18,7 +18,7 @@ import enum
import inspect
import logging
from functools import partial
from typing import AsyncIterable, Awaitable, Optional, Tuple
from typing import AsyncIterable, Optional, Tuple
import grpc
from grpc import _common
@ -250,9 +250,8 @@ class _APIStyle(enum.IntEnum):
class _UnaryResponseMixin(Call):
_call_response: asyncio.Task
def _init_unary_response_mixin(self,
response_coro: Awaitable[ResponseType]):
self._call_response = self._loop.create_task(response_coro)
def _init_unary_response_mixin(self, response_task: asyncio.Task):
self._call_response = response_task
def cancel(self) -> bool:
if super().cancel():
@ -458,7 +457,7 @@ class _StreamRequestMixin(Call):
self._raise_for_different_style(_APIStyle.READER_WRITER)
await self._done_writing()
async def try_connect(self) -> None:
async def wait_for_connection(self) -> None:
await self._metadata_sent.wait()
if self.done():
await self._raise_for_status()
@ -470,6 +469,7 @@ class UnaryUnaryCall(_UnaryResponseMixin, Call, _base_call.UnaryUnaryCall):
Returned when an instance of `UnaryUnaryMultiCallable` object is called.
"""
_request: RequestType
_invocation_task: asyncio.Task
# pylint: disable=too-many-arguments
def __init__(self, request: RequestType, deadline: Optional[float],
@ -483,7 +483,8 @@ class UnaryUnaryCall(_UnaryResponseMixin, Call, _base_call.UnaryUnaryCall):
channel.call(method, deadline, credentials, wait_for_ready),
metadata, request_serializer, response_deserializer, loop)
self._request = request
self._init_unary_response_mixin(self._invoke())
self._invocation_task = loop.create_task(self._invoke())
self._init_unary_response_mixin(self._invocation_task)
async def _invoke(self) -> ResponseType:
serialized_request = _common.serialize(self._request,
@ -505,6 +506,11 @@ class UnaryUnaryCall(_UnaryResponseMixin, Call, _base_call.UnaryUnaryCall):
else:
return cygrpc.EOF
async def wait_for_connection(self) -> None:
await self._invocation_task
if self.done():
await self._raise_for_status()
class UnaryStreamCall(_StreamResponseMixin, Call, _base_call.UnaryStreamCall):
"""Object for managing unary-stream RPC calls.
@ -541,7 +547,7 @@ class UnaryStreamCall(_StreamResponseMixin, Call, _base_call.UnaryStreamCall):
self.cancel()
raise
async def try_connect(self) -> None:
async def wait_for_connection(self) -> None:
await self._send_unary_request_task
if self.done():
await self._raise_for_status()
@ -566,8 +572,13 @@ class StreamUnaryCall(_StreamRequestMixin, _UnaryResponseMixin, Call,
channel.call(method, deadline, credentials, wait_for_ready),
metadata, request_serializer, response_deserializer, loop)
<<<<<<< HEAD
self._init_stream_request_mixin(request_iterator)
self._init_unary_response_mixin(self._conduct_rpc())
=======
self._init_stream_request_mixin(request_async_iterator)
self._init_unary_response_mixin(loop.create_task(self._conduct_rpc()))
>>>>>>> Rename to wait_for_conneciton && Add to unary-unary RPC
async def _conduct_rpc(self) -> ResponseType:
try:

@ -330,6 +330,10 @@ class InterceptedUnaryUnaryCall(_base_call.UnaryUnaryCall):
response = yield from call.__await__()
return response
async def wait_for_connection(self) -> None:
call = await self._interceptors_task
return await call.wait_for_connection()
class UnaryUnaryCallResponse(_base_call.UnaryUnaryCall):
"""Final UnaryUnaryCall class finished with a response."""
@ -374,3 +378,6 @@ class UnaryUnaryCallResponse(_base_call.UnaryUnaryCall):
# for telling the interpreter that __await__ is a generator.
yield None
return self._response
async def wait_for_connection(self) -> None:
pass

@ -28,6 +28,6 @@
"unit.server_interceptor_test.TestServerInterceptor",
"unit.server_test.TestServer",
"unit.timeout_test.TestTimeout",
"unit.try_connect_test.TestTryConnect",
"unit.wait_for_connection_test.TestWaitForConnection",
"unit.wait_for_ready_test.TestWaitForReady"
]

@ -24,6 +24,7 @@ from grpc.experimental import aio
from src.proto.grpc.testing import messages_pb2, test_pb2_grpc
from tests_aio.unit._test_base import AioTestBase
from tests_aio.unit._test_server import start_test_server
from tests_aio.unit._constants import UNREACHABLE_TARGET
_SHORT_TIMEOUT_S = datetime.timedelta(seconds=1).total_seconds()
@ -32,7 +33,6 @@ _RESPONSE_PAYLOAD_SIZE = 42
_REQUEST_PAYLOAD_SIZE = 7
_LOCAL_CANCEL_DETAILS_EXPECTATION = 'Locally cancelled by application!'
_RESPONSE_INTERVAL_US = int(_SHORT_TIMEOUT_S * 1000 * 1000)
_UNREACHABLE_TARGET = '0.1:1111'
_INFINITE_INTERVAL_US = 2**31 - 1
@ -78,7 +78,7 @@ class TestUnaryUnaryCall(_MulticallableTestMixin, AioTestBase):
self.assertIs(response, response_retry)
async def test_call_rpc_error(self):
async with aio.insecure_channel(_UNREACHABLE_TARGET) as channel:
async with aio.insecure_channel(UNREACHABLE_TARGET) as channel:
stub = test_pb2_grpc.TestServiceStub(channel)
call = stub.UnaryCall(messages_pb2.SimpleRequest())
@ -577,7 +577,7 @@ class TestStreamUnaryCall(_MulticallableTestMixin, AioTestBase):
self.assertEqual(await call.code(), grpc.StatusCode.OK)
async def test_call_rpc_error(self):
async with aio.insecure_channel(_UNREACHABLE_TARGET) as channel:
async with aio.insecure_channel(UNREACHABLE_TARGET) as channel:
stub = test_pb2_grpc.TestServiceStub(channel)
# The error should be raised automatically without any traffic.

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests behavior of the try connect API on client side."""
"""Tests behavior of the wait for connection API on client side."""
import asyncio
import logging
@ -26,9 +26,9 @@ from tests_aio.unit._test_base import AioTestBase
from tests_aio.unit._test_server import start_test_server
from tests_aio.unit import _common
from src.proto.grpc.testing import messages_pb2, test_pb2_grpc
from tests_aio.unit._constants import UNREACHABLE_TARGET
_REQUEST = b'\x01\x02\x03'
_UNREACHABLE_TARGET = '0.1:1111'
_TEST_METHOD = '/test/Test'
_NUM_STREAM_RESPONSES = 5
@ -36,13 +36,13 @@ _REQUEST_PAYLOAD_SIZE = 7
_RESPONSE_PAYLOAD_SIZE = 42
class TestTryConnect(AioTestBase):
"""Tests if try connect raises connectivity issue."""
class TestWaitForConnection(AioTestBase):
"""Tests if wait_for_connection raises connectivity issue."""
async def setUp(self):
address, self._server = await start_test_server()
self._channel = aio.insecure_channel(address)
self._dummy_channel = aio.insecure_channel(_UNREACHABLE_TARGET)
self._dummy_channel = aio.insecure_channel(UNREACHABLE_TARGET)
self._stub = test_pb2_grpc.TestServiceStub(self._channel)
async def tearDown(self):
@ -50,6 +50,15 @@ class TestTryConnect(AioTestBase):
await self._channel.close()
await self._server.stop(None)
async def test_unary_unary_ok(self):
call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
# No exception raised and no message swallowed.
await call.wait_for_connection()
response = await call
self.assertIsInstance(response, messages_pb2.SimpleResponse)
async def test_unary_stream_ok(self):
request = messages_pb2.StreamingOutputCallRequest()
for _ in range(_NUM_STREAM_RESPONSES):
@ -59,7 +68,7 @@ class TestTryConnect(AioTestBase):
call = self._stub.StreamingOutputCall(request)
# No exception raised and no message swallowed.
await call.try_connect()
await call.wait_for_connection()
response_cnt = 0
async for response in call:
@ -75,7 +84,7 @@ class TestTryConnect(AioTestBase):
call = self._stub.StreamingInputCall()
# No exception raised and no message swallowed.
await call.try_connect()
await call.wait_for_connection()
payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE)
request = messages_pb2.StreamingInputCallRequest(payload=payload)
@ -95,7 +104,7 @@ class TestTryConnect(AioTestBase):
call = self._stub.FullDuplexCall()
# No exception raised and no message swallowed.
await call.try_connect()
await call.wait_for_connection()
request = messages_pb2.StreamingOutputCallRequest()
request.response_parameters.append(
@ -112,11 +121,19 @@ class TestTryConnect(AioTestBase):
self.assertEqual(grpc.StatusCode.OK, await call.code())
async def test_unary_unary_error(self):
call = self._dummy_channel.unary_unary(_TEST_METHOD)(_REQUEST)
with self.assertRaises(aio.AioRpcError) as exception_context:
await call.wait_for_connection()
rpc_error = exception_context.exception
self.assertEqual(grpc.StatusCode.UNAVAILABLE, rpc_error.code())
async def test_unary_stream_error(self):
call = self._dummy_channel.unary_stream(_TEST_METHOD)(_REQUEST)
with self.assertRaises(aio.AioRpcError) as exception_context:
await call.try_connect()
await call.wait_for_connection()
rpc_error = exception_context.exception
self.assertEqual(grpc.StatusCode.UNAVAILABLE, rpc_error.code())
@ -124,7 +141,7 @@ class TestTryConnect(AioTestBase):
call = self._dummy_channel.stream_unary(_TEST_METHOD)()
with self.assertRaises(aio.AioRpcError) as exception_context:
await call.try_connect()
await call.wait_for_connection()
rpc_error = exception_context.exception
self.assertEqual(grpc.StatusCode.UNAVAILABLE, rpc_error.code())
@ -132,7 +149,7 @@ class TestTryConnect(AioTestBase):
call = self._dummy_channel.stream_stream(_TEST_METHOD)()
with self.assertRaises(aio.AioRpcError) as exception_context:
await call.try_connect()
await call.wait_for_connection()
rpc_error = exception_context.exception
self.assertEqual(grpc.StatusCode.UNAVAILABLE, rpc_error.code())
Loading…
Cancel
Save