Merge pull request #22032 from ZHmao/implement-server-interceptor-for-unary-unary-call
[Aio] Implement server interceptor for unary unary callpull/22370/head
commit
87d01bf9e5
9 changed files with 255 additions and 18 deletions
@ -0,0 +1,168 @@ |
|||||||
|
# Copyright 2020 The gRPC Authors |
||||||
|
# |
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||||
|
# you may not use this file except in compliance with the License. |
||||||
|
# You may obtain a copy of the License at |
||||||
|
# |
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||||
|
# |
||||||
|
# Unless required by applicable law or agreed to in writing, software |
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||||
|
# See the License for the specific language governing permissions and |
||||||
|
# limitations under the License. |
||||||
|
import logging |
||||||
|
import unittest |
||||||
|
from typing import Callable, Awaitable, Any |
||||||
|
|
||||||
|
import grpc |
||||||
|
|
||||||
|
from grpc.experimental import aio |
||||||
|
|
||||||
|
from tests_aio.unit._test_server import start_test_server |
||||||
|
from tests_aio.unit._test_base import AioTestBase |
||||||
|
from src.proto.grpc.testing import messages_pb2 |
||||||
|
|
||||||
|
|
||||||
|
class _LoggingInterceptor(aio.ServerInterceptor): |
||||||
|
|
||||||
|
def __init__(self, tag: str, record: list) -> None: |
||||||
|
self.tag = tag |
||||||
|
self.record = record |
||||||
|
|
||||||
|
async def intercept_service( |
||||||
|
self, continuation: Callable[[grpc.HandlerCallDetails], Awaitable[ |
||||||
|
grpc.RpcMethodHandler]], |
||||||
|
handler_call_details: grpc.HandlerCallDetails |
||||||
|
) -> grpc.RpcMethodHandler: |
||||||
|
self.record.append(self.tag + ':intercept_service') |
||||||
|
return await continuation(handler_call_details) |
||||||
|
|
||||||
|
|
||||||
|
class _GenericInterceptor(aio.ServerInterceptor): |
||||||
|
|
||||||
|
def __init__(self, fn: Callable[[ |
||||||
|
Callable[[grpc.HandlerCallDetails], Awaitable[grpc. |
||||||
|
RpcMethodHandler]], |
||||||
|
grpc.HandlerCallDetails |
||||||
|
], Any]) -> None: |
||||||
|
self._fn = fn |
||||||
|
|
||||||
|
async def intercept_service( |
||||||
|
self, continuation: Callable[[grpc.HandlerCallDetails], Awaitable[ |
||||||
|
grpc.RpcMethodHandler]], |
||||||
|
handler_call_details: grpc.HandlerCallDetails |
||||||
|
) -> grpc.RpcMethodHandler: |
||||||
|
return await self._fn(continuation, handler_call_details) |
||||||
|
|
||||||
|
|
||||||
|
def _filter_server_interceptor(condition: Callable, |
||||||
|
interceptor: aio.ServerInterceptor |
||||||
|
) -> aio.ServerInterceptor: |
||||||
|
|
||||||
|
async def intercept_service( |
||||||
|
continuation: Callable[[grpc.HandlerCallDetails], Awaitable[ |
||||||
|
grpc.RpcMethodHandler]], |
||||||
|
handler_call_details: grpc.HandlerCallDetails |
||||||
|
) -> grpc.RpcMethodHandler: |
||||||
|
if condition(handler_call_details): |
||||||
|
return await interceptor.intercept_service(continuation, |
||||||
|
handler_call_details) |
||||||
|
return await continuation(handler_call_details) |
||||||
|
|
||||||
|
return _GenericInterceptor(intercept_service) |
||||||
|
|
||||||
|
|
||||||
|
class TestServerInterceptor(AioTestBase): |
||||||
|
|
||||||
|
async def test_invalid_interceptor(self): |
||||||
|
|
||||||
|
class InvalidInterceptor: |
||||||
|
"""Just an invalid Interceptor""" |
||||||
|
|
||||||
|
with self.assertRaises(ValueError): |
||||||
|
server_target, _ = await start_test_server( |
||||||
|
interceptors=(InvalidInterceptor(),)) |
||||||
|
|
||||||
|
async def test_executed_right_order(self): |
||||||
|
record = [] |
||||||
|
server_target, _ = await start_test_server(interceptors=( |
||||||
|
_LoggingInterceptor('log1', record), |
||||||
|
_LoggingInterceptor('log2', record), |
||||||
|
)) |
||||||
|
|
||||||
|
async with aio.insecure_channel(server_target) as channel: |
||||||
|
multicallable = channel.unary_unary( |
||||||
|
'/grpc.testing.TestService/UnaryCall', |
||||||
|
request_serializer=messages_pb2.SimpleRequest.SerializeToString, |
||||||
|
response_deserializer=messages_pb2.SimpleResponse.FromString) |
||||||
|
call = multicallable(messages_pb2.SimpleRequest()) |
||||||
|
response = await call |
||||||
|
|
||||||
|
# Check that all interceptors were executed, and were executed |
||||||
|
# in the right order. |
||||||
|
self.assertSequenceEqual([ |
||||||
|
'log1:intercept_service', |
||||||
|
'log2:intercept_service', |
||||||
|
], record) |
||||||
|
self.assertIsInstance(response, messages_pb2.SimpleResponse) |
||||||
|
|
||||||
|
async def test_response_ok(self): |
||||||
|
record = [] |
||||||
|
server_target, _ = await start_test_server( |
||||||
|
interceptors=(_LoggingInterceptor('log1', record),)) |
||||||
|
|
||||||
|
async with aio.insecure_channel(server_target) as channel: |
||||||
|
multicallable = channel.unary_unary( |
||||||
|
'/grpc.testing.TestService/UnaryCall', |
||||||
|
request_serializer=messages_pb2.SimpleRequest.SerializeToString, |
||||||
|
response_deserializer=messages_pb2.SimpleResponse.FromString) |
||||||
|
call = multicallable(messages_pb2.SimpleRequest()) |
||||||
|
response = await call |
||||||
|
code = await call.code() |
||||||
|
|
||||||
|
self.assertSequenceEqual(['log1:intercept_service'], record) |
||||||
|
self.assertIsInstance(response, messages_pb2.SimpleResponse) |
||||||
|
self.assertEqual(code, grpc.StatusCode.OK) |
||||||
|
|
||||||
|
async def test_apply_different_interceptors_by_metadata(self): |
||||||
|
record = [] |
||||||
|
conditional_interceptor = _filter_server_interceptor( |
||||||
|
lambda x: ('secret', '42') in x.invocation_metadata, |
||||||
|
_LoggingInterceptor('log3', record)) |
||||||
|
server_target, _ = await start_test_server(interceptors=( |
||||||
|
_LoggingInterceptor('log1', record), |
||||||
|
conditional_interceptor, |
||||||
|
_LoggingInterceptor('log2', record), |
||||||
|
)) |
||||||
|
|
||||||
|
async with aio.insecure_channel(server_target) as channel: |
||||||
|
multicallable = channel.unary_unary( |
||||||
|
'/grpc.testing.TestService/UnaryCall', |
||||||
|
request_serializer=messages_pb2.SimpleRequest.SerializeToString, |
||||||
|
response_deserializer=messages_pb2.SimpleResponse.FromString) |
||||||
|
|
||||||
|
metadata = (('key', 'value'),) |
||||||
|
call = multicallable(messages_pb2.SimpleRequest(), |
||||||
|
metadata=metadata) |
||||||
|
await call |
||||||
|
self.assertSequenceEqual([ |
||||||
|
'log1:intercept_service', |
||||||
|
'log2:intercept_service', |
||||||
|
], record) |
||||||
|
|
||||||
|
record.clear() |
||||||
|
metadata = (('key', 'value'), ('secret', '42')) |
||||||
|
call = multicallable(messages_pb2.SimpleRequest(), |
||||||
|
metadata=metadata) |
||||||
|
await call |
||||||
|
self.assertSequenceEqual([ |
||||||
|
'log1:intercept_service', |
||||||
|
'log3:intercept_service', |
||||||
|
'log2:intercept_service', |
||||||
|
], record) |
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__': |
||||||
|
logging.basicConfig() |
||||||
|
unittest.main(verbosity=2) |
Loading…
Reference in new issue