Add try_connect API for UnaryStreamCall and StreamStreamCall

pull/22565/head
Lidi Zheng 5 years ago
parent 4d91e531ab
commit 41866c1250
  1. 30
      src/python/grpcio/grpc/experimental/aio/_base_call.py
  2. 10
      src/python/grpcio/grpc/experimental/aio/_call.py
  3. 1
      src/python/grpcio_tests/tests_aio/tests.json
  4. 55
      src/python/grpcio_tests/tests_aio/unit/call_test.py
  5. 114
      src/python/grpcio_tests/tests_aio/unit/try_connect_test.py

@ -158,6 +158,21 @@ class UnaryStreamCall(Generic[RequestType, ResponseType],
stream.
"""
@abstractmethod
async def try_connect(self) -> None:
"""Tries to connect to peer and raise aio.AioRpcError if failed.
This method is available for RPCs with streaming responses. This method
enables the application to ensure if the RPC has been successfully
connected. Otherwise, an AioRpcError will be raised to explain the
reason of the connection failure.
For RPCs with unary response, the connectivity issue will be raised
once the application awaits the call.
This method is recommended for building retry mechanisms.
"""
class StreamUnaryCall(Generic[RequestType, ResponseType],
Call,
@ -229,3 +244,18 @@ class StreamStreamCall(Generic[RequestType, ResponseType],
After done_writing is called, any additional invocation to the write
function will fail. This function is idempotent.
"""
@abstractmethod
async def try_connect(self) -> None:
"""Tries to connect to peer and raise aio.AioRpcError if failed.
This method is available for RPCs with streaming responses. This method
enables the application to ensure if the RPC has been successfully
connected. Otherwise, an AioRpcError will be raised to explain the
reason of the connection failure.
For RPCs with unary response, the connectivity issue will be raised
once the application awaits the call.
This method is recommended for building retry mechanisms.
"""

@ -536,6 +536,11 @@ class UnaryStreamCall(_StreamResponseMixin, Call, _base_call.UnaryStreamCall):
self.cancel()
raise
async def try_connect(self) -> None:
await self._send_unary_request_task
if self.done():
await self._raise_for_status()
class StreamUnaryCall(_StreamRequestMixin, _UnaryResponseMixin, Call,
_base_call.StreamUnaryCall):
@ -610,3 +615,8 @@ class StreamStreamCall(_StreamRequestMixin, _StreamResponseMixin, Call,
if not self.cancelled():
self.cancel()
# No need to raise RpcError here, because no one will `await` this task.
async def try_connect(self) -> None:
await self._metadata_sent.wait()
if self.done():
await self._raise_for_status()

@ -28,5 +28,6 @@
"unit.server_interceptor_test.TestServerInterceptor",
"unit.server_test.TestServer",
"unit.timeout_test.TestTimeout",
"unit.try_connect_test.TestTryConnect",
"unit.wait_for_ready_test.TestWaitForReady"
]

@ -16,22 +16,22 @@
import asyncio
import logging
import unittest
import datetime
import grpc
from grpc.experimental import aio
from src.proto.grpc.testing import messages_pb2, test_pb2_grpc
from tests.unit.framework.common import test_constants
from tests_aio.unit._test_base import AioTestBase
from tests.unit import resources
from tests_aio.unit._test_server import start_test_server
_SHORT_TIMEOUT_S = datetime.timedelta(seconds=1).total_seconds()
_NUM_STREAM_RESPONSES = 5
_RESPONSE_PAYLOAD_SIZE = 42
_REQUEST_PAYLOAD_SIZE = 7
_LOCAL_CANCEL_DETAILS_EXPECTATION = 'Locally cancelled by application!'
_RESPONSE_INTERVAL_US = test_constants.SHORT_TIMEOUT * 1000 * 1000
_RESPONSE_INTERVAL_US = int(_SHORT_TIMEOUT_S * 1000 * 1000)
_UNREACHABLE_TARGET = '0.1:1111'
_INFINITE_INTERVAL_US = 2**31 - 1
@ -434,24 +434,24 @@ class TestUnaryStreamCall(_MulticallableTestMixin, AioTestBase):
interval_us=_RESPONSE_INTERVAL_US,
))
call = self._stub.StreamingOutputCall(
request, timeout=test_constants.SHORT_TIMEOUT * 2)
call = self._stub.StreamingOutputCall(request,
timeout=_SHORT_TIMEOUT_S * 2)
response = await call.read()
self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
# Should be around the same as the timeout
remained_time = call.time_remaining()
self.assertGreater(remained_time, test_constants.SHORT_TIMEOUT * 3 / 2)
self.assertLess(remained_time, test_constants.SHORT_TIMEOUT * 5 / 2)
self.assertGreater(remained_time, _SHORT_TIMEOUT_S * 3 / 2)
self.assertLess(remained_time, _SHORT_TIMEOUT_S * 5 / 2)
response = await call.read()
self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
# Should be around the timeout minus a unit of wait time
remained_time = call.time_remaining()
self.assertGreater(remained_time, test_constants.SHORT_TIMEOUT / 2)
self.assertLess(remained_time, test_constants.SHORT_TIMEOUT * 3 / 2)
self.assertGreater(remained_time, _SHORT_TIMEOUT_S / 2)
self.assertLess(remained_time, _SHORT_TIMEOUT_S * 3 / 2)
self.assertEqual(grpc.StatusCode.OK, await call.code())
@ -538,14 +538,14 @@ class TestStreamUnaryCall(_MulticallableTestMixin, AioTestBase):
with self.assertRaises(asyncio.CancelledError):
for _ in range(_NUM_STREAM_RESPONSES):
yield request
await asyncio.sleep(test_constants.SHORT_TIMEOUT)
await asyncio.sleep(_SHORT_TIMEOUT_S)
request_iterator_received_the_exception.set()
call = self._stub.StreamingInputCall(request_iterator())
# Cancel the RPC after at least one response
async def cancel_later():
await asyncio.sleep(test_constants.SHORT_TIMEOUT * 2)
await asyncio.sleep(_SHORT_TIMEOUT_S * 2)
call.cancel()
cancel_later_task = self.loop.create_task(cancel_later())
@ -576,6 +576,33 @@ class TestStreamUnaryCall(_MulticallableTestMixin, AioTestBase):
self.assertEqual(await call.code(), grpc.StatusCode.OK)
async def test_call_rpc_error(self):
async with aio.insecure_channel(_UNREACHABLE_TARGET) as channel:
stub = test_pb2_grpc.TestServiceStub(channel)
# The error should be raised automatically without any traffic.
call = stub.StreamingInputCall()
with self.assertRaises(aio.AioRpcError) as exception_context:
await call
self.assertEqual(grpc.StatusCode.UNAVAILABLE,
exception_context.exception.code())
self.assertTrue(call.done())
self.assertEqual(grpc.StatusCode.UNAVAILABLE, await call.code())
async def test_timeout(self):
call = self._stub.StreamingInputCall(timeout=_SHORT_TIMEOUT_S)
# The error should be raised automatically without any traffic.
with self.assertRaises(aio.AioRpcError) as exception_context:
await call
rpc_error = exception_context.exception
self.assertEqual(grpc.StatusCode.DEADLINE_EXCEEDED, rpc_error.code())
self.assertTrue(call.done())
self.assertEqual(grpc.StatusCode.DEADLINE_EXCEEDED, await call.code())
# Prepares the request that stream in a ping-pong manner.
_STREAM_OUTPUT_REQUEST_ONE_RESPONSE = messages_pb2.StreamingOutputCallRequest()
@ -733,14 +760,14 @@ class TestStreamStreamCall(_MulticallableTestMixin, AioTestBase):
with self.assertRaises(asyncio.CancelledError):
for _ in range(_NUM_STREAM_RESPONSES):
yield request
await asyncio.sleep(test_constants.SHORT_TIMEOUT)
await asyncio.sleep(_SHORT_TIMEOUT_S)
request_iterator_received_the_exception.set()
call = self._stub.FullDuplexCall(request_iterator())
# Cancel the RPC after at least one response
async def cancel_later():
await asyncio.sleep(test_constants.SHORT_TIMEOUT * 2)
await asyncio.sleep(_SHORT_TIMEOUT_S * 2)
call.cancel()
cancel_later_task = self.loop.create_task(cancel_later())

