Added google call creds/per_rpc interop tests

pull/6254/head
Ken Payson 9 years ago
parent fcbe7daf83
commit 60a83c744b
  1. 2
      src/python/grpcio/grpc/_cython/_cygrpc/credentials.pxd.pxi
  2. 2
      src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx.pxi
  3. 73
      src/python/grpcio/grpc/beta/_auth.py
  4. 33
      src/python/grpcio/grpc/beta/implementations.py
  5. 39
      src/python/grpcio/tests/interop/client.py
  6. 30
      src/python/grpcio/tests/interop/methods.py
  7. 3
      src/python/grpcio/tests/tests.json
  8. 96
      src/python/grpcio/tests/unit/beta/_auth_test.py
  9. 17
      src/python/grpcio/tests/unit/beta/_implementations_test.py
  10. 3
      tools/run_tests/run_interop_tests.py

@ -68,4 +68,4 @@ cdef void plugin_get_metadata(
void *state, grpc_auth_metadata_context context, void *state, grpc_auth_metadata_context context,
grpc_credentials_plugin_metadata_cb cb, void *user_data) with gil grpc_credentials_plugin_metadata_cb cb, void *user_data) with gil
cdef void plugin_destroy_c_plugin_state(void *state) cdef void plugin_destroy_c_plugin_state(void *state) with gil

