[Aio] Add time_remaining method to ServicerContext (#25719)

* [Aio] Add time_remaining method to ServicerContext

* Fix comments

* Resolve reviewer's requests
reviewable/pr25108/r13^2
Lidi Zheng 4 years ago committed by GitHub
parent 4693b9b1e5
commit 83b19b2efe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 9
      src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi
  2. 9
      src/python/grpcio/grpc/aio/_base_server.py
  3. 1
      src/python/grpcio_tests/tests_aio/tests.json
  4. 19
      src/python/grpcio_tests/tests_aio/unit/_common.py
  5. 51
      src/python/grpcio_tests/tests_aio/unit/compatibility_test.py
  6. 70
      src/python/grpcio_tests/tests_aio/unit/server_time_remaining_test.py

@ -252,6 +252,12 @@ cdef class _ServicerContext:
else:
return {}
def time_remaining(self):
if self._rpc_state.details.deadline.seconds == _GPR_INF_FUTURE.seconds:
return None
else:
return max(_time_from_timespec(self._rpc_state.details.deadline) - time.time(), 0)
cdef class _SyncServicerContext:
"""Sync servicer context for sync handler compatibility."""
@ -311,6 +317,9 @@ cdef class _SyncServicerContext:
def auth_context(self):
return self._context.auth_context()
def time_remaining(self):
return self._context.time_remaining()
async def _run_interceptor(object interceptors, object query_handler,
object handler_call_details):

@ -295,3 +295,12 @@ class ServicerContext(Generic[RequestType, ResponseType], abc.ABC):
Returns:
A map of strings to an iterable of bytes for each auth property.
"""
def time_remaining(self) -> float:
"""Describes the length of allowed time remaining for the RPC.
Returns:
A nonnegative float indicating the length of allowed time in seconds
remaining for the RPC to complete before it is considered to have
timed out, or None if no deadline was specified for the RPC.
"""

@ -36,6 +36,7 @@
"unit.secure_call_test.TestUnaryUnarySecureCall",
"unit.server_interceptor_test.TestServerInterceptor",
"unit.server_test.TestServer",
"unit.server_time_remaining_test.TestServerTimeRemaining",
"unit.timeout_test.TestTimeout",
"unit.wait_for_connection_test.TestWaitForConnection",
"unit.wait_for_ready_test.TestWaitForReady"

@ -21,6 +21,8 @@ from grpc.aio._metadata import Metadata
from tests.unit.framework.common import test_constants
ADHOC_METHOD = '/test/AdHoc'
def seen_metadata(expected: Metadata, actual: Metadata):
return not bool(set(tuple(expected)) - set(tuple(actual)))
@ -97,3 +99,20 @@ class CountingResponseIterator:
def __aiter__(self):
return self._forward_responses()
class AdhocGenericHandler(grpc.GenericRpcHandler):
"""A generic handler to plugin testing server methods on the fly."""
_handler: grpc.RpcMethodHandler
def __init__(self):
self._handler = None
def set_adhoc_handler(self, handler: grpc.RpcMethodHandler):
self._handler = handler
def service(self, handler_call_details):
if handler_call_details.method == ADHOC_METHOD:
return self._handler
else:
return None

@ -35,29 +35,12 @@ _NUM_STREAM_RESPONSES = 5
_REQUEST_PAYLOAD_SIZE = 7
_RESPONSE_PAYLOAD_SIZE = 42
_REQUEST = b'\x03\x07'
_ADHOC_METHOD = '/test/AdHoc'
def _unique_options() -> Sequence[Tuple[str, float]]:
return (('iv', random.random()),)
class _AdhocGenericHandler(grpc.GenericRpcHandler):
_handler: grpc.RpcMethodHandler
def __init__(self):
self._handler = None
def set_adhoc_handler(self, handler: grpc.RpcMethodHandler):
self._handler = handler
def service(self, handler_call_details):
if handler_call_details.method == _ADHOC_METHOD:
return self._handler
else:
return None
@unittest.skipIf(
os.environ.get('GRPC_ASYNCIO_ENGINE', '').lower() == 'custom_io_manager',
'Compatible mode needs POLLER completion queue.')
@ -70,7 +53,7 @@ class TestCompatibility(AioTestBase):
test_pb2_grpc.add_TestServiceServicer_to_server(TestServiceServicer(),
self._async_server)
self._adhoc_handlers = _AdhocGenericHandler()
self._adhoc_handlers = _common.AdhocGenericHandler()
self._async_server.add_generic_rpc_handlers((self._adhoc_handlers,))
port = self._async_server.add_insecure_port('[::]:0')
@ -240,8 +223,8 @@ class TestCompatibility(AioTestBase):
return request
self._adhoc_handlers.set_adhoc_handler(echo_unary_unary)
response = await self._async_channel.unary_unary(_ADHOC_METHOD)(_REQUEST
)
response = await self._async_channel.unary_unary(_common.ADHOC_METHOD
)(_REQUEST)
self.assertEqual(_REQUEST, response)
async def test_sync_unary_unary_metadata(self):
@ -253,7 +236,7 @@ class TestCompatibility(AioTestBase):
return request
self._adhoc_handlers.set_adhoc_handler(metadata_unary_unary)
call = self._async_channel.unary_unary(_ADHOC_METHOD)(_REQUEST)
call = self._async_channel.unary_unary(_common.ADHOC_METHOD)(_REQUEST)
self.assertTrue(
_common.seen_metadata(aio.Metadata(*metadata), await
call.initial_metadata()))
@ -266,7 +249,8 @@ class TestCompatibility(AioTestBase):
self._adhoc_handlers.set_adhoc_handler(abort_unary_unary)
with self.assertRaises(aio.AioRpcError) as exception_context:
await self._async_channel.unary_unary(_ADHOC_METHOD)(_REQUEST)
await self._async_channel.unary_unary(_common.ADHOC_METHOD
)(_REQUEST)
self.assertEqual(grpc.StatusCode.INTERNAL,
exception_context.exception.code())
@ -278,7 +262,8 @@ class TestCompatibility(AioTestBase):
self._adhoc_handlers.set_adhoc_handler(set_code_unary_unary)
with self.assertRaises(aio.AioRpcError) as exception_context:
await self._async_channel.unary_unary(_ADHOC_METHOD)(_REQUEST)
await self._async_channel.unary_unary(_common.ADHOC_METHOD
)(_REQUEST)
self.assertEqual(grpc.StatusCode.INTERNAL,
exception_context.exception.code())
@ -290,7 +275,7 @@ class TestCompatibility(AioTestBase):
yield request
self._adhoc_handlers.set_adhoc_handler(echo_unary_stream)
call = self._async_channel.unary_stream(_ADHOC_METHOD)(_REQUEST)
call = self._async_channel.unary_stream(_common.ADHOC_METHOD)(_REQUEST)
async for response in call:
self.assertEqual(_REQUEST, response)
@ -303,7 +288,7 @@ class TestCompatibility(AioTestBase):
raise RuntimeError('Test')
self._adhoc_handlers.set_adhoc_handler(error_unary_stream)
call = self._async_channel.unary_stream(_ADHOC_METHOD)(_REQUEST)
call = self._async_channel.unary_stream(_common.ADHOC_METHOD)(_REQUEST)
with self.assertRaises(aio.AioRpcError) as exception_context:
async for response in call:
self.assertEqual(_REQUEST, response)
@ -320,8 +305,8 @@ class TestCompatibility(AioTestBase):
self._adhoc_handlers.set_adhoc_handler(echo_stream_unary)
request_iterator = iter([_REQUEST] * _NUM_STREAM_RESPONSES)
response = await self._async_channel.stream_unary(_ADHOC_METHOD)(
request_iterator)
response = await self._async_channel.stream_unary(_common.ADHOC_METHOD
)(request_iterator)
self.assertEqual(_REQUEST, response)
async def test_sync_stream_unary_error(self):
@ -335,8 +320,8 @@ class TestCompatibility(AioTestBase):
self._adhoc_handlers.set_adhoc_handler(echo_stream_unary)
request_iterator = iter([_REQUEST] * _NUM_STREAM_RESPONSES)
with self.assertRaises(aio.AioRpcError) as exception_context:
response = await self._async_channel.stream_unary(_ADHOC_METHOD)(
request_iterator)
response = await self._async_channel.stream_unary(
_common.ADHOC_METHOD)(request_iterator)
self.assertEqual(grpc.StatusCode.UNKNOWN,
exception_context.exception.code())
@ -350,8 +335,8 @@ class TestCompatibility(AioTestBase):
self._adhoc_handlers.set_adhoc_handler(echo_stream_stream)
request_iterator = iter([_REQUEST] * _NUM_STREAM_RESPONSES)
call = self._async_channel.stream_stream(_ADHOC_METHOD)(
request_iterator)
call = self._async_channel.stream_stream(
_common.ADHOC_METHOD)(request_iterator)
async for response in call:
self.assertEqual(_REQUEST, response)
@ -366,8 +351,8 @@ class TestCompatibility(AioTestBase):
self._adhoc_handlers.set_adhoc_handler(echo_stream_stream)
request_iterator = iter([_REQUEST] * _NUM_STREAM_RESPONSES)
call = self._async_channel.stream_stream(_ADHOC_METHOD)(
request_iterator)
call = self._async_channel.stream_stream(
_common.ADHOC_METHOD)(request_iterator)
with self.assertRaises(aio.AioRpcError) as exception_context:
async for response in call:
self.assertEqual(_REQUEST, response)

@ -0,0 +1,70 @@
# Copyright 2021 The gRPC Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test the time_remaining() method of async ServicerContext."""
import asyncio
import logging
import unittest
import datetime
import grpc
from grpc import aio
from tests_aio.unit._common import ADHOC_METHOD, AdhocGenericHandler
from tests_aio.unit._test_base import AioTestBase
_REQUEST = b'\x09\x05'
_REQUEST_TIMEOUT_S = datetime.timedelta(seconds=5).total_seconds()
class TestServerTimeRemaining(AioTestBase):
async def setUp(self):
# Create async server
self._server = aio.server(options=(('grpc.so_reuseport', 0),))
self._adhoc_handlers = AdhocGenericHandler()
self._server.add_generic_rpc_handlers((self._adhoc_handlers,))
port = self._server.add_insecure_port('[::]:0')
address = 'localhost:%d' % port
await self._server.start()
# Create async channel
self._channel = aio.insecure_channel(address)
async def tearDown(self):
await self._channel.close()
await self._server.stop(None)
async def test_servicer_context_time_remaining(self):
seen_time_remaining = []
@grpc.unary_unary_rpc_method_handler
def log_time_remaining(request: bytes,
context: grpc.ServicerContext) -> bytes:
seen_time_remaining.append(context.time_remaining())
return b""
# Check if the deadline propagates properly
self._adhoc_handlers.set_adhoc_handler(log_time_remaining)
await self._channel.unary_unary(ADHOC_METHOD)(
_REQUEST, timeout=_REQUEST_TIMEOUT_S)
self.assertGreater(seen_time_remaining[0], _REQUEST_TIMEOUT_S / 2)
# Check if there is no timeout, the time_remaining will be None
self._adhoc_handlers.set_adhoc_handler(log_time_remaining)
await self._channel.unary_unary(ADHOC_METHOD)(_REQUEST)
self.assertIsNone(seen_time_remaining[1])
if __name__ == '__main__':
logging.basicConfig(level=logging.DEBUG)
unittest.main(verbosity=2)
Loading…
Cancel
Save