diff --git a/src/python/grpcio/grpc/__init__.py b/src/python/grpcio/grpc/__init__.py index a6c9275d5c8..9e349c5b023 100644 --- a/src/python/grpcio/grpc/__init__.py +++ b/src/python/grpcio/grpc/__init__.py @@ -628,7 +628,7 @@ class AuthMetadataPlugin(six.with_metaclass(abc.ABCMeta)): def __call__(self, context, callback): """Implements authentication by passing metadata to a callback. - Implementations of this method must not block. + This method will be invoked asynchronously in a separate thread. Args: context: An AuthMetadataContext providing information on the RPC that diff --git a/src/python/grpcio/grpc/_auth.py b/src/python/grpcio/grpc/_auth.py index 724229a8f56..2d38320afff 100644 --- a/src/python/grpcio/grpc/_auth.py +++ b/src/python/grpcio/grpc/_auth.py @@ -14,7 +14,6 @@ """GRPCAuthMetadataPlugins for standard authentication.""" import inspect -from concurrent import futures import grpc @@ -24,43 +23,29 @@ def _sign_request(callback, token, error): callback(metadata, error) -def _create_get_token_callback(callback): - - def get_token_callback(future): - try: - access_token = future.result().access_token - except Exception as exception: # pylint: disable=broad-except - _sign_request(callback, None, exception) - else: - _sign_request(callback, access_token, None) - - return get_token_callback - - class GoogleCallCredentials(grpc.AuthMetadataPlugin): """Metadata wrapper for GoogleCredentials from the oauth2client library.""" def __init__(self, credentials): self._credentials = credentials - self._pool = futures.ThreadPoolExecutor(max_workers=1) - # Hack to determine if these are JWT creds and we need to pass # additional_claims when getting a token self._is_jwt = 'additional_claims' in inspect.getargspec( # pylint: disable=deprecated-method credentials.get_access_token).args def __call__(self, context, callback): - # MetadataPlugins cannot block (see grpc.beta.interfaces.py) - if self._is_jwt: - future = self._pool.submit( - self._credentials.get_access_token, - additional_claims={'aud': context.service_url}) + try: + if self._is_jwt: + access_token = self._credentials.get_access_token( + additional_claims={ + 'aud': context.service_url + }).access_token + else: + access_token = self._credentials.get_access_token().access_token + except Exception as exception: # pylint: disable=broad-except + _sign_request(callback, None, exception) else: - future = self._pool.submit(self._credentials.get_access_token) - future.add_done_callback(_create_get_token_callback(callback)) - - def __del__(self): - self._pool.shutdown(wait=False) + _sign_request(callback, access_token, None) class AccessTokenAuthMetadataPlugin(grpc.AuthMetadataPlugin):