diff --git a/src/python/grpcio/grpc/__init__.py b/src/python/grpcio/grpc/__init__.py index 8613bc501f1..68e5361bb99 100644 --- a/src/python/grpcio/grpc/__init__.py +++ b/src/python/grpcio/grpc/__init__.py @@ -14,6 +14,7 @@ """gRPC's Python API.""" import abc +import contextlib import enum import logging import sys @@ -1779,6 +1780,14 @@ def server(thread_pool, maximum_concurrent_rpcs) +@contextlib.contextmanager +def _create_servicer_context(rpc_event, state, request_deserializer): + from grpc import _server # pylint: disable=cyclic-import + context = _server._Context(rpc_event, state, request_deserializer) + yield context + context._finalize_state() # pylint: disable=protected-access + + ################################### __all__ ################################# __all__ = ( diff --git a/src/python/grpcio/grpc/_server.py b/src/python/grpcio/grpc/_server.py index 6caaece82c4..31f31b0f208 100644 --- a/src/python/grpcio/grpc/_server.py +++ b/src/python/grpcio/grpc/_server.py @@ -302,6 +302,9 @@ class _Context(grpc.ServicerContext): with self._state.condition: self._state.details = _common.encode(details) + def _finalize_state(self): + pass + class _RequestIterator(object): @@ -387,20 +390,24 @@ def _unary_request(rpc_event, state, request_deserializer): def _call_behavior(rpc_event, state, behavior, argument, request_deserializer): - context = _Context(rpc_event, state, request_deserializer) - try: - return behavior(argument, context), True - except Exception as exception: # pylint: disable=broad-except - with state.condition: - if state.aborted: - _abort(state, rpc_event.call, cygrpc.StatusCode.unknown, - b'RPC Aborted') - elif exception not in state.rpc_errors: - details = 'Exception calling application: {}'.format(exception) - _LOGGER.exception(details) - _abort(state, rpc_event.call, cygrpc.StatusCode.unknown, - _common.encode(details)) - return None, False + from grpc import _create_servicer_context + with _create_servicer_context(rpc_event, state, + request_deserializer) as context: + try: + response = behavior(argument, context) + return response, True + except Exception as exception: # pylint: disable=broad-except + with state.condition: + if state.aborted: + _abort(state, rpc_event.call, cygrpc.StatusCode.unknown, + b'RPC Aborted') + elif exception not in state.rpc_errors: + details = 'Exception calling application: {}'.format( + exception) + _LOGGER.exception(details) + _abort(state, rpc_event.call, cygrpc.StatusCode.unknown, + _common.encode(details)) + return None, False def _take_response_from_response_iterator(rpc_event, state, response_iterator):