Add metadata auth plugin API support

pull/4228/head
Masood Malekghassemi 9 years ago
parent 25ef5c8fad
commit 0f1bf32387
  1. 48
      src/python/grpcio/grpc/_adapter/_implementations.py
  2. 25
      src/python/grpcio/grpc/_adapter/_intermediary_low.py
  3. 80
      src/python/grpcio/grpc/_adapter/_low.py
  4. 23
      src/python/grpcio/grpc/_cython/_cygrpc/credentials.pxd
  5. 71
      src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx
  6. 26
      src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxd
  7. 2
      src/python/grpcio/grpc/_cython/_cygrpc/records.pyx
  8. 3
      src/python/grpcio/grpc/_cython/cygrpc.pyx
  9. 2
      src/python/grpcio/grpc/_links/invocation.py
  10. 2
      src/python/grpcio/grpc/beta/_server.py
  11. 96
      src/python/grpcio/grpc/beta/implementations.py
  12. 49
      src/python/grpcio/grpc/beta/interfaces.py
  13. 2
      src/python/grpcio/tests/interop/_secure_interop_test.py
  14. 2
      src/python/grpcio/tests/interop/client.py
  15. 188
      src/python/grpcio/tests/unit/_cython/cygrpc_test.py
  16. 58
      src/python/grpcio/tests/unit/beta/_beta_features_test.py
  17. 4
      src/python/grpcio/tests/unit/beta/_face_interface_test.py
  18. 6
      src/python/grpcio/tests/unit/beta/test_utilities.py

@ -0,0 +1,48 @@
# Copyright 2015, Google Inc.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import collections
from grpc.beta import interfaces
class AuthMetadataContext(collections.namedtuple(
'AuthMetadataContext', [
'service_url',
'method_name'
]), interfaces.GRPCAuthMetadataContext):
pass
class AuthMetadataPluginCallback(interfaces.GRPCAuthMetadataContext):
def __init__(self, callback):
self._callback = callback
def __call__(self, metadata, error):
self._callback(metadata, error)

@ -173,20 +173,17 @@ class Call(object):
return self._internal.peer()
def set_credentials(self, creds):
return self._internal.set_credentials(creds._internal)
return self._internal.set_credentials(creds)
class Channel(object):
"""Adapter from old _low.Channel interface to new _low.Channel."""
def __init__(self, hostport, client_credentials, server_host_override=None):
def __init__(self, hostport, channel_credentials, server_host_override=None):
args = []
if server_host_override:
args.append((_types.GrpcChannelArgumentKeys.SSL_TARGET_NAME_OVERRIDE.value, server_host_override))
creds = None
if client_credentials:
creds = client_credentials._internal
self._internal = _low.Channel(hostport, args, creds)
self._internal = _low.Channel(hostport, args, channel_credentials)
class CompletionQueue(object):
@ -245,7 +242,7 @@ class Server(object):
if server_credentials is None:
return self._internal.add_http2_port(addr, None)
else:
return self._internal.add_http2_port(addr, server_credentials._internal)
return self._internal.add_http2_port(addr, server_credentials)
def start(self):
return self._internal.start()
@ -259,17 +256,3 @@ class Server(object):
def stop(self):
return self._internal.shutdown(_TagAdapter(None, Event.Kind.STOP))
class ClientCredentials(object):
"""Adapter from old _low.ClientCredentials interface to new _low.ChannelCredentials."""
def __init__(self, root_certificates, private_key, certificate_chain):
self._internal = _low.channel_credentials_ssl(root_certificates, private_key, certificate_chain)
class ServerCredentials(object):
"""Adapter from old _low.ServerCredentials interface to new _low.ServerCredentials."""
def __init__(self, root_credentials, pair_sequence, force_client_auth):
self._internal = _low.server_credentials_ssl(
root_credentials, pair_sequence, force_client_auth)

