Fixes bug with deadline

pull/21455/head
Pau Freixes 5 years ago
parent 75c858bcef
commit 2a342b22a7
  1. 32
      src/python/grpcio/grpc/experimental/aio/_interceptor.py
  2. 3
      src/python/grpcio/grpc/experimental/aio/_utils.py
  3. 28
      src/python/grpcio_tests/tests_aio/unit/_test_server.py
  4. 21
      src/python/grpcio_tests/tests_aio/unit/channel_test.py
  5. 79
      src/python/grpcio_tests/tests_aio/unit/interceptor_test.py

@ -168,12 +168,12 @@ class InterceptedUnaryUnaryCall(_base_call.UnaryUnaryCall):
try: try:
call = self._interceptors_task.result() call = self._interceptors_task.result()
except AioRpcError: except AioRpcError as err:
return False return err.code() == grpc.StatusCode.CANCELLED
except asyncio.CancelledError: except asyncio.CancelledError:
return True return True
else:
return call.cancelled() return call.cancelled()
def done(self) -> bool: def done(self) -> bool:
if not self._interceptors_task.done(): if not self._interceptors_task.done():
@ -183,8 +183,8 @@ class InterceptedUnaryUnaryCall(_base_call.UnaryUnaryCall):
call = self._interceptors_task.result() call = self._interceptors_task.result()
except (AioRpcError, asyncio.CancelledError): except (AioRpcError, asyncio.CancelledError):
return True return True
else:
return call.done() return call.done()
def add_done_callback(self, unused_callback) -> None: def add_done_callback(self, unused_callback) -> None:
raise NotImplementedError() raise NotImplementedError()
@ -199,8 +199,8 @@ class InterceptedUnaryUnaryCall(_base_call.UnaryUnaryCall):
return err.initial_metadata() return err.initial_metadata()
except asyncio.CancelledError: except asyncio.CancelledError:
return None return None
else:
return await call.initial_metadata() return await call.initial_metadata()
async def trailing_metadata(self) -> Optional[MetadataType]: async def trailing_metadata(self) -> Optional[MetadataType]:
try: try:
@ -209,8 +209,8 @@ class InterceptedUnaryUnaryCall(_base_call.UnaryUnaryCall):
return err.trailing_metadata() return err.trailing_metadata()
except asyncio.CancelledError: except asyncio.CancelledError:
return None return None
else:
return await call.trailing_metadata() return await call.trailing_metadata()
async def code(self) -> grpc.StatusCode: async def code(self) -> grpc.StatusCode:
try: try:
@ -219,8 +219,8 @@ class InterceptedUnaryUnaryCall(_base_call.UnaryUnaryCall):
return err.code() return err.code()
except asyncio.CancelledError: except asyncio.CancelledError:
return grpc.StatusCode.CANCELLED return grpc.StatusCode.CANCELLED
else:
return await call.code() return await call.code()
async def details(self) -> str: async def details(self) -> str:
try: try:
@ -229,8 +229,8 @@ class InterceptedUnaryUnaryCall(_base_call.UnaryUnaryCall):
return err.details() return err.details()
except asyncio.CancelledError: except asyncio.CancelledError:
return _LOCAL_CANCELLATION_DETAILS return _LOCAL_CANCELLATION_DETAILS
else:
return await call.details() return await call.details()
async def debug_error_string(self) -> Optional[str]: async def debug_error_string(self) -> Optional[str]:
try: try:
@ -239,8 +239,8 @@ class InterceptedUnaryUnaryCall(_base_call.UnaryUnaryCall):
return err.debug_error_string() return err.debug_error_string()
except asyncio.CancelledError: except asyncio.CancelledError:
return '' return ''
else:
return await call.debug_error_string() return await call.debug_error_string()
def __await__(self): def __await__(self):
call = yield from self._interceptors_task.__await__() call = yield from self._interceptors_task.__await__()

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
"""Internal utilities used by the gRPC Aio module.""" """Internal utilities used by the gRPC Aio module."""
import asyncio import asyncio
import time
from typing import Optional from typing import Optional
@ -20,4 +21,4 @@ def _timeout_to_deadline(loop: asyncio.AbstractEventLoop,
timeout: Optional[float]) -> Optional[float]: timeout: Optional[float]) -> Optional[float]:
if timeout is None: if timeout is None:
return None return None
return loop.time() + timeout return time.time() + timeout

