Merge pull request #8137 from kpayson64/python_server_args

Add parameter for server options
pull/8146/head
kpayson64 9 years ago committed by GitHub
commit a6a6fa4f12
  1. 12
      src/python/grpcio/grpc/__init__.py
  2. 15
      src/python/grpcio/grpc/_channel.py
  3. 10
      src/python/grpcio/grpc/_common.py
  4. 5
      src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi
  5. 4
      src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi
  6. 5
      src/python/grpcio/grpc/_server.py
  7. 1
      src/python/grpcio_tests/tests/tests.json
  8. 53
      src/python/grpcio_tests/tests/unit/_channel_args_test.py
  9. 10
      src/python/grpcio_tests/tests/unit/_channel_connectivity_test.py
  10. 2
      src/python/grpcio_tests/tests/unit/_channel_ready_future_test.py
  11. 5
      src/python/grpcio_tests/tests/unit/_cython/_cancel_many_calls_test.py
  12. 5
      src/python/grpcio_tests/tests/unit/_cython/_read_some_but_not_all_responses_test.py
  13. 9
      src/python/grpcio_tests/tests/unit/_cython/cygrpc_test.py

@ -1189,7 +1189,7 @@ def insecure_channel(target, options=None):
A Channel to the target through which RPCs may be conducted. A Channel to the target through which RPCs may be conducted.
""" """
from grpc import _channel from grpc import _channel
return _channel.Channel(target, options, None) return _channel.Channel(target, () if options is None else options, None)
def secure_channel(target, credentials, options=None): def secure_channel(target, credentials, options=None):
@ -1205,10 +1205,11 @@ def secure_channel(target, credentials, options=None):
A Channel to the target through which RPCs may be conducted. A Channel to the target through which RPCs may be conducted.
""" """
from grpc import _channel from grpc import _channel
return _channel.Channel(target, options, credentials._credentials) return _channel.Channel(target, () if options is None else options,
credentials._credentials)
def server(thread_pool, handlers=None): def server(thread_pool, handlers=None, options=None):
"""Creates a Server with which RPCs can be serviced. """Creates a Server with which RPCs can be serviced.
Args: Args:
@ -1219,12 +1220,15 @@ def server(thread_pool, handlers=None):
only handlers the server will use to service RPCs; other handlers may only handlers the server will use to service RPCs; other handlers may
later be added by calling add_generic_rpc_handlers any time before the later be added by calling add_generic_rpc_handlers any time before the
returned Server is started. returned Server is started.
options: A sequence of string-value pairs according to which to configure
the created server.
Returns: Returns:
A Server with which RPCs can be serviced. A Server with which RPCs can be serviced.
""" """
from grpc import _server from grpc import _server
return _server.Server(thread_pool, () if handlers is None else handlers) return _server.Server(thread_pool, () if handlers is None else handlers,
() if options is None else options)
################################### __all__ ################################# ################################### __all__ #################################

@ -842,18 +842,8 @@ def _unsubscribe(state, callback):
def _options(options): def _options(options):
if options is None: return list(options) + [
pairs = ((cygrpc.ChannelArgKey.primary_user_agent_string, _USER_AGENT),)
else:
pairs = list(options) + [
(cygrpc.ChannelArgKey.primary_user_agent_string, _USER_AGENT)] (cygrpc.ChannelArgKey.primary_user_agent_string, _USER_AGENT)]
encoded_pairs = [
(_common.encode(arg_name), arg_value) if isinstance(arg_value, int)
else (_common.encode(arg_name), _common.encode(arg_value))
for arg_name, arg_value in pairs]
return cygrpc.ChannelArgs([
cygrpc.ChannelArg(arg_name, arg_value)
for arg_name, arg_value in encoded_pairs])
class Channel(grpc.Channel): class Channel(grpc.Channel):
@ -867,7 +857,8 @@ class Channel(grpc.Channel):
credentials: A cygrpc.ChannelCredentials or None. credentials: A cygrpc.ChannelCredentials or None.
""" """
self._channel = cygrpc.Channel( self._channel = cygrpc.Channel(
_common.encode(target), _options(options), credentials) _common.encode(target), _common.channel_args(_options(options)),
credentials)
self._call_state = _ChannelCallState(self._channel) self._call_state = _ChannelCallState(self._channel)
self._connectivity_state = _ChannelConnectivityState(self._channel) self._connectivity_state = _ChannelConnectivityState(self._channel)

