|
|
|
@ -14,7 +14,6 @@ |
|
|
|
|
"""GRPCAuthMetadataPlugins for standard authentication.""" |
|
|
|
|
|
|
|
|
|
import inspect |
|
|
|
|
from concurrent import futures |
|
|
|
|
|
|
|
|
|
import grpc |
|
|
|
|
|
|
|
|
@ -24,43 +23,29 @@ def _sign_request(callback, token, error): |
|
|
|
|
callback(metadata, error) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _create_get_token_callback(callback): |
|
|
|
|
|
|
|
|
|
def get_token_callback(future): |
|
|
|
|
try: |
|
|
|
|
access_token = future.result().access_token |
|
|
|
|
except Exception as exception: # pylint: disable=broad-except |
|
|
|
|
_sign_request(callback, None, exception) |
|
|
|
|
else: |
|
|
|
|
_sign_request(callback, access_token, None) |
|
|
|
|
|
|
|
|
|
return get_token_callback |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GoogleCallCredentials(grpc.AuthMetadataPlugin): |
|
|
|
|
"""Metadata wrapper for GoogleCredentials from the oauth2client library.""" |
|
|
|
|
|
|
|
|
|
def __init__(self, credentials): |
|
|
|
|
self._credentials = credentials |
|
|
|
|
self._pool = futures.ThreadPoolExecutor(max_workers=1) |
|
|
|
|
|
|
|
|
|
# Hack to determine if these are JWT creds and we need to pass |
|
|
|
|
# additional_claims when getting a token |
|
|
|
|
self._is_jwt = 'additional_claims' in inspect.getargspec( # pylint: disable=deprecated-method |
|
|
|
|
credentials.get_access_token).args |
|
|
|
|
|
|
|
|
|
def __call__(self, context, callback): |
|
|
|
|
# MetadataPlugins cannot block (see grpc.beta.interfaces.py) |
|
|
|
|
try: |
|
|
|
|
if self._is_jwt: |
|
|
|
|
future = self._pool.submit( |
|
|
|
|
self._credentials.get_access_token, |
|
|
|
|
additional_claims={'aud': context.service_url}) |
|
|
|
|
access_token = self._credentials.get_access_token( |
|
|
|
|
additional_claims={ |
|
|
|
|
'aud': context.service_url |
|
|
|
|
}).access_token |
|
|
|
|
else: |
|
|
|
|
future = self._pool.submit(self._credentials.get_access_token) |
|
|
|
|
future.add_done_callback(_create_get_token_callback(callback)) |
|
|
|
|
|
|
|
|
|
def __del__(self): |
|
|
|
|
self._pool.shutdown(wait=False) |
|
|
|
|
access_token = self._credentials.get_access_token().access_token |
|
|
|
|
except Exception as exception: # pylint: disable=broad-except |
|
|
|
|
_sign_request(callback, None, exception) |
|
|
|
|
else: |
|
|
|
|
_sign_request(callback, access_token, None) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AccessTokenAuthMetadataPlugin(grpc.AuthMetadataPlugin): |
|
|
|
|