@ -16,11 +16,14 @@ import asyncio
import logging import logging
import datetime import datetime
import grpc
from grpc.experimental import aio from grpc.experimental import aio
from tests.unit.framework.common import test_constants from tests.unit.framework.common import test_constants
from src.proto.grpc.testing import messages_pb2 from src.proto.grpc.testing import messages_pb2
from src.proto.grpc.testing import test_pb2_grpc from src.proto.grpc.testing import test_pb2_grpc
UNARY_CALL_WITH_SLEEP_VALUE = 0.2
class _TestServiceServicer(test_pb2_grpc.TestServiceServicer): class _TestServiceServicer(test_pb2_grpc.TestServiceServicer):
@ -39,11 +42,34 @@ class _TestServiceServicer(test_pb2_grpc.TestServiceServicer):
body=b'\x00' * body=b'\x00' *
response_parameters.size)) response_parameters.size))
# Next methods are extra ones that are registred programatically
# when the sever is instantiated. They are not being provided by
# the proto file.
async def UnaryCallWithSleep(self, request, context):
await asyncio.sleep(UNARY_CALL_WITH_SLEEP_VALUE)
return messages_pb2.SimpleResponse()
async def start_test_server(): async def start_test_server():
server = aio.server(options=(('grpc.so_reuseport', 0),)) server = aio.server(options=(('grpc.so_reuseport', 0),))
test_pb2_grpc.add_TestServiceServicer_to_server(_TestServiceServicer(), servicer = _TestServiceServicer()
test_pb2_grpc.add_TestServiceServicer_to_server(servicer,
server) server)
# Add programatically extra methods not provided by the proto file
# that are used during the tests
rpc_method_handlers = {
'UnaryCallWithSleep': grpc.unary_unary_rpc_method_handler(
servicer.UnaryCallWithSleep,
request_deserializer=messages_pb2.SimpleRequest.FromString,
response_serializer=messages_pb2.SimpleResponse.SerializeToString
)
}
extra_handler = grpc.method_handlers_generic_handler(
'grpc.testing.TestService', rpc_method_handlers)
server.add_generic_rpc_handlers((extra_handler,))
port = server.add_insecure_port('[::]:0') port = server.add_insecure_port('[::]:0')
await server.start() await server.start()
# NOTE(lidizheng) returning the server to prevent it from deallocation # NOTE(lidizheng) returning the server to prevent it from deallocation