@ -94,6 +94,16 @@ def decode(b):
return b.decode('latin1') return b.decode('latin1')
def channel_args(options):
channel_args = []
for key, value in options:
if isinstance(value, six.string_types):
channel_args.append(cygrpc.ChannelArg(encode(key), encode(value)))
else:
channel_args.append(cygrpc.ChannelArg(encode(key), value))
return cygrpc.ChannelArgs(channel_args)
def cygrpc_metadata(application_metadata): def cygrpc_metadata(application_metadata):
return _EMPTY_METADATA if application_metadata is None else cygrpc.Metadata( return _EMPTY_METADATA if application_metadata is None else cygrpc.Metadata(
cygrpc.Metadatum(encode(key), encode(value)) cygrpc.Metadatum(encode(key), encode(value))

@ -32,15 +32,16 @@ cimport cpython
cdef class Channel: cdef class Channel:
def __cinit__(self, bytes target, ChannelArgs arguments=None, def __cinit__(self, bytes target, ChannelArgs arguments,
ChannelCredentials channel_credentials=None): ChannelCredentials channel_credentials=None):
grpc_init() grpc_init()
cdef grpc_channel_args *c_arguments = NULL cdef grpc_channel_args *c_arguments = NULL
cdef char *c_target = NULL cdef char *c_target = NULL
self.c_channel = NULL self.c_channel = NULL
self.references = [] self.references = []
if arguments is not None: if len(arguments) > 0:
c_arguments = &arguments.c_args c_arguments = &arguments.c_args
self.references.append(arguments)
c_target = target c_target = target
if channel_credentials is None: if channel_credentials is None:
with nogil: with nogil:

@ -34,12 +34,12 @@ import time
cdef class Server: cdef class Server:
def __cinit__(self, ChannelArgs arguments=None): def __cinit__(self, ChannelArgs arguments):
grpc_init() grpc_init()
cdef grpc_channel_args *c_arguments = NULL cdef grpc_channel_args *c_arguments = NULL
self.references = [] self.references = []
self.registered_completion_queues = [] self.registered_completion_queues = []
if arguments is not None: if len(arguments) > 0:
c_arguments = &arguments.c_args c_arguments = &arguments.c_args
self.references.append(arguments) self.references.append(arguments)
with nogil: with nogil:

@ -728,12 +728,11 @@ def _start(state):
cleanup_server, target=_serve, args=(state,)) cleanup_server, target=_serve, args=(state,))
thread.start() thread.start()
class Server(grpc.Server): class Server(grpc.Server):
def __init__(self, thread_pool, generic_handlers): def __init__(self, thread_pool, generic_handlers, options):
completion_queue = cygrpc.CompletionQueue() completion_queue = cygrpc.CompletionQueue()
server = cygrpc.Server() server = cygrpc.Server(_common.channel_args(options))
server.register_completion_queue(completion_queue) server.register_completion_queue(completion_queue)
self._state = _ServerState( self._state = _ServerState(
completion_queue, server, generic_handlers, thread_pool) completion_queue, server, generic_handlers, thread_pool)

@ -7,6 +7,7 @@
"_beta_features_test.BetaFeaturesTest", "_beta_features_test.BetaFeaturesTest",
"_beta_features_test.ContextManagementAndLifecycleTest", "_beta_features_test.ContextManagementAndLifecycleTest",
"_cancel_many_calls_test.CancelManyCallsTest", "_cancel_many_calls_test.CancelManyCallsTest",
"_channel_args_test.ChannelArgsTest",
"_channel_connectivity_test.ChannelConnectivityTest", "_channel_connectivity_test.ChannelConnectivityTest",
"_channel_ready_future_test.ChannelReadyFutureTest", "_channel_ready_future_test.ChannelReadyFutureTest",
"_channel_test.ChannelTest", "_channel_test.ChannelTest",

