|
|
|
@ -1,4 +1,4 @@ |
|
|
|
|
# Copyright 2019 The gRPC Authors. |
|
|
|
|
# Copyright 2020 The gRPC Authors |
|
|
|
|
# |
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
|
|
|
# you may not use this file except in compliance with the License. |
|
|
|
@ -44,9 +44,10 @@ class _GenericInterceptor(aio.ServerInterceptor): |
|
|
|
|
return await self._fn(continuation, handler_call_details) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _filter_server_interceptor( |
|
|
|
|
condition: Callable, |
|
|
|
|
interceptor: aio.ServerInterceptor) -> aio.ServerInterceptor: |
|
|
|
|
def _filter_server_interceptor(condition: Callable, |
|
|
|
|
interceptor: aio.ServerInterceptor |
|
|
|
|
) -> aio.ServerInterceptor: |
|
|
|
|
|
|
|
|
|
async def intercept_service(continuation, handler_call_details): |
|
|
|
|
if condition(handler_call_details): |
|
|
|
|
return await interceptor.intercept_service(continuation, |
|
|
|
@ -59,6 +60,7 @@ def _filter_server_interceptor( |
|
|
|
|
class TestServerInterceptor(AioTestBase): |
|
|
|
|
|
|
|
|
|
async def test_invalid_interceptor(self): |
|
|
|
|
|
|
|
|
|
class InvalidInterceptor: |
|
|
|
|
"""Just an invalid Interceptor""" |
|
|
|
|
|
|
|
|
@ -68,9 +70,10 @@ class TestServerInterceptor(AioTestBase): |
|
|
|
|
|
|
|
|
|
async def test_executed_right_order(self): |
|
|
|
|
record = [] |
|
|
|
|
server_target, _ = await start_test_server( |
|
|
|
|
interceptors=(_LoggingInterceptor('log1', record), |
|
|
|
|
_LoggingInterceptor('log2', record),)) |
|
|
|
|
server_target, _ = await start_test_server(interceptors=( |
|
|
|
|
_LoggingInterceptor('log1', record), |
|
|
|
|
_LoggingInterceptor('log2', record), |
|
|
|
|
)) |
|
|
|
|
|
|
|
|
|
async with aio.insecure_channel(server_target) as channel: |
|
|
|
|
multicallable = channel.unary_unary( |
|
|
|
@ -82,8 +85,10 @@ class TestServerInterceptor(AioTestBase): |
|
|
|
|
|
|
|
|
|
# Check that all interceptors were executed, and were executed |
|
|
|
|
# in the right order. |
|
|
|
|
self.assertSequenceEqual(['log1:intercept_service', |
|
|
|
|
'log2:intercept_service',], record) |
|
|
|
|
self.assertSequenceEqual([ |
|
|
|
|
'log1:intercept_service', |
|
|
|
|
'log2:intercept_service', |
|
|
|
|
], record) |
|
|
|
|
self.assertIsInstance(response, messages_pb2.SimpleResponse) |
|
|
|
|
|
|
|
|
|
async def test_response_ok(self): |
|
|
|
@ -109,10 +114,11 @@ class TestServerInterceptor(AioTestBase): |
|
|
|
|
conditional_interceptor = _filter_server_interceptor( |
|
|
|
|
lambda x: ('secret', '42') in x.invocation_metadata, |
|
|
|
|
_LoggingInterceptor('log3', record)) |
|
|
|
|
server_target, _ = await start_test_server( |
|
|
|
|
interceptors=(_LoggingInterceptor('log1', record), |
|
|
|
|
conditional_interceptor, |
|
|
|
|
_LoggingInterceptor('log2', record),)) |
|
|
|
|
server_target, _ = await start_test_server(interceptors=( |
|
|
|
|
_LoggingInterceptor('log1', record), |
|
|
|
|
conditional_interceptor, |
|
|
|
|
_LoggingInterceptor('log2', record), |
|
|
|
|
)) |
|
|
|
|
|
|
|
|
|
async with aio.insecure_channel(server_target) as channel: |
|
|
|
|
multicallable = channel.unary_unary( |
|
|
|
@ -124,19 +130,21 @@ class TestServerInterceptor(AioTestBase): |
|
|
|
|
call = multicallable(messages_pb2.SimpleRequest(), |
|
|
|
|
metadata=metadata) |
|
|
|
|
await call |
|
|
|
|
self.assertSequenceEqual(['log1:intercept_service', |
|
|
|
|
'log2:intercept_service',], |
|
|
|
|
record) |
|
|
|
|
self.assertSequenceEqual([ |
|
|
|
|
'log1:intercept_service', |
|
|
|
|
'log2:intercept_service', |
|
|
|
|
], record) |
|
|
|
|
|
|
|
|
|
record.clear() |
|
|
|
|
metadata = (('key', 'value'), ('secret', '42')) |
|
|
|
|
call = multicallable(messages_pb2.SimpleRequest(), |
|
|
|
|
metadata=metadata) |
|
|
|
|
await call |
|
|
|
|
self.assertSequenceEqual(['log1:intercept_service', |
|
|
|
|
'log3:intercept_service', |
|
|
|
|
'log2:intercept_service',], |
|
|
|
|
record) |
|
|
|
|
self.assertSequenceEqual([ |
|
|
|
|
'log1:intercept_service', |
|
|
|
|
'log3:intercept_service', |
|
|
|
|
'log2:intercept_service', |
|
|
|
|
], record) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|