Merge pull request #22925 from lidizheng/aio-server-interceptor-test

[Aio] Add test cases for server interceptors
pull/22940/head
Lidi Zheng 5 years ago committed by GitHub
commit 0b7b6181e9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 1
      src/python/grpcio/grpc/_common.py
  2. 35
      src/python/grpcio/grpc/experimental/__init__.py
  3. 172
      src/python/grpcio_tests/tests_aio/unit/server_interceptor_test.py

@ -14,7 +14,6 @@
"""Shared implementation.""" """Shared implementation."""
import logging import logging
import time import time
import six import six

@ -16,6 +16,7 @@
These APIs are subject to be removed during any minor version release. These APIs are subject to be removed during any minor version release.
""" """
import copy
import functools import functools
import sys import sys
import warnings import warnings
@ -78,11 +79,45 @@ def experimental_api(f):
return _wrapper return _wrapper
def wrap_server_method_handler(wrapper, handler):
"""Wraps the server method handler function.
The server implementation requires all server handlers being wrapped as
RpcMethodHandler objects. This helper function ease the pain of writing
server handler wrappers.
Args:
wrapper: A wrapper function that takes in a method handler behavior
(the actual function) and returns a wrapped function.
handler: A RpcMethodHandler object to be wrapped.
Returns:
A newly created RpcMethodHandler.
"""
if not handler:
return None
if not handler.request_streaming:
if not handler.response_streaming:
# NOTE(lidiz) _replace is a public API:
# https://docs.python.org/dev/library/collections.html
return handler._replace(unary_unary=wrapper(handler.unary_unary))
else:
return handler._replace(unary_stream=wrapper(handler.unary_stream))
else:
if not handler.response_streaming:
return handler._replace(stream_unary=wrapper(handler.stream_unary))
else:
return handler._replace(
stream_stream=wrapper(handler.stream_stream))
__all__ = ( __all__ = (
'ChannelOptions', 'ChannelOptions',
'ExperimentalApiWarning', 'ExperimentalApiWarning',
'UsageError', 'UsageError',
'insecure_channel_credentials', 'insecure_channel_credentials',
'wrap_server_method_handler',
) )
if sys.version_info[0] == 3 and sys.version_info[1] >= 6: if sys.version_info[0] == 3 and sys.version_info[1] >= 6:

