|
|
@ -13,6 +13,7 @@ |
|
|
|
# limitations under the License. |
|
|
|
# limitations under the License. |
|
|
|
"""Test the functionality of server interceptors.""" |
|
|
|
"""Test the functionality of server interceptors.""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import asyncio |
|
|
|
import functools |
|
|
|
import functools |
|
|
|
import logging |
|
|
|
import logging |
|
|
|
import unittest |
|
|
|
import unittest |
|
|
@ -79,6 +80,43 @@ def _filter_server_interceptor(condition: Callable, |
|
|
|
return _GenericInterceptor(intercept_service) |
|
|
|
return _GenericInterceptor(intercept_service) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _CacheInterceptor(aio.ServerInterceptor): |
|
|
|
|
|
|
|
"""An interceptor that caches response based on request message.""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, cache_store=None): |
|
|
|
|
|
|
|
self.cache_store = cache_store or {} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def intercept_service( |
|
|
|
|
|
|
|
self, continuation: Callable[[grpc.HandlerCallDetails], Awaitable[ |
|
|
|
|
|
|
|
grpc.RpcMethodHandler]], |
|
|
|
|
|
|
|
handler_call_details: grpc.HandlerCallDetails |
|
|
|
|
|
|
|
) -> grpc.RpcMethodHandler: |
|
|
|
|
|
|
|
# Get the actual handler |
|
|
|
|
|
|
|
handler = await continuation(handler_call_details) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Only intercept unary call RPCs |
|
|
|
|
|
|
|
if handler and (handler.request_streaming or |
|
|
|
|
|
|
|
handler.response_streaming): |
|
|
|
|
|
|
|
return handler |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def wrapper(behavior: Callable[ |
|
|
|
|
|
|
|
[messages_pb2.SimpleRequest, aio. |
|
|
|
|
|
|
|
ServicerContext], messages_pb2.SimpleResponse]): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@functools.wraps(behavior) |
|
|
|
|
|
|
|
async def wrapper(request: messages_pb2.SimpleRequest, |
|
|
|
|
|
|
|
context: aio.ServicerContext |
|
|
|
|
|
|
|
) -> messages_pb2.SimpleResponse: |
|
|
|
|
|
|
|
if request.response_size not in self.cache_store: |
|
|
|
|
|
|
|
self.cache_store[request.response_size] = await behavior( |
|
|
|
|
|
|
|
request, context) |
|
|
|
|
|
|
|
return self.cache_store[request.response_size] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return wrapper |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return wrap_server_method_handler(wrapper, handler) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _create_server_stub_pair( |
|
|
|
async def _create_server_stub_pair( |
|
|
|
*interceptors: aio.ServerInterceptor |
|
|
|
*interceptors: aio.ServerInterceptor |
|
|
|
) -> Tuple[aio.Server, test_pb2_grpc.TestServiceStub]: |
|
|
|
) -> Tuple[aio.Server, test_pb2_grpc.TestServiceStub]: |
|
|
@ -182,55 +220,29 @@ class TestServerInterceptor(AioTestBase): |
|
|
|
|
|
|
|
|
|
|
|
async def test_response_caching(self): |
|
|
|
async def test_response_caching(self): |
|
|
|
# Prepares a preset value to help testing |
|
|
|
# Prepares a preset value to help testing |
|
|
|
cache_store = { |
|
|
|
interceptor = _CacheInterceptor({ |
|
|
|
42: |
|
|
|
42: |
|
|
|
messages_pb2.SimpleResponse(payload=messages_pb2.Payload( |
|
|
|
messages_pb2.SimpleResponse(payload=messages_pb2.Payload( |
|
|
|
body=b'\x42')) |
|
|
|
body=b'\x42')) |
|
|
|
} |
|
|
|
}) |
|
|
|
|
|
|
|
|
|
|
|
async def intercept_and_cache( |
|
|
|
|
|
|
|
continuation: Callable[[grpc.HandlerCallDetails], Awaitable[ |
|
|
|
|
|
|
|
grpc.RpcMethodHandler]], |
|
|
|
|
|
|
|
handler_call_details: grpc.HandlerCallDetails |
|
|
|
|
|
|
|
) -> grpc.RpcMethodHandler: |
|
|
|
|
|
|
|
# Get the actual handler |
|
|
|
|
|
|
|
handler = await continuation(handler_call_details) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def wrapper(behavior: Callable[ |
|
|
|
|
|
|
|
[messages_pb2.SimpleRequest, aio. |
|
|
|
|
|
|
|
ServerInterceptor], messages_pb2.SimpleResponse]): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@functools.wraps(behavior) |
|
|
|
|
|
|
|
async def wrapper(request: messages_pb2.SimpleRequest, |
|
|
|
|
|
|
|
context: aio.ServicerContext |
|
|
|
|
|
|
|
) -> messages_pb2.SimpleResponse: |
|
|
|
|
|
|
|
if request.response_size not in cache_store: |
|
|
|
|
|
|
|
cache_store[request.response_size] = await behavior( |
|
|
|
|
|
|
|
request, context) |
|
|
|
|
|
|
|
return cache_store[request.response_size] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return wrapper |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return wrap_server_method_handler(wrapper, handler) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Constructs a server with the cache interceptor |
|
|
|
# Constructs a server with the cache interceptor |
|
|
|
server, stub = await _create_server_stub_pair( |
|
|
|
server, stub = await _create_server_stub_pair(interceptor) |
|
|
|
_GenericInterceptor(intercept_and_cache)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Tests if the cache store is used |
|
|
|
# Tests if the cache store is used |
|
|
|
response = await stub.UnaryCall( |
|
|
|
response = await stub.UnaryCall( |
|
|
|
messages_pb2.SimpleRequest(response_size=42)) |
|
|
|
messages_pb2.SimpleRequest(response_size=42)) |
|
|
|
self.assertEqual(1, len(cache_store[42].payload.body)) |
|
|
|
self.assertEqual(1, len(interceptor.cache_store[42].payload.body)) |
|
|
|
self.assertEqual(cache_store[42], response) |
|
|
|
self.assertEqual(interceptor.cache_store[42], response) |
|
|
|
|
|
|
|
|
|
|
|
# Tests response can be cached |
|
|
|
# Tests response can be cached |
|
|
|
response = await stub.UnaryCall( |
|
|
|
response = await stub.UnaryCall( |
|
|
|
messages_pb2.SimpleRequest(response_size=1337)) |
|
|
|
messages_pb2.SimpleRequest(response_size=1337)) |
|
|
|
self.assertEqual(1337, len(cache_store[1337].payload.body)) |
|
|
|
self.assertEqual(1337, len(interceptor.cache_store[1337].payload.body)) |
|
|
|
self.assertEqual(cache_store[1337], response) |
|
|
|
self.assertEqual(interceptor.cache_store[1337], response) |
|
|
|
response = await stub.UnaryCall( |
|
|
|
response = await stub.UnaryCall( |
|
|
|
messages_pb2.SimpleRequest(response_size=1337)) |
|
|
|
messages_pb2.SimpleRequest(response_size=1337)) |
|
|
|
self.assertEqual(cache_store[1337], response) |
|
|
|
self.assertEqual(interceptor.cache_store[1337], response) |
|
|
|
|
|
|
|
|
|
|
|
async def test_interceptor_unary_stream(self): |
|
|
|
async def test_interceptor_unary_stream(self): |
|
|
|
record = [] |
|
|
|
record = [] |
|
|
|