|
|
@ -685,6 +685,110 @@ class TestInterceptedUnaryUnaryCall(AioTestBase): |
|
|
|
self.fail("Callback was not called") |
|
|
|
self.fail("Callback was not called") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _LoggingServerInterceptor(aio.ServerInterceptor): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, tag, record): |
|
|
|
|
|
|
|
self.tag = tag |
|
|
|
|
|
|
|
self.record = record |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def intercept_service(self, continuation, handler_call_details): |
|
|
|
|
|
|
|
self.record.append(self.tag + ':intercept_service') |
|
|
|
|
|
|
|
return await continuation(handler_call_details) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _GenericServerInterceptor(aio.ServerInterceptor): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, fn): |
|
|
|
|
|
|
|
self._fn = fn |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def intercept_service(self, continuation, handler_call_details): |
|
|
|
|
|
|
|
return await self._fn(continuation, handler_call_details) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _filter_server_interceptor(condition, interceptor): |
|
|
|
|
|
|
|
async def intercept_service(continuation, handler_call_details): |
|
|
|
|
|
|
|
if condition(handler_call_details): |
|
|
|
|
|
|
|
return await interceptor.intercept_service(continuation, |
|
|
|
|
|
|
|
handler_call_details) |
|
|
|
|
|
|
|
return await continuation(handler_call_details) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return _GenericServerInterceptor(intercept_service) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestServerInterceptor(AioTestBase): |
|
|
|
|
|
|
|
async def setUp(self) -> None: |
|
|
|
|
|
|
|
self._record = [] |
|
|
|
|
|
|
|
conditional_interceptor = _filter_server_interceptor( |
|
|
|
|
|
|
|
lambda x: ('secret', '42') in x.invocation_metadata, |
|
|
|
|
|
|
|
_LoggingServerInterceptor('log3', self._record)) |
|
|
|
|
|
|
|
self._interceptors = ( |
|
|
|
|
|
|
|
_LoggingServerInterceptor('log1', self._record), |
|
|
|
|
|
|
|
conditional_interceptor, |
|
|
|
|
|
|
|
_LoggingServerInterceptor('log2', self._record), |
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
self._server_target, self._server = await start_test_server( |
|
|
|
|
|
|
|
interceptors=self._interceptors) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def tearDown(self) -> None: |
|
|
|
|
|
|
|
self._server.stop(None) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def test_invalid_interceptor(self): |
|
|
|
|
|
|
|
class InvalidInterceptor: |
|
|
|
|
|
|
|
"""Just an invalid Interceptor""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with self.assertRaises(aio.AioRpcError): |
|
|
|
|
|
|
|
server_target, _ = await start_test_server( |
|
|
|
|
|
|
|
interceptors=(InvalidInterceptor(),)) |
|
|
|
|
|
|
|
channel = aio.insecure_channel(server_target) |
|
|
|
|
|
|
|
multicallable = channel.unary_unary( |
|
|
|
|
|
|
|
'/grpc.testing.TestService/UnaryCall', |
|
|
|
|
|
|
|
request_serializer=messages_pb2.SimpleRequest.SerializeToString, |
|
|
|
|
|
|
|
response_deserializer=messages_pb2.SimpleResponse.FromString) |
|
|
|
|
|
|
|
call = multicallable(messages_pb2.SimpleRequest()) |
|
|
|
|
|
|
|
await call |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def test_executed_right_order(self): |
|
|
|
|
|
|
|
self._record.clear() |
|
|
|
|
|
|
|
async with aio.insecure_channel(self._server_target) as channel: |
|
|
|
|
|
|
|
multicallable = channel.unary_unary( |
|
|
|
|
|
|
|
'/grpc.testing.TestService/UnaryCall', |
|
|
|
|
|
|
|
request_serializer=messages_pb2.SimpleRequest.SerializeToString, |
|
|
|
|
|
|
|
response_deserializer=messages_pb2.SimpleResponse.FromString) |
|
|
|
|
|
|
|
call = multicallable(messages_pb2.SimpleRequest()) |
|
|
|
|
|
|
|
response = await call |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Check that all interceptors were executed, and were executed |
|
|
|
|
|
|
|
# in the right order. |
|
|
|
|
|
|
|
self.assertSequenceEqual(['log1:intercept_service', |
|
|
|
|
|
|
|
'log2:intercept_service',], self._record) |
|
|
|
|
|
|
|
self.assertIsInstance(response, messages_pb2.SimpleResponse) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def test_apply_different_interceptors_by_metadata(self): |
|
|
|
|
|
|
|
async with aio.insecure_channel(self._server_target) as channel: |
|
|
|
|
|
|
|
multicallable = channel.unary_unary( |
|
|
|
|
|
|
|
'/grpc.testing.TestService/UnaryCall', |
|
|
|
|
|
|
|
request_serializer=messages_pb2.SimpleRequest.SerializeToString, |
|
|
|
|
|
|
|
response_deserializer=messages_pb2.SimpleResponse.FromString) |
|
|
|
|
|
|
|
self._record.clear() |
|
|
|
|
|
|
|
metadata = (('key', 'value'),) |
|
|
|
|
|
|
|
call = multicallable(messages_pb2.SimpleRequest(), |
|
|
|
|
|
|
|
metadata=metadata) |
|
|
|
|
|
|
|
await call |
|
|
|
|
|
|
|
self.assertSequenceEqual(['log1:intercept_service', |
|
|
|
|
|
|
|
'log2:intercept_service',], |
|
|
|
|
|
|
|
self._record) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self._record.clear() |
|
|
|
|
|
|
|
metadata = (('key', 'value'), ('secret', '42')) |
|
|
|
|
|
|
|
call = multicallable(messages_pb2.SimpleRequest(), |
|
|
|
|
|
|
|
metadata=metadata) |
|
|
|
|
|
|
|
await call |
|
|
|
|
|
|
|
self.assertSequenceEqual(['log1:intercept_service', |
|
|
|
|
|
|
|
'log3:intercept_service', |
|
|
|
|
|
|
|
'log2:intercept_service',], |
|
|
|
|
|
|
|
self._record) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
if __name__ == '__main__': |
|
|
|
logging.basicConfig() |
|
|
|
logging.basicConfig() |
|
|
|
unittest.main(verbosity=2) |
|
|
|
unittest.main(verbosity=2) |
|
|
|