@ -137,7 +137,7 @@ cdef void plugin_get_metadata(
cy_context.context = context cy_context.context = context
self.plugin_callback(cy_context, python_callback) self.plugin_callback(cy_context, python_callback)
cdef void plugin_destroy_c_plugin_state(void *state): cdef void plugin_destroy_c_plugin_state(void *state) with gil:
cpython.Py_DECREF(<CredentialsMetadataPlugin>state) cpython.Py_DECREF(<CredentialsMetadataPlugin>state)
def channel_credentials_google_default(): def channel_credentials_google_default():

@ -0,0 +1,73 @@
# Copyright 2016, Google Inc.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""GRPCAuthMetadataPlugins for standard authentication."""
from concurrent import futures
from grpc.beta import interfaces
def _sign_request(callback, token, error):
metadata = (('authorization', 'Bearer {}'.format(token)),)
callback(metadata, error)
class GoogleCallCredentials(interfaces.GRPCAuthMetadataPlugin):
"""Metadata wrapper for GoogleCredentials from the oauth2client library."""
def __init__(self, credentials):
self._credentials = credentials
self._pool = futures.ThreadPoolExecutor(max_workers=1)
def __call__(self, context, callback):
# MetadataPlugins cannot block (see grpc.beta.interfaces.py)
future = self._pool.submit(self._credentials.get_access_token)
future.add_done_callback(lambda x: self._get_token_callback(callback, x))
def _get_token_callback(self, callback, future):
try:
access_token = future.result().access_token
except Exception as e:
_sign_request(callback, None, e)
else:
_sign_request(callback, access_token, None)
def __del__(self):
self._pool.shutdown(wait=False)
class AccessTokenCallCredentials(interfaces.GRPCAuthMetadataPlugin):
"""Metadata wrapper for raw access token credentials."""
def __init__(self, access_token):
self._access_token = access_token
def __call__(self, context, callback):
_sign_request(callback, self._access_token, None)

@ -38,6 +38,7 @@ import threading # pylint: disable=unused-import
from grpc._adapter import _intermediary_low from grpc._adapter import _intermediary_low
from grpc._adapter import _low from grpc._adapter import _low
from grpc._adapter import _types from grpc._adapter import _types
from grpc.beta import _auth
from grpc.beta import _connectivity_channel from grpc.beta import _connectivity_channel
from grpc.beta import _server from grpc.beta import _server
from grpc.beta import _stub from grpc.beta import _stub
@ -105,10 +106,40 @@ def metadata_call_credentials(metadata_plugin, name=None):
A CallCredentials object for use in a GRPCCallOptions object. A CallCredentials object for use in a GRPCCallOptions object.
""" """
if name is None: if name is None:
name = metadata_plugin.__name__ try:
name = metadata_plugin.__name__
except AttributeError:
name = metadata_plugin.__class__.__name__
return CallCredentials( return CallCredentials(
_low.call_credentials_metadata_plugin(metadata_plugin, name)) _low.call_credentials_metadata_plugin(metadata_plugin, name))
def google_call_credentials(credentials):
"""Construct CallCredentials from GoogleCredentials.
Args:
credentials: A GoogleCredentials object from the oauth2client library.
Returns:
A CallCredentials object for use in a GRPCCallOptions object.
"""
return metadata_call_credentials(_auth.GoogleCallCredentials(credentials))
def access_token_call_credentials(access_token):
"""Construct CallCredentials from an access token.
Args:
access_token: A string to place directly in the http request
authorization header, ie "Authorization: Bearer <access_token>".
Returns:
A CallCredentials object for use in a GRPCCallOptions object.
"""
return metadata_call_credentials(
_auth.AccessTokenCallCredentials(access_token))
def composite_call_credentials(call_credentials, additional_call_credentials): def composite_call_credentials(call_credentials, additional_call_credentials):
"""Compose two CallCredentials to make a new one. """Compose two CallCredentials to make a new one.

@ -65,39 +65,34 @@ def _args():
help='email address of the default service account', type=str) help='email address of the default service account', type=str)
return parser.parse_args() return parser.parse_args()
def _oauth_access_token(args):
credentials = oauth2client_client.GoogleCredentials.get_application_default()
scoped_credentials = credentials.create_scoped([args.oauth_scope])
return scoped_credentials.get_access_token().access_token
def _stub(args): def _stub(args):
if args.oauth_scope: if args.test_case == 'oauth2_auth_token':
if args.test_case == 'oauth2_auth_token': creds = oauth2client_client.GoogleCredentials.get_application_default()
# TODO(jtattermusch): This testcase sets the auth metadata key-value scoped_creds = creds.create_scoped([args.oauth_scope])
# manually, which also means that the user would need to do the same access_token = scoped_creds.get_access_token().access_token
# thing every time he/she would like to use and out of band oauth token. call_creds = implementations.access_token_call_credentials(access_token)
# The transformer function that produces the metadata key-value from elif args.test_case == 'compute_engine_creds':
# the access token should be provided by gRPC auth library. creds = oauth2client_client.GoogleCredentials.get_application_default()
access_token = _oauth_access_token(args) scoped_creds = creds.create_scoped([args.oauth_scope])
metadata_transformer = lambda x: [ call_creds = implementations.google_call_credentials(scoped_creds)
('authorization', 'Bearer %s' % access_token)]
else:
metadata_transformer = lambda x: [
('authorization', 'Bearer %s' % _oauth_access_token(args))]
else: else:
metadata_transformer = lambda x: [] call_creds = None
if args.use_tls: if args.use_tls:
if args.use_test_ca: if args.use_test_ca:
root_certificates = resources.test_root_certificates() root_certificates = resources.test_root_certificates()
else: else:
root_certificates = None # will load default roots. root_certificates = None # will load default roots.
channel_creds = implementations.ssl_channel_credentials(root_certificates)
if call_creds is not None:
channel_creds = implementations.composite_channel_credentials(
channel_creds, call_creds)
channel = test_utilities.not_really_secure_channel( channel = test_utilities.not_really_secure_channel(
args.server_host, args.server_port, args.server_host, args.server_port, channel_creds,
implementations.ssl_channel_credentials(root_certificates),
args.server_host_override) args.server_host_override)
stub = test_pb2.beta_create_TestService_stub( stub = test_pb2.beta_create_TestService_stub(channel)
channel, metadata_transformer=metadata_transformer)
else: else:
channel = implementations.insecure_channel( channel = implementations.insecure_channel(
args.server_host, args.server_port) args.server_host, args.server_port)

