|
|
|
@ -13,6 +13,7 @@ |
|
|
|
|
# limitations under the License. |
|
|
|
|
|
|
|
|
|
import collections |
|
|
|
|
import logging |
|
|
|
|
import threading |
|
|
|
|
|
|
|
|
|
import grpc |
|
|
|
@ -20,89 +21,80 @@ from grpc import _common |
|
|
|
|
from grpc._cython import cygrpc |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AuthMetadataContext( |
|
|
|
|
class _AuthMetadataContext( |
|
|
|
|
collections.namedtuple('AuthMetadataContext', ( |
|
|
|
|
'service_url', 'method_name',)), grpc.AuthMetadataContext): |
|
|
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AuthMetadataPluginCallback(grpc.AuthMetadataContext): |
|
|
|
|
class _CallbackState(object): |
|
|
|
|
|
|
|
|
|
def __init__(self, callback): |
|
|
|
|
self._callback = callback |
|
|
|
|
|
|
|
|
|
def __call__(self, metadata, error): |
|
|
|
|
self._callback(metadata, error) |
|
|
|
|
def __init__(self): |
|
|
|
|
self.lock = threading.Lock() |
|
|
|
|
self.called = False |
|
|
|
|
self.exception = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _WrappedCygrpcCallback(object): |
|
|
|
|
class _AuthMetadataPluginCallback(grpc.AuthMetadataPluginCallback): |
|
|
|
|
|
|
|
|
|
def __init__(self, cygrpc_callback): |
|
|
|
|
self.is_called = False |
|
|
|
|
self.error = None |
|
|
|
|
self.is_called_lock = threading.Lock() |
|
|
|
|
self.cygrpc_callback = cygrpc_callback |
|
|
|
|
|
|
|
|
|
def _invoke_failure(self, error): |
|
|
|
|
# TODO(atash) translate different Exception superclasses into different |
|
|
|
|
# status codes. |
|
|
|
|
self.cygrpc_callback(_common.EMPTY_METADATA, cygrpc.StatusCode.internal, |
|
|
|
|
_common.encode(str(error))) |
|
|
|
|
|
|
|
|
|
def _invoke_success(self, metadata): |
|
|
|
|
try: |
|
|
|
|
cygrpc_metadata = _common.to_cygrpc_metadata(metadata) |
|
|
|
|
except Exception as exception: # pylint: disable=broad-except |
|
|
|
|
self._invoke_failure(exception) |
|
|
|
|
return |
|
|
|
|
self.cygrpc_callback(cygrpc_metadata, cygrpc.StatusCode.ok, b'') |
|
|
|
|
def __init__(self, state, callback): |
|
|
|
|
self._state = state |
|
|
|
|
self._callback = callback |
|
|
|
|
|
|
|
|
|
def __call__(self, metadata, error): |
|
|
|
|
with self.is_called_lock: |
|
|
|
|
if self.is_called: |
|
|
|
|
raise RuntimeError('callback should only ever be invoked once') |
|
|
|
|
if self.error: |
|
|
|
|
self._invoke_failure(self.error) |
|
|
|
|
return |
|
|
|
|
self.is_called = True |
|
|
|
|
with self._state.lock: |
|
|
|
|
if self._state.exception is None: |
|
|
|
|
if self._state.called: |
|
|
|
|
raise RuntimeError( |
|
|
|
|
'AuthMetadataPluginCallback invoked more than once!') |
|
|
|
|
else: |
|
|
|
|
self._state.called = True |
|
|
|
|
else: |
|
|
|
|
raise RuntimeError( |
|
|
|
|
'AuthMetadataPluginCallback raised exception "{}"!'.format( |
|
|
|
|
self._state.exception)) |
|
|
|
|
if error is None: |
|
|
|
|
self._invoke_success(metadata) |
|
|
|
|
self._callback( |
|
|
|
|
_common.to_cygrpc_metadata(metadata), cygrpc.StatusCode.ok, b'') |
|
|
|
|
else: |
|
|
|
|
self._invoke_failure(error) |
|
|
|
|
|
|
|
|
|
def notify_failure(self, error): |
|
|
|
|
with self.is_called_lock: |
|
|
|
|
if not self.is_called: |
|
|
|
|
self.error = error |
|
|
|
|
self._callback(_common.EMPTY_METADATA, cygrpc.StatusCode.internal, |
|
|
|
|
_common.encode(str(error))) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _WrappedPlugin(object): |
|
|
|
|
class _Plugin(object): |
|
|
|
|
|
|
|
|
|
def __init__(self, plugin): |
|
|
|
|
self.plugin = plugin |
|
|
|
|
def __init__(self, metadata_plugin): |
|
|
|
|
self._metadata_plugin = metadata_plugin |
|
|
|
|
|
|
|
|
|
def __call__(self, context, cygrpc_callback): |
|
|
|
|
wrapped_cygrpc_callback = _WrappedCygrpcCallback(cygrpc_callback) |
|
|
|
|
wrapped_context = AuthMetadataContext( |
|
|
|
|
def __call__(self, context, callback): |
|
|
|
|
wrapped_context = _AuthMetadataContext( |
|
|
|
|
_common.decode(context.service_url), |
|
|
|
|
_common.decode(context.method_name)) |
|
|
|
|
callback_state = _CallbackState() |
|
|
|
|
try: |
|
|
|
|
self._metadata_plugin( |
|
|
|
|
wrapped_context, |
|
|
|
|
_AuthMetadataPluginCallback(callback_state, callback)) |
|
|
|
|
except Exception as exception: # pylint: disable=broad-except |
|
|
|
|
logging.exception( |
|
|
|
|
'AuthMetadataPluginCallback "%s" raised exception!', |
|
|
|
|
self._metadata_plugin) |
|
|
|
|
with callback_state.lock: |
|
|
|
|
callback_state.exception = exception |
|
|
|
|
if callback_state.called: |
|
|
|
|
return |
|
|
|
|
callback(_common.EMPTY_METADATA, cygrpc.StatusCode.internal, |
|
|
|
|
_common.encode(str(exception))) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def metadata_plugin_call_credentials(metadata_plugin, name): |
|
|
|
|
if name is None: |
|
|
|
|
try: |
|
|
|
|
self.plugin(wrapped_context, |
|
|
|
|
AuthMetadataPluginCallback(wrapped_cygrpc_callback)) |
|
|
|
|
except Exception as error: |
|
|
|
|
wrapped_cygrpc_callback.notify_failure(error) |
|
|
|
|
raise |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def call_credentials_metadata_plugin(plugin, name): |
|
|
|
|
""" |
|
|
|
|
Args: |
|
|
|
|
plugin: A callable accepting a grpc.AuthMetadataContext |
|
|
|
|
object and a callback (itself accepting a list of metadata key/value |
|
|
|
|
2-tuples and a None-able exception value). The callback must be eventually |
|
|
|
|
called, but need not be called in plugin's invocation. |
|
|
|
|
plugin's invocation must be non-blocking. |
|
|
|
|
""" |
|
|
|
|
effective_name = metadata_plugin.__name__ |
|
|
|
|
except AttributeError: |
|
|
|
|
effective_name = metadata_plugin.__class__.__name__ |
|
|
|
|
else: |
|
|
|
|
effective_name = name |
|
|
|
|
return cygrpc.call_credentials_metadata_plugin( |
|
|
|
|
cygrpc.CredentialsMetadataPlugin( |
|
|
|
|
_WrappedPlugin(plugin), _common.encode(name))) |
|
|
|
|
_Plugin(metadata_plugin), _common.encode(effective_name))) |
|
|
|
|