fix sanity checks

pull/22032/head
Zhanghui Mao 5 years ago
parent 99e26eb647
commit 26985fd722
  1. 10
      src/python/grpcio/grpc/experimental/aio/_interceptor.py
  2. 14
      src/python/grpcio/grpc/experimental/aio/_server.py
  3. 5
      src/python/grpcio_tests/tests_aio/unit/_test_server.py
  4. 48
      src/python/grpcio_tests/tests_aio/unit/server_interceptor_test.py

@ -37,11 +37,11 @@ class ServerInterceptor(metaclass=ABCMeta):
"""
@abstractmethod
async def intercept_service(self,
continuation: Callable[
[grpc.HandlerCallDetails], grpc.RpcMethodHandler],
handler_call_details: grpc.HandlerCallDetails
) -> grpc.RpcMethodHandler:
async def intercept_service(
self, continuation: Callable[[grpc.HandlerCallDetails], grpc.
RpcMethodHandler],
handler_call_details: grpc.HandlerCallDetails
) -> grpc.RpcMethodHandler:
"""Intercepts incoming RPCs before handing them over to a handler.
Args:

@ -20,10 +20,10 @@ from typing import Any, Optional, Sequence
import grpc
from grpc import _common, _compression
from grpc._cython import cygrpc
from grpc.experimental.aio import ServerInterceptor
from . import _base_server
from ._typing import ChannelArgumentType
from ._interceptor import ServerInterceptor
def _augment_channel_arguments(base_options: ChannelArgumentType,
@ -43,12 +43,14 @@ class Server(_base_server.Server):
compression: Optional[grpc.Compression]):
self._loop = asyncio.get_event_loop()
if interceptors:
invalid_interceptors = [interceptor for interceptor in interceptors
if not isinstance(interceptor,
ServerInterceptor)]
invalid_interceptors = [
interceptor for interceptor in interceptors
if not isinstance(interceptor, ServerInterceptor)
]
if invalid_interceptors:
raise ValueError('Interceptor must be ServerInterceptor, the '
f'following are invalid: {invalid_interceptors}')
raise ValueError(
'Interceptor must be ServerInterceptor, the '
f'following are invalid: {invalid_interceptors}')
self._server = cygrpc.AioServer(
self._loop, thread_pool, generic_handlers, interceptors,
_augment_channel_arguments(options, compression),

@ -14,7 +14,6 @@
import asyncio
import datetime
import logging
import grpc
from grpc.experimental import aio
@ -117,7 +116,9 @@ def _create_extra_generic_handler(servicer: _TestServiceServicer):
rpc_method_handlers)
async def start_test_server(port=0, secure=False, server_credentials=None,
async def start_test_server(port=0,
secure=False,
server_credentials=None,
interceptors=None):
server = aio.server(options=(('grpc.so_reuseport', 0),),
interceptors=interceptors)

@ -1,4 +1,4 @@
# Copyright 2019 The gRPC Authors.
# Copyright 2020 The gRPC Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -44,9 +44,10 @@ class _GenericInterceptor(aio.ServerInterceptor):
return await self._fn(continuation, handler_call_details)
def _filter_server_interceptor(
condition: Callable,
interceptor: aio.ServerInterceptor) -> aio.ServerInterceptor:
def _filter_server_interceptor(condition: Callable,
interceptor: aio.ServerInterceptor
) -> aio.ServerInterceptor:
async def intercept_service(continuation, handler_call_details):
if condition(handler_call_details):
return await interceptor.intercept_service(continuation,
@ -59,6 +60,7 @@ def _filter_server_interceptor(
class TestServerInterceptor(AioTestBase):
async def test_invalid_interceptor(self):
class InvalidInterceptor:
"""Just an invalid Interceptor"""
@ -68,9 +70,10 @@ class TestServerInterceptor(AioTestBase):
async def test_executed_right_order(self):
record = []
server_target, _ = await start_test_server(
interceptors=(_LoggingInterceptor('log1', record),
_LoggingInterceptor('log2', record),))
server_target, _ = await start_test_server(interceptors=(
_LoggingInterceptor('log1', record),
_LoggingInterceptor('log2', record),
))
async with aio.insecure_channel(server_target) as channel:
multicallable = channel.unary_unary(
@ -82,8 +85,10 @@ class TestServerInterceptor(AioTestBase):
# Check that all interceptors were executed, and were executed
# in the right order.
self.assertSequenceEqual(['log1:intercept_service',
'log2:intercept_service',], record)
self.assertSequenceEqual([
'log1:intercept_service',
'log2:intercept_service',
], record)
self.assertIsInstance(response, messages_pb2.SimpleResponse)
async def test_response_ok(self):
@ -109,10 +114,11 @@ class TestServerInterceptor(AioTestBase):
conditional_interceptor = _filter_server_interceptor(
lambda x: ('secret', '42') in x.invocation_metadata,
_LoggingInterceptor('log3', record))
server_target, _ = await start_test_server(
interceptors=(_LoggingInterceptor('log1', record),
conditional_interceptor,
_LoggingInterceptor('log2', record),))
server_target, _ = await start_test_server(interceptors=(
_LoggingInterceptor('log1', record),
conditional_interceptor,
_LoggingInterceptor('log2', record),
))
async with aio.insecure_channel(server_target) as channel:
multicallable = channel.unary_unary(
@ -124,19 +130,21 @@ class TestServerInterceptor(AioTestBase):
call = multicallable(messages_pb2.SimpleRequest(),
metadata=metadata)
await call
self.assertSequenceEqual(['log1:intercept_service',
'log2:intercept_service',],
record)
self.assertSequenceEqual([
'log1:intercept_service',
'log2:intercept_service',
], record)
record.clear()
metadata = (('key', 'value'), ('secret', '42'))
call = multicallable(messages_pb2.SimpleRequest(),
metadata=metadata)
await call
self.assertSequenceEqual(['log1:intercept_service',
'log3:intercept_service',
'log2:intercept_service',],
record)
self.assertSequenceEqual([
'log1:intercept_service',
'log3:intercept_service',
'log2:intercept_service',
], record)
if __name__ == '__main__':

Loading…
Cancel
Save