|
|
|
@ -11,17 +11,24 @@ |
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
|
|
# See the License for the specific language governing permissions and |
|
|
|
|
# limitations under the License. |
|
|
|
|
"""Test the functionality of server interceptors.""" |
|
|
|
|
|
|
|
|
|
import asyncio |
|
|
|
|
import functools |
|
|
|
|
import logging |
|
|
|
|
import unittest |
|
|
|
|
from typing import Callable, Awaitable, Any |
|
|
|
|
from typing import Any, Awaitable, Callable, Tuple |
|
|
|
|
|
|
|
|
|
import grpc |
|
|
|
|
from grpc.experimental import aio, wrap_server_method_handler |
|
|
|
|
|
|
|
|
|
from grpc.experimental import aio |
|
|
|
|
|
|
|
|
|
from tests_aio.unit._test_server import start_test_server |
|
|
|
|
from src.proto.grpc.testing import messages_pb2, test_pb2_grpc |
|
|
|
|
from tests_aio.unit._test_base import AioTestBase |
|
|
|
|
from src.proto.grpc.testing import messages_pb2 |
|
|
|
|
from tests_aio.unit._test_server import start_test_server |
|
|
|
|
|
|
|
|
|
_NUM_STREAM_RESPONSES = 5 |
|
|
|
|
_REQUEST_PAYLOAD_SIZE = 7 |
|
|
|
|
_RESPONSE_PAYLOAD_SIZE = 42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _LoggingInterceptor(aio.ServerInterceptor): |
|
|
|
@ -73,6 +80,55 @@ def _filter_server_interceptor(condition: Callable, |
|
|
|
|
return _GenericInterceptor(intercept_service) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _CacheInterceptor(aio.ServerInterceptor): |
|
|
|
|
"""An interceptor that caches response based on request message.""" |
|
|
|
|
|
|
|
|
|
def __init__(self, cache_store=None): |
|
|
|
|
self.cache_store = cache_store or {} |
|
|
|
|
|
|
|
|
|
async def intercept_service( |
|
|
|
|
self, continuation: Callable[[grpc.HandlerCallDetails], Awaitable[ |
|
|
|
|
grpc.RpcMethodHandler]], |
|
|
|
|
handler_call_details: grpc.HandlerCallDetails |
|
|
|
|
) -> grpc.RpcMethodHandler: |
|
|
|
|
# Get the actual handler |
|
|
|
|
handler = await continuation(handler_call_details) |
|
|
|
|
|
|
|
|
|
# Only intercept unary call RPCs |
|
|
|
|
if handler and (handler.request_streaming or # pytype: disable=attribute-error |
|
|
|
|
handler.response_streaming): # pytype: disable=attribute-error |
|
|
|
|
return handler |
|
|
|
|
|
|
|
|
|
def wrapper(behavior: Callable[ |
|
|
|
|
[messages_pb2.SimpleRequest, aio. |
|
|
|
|
ServicerContext], messages_pb2.SimpleResponse]): |
|
|
|
|
|
|
|
|
|
@functools.wraps(behavior) |
|
|
|
|
async def wrapper(request: messages_pb2.SimpleRequest, |
|
|
|
|
context: aio.ServicerContext |
|
|
|
|
) -> messages_pb2.SimpleResponse: |
|
|
|
|
if request.response_size not in self.cache_store: |
|
|
|
|
self.cache_store[request.response_size] = await behavior( |
|
|
|
|
request, context) |
|
|
|
|
return self.cache_store[request.response_size] |
|
|
|
|
|
|
|
|
|
return wrapper |
|
|
|
|
|
|
|
|
|
return wrap_server_method_handler(wrapper, handler) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _create_server_stub_pair( |
|
|
|
|
*interceptors: aio.ServerInterceptor |
|
|
|
|
) -> Tuple[aio.Server, test_pb2_grpc.TestServiceStub]: |
|
|
|
|
"""Creates a server-stub pair with given interceptors. |
|
|
|
|
|
|
|
|
|
Returning the server object to protect it from being garbage collected. |
|
|
|
|
""" |
|
|
|
|
server_target, server = await start_test_server(interceptors=interceptors) |
|
|
|
|
channel = aio.insecure_channel(server_target) |
|
|
|
|
return server, test_pb2_grpc.TestServiceStub(channel) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestServerInterceptor(AioTestBase): |
|
|
|
|
|
|
|
|
|
async def test_invalid_interceptor(self): |
|
|
|
@ -162,6 +218,112 @@ class TestServerInterceptor(AioTestBase): |
|
|
|
|
'log2:intercept_service', |
|
|
|
|
], record) |
|
|
|
|
|
|
|
|
|
async def test_response_caching(self): |
|
|
|
|
# Prepares a preset value to help testing |
|
|
|
|
interceptor = _CacheInterceptor({ |
|
|
|
|
42: |
|
|
|
|
messages_pb2.SimpleResponse(payload=messages_pb2.Payload( |
|
|
|
|
body=b'\x42')) |
|
|
|
|
}) |
|
|
|
|
|
|
|
|
|
# Constructs a server with the cache interceptor |
|
|
|
|
server, stub = await _create_server_stub_pair(interceptor) |
|
|
|
|
|
|
|
|
|
# Tests if the cache store is used |
|
|
|
|
response = await stub.UnaryCall( |
|
|
|
|
messages_pb2.SimpleRequest(response_size=42)) |
|
|
|
|
self.assertEqual(1, len(interceptor.cache_store[42].payload.body)) |
|
|
|
|
self.assertEqual(interceptor.cache_store[42], response) |
|
|
|
|
|
|
|
|
|
# Tests response can be cached |
|
|
|
|
response = await stub.UnaryCall( |
|
|
|
|
messages_pb2.SimpleRequest(response_size=1337)) |
|
|
|
|
self.assertEqual(1337, len(interceptor.cache_store[1337].payload.body)) |
|
|
|
|
self.assertEqual(interceptor.cache_store[1337], response) |
|
|
|
|
response = await stub.UnaryCall( |
|
|
|
|
messages_pb2.SimpleRequest(response_size=1337)) |
|
|
|
|
self.assertEqual(interceptor.cache_store[1337], response) |
|
|
|
|
|
|
|
|
|
async def test_interceptor_unary_stream(self): |
|
|
|
|
record = [] |
|
|
|
|
server, stub = await _create_server_stub_pair( |
|
|
|
|
_LoggingInterceptor('log_unary_stream', record)) |
|
|
|
|
|
|
|
|
|
# Prepares the request |
|
|
|
|
request = messages_pb2.StreamingOutputCallRequest() |
|
|
|
|
for _ in range(_NUM_STREAM_RESPONSES): |
|
|
|
|
request.response_parameters.append( |
|
|
|
|
messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE,)) |
|
|
|
|
|
|
|
|
|
# Tests if the cache store is used |
|
|
|
|
call = stub.StreamingOutputCall(request) |
|
|
|
|
|
|
|
|
|
# Ensures the RPC goes fine |
|
|
|
|
async for response in call: |
|
|
|
|
self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) |
|
|
|
|
self.assertEqual(await call.code(), grpc.StatusCode.OK) |
|
|
|
|
|
|
|
|
|
self.assertSequenceEqual([ |
|
|
|
|
'log_unary_stream:intercept_service', |
|
|
|
|
], record) |
|
|
|
|
|
|
|
|
|
async def test_interceptor_stream_unary(self): |
|
|
|
|
record = [] |
|
|
|
|
server, stub = await _create_server_stub_pair( |
|
|
|
|
_LoggingInterceptor('log_stream_unary', record)) |
|
|
|
|
|
|
|
|
|
# Invokes the actual RPC |
|
|
|
|
call = stub.StreamingInputCall() |
|
|
|
|
|
|
|
|
|
# Prepares the request |
|
|
|
|
payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) |
|
|
|
|
request = messages_pb2.StreamingInputCallRequest(payload=payload) |
|
|
|
|
|
|
|
|
|
# Sends out requests |
|
|
|
|
for _ in range(_NUM_STREAM_RESPONSES): |
|
|
|
|
await call.write(request) |
|
|
|
|
await call.done_writing() |
|
|
|
|
|
|
|
|
|
# Validates the responses |
|
|
|
|
response = await call |
|
|
|
|
self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse) |
|
|
|
|
self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE, |
|
|
|
|
response.aggregated_payload_size) |
|
|
|
|
|
|
|
|
|
self.assertEqual(await call.code(), grpc.StatusCode.OK) |
|
|
|
|
|
|
|
|
|
self.assertSequenceEqual([ |
|
|
|
|
'log_stream_unary:intercept_service', |
|
|
|
|
], record) |
|
|
|
|
|
|
|
|
|
async def test_interceptor_stream_stream(self): |
|
|
|
|
record = [] |
|
|
|
|
server, stub = await _create_server_stub_pair( |
|
|
|
|
_LoggingInterceptor('log_stream_stream', record)) |
|
|
|
|
|
|
|
|
|
# Prepares the request |
|
|
|
|
payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) |
|
|
|
|
request = messages_pb2.StreamingInputCallRequest(payload=payload) |
|
|
|
|
|
|
|
|
|
async def gen(): |
|
|
|
|
for _ in range(_NUM_STREAM_RESPONSES): |
|
|
|
|
yield request |
|
|
|
|
|
|
|
|
|
# Invokes the actual RPC |
|
|
|
|
call = stub.StreamingInputCall(gen()) |
|
|
|
|
|
|
|
|
|
# Validates the responses |
|
|
|
|
response = await call |
|
|
|
|
self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse) |
|
|
|
|
self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE, |
|
|
|
|
response.aggregated_payload_size) |
|
|
|
|
|
|
|
|
|
self.assertEqual(await call.code(), grpc.StatusCode.OK) |
|
|
|
|
|
|
|
|
|
self.assertSequenceEqual([ |
|
|
|
|
'log_stream_stream:intercept_service', |
|
|
|
|
], record) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
logging.basicConfig(level=logging.DEBUG) |
|
|
|
|