@ -0,0 +1,114 @@
# Copyright 2020 The gRPC Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests behavior of the try connect API on client side."""
import asyncio
import logging
import unittest
import datetime
from typing import Callable, Tuple
import grpc
from grpc.experimental import aio
from tests_aio.unit._test_base import AioTestBase
from tests_aio.unit._test_server import start_test_server
from tests_aio.unit import _common
from src.proto.grpc.testing import messages_pb2, test_pb2_grpc
_REQUEST = b'\x01\x02\x03'
_UNREACHABLE_TARGET = '0.1:1111'
_TEST_METHOD = '/test/Test'
_NUM_STREAM_RESPONSES = 5
_REQUEST_PAYLOAD_SIZE = 7
_RESPONSE_PAYLOAD_SIZE = 42
class TestTryConnect(AioTestBase):
"""Tests if try connect raises connectivity issue."""
async def setUp(self):
address, self._server = await start_test_server()
self._channel = aio.insecure_channel(address)
self._dummy_channel = aio.insecure_channel(_UNREACHABLE_TARGET)
self._stub = test_pb2_grpc.TestServiceStub(self._channel)
async def tearDown(self):
await self._dummy_channel.close()
await self._channel.close()
await self._server.stop(None)
async def test_unary_stream_ok(self):
request = messages_pb2.StreamingOutputCallRequest()
for _ in range(_NUM_STREAM_RESPONSES):
request.response_parameters.append(
messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE))
call = self._stub.StreamingOutputCall(request)
# No exception raised and no message swallowed.
await call.try_connect()
response_cnt = 0
async for response in call:
response_cnt += 1
self.assertIs(type(response),
messages_pb2.StreamingOutputCallResponse)
self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
self.assertEqual(_NUM_STREAM_RESPONSES, response_cnt)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
async def test_stream_stream_ok(self):
call = self._stub.FullDuplexCall()
# No exception raised and no message swallowed.
await call.try_connect()
request = messages_pb2.StreamingOutputCallRequest()
request.response_parameters.append(
messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE))
for _ in range(_NUM_STREAM_RESPONSES):
await call.write(request)
response = await call.read()
self.assertIsInstance(response,
messages_pb2.StreamingOutputCallResponse)
self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
await call.done_writing()
self.assertEqual(grpc.StatusCode.OK, await call.code())
async def test_unary_stream_error(self):
call = self._dummy_channel.unary_stream(_TEST_METHOD)(_REQUEST)
with self.assertRaises(aio.AioRpcError) as exception_context:
await call.try_connect()
rpc_error = exception_context.exception
self.assertEqual(grpc.StatusCode.UNAVAILABLE, rpc_error.code())
async def test_stream_stream_error(self):
call = self._dummy_channel.stream_stream(_TEST_METHOD)()
with self.assertRaises(aio.AioRpcError) as exception_context:
await call.try_connect()
rpc_error = exception_context.exception
self.assertEqual(grpc.StatusCode.UNAVAILABLE, rpc_error.code())
if __name__ == '__main__':
logging.basicConfig(level=logging.DEBUG)
unittest.main(verbosity=2)
Loading…
Cancel
Save