Merge pull request #22032 from ZHmao/implement-server-interceptor-for-unary-unary-call
[Aio] Implement server interceptor for unary unary callpull/22370/head
commit
87d01bf9e5
9 changed files with 255 additions and 18 deletions
@ -0,0 +1,168 @@ |
||||
# Copyright 2020 The gRPC Authors |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
import logging |
||||
import unittest |
||||
from typing import Callable, Awaitable, Any |
||||
|
||||
import grpc |
||||
|
||||
from grpc.experimental import aio |
||||
|
||||
from tests_aio.unit._test_server import start_test_server |
||||
from tests_aio.unit._test_base import AioTestBase |
||||
from src.proto.grpc.testing import messages_pb2 |
||||
|
||||
|
||||
class _LoggingInterceptor(aio.ServerInterceptor): |
||||
|
||||
def __init__(self, tag: str, record: list) -> None: |
||||
self.tag = tag |
||||
self.record = record |
||||
|
||||
async def intercept_service( |
||||
self, continuation: Callable[[grpc.HandlerCallDetails], Awaitable[ |
||||
grpc.RpcMethodHandler]], |
||||
handler_call_details: grpc.HandlerCallDetails |
||||
) -> grpc.RpcMethodHandler: |
||||
self.record.append(self.tag + ':intercept_service') |
||||
return await continuation(handler_call_details) |
||||
|
||||
|
||||
class _GenericInterceptor(aio.ServerInterceptor): |
||||
|
||||
def __init__(self, fn: Callable[[ |
||||
Callable[[grpc.HandlerCallDetails], Awaitable[grpc. |
||||
RpcMethodHandler]], |
||||
grpc.HandlerCallDetails |
||||
], Any]) -> None: |
||||
self._fn = fn |
||||
|
||||
async def intercept_service( |
||||
self, continuation: Callable[[grpc.HandlerCallDetails], Awaitable[ |
||||
grpc.RpcMethodHandler]], |
||||
handler_call_details: grpc.HandlerCallDetails |
||||
) -> grpc.RpcMethodHandler: |
||||
return await self._fn(continuation, handler_call_details) |
||||
|
||||
|
||||
def _filter_server_interceptor(condition: Callable, |
||||
interceptor: aio.ServerInterceptor |
||||
) -> aio.ServerInterceptor: |
||||
|
||||
async def intercept_service( |
||||
continuation: Callable[[grpc.HandlerCallDetails], Awaitable[ |
||||
grpc.RpcMethodHandler]], |
||||
handler_call_details: grpc.HandlerCallDetails |
||||
) -> grpc.RpcMethodHandler: |
||||
if condition(handler_call_details): |
||||
return await interceptor.intercept_service(continuation, |
||||
handler_call_details) |
||||
return await continuation(handler_call_details) |
||||
|
||||
return _GenericInterceptor(intercept_service) |
||||
|
||||
|
||||
class TestServerInterceptor(AioTestBase): |
||||
|
||||
async def test_invalid_interceptor(self): |
||||
|
||||
class InvalidInterceptor: |
||||
"""Just an invalid Interceptor""" |
||||
|
||||
with self.assertRaises(ValueError): |
||||
server_target, _ = await start_test_server( |
||||
interceptors=(InvalidInterceptor(),)) |
||||
|
||||
async def test_executed_right_order(self): |
||||
record = [] |
||||
server_target, _ = await start_test_server(interceptors=( |
||||
_LoggingInterceptor('log1', record), |
||||
_LoggingInterceptor('log2', record), |
||||
)) |
||||
|
||||
async with aio.insecure_channel(server_target) as channel: |
||||
multicallable = channel.unary_unary( |
||||
'/grpc.testing.TestService/UnaryCall', |
||||
request_serializer=messages_pb2.SimpleRequest.SerializeToString, |
||||
response_deserializer=messages_pb2.SimpleResponse.FromString) |
||||
call = multicallable(messages_pb2.SimpleRequest()) |
||||
response = await call |
||||
|
||||
# Check that all interceptors were executed, and were executed |
||||
# in the right order. |
||||
self.assertSequenceEqual([ |
||||
'log1:intercept_service', |
||||
'log2:intercept_service', |
||||
], record) |
||||
self.assertIsInstance(response, messages_pb2.SimpleResponse) |
||||
|
||||
async def test_response_ok(self): |
||||
record = [] |
||||
server_target, _ = await start_test_server( |
||||
interceptors=(_LoggingInterceptor('log1', record),)) |
||||
|
||||
async with aio.insecure_channel(server_target) as channel: |
||||
multicallable = channel.unary_unary( |
||||
'/grpc.testing.TestService/UnaryCall', |
||||
request_serializer=messages_pb2.SimpleRequest.SerializeToString, |
||||
response_deserializer=messages_pb2.SimpleResponse.FromString) |
||||
call = multicallable(messages_pb2.SimpleRequest()) |
||||
response = await call |
||||
code = await call.code() |
||||
|
||||
self.assertSequenceEqual(['log1:intercept_service'], record) |
||||
self.assertIsInstance(response, messages_pb2.SimpleResponse) |
||||
self.assertEqual(code, grpc.StatusCode.OK) |
||||
|
||||
async def test_apply_different_interceptors_by_metadata(self): |
||||
record = [] |
||||
conditional_interceptor = _filter_server_interceptor( |
||||
lambda x: ('secret', '42') in x.invocation_metadata, |
||||
_LoggingInterceptor('log3', record)) |
||||
server_target, _ = await start_test_server(interceptors=( |
||||
_LoggingInterceptor('log1', record), |
||||
conditional_interceptor, |
||||
_LoggingInterceptor('log2', record), |
||||
)) |
||||
|
||||
async with aio.insecure_channel(server_target) as channel: |
||||
multicallable = channel.unary_unary( |
||||
'/grpc.testing.TestService/UnaryCall', |
||||
request_serializer=messages_pb2.SimpleRequest.SerializeToString, |
||||
response_deserializer=messages_pb2.SimpleResponse.FromString) |
||||
|
||||
metadata = (('key', 'value'),) |
||||
call = multicallable(messages_pb2.SimpleRequest(), |
||||
metadata=metadata) |
||||
await call |
||||
self.assertSequenceEqual([ |
||||
'log1:intercept_service', |
||||
'log2:intercept_service', |
||||
], record) |
||||
|
||||
record.clear() |
||||
metadata = (('key', 'value'), ('secret', '42')) |
||||
call = multicallable(messages_pb2.SimpleRequest(), |
||||
metadata=metadata) |
||||
await call |
||||
self.assertSequenceEqual([ |
||||
'log1:intercept_service', |
||||
'log3:intercept_service', |
||||
'log2:intercept_service', |
||||
], record) |
||||
|
||||
|
||||
if __name__ == '__main__': |
||||
logging.basicConfig() |
||||
unittest.main(verbosity=2) |
Loading…
Reference in new issue