From 028a7c4e79fcddd202b8bdc95aa6742000f8f0ab Mon Sep 17 00:00:00 2001 From: Lidi Zheng Date: Mon, 11 May 2020 13:59:44 -0700 Subject: [PATCH 1/5] Add test cases for server interceptors --- src/python/grpcio/grpc/_common.py | 3 +- .../grpcio/grpc/experimental/__init__.py | 27 +++ .../tests_aio/unit/server_interceptor_test.py | 157 +++++++++++++++++- 3 files changed, 180 insertions(+), 7 deletions(-) diff --git a/src/python/grpcio/grpc/_common.py b/src/python/grpcio/grpc/_common.py index 6a170ac6ce2..3c455678ff0 100644 --- a/src/python/grpcio/grpc/_common.py +++ b/src/python/grpcio/grpc/_common.py @@ -14,11 +14,10 @@ """Shared implementation.""" import logging - import time -import six import grpc +import six from grpc._cython import cygrpc _LOGGER = logging.getLogger(__name__) diff --git a/src/python/grpcio/grpc/experimental/__init__.py b/src/python/grpcio/grpc/experimental/__init__.py index 2c63908fe6d..83f0b860eb7 100644 --- a/src/python/grpcio/grpc/experimental/__init__.py +++ b/src/python/grpcio/grpc/experimental/__init__.py @@ -16,6 +16,7 @@ These APIs are subject to be removed during any minor version release. """ +import copy import functools import sys import warnings @@ -78,11 +79,37 @@ def experimental_api(f): return _wrapper +def wrap_server_method_handler(wrapper, handler): + """Wraps the server method handler function. + + The server implementation requires all server handlers being wrapped as + RpcMethodHandler objects. This helper function ease the pain of writing + server handler wrappers. + """ + if not handler: + return None + + if not handler.request_streaming: + if not handler.response_streaming: + # NOTE(lidiz) _replace is a public API: + # https://docs.python.org/dev/library/collections.html#collections.somenamedtuple._replace + return handler._replace(unary_unary=wrapper(handler.unary_unary)) + else: + return handler._replace(unary_stream=wrapper(handler.unary_stream)) + else: + if not handler.response_streaming: + return handler._replace(stream_unary=wrapper(handler.stream_unary)) + else: + return handler._replace( + stream_stream=wrapper(handler.stream_stream)) + + __all__ = ( 'ChannelOptions', 'ExperimentalApiWarning', 'UsageError', 'insecure_channel_credentials', + 'wrap_server_method_handler', ) if sys.version_info[0] == 3 and sys.version_info[1] >= 6: diff --git a/src/python/grpcio_tests/tests_aio/unit/server_interceptor_test.py b/src/python/grpcio_tests/tests_aio/unit/server_interceptor_test.py index dabf005591f..ad56a44aa1c 100644 --- a/src/python/grpcio_tests/tests_aio/unit/server_interceptor_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/server_interceptor_test.py @@ -11,17 +11,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""Test the functionality of server interceptors.""" + +import functools import logging import unittest -from typing import Callable, Awaitable, Any +from typing import Any, Awaitable, Callable, Tuple import grpc +from grpc.experimental import aio, wrap_server_method_handler -from grpc.experimental import aio - -from tests_aio.unit._test_server import start_test_server +from src.proto.grpc.testing import messages_pb2, test_pb2_grpc from tests_aio.unit._test_base import AioTestBase -from src.proto.grpc.testing import messages_pb2 +from tests_aio.unit._test_server import start_test_server + +_NUM_STREAM_RESPONSES = 5 +_REQUEST_PAYLOAD_SIZE = 7 +_RESPONSE_PAYLOAD_SIZE = 42 class _LoggingInterceptor(aio.ServerInterceptor): @@ -73,6 +79,18 @@ def _filter_server_interceptor(condition: Callable, return _GenericInterceptor(intercept_service) +async def _create_server_stub_pair( + *interceptors: aio.ServerInterceptor +) -> Tuple[aio.Server, test_pb2_grpc.TestServiceStub]: + """Creates a server-stub pair with given interceptors. + + Returning the server object to protect it from being garbage collected. + """ + server_target, server = await start_test_server(interceptors=interceptors) + channel = aio.insecure_channel(server_target) + return server, test_pb2_grpc.TestServiceStub(channel) + + class TestServerInterceptor(AioTestBase): async def test_invalid_interceptor(self): @@ -162,6 +180,135 @@ class TestServerInterceptor(AioTestBase): 'log2:intercept_service', ], record) + async def test_response_caching(self): + # Prepares a preset value to help testing + cache_store = { + 42: + messages_pb2.SimpleResponse(payload=messages_pb2.Payload( + body=b'\x42')) + } + + async def intercept_and_cache( + continuation: Callable[[grpc.HandlerCallDetails], Awaitable[ + grpc.RpcMethodHandler]], + handler_call_details: grpc.HandlerCallDetails + ) -> grpc.RpcMethodHandler: + # Get the actual handler + handler = await continuation(handler_call_details) + + def wrap_handler(handler: grpc.RpcMethodHandler): + + @functools.wraps(handler) + async def wrapper(request: messages_pb2.SimpleRequest, + context: aio.ServicerContext): + if request.response_size not in cache_store: + cache_store[request.response_size] = await handler( + request, context) + return cache_store[request.response_size] + + return wrapper + + return wrap_server_method_handler(wrap_handler, handler) + + # Constructs a server with the cache interceptor + server, stub = await _create_server_stub_pair( + _GenericInterceptor(intercept_and_cache)) + + # Tests if the cache store is used + response = await stub.UnaryCall( + messages_pb2.SimpleRequest(response_size=42)) + self.assertEqual(1, len(cache_store[42].payload.body)) + self.assertEqual(cache_store[42], response) + + # Tests response can be cached + response = await stub.UnaryCall( + messages_pb2.SimpleRequest(response_size=1337)) + self.assertEqual(1337, len(cache_store[1337].payload.body)) + self.assertEqual(cache_store[1337], response) + response = await stub.UnaryCall( + messages_pb2.SimpleRequest(response_size=1337)) + self.assertEqual(cache_store[1337], response) + + async def test_interceptor_unary_stream(self): + record = [] + server, stub = await _create_server_stub_pair( + _LoggingInterceptor('log_unary_stream', record)) + + # Prepares the request + request = messages_pb2.StreamingOutputCallRequest() + for _ in range(_NUM_STREAM_RESPONSES): + request.response_parameters.append( + messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE,)) + + # Tests if the cache store is used + call = stub.StreamingOutputCall(request) + + # Ensures the RPC goes fine + async for response in call: + self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) + self.assertEqual(await call.code(), grpc.StatusCode.OK) + + self.assertSequenceEqual([ + 'log_unary_stream:intercept_service', + ], record) + + async def test_interceptor_stream_unary(self): + record = [] + server, stub = await _create_server_stub_pair( + _LoggingInterceptor('log_stream_unary', record)) + + # Invokes the actual RPC + call = stub.StreamingInputCall() + + # Prepares the request + payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) + request = messages_pb2.StreamingInputCallRequest(payload=payload) + + # Sends out requests + for _ in range(_NUM_STREAM_RESPONSES): + await call.write(request) + await call.done_writing() + + # Validates the responses + response = await call + self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse) + self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE, + response.aggregated_payload_size) + + self.assertEqual(await call.code(), grpc.StatusCode.OK) + + self.assertSequenceEqual([ + 'log_stream_unary:intercept_service', + ], record) + + async def test_interceptor_stream_stream(self): + record = [] + server, stub = await _create_server_stub_pair( + _LoggingInterceptor('log_stream_stream', record)) + + # Prepares the request + payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) + request = messages_pb2.StreamingInputCallRequest(payload=payload) + + async def gen(): + for _ in range(_NUM_STREAM_RESPONSES): + yield request + + # Invokes the actual RPC + call = stub.StreamingInputCall(gen()) + + # Validates the responses + response = await call + self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse) + self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE, + response.aggregated_payload_size) + + self.assertEqual(await call.code(), grpc.StatusCode.OK) + + self.assertSequenceEqual([ + 'log_stream_stream:intercept_service', + ], record) + if __name__ == '__main__': logging.basicConfig(level=logging.DEBUG) From ef6ff6dcfde7e65f114650c57151b73ea87e9ea1 Mon Sep 17 00:00:00 2001 From: Lidi Zheng Date: Tue, 12 May 2020 10:43:12 -0700 Subject: [PATCH 2/5] Make linters happy --- src/python/grpcio/grpc/_common.py | 2 +- src/python/grpcio/grpc/experimental/__init__.py | 2 +- .../tests_aio/unit/server_interceptor_test.py | 13 ++++++++----- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/python/grpcio/grpc/_common.py b/src/python/grpcio/grpc/_common.py index 3c455678ff0..27fc80b1498 100644 --- a/src/python/grpcio/grpc/_common.py +++ b/src/python/grpcio/grpc/_common.py @@ -15,9 +15,9 @@ import logging import time +import six import grpc -import six from grpc._cython import cygrpc _LOGGER = logging.getLogger(__name__) diff --git a/src/python/grpcio/grpc/experimental/__init__.py b/src/python/grpcio/grpc/experimental/__init__.py index 83f0b860eb7..083a7638486 100644 --- a/src/python/grpcio/grpc/experimental/__init__.py +++ b/src/python/grpcio/grpc/experimental/__init__.py @@ -92,7 +92,7 @@ def wrap_server_method_handler(wrapper, handler): if not handler.request_streaming: if not handler.response_streaming: # NOTE(lidiz) _replace is a public API: - # https://docs.python.org/dev/library/collections.html#collections.somenamedtuple._replace + # https://docs.python.org/dev/library/collections.html return handler._replace(unary_unary=wrapper(handler.unary_unary)) else: return handler._replace(unary_stream=wrapper(handler.unary_stream)) diff --git a/src/python/grpcio_tests/tests_aio/unit/server_interceptor_test.py b/src/python/grpcio_tests/tests_aio/unit/server_interceptor_test.py index ad56a44aa1c..c93fe1b7b0e 100644 --- a/src/python/grpcio_tests/tests_aio/unit/server_interceptor_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/server_interceptor_test.py @@ -196,19 +196,22 @@ class TestServerInterceptor(AioTestBase): # Get the actual handler handler = await continuation(handler_call_details) - def wrap_handler(handler: grpc.RpcMethodHandler): + def wrapper(behavior: Callable[ + [messages_pb2.SimpleRequest, aio. + ServerInterceptor], messages_pb2.SimpleResponse]): - @functools.wraps(handler) + @functools.wraps(behavior) async def wrapper(request: messages_pb2.SimpleRequest, - context: aio.ServicerContext): + context: aio.ServicerContext + ) -> messages_pb2.SimpleResponse: if request.response_size not in cache_store: - cache_store[request.response_size] = await handler( + cache_store[request.response_size] = await behavior( request, context) return cache_store[request.response_size] return wrapper - return wrap_server_method_handler(wrap_handler, handler) + return wrap_server_method_handler(wrapper, handler) # Constructs a server with the cache interceptor server, stub = await _create_server_stub_pair( From 7eeab2a23c1b9bf3abd97f380dd560f0175ed822 Mon Sep 17 00:00:00 2001 From: Lidi Zheng Date: Tue, 12 May 2020 12:52:27 -0700 Subject: [PATCH 3/5] Describe the expectations of input arguments and return values --- src/python/grpcio/grpc/experimental/__init__.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/python/grpcio/grpc/experimental/__init__.py b/src/python/grpcio/grpc/experimental/__init__.py index 083a7638486..eb642900d9c 100644 --- a/src/python/grpcio/grpc/experimental/__init__.py +++ b/src/python/grpcio/grpc/experimental/__init__.py @@ -85,6 +85,14 @@ def wrap_server_method_handler(wrapper, handler): The server implementation requires all server handlers being wrapped as RpcMethodHandler objects. This helper function ease the pain of writing server handler wrappers. + + Args: + wrapper: A wrapper function that takes in a method handler behavior + (the actual function) and returns a wrapped function. + handler: A RpcMethodHandler object to be wrapped. + + Returns: + A newly created RpcMethodHandler. """ if not handler: return None From f0f99b1b0548e4bc63fea1490a8c36ecf49cb863 Mon Sep 17 00:00:00 2001 From: Lidi Zheng Date: Tue, 12 May 2020 14:00:20 -0700 Subject: [PATCH 4/5] Clean up test logic --- .../tests_aio/unit/server_interceptor_test.py | 80 +++++++++++-------- 1 file changed, 46 insertions(+), 34 deletions(-) diff --git a/src/python/grpcio_tests/tests_aio/unit/server_interceptor_test.py b/src/python/grpcio_tests/tests_aio/unit/server_interceptor_test.py index c93fe1b7b0e..047e1401891 100644 --- a/src/python/grpcio_tests/tests_aio/unit/server_interceptor_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/server_interceptor_test.py @@ -13,6 +13,7 @@ # limitations under the License. """Test the functionality of server interceptors.""" +import asyncio import functools import logging import unittest @@ -79,6 +80,43 @@ def _filter_server_interceptor(condition: Callable, return _GenericInterceptor(intercept_service) +class _CacheInterceptor(aio.ServerInterceptor): + """An interceptor that caches response based on request message.""" + + def __init__(self, cache_store=None): + self.cache_store = cache_store or {} + + async def intercept_service( + self, continuation: Callable[[grpc.HandlerCallDetails], Awaitable[ + grpc.RpcMethodHandler]], + handler_call_details: grpc.HandlerCallDetails + ) -> grpc.RpcMethodHandler: + # Get the actual handler + handler = await continuation(handler_call_details) + + # Only intercept unary call RPCs + if handler and (handler.request_streaming or + handler.response_streaming): + return handler + + def wrapper(behavior: Callable[ + [messages_pb2.SimpleRequest, aio. + ServicerContext], messages_pb2.SimpleResponse]): + + @functools.wraps(behavior) + async def wrapper(request: messages_pb2.SimpleRequest, + context: aio.ServicerContext + ) -> messages_pb2.SimpleResponse: + if request.response_size not in self.cache_store: + self.cache_store[request.response_size] = await behavior( + request, context) + return self.cache_store[request.response_size] + + return wrapper + + return wrap_server_method_handler(wrapper, handler) + + async def _create_server_stub_pair( *interceptors: aio.ServerInterceptor ) -> Tuple[aio.Server, test_pb2_grpc.TestServiceStub]: @@ -182,55 +220,29 @@ class TestServerInterceptor(AioTestBase): async def test_response_caching(self): # Prepares a preset value to help testing - cache_store = { + interceptor = _CacheInterceptor({ 42: messages_pb2.SimpleResponse(payload=messages_pb2.Payload( body=b'\x42')) - } - - async def intercept_and_cache( - continuation: Callable[[grpc.HandlerCallDetails], Awaitable[ - grpc.RpcMethodHandler]], - handler_call_details: grpc.HandlerCallDetails - ) -> grpc.RpcMethodHandler: - # Get the actual handler - handler = await continuation(handler_call_details) - - def wrapper(behavior: Callable[ - [messages_pb2.SimpleRequest, aio. - ServerInterceptor], messages_pb2.SimpleResponse]): - - @functools.wraps(behavior) - async def wrapper(request: messages_pb2.SimpleRequest, - context: aio.ServicerContext - ) -> messages_pb2.SimpleResponse: - if request.response_size not in cache_store: - cache_store[request.response_size] = await behavior( - request, context) - return cache_store[request.response_size] - - return wrapper - - return wrap_server_method_handler(wrapper, handler) + }) # Constructs a server with the cache interceptor - server, stub = await _create_server_stub_pair( - _GenericInterceptor(intercept_and_cache)) + server, stub = await _create_server_stub_pair(interceptor) # Tests if the cache store is used response = await stub.UnaryCall( messages_pb2.SimpleRequest(response_size=42)) - self.assertEqual(1, len(cache_store[42].payload.body)) - self.assertEqual(cache_store[42], response) + self.assertEqual(1, len(interceptor.cache_store[42].payload.body)) + self.assertEqual(interceptor.cache_store[42], response) # Tests response can be cached response = await stub.UnaryCall( messages_pb2.SimpleRequest(response_size=1337)) - self.assertEqual(1337, len(cache_store[1337].payload.body)) - self.assertEqual(cache_store[1337], response) + self.assertEqual(1337, len(interceptor.cache_store[1337].payload.body)) + self.assertEqual(interceptor.cache_store[1337], response) response = await stub.UnaryCall( messages_pb2.SimpleRequest(response_size=1337)) - self.assertEqual(cache_store[1337], response) + self.assertEqual(interceptor.cache_store[1337], response) async def test_interceptor_unary_stream(self): record = [] From b2b939d7475a3ddc74b5302beb2ad9f35c2ac522 Mon Sep 17 00:00:00 2001 From: Lidi Zheng Date: Tue, 12 May 2020 14:45:34 -0700 Subject: [PATCH 5/5] Make pytype happy --- .../grpcio_tests/tests_aio/unit/server_interceptor_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/python/grpcio_tests/tests_aio/unit/server_interceptor_test.py b/src/python/grpcio_tests/tests_aio/unit/server_interceptor_test.py index 047e1401891..f85e46c379a 100644 --- a/src/python/grpcio_tests/tests_aio/unit/server_interceptor_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/server_interceptor_test.py @@ -95,8 +95,8 @@ class _CacheInterceptor(aio.ServerInterceptor): handler = await continuation(handler_call_details) # Only intercept unary call RPCs - if handler and (handler.request_streaming or - handler.response_streaming): + if handler and (handler.request_streaming or # pytype: disable=attribute-error + handler.response_streaming): # pytype: disable=attribute-error return handler def wrapper(behavior: Callable[