From 6fef56573e9a0347c33c5aea4bc57ab625b0c6ea Mon Sep 17 00:00:00 2001 From: Zhanghui Mao Date: Sat, 15 Feb 2020 23:39:34 +0800 Subject: [PATCH] Implement server interceptor for unary unary call --- .../grpc/_cython/_cygrpc/aio/server.pxd.pxi | 1 + .../grpc/_cython/_cygrpc/aio/server.pyx.pxi | 44 ++++++-- .../grpcio/grpc/experimental/aio/__init__.py | 6 +- .../grpc/experimental/aio/_interceptor.py | 28 +++++ src/python/grpcio_tests/tests_aio/tests.json | 1 + .../tests_aio/unit/_test_server.py | 6 +- .../tests_aio/unit/interceptor_test.py | 104 ++++++++++++++++++ 7 files changed, 175 insertions(+), 15 deletions(-) diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pxd.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pxd.pxi index d3edb70dafe..daf1ffaf72f 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pxd.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pxd.pxi @@ -61,3 +61,4 @@ cdef class AioServer: cdef CallbackWrapper _shutdown_callback_wrapper cdef object _crash_exception # Exception cdef set _ongoing_rpc_tasks + cdef tuple _interceptors diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi index 903c20796f7..44342b88c30 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi @@ -15,6 +15,7 @@ import inspect import traceback +import functools cdef int _EMPTY_FLAG = 0 @@ -214,15 +215,34 @@ cdef class _ServicerContext: self._rpc_state.disable_next_compression = True -cdef _find_method_handler(str method, tuple metadata, list generic_handlers): +async def _run_interceptor(object interceptors, object query_handler, + object handler_call_details): + interceptor = next(interceptors, None) + if interceptor: + continuation = functools.partial(_run_interceptor, interceptors, + query_handler) + return await interceptor.intercept_service(continuation, handler_call_details) + else: + return query_handler(handler_call_details) + + +async def _find_method_handler(str method, tuple metadata, list generic_handlers, + tuple interceptors): + def query_handlers(handler_call_details): + for generic_handler in generic_handlers: + method_handler = generic_handler.service(handler_call_details) + if method_handler is not None: + return method_handler + return None + cdef _HandlerCallDetails handler_call_details = _HandlerCallDetails(method, metadata) - - for generic_handler in generic_handlers: - method_handler = generic_handler.service(handler_call_details) - if method_handler is not None: - return method_handler - return None + # interceptor + if interceptors: + return await _run_interceptor(iter(interceptors), query_handlers, + handler_call_details) + else: + return query_handlers(handler_call_details) async def _finish_handler_with_unary_response(RPCState rpc_state, @@ -516,13 +536,15 @@ async def _schedule_rpc_coro(object rpc_coro, await _handle_cancellation_from_core(rpc_task, rpc_state, loop) -async def _handle_rpc(list generic_handlers, RPCState rpc_state, object loop): +async def _handle_rpc(list generic_handlers, tuple interceptors, + RPCState rpc_state, object loop): cdef object method_handler # Finds the method handler (application logic) - method_handler = _find_method_handler( + method_handler = await _find_method_handler( rpc_state.method().decode(), rpc_state.invocation_metadata(), generic_handlers, + interceptors, ) if method_handler is None: rpc_state.status_sent = True @@ -605,8 +627,9 @@ cdef class AioServer: SERVER_SHUTDOWN_FAILURE_HANDLER) self._crash_exception = None + self._interceptors = () if interceptors: - raise NotImplementedError() + self._interceptors = interceptors if maximum_concurrent_rpcs: raise NotImplementedError() if thread_pool: @@ -662,6 +685,7 @@ cdef class AioServer: # the coroutine onto event loop inside of the cancellation # coroutine. rpc_coro = _handle_rpc(self._generic_handlers, + self._interceptors, rpc_state, self._loop) diff --git a/src/python/grpcio/grpc/experimental/aio/__init__.py b/src/python/grpcio/grpc/experimental/aio/__init__.py index d8d284780c1..053a53ac413 100644 --- a/src/python/grpcio/grpc/experimental/aio/__init__.py +++ b/src/python/grpcio/grpc/experimental/aio/__init__.py @@ -27,7 +27,7 @@ from ._base_call import Call, RpcContext, UnaryStreamCall, UnaryUnaryCall from ._call import AioRpcError from ._channel import Channel, UnaryUnaryMultiCallable from ._interceptor import (ClientCallDetails, InterceptedUnaryUnaryCall, - UnaryUnaryClientInterceptor) + UnaryUnaryClientInterceptor, ServerInterceptor) from ._server import Server, server from ._typing import ChannelArgumentType @@ -86,5 +86,5 @@ __all__ = ('AioRpcError', 'RpcContext', 'Call', 'UnaryUnaryCall', 'UnaryStreamCall', 'init_grpc_aio', 'Channel', 'UnaryUnaryMultiCallable', 'ClientCallDetails', 'UnaryUnaryClientInterceptor', 'InterceptedUnaryUnaryCall', - 'insecure_channel', 'server', 'Server', 'EOF', 'secure_channel', - 'AbortError', 'BaseError', 'UsageError') + 'ServerInterceptor', 'insecure_channel', 'server', 'Server', 'EOF', + 'secure_channel', 'AbortError', 'BaseError', 'UsageError') diff --git a/src/python/grpcio/grpc/experimental/aio/_interceptor.py b/src/python/grpcio/grpc/experimental/aio/_interceptor.py index aca93fd468c..3ea0a2f62dc 100644 --- a/src/python/grpcio/grpc/experimental/aio/_interceptor.py +++ b/src/python/grpcio/grpc/experimental/aio/_interceptor.py @@ -30,6 +30,34 @@ from ._typing import (RequestType, SerializingFunction, DeserializingFunction, _LOCAL_CANCELLATION_DETAILS = 'Locally cancelled by application!' +class ServerInterceptor(metaclass=ABCMeta): + """Affords intercepting incoming RPCs on the service-side. + + This is an EXPERIMENTAL API. + """ + + @abstractmethod + async def intercept_service(self, + continuation: Callable[ + [grpc.HandlerCallDetails], grpc.RpcMethodHandler], + handler_call_details: grpc.HandlerCallDetails + ) -> grpc.RpcMethodHandler: + """Intercepts incoming RPCs before handing them over to a handler. + + Args: + continuation: A function that takes a HandlerCallDetails and + proceeds to invoke the next interceptor in the chain, if any, + or the RPC handler lookup logic, with the call details passed + as an argument, and returns an RpcMethodHandler instance if + the RPC is considered serviced, or None otherwise. + handler_call_details: A HandlerCallDetails describing the RPC. + + Returns: + An RpcMethodHandler with which the RPC may be serviced if the + interceptor chooses to service this RPC, or None otherwise. + """ + + class ClientCallDetails( collections.namedtuple( 'ClientCallDetails', diff --git a/src/python/grpcio_tests/tests_aio/tests.json b/src/python/grpcio_tests/tests_aio/tests.json index e05d64ac474..300865a0fb7 100644 --- a/src/python/grpcio_tests/tests_aio/tests.json +++ b/src/python/grpcio_tests/tests_aio/tests.json @@ -18,6 +18,7 @@ "unit.init_test.TestInsecureChannel", "unit.init_test.TestSecureChannel", "unit.interceptor_test.TestInterceptedUnaryUnaryCall", + "unit.interceptor_test.TestServerInterceptor", "unit.interceptor_test.TestUnaryUnaryClientInterceptor", "unit.metadata_test.TestMetadata", "unit.server_test.TestServer", diff --git a/src/python/grpcio_tests/tests_aio/unit/_test_server.py b/src/python/grpcio_tests/tests_aio/unit/_test_server.py index 32f988a6b90..769e0841b7d 100644 --- a/src/python/grpcio_tests/tests_aio/unit/_test_server.py +++ b/src/python/grpcio_tests/tests_aio/unit/_test_server.py @@ -117,8 +117,10 @@ def _create_extra_generic_handler(servicer: _TestServiceServicer): rpc_method_handlers) -async def start_test_server(port=0, secure=False, server_credentials=None): - server = aio.server(options=(('grpc.so_reuseport', 0),)) +async def start_test_server(port=0, secure=False, server_credentials=None, + interceptors=None): + server = aio.server(options=(('grpc.so_reuseport', 0),), + interceptors=interceptors) servicer = _TestServiceServicer() test_pb2_grpc.add_TestServiceServicer_to_server(servicer, server) diff --git a/src/python/grpcio_tests/tests_aio/unit/interceptor_test.py b/src/python/grpcio_tests/tests_aio/unit/interceptor_test.py index 9fa08a78806..accd5457429 100644 --- a/src/python/grpcio_tests/tests_aio/unit/interceptor_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/interceptor_test.py @@ -685,6 +685,110 @@ class TestInterceptedUnaryUnaryCall(AioTestBase): self.fail("Callback was not called") +class _LoggingServerInterceptor(aio.ServerInterceptor): + + def __init__(self, tag, record): + self.tag = tag + self.record = record + + async def intercept_service(self, continuation, handler_call_details): + self.record.append(self.tag + ':intercept_service') + return await continuation(handler_call_details) + + +class _GenericServerInterceptor(aio.ServerInterceptor): + + def __init__(self, fn): + self._fn = fn + + async def intercept_service(self, continuation, handler_call_details): + return await self._fn(continuation, handler_call_details) + + +def _filter_server_interceptor(condition, interceptor): + async def intercept_service(continuation, handler_call_details): + if condition(handler_call_details): + return await interceptor.intercept_service(continuation, + handler_call_details) + return await continuation(handler_call_details) + + return _GenericServerInterceptor(intercept_service) + + +class TestServerInterceptor(AioTestBase): + async def setUp(self) -> None: + self._record = [] + conditional_interceptor = _filter_server_interceptor( + lambda x: ('secret', '42') in x.invocation_metadata, + _LoggingServerInterceptor('log3', self._record)) + self._interceptors = ( + _LoggingServerInterceptor('log1', self._record), + conditional_interceptor, + _LoggingServerInterceptor('log2', self._record), + ) + self._server_target, self._server = await start_test_server( + interceptors=self._interceptors) + + async def tearDown(self) -> None: + self._server.stop(None) + + async def test_invalid_interceptor(self): + class InvalidInterceptor: + """Just an invalid Interceptor""" + + with self.assertRaises(aio.AioRpcError): + server_target, _ = await start_test_server( + interceptors=(InvalidInterceptor(),)) + channel = aio.insecure_channel(server_target) + multicallable = channel.unary_unary( + '/grpc.testing.TestService/UnaryCall', + request_serializer=messages_pb2.SimpleRequest.SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString) + call = multicallable(messages_pb2.SimpleRequest()) + await call + + async def test_executed_right_order(self): + self._record.clear() + async with aio.insecure_channel(self._server_target) as channel: + multicallable = channel.unary_unary( + '/grpc.testing.TestService/UnaryCall', + request_serializer=messages_pb2.SimpleRequest.SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString) + call = multicallable(messages_pb2.SimpleRequest()) + response = await call + + # Check that all interceptors were executed, and were executed + # in the right order. + self.assertSequenceEqual(['log1:intercept_service', + 'log2:intercept_service',], self._record) + self.assertIsInstance(response, messages_pb2.SimpleResponse) + + async def test_apply_different_interceptors_by_metadata(self): + async with aio.insecure_channel(self._server_target) as channel: + multicallable = channel.unary_unary( + '/grpc.testing.TestService/UnaryCall', + request_serializer=messages_pb2.SimpleRequest.SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString) + self._record.clear() + metadata = (('key', 'value'),) + call = multicallable(messages_pb2.SimpleRequest(), + metadata=metadata) + await call + self.assertSequenceEqual(['log1:intercept_service', + 'log2:intercept_service',], + self._record) + + self._record.clear() + metadata = (('key', 'value'), ('secret', '42')) + call = multicallable(messages_pb2.SimpleRequest(), + metadata=metadata) + await call + self.assertSequenceEqual(['log1:intercept_service', + 'log3:intercept_service', + 'log2:intercept_service',], + self._record) + + if __name__ == '__main__': logging.basicConfig() unittest.main(verbosity=2)