From 68e1978e9e972314ea91b0fe3a333117736ace2e Mon Sep 17 00:00:00 2001 From: Mehrdad Afshari Date: Sun, 10 Dec 2017 23:20:06 -0800 Subject: [PATCH] Add gRPC Python service-side interceptor machinery --- src/python/grpcio/grpc/__init__.py | 101 ++++++++++++------ src/python/grpcio/grpc/_interceptor.py | 38 +++++++ src/python/grpcio/grpc/_server.py | 46 +++++--- .../grpcio_tests/tests/unit/_api_test.py | 4 +- 4 files changed, 137 insertions(+), 52 deletions(-) create mode 100644 src/python/grpcio/grpc/_interceptor.py diff --git a/src/python/grpcio/grpc/__init__.py b/src/python/grpcio/grpc/__init__.py index 0f6bdfc1c2d..30ed6d05778 100644 --- a/src/python/grpcio/grpc/__init__.py +++ b/src/python/grpcio/grpc/__init__.py @@ -962,6 +962,34 @@ class ServiceRpcHandler(six.with_metaclass(abc.ABCMeta, GenericRpcHandler)): raise NotImplementedError() +#################### Service-Side Interceptor Interfaces ##################### + + +class ServerInterceptor(six.with_metaclass(abc.ABCMeta)): + """Affords intercepting incoming RPCs on the service-side. + + This is an EXPERIMENTAL API. + """ + + @abc.abstractmethod + def intercept_service(self, continuation, handler_call_details): + """Intercepts incoming RPCs before handing them over to a handler. + + Args: + continuation: A function that takes a HandlerCallDetails and + proceeds to invoke the next interceptor in the chain, if any, + or the RPC handler lookup logic, with the call details passed + as an argument, and returns an RpcMethodHandler instance if + the RPC is considered serviced, or None otherwise. + handler_call_details: A HandlerCallDetails describing the RPC. + + Returns: + An RpcMethodHandler with which the RPC may be serviced if the + interceptor chooses to service this RPC, or None otherwise. + """ + raise NotImplementedError() + + ############################# Server Interface ############################### @@ -1378,51 +1406,56 @@ def secure_channel(target, credentials, options=None): def server(thread_pool, handlers=None, + interceptors=None, options=None, maximum_concurrent_rpcs=None): """Creates a Server with which RPCs can be serviced. - Args: - thread_pool: A futures.ThreadPoolExecutor to be used by the Server - to execute RPC handlers. - handlers: An optional list of GenericRpcHandlers used for executing RPCs. - More handlers may be added by calling add_generic_rpc_handlers any time - before the server is started. - options: An optional list of key-value pairs (channel args in gRPC runtime) - to configure the channel. - maximum_concurrent_rpcs: The maximum number of concurrent RPCs this server - will service before returning RESOURCE_EXHAUSTED status, or None to - indicate no limit. + Args: + thread_pool: A futures.ThreadPoolExecutor to be used by the Server + to execute RPC handlers. + handlers: An optional list of GenericRpcHandlers used for executing RPCs. + More handlers may be added by calling add_generic_rpc_handlers any time + before the server is started. + interceptors: An optional list of ServerInterceptor objects that observe + and optionally manipulate the incoming RPCs before handing them over to + handlers. The interceptors are given control in the order they are + specified. This is an EXPERIMENTAL API. + options: An optional list of key-value pairs (channel args in gRPC runtime) + to configure the channel. + maximum_concurrent_rpcs: The maximum number of concurrent RPCs this server + will service before returning RESOURCE_EXHAUSTED status, or None to + indicate no limit. - Returns: - A Server object. - """ + Returns: + A Server object. + """ from grpc import _server # pylint: disable=cyclic-import return _server.Server(thread_pool, () if handlers is None else handlers, () - if options is None else options, - maximum_concurrent_rpcs) + if interceptors is None else interceptors, () if + options is None else options, maximum_concurrent_rpcs) ################################### __all__ ################################# -__all__ = ('FutureTimeoutError', 'FutureCancelledError', 'Future', - 'ChannelConnectivity', 'StatusCode', 'RpcError', 'RpcContext', - 'Call', 'ChannelCredentials', 'CallCredentials', - 'AuthMetadataContext', 'AuthMetadataPluginCallback', - 'AuthMetadataPlugin', 'ServerCertificateConfiguration', - 'ServerCredentials', 'UnaryUnaryMultiCallable', - 'UnaryStreamMultiCallable', 'StreamUnaryMultiCallable', - 'StreamStreamMultiCallable', 'Channel', 'ServicerContext', - 'RpcMethodHandler', 'HandlerCallDetails', 'GenericRpcHandler', - 'ServiceRpcHandler', 'Server', 'unary_unary_rpc_method_handler', - 'unary_stream_rpc_method_handler', 'stream_unary_rpc_method_handler', - 'stream_stream_rpc_method_handler', - 'method_handlers_generic_handler', 'ssl_channel_credentials', - 'metadata_call_credentials', 'access_token_call_credentials', - 'composite_call_credentials', 'composite_channel_credentials', - 'ssl_server_credentials', 'ssl_server_certificate_configuration', - 'dynamic_ssl_server_credentials', 'channel_ready_future', - 'insecure_channel', 'secure_channel', 'server',) +__all__ = ( + 'FutureTimeoutError', 'FutureCancelledError', 'Future', + 'ChannelConnectivity', 'StatusCode', 'RpcError', 'RpcContext', 'Call', + 'ChannelCredentials', 'CallCredentials', 'AuthMetadataContext', + 'AuthMetadataPluginCallback', 'AuthMetadataPlugin', + 'ServerCertificateConfiguration', 'ServerCredentials', + 'UnaryUnaryMultiCallable', 'UnaryStreamMultiCallable', + 'StreamUnaryMultiCallable', 'StreamStreamMultiCallable', 'Channel', + 'ServicerContext', 'RpcMethodHandler', 'HandlerCallDetails', + 'GenericRpcHandler', 'ServiceRpcHandler', 'Server', 'ServerInterceptor', + 'unary_unary_rpc_method_handler', 'unary_stream_rpc_method_handler', + 'stream_unary_rpc_method_handler', 'stream_stream_rpc_method_handler', + 'method_handlers_generic_handler', 'ssl_channel_credentials', + 'metadata_call_credentials', 'access_token_call_credentials', + 'composite_call_credentials', 'composite_channel_credentials', + 'ssl_server_credentials', 'ssl_server_certificate_configuration', + 'dynamic_ssl_server_credentials', 'channel_ready_future', + 'insecure_channel', 'secure_channel', 'server',) ############################### Extension Shims ################################ diff --git a/src/python/grpcio/grpc/_interceptor.py b/src/python/grpcio/grpc/_interceptor.py new file mode 100644 index 00000000000..7aa2c7ec222 --- /dev/null +++ b/src/python/grpcio/grpc/_interceptor.py @@ -0,0 +1,38 @@ +# Copyright 2017 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Implementation of gRPC Python interceptors.""" + + +class _ServicePipeline(object): + + def __init__(self, interceptors): + self.interceptors = tuple(interceptors) + + def _continuation(self, thunk, index): + return lambda context: self._intercept_at(thunk, index, context) + + def _intercept_at(self, thunk, index, context): + if index < len(self.interceptors): + interceptor = self.interceptors[index] + thunk = self._continuation(thunk, index + 1) + return interceptor.intercept_service(thunk, context) + else: + return thunk(context) + + def execute(self, thunk, context): + return self._intercept_at(thunk, 0, context) + + +def service_pipeline(interceptors): + return _ServicePipeline(interceptors) if interceptors else None diff --git a/src/python/grpcio/grpc/_server.py b/src/python/grpcio/grpc/_server.py index 7b78a8be736..8857bd3eb54 100644 --- a/src/python/grpcio/grpc/_server.py +++ b/src/python/grpcio/grpc/_server.py @@ -23,6 +23,7 @@ import six import grpc from grpc import _common +from grpc import _interceptor from grpc._cython import cygrpc from grpc.framework.foundation import callable_util @@ -541,17 +542,25 @@ def _handle_stream_stream(rpc_event, state, method_handler, thread_pool): method_handler.request_deserializer, method_handler.response_serializer) -def _find_method_handler(rpc_event, generic_handlers): - for generic_handler in generic_handlers: - method_handler = generic_handler.service( - _HandlerCallDetails( - _common.decode(rpc_event.request_call_details.method), - rpc_event.request_metadata)) - if method_handler is not None: - return method_handler - else: +def _find_method_handler(rpc_event, generic_handlers, interceptor_pipeline): + + def query_handlers(handler_call_details): + for generic_handler in generic_handlers: + method_handler = generic_handler.service(handler_call_details) + if method_handler is not None: + return method_handler return None + handler_call_details = _HandlerCallDetails( + _common.decode(rpc_event.request_call_details.method), + rpc_event.request_metadata) + + if interceptor_pipeline is not None: + return interceptor_pipeline.execute(query_handlers, + handler_call_details) + else: + return query_handlers(handler_call_details) + def _reject_rpc(rpc_event, status, details): operations = (cygrpc.operation_send_initial_metadata((), _EMPTY_FLAGS), @@ -587,13 +596,14 @@ def _handle_with_method_handler(rpc_event, method_handler, thread_pool): method_handler, thread_pool) -def _handle_call(rpc_event, generic_handlers, thread_pool, +def _handle_call(rpc_event, generic_handlers, interceptor_pipeline, thread_pool, concurrency_exceeded): if not rpc_event.success: return None, None if rpc_event.request_call_details.method is not None: try: - method_handler = _find_method_handler(rpc_event, generic_handlers) + method_handler = _find_method_handler(rpc_event, generic_handlers, + interceptor_pipeline) except Exception as exception: # pylint: disable=broad-except details = 'Exception servicing handler: {}'.format(exception) logging.exception(details) @@ -621,12 +631,14 @@ class _ServerStage(enum.Enum): class _ServerState(object): - def __init__(self, completion_queue, server, generic_handlers, thread_pool, - maximum_concurrent_rpcs): + # pylint: disable=too-many-arguments + def __init__(self, completion_queue, server, generic_handlers, + interceptor_pipeline, thread_pool, maximum_concurrent_rpcs): self.lock = threading.Lock() self.completion_queue = completion_queue self.server = server self.generic_handlers = list(generic_handlers) + self.interceptor_pipeline = interceptor_pipeline self.thread_pool = thread_pool self.stage = _ServerStage.STOPPED self.shutdown_events = None @@ -691,8 +703,8 @@ def _serve(state): state.maximum_concurrent_rpcs is not None and state.active_rpc_count >= state.maximum_concurrent_rpcs) rpc_state, rpc_future = _handle_call( - event, state.generic_handlers, state.thread_pool, - concurrency_exceeded) + event, state.generic_handlers, state.interceptor_pipeline, + state.thread_pool, concurrency_exceeded) if rpc_state is not None: state.rpc_states.add(rpc_state) if rpc_future is not None: @@ -780,12 +792,14 @@ def _start(state): class Server(grpc.Server): - def __init__(self, thread_pool, generic_handlers, options, + # pylint: disable=too-many-arguments + def __init__(self, thread_pool, generic_handlers, interceptors, options, maximum_concurrent_rpcs): completion_queue = cygrpc.CompletionQueue() server = cygrpc.Server(_common.channel_args(options)) server.register_completion_queue(completion_queue) self._state = _ServerState(completion_queue, server, generic_handlers, + _interceptor.service_pipeline(interceptors), thread_pool, maximum_concurrent_rpcs) def add_generic_rpc_handlers(self, generic_rpc_handlers): diff --git a/src/python/grpcio_tests/tests/unit/_api_test.py b/src/python/grpcio_tests/tests/unit/_api_test.py index b14e8d5c753..49f9b9aa7b9 100644 --- a/src/python/grpcio_tests/tests/unit/_api_test.py +++ b/src/python/grpcio_tests/tests/unit/_api_test.py @@ -35,8 +35,8 @@ class AllTest(unittest.TestCase): 'UnaryStreamMultiCallable', 'StreamUnaryMultiCallable', 'StreamStreamMultiCallable', 'Channel', 'ServicerContext', 'RpcMethodHandler', 'HandlerCallDetails', 'GenericRpcHandler', - 'ServiceRpcHandler', 'Server', 'unary_unary_rpc_method_handler', - 'unary_stream_rpc_method_handler', + 'ServiceRpcHandler', 'Server', 'ServerInterceptor', + 'unary_unary_rpc_method_handler', 'unary_stream_rpc_method_handler', 'stream_unary_rpc_method_handler', 'stream_stream_rpc_method_handler', 'method_handlers_generic_handler', 'ssl_channel_credentials',