Implement server interceptor for unary unary call

pull/22032/head
Zhanghui Mao 5 years ago
parent 8072fcb231
commit 6fef56573e
  1. 1
      src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pxd.pxi
  2. 44
      src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi
  3. 6
      src/python/grpcio/grpc/experimental/aio/__init__.py
  4. 28
      src/python/grpcio/grpc/experimental/aio/_interceptor.py
  5. 1
      src/python/grpcio_tests/tests_aio/tests.json
  6. 6
      src/python/grpcio_tests/tests_aio/unit/_test_server.py
  7. 104
      src/python/grpcio_tests/tests_aio/unit/interceptor_test.py

@ -61,3 +61,4 @@ cdef class AioServer:
cdef CallbackWrapper _shutdown_callback_wrapper
cdef object _crash_exception # Exception
cdef set _ongoing_rpc_tasks
cdef tuple _interceptors

@ -15,6 +15,7 @@
import inspect
import traceback
import functools
cdef int _EMPTY_FLAG = 0
@ -214,15 +215,34 @@ cdef class _ServicerContext:
self._rpc_state.disable_next_compression = True
cdef _find_method_handler(str method, tuple metadata, list generic_handlers):
async def _run_interceptor(object interceptors, object query_handler,
object handler_call_details):
interceptor = next(interceptors, None)
if interceptor:
continuation = functools.partial(_run_interceptor, interceptors,
query_handler)
return await interceptor.intercept_service(continuation, handler_call_details)
else:
return query_handler(handler_call_details)
async def _find_method_handler(str method, tuple metadata, list generic_handlers,
tuple interceptors):
def query_handlers(handler_call_details):
for generic_handler in generic_handlers:
method_handler = generic_handler.service(handler_call_details)
if method_handler is not None:
return method_handler
return None
cdef _HandlerCallDetails handler_call_details = _HandlerCallDetails(method,
metadata)
for generic_handler in generic_handlers:
method_handler = generic_handler.service(handler_call_details)
if method_handler is not None:
return method_handler
return None
# interceptor
if interceptors:
return await _run_interceptor(iter(interceptors), query_handlers,
handler_call_details)
else:
return query_handlers(handler_call_details)
async def _finish_handler_with_unary_response(RPCState rpc_state,
@ -516,13 +536,15 @@ async def _schedule_rpc_coro(object rpc_coro,
await _handle_cancellation_from_core(rpc_task, rpc_state, loop)
async def _handle_rpc(list generic_handlers, RPCState rpc_state, object loop):
async def _handle_rpc(list generic_handlers, tuple interceptors,
RPCState rpc_state, object loop):
cdef object method_handler
# Finds the method handler (application logic)
method_handler = _find_method_handler(
method_handler = await _find_method_handler(
rpc_state.method().decode(),
rpc_state.invocation_metadata(),
generic_handlers,
interceptors,
)
if method_handler is None:
rpc_state.status_sent = True
@ -605,8 +627,9 @@ cdef class AioServer:
SERVER_SHUTDOWN_FAILURE_HANDLER)
self._crash_exception = None
self._interceptors = ()
if interceptors:
raise NotImplementedError()
self._interceptors = interceptors
if maximum_concurrent_rpcs:
raise NotImplementedError()
if thread_pool:
@ -662,6 +685,7 @@ cdef class AioServer:
# the coroutine onto event loop inside of the cancellation
# coroutine.
rpc_coro = _handle_rpc(self._generic_handlers,
self._interceptors,
rpc_state,
self._loop)

@ -27,7 +27,7 @@ from ._base_call import Call, RpcContext, UnaryStreamCall, UnaryUnaryCall
from ._call import AioRpcError
from ._channel import Channel, UnaryUnaryMultiCallable
from ._interceptor import (ClientCallDetails, InterceptedUnaryUnaryCall,
UnaryUnaryClientInterceptor)
UnaryUnaryClientInterceptor, ServerInterceptor)
from ._server import Server, server
from ._typing import ChannelArgumentType
@ -86,5 +86,5 @@ __all__ = ('AioRpcError', 'RpcContext', 'Call', 'UnaryUnaryCall',
'UnaryStreamCall', 'init_grpc_aio', 'Channel',
'UnaryUnaryMultiCallable', 'ClientCallDetails',
'UnaryUnaryClientInterceptor', 'InterceptedUnaryUnaryCall',
'insecure_channel', 'server', 'Server', 'EOF', 'secure_channel',
'AbortError', 'BaseError', 'UsageError')
'ServerInterceptor', 'insecure_channel', 'server', 'Server', 'EOF',
'secure_channel', 'AbortError', 'BaseError', 'UsageError')

