Merge pull request #22032 from ZHmao/implement-server-interceptor-for-unary-unary-call

[Aio] Implement server interceptor for unary unary call
pull/22370/head
Lidi Zheng 5 years ago committed by GitHub
commit 87d01bf9e5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 1
      src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pxd.pxi
  2. 44
      src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi
  3. 3
      src/python/grpcio/grpc/experimental/aio/__init__.py
  4. 30
      src/python/grpcio/grpc/experimental/aio/_interceptor.py
  5. 13
      src/python/grpcio/grpc/experimental/aio/_server.py
  6. 5
      src/python/grpcio_tests/tests_aio/tests.json
  7. 9
      src/python/grpcio_tests/tests_aio/unit/_test_server.py
  8. 0
      src/python/grpcio_tests/tests_aio/unit/client_interceptor_test.py
  9. 168
      src/python/grpcio_tests/tests_aio/unit/server_interceptor_test.py

@ -61,3 +61,4 @@ cdef class AioServer:
cdef CallbackWrapper _shutdown_callback_wrapper
cdef object _crash_exception # Exception
cdef set _ongoing_rpc_tasks
cdef tuple _interceptors

@ -15,6 +15,7 @@
import inspect
import traceback
import functools
cdef int _EMPTY_FLAG = 0
@ -214,15 +215,34 @@ cdef class _ServicerContext:
self._rpc_state.disable_next_compression = True
cdef _find_method_handler(str method, tuple metadata, list generic_handlers):
async def _run_interceptor(object interceptors, object query_handler,
object handler_call_details):
interceptor = next(interceptors, None)
if interceptor:
continuation = functools.partial(_run_interceptor, interceptors,
query_handler)
return await interceptor.intercept_service(continuation, handler_call_details)
else:
return query_handler(handler_call_details)
async def _find_method_handler(str method, tuple metadata, list generic_handlers,
tuple interceptors):
def query_handlers(handler_call_details):
for generic_handler in generic_handlers:
method_handler = generic_handler.service(handler_call_details)
if method_handler is not None:
return method_handler
return None
cdef _HandlerCallDetails handler_call_details = _HandlerCallDetails(method,
metadata)
for generic_handler in generic_handlers:
method_handler = generic_handler.service(handler_call_details)
if method_handler is not None:
return method_handler
return None
# interceptor
if interceptors:
return await _run_interceptor(iter(interceptors), query_handlers,
handler_call_details)
else:
return query_handlers(handler_call_details)
async def _finish_handler_with_unary_response(RPCState rpc_state,
@ -523,13 +543,15 @@ async def _schedule_rpc_coro(object rpc_coro,
await _handle_cancellation_from_core(rpc_task, rpc_state, loop)
async def _handle_rpc(list generic_handlers, RPCState rpc_state, object loop):
async def _handle_rpc(list generic_handlers, tuple interceptors,
RPCState rpc_state, object loop):
cdef object method_handler
# Finds the method handler (application logic)
method_handler = _find_method_handler(
method_handler = await _find_method_handler(
rpc_state.method().decode(),
rpc_state.invocation_metadata(),
generic_handlers,
interceptors,
)
if method_handler is None:
rpc_state.status_sent = True
@ -612,8 +634,9 @@ cdef class AioServer:
SERVER_SHUTDOWN_FAILURE_HANDLER)
self._crash_exception = None
self._interceptors = ()
if interceptors:
raise NotImplementedError()
self._interceptors = interceptors
if maximum_concurrent_rpcs:
raise NotImplementedError()
if thread_pool:
@ -669,6 +692,7 @@ cdef class AioServer:
# the coroutine onto event loop inside of the cancellation
# coroutine.
rpc_coro = _handle_rpc(self._generic_handlers,
self._interceptors,
rpc_state,
self._loop)

@ -30,7 +30,7 @@ from ._base_channel import (Channel, StreamStreamMultiCallable,
UnaryUnaryMultiCallable)
from ._call import AioRpcError
from ._interceptor import (ClientCallDetails, InterceptedUnaryUnaryCall,
UnaryUnaryClientInterceptor)
UnaryUnaryClientInterceptor, ServerInterceptor)
from ._server import server
from ._base_server import Server, ServicerContext
from ._typing import ChannelArgumentType
@ -55,6 +55,7 @@ __all__ = (
'ClientCallDetails',
'UnaryUnaryClientInterceptor',
'InterceptedUnaryUnaryCall',
'ServerInterceptor',
'insecure_channel',
'server',
'Server',

@ -16,7 +16,7 @@ import asyncio
import collections
import functools
from abc import ABCMeta, abstractmethod
from typing import Callable, Optional, Iterator, Sequence, Union
from typing import Callable, Optional, Iterator, Sequence, Union, Awaitable
import grpc
from grpc._cython import cygrpc
@ -30,6 +30,34 @@ from ._typing import (RequestType, SerializingFunction, DeserializingFunction,
_LOCAL_CANCELLATION_DETAILS = 'Locally cancelled by application!'
class ServerInterceptor(metaclass=ABCMeta):
"""Affords intercepting incoming RPCs on the service-side.
This is an EXPERIMENTAL API.
"""
@abstractmethod
async def intercept_service(
self, continuation: Callable[[grpc.HandlerCallDetails], Awaitable[
grpc.RpcMethodHandler]],
handler_call_details: grpc.HandlerCallDetails
) -> grpc.RpcMethodHandler:
"""Intercepts incoming RPCs before handing them over to a handler.
Args:
continuation: A function that takes a HandlerCallDetails and
proceeds to invoke the next interceptor in the chain, if any,
or the RPC handler lookup logic, with the call details passed
as an argument, and returns an RpcMethodHandler instance if
the RPC is considered serviced, or None otherwise.
handler_call_details: A HandlerCallDetails describing the RPC.
Returns:
An RpcMethodHandler with which the RPC may be serviced if the
interceptor chooses to service this RPC, or None otherwise.
"""
class ClientCallDetails(
collections.namedtuple(
'ClientCallDetails',

@ -23,6 +23,7 @@ from grpc._cython import cygrpc
from . import _base_server
from ._typing import ChannelArgumentType
from ._interceptor import ServerInterceptor
def _augment_channel_arguments(base_options: ChannelArgumentType,
@ -41,6 +42,15 @@ class Server(_base_server.Server):
maximum_concurrent_rpcs: Optional[int],
compression: Optional[grpc.Compression]):
self._loop = asyncio.get_event_loop()
if interceptors:
invalid_interceptors = [
interceptor for interceptor in interceptors
if not isinstance(interceptor, ServerInterceptor)
]
if invalid_interceptors:
raise ValueError(
'Interceptor must be ServerInterceptor, the '
f'following are invalid: {invalid_interceptors}')
self._server = cygrpc.AioServer(
self._loop, thread_pool, generic_handlers, interceptors,
_augment_channel_arguments(options, compression),
@ -152,7 +162,8 @@ class Server(_base_server.Server):
The Cython AioServer doesn't hold a ref-count to this class. It should
be safe to slightly extend the underlying Cython object's life span.
"""
self._loop.create_task(self._server.shutdown(None))
if hasattr(self, '_server'):
self._loop.create_task(self._server.shutdown(None))
def server(migration_thread_pool: Optional[Executor] = None,

@ -12,15 +12,16 @@
"unit.channel_argument_test.TestChannelArgument",
"unit.channel_ready_test.TestChannelReady",
"unit.channel_test.TestChannel",
"unit.client_interceptor_test.TestInterceptedUnaryUnaryCall",
"unit.client_interceptor_test.TestUnaryUnaryClientInterceptor",
"unit.close_channel_test.TestCloseChannel",
"unit.compression_test.TestCompression",
"unit.connectivity_test.TestConnectivityState",
"unit.done_callback_test.TestDoneCallback",
"unit.init_test.TestInsecureChannel",
"unit.init_test.TestSecureChannel",
"unit.interceptor_test.TestInterceptedUnaryUnaryCall",
"unit.interceptor_test.TestUnaryUnaryClientInterceptor",
"unit.metadata_test.TestMetadata",
"unit.server_interceptor_test.TestServerInterceptor",
"unit.server_test.TestServer",
"unit.timeout_test.TestTimeout",
"unit.wait_for_ready_test.TestWaitForReady"

@ -14,7 +14,6 @@
import asyncio
import datetime
import logging
import grpc
from grpc.experimental import aio
@ -117,8 +116,12 @@ def _create_extra_generic_handler(servicer: _TestServiceServicer):
rpc_method_handlers)
async def start_test_server(port=0, secure=False, server_credentials=None):
server = aio.server(options=(('grpc.so_reuseport', 0),))
async def start_test_server(port=0,
secure=False,
server_credentials=None,
interceptors=None):
server = aio.server(options=(('grpc.so_reuseport', 0),),
interceptors=interceptors)
servicer = _TestServiceServicer()
test_pb2_grpc.add_TestServiceServicer_to_server(servicer, server)

@ -0,0 +1,168 @@
# Copyright 2020 The gRPC Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import unittest
from typing import Callable, Awaitable, Any
import grpc
from grpc.experimental import aio
from tests_aio.unit._test_server import start_test_server
from tests_aio.unit._test_base import AioTestBase
from src.proto.grpc.testing import messages_pb2
class _LoggingInterceptor(aio.ServerInterceptor):
def __init__(self, tag: str, record: list) -> None:
self.tag = tag
self.record = record
async def intercept_service(
self, continuation: Callable[[grpc.HandlerCallDetails], Awaitable[
grpc.RpcMethodHandler]],
handler_call_details: grpc.HandlerCallDetails
) -> grpc.RpcMethodHandler:
self.record.append(self.tag + ':intercept_service')
return await continuation(handler_call_details)
class _GenericInterceptor(aio.ServerInterceptor):
def __init__(self, fn: Callable[[
Callable[[grpc.HandlerCallDetails], Awaitable[grpc.
RpcMethodHandler]],
grpc.HandlerCallDetails
], Any]) -> None:
self._fn = fn
async def intercept_service(
self, continuation: Callable[[grpc.HandlerCallDetails], Awaitable[
grpc.RpcMethodHandler]],
handler_call_details: grpc.HandlerCallDetails
) -> grpc.RpcMethodHandler:
return await self._fn(continuation, handler_call_details)
def _filter_server_interceptor(condition: Callable,
interceptor: aio.ServerInterceptor
) -> aio.ServerInterceptor:
async def intercept_service(
continuation: Callable[[grpc.HandlerCallDetails], Awaitable[
grpc.RpcMethodHandler]],
handler_call_details: grpc.HandlerCallDetails
) -> grpc.RpcMethodHandler:
if condition(handler_call_details):
return await interceptor.intercept_service(continuation,
handler_call_details)
return await continuation(handler_call_details)
return _GenericInterceptor(intercept_service)
class TestServerInterceptor(AioTestBase):
async def test_invalid_interceptor(self):
class InvalidInterceptor:
"""Just an invalid Interceptor"""
with self.assertRaises(ValueError):
server_target, _ = await start_test_server(
interceptors=(InvalidInterceptor(),))
async def test_executed_right_order(self):
record = []
server_target, _ = await start_test_server(interceptors=(
_LoggingInterceptor('log1', record),
_LoggingInterceptor('log2', record),
))
async with aio.insecure_channel(server_target) as channel:
multicallable = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
call = multicallable(messages_pb2.SimpleRequest())
response = await call
# Check that all interceptors were executed, and were executed
# in the right order.
self.assertSequenceEqual([
'log1:intercept_service',
'log2:intercept_service',
], record)
self.assertIsInstance(response, messages_pb2.SimpleResponse)
async def test_response_ok(self):
record = []
server_target, _ = await start_test_server(
interceptors=(_LoggingInterceptor('log1', record),))
async with aio.insecure_channel(server_target) as channel:
multicallable = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
call = multicallable(messages_pb2.SimpleRequest())
response = await call
code = await call.code()
self.assertSequenceEqual(['log1:intercept_service'], record)
self.assertIsInstance(response, messages_pb2.SimpleResponse)
self.assertEqual(code, grpc.StatusCode.OK)
async def test_apply_different_interceptors_by_metadata(self):
record = []
conditional_interceptor = _filter_server_interceptor(
lambda x: ('secret', '42') in x.invocation_metadata,
_LoggingInterceptor('log3', record))
server_target, _ = await start_test_server(interceptors=(
_LoggingInterceptor('log1', record),
conditional_interceptor,
_LoggingInterceptor('log2', record),
))
async with aio.insecure_channel(server_target) as channel:
multicallable = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
metadata = (('key', 'value'),)
call = multicallable(messages_pb2.SimpleRequest(),
metadata=metadata)
await call
self.assertSequenceEqual([
'log1:intercept_service',
'log2:intercept_service',
], record)
record.clear()
metadata = (('key', 'value'), ('secret', '42'))
call = multicallable(messages_pb2.SimpleRequest(),
metadata=metadata)
await call
self.assertSequenceEqual([
'log1:intercept_service',
'log3:intercept_service',
'log2:intercept_service',
], record)
if __name__ == '__main__':
logging.basicConfig()
unittest.main(verbosity=2)
Loading…
Cancel
Save