[Aio] Add add_done_callback/done/cancelled methods to ServicerContext (#27767)

* WIP

* [Aio] Add add_done_callback method to aio server

* Allow >20 public methods on the ServicerContext class
pull/27854/head
Lidi Zheng 4 years ago committed by GitHub
parent 2bc96189b9
commit 09a55f26f9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 1
      src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pxd.pxi
  2. 31
      src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi
  3. 34
      src/python/grpcio/grpc/aio/_base_server.py
  4. 3
      src/python/grpcio_tests/tests_aio/tests.json
  5. 158
      src/python/grpcio_tests/tests_aio/unit/done_callback_test.py

@ -33,6 +33,7 @@ cdef class RPCState(GrpcCallWrapper):
cdef tuple trailing_metadata
cdef object compression_algorithm
cdef bint disable_next_compression
cdef object callbacks
cdef bytes method(self)
cdef tuple invocation_metadata(self)

@ -59,6 +59,7 @@ cdef class RPCState:
self.trailing_metadata = _IMMUTABLE_EMPTY_METADATA
self.compression_algorithm = None
self.disable_next_compression = False
self.callbacks = []
cdef bytes method(self):
return _slice_bytes(self.details.method)
@ -173,11 +174,16 @@ cdef class _ServicerContext:
if trailing_metadata == _IMMUTABLE_EMPTY_METADATA and self._rpc_state.trailing_metadata:
trailing_metadata = self._rpc_state.trailing_metadata
else:
self._rpc_state.trailing_metadata = trailing_metadata
if details == '' and self._rpc_state.status_details:
details = self._rpc_state.status_details
else:
self._rpc_state.status_details = details
actual_code = get_status_code(code)
self._rpc_state.status_code = actual_code
self._rpc_state.status_sent = True
await _send_error_status_from_server(
@ -267,6 +273,16 @@ cdef class _ServicerContext:
else:
return max(_time_from_timespec(self._rpc_state.details.deadline) - time.time(), 0)
def add_done_callback(self, callback):
cb = functools.partial(callback, self)
self._rpc_state.callbacks.append(cb)
def done(self):
return self._rpc_state.status_sent
def cancelled(self):
return self._rpc_state.status_code == StatusCode.cancelled
cdef class _SyncServicerContext:
"""Sync servicer context for sync handler compatibility."""
@ -697,6 +713,7 @@ async def _handle_exceptions(RPCState rpc_state, object rpc_coro, object loop):
else:
status_code = rpc_state.status_code
rpc_state.status_sent = True
await _send_error_status_from_server(
rpc_state,
status_code,
@ -707,6 +724,19 @@ async def _handle_exceptions(RPCState rpc_state, object rpc_coro, object loop):
)
cdef _add_callback_handler(object rpc_task, RPCState rpc_state):
def handle_callbacks(object unused_task):
try:
for callback in rpc_state.callbacks:
# The _ServicerContext object is bound in add_done_callback.
callback()
except:
_LOGGER.exception('Error in callback for method [%s]', _decode(rpc_state.method()))
rpc_task.add_done_callback(handle_callbacks)
async def _handle_cancellation_from_core(object rpc_task,
RPCState rpc_state,
object loop):
@ -733,6 +763,7 @@ async def _schedule_rpc_coro(object rpc_coro,
rpc_coro,
loop,
))
_add_callback_handler(rpc_task, rpc_state)
await _handle_cancellation_from_core(rpc_task, rpc_state, loop)

@ -19,6 +19,7 @@ from typing import Generic, Iterable, Mapping, Optional, Sequence
import grpc
from ._metadata import Metadata
from ._typing import DoneCallbackType
from ._typing import MetadataType
from ._typing import RequestType
from ._typing import ResponseType
@ -133,6 +134,7 @@ class Server(abc.ABC):
"""
# pylint: disable=too-many-public-methods
class ServicerContext(Generic[RequestType, ResponseType], abc.ABC):
"""A context object passed to method implementations."""
@ -337,3 +339,35 @@ class ServicerContext(Generic[RequestType, ResponseType], abc.ABC):
The details string of the RPC.
"""
raise NotImplementedError()
def add_done_callback(self, callback: DoneCallbackType) -> None:
"""Registers a callback to be called on RPC termination.
This is an EXPERIMENTAL API.
Args:
callback: A callable object will be called with the servicer context
object as its only argument.
"""
def cancelled(self) -> bool:
"""Return True if the RPC is cancelled.
The RPC is cancelled when the cancellation was requested with cancel().
This is an EXPERIMENTAL API.
Returns:
A bool indicates whether the RPC is cancelled or not.
"""
def done(self) -> bool:
"""Return True if the RPC is done.
An RPC is done if the RPC is completed, cancelled or aborted.
This is an EXPERIMENTAL API.
Returns:
A bool indicates if the RPC is done.
"""

@ -27,7 +27,8 @@
"unit.compression_test.TestCompression",
"unit.connectivity_test.TestConnectivityState",
"unit.context_peer_test.TestContextPeer",
"unit.done_callback_test.TestDoneCallback",
"unit.done_callback_test.TestClientSideDoneCallback",
"unit.done_callback_test.TestServerSideDoneCallback",
"unit.init_test.TestInit",
"unit.metadata_test.TestMetadata",
"unit.outside_init_test.TestOutsideInit",

@ -14,9 +14,7 @@
"""Testing the done callbacks mechanism."""
import asyncio
import gc
import logging
import time
import unittest
import grpc
@ -24,7 +22,6 @@ from grpc.experimental import aio
from src.proto.grpc.testing import messages_pb2
from src.proto.grpc.testing import test_pb2_grpc
from tests.unit.framework.common import test_constants
from tests_aio.unit._common import inject_callbacks
from tests_aio.unit._test_base import AioTestBase
from tests_aio.unit._test_server import start_test_server
@ -32,9 +29,13 @@ from tests_aio.unit._test_server import start_test_server
_NUM_STREAM_RESPONSES = 5
_REQUEST_PAYLOAD_SIZE = 7
_RESPONSE_PAYLOAD_SIZE = 42
_REQUEST = b'\x01\x02\x03'
_RESPONSE = b'\x04\x05\x06'
_TEST_METHOD = '/test/Test'
_FAKE_METHOD = '/test/Fake'
class TestDoneCallback(AioTestBase):
class TestClientSideDoneCallback(AioTestBase):
async def setUp(self):
address, self._server = await start_test_server()
@ -121,6 +122,155 @@ class TestDoneCallback(AioTestBase):
await validation
class TestServerSideDoneCallback(AioTestBase):
async def setUp(self):
self._server = aio.server()
port = self._server.add_insecure_port('[::]:0')
self._channel = aio.insecure_channel('localhost:%d' % port)
async def tearDown(self):
await self._channel.close()
await self._server.stop(None)
async def _register_method_handler(self, method_handler):
"""Registers method handler and starts the server"""
generic_handler = grpc.method_handlers_generic_handler(
'test',
dict(Test=method_handler),
)
self._server.add_generic_rpc_handlers((generic_handler,))
await self._server.start()
async def test_unary_unary(self):
validation_future = self.loop.create_future()
async def test_handler(request: bytes, context: aio.ServicerContext):
self.assertEqual(_REQUEST, request)
validation_future.set_result(inject_callbacks(context))
return _RESPONSE
await self._register_method_handler(
grpc.unary_unary_rpc_method_handler(test_handler))
response = await self._channel.unary_unary(_TEST_METHOD)(_REQUEST)
self.assertEqual(_RESPONSE, response)
validation = await validation_future
await validation
async def test_unary_stream(self):
validation_future = self.loop.create_future()
async def test_handler(request: bytes, context: aio.ServicerContext):
self.assertEqual(_REQUEST, request)
validation_future.set_result(inject_callbacks(context))
for _ in range(_NUM_STREAM_RESPONSES):
yield _RESPONSE
await self._register_method_handler(
grpc.unary_stream_rpc_method_handler(test_handler))
call = self._channel.unary_stream(_TEST_METHOD)(_REQUEST)
async for response in call:
self.assertEqual(_RESPONSE, response)
validation = await validation_future
await validation
async def test_stream_unary(self):
validation_future = self.loop.create_future()
async def test_handler(request_iterator, context: aio.ServicerContext):
validation_future.set_result(inject_callbacks(context))
async for request in request_iterator:
self.assertEqual(_REQUEST, request)
return _RESPONSE
await self._register_method_handler(
grpc.stream_unary_rpc_method_handler(test_handler))
call = self._channel.stream_unary(_TEST_METHOD)()
for _ in range(_NUM_STREAM_RESPONSES):
await call.write(_REQUEST)
await call.done_writing()
self.assertEqual(_RESPONSE, await call)
validation = await validation_future
await validation
async def test_stream_stream(self):
validation_future = self.loop.create_future()
async def test_handler(request_iterator, context: aio.ServicerContext):
validation_future.set_result(inject_callbacks(context))
async for request in request_iterator:
self.assertEqual(_REQUEST, request)
return _RESPONSE
await self._register_method_handler(
grpc.stream_stream_rpc_method_handler(test_handler))
call = self._channel.stream_stream(_TEST_METHOD)()
for _ in range(_NUM_STREAM_RESPONSES):
await call.write(_REQUEST)
await call.done_writing()
async for response in call:
self.assertEqual(_RESPONSE, response)
validation = await validation_future
await validation
async def test_error_in_handler(self):
"""Errors in the handler still triggers callbacks."""
validation_future = self.loop.create_future()
async def test_handler(request: bytes, context: aio.ServicerContext):
self.assertEqual(_REQUEST, request)
validation_future.set_result(inject_callbacks(context))
raise RuntimeError('A test RuntimeError')
await self._register_method_handler(
grpc.unary_unary_rpc_method_handler(test_handler))
with self.assertRaises(aio.AioRpcError) as exception_context:
await self._channel.unary_unary(_TEST_METHOD)(_REQUEST)
rpc_error = exception_context.exception
self.assertEqual(grpc.StatusCode.UNKNOWN, rpc_error.code())
validation = await validation_future
await validation
async def test_error_in_callback(self):
"""Errors in the callback won't be propagated to client."""
validation_future = self.loop.create_future()
async def test_handler(request: bytes, context: aio.ServicerContext):
self.assertEqual(_REQUEST, request)
def exception_raiser(unused_context):
raise RuntimeError('A test RuntimeError')
context.add_done_callback(exception_raiser)
validation_future.set_result(inject_callbacks(context))
return _RESPONSE
await self._register_method_handler(
grpc.unary_unary_rpc_method_handler(test_handler))
response = await self._channel.unary_unary(_TEST_METHOD)(_REQUEST)
self.assertEqual(_RESPONSE, response)
# Following callbacks won't be invoked, if one of the callback crashed.
validation = await validation_future
with self.assertRaises(asyncio.TimeoutError):
await validation
# Invoke RPC one more time to ensure the toxic callback won't break the
# server.
with self.assertRaises(aio.AioRpcError) as exception_context:
await self._channel.unary_unary(_FAKE_METHOD)(_REQUEST)
rpc_error = exception_context.exception
self.assertEqual(grpc.StatusCode.UNIMPLEMENTED, rpc_error.code())
if __name__ == '__main__':
logging.basicConfig(level=logging.DEBUG)
unittest.main(verbosity=2)

Loading…
Cancel
Save