diff --git a/src/python/grpcio/grpc/_common.py b/src/python/grpcio/grpc/_common.py index 6a170ac6ce2..27fc80b1498 100644 --- a/src/python/grpcio/grpc/_common.py +++ b/src/python/grpcio/grpc/_common.py @@ -14,7 +14,6 @@ """Shared implementation.""" import logging - import time import six diff --git a/src/python/grpcio/grpc/experimental/__init__.py b/src/python/grpcio/grpc/experimental/__init__.py index 2c63908fe6d..eb642900d9c 100644 --- a/src/python/grpcio/grpc/experimental/__init__.py +++ b/src/python/grpcio/grpc/experimental/__init__.py @@ -16,6 +16,7 @@ These APIs are subject to be removed during any minor version release. """ +import copy import functools import sys import warnings @@ -78,11 +79,45 @@ def experimental_api(f): return _wrapper +def wrap_server_method_handler(wrapper, handler): + """Wraps the server method handler function. + + The server implementation requires all server handlers being wrapped as + RpcMethodHandler objects. This helper function ease the pain of writing + server handler wrappers. + + Args: + wrapper: A wrapper function that takes in a method handler behavior + (the actual function) and returns a wrapped function. + handler: A RpcMethodHandler object to be wrapped. + + Returns: + A newly created RpcMethodHandler. + """ + if not handler: + return None + + if not handler.request_streaming: + if not handler.response_streaming: + # NOTE(lidiz) _replace is a public API: + # https://docs.python.org/dev/library/collections.html + return handler._replace(unary_unary=wrapper(handler.unary_unary)) + else: + return handler._replace(unary_stream=wrapper(handler.unary_stream)) + else: + if not handler.response_streaming: + return handler._replace(stream_unary=wrapper(handler.stream_unary)) + else: + return handler._replace( + stream_stream=wrapper(handler.stream_stream)) + + __all__ = ( 'ChannelOptions', 'ExperimentalApiWarning', 'UsageError', 'insecure_channel_credentials', + 'wrap_server_method_handler', ) if sys.version_info[0] == 3 and sys.version_info[1] >= 6: diff --git a/src/python/grpcio_tests/tests_aio/unit/server_interceptor_test.py b/src/python/grpcio_tests/tests_aio/unit/server_interceptor_test.py index dabf005591f..f85e46c379a 100644 --- a/src/python/grpcio_tests/tests_aio/unit/server_interceptor_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/server_interceptor_test.py @@ -11,17 +11,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""Test the functionality of server interceptors.""" + +import asyncio +import functools import logging import unittest -from typing import Callable, Awaitable, Any +from typing import Any, Awaitable, Callable, Tuple import grpc +from grpc.experimental import aio, wrap_server_method_handler -from grpc.experimental import aio - -from tests_aio.unit._test_server import start_test_server +from src.proto.grpc.testing import messages_pb2, test_pb2_grpc from tests_aio.unit._test_base import AioTestBase -from src.proto.grpc.testing import messages_pb2 +from tests_aio.unit._test_server import start_test_server + +_NUM_STREAM_RESPONSES = 5 +_REQUEST_PAYLOAD_SIZE = 7 +_RESPONSE_PAYLOAD_SIZE = 42 class _LoggingInterceptor(aio.ServerInterceptor): @@ -73,6 +80,55 @@ def _filter_server_interceptor(condition: Callable, return _GenericInterceptor(intercept_service) +class _CacheInterceptor(aio.ServerInterceptor): + """An interceptor that caches response based on request message.""" + + def __init__(self, cache_store=None): + self.cache_store = cache_store or {} + + async def intercept_service( + self, continuation: Callable[[grpc.HandlerCallDetails], Awaitable[ + grpc.RpcMethodHandler]], + handler_call_details: grpc.HandlerCallDetails + ) -> grpc.RpcMethodHandler: + # Get the actual handler + handler = await continuation(handler_call_details) + + # Only intercept unary call RPCs + if handler and (handler.request_streaming or # pytype: disable=attribute-error + handler.response_streaming): # pytype: disable=attribute-error + return handler + + def wrapper(behavior: Callable[ + [messages_pb2.SimpleRequest, aio. + ServicerContext], messages_pb2.SimpleResponse]): + + @functools.wraps(behavior) + async def wrapper(request: messages_pb2.SimpleRequest, + context: aio.ServicerContext + ) -> messages_pb2.SimpleResponse: + if request.response_size not in self.cache_store: + self.cache_store[request.response_size] = await behavior( + request, context) + return self.cache_store[request.response_size] + + return wrapper + + return wrap_server_method_handler(wrapper, handler) + + +async def _create_server_stub_pair( + *interceptors: aio.ServerInterceptor +) -> Tuple[aio.Server, test_pb2_grpc.TestServiceStub]: + """Creates a server-stub pair with given interceptors. + + Returning the server object to protect it from being garbage collected. + """ + server_target, server = await start_test_server(interceptors=interceptors) + channel = aio.insecure_channel(server_target) + return server, test_pb2_grpc.TestServiceStub(channel) + + class TestServerInterceptor(AioTestBase): async def test_invalid_interceptor(self): @@ -162,6 +218,112 @@ class TestServerInterceptor(AioTestBase): 'log2:intercept_service', ], record) + async def test_response_caching(self): + # Prepares a preset value to help testing + interceptor = _CacheInterceptor({ + 42: + messages_pb2.SimpleResponse(payload=messages_pb2.Payload( + body=b'\x42')) + }) + + # Constructs a server with the cache interceptor + server, stub = await _create_server_stub_pair(interceptor) + + # Tests if the cache store is used + response = await stub.UnaryCall( + messages_pb2.SimpleRequest(response_size=42)) + self.assertEqual(1, len(interceptor.cache_store[42].payload.body)) + self.assertEqual(interceptor.cache_store[42], response) + + # Tests response can be cached + response = await stub.UnaryCall( + messages_pb2.SimpleRequest(response_size=1337)) + self.assertEqual(1337, len(interceptor.cache_store[1337].payload.body)) + self.assertEqual(interceptor.cache_store[1337], response) + response = await stub.UnaryCall( + messages_pb2.SimpleRequest(response_size=1337)) + self.assertEqual(interceptor.cache_store[1337], response) + + async def test_interceptor_unary_stream(self): + record = [] + server, stub = await _create_server_stub_pair( + _LoggingInterceptor('log_unary_stream', record)) + + # Prepares the request + request = messages_pb2.StreamingOutputCallRequest() + for _ in range(_NUM_STREAM_RESPONSES): + request.response_parameters.append( + messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE,)) + + # Tests if the cache store is used + call = stub.StreamingOutputCall(request) + + # Ensures the RPC goes fine + async for response in call: + self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) + self.assertEqual(await call.code(), grpc.StatusCode.OK) + + self.assertSequenceEqual([ + 'log_unary_stream:intercept_service', + ], record) + + async def test_interceptor_stream_unary(self): + record = [] + server, stub = await _create_server_stub_pair( + _LoggingInterceptor('log_stream_unary', record)) + + # Invokes the actual RPC + call = stub.StreamingInputCall() + + # Prepares the request + payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) + request = messages_pb2.StreamingInputCallRequest(payload=payload) + + # Sends out requests + for _ in range(_NUM_STREAM_RESPONSES): + await call.write(request) + await call.done_writing() + + # Validates the responses + response = await call + self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse) + self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE, + response.aggregated_payload_size) + + self.assertEqual(await call.code(), grpc.StatusCode.OK) + + self.assertSequenceEqual([ + 'log_stream_unary:intercept_service', + ], record) + + async def test_interceptor_stream_stream(self): + record = [] + server, stub = await _create_server_stub_pair( + _LoggingInterceptor('log_stream_stream', record)) + + # Prepares the request + payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) + request = messages_pb2.StreamingInputCallRequest(payload=payload) + + async def gen(): + for _ in range(_NUM_STREAM_RESPONSES): + yield request + + # Invokes the actual RPC + call = stub.StreamingInputCall(gen()) + + # Validates the responses + response = await call + self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse) + self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE, + response.aggregated_payload_size) + + self.assertEqual(await call.code(), grpc.StatusCode.OK) + + self.assertSequenceEqual([ + 'log_stream_stream:intercept_service', + ], record) + if __name__ == '__main__': logging.basicConfig(level=logging.DEBUG)