Add test cases for server interceptors

pull/22925/head
Lidi Zheng 5 years ago
parent 729af3a43d
commit 028a7c4e79
  1. 3
      src/python/grpcio/grpc/_common.py
  2. 27
      src/python/grpcio/grpc/experimental/__init__.py
  3. 157
      src/python/grpcio_tests/tests_aio/unit/server_interceptor_test.py

@ -14,11 +14,10 @@
"""Shared implementation.""" """Shared implementation."""
import logging import logging
import time import time
import six
import grpc import grpc
import six
from grpc._cython import cygrpc from grpc._cython import cygrpc
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)

@ -16,6 +16,7 @@
These APIs are subject to be removed during any minor version release. These APIs are subject to be removed during any minor version release.
""" """
import copy
import functools import functools
import sys import sys
import warnings import warnings
@ -78,11 +79,37 @@ def experimental_api(f):
return _wrapper return _wrapper
def wrap_server_method_handler(wrapper, handler):
"""Wraps the server method handler function.
The server implementation requires all server handlers being wrapped as
RpcMethodHandler objects. This helper function ease the pain of writing
server handler wrappers.
"""
if not handler:
return None
if not handler.request_streaming:
if not handler.response_streaming:
# NOTE(lidiz) _replace is a public API:
# https://docs.python.org/dev/library/collections.html#collections.somenamedtuple._replace
return handler._replace(unary_unary=wrapper(handler.unary_unary))
else:
return handler._replace(unary_stream=wrapper(handler.unary_stream))
else:
if not handler.response_streaming:
return handler._replace(stream_unary=wrapper(handler.stream_unary))
else:
return handler._replace(
stream_stream=wrapper(handler.stream_stream))
__all__ = ( __all__ = (
'ChannelOptions', 'ChannelOptions',
'ExperimentalApiWarning', 'ExperimentalApiWarning',
'UsageError', 'UsageError',
'insecure_channel_credentials', 'insecure_channel_credentials',
'wrap_server_method_handler',
) )
if sys.version_info[0] == 3 and sys.version_info[1] >= 6: if sys.version_info[0] == 3 and sys.version_info[1] >= 6:

@ -11,17 +11,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Test the functionality of server interceptors."""
import functools
import logging import logging
import unittest import unittest
from typing import Callable, Awaitable, Any from typing import Any, Awaitable, Callable, Tuple
import grpc import grpc
from grpc.experimental import aio, wrap_server_method_handler
from grpc.experimental import aio from src.proto.grpc.testing import messages_pb2, test_pb2_grpc
from tests_aio.unit._test_server import start_test_server
from tests_aio.unit._test_base import AioTestBase from tests_aio.unit._test_base import AioTestBase
from src.proto.grpc.testing import messages_pb2 from tests_aio.unit._test_server import start_test_server
_NUM_STREAM_RESPONSES = 5
_REQUEST_PAYLOAD_SIZE = 7
_RESPONSE_PAYLOAD_SIZE = 42
class _LoggingInterceptor(aio.ServerInterceptor): class _LoggingInterceptor(aio.ServerInterceptor):
@ -73,6 +79,18 @@ def _filter_server_interceptor(condition: Callable,
return _GenericInterceptor(intercept_service) return _GenericInterceptor(intercept_service)
async def _create_server_stub_pair(
*interceptors: aio.ServerInterceptor
) -> Tuple[aio.Server, test_pb2_grpc.TestServiceStub]:
"""Creates a server-stub pair with given interceptors.
Returning the server object to protect it from being garbage collected.
"""
server_target, server = await start_test_server(interceptors=interceptors)
channel = aio.insecure_channel(server_target)
return server, test_pb2_grpc.TestServiceStub(channel)
class TestServerInterceptor(AioTestBase): class TestServerInterceptor(AioTestBase):
async def test_invalid_interceptor(self): async def test_invalid_interceptor(self):
@ -162,6 +180,135 @@ class TestServerInterceptor(AioTestBase):
'log2:intercept_service', 'log2:intercept_service',
], record) ], record)
async def test_response_caching(self):
# Prepares a preset value to help testing
cache_store = {
42:
messages_pb2.SimpleResponse(payload=messages_pb2.Payload(
body=b'\x42'))
}
async def intercept_and_cache(
continuation: Callable[[grpc.HandlerCallDetails], Awaitable[
grpc.RpcMethodHandler]],
handler_call_details: grpc.HandlerCallDetails
) -> grpc.RpcMethodHandler:
# Get the actual handler
handler = await continuation(handler_call_details)
def wrap_handler(handler: grpc.RpcMethodHandler):
@functools.wraps(handler)
async def wrapper(request: messages_pb2.SimpleRequest,
context: aio.ServicerContext):
if request.response_size not in cache_store:
cache_store[request.response_size] = await handler(
request, context)
return cache_store[request.response_size]
return wrapper
return wrap_server_method_handler(wrap_handler, handler)
# Constructs a server with the cache interceptor
server, stub = await _create_server_stub_pair(
_GenericInterceptor(intercept_and_cache))
# Tests if the cache store is used
response = await stub.UnaryCall(
messages_pb2.SimpleRequest(response_size=42))
self.assertEqual(1, len(cache_store[42].payload.body))
self.assertEqual(cache_store[42], response)
# Tests response can be cached
response = await stub.UnaryCall(
messages_pb2.SimpleRequest(response_size=1337))
self.assertEqual(1337, len(cache_store[1337].payload.body))
self.assertEqual(cache_store[1337], response)
response = await stub.UnaryCall(
messages_pb2.SimpleRequest(response_size=1337))
self.assertEqual(cache_store[1337], response)
async def test_interceptor_unary_stream(self):
record = []
server, stub = await _create_server_stub_pair(
_LoggingInterceptor('log_unary_stream', record))
# Prepares the request
request = messages_pb2.StreamingOutputCallRequest()
for _ in range(_NUM_STREAM_RESPONSES):
request.response_parameters.append(
messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE,))
# Tests if the cache store is used
call = stub.StreamingOutputCall(request)
# Ensures the RPC goes fine
async for response in call:
self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
self.assertEqual(await call.code(), grpc.StatusCode.OK)
self.assertSequenceEqual([
'log_unary_stream:intercept_service',
], record)
async def test_interceptor_stream_unary(self):
record = []
server, stub = await _create_server_stub_pair(
_LoggingInterceptor('log_stream_unary', record))
# Invokes the actual RPC
call = stub.StreamingInputCall()
# Prepares the request
payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE)
request = messages_pb2.StreamingInputCallRequest(payload=payload)
# Sends out requests
for _ in range(_NUM_STREAM_RESPONSES):
await call.write(request)
await call.done_writing()
# Validates the responses
response = await call
self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse)
self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE,
response.aggregated_payload_size)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
self.assertSequenceEqual([
'log_stream_unary:intercept_service',
], record)
async def test_interceptor_stream_stream(self):
record = []
server, stub = await _create_server_stub_pair(
_LoggingInterceptor('log_stream_stream', record))
# Prepares the request
payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE)
request = messages_pb2.StreamingInputCallRequest(payload=payload)
async def gen():
for _ in range(_NUM_STREAM_RESPONSES):
yield request
# Invokes the actual RPC
call = stub.StreamingInputCall(gen())
# Validates the responses
response = await call
self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse)
self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE,
response.aggregated_payload_size)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
self.assertSequenceEqual([
'log_stream_stream:intercept_service',
], record)
if __name__ == '__main__': if __name__ == '__main__':
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)

Loading…
Cancel
Save