@ -30,6 +30,34 @@ from ._typing import (RequestType, SerializingFunction, DeserializingFunction,
_LOCAL_CANCELLATION_DETAILS = 'Locally cancelled by application!'
class ServerInterceptor(metaclass=ABCMeta):
"""Affords intercepting incoming RPCs on the service-side.
This is an EXPERIMENTAL API.
"""
@abstractmethod
async def intercept_service(self,
continuation: Callable[
[grpc.HandlerCallDetails], grpc.RpcMethodHandler],
handler_call_details: grpc.HandlerCallDetails
) -> grpc.RpcMethodHandler:
"""Intercepts incoming RPCs before handing them over to a handler.
Args:
continuation: A function that takes a HandlerCallDetails and
proceeds to invoke the next interceptor in the chain, if any,
or the RPC handler lookup logic, with the call details passed
as an argument, and returns an RpcMethodHandler instance if
the RPC is considered serviced, or None otherwise.
handler_call_details: A HandlerCallDetails describing the RPC.
Returns:
An RpcMethodHandler with which the RPC may be serviced if the
interceptor chooses to service this RPC, or None otherwise.
"""
class ClientCallDetails(
collections.namedtuple(
'ClientCallDetails',

@ -18,6 +18,7 @@
"unit.init_test.TestInsecureChannel",
"unit.init_test.TestSecureChannel",
"unit.interceptor_test.TestInterceptedUnaryUnaryCall",
"unit.interceptor_test.TestServerInterceptor",
"unit.interceptor_test.TestUnaryUnaryClientInterceptor",
"unit.metadata_test.TestMetadata",
"unit.server_test.TestServer",

@ -117,8 +117,10 @@ def _create_extra_generic_handler(servicer: _TestServiceServicer):
rpc_method_handlers)
async def start_test_server(port=0, secure=False, server_credentials=None):
server = aio.server(options=(('grpc.so_reuseport', 0),))
async def start_test_server(port=0, secure=False, server_credentials=None,
interceptors=None):
server = aio.server(options=(('grpc.so_reuseport', 0),),
interceptors=interceptors)
servicer = _TestServiceServicer()
test_pb2_grpc.add_TestServiceServicer_to_server(servicer, server)

@ -685,6 +685,110 @@ class TestInterceptedUnaryUnaryCall(AioTestBase):
self.fail("Callback was not called")
class _LoggingServerInterceptor(aio.ServerInterceptor):
def __init__(self, tag, record):
self.tag = tag
self.record = record
async def intercept_service(self, continuation, handler_call_details):
self.record.append(self.tag + ':intercept_service')
return await continuation(handler_call_details)
class _GenericServerInterceptor(aio.ServerInterceptor):
def __init__(self, fn):
self._fn = fn
async def intercept_service(self, continuation, handler_call_details):
return await self._fn(continuation, handler_call_details)
def _filter_server_interceptor(condition, interceptor):
async def intercept_service(continuation, handler_call_details):
if condition(handler_call_details):
return await interceptor.intercept_service(continuation,
handler_call_details)
return await continuation(handler_call_details)
return _GenericServerInterceptor(intercept_service)
class TestServerInterceptor(AioTestBase):
async def setUp(self) -> None:
self._record = []
conditional_interceptor = _filter_server_interceptor(
lambda x: ('secret', '42') in x.invocation_metadata,
_LoggingServerInterceptor('log3', self._record))
self._interceptors = (
_LoggingServerInterceptor('log1', self._record),
conditional_interceptor,
_LoggingServerInterceptor('log2', self._record),
)
self._server_target, self._server = await start_test_server(
interceptors=self._interceptors)
async def tearDown(self) -> None:
self._server.stop(None)
async def test_invalid_interceptor(self):
class InvalidInterceptor:
"""Just an invalid Interceptor"""
with self.assertRaises(aio.AioRpcError):
server_target, _ = await start_test_server(
interceptors=(InvalidInterceptor(),))
channel = aio.insecure_channel(server_target)
multicallable = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
call = multicallable(messages_pb2.SimpleRequest())
await call
async def test_executed_right_order(self):
self._record.clear()
async with aio.insecure_channel(self._server_target) as channel:
multicallable = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
call = multicallable(messages_pb2.SimpleRequest())
response = await call
# Check that all interceptors were executed, and were executed
# in the right order.
self.assertSequenceEqual(['log1:intercept_service',
'log2:intercept_service',], self._record)
self.assertIsInstance(response, messages_pb2.SimpleResponse)
async def test_apply_different_interceptors_by_metadata(self):
async with aio.insecure_channel(self._server_target) as channel:
multicallable = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
self._record.clear()
metadata = (('key', 'value'),)
call = multicallable(messages_pb2.SimpleRequest(),
metadata=metadata)
await call
self.assertSequenceEqual(['log1:intercept_service',
'log2:intercept_service',],
self._record)
self._record.clear()
metadata = (('key', 'value'), ('secret', '42'))
call = multicallable(messages_pb2.SimpleRequest(),
metadata=metadata)
await call
self.assertSequenceEqual(['log1:intercept_service',
'log3:intercept_service',
'log2:intercept_service',],
self._record)
if __name__ == '__main__':
logging.basicConfig()
unittest.main(verbosity=2)

Loading…
Cancel
Save