@ -23,11 +23,12 @@ from grpc.experimental import aio
from src.proto.grpc.testing import messages_pb2 from src.proto.grpc.testing import messages_pb2
from src.proto.grpc.testing import test_pb2_grpc from src.proto.grpc.testing import test_pb2_grpc
from tests.unit.framework.common import test_constants from tests.unit.framework.common import test_constants
from tests_aio.unit._test_server import start_test_server from tests_aio.unit._test_server import start_test_server, UNARY_CALL_WITH_SLEEP_VALUE
from tests_aio.unit._test_base import AioTestBase from tests_aio.unit._test_base import AioTestBase
from src.proto.grpc.testing import messages_pb2 from src.proto.grpc.testing import messages_pb2
_UNARY_CALL_METHOD = '/grpc.testing.TestService/UnaryCall' _UNARY_CALL_METHOD = '/grpc.testing.TestService/UnaryCall'
_UNARY_CALL_METHOD_WITH_SLEEP = '/grpc.testing.TestService/UnaryCallWithSleep'
_STREAMING_OUTPUT_CALL_METHOD = '/grpc.testing.TestService/StreamingOutputCall' _STREAMING_OUTPUT_CALL_METHOD = '/grpc.testing.TestService/StreamingOutputCall'
_NUM_STREAM_RESPONSES = 5 _NUM_STREAM_RESPONSES = 5
_RESPONSE_PAYLOAD_SIZE = 42 _RESPONSE_PAYLOAD_SIZE = 42
@ -52,7 +53,6 @@ class TestChannel(AioTestBase):
async def test_unary_unary(self): async def test_unary_unary(self):
async with aio.insecure_channel(self._server_target) as channel: async with aio.insecure_channel(self._server_target) as channel:
channel = aio.insecure_channel(self._server_target)
hi = channel.unary_unary( hi = channel.unary_unary(
_UNARY_CALL_METHOD, _UNARY_CALL_METHOD,
request_serializer=messages_pb2.SimpleRequest.SerializeToString, request_serializer=messages_pb2.SimpleRequest.SerializeToString,
@ -62,15 +62,15 @@ class TestChannel(AioTestBase):
self.assertIsInstance(response, messages_pb2.SimpleResponse) self.assertIsInstance(response, messages_pb2.SimpleResponse)
async def test_unary_call_times_out(self): async def test_unary_call_times_out(self):
async with aio.insecure_channel(_UNREACHABLE_TARGET) as channel: async with aio.insecure_channel(self._server_target) as channel:
hi = channel.unary_unary( hi = channel.unary_unary(
_UNARY_CALL_METHOD, _UNARY_CALL_METHOD_WITH_SLEEP,
request_serializer=messages_pb2.SimpleRequest.SerializeToString, request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString, response_deserializer=messages_pb2.SimpleResponse.FromString,
) )
with self.assertRaises(grpc.RpcError) as exception_context: with self.assertRaises(grpc.RpcError) as exception_context:
await hi(messages_pb2.SimpleRequest(), timeout=1.0) await hi(messages_pb2.SimpleRequest(), timeout=UNARY_CALL_WITH_SLEEP_VALUE / 2)
_, details = grpc.StatusCode.DEADLINE_EXCEEDED.value # pylint: disable=unused-variable _, details = grpc.StatusCode.DEADLINE_EXCEEDED.value # pylint: disable=unused-variable
self.assertEqual(grpc.StatusCode.DEADLINE_EXCEEDED, self.assertEqual(grpc.StatusCode.DEADLINE_EXCEEDED,
@ -81,6 +81,17 @@ class TestChannel(AioTestBase):
self.assertIsNotNone( self.assertIsNotNone(
exception_context.exception.trailing_metadata()) exception_context.exception.trailing_metadata())
async def test_unary_call_does_not_times_out(self):
async with aio.insecure_channel(self._server_target) as channel:
hi = channel.unary_unary(
_UNARY_CALL_METHOD_WITH_SLEEP,
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString,
)
call = hi(messages_pb2.SimpleRequest(), timeout=UNARY_CALL_WITH_SLEEP_VALUE * 2)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
async def test_unary_stream(self): async def test_unary_stream(self):
channel = aio.insecure_channel(self._server_target) channel = aio.insecure_channel(self._server_target)
stub = test_pb2_grpc.TestServiceStub(channel) stub = test_pb2_grpc.TestServiceStub(channel)

@ -18,15 +18,22 @@ import unittest
import grpc import grpc
from grpc.experimental import aio from grpc.experimental import aio
from tests_aio.unit._test_server import start_test_server from tests_aio.unit._test_server import start_test_server, UNARY_CALL_WITH_SLEEP_VALUE
from tests_aio.unit._test_base import AioTestBase from tests_aio.unit._test_base import AioTestBase
from src.proto.grpc.testing import messages_pb2 from src.proto.grpc.testing import messages_pb2
_LOCAL_CANCEL_DETAILS_EXPECTATION = 'Locally cancelled by application!' _LOCAL_CANCEL_DETAILS_EXPECTATION = 'Locally cancelled by application!'
class TestUnaryUnaryClientInterceptor(AioTestBase): class TestUnaryUnaryClientInterceptor(AioTestBase):
async def setUp(self):
self._server_target, self._server = await start_test_server()
async def tearDown(self):
await self._server.stop(None)
def test_invalid_interceptor(self): def test_invalid_interceptor(self):
class InvalidInterceptor: class InvalidInterceptor:
@ -50,9 +57,7 @@ class TestUnaryUnaryClientInterceptor(AioTestBase):
interceptors = [Interceptor() for i in range(2)] interceptors = [Interceptor() for i in range(2)]
server_target, _ = await start_test_server() # pylint: disable=unused-variable async with aio.insecure_channel(self._server_target,
async with aio.insecure_channel(server_target,
interceptors=interceptors) as channel: interceptors=interceptors) as channel:
multicallable = channel.unary_unary( multicallable = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall', '/grpc.testing.TestService/UnaryCall',
@ -97,9 +102,8 @@ class TestUnaryUnaryClientInterceptor(AioTestBase):
return call return call
interceptor = StatusCodeOkInterceptor() interceptor = StatusCodeOkInterceptor()
server_target, server = await start_test_server() # pylint: disable=unused-variable
async with aio.insecure_channel(server_target, async with aio.insecure_channel(self._server_target,
interceptors=[interceptor]) as channel: interceptors=[interceptor]) as channel:
# when no error StatusCode.OK must be observed # when no error StatusCode.OK must be observed
@ -121,26 +125,23 @@ class TestUnaryUnaryClientInterceptor(AioTestBase):
client_call_details, request): client_call_details, request):
new_client_call_details = aio.ClientCallDetails( new_client_call_details = aio.ClientCallDetails(
method=client_call_details.method, method=client_call_details.method,
timeout=0.1, timeout=UNARY_CALL_WITH_SLEEP_VALUE / 2,
metadata=client_call_details.metadata, metadata=client_call_details.metadata,
credentials=client_call_details.credentials) credentials=client_call_details.credentials)
return await continuation(new_client_call_details, request) return await continuation(new_client_call_details, request)
interceptor = TimeoutInterceptor() interceptor = TimeoutInterceptor()
server_target, server = await start_test_server()
async with aio.insecure_channel(server_target, async with aio.insecure_channel(self._server_target,
interceptors=[interceptor]) as channel: interceptors=[interceptor]) as channel:
multicallable = channel.unary_unary( multicallable = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall', '/grpc.testing.TestService/UnaryCallWithSleep',
request_serializer=messages_pb2.SimpleRequest.SerializeToString, request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString) response_deserializer=messages_pb2.SimpleResponse.FromString)
call = multicallable(messages_pb2.SimpleRequest()) call = multicallable(messages_pb2.SimpleRequest())
await server.stop(None)
with self.assertRaises(aio.AioRpcError) as exception_context: with self.assertRaises(aio.AioRpcError) as exception_context:
await call await call
@ -165,7 +166,7 @@ class TestUnaryUnaryClientInterceptor(AioTestBase):
new_client_call_details = aio.ClientCallDetails( new_client_call_details = aio.ClientCallDetails(
method=client_call_details.method, method=client_call_details.method,
timeout=0.1, timeout=UNARY_CALL_WITH_SLEEP_VALUE / 2,
metadata=client_call_details.metadata, metadata=client_call_details.metadata,
credentials=client_call_details.credentials) credentials=client_call_details.credentials)
@ -188,13 +189,12 @@ class TestUnaryUnaryClientInterceptor(AioTestBase):
return call return call
interceptor = RetryInterceptor() interceptor = RetryInterceptor()
server_target, server = await start_test_server()
async with aio.insecure_channel(server_target, async with aio.insecure_channel(self._server_target,
interceptors=[interceptor]) as channel: interceptors=[interceptor]) as channel:
multicallable = channel.unary_unary( multicallable = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall', '/grpc.testing.TestService/UnaryCallWithSleep',
request_serializer=messages_pb2.SimpleRequest.SerializeToString, request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString) response_deserializer=messages_pb2.SimpleResponse.FromString)
@ -232,10 +232,9 @@ class TestUnaryUnaryClientInterceptor(AioTestBase):
return ResponseInterceptor.response return ResponseInterceptor.response
interceptor, interceptor_response = Interceptor(), ResponseInterceptor() interceptor, interceptor_response = Interceptor(), ResponseInterceptor()
server_target, server = await start_test_server()
async with aio.insecure_channel( async with aio.insecure_channel(
server_target, interceptors=[interceptor, self._server_target, interceptors=[interceptor,
interceptor_response]) as channel: interceptor_response]) as channel:
multicallable = channel.unary_unary( multicallable = channel.unary_unary(
@ -263,6 +262,12 @@ class TestUnaryUnaryClientInterceptor(AioTestBase):
class TestInterceptedUnaryUnaryCall(AioTestBase): class TestInterceptedUnaryUnaryCall(AioTestBase):
async def setUp(self):
self._server_target, self._server = await start_test_server()
async def tearDown(self):
await self._server.stop(None)
async def test_call_ok(self): async def test_call_ok(self):
class Interceptor(aio.UnaryUnaryClientInterceptor): class Interceptor(aio.UnaryUnaryClientInterceptor):
@ -272,9 +277,8 @@ class TestInterceptedUnaryUnaryCall(AioTestBase):
call = await continuation(client_call_details, request) call = await continuation(client_call_details, request)
return call return call
server_target, _ = await start_test_server() # pylint: disable=unused-variable
async with aio.insecure_channel(server_target, async with aio.insecure_channel(self._server_target,
interceptors=[Interceptor() interceptors=[Interceptor()
]) as channel: ]) as channel:
@ -303,9 +307,8 @@ class TestInterceptedUnaryUnaryCall(AioTestBase):
await call await call
return call return call
server_target, _ = await start_test_server() # pylint: disable=unused-variable
async with aio.insecure_channel(server_target, async with aio.insecure_channel(self._server_target,
interceptors=[Interceptor() interceptors=[Interceptor()
]) as channel: ]) as channel:
@ -333,20 +336,17 @@ class TestInterceptedUnaryUnaryCall(AioTestBase):
call = await continuation(client_call_details, request) call = await continuation(client_call_details, request)
return call return call
server_target, server = await start_test_server() # pylint: disable=unused-variable
async with aio.insecure_channel(server_target, async with aio.insecure_channel(self._server_target,
interceptors=[Interceptor() interceptors=[Interceptor()
]) as channel: ]) as channel:
multicallable = channel.unary_unary( multicallable = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall', '/grpc.testing.TestService/UnaryCallWithSleep',
request_serializer=messages_pb2.SimpleRequest.SerializeToString, request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString) response_deserializer=messages_pb2.SimpleResponse.FromString)
await server.stop(None) call = multicallable(messages_pb2.SimpleRequest(), timeout=UNARY_CALL_WITH_SLEEP_VALUE / 2)
call = multicallable(messages_pb2.SimpleRequest(), timeout=0.1)
with self.assertRaises(aio.AioRpcError) as exception_context: with self.assertRaises(aio.AioRpcError) as exception_context:
await call await call
@ -359,7 +359,7 @@ class TestInterceptedUnaryUnaryCall(AioTestBase):
self.assertEqual(await call.initial_metadata(), ()) self.assertEqual(await call.initial_metadata(), ())
self.assertEqual(await call.trailing_metadata(), ()) self.assertEqual(await call.trailing_metadata(), ())
async def test_call_rpcerror_awaited(self): async def test_call_rpc_error_awaited(self):
class Interceptor(aio.UnaryUnaryClientInterceptor): class Interceptor(aio.UnaryUnaryClientInterceptor):
@ -369,20 +369,17 @@ class TestInterceptedUnaryUnaryCall(AioTestBase):
await call await call
return call return call
server_target, server = await start_test_server() # pylint: disable=unused-variable
async with aio.insecure_channel(server_target, async with aio.insecure_channel(self._server_target,
interceptors=[Interceptor() interceptors=[Interceptor()
]) as channel: ]) as channel:
multicallable = channel.unary_unary( multicallable = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall', '/grpc.testing.TestService/UnaryCallWithSleep',
request_serializer=messages_pb2.SimpleRequest.SerializeToString, request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString) response_deserializer=messages_pb2.SimpleResponse.FromString)
await server.stop(None) call = multicallable(messages_pb2.SimpleRequest(), timeout=UNARY_CALL_WITH_SLEEP_VALUE / 2)
call = multicallable(messages_pb2.SimpleRequest(), timeout=0.1)
with self.assertRaises(aio.AioRpcError) as exception_context: with self.assertRaises(aio.AioRpcError) as exception_context:
await call await call
@ -409,9 +406,8 @@ class TestInterceptedUnaryUnaryCall(AioTestBase):
# This line should never be reached # This line should never be reached
raise Exception() raise Exception()
server_target, _ = await start_test_server() # pylint: disable=unused-variable
async with aio.insecure_channel(server_target, async with aio.insecure_channel(self._server_target,
interceptors=[Interceptor() interceptors=[Interceptor()
]) as channel: ]) as channel:
@ -454,9 +450,8 @@ class TestInterceptedUnaryUnaryCall(AioTestBase):
# This line should never be reached # This line should never be reached
raise Exception() raise Exception()
server_target, _ = await start_test_server() # pylint: disable=unused-variable
async with aio.insecure_channel(server_target, async with aio.insecure_channel(self._server_target,
interceptors=[Interceptor() interceptors=[Interceptor()
]) as channel: ]) as channel:
@ -494,9 +489,8 @@ class TestInterceptedUnaryUnaryCall(AioTestBase):
await call await call
return call return call
server_target, _ = await start_test_server() # pylint: disable=unused-variable
async with aio.insecure_channel(server_target, async with aio.insecure_channel(self._server_target,
interceptors=[Interceptor() interceptors=[Interceptor()
]) as channel: ]) as channel:
@ -527,9 +521,8 @@ class TestInterceptedUnaryUnaryCall(AioTestBase):
call.cancel() call.cancel()
return call return call
server_target, _ = await start_test_server() # pylint: disable=unused-variable
async with aio.insecure_channel(server_target, async with aio.insecure_channel(self._server_target,
interceptors=[Interceptor() interceptors=[Interceptor()
]) as channel: ]) as channel:

Loading…
Cancel
Save