diff --git a/src/python/grpcio/grpc/__init__.py b/src/python/grpcio/grpc/__init__.py index 45ee2c2acd1..ea5e7457a67 100644 --- a/src/python/grpcio/grpc/__init__.py +++ b/src/python/grpcio/grpc/__init__.py @@ -1156,20 +1156,6 @@ def ssl_channel_credentials(root_certificates=None, _cygrpc.channel_credentials_ssl(root_certificates, pair)) -def _metadata_call_credentials(metadata_plugin, name): - from grpc import _plugin_wrapping # pylint: disable=cyclic-import - if name is None: - try: - effective_name = metadata_plugin.__name__ - except AttributeError: - effective_name = metadata_plugin.__class__.__name__ - else: - effective_name = name - return CallCredentials( - _plugin_wrapping.call_credentials_metadata_plugin(metadata_plugin, - effective_name)) - - def metadata_call_credentials(metadata_plugin, name=None): """Construct CallCredentials from an AuthMetadataPlugin. @@ -1180,7 +1166,10 @@ def metadata_call_credentials(metadata_plugin, name=None): Returns: A CallCredentials. """ - return _metadata_call_credentials(metadata_plugin, name) + from grpc import _plugin_wrapping # pylint: disable=cyclic-import + return CallCredentials( + _plugin_wrapping.metadata_plugin_call_credentials(metadata_plugin, + name)) def access_token_call_credentials(access_token): @@ -1195,8 +1184,10 @@ def access_token_call_credentials(access_token): A CallCredentials. """ from grpc import _auth # pylint: disable=cyclic-import - return _metadata_call_credentials( - _auth.AccessTokenAuthMetadataPlugin(access_token), None) + from grpc import _plugin_wrapping # pylint: disable=cyclic-import + return CallCredentials( + _plugin_wrapping.metadata_plugin_call_credentials( + _auth.AccessTokenAuthMetadataPlugin(access_token), None)) def composite_call_credentials(*call_credentials): diff --git a/src/python/grpcio/grpc/_plugin_wrapping.py b/src/python/grpcio/grpc/_plugin_wrapping.py index bea2c0f1396..d089d762c48 100644 --- a/src/python/grpcio/grpc/_plugin_wrapping.py +++ b/src/python/grpcio/grpc/_plugin_wrapping.py @@ -13,6 +13,7 @@ # limitations under the License. import collections +import logging import threading import grpc @@ -20,89 +21,80 @@ from grpc import _common from grpc._cython import cygrpc -class AuthMetadataContext( +class _AuthMetadataContext( collections.namedtuple('AuthMetadataContext', ( 'service_url', 'method_name',)), grpc.AuthMetadataContext): pass -class AuthMetadataPluginCallback(grpc.AuthMetadataContext): +class _CallbackState(object): - def __init__(self, callback): - self._callback = callback - - def __call__(self, metadata, error): - self._callback(metadata, error) + def __init__(self): + self.lock = threading.Lock() + self.called = False + self.exception = None -class _WrappedCygrpcCallback(object): +class _AuthMetadataPluginCallback(grpc.AuthMetadataPluginCallback): - def __init__(self, cygrpc_callback): - self.is_called = False - self.error = None - self.is_called_lock = threading.Lock() - self.cygrpc_callback = cygrpc_callback - - def _invoke_failure(self, error): - # TODO(atash) translate different Exception superclasses into different - # status codes. - self.cygrpc_callback(_common.EMPTY_METADATA, cygrpc.StatusCode.internal, - _common.encode(str(error))) - - def _invoke_success(self, metadata): - try: - cygrpc_metadata = _common.to_cygrpc_metadata(metadata) - except Exception as exception: # pylint: disable=broad-except - self._invoke_failure(exception) - return - self.cygrpc_callback(cygrpc_metadata, cygrpc.StatusCode.ok, b'') + def __init__(self, state, callback): + self._state = state + self._callback = callback def __call__(self, metadata, error): - with self.is_called_lock: - if self.is_called: - raise RuntimeError('callback should only ever be invoked once') - if self.error: - self._invoke_failure(self.error) - return - self.is_called = True + with self._state.lock: + if self._state.exception is None: + if self._state.called: + raise RuntimeError( + 'AuthMetadataPluginCallback invoked more than once!') + else: + self._state.called = True + else: + raise RuntimeError( + 'AuthMetadataPluginCallback raised exception "{}"!'.format( + self._state.exception)) if error is None: - self._invoke_success(metadata) + self._callback( + _common.to_cygrpc_metadata(metadata), cygrpc.StatusCode.ok, b'') else: - self._invoke_failure(error) - - def notify_failure(self, error): - with self.is_called_lock: - if not self.is_called: - self.error = error + self._callback(_common.EMPTY_METADATA, cygrpc.StatusCode.internal, + _common.encode(str(error))) -class _WrappedPlugin(object): +class _Plugin(object): - def __init__(self, plugin): - self.plugin = plugin + def __init__(self, metadata_plugin): + self._metadata_plugin = metadata_plugin - def __call__(self, context, cygrpc_callback): - wrapped_cygrpc_callback = _WrappedCygrpcCallback(cygrpc_callback) - wrapped_context = AuthMetadataContext( + def __call__(self, context, callback): + wrapped_context = _AuthMetadataContext( _common.decode(context.service_url), _common.decode(context.method_name)) + callback_state = _CallbackState() + try: + self._metadata_plugin( + wrapped_context, + _AuthMetadataPluginCallback(callback_state, callback)) + except Exception as exception: # pylint: disable=broad-except + logging.exception( + 'AuthMetadataPluginCallback "%s" raised exception!', + self._metadata_plugin) + with callback_state.lock: + callback_state.exception = exception + if callback_state.called: + return + callback(_common.EMPTY_METADATA, cygrpc.StatusCode.internal, + _common.encode(str(exception))) + + +def metadata_plugin_call_credentials(metadata_plugin, name): + if name is None: try: - self.plugin(wrapped_context, - AuthMetadataPluginCallback(wrapped_cygrpc_callback)) - except Exception as error: - wrapped_cygrpc_callback.notify_failure(error) - raise - - -def call_credentials_metadata_plugin(plugin, name): - """ - Args: - plugin: A callable accepting a grpc.AuthMetadataContext - object and a callback (itself accepting a list of metadata key/value - 2-tuples and a None-able exception value). The callback must be eventually - called, but need not be called in plugin's invocation. - plugin's invocation must be non-blocking. - """ + effective_name = metadata_plugin.__name__ + except AttributeError: + effective_name = metadata_plugin.__class__.__name__ + else: + effective_name = name return cygrpc.call_credentials_metadata_plugin( cygrpc.CredentialsMetadataPlugin( - _WrappedPlugin(plugin), _common.encode(name))) + _Plugin(metadata_plugin), _common.encode(effective_name)))