@ -0,0 +1,53 @@
# Copyright 2016, Google Inc.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Tests of Channel Args on client/server side."""
import unittest
import grpc
TEST_CHANNEL_ARGS = (
('arg1', b'bytes_val'),
('arg2', 'str_val'),
('arg3', 1),
(b'arg4', 'str_val'),
)
class ChannelArgsTest(unittest.TestCase):
def test_client(self):
grpc.insecure_channel('localhost:8080', options=TEST_CHANNEL_ARGS)
def test_server(self):
grpc.server(None, options=TEST_CHANNEL_ARGS)
if __name__ == '__main__':
unittest.main(verbosity=2)

@ -78,7 +78,7 @@ class ChannelConnectivityTest(unittest.TestCase):
def test_lonely_channel_connectivity(self): def test_lonely_channel_connectivity(self):
callback = _Callback() callback = _Callback()
channel = _channel.Channel('localhost:12345', None, None) channel = _channel.Channel('localhost:12345', (), None)
channel.subscribe(callback.update, try_to_connect=False) channel.subscribe(callback.update, try_to_connect=False)
first_connectivities = callback.block_until_connectivities_satisfy(bool) first_connectivities = callback.block_until_connectivities_satisfy(bool)
channel.subscribe(callback.update, try_to_connect=True) channel.subscribe(callback.update, try_to_connect=True)
@ -105,13 +105,13 @@ class ChannelConnectivityTest(unittest.TestCase):
def test_immediately_connectable_channel_connectivity(self): def test_immediately_connectable_channel_connectivity(self):
thread_pool = _thread_pool.RecordingThreadPool(max_workers=None) thread_pool = _thread_pool.RecordingThreadPool(max_workers=None)
server = _server.Server(thread_pool, ()) server = _server.Server(thread_pool, (), ())
port = server.add_insecure_port('[::]:0') port = server.add_insecure_port('[::]:0')
server.start() server.start()
first_callback = _Callback() first_callback = _Callback()
second_callback = _Callback() second_callback = _Callback()
channel = _channel.Channel('localhost:{}'.format(port), None, None) channel = _channel.Channel('localhost:{}'.format(port), (), None)
channel.subscribe(first_callback.update, try_to_connect=False) channel.subscribe(first_callback.update, try_to_connect=False)
first_connectivities = first_callback.block_until_connectivities_satisfy( first_connectivities = first_callback.block_until_connectivities_satisfy(
bool) bool)
@ -146,12 +146,12 @@ class ChannelConnectivityTest(unittest.TestCase):
def test_reachable_then_unreachable_channel_connectivity(self): def test_reachable_then_unreachable_channel_connectivity(self):
thread_pool = _thread_pool.RecordingThreadPool(max_workers=None) thread_pool = _thread_pool.RecordingThreadPool(max_workers=None)
server = _server.Server(thread_pool, ()) server = _server.Server(thread_pool, (), ())
port = server.add_insecure_port('[::]:0') port = server.add_insecure_port('[::]:0')
server.start() server.start()
callback = _Callback() callback = _Callback()
channel = _channel.Channel('localhost:{}'.format(port), None, None) channel = _channel.Channel('localhost:{}'.format(port), (), None)
channel.subscribe(callback.update, try_to_connect=True) channel.subscribe(callback.update, try_to_connect=True)
callback.block_until_connectivities_satisfy(_ready_in_connectivities) callback.block_until_connectivities_satisfy(_ready_in_connectivities)
# Now take down the server and confirm that channel readiness is repudiated. # Now take down the server and confirm that channel readiness is repudiated.

@ -79,7 +79,7 @@ class ChannelReadyFutureTest(unittest.TestCase):
def test_immediately_connectable_channel_connectivity(self): def test_immediately_connectable_channel_connectivity(self):
thread_pool = _thread_pool.RecordingThreadPool(max_workers=None) thread_pool = _thread_pool.RecordingThreadPool(max_workers=None)
server = _server.Server(thread_pool, ()) server = _server.Server(thread_pool, (), ())
port = server.add_insecure_port('[::]:0') port = server.add_insecure_port('[::]:0')
server.start() server.start()
channel = grpc.insecure_channel('localhost:{}'.format(port)) channel = grpc.insecure_channel('localhost:{}'.format(port))

