diff --git a/src/python/grpcio_tests/tests_aio/unit/server_interceptor_test.py b/src/python/grpcio_tests/tests_aio/unit/server_interceptor_test.py index c93fe1b7b0e..047e1401891 100644 --- a/src/python/grpcio_tests/tests_aio/unit/server_interceptor_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/server_interceptor_test.py @@ -13,6 +13,7 @@ # limitations under the License. """Test the functionality of server interceptors.""" +import asyncio import functools import logging import unittest @@ -79,6 +80,43 @@ def _filter_server_interceptor(condition: Callable, return _GenericInterceptor(intercept_service) +class _CacheInterceptor(aio.ServerInterceptor): + """An interceptor that caches response based on request message.""" + + def __init__(self, cache_store=None): + self.cache_store = cache_store or {} + + async def intercept_service( + self, continuation: Callable[[grpc.HandlerCallDetails], Awaitable[ + grpc.RpcMethodHandler]], + handler_call_details: grpc.HandlerCallDetails + ) -> grpc.RpcMethodHandler: + # Get the actual handler + handler = await continuation(handler_call_details) + + # Only intercept unary call RPCs + if handler and (handler.request_streaming or + handler.response_streaming): + return handler + + def wrapper(behavior: Callable[ + [messages_pb2.SimpleRequest, aio. + ServicerContext], messages_pb2.SimpleResponse]): + + @functools.wraps(behavior) + async def wrapper(request: messages_pb2.SimpleRequest, + context: aio.ServicerContext + ) -> messages_pb2.SimpleResponse: + if request.response_size not in self.cache_store: + self.cache_store[request.response_size] = await behavior( + request, context) + return self.cache_store[request.response_size] + + return wrapper + + return wrap_server_method_handler(wrapper, handler) + + async def _create_server_stub_pair( *interceptors: aio.ServerInterceptor ) -> Tuple[aio.Server, test_pb2_grpc.TestServiceStub]: @@ -182,55 +220,29 @@ class TestServerInterceptor(AioTestBase): async def test_response_caching(self): # Prepares a preset value to help testing - cache_store = { + interceptor = _CacheInterceptor({ 42: messages_pb2.SimpleResponse(payload=messages_pb2.Payload( body=b'\x42')) - } - - async def intercept_and_cache( - continuation: Callable[[grpc.HandlerCallDetails], Awaitable[ - grpc.RpcMethodHandler]], - handler_call_details: grpc.HandlerCallDetails - ) -> grpc.RpcMethodHandler: - # Get the actual handler - handler = await continuation(handler_call_details) - - def wrapper(behavior: Callable[ - [messages_pb2.SimpleRequest, aio. - ServerInterceptor], messages_pb2.SimpleResponse]): - - @functools.wraps(behavior) - async def wrapper(request: messages_pb2.SimpleRequest, - context: aio.ServicerContext - ) -> messages_pb2.SimpleResponse: - if request.response_size not in cache_store: - cache_store[request.response_size] = await behavior( - request, context) - return cache_store[request.response_size] - - return wrapper - - return wrap_server_method_handler(wrapper, handler) + }) # Constructs a server with the cache interceptor - server, stub = await _create_server_stub_pair( - _GenericInterceptor(intercept_and_cache)) + server, stub = await _create_server_stub_pair(interceptor) # Tests if the cache store is used response = await stub.UnaryCall( messages_pb2.SimpleRequest(response_size=42)) - self.assertEqual(1, len(cache_store[42].payload.body)) - self.assertEqual(cache_store[42], response) + self.assertEqual(1, len(interceptor.cache_store[42].payload.body)) + self.assertEqual(interceptor.cache_store[42], response) # Tests response can be cached response = await stub.UnaryCall( messages_pb2.SimpleRequest(response_size=1337)) - self.assertEqual(1337, len(cache_store[1337].payload.body)) - self.assertEqual(cache_store[1337], response) + self.assertEqual(1337, len(interceptor.cache_store[1337].payload.body)) + self.assertEqual(interceptor.cache_store[1337], response) response = await stub.UnaryCall( messages_pb2.SimpleRequest(response_size=1337)) - self.assertEqual(cache_store[1337], response) + self.assertEqual(interceptor.cache_store[1337], response) async def test_interceptor_unary_stream(self): record = []