@ -11,17 +11,24 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Test the functionality of server interceptors."""
import asyncio
import functools
import logging import logging
import unittest import unittest
from typing import Callable, Awaitable, Any from typing import Any, Awaitable, Callable, Tuple
import grpc import grpc
from grpc.experimental import aio, wrap_server_method_handler
from grpc.experimental import aio from src.proto.grpc.testing import messages_pb2, test_pb2_grpc
from tests_aio.unit._test_server import start_test_server
from tests_aio.unit._test_base import AioTestBase from tests_aio.unit._test_base import AioTestBase
from src.proto.grpc.testing import messages_pb2 from tests_aio.unit._test_server import start_test_server
_NUM_STREAM_RESPONSES = 5
_REQUEST_PAYLOAD_SIZE = 7
_RESPONSE_PAYLOAD_SIZE = 42
class _LoggingInterceptor(aio.ServerInterceptor): class _LoggingInterceptor(aio.ServerInterceptor):
@ -73,6 +80,55 @@ def _filter_server_interceptor(condition: Callable,
return _GenericInterceptor(intercept_service) return _GenericInterceptor(intercept_service)
class _CacheInterceptor(aio.ServerInterceptor):
"""An interceptor that caches response based on request message."""
def __init__(self, cache_store=None):
self.cache_store = cache_store or {}
async def intercept_service(
self, continuation: Callable[[grpc.HandlerCallDetails], Awaitable[
grpc.RpcMethodHandler]],
handler_call_details: grpc.HandlerCallDetails
) -> grpc.RpcMethodHandler:
# Get the actual handler
handler = await continuation(handler_call_details)
# Only intercept unary call RPCs
if handler and (handler.request_streaming or # pytype: disable=attribute-error
handler.response_streaming): # pytype: disable=attribute-error
return handler
def wrapper(behavior: Callable[
[messages_pb2.SimpleRequest, aio.
ServicerContext], messages_pb2.SimpleResponse]):
@functools.wraps(behavior)
async def wrapper(request: messages_pb2.SimpleRequest,
context: aio.ServicerContext
) -> messages_pb2.SimpleResponse:
if request.response_size not in self.cache_store:
self.cache_store[request.response_size] = await behavior(
request, context)
return self.cache_store[request.response_size]
return wrapper
return wrap_server_method_handler(wrapper, handler)
async def _create_server_stub_pair(
*interceptors: aio.ServerInterceptor
) -> Tuple[aio.Server, test_pb2_grpc.TestServiceStub]:
"""Creates a server-stub pair with given interceptors.
Returning the server object to protect it from being garbage collected.
"""
server_target, server = await start_test_server(interceptors=interceptors)
channel = aio.insecure_channel(server_target)
return server, test_pb2_grpc.TestServiceStub(channel)
class TestServerInterceptor(AioTestBase): class TestServerInterceptor(AioTestBase):
async def test_invalid_interceptor(self): async def test_invalid_interceptor(self):
@ -162,6 +218,112 @@ class TestServerInterceptor(AioTestBase):
'log2:intercept_service', 'log2:intercept_service',
], record) ], record)
async def test_response_caching(self):
# Prepares a preset value to help testing
interceptor = _CacheInterceptor({
42:
messages_pb2.SimpleResponse(payload=messages_pb2.Payload(
body=b'\x42'))
})
# Constructs a server with the cache interceptor
server, stub = await _create_server_stub_pair(interceptor)
# Tests if the cache store is used
response = await stub.UnaryCall(
messages_pb2.SimpleRequest(response_size=42))
self.assertEqual(1, len(interceptor.cache_store[42].payload.body))
self.assertEqual(interceptor.cache_store[42], response)
# Tests response can be cached
response = await stub.UnaryCall(
messages_pb2.SimpleRequest(response_size=1337))
self.assertEqual(1337, len(interceptor.cache_store[1337].payload.body))
self.assertEqual(interceptor.cache_store[1337], response)
response = await stub.UnaryCall(
messages_pb2.SimpleRequest(response_size=1337))
self.assertEqual(interceptor.cache_store[1337], response)
async def test_interceptor_unary_stream(self):
record = []
server, stub = await _create_server_stub_pair(
_LoggingInterceptor('log_unary_stream', record))
# Prepares the request
request = messages_pb2.StreamingOutputCallRequest()
for _ in range(_NUM_STREAM_RESPONSES):
request.response_parameters.append(
messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE,))
# Tests if the cache store is used
call = stub.StreamingOutputCall(request)
# Ensures the RPC goes fine
async for response in call:
self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
self.assertEqual(await call.code(), grpc.StatusCode.OK)
self.assertSequenceEqual([
'log_unary_stream:intercept_service',
], record)
async def test_interceptor_stream_unary(self):
record = []
server, stub = await _create_server_stub_pair(
_LoggingInterceptor('log_stream_unary', record))
# Invokes the actual RPC
call = stub.StreamingInputCall()
# Prepares the request
payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE)
request = messages_pb2.StreamingInputCallRequest(payload=payload)
# Sends out requests
for _ in range(_NUM_STREAM_RESPONSES):
await call.write(request)
await call.done_writing()
# Validates the responses
response = await call
self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse)
self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE,
response.aggregated_payload_size)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
self.assertSequenceEqual([
'log_stream_unary:intercept_service',
], record)
async def test_interceptor_stream_stream(self):
record = []
server, stub = await _create_server_stub_pair(
_LoggingInterceptor('log_stream_stream', record))
# Prepares the request
payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE)
request = messages_pb2.StreamingInputCallRequest(payload=payload)
async def gen():
for _ in range(_NUM_STREAM_RESPONSES):
yield request
# Invokes the actual RPC
call = stub.StreamingInputCall(gen())
# Validates the responses
response = await call
self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse)
self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE,
response.aggregated_payload_size)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
self.assertSequenceEqual([
'log_stream_stream:intercept_service',
], record)
if __name__ == '__main__': if __name__ == '__main__':
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)

Loading…
Cancel
Save