@ -157,11 +157,12 @@ class CancelManyCallsTest(unittest.TestCase):
server_thread_pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY) server_thread_pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY)
server_completion_queue = cygrpc.CompletionQueue() server_completion_queue = cygrpc.CompletionQueue()
server = cygrpc.Server() server = cygrpc.Server(cygrpc.ChannelArgs([]))
server.register_completion_queue(server_completion_queue) server.register_completion_queue(server_completion_queue)
port = server.add_http2_port(b'[::]:0') port = server.add_http2_port(b'[::]:0')
server.start() server.start()
channel = cygrpc.Channel('localhost:{}'.format(port).encode()) channel = cygrpc.Channel('localhost:{}'.format(port).encode(),
cygrpc.ChannelArgs([]))
state = _State() state = _State()

@ -124,11 +124,12 @@ class ReadSomeButNotAllResponsesTest(unittest.TestCase):
def testReadSomeButNotAllResponses(self): def testReadSomeButNotAllResponses(self):
server_completion_queue = cygrpc.CompletionQueue() server_completion_queue = cygrpc.CompletionQueue()
server = cygrpc.Server() server = cygrpc.Server(cygrpc.ChannelArgs([]))
server.register_completion_queue(server_completion_queue) server.register_completion_queue(server_completion_queue)
port = server.add_http2_port(b'[::]:0') port = server.add_http2_port(b'[::]:0')
server.start() server.start()
channel = cygrpc.Channel('localhost:{}'.format(port).encode()) channel = cygrpc.Channel('localhost:{}'.format(port).encode(),
cygrpc.ChannelArgs([]))
server_shutdown_tag = 'server_shutdown_tag' server_shutdown_tag = 'server_shutdown_tag'
server_driver = _ServerDriver(server_completion_queue, server_shutdown_tag) server_driver = _ServerDriver(server_completion_queue, server_shutdown_tag)

@ -121,7 +121,7 @@ class TypeSmokeTest(unittest.TestCase):
del call_credentials del call_credentials
def testServerStartNoExplicitShutdown(self): def testServerStartNoExplicitShutdown(self):
server = cygrpc.Server() server = cygrpc.Server(cygrpc.ChannelArgs([]))
completion_queue = cygrpc.CompletionQueue() completion_queue = cygrpc.CompletionQueue()
server.register_completion_queue(completion_queue) server.register_completion_queue(completion_queue)
port = server.add_http2_port(b'[::]:0') port = server.add_http2_port(b'[::]:0')
@ -131,7 +131,7 @@ class TypeSmokeTest(unittest.TestCase):
def testServerStartShutdown(self): def testServerStartShutdown(self):
completion_queue = cygrpc.CompletionQueue() completion_queue = cygrpc.CompletionQueue()
server = cygrpc.Server() server = cygrpc.Server(cygrpc.ChannelArgs([]))
server.add_http2_port(b'[::]:0') server.add_http2_port(b'[::]:0')
server.register_completion_queue(completion_queue) server.register_completion_queue(completion_queue)
server.start() server.start()
@ -148,7 +148,7 @@ class ServerClientMixin(object):
def setUpMixin(self, server_credentials, client_credentials, host_override): def setUpMixin(self, server_credentials, client_credentials, host_override):
self.server_completion_queue = cygrpc.CompletionQueue() self.server_completion_queue = cygrpc.CompletionQueue()
self.server = cygrpc.Server() self.server = cygrpc.Server(cygrpc.ChannelArgs([]))
self.server.register_completion_queue(self.server_completion_queue) self.server.register_completion_queue(self.server_completion_queue)
if server_credentials: if server_credentials:
self.port = self.server.add_http2_port(b'[::]:0', server_credentials) self.port = self.server.add_http2_port(b'[::]:0', server_credentials)
@ -164,7 +164,8 @@ class ServerClientMixin(object):
'localhost:{}'.format(self.port).encode(), client_channel_arguments, 'localhost:{}'.format(self.port).encode(), client_channel_arguments,
client_credentials) client_credentials)
else: else:
self.client_channel = cygrpc.Channel('localhost:{}'.format(self.port).encode()) self.client_channel = cygrpc.Channel(
'localhost:{}'.format(self.port).encode(), cygrpc.ChannelArgs([]))
if host_override: if host_override:
self.host_argument = None # default host self.host_argument = None # default host
self.expected_host = host_override self.expected_host = host_override

Loading…
Cancel
Save