@ -39,6 +39,8 @@ import time
from oauth2client import client as oauth2client_client from oauth2client import client as oauth2client_client
from grpc.beta import implementations
from grpc.beta import interfaces
from grpc.framework.common import cardinality from grpc.framework.common import cardinality
from grpc.framework.interfaces.face import face from grpc.framework.interfaces.face import face
@ -88,13 +90,15 @@ class TestService(test_pb2.BetaTestServiceServicer):
return self.FullDuplexCall(request_iterator, context) return self.FullDuplexCall(request_iterator, context)
def _large_unary_common_behavior(stub, fill_username, fill_oauth_scope): def _large_unary_common_behavior(stub, fill_username, fill_oauth_scope,
protocol_options=None):
with stub: with stub:
request = messages_pb2.SimpleRequest( request = messages_pb2.SimpleRequest(
response_type=messages_pb2.COMPRESSABLE, response_size=314159, response_type=messages_pb2.COMPRESSABLE, response_size=314159,
payload=messages_pb2.Payload(body=b'\x00' * 271828), payload=messages_pb2.Payload(body=b'\x00' * 271828),
fill_username=fill_username, fill_oauth_scope=fill_oauth_scope) fill_username=fill_username, fill_oauth_scope=fill_oauth_scope)
response_future = stub.UnaryCall.future(request, _TIMEOUT) response_future = stub.UnaryCall.future(request, _TIMEOUT,
protocol_options=protocol_options)
response = response_future.result() response = response_future.result()
if response.payload.type is not messages_pb2.COMPRESSABLE: if response.payload.type is not messages_pb2.COMPRESSABLE:
raise ValueError( raise ValueError(
@ -303,7 +307,24 @@ def _oauth2_auth_token(stub, args):
if args.oauth_scope.find(response.oauth_scope) == -1: if args.oauth_scope.find(response.oauth_scope) == -1:
raise ValueError( raise ValueError(
'expected to find oauth scope "%s" in received "%s"' % 'expected to find oauth scope "%s" in received "%s"' %
(response.oauth_scope, args.oauth_scope)) (response.oauth_scope, args.oauth_scope))
def _per_rpc_creds(stub, args):
json_key_filename = os.environ[
oauth2client_client.GOOGLE_APPLICATION_CREDENTIALS]
wanted_email = json.load(open(json_key_filename, 'rb'))['client_email']
credentials = oauth2client_client.GoogleCredentials.get_application_default()
scoped_credentials = credentials.create_scoped([args.oauth_scope])
call_creds = implementations.google_call_credentials(scoped_credentials)
options = interfaces.grpc_call_options(disable_compression=False,
credentials=call_creds)
response = _large_unary_common_behavior(stub, True, False,
protocol_options=options)
if wanted_email != response.username:
raise ValueError(
'expected username %s, got %s' % (wanted_email, response.username))
@enum.unique @enum.unique
class TestCase(enum.Enum): class TestCase(enum.Enum):
@ -317,6 +338,7 @@ class TestCase(enum.Enum):
EMPTY_STREAM = 'empty_stream' EMPTY_STREAM = 'empty_stream'
COMPUTE_ENGINE_CREDS = 'compute_engine_creds' COMPUTE_ENGINE_CREDS = 'compute_engine_creds'
OAUTH2_AUTH_TOKEN = 'oauth2_auth_token' OAUTH2_AUTH_TOKEN = 'oauth2_auth_token'
PER_RPC_CREDS = 'per_rpc_creds'
TIMEOUT_ON_SLEEPING_SERVER = 'timeout_on_sleeping_server' TIMEOUT_ON_SLEEPING_SERVER = 'timeout_on_sleeping_server'
def test_interoperability(self, stub, args): def test_interoperability(self, stub, args):
@ -342,5 +364,7 @@ class TestCase(enum.Enum):
_compute_engine_creds(stub, args) _compute_engine_creds(stub, args)
elif self is TestCase.OAUTH2_AUTH_TOKEN: elif self is TestCase.OAUTH2_AUTH_TOKEN:
_oauth2_auth_token(stub, args) _oauth2_auth_token(stub, args)
elif self is TestCase.PER_RPC_CREDS:
_per_rpc_creds(stub, args)
else: else:
raise NotImplementedError('Test case "%s" not implemented!' % self.name) raise NotImplementedError('Test case "%s" not implemented!' % self.name)

@ -1,4 +1,6 @@
[ [
"_auth_test.AccessTokenCallCredentialsTest",
"_auth_test.GoogleCallCredentialsTest",
"_base_interface_test.AsyncEasyTest", "_base_interface_test.AsyncEasyTest",
"_base_interface_test.AsyncPeasyTest", "_base_interface_test.AsyncPeasyTest",
"_base_interface_test.SyncEasyTest", "_base_interface_test.SyncEasyTest",
@ -30,6 +32,7 @@
"_face_interface_test.MultiCallableInvokerBlockingInvocationInlineServiceTest", "_face_interface_test.MultiCallableInvokerBlockingInvocationInlineServiceTest",
"_face_interface_test.MultiCallableInvokerFutureInvocationAsynchronousEventServiceTest", "_face_interface_test.MultiCallableInvokerFutureInvocationAsynchronousEventServiceTest",
"_health_servicer_test.HealthServicerTest", "_health_servicer_test.HealthServicerTest",
"_implementations_test.CallCredentialsTest",
"_implementations_test.ChannelCredentialsTest", "_implementations_test.ChannelCredentialsTest",
"_insecure_interop_test.InsecureInteropTest", "_insecure_interop_test.InsecureInteropTest",
"_intermediary_low_test.CancellationTest", "_intermediary_low_test.CancellationTest",

@ -0,0 +1,96 @@
# Copyright 2016, Google Inc.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Tests of standard AuthMetadataPlugins."""
import collections
import threading
import unittest
from grpc.beta import _auth
class MockGoogleCreds(object):
def get_access_token(self):
token = collections.namedtuple('MockAccessTokenInfo',
('access_token', 'expires_in'))
token.access_token = 'token'
return token
class MockExceptionGoogleCreds(object):
def get_access_token(self):
raise Exception()
class GoogleCallCredentialsTest(unittest.TestCase):
def test_google_call_credentials_success(self):
callback_event = threading.Event()
def mock_callback(metadata, error):
self.assertEqual(metadata, (('authorization', 'Bearer token'),))
self.assertIsNone(error)
callback_event.set()
call_creds = _auth.GoogleCallCredentials(MockGoogleCreds())
call_creds(None, mock_callback)
self.assertTrue(callback_event.wait(1.0))
def test_google_call_credentials_error(self):
callback_event = threading.Event()
def mock_callback(metadata, error):
self.assertIsNotNone(error)
callback_event.set()
call_creds = _auth.GoogleCallCredentials(MockExceptionGoogleCreds())
call_creds(None, mock_callback)
self.assertTrue(callback_event.wait(1.0))
class AccessTokenCallCredentialsTest(unittest.TestCase):
def test_google_call_credentials_success(self):
callback_event = threading.Event()
def mock_callback(metadata, error):
self.assertEqual(metadata, (('authorization', 'Bearer token'),))
self.assertIsNone(error)
callback_event.set()
call_creds = _auth.AccessTokenCallCredentials('token')
call_creds(None, mock_callback)
self.assertTrue(callback_event.wait(1.0))
if __name__ == '__main__':
unittest.main(verbosity=2)

@ -29,8 +29,11 @@
"""Tests the implementations module of the gRPC Python Beta API.""" """Tests the implementations module of the gRPC Python Beta API."""
import datetime
import unittest import unittest
from oauth2client import client as oauth2client_client
from grpc.beta import implementations from grpc.beta import implementations
from tests.unit import resources from tests.unit import resources
@ -49,5 +52,19 @@ class ChannelCredentialsTest(unittest.TestCase):
channel_credentials, implementations.ChannelCredentials) channel_credentials, implementations.ChannelCredentials)
class CallCredentialsTest(unittest.TestCase):
def test_google_call_credentials(self):
creds = oauth2client_client.GoogleCredentials(
'token', 'client_id', 'secret', 'refresh_token',
datetime.datetime(2008, 6, 24), 'https://refresh.uri.com/',
'user_agent')
call_creds = implementations.google_call_credentials(creds)
self.assertIsInstance(call_creds, implementations.CallCredentials)
def test_access_token_call_credentials(self):
call_creds = implementations.access_token_call_credentials('token')
self.assertIsInstance(call_creds, implementations.CallCredentials)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main(verbosity=2) unittest.main(verbosity=2)

@ -317,8 +317,7 @@ class PythonLanguage:
'PYTHONPATH': '{}/src/python/gens'.format(DOCKER_WORKDIR_ROOT)} 'PYTHONPATH': '{}/src/python/gens'.format(DOCKER_WORKDIR_ROOT)}
def unimplemented_test_cases(self): def unimplemented_test_cases(self):
return _SKIP_ADVANCED + _SKIP_COMPRESSION + ['jwt_token_creds', return _SKIP_ADVANCED + _SKIP_COMPRESSION + ['jwt_token_creds']
'per_rpc_creds']
def unimplemented_test_cases_server(self): def unimplemented_test_cases_server(self):
return _SKIP_ADVANCED + _SKIP_COMPRESSION return _SKIP_ADVANCED + _SKIP_COMPRESSION

Loading…
Cancel
Save