@ -27,8 +27,11 @@
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import threading
from grpc import _grpcio_metadata
from grpc._cython import cygrpc
from grpc._adapter import _implementations
from grpc._adapter import _types
_USER_AGENT = 'Python-gRPC-{}'.format(_grpcio_metadata.__version__)
@ -37,6 +40,9 @@ ChannelCredentials = cygrpc.ChannelCredentials
CallCredentials = cygrpc.CallCredentials
ServerCredentials = cygrpc.ServerCredentials
channel_credentials_composite = cygrpc.channel_credentials_composite
call_credentials_composite = cygrpc.call_credentials_composite
def server_credentials_ssl(root_credentials, pair_sequence, force_client_auth):
return cygrpc.server_credentials_ssl(
root_credentials,
@ -51,6 +57,80 @@ def channel_credentials_ssl(
return cygrpc.channel_credentials_ssl(root_certificates, pair)
class _WrappedCygrpcCallback(object):
def __init__(self, cygrpc_callback):
self.is_called = False
self.error = None
self.is_called_lock = threading.Lock()
self.cygrpc_callback = cygrpc_callback
def _invoke_failure(self, error):
# TODO(atash) translate different Exception superclasses into different
# status codes.
self.cygrpc_callback(
cygrpc.Metadata([]), cygrpc.StatusCode.internal, error.message)
def _invoke_success(self, metadata):
try:
cygrpc_metadata = cygrpc.Metadata(
cygrpc.Metadatum(key, value)
for key, value in metadata)
except Exception as error:
self._invoke_failure(error)
return
self.cygrpc_callback(cygrpc_metadata, cygrpc.StatusCode.ok, '')
def __call__(self, metadata, error):
with self.is_called_lock:
if self.is_called:
raise RuntimeError('callback should only ever be invoked once')
if self.error:
self._invoke_failure(self.error)
return
self.is_called = True
if error is None:
self._invoke_success(metadata)
else:
self._invoke_failure(error)
def notify_failure(self, error):
with self.is_called_lock:
if not self.is_called:
self.error = error
class _WrappedPlugin(object):
def __init__(self, plugin):
self.plugin = plugin
def __call__(self, context, cygrpc_callback):
wrapped_cygrpc_callback = _WrappedCygrpcCallback(cygrpc_callback)
wrapped_context = _implementations.AuthMetadataContext(context.service_url,
context.method_name)
try:
self.plugin(
wrapped_context,
_implementations.AuthMetadataPluginCallback(wrapped_cygrpc_callback))
except Exception as error:
wrapped_cygrpc_callback.notify_failure(error)
raise
def call_credentials_metadata_plugin(plugin, name):
"""
Args:
plugin: A callable accepting a _types.AuthMetadataContext
object and a callback (itself accepting a list of metadata key/value
2-tuples and a None-able exception value). The callback must be eventually
called, but need not be called in plugin's invocation.
plugin's invocation must be non-blocking.
"""
return cygrpc.call_credentials_metadata_plugin(
cygrpc.CredentialsMetadataPlugin(_WrappedPlugin(plugin), name))
class CompletionQueue(_types.CompletionQueue):
def __init__(self):

@ -27,7 +27,10 @@
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
cimport cpython
from grpc._cython._cygrpc cimport grpc
from grpc._cython._cygrpc cimport records
cdef class ChannelCredentials:
@ -49,3 +52,23 @@ cdef class ServerCredentials:
cdef grpc.grpc_ssl_pem_key_cert_pair *c_ssl_pem_key_cert_pairs
cdef size_t c_ssl_pem_key_cert_pairs_count
cdef list references
cdef class CredentialsMetadataPlugin:
cdef object plugin_callback
cdef str plugin_name
cdef grpc.grpc_metadata_credentials_plugin make_c_plugin(self)
cdef class AuthMetadataContext:
cdef grpc.grpc_auth_metadata_context context
cdef void plugin_get_metadata(
void *state, grpc.grpc_auth_metadata_context context,
grpc.grpc_credentials_plugin_metadata_cb cb, void *user_data) with gil
cdef void plugin_destroy_c_plugin_state(void *state)

@ -27,6 +27,8 @@
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
cimport cpython
from grpc._cython._cygrpc cimport grpc
from grpc._cython._cygrpc cimport records
@ -78,6 +80,66 @@ cdef class ServerCredentials:
grpc.grpc_server_credentials_release(self.c_credentials)
cdef class CredentialsMetadataPlugin:
def __cinit__(self, object plugin_callback, str name):
"""
Args:
plugin_callback (callable): Callback accepting a service URL (str/bytes)
and callback object (accepting a records.Metadata,
grpc.grpc_status_code, and a str/bytes error message). This argument
when called should be non-blocking and eventually call the callback
object with the appropriate status code/details and metadata (if
successful).
name (str): Plugin name.
"""
if not callable(plugin_callback):
raise ValueError('expected callable plugin_callback')
self.plugin_callback = plugin_callback
self.plugin_name = name
@staticmethod
cdef grpc.grpc_metadata_credentials_plugin make_c_plugin(self):
cdef grpc.grpc_metadata_credentials_plugin result
result.get_metadata = plugin_get_metadata
result.destroy = plugin_destroy_c_plugin_state
result.state = <void *>self
result.type = self.plugin_name
cpython.Py_INCREF(self)
return result
cdef class AuthMetadataContext:
def __cinit__(self):
self.context.service_url = NULL
self.context.method_name = NULL
@property
def service_url(self):
return self.context.service_url
@property
def method_name(self):
return self.context.method_name
cdef void plugin_get_metadata(
void *state, grpc.grpc_auth_metadata_context context,
grpc.grpc_credentials_plugin_metadata_cb cb, void *user_data) with gil:
def python_callback(
records.Metadata metadata, grpc.grpc_status_code status,
const char *error_details):
cb(user_data, metadata.c_metadata_array.metadata,
metadata.c_metadata_array.count, status, error_details)
cdef CredentialsMetadataPlugin self = <CredentialsMetadataPlugin>state
cdef AuthMetadataContext cy_context = AuthMetadataContext()
cy_context.context = context
self.plugin_callback(cy_context, python_callback)
cdef void plugin_destroy_c_plugin_state(void *state):
cpython.Py_DECREF(<CredentialsMetadataPlugin>state)
def channel_credentials_google_default():
cdef ChannelCredentials credentials = ChannelCredentials();
credentials.c_credentials = grpc.grpc_google_default_credentials_create()
@ -185,6 +247,15 @@ def call_credentials_google_iam(authorization_token, authority_selector):
credentials.references.append(authority_selector)
return credentials
def call_credentials_metadata_plugin(CredentialsMetadataPlugin plugin):
cdef CallCredentials credentials = CallCredentials()
credentials.c_credentials = (
grpc.grpc_metadata_credentials_create_from_plugin(plugin.make_c_plugin(),
NULL))
# TODO(atash): the following held reference is *probably* never necessary
credentials.references.append(plugin)
return credentials
def server_credentials_ssl(pem_root_certs, pem_key_cert_pairs,
bint force_client_auth):
cdef char *c_pem_root_certs = NULL

@ -137,8 +137,6 @@ cdef extern from "grpc/grpc.h":
const char *GRPC_ARG_MAX_CONCURRENT_STREAMS
const char *GRPC_ARG_MAX_MESSAGE_LENGTH
const char *GRPC_ARG_HTTP2_INITIAL_SEQUENCE_NUMBER
const char *GRPC_ARG_HTTP2_HPACK_TABLE_SIZE_DECODER
const char *GRPC_ARG_HTTP2_HPACK_TABLE_SIZE_ENCODER
const char *GRPC_ARG_DEFAULT_AUTHORITY
const char *GRPC_ARG_PRIMARY_USER_AGENT_STRING
const char *GRPC_ARG_SECONDARY_USER_AGENT_STRING
@ -396,3 +394,27 @@ cdef extern from "grpc/grpc_security.h":
grpc_call_error grpc_call_set_credentials(grpc_call *call,
grpc_call_credentials *creds)
ctypedef struct grpc_auth_context:
# We don't care about the internals (and in fact don't know them)
pass
ctypedef struct grpc_auth_metadata_context:
const char *service_url
const char *method_name
const grpc_auth_context *channel_auth_context
ctypedef void (*grpc_credentials_plugin_metadata_cb)(
void *user_data, const grpc_metadata *creds_md, size_t num_creds_md,
grpc_status_code status, const char *error_details)
ctypedef struct grpc_metadata_credentials_plugin:
void (*get_metadata)(
void *state, grpc_auth_metadata_context context,
grpc_credentials_plugin_metadata_cb cb, void *user_data)
void (*destroy)(void *state)
void *state
const char *type
grpc_call_credentials *grpc_metadata_credentials_create_from_plugin(
grpc_metadata_credentials_plugin plugin, void *reserved)

@ -45,8 +45,6 @@ class ChannelArgKey:
max_concurrent_streams = grpc.GRPC_ARG_MAX_CONCURRENT_STREAMS
max_message_length = grpc.GRPC_ARG_MAX_MESSAGE_LENGTH
http2_initial_sequence_number = grpc.GRPC_ARG_HTTP2_INITIAL_SEQUENCE_NUMBER
http2_hpack_table_size_decoder = grpc.GRPC_ARG_HTTP2_HPACK_TABLE_SIZE_DECODER
http2_hpack_table_size_encoder = grpc.GRPC_ARG_HTTP2_HPACK_TABLE_SIZE_ENCODER
default_authority = grpc.GRPC_ARG_DEFAULT_AUTHORITY
primary_user_agent_string = grpc.GRPC_ARG_PRIMARY_USER_AGENT_STRING
secondary_user_agent_string = grpc.GRPC_ARG_SECONDARY_USER_AGENT_STRING

@ -76,6 +76,8 @@ Operations = records.Operations
CallCredentials = credentials.CallCredentials
ChannelCredentials = credentials.ChannelCredentials
ServerCredentials = credentials.ServerCredentials
CredentialsMetadataPlugin = credentials.CredentialsMetadataPlugin
AuthMetadataContext = credentials.AuthMetadataContext
channel_credentials_google_default = (
credentials.channel_credentials_google_default)
@ -91,6 +93,7 @@ call_credentials_jwt_access = (
call_credentials_refresh_token = (
credentials.call_credentials_google_refresh_token)
call_credentials_google_iam = credentials.call_credentials_google_iam
call_credentials_metadata_plugin = credentials.call_credentials_metadata_plugin
server_credentials_ssl = credentials.server_credentials_ssl
CompletionQueue = completion_queue.CompletionQueue

@ -262,7 +262,7 @@ class _Kernel(object):
self._channel, self._completion_queue, '/%s/%s' % (group, method),
self._host, time.time() + timeout)
if options is not None and options.credentials is not None:
call.set_credentials(options.credentials._intermediary_low_credentials)
call.set_credentials(options.credentials._low_credentials)
if transformed_initial_metadata is not None:
for metadata_key, metadata_value in transformed_initial_metadata:
call.add_metadata(metadata_key, metadata_value)

@ -170,7 +170,7 @@ class _Server(interfaces.Server):
with self._lock:
if self._end_link is None:
return self._grpc_link.add_port(
address, server_credentials._intermediary_low_credentials) # pylint: disable=protected-access
address, server_credentials._low_credentials) # pylint: disable=protected-access
else:
raise ValueError('Can\'t add port to serving server!')

@ -36,6 +36,7 @@ import threading # pylint: disable=unused-import
# cardinality and face are referenced from specification in this module.
from grpc._adapter import _intermediary_low
from grpc._adapter import _low
from grpc._adapter import _types
from grpc.beta import _connectivity_channel
from grpc.beta import _server
@ -48,7 +49,7 @@ _CHANNEL_SUBSCRIPTION_CALLBACK_ERROR_LOG_MESSAGE = (
'Exception calling channel subscription callback!')
class ClientCredentials(object):
class ChannelCredentials(object):
"""A value encapsulating the data required to create a secure Channel.
This class and its instances have no supported interface - it exists to define
@ -56,13 +57,12 @@ class ClientCredentials(object):
functions.
"""
def __init__(self, low_credentials, intermediary_low_credentials):
def __init__(self, low_credentials):
self._low_credentials = low_credentials
self._intermediary_low_credentials = intermediary_low_credentials
def ssl_client_credentials(root_certificates, private_key, certificate_chain):
"""Creates a ClientCredentials for use with an SSL-enabled Channel.
def ssl_channel_credentials(root_certificates, private_key, certificate_chain):
"""Creates a ChannelCredentials for use with an SSL-enabled Channel.
Args:
root_certificates: The PEM-encoded root certificates or None to ask for
@ -73,12 +73,73 @@ def ssl_client_credentials(root_certificates, private_key, certificate_chain):
certificate chain should be used.
Returns:
A ClientCredentials for use with an SSL-enabled Channel.
A ChannelCredentials for use with an SSL-enabled Channel.
"""
intermediary_low_credentials = _intermediary_low.ClientCredentials(
root_certificates, private_key, certificate_chain)
return ClientCredentials(
intermediary_low_credentials._internal, intermediary_low_credentials) # pylint: disable=protected-access
return ChannelCredentials(_low.channel_credentials_ssl(
root_certificates, private_key, certificate_chain))
class CallCredentials(object):
"""A value encapsulating data asserting an identity over an *established*
channel. May be composed with ChannelCredentials to always assert identity for
every call over that channel.
This class and its instances have no supported interface - it exists to define
the type of its instances and its instances exist to be passed to other
functions.
"""
def __init__(self, low_credentials):
self._low_credentials = low_credentials
def metadata_call_credentials(metadata_plugin, name=None):
"""Construct CallCredentials from an interfaces.GRPCAuthMetadataPlugin.
Args:
metadata_plugin: An interfaces.GRPCAuthMetadataPlugin to use in constructing
the CallCredentials object.
Returns:
A CallCredentials object for use in a GRPCCallOptions object.
"""
if name is None:
name = metadata_plugin.__name__
return CallCredentials(
_low.call_credentials_metadata_plugin(metadata_plugin, name))
def composite_call_credentials(call_credentials, additional_call_credentials):
"""Compose two CallCredentials to make a new one.
Args:
call_credentials: A CallCredentials object.
additional_call_credentials: Another CallCredentials object to compose on
top of call_credentials.
Returns:
A CallCredentials object for use in a GRPCCallOptions object.
"""
return CallCredentials(
_low.call_credentials_composite(
call_credentials._low_credentials,
additional_call_credentials._low_credentials))
def composite_channel_credentials(channel_credentials,
additional_call_credentials):
"""Compose ChannelCredentials on top of client credentials to make a new one.
Args:
channel_credentials: A ChannelCredentials object.
additional_call_credentials: A CallCredentials object to compose on
top of channel_credentials.
Returns:
A ChannelCredentials object for use in a GRPCCallOptions object.
"""
return ChannelCredentials(
_low.channel_credentials_composite(
channel_credentials._low_credentials,
additional_call_credentials._low_credentials))
class Channel(object):
@ -135,19 +196,19 @@ def insecure_channel(host, port):
return Channel(intermediary_low_channel._internal, intermediary_low_channel) # pylint: disable=protected-access
def secure_channel(host, port, client_credentials):
def secure_channel(host, port, channel_credentials):
"""Creates a secure Channel to a remote host.
Args:
host: The name of the remote host to which to connect.
port: The port of the remote host to which to connect.
client_credentials: A ClientCredentials.
channel_credentials: A ChannelCredentials.
Returns:
A secure Channel to the remote host through which RPCs may be conducted.
"""
intermediary_low_channel = _intermediary_low.Channel(
'%s:%d' % (host, port), client_credentials._intermediary_low_credentials)
'%s:%d' % (host, port), channel_credentials._low_credentials)
return Channel(intermediary_low_channel._internal, intermediary_low_channel) # pylint: disable=protected-access
@ -251,9 +312,8 @@ class ServerCredentials(object):
functions.
"""
def __init__(self, low_credentials, intermediary_low_credentials):
def __init__(self, low_credentials):
self._low_credentials = low_credentials
self._intermediary_low_credentials = intermediary_low_credentials
def ssl_server_credentials(
@ -282,11 +342,9 @@ def ssl_server_credentials(
raise ValueError(
'Illegal to require client auth without providing root certificates!')
else:
intermediary_low_credentials = _intermediary_low.ServerCredentials(
return ServerCredentials(_low.server_credentials_ssl(
root_certificates, private_key_certificate_chain_pairs,
require_client_auth)
return ServerCredentials(
intermediary_low_credentials._internal, intermediary_low_credentials) # pylint: disable=protected-access
require_client_auth))
class ServerOptions(object):

@ -100,14 +100,55 @@ def grpc_call_options(disable_compression=False, credentials=None):
disable_compression: A boolean indicating whether or not compression should
be disabled for the request object of the RPC. Only valid for
request-unary RPCs.
credentials: Reserved for gRPC per-call credentials. The type for this does
not exist yet at the Python level.
credentials: A CallCredentials object to use for the invoked RPC.
"""
if credentials is not None:
raise ValueError('`credentials` is a reserved argument')
return GRPCCallOptions(disable_compression, None, credentials)
class GRPCAuthMetadataContext(object):
"""Provides information to call credentials metadata plugins.
Attributes:
service_url: A string URL of the service being called into.
method_name: A string of the fully qualified method name being called.
"""
__metaclass__ = abc.ABCMeta
class GRPCAuthMetadataPluginCallback(object):
"""Callback object received by a metadata plugin."""
__metaclass__ = abc.ABCMeta
def __call__(self, metadata, error):
"""Inform the gRPC runtime of the metadata to construct a CallCredentials.
Args:
metadata: An iterable of 2-sequences (e.g. tuples) of metadata key/value
pairs.
error: An Exception to indicate error or None to indicate success.
"""
raise NotImplementedError()
class GRPCAuthMetadataPlugin(object):
"""
"""
__metaclass__ = abc.ABCMeta
def __call__(self, context, callback):
"""Invoke the plugin.
Must not block. Need only be called by the gRPC runtime.
Args:
context: A GRPCAuthMetadataContext providing information on what the
plugin is being used for.
callback: A GRPCAuthMetadataPluginCallback to be invoked either
synchronously or asynchronously.
"""
raise NotImplementedError()
class GRPCServicerContext(object):
"""Exposes gRPC-specific options and behaviors to code servicing RPCs."""
__metaclass__ = abc.ABCMeta

@ -55,7 +55,7 @@ class SecureInteropTest(
self.server.start()
self.stub = test_pb2.beta_create_TestService_stub(
test_utilities.not_really_secure_channel(
'[::]', port, implementations.ssl_client_credentials(
'[::]', port, implementations.ssl_channel_credentials(
resources.test_root_certificates(), None, None),
_SERVER_HOST_OVERRIDE))

@ -94,7 +94,7 @@ def _stub(args):
channel = test_utilities.not_really_secure_channel(
args.server_host, args.server_port,
implementations.ssl_client_credentials(root_certificates, None, None),
implementations.ssl_channel_credentials(root_certificates, None, None),
args.server_host_override)
stub = test_pb2.beta_create_TestService_stub(
channel, metadata_transformer=metadata_transformer)

@ -28,11 +28,24 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import time
import threading
import unittest
from grpc._cython import cygrpc
from tests.unit._cython import test_utilities
from tests.unit import test_common
from tests.unit import resources
_SSL_HOST_OVERRIDE = 'foo.test.google.fr'
_CALL_CREDENTIALS_METADATA_KEY = 'call-creds-key'
_CALL_CREDENTIALS_METADATA_VALUE = 'call-creds-value'
def _metadata_plugin_callback(context, callback):
callback(cygrpc.Metadata(
[cygrpc.Metadatum(_CALL_CREDENTIALS_METADATA_KEY,
_CALL_CREDENTIALS_METADATA_VALUE)]),
cygrpc.StatusCode.ok, '')
class TypeSmokeTest(unittest.TestCase):
@ -89,6 +102,17 @@ class TypeSmokeTest(unittest.TestCase):
channel = cygrpc.Channel('[::]:0', cygrpc.ChannelArgs([]))
del channel
def testCredentialsMetadataPluginUpDown(self):
plugin = cygrpc.CredentialsMetadataPlugin(
lambda ignored_a, ignored_b: None, '')
del plugin
def testCallCredentialsFromPluginUpDown(self):
plugin = cygrpc.CredentialsMetadataPlugin(_metadata_plugin_callback, '')
call_credentials = cygrpc.call_credentials_metadata_plugin(plugin)
del plugin
del call_credentials
def testServerStartNoExplicitShutdown(self):
server = cygrpc.Server()
completion_queue = cygrpc.CompletionQueue()
@ -260,5 +284,169 @@ class InsecureServerInsecureClient(unittest.TestCase):
del server_call
class SecureServerSecureClient(unittest.TestCase):
def setUp(self):
server_credentials = cygrpc.server_credentials_ssl(
None, [cygrpc.SslPemKeyCertPair(resources.private_key(),
resources.certificate_chain())], False)
channel_credentials = cygrpc.channel_credentials_ssl(
resources.test_root_certificates(), None)
self.server_completion_queue = cygrpc.CompletionQueue()
self.server = cygrpc.Server()
self.server.register_completion_queue(self.server_completion_queue)
self.port = self.server.add_http2_port('[::]:0', server_credentials)
self.server.start()
self.client_completion_queue = cygrpc.CompletionQueue()
client_channel_arguments = cygrpc.ChannelArgs([
cygrpc.ChannelArg(cygrpc.ChannelArgKey.ssl_target_name_override,
_SSL_HOST_OVERRIDE)])
self.client_channel = cygrpc.Channel(
'localhost:{}'.format(self.port), client_channel_arguments,
channel_credentials)
def tearDown(self):
del self.server
del self.client_completion_queue
del self.server_completion_queue
def testEcho(self):
DEADLINE = time.time()+5
DEADLINE_TOLERANCE = 0.25
CLIENT_METADATA_ASCII_KEY = b'key'
CLIENT_METADATA_ASCII_VALUE = b'val'
CLIENT_METADATA_BIN_KEY = b'key-bin'
CLIENT_METADATA_BIN_VALUE = b'\0'*1000
SERVER_INITIAL_METADATA_KEY = b'init_me_me_me'
SERVER_INITIAL_METADATA_VALUE = b'whodawha?'
SERVER_TRAILING_METADATA_KEY = b'california_is_in_a_drought'
SERVER_TRAILING_METADATA_VALUE = b'zomg it is'
SERVER_STATUS_CODE = cygrpc.StatusCode.ok
SERVER_STATUS_DETAILS = b'our work is never over'
REQUEST = b'in death a member of project mayhem has a name'
RESPONSE = b'his name is robert paulson'
METHOD = b'/twinkies'
HOST = None # Default host
cygrpc_deadline = cygrpc.Timespec(DEADLINE)
server_request_tag = object()
request_call_result = self.server.request_call(
self.server_completion_queue, self.server_completion_queue,
server_request_tag)
self.assertEqual(cygrpc.CallError.ok, request_call_result)
plugin = cygrpc.CredentialsMetadataPlugin(_metadata_plugin_callback, '')
call_credentials = cygrpc.call_credentials_metadata_plugin(plugin)
client_call_tag = object()
client_call = self.client_channel.create_call(
None, 0, self.client_completion_queue, METHOD, HOST, cygrpc_deadline)
client_call.set_credentials(call_credentials)
client_initial_metadata = cygrpc.Metadata([
cygrpc.Metadatum(CLIENT_METADATA_ASCII_KEY,
CLIENT_METADATA_ASCII_VALUE),
cygrpc.Metadatum(CLIENT_METADATA_BIN_KEY, CLIENT_METADATA_BIN_VALUE)])
client_start_batch_result = client_call.start_batch(cygrpc.Operations([
cygrpc.operation_send_initial_metadata(client_initial_metadata),
cygrpc.operation_send_message(REQUEST),
cygrpc.operation_send_close_from_client(),
cygrpc.operation_receive_initial_metadata(),
cygrpc.operation_receive_message(),
cygrpc.operation_receive_status_on_client()
]), client_call_tag)
self.assertEqual(cygrpc.CallError.ok, client_start_batch_result)
client_event_future = test_utilities.CompletionQueuePollFuture(
self.client_completion_queue, cygrpc_deadline)
request_event = self.server_completion_queue.poll(cygrpc_deadline)
self.assertEqual(cygrpc.CompletionType.operation_complete,
request_event.type)
self.assertIsInstance(request_event.operation_call, cygrpc.Call)
self.assertIs(server_request_tag, request_event.tag)
self.assertEqual(0, len(request_event.batch_operations))
client_metadata_with_credentials = list(client_initial_metadata) + [
(_CALL_CREDENTIALS_METADATA_KEY, _CALL_CREDENTIALS_METADATA_VALUE)]
self.assertTrue(
test_common.metadata_transmitted(client_metadata_with_credentials,
request_event.request_metadata))
self.assertEqual(METHOD, request_event.request_call_details.method)
self.assertEqual(_SSL_HOST_OVERRIDE,
request_event.request_call_details.host)
self.assertLess(
abs(DEADLINE - float(request_event.request_call_details.deadline)),
DEADLINE_TOLERANCE)
server_call_tag = object()
server_call = request_event.operation_call
server_initial_metadata = cygrpc.Metadata([
cygrpc.Metadatum(SERVER_INITIAL_METADATA_KEY,
SERVER_INITIAL_METADATA_VALUE)])
server_trailing_metadata = cygrpc.Metadata([
cygrpc.Metadatum(SERVER_TRAILING_METADATA_KEY,
SERVER_TRAILING_METADATA_VALUE)])
server_start_batch_result = server_call.start_batch([
cygrpc.operation_send_initial_metadata(server_initial_metadata),
cygrpc.operation_receive_message(),
cygrpc.operation_send_message(RESPONSE),
cygrpc.operation_receive_close_on_server(),
cygrpc.operation_send_status_from_server(
server_trailing_metadata, SERVER_STATUS_CODE, SERVER_STATUS_DETAILS)
], server_call_tag)
self.assertEqual(cygrpc.CallError.ok, server_start_batch_result)
client_event = client_event_future.result()
server_event = self.server_completion_queue.poll(cygrpc_deadline)
self.assertEqual(6, len(client_event.batch_operations))
found_client_op_types = set()
for client_result in client_event.batch_operations:
# we expect each op type to be unique
self.assertNotIn(client_result.type, found_client_op_types)
found_client_op_types.add(client_result.type)
if client_result.type == cygrpc.OperationType.receive_initial_metadata:
self.assertTrue(
test_common.metadata_transmitted(server_initial_metadata,
client_result.received_metadata))
elif client_result.type == cygrpc.OperationType.receive_message:
self.assertEqual(RESPONSE, client_result.received_message.bytes())
elif client_result.type == cygrpc.OperationType.receive_status_on_client:
self.assertTrue(
test_common.metadata_transmitted(server_trailing_metadata,
client_result.received_metadata))
self.assertEqual(SERVER_STATUS_DETAILS,
client_result.received_status_details)
self.assertEqual(SERVER_STATUS_CODE, client_result.received_status_code)
self.assertEqual(set([
cygrpc.OperationType.send_initial_metadata,
cygrpc.OperationType.send_message,
cygrpc.OperationType.send_close_from_client,
cygrpc.OperationType.receive_initial_metadata,
cygrpc.OperationType.receive_message,
cygrpc.OperationType.receive_status_on_client
]), found_client_op_types)
self.assertEqual(5, len(server_event.batch_operations))
found_server_op_types = set()
for server_result in server_event.batch_operations:
self.assertNotIn(client_result.type, found_server_op_types)
found_server_op_types.add(server_result.type)
if server_result.type == cygrpc.OperationType.receive_message:
self.assertEqual(REQUEST, server_result.received_message.bytes())
elif server_result.type == cygrpc.OperationType.receive_close_on_server:
self.assertFalse(server_result.received_cancelled)
self.assertEqual(set([
cygrpc.OperationType.send_initial_metadata,
cygrpc.OperationType.receive_message,
cygrpc.OperationType.send_message,
cygrpc.OperationType.receive_close_on_server,
cygrpc.OperationType.send_status_from_server
]), found_server_op_types)
del client_call
del server_call
if __name__ == '__main__':
unittest.main(verbosity=2)

@ -42,6 +42,9 @@ from tests.unit.framework.common import test_constants
_SERVER_HOST_OVERRIDE = 'foo.test.google.fr'
_PER_RPC_CREDENTIALS_METADATA_KEY = 'my-call-credentials-metadata-key'
_PER_RPC_CREDENTIALS_METADATA_VALUE = 'my-call-credentials-metadata-value'
_GROUP = 'group'
_UNARY_UNARY = 'unary-unary'
_UNARY_STREAM = 'unary-stream'
@ -63,6 +66,7 @@ class _Servicer(object):
with self._condition:
self._request = request
self._peer = context.protocol_context().peer()
self._invocation_metadata = context.invocation_metadata()
context.protocol_context().disable_next_response_compression()
self._serviced = True
self._condition.notify_all()
@ -72,6 +76,7 @@ class _Servicer(object):
with self._condition:
self._request = request
self._peer = context.protocol_context().peer()
self._invocation_metadata = context.invocation_metadata()
context.protocol_context().disable_next_response_compression()
self._serviced = True
self._condition.notify_all()
@ -83,6 +88,7 @@ class _Servicer(object):
self._request = request
with self._condition:
self._peer = context.protocol_context().peer()
self._invocation_metadata = context.invocation_metadata()
context.protocol_context().disable_next_response_compression()
self._serviced = True
self._condition.notify_all()
@ -95,6 +101,7 @@ class _Servicer(object):
context.protocol_context().disable_next_response_compression()
yield _RESPONSE
with self._condition:
self._invocation_metadata = context.invocation_metadata()
self._serviced = True
self._condition.notify_all()
@ -137,6 +144,11 @@ class _BlockingIterator(object):
self._condition.notify_all()
def _metadata_plugin(context, callback):
callback([(_PER_RPC_CREDENTIALS_METADATA_KEY,
_PER_RPC_CREDENTIALS_METADATA_VALUE)], None)
class BetaFeaturesTest(unittest.TestCase):
def setUp(self):
@ -167,10 +179,12 @@ class BetaFeaturesTest(unittest.TestCase):
[(resources.private_key(), resources.certificate_chain(),),])
port = self._server.add_secure_port('[::]:0', server_credentials)
self._server.start()
self._client_credentials = implementations.ssl_client_credentials(
self._channel_credentials = implementations.ssl_channel_credentials(
resources.test_root_certificates(), None, None)
self._call_credentials = implementations.metadata_call_credentials(
_metadata_plugin)
channel = test_utilities.not_really_secure_channel(
'localhost', port, self._client_credentials, _SERVER_HOST_OVERRIDE)
'localhost', port, self._channel_credentials, _SERVER_HOST_OVERRIDE)
stub_options = implementations.stub_options(
thread_pool_size=test_constants.POOL_SIZE)
self._dynamic_stub = implementations.dynamic_stub(
@ -181,21 +195,36 @@ class BetaFeaturesTest(unittest.TestCase):
self._server.stop(test_constants.SHORT_TIMEOUT).wait()
def test_unary_unary(self):
call_options = interfaces.grpc_call_options(disable_compression=True)
call_options = interfaces.grpc_call_options(
disable_compression=True, credentials=self._call_credentials)
response = getattr(self._dynamic_stub, _UNARY_UNARY)(
_REQUEST, test_constants.LONG_TIMEOUT, protocol_options=call_options)
self.assertEqual(_RESPONSE, response)
self.assertIsNotNone(self._servicer.peer())
invocation_metadata = [(metadatum.key, metadatum.value) for metadatum in
self._servicer._invocation_metadata]
self.assertIn(
(_PER_RPC_CREDENTIALS_METADATA_KEY,
_PER_RPC_CREDENTIALS_METADATA_VALUE),
invocation_metadata)
def test_unary_stream(self):
call_options = interfaces.grpc_call_options(disable_compression=True)
call_options = interfaces.grpc_call_options(
disable_compression=True, credentials=self._call_credentials)
response_iterator = getattr(self._dynamic_stub, _UNARY_STREAM)(
_REQUEST, test_constants.LONG_TIMEOUT, protocol_options=call_options)
self._servicer.block_until_serviced()
self.assertIsNotNone(self._servicer.peer())
invocation_metadata = [(metadatum.key, metadatum.value) for metadatum in
self._servicer._invocation_metadata]
self.assertIn(
(_PER_RPC_CREDENTIALS_METADATA_KEY,
_PER_RPC_CREDENTIALS_METADATA_VALUE),
invocation_metadata)
def test_stream_unary(self):
call_options = interfaces.grpc_call_options()
call_options = interfaces.grpc_call_options(
credentials=self._call_credentials)
request_iterator = _BlockingIterator(iter((_REQUEST,)))
response_future = getattr(self._dynamic_stub, _STREAM_UNARY).future(
request_iterator, test_constants.LONG_TIMEOUT,
@ -207,9 +236,16 @@ class BetaFeaturesTest(unittest.TestCase):
self._servicer.block_until_serviced()
self.assertIsNotNone(self._servicer.peer())
self.assertEqual(_RESPONSE, response_future.result())
invocation_metadata = [(metadatum.key, metadatum.value) for metadatum in
self._servicer._invocation_metadata]
self.assertIn(
(_PER_RPC_CREDENTIALS_METADATA_KEY,
_PER_RPC_CREDENTIALS_METADATA_VALUE),
invocation_metadata)
def test_stream_stream(self):
call_options = interfaces.grpc_call_options()
call_options = interfaces.grpc_call_options(
credentials=self._call_credentials)
request_iterator = _BlockingIterator(iter((_REQUEST,)))
response_iterator = getattr(self._dynamic_stub, _STREAM_STREAM)(
request_iterator, test_constants.SHORT_TIMEOUT,
@ -222,6 +258,12 @@ class BetaFeaturesTest(unittest.TestCase):
self._servicer.block_until_serviced()
self.assertIsNotNone(self._servicer.peer())
self.assertEqual(_RESPONSE, response)
invocation_metadata = [(metadatum.key, metadatum.value) for metadatum in
self._servicer._invocation_metadata]
self.assertIn(
(_PER_RPC_CREDENTIALS_METADATA_KEY,
_PER_RPC_CREDENTIALS_METADATA_VALUE),
invocation_metadata)
class ContextManagementAndLifecycleTest(unittest.TestCase):
@ -250,7 +292,7 @@ class ContextManagementAndLifecycleTest(unittest.TestCase):
thread_pool_size=test_constants.POOL_SIZE)
self._server_credentials = implementations.ssl_server_credentials(
[(resources.private_key(), resources.certificate_chain(),),])
self._client_credentials = implementations.ssl_client_credentials(
self._channel_credentials = implementations.ssl_channel_credentials(
resources.test_root_certificates(), None, None)
self._stub_options = implementations.stub_options(
thread_pool_size=test_constants.POOL_SIZE)
@ -262,7 +304,7 @@ class ContextManagementAndLifecycleTest(unittest.TestCase):
server.start()
channel = test_utilities.not_really_secure_channel(
'localhost', port, self._client_credentials, _SERVER_HOST_OVERRIDE)
'localhost', port, self._channel_credentials, _SERVER_HOST_OVERRIDE)
dynamic_stub = implementations.dynamic_stub(
channel, _GROUP, self._cardinalities, options=self._stub_options)
for _ in range(100):

@ -91,10 +91,10 @@ class _Implementation(test_interfaces.Implementation):
[(resources.private_key(), resources.certificate_chain(),),])
port = server.add_secure_port('[::]:0', server_credentials)
server.start()
client_credentials = implementations.ssl_client_credentials(
channel_credentials = implementations.ssl_channel_credentials(
resources.test_root_certificates(), None, None)
channel = test_utilities.not_really_secure_channel(
'localhost', port, client_credentials, _SERVER_HOST_OVERRIDE)
'localhost', port, channel_credentials, _SERVER_HOST_OVERRIDE)
stub_options = implementations.stub_options(
request_serializers=serialization_behaviors.request_serializers,
response_deserializers=serialization_behaviors.response_deserializers,

@ -34,13 +34,13 @@ from grpc.beta import implementations
def not_really_secure_channel(
host, port, client_credentials, server_host_override):
host, port, channel_credentials, server_host_override):
"""Creates an insecure Channel to a remote host.
Args:
host: The name of the remote host to which to connect.
port: The port of the remote host to which to connect.
client_credentials: The implementations.ClientCredentials with which to
channel_credentials: The implementations.ChannelCredentials with which to
connect.
server_host_override: The target name used for SSL host name checking.
@ -50,7 +50,7 @@ def not_really_secure_channel(
"""
hostport = '%s:%d' % (host, port)
intermediary_low_channel = _intermediary_low.Channel(
hostport, client_credentials._intermediary_low_credentials,
hostport, channel_credentials._low_credentials,
server_host_override=server_host_override)
return implementations.Channel(
intermediary_low_channel._internal, intermediary_low_channel)

Loading…
Cancel
Save