From 26985fd722ae3612b94ca36d529558f37906a4bc Mon Sep 17 00:00:00 2001 From: Zhanghui Mao Date: Fri, 28 Feb 2020 22:06:31 +0800 Subject: [PATCH] fix sanity checks --- .../grpc/experimental/aio/_interceptor.py | 10 ++-- .../grpcio/grpc/experimental/aio/_server.py | 14 +++--- .../tests_aio/unit/_test_server.py | 5 +- .../tests_aio/unit/server_interceptor_test.py | 48 +++++++++++-------- 4 files changed, 44 insertions(+), 33 deletions(-) diff --git a/src/python/grpcio/grpc/experimental/aio/_interceptor.py b/src/python/grpcio/grpc/experimental/aio/_interceptor.py index 1678224c187..5de3367a86e 100644 --- a/src/python/grpcio/grpc/experimental/aio/_interceptor.py +++ b/src/python/grpcio/grpc/experimental/aio/_interceptor.py @@ -37,11 +37,11 @@ class ServerInterceptor(metaclass=ABCMeta): """ @abstractmethod - async def intercept_service(self, - continuation: Callable[ - [grpc.HandlerCallDetails], grpc.RpcMethodHandler], - handler_call_details: grpc.HandlerCallDetails - ) -> grpc.RpcMethodHandler: + async def intercept_service( + self, continuation: Callable[[grpc.HandlerCallDetails], grpc. + RpcMethodHandler], + handler_call_details: grpc.HandlerCallDetails + ) -> grpc.RpcMethodHandler: """Intercepts incoming RPCs before handing them over to a handler. Args: diff --git a/src/python/grpcio/grpc/experimental/aio/_server.py b/src/python/grpcio/grpc/experimental/aio/_server.py index f3447dd048d..587d8096c69 100644 --- a/src/python/grpcio/grpc/experimental/aio/_server.py +++ b/src/python/grpcio/grpc/experimental/aio/_server.py @@ -20,10 +20,10 @@ from typing import Any, Optional, Sequence import grpc from grpc import _common, _compression from grpc._cython import cygrpc -from grpc.experimental.aio import ServerInterceptor from . import _base_server from ._typing import ChannelArgumentType +from ._interceptor import ServerInterceptor def _augment_channel_arguments(base_options: ChannelArgumentType, @@ -43,12 +43,14 @@ class Server(_base_server.Server): compression: Optional[grpc.Compression]): self._loop = asyncio.get_event_loop() if interceptors: - invalid_interceptors = [interceptor for interceptor in interceptors - if not isinstance(interceptor, - ServerInterceptor)] + invalid_interceptors = [ + interceptor for interceptor in interceptors + if not isinstance(interceptor, ServerInterceptor) + ] if invalid_interceptors: - raise ValueError('Interceptor must be ServerInterceptor, the ' - f'following are invalid: {invalid_interceptors}') + raise ValueError( + 'Interceptor must be ServerInterceptor, the ' + f'following are invalid: {invalid_interceptors}') self._server = cygrpc.AioServer( self._loop, thread_pool, generic_handlers, interceptors, _augment_channel_arguments(options, compression), diff --git a/src/python/grpcio_tests/tests_aio/unit/_test_server.py b/src/python/grpcio_tests/tests_aio/unit/_test_server.py index 769e0841b7d..7c8afa8ff5c 100644 --- a/src/python/grpcio_tests/tests_aio/unit/_test_server.py +++ b/src/python/grpcio_tests/tests_aio/unit/_test_server.py @@ -14,7 +14,6 @@ import asyncio import datetime -import logging import grpc from grpc.experimental import aio @@ -117,7 +116,9 @@ def _create_extra_generic_handler(servicer: _TestServiceServicer): rpc_method_handlers) -async def start_test_server(port=0, secure=False, server_credentials=None, +async def start_test_server(port=0, + secure=False, + server_credentials=None, interceptors=None): server = aio.server(options=(('grpc.so_reuseport', 0),), interceptors=interceptors) diff --git a/src/python/grpcio_tests/tests_aio/unit/server_interceptor_test.py b/src/python/grpcio_tests/tests_aio/unit/server_interceptor_test.py index 9794bb1d62d..1680fc98b4a 100644 --- a/src/python/grpcio_tests/tests_aio/unit/server_interceptor_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/server_interceptor_test.py @@ -1,4 +1,4 @@ -# Copyright 2019 The gRPC Authors. +# Copyright 2020 The gRPC Authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -44,9 +44,10 @@ class _GenericInterceptor(aio.ServerInterceptor): return await self._fn(continuation, handler_call_details) -def _filter_server_interceptor( - condition: Callable, - interceptor: aio.ServerInterceptor) -> aio.ServerInterceptor: +def _filter_server_interceptor(condition: Callable, + interceptor: aio.ServerInterceptor + ) -> aio.ServerInterceptor: + async def intercept_service(continuation, handler_call_details): if condition(handler_call_details): return await interceptor.intercept_service(continuation, @@ -59,6 +60,7 @@ def _filter_server_interceptor( class TestServerInterceptor(AioTestBase): async def test_invalid_interceptor(self): + class InvalidInterceptor: """Just an invalid Interceptor""" @@ -68,9 +70,10 @@ class TestServerInterceptor(AioTestBase): async def test_executed_right_order(self): record = [] - server_target, _ = await start_test_server( - interceptors=(_LoggingInterceptor('log1', record), - _LoggingInterceptor('log2', record),)) + server_target, _ = await start_test_server(interceptors=( + _LoggingInterceptor('log1', record), + _LoggingInterceptor('log2', record), + )) async with aio.insecure_channel(server_target) as channel: multicallable = channel.unary_unary( @@ -82,8 +85,10 @@ class TestServerInterceptor(AioTestBase): # Check that all interceptors were executed, and were executed # in the right order. - self.assertSequenceEqual(['log1:intercept_service', - 'log2:intercept_service',], record) + self.assertSequenceEqual([ + 'log1:intercept_service', + 'log2:intercept_service', + ], record) self.assertIsInstance(response, messages_pb2.SimpleResponse) async def test_response_ok(self): @@ -109,10 +114,11 @@ class TestServerInterceptor(AioTestBase): conditional_interceptor = _filter_server_interceptor( lambda x: ('secret', '42') in x.invocation_metadata, _LoggingInterceptor('log3', record)) - server_target, _ = await start_test_server( - interceptors=(_LoggingInterceptor('log1', record), - conditional_interceptor, - _LoggingInterceptor('log2', record),)) + server_target, _ = await start_test_server(interceptors=( + _LoggingInterceptor('log1', record), + conditional_interceptor, + _LoggingInterceptor('log2', record), + )) async with aio.insecure_channel(server_target) as channel: multicallable = channel.unary_unary( @@ -124,19 +130,21 @@ class TestServerInterceptor(AioTestBase): call = multicallable(messages_pb2.SimpleRequest(), metadata=metadata) await call - self.assertSequenceEqual(['log1:intercept_service', - 'log2:intercept_service',], - record) + self.assertSequenceEqual([ + 'log1:intercept_service', + 'log2:intercept_service', + ], record) record.clear() metadata = (('key', 'value'), ('secret', '42')) call = multicallable(messages_pb2.SimpleRequest(), metadata=metadata) await call - self.assertSequenceEqual(['log1:intercept_service', - 'log3:intercept_service', - 'log2:intercept_service',], - record) + self.assertSequenceEqual([ + 'log1:intercept_service', + 'log3:intercept_service', + 'log2:intercept_service', + ], record) if __name__ == '__main__':