Elide cygrpc.ChannelArg and cygrpc.ChannelArgs

pull/14107/head
Nathaniel Manista 7 years ago
parent 31ddbff8cf
commit c73758acb0
  1. 23
      src/python/grpcio/grpc/_channel.py
  2. 10
      src/python/grpcio/grpc/_common.py
  3. 40
      src/python/grpcio/grpc/_cython/_cygrpc/arguments.pxd.pxi
  4. 88
      src/python/grpcio/grpc/_cython/_cygrpc/arguments.pyx.pxi
  5. 2
      src/python/grpcio/grpc/_cython/_cygrpc/channel.pxd.pxi
  6. 19
      src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi
  7. 13
      src/python/grpcio/grpc/_cython/_cygrpc/records.pxd.pxi
  8. 81
      src/python/grpcio/grpc/_cython/_cygrpc/records.pyx.pxi
  9. 2
      src/python/grpcio/grpc/_cython/_cygrpc/server.pxd.pxi
  10. 17
      src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi
  11. 1
      src/python/grpcio/grpc/_cython/cygrpc.pxd
  12. 1
      src/python/grpcio/grpc/_cython/cygrpc.pyx
  13. 2
      src/python/grpcio/grpc/_server.py
  14. 11
      src/python/grpcio_tests/tests/unit/_cython/_cancel_many_calls_test.py
  15. 2
      src/python/grpcio_tests/tests/unit/_cython/_channel_test.py
  16. 6
      src/python/grpcio_tests/tests/unit/_cython/_common.py
  17. 9
      src/python/grpcio_tests/tests/unit/_cython/_read_some_but_not_all_responses_test.py
  18. 2
      src/python/grpcio_tests/tests/unit/_cython/_server_test.py
  19. 44
      src/python/grpcio_tests/tests/unit/_cython/cygrpc_test.py
  20. 16
      src/python/grpcio_tests/tests/unit/_metadata_test.py

@ -13,10 +13,10 @@
# limitations under the License. # limitations under the License.
"""Invocation-side implementation of gRPC Python.""" """Invocation-side implementation of gRPC Python."""
import logging
import sys import sys
import threading import threading
import time import time
import logging
import grpc import grpc
from grpc import _common from grpc import _common
@ -882,8 +882,12 @@ def _unsubscribe(state, callback):
def _options(options): def _options(options):
return list(options) + [(cygrpc.ChannelArgKey.primary_user_agent_string, return list(options) + [
_USER_AGENT)] (
cygrpc.ChannelArgKey.primary_user_agent_string,
_USER_AGENT,
),
]
class Channel(grpc.Channel): class Channel(grpc.Channel):
@ -892,14 +896,13 @@ class Channel(grpc.Channel):
def __init__(self, target, options, credentials): def __init__(self, target, options, credentials):
"""Constructor. """Constructor.
Args: Args:
target: The target to which to connect. target: The target to which to connect.
options: Configuration options for the channel. options: Configuration options for the channel.
credentials: A cygrpc.ChannelCredentials or None. credentials: A cygrpc.ChannelCredentials or None.
""" """
self._channel = cygrpc.Channel( self._channel = cygrpc.Channel(
_common.encode(target), _common.channel_args(_options(options)), _common.encode(target), _options(options), credentials)
credentials)
self._call_state = _ChannelCallState(self._channel) self._call_state = _ChannelCallState(self._channel)
self._connectivity_state = _ChannelConnectivityState(self._channel) self._connectivity_state = _ChannelConnectivityState(self._channel)

@ -79,16 +79,6 @@ def decode(b):
return b.decode('latin1') return b.decode('latin1')
def channel_args(options):
cygrpc_args = []
for key, value in options:
if isinstance(value, six.string_types):
cygrpc_args.append(cygrpc.ChannelArg(encode(key), encode(value)))
else:
cygrpc_args.append(cygrpc.ChannelArg(encode(key), value))
return cygrpc.ChannelArgs(cygrpc_args)
def _transform(message, transformer, exception_message): def _transform(message, transformer, exception_message):
if transformer is None: if transformer is None:
return message return message

@ -0,0 +1,40 @@
# Copyright 2018 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
cdef void* _copy_pointer(void* pointer)
cdef void _destroy_pointer(void* pointer)
cdef int _compare_pointer(void* first_pointer, void* second_pointer)
cdef class _ArgumentProcessor:
cdef grpc_arg c_argument
cdef void c(self, argument, grpc_arg_pointer_vtable *vtable, references)
cdef class _ArgumentsProcessor:
cdef readonly tuple _arguments
cdef list _argument_processors
cdef readonly list _references
cdef grpc_channel_args _c_arguments
cdef grpc_channel_args *c(self, grpc_arg_pointer_vtable *vtable)
cdef un_c(self)

@ -0,0 +1,88 @@
# Copyright 2018 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
cimport cpython
cdef void* _copy_pointer(void* pointer):
return pointer
cdef void _destroy_pointer(void* pointer):
pass
cdef int _compare_pointer(void* first_pointer, void* second_pointer):
if first_pointer < second_pointer:
return -1
elif first_pointer > second_pointer:
return 1
else:
return 0
cdef class _ArgumentProcessor:
cdef void c(self, argument, grpc_arg_pointer_vtable *vtable, references):
key, value = argument
cdef bytes encoded_key = _encode(key)
if encoded_key is not key:
references.append(encoded_key)
self.c_argument.key = encoded_key
if isinstance(value, int):
self.c_argument.type = GRPC_ARG_INTEGER
self.c_argument.value.integer = value
elif isinstance(value, (bytes, str, unicode,)):
self.c_argument.type = GRPC_ARG_STRING
encoded_value = _encode(value)
if encoded_value is not value:
references.append(encoded_value)
self.c_argument.value.string = encoded_value
elif hasattr(value, '__int__'):
# Pointer objects must override __int__() to return
# the underlying C address (Python ints are word size). The
# lifecycle of the pointer is fixed to the lifecycle of the
# python object wrapping it.
self.c_argument.type = GRPC_ARG_POINTER
self.c_argument.value.pointer.vtable = vtable
self.c_argument.value.pointer.address = <void*>(<intptr_t>int(value))
else:
raise TypeError(
'Expected int, bytes, or behavior, got {}'.format(type(value)))
cdef class _ArgumentsProcessor:
def __cinit__(self, arguments):
self._arguments = () if arguments is None else tuple(arguments)
self._argument_processors = []
self._references = []
cdef grpc_channel_args *c(self, grpc_arg_pointer_vtable *vtable):
self._c_arguments.arguments_length = len(self._arguments)
if self._c_arguments.arguments_length == 0:
return NULL
else:
self._c_arguments.arguments = <grpc_arg *>gpr_malloc(
self._c_arguments.arguments_length * sizeof(grpc_arg))
for index, argument in enumerate(self._arguments):
argument_processor = _ArgumentProcessor()
argument_processor.c(argument, vtable, self._references)
self._c_arguments.arguments[index] = argument_processor.c_argument
self._argument_processors.append(argument_processor)
return &self._c_arguments
cdef un_c(self):
if self._arguments:
gpr_free(self._c_arguments.arguments)

@ -15,5 +15,7 @@
cdef class Channel: cdef class Channel:
cdef grpc_arg_pointer_vtable _vtable
cdef grpc_channel *c_channel cdef grpc_channel *c_channel
cdef list references cdef list references
cdef readonly _ArgumentsProcessor _arguments_processor

@ -17,26 +17,25 @@ cimport cpython
cdef class Channel: cdef class Channel:
def __cinit__(self, bytes target, ChannelArgs arguments, def __cinit__(self, bytes target, object arguments,
ChannelCredentials channel_credentials=None): ChannelCredentials channel_credentials=None):
grpc_init() grpc_init()
cdef grpc_channel_args *c_arguments = NULL self._vtable.copy = &_copy_pointer
cdef char *c_target = NULL self._vtable.destroy = &_destroy_pointer
self.c_channel = NULL self._vtable.cmp = &_compare_pointer
cdef _ArgumentsProcessor arguments_processor = _ArgumentsProcessor(
arguments)
cdef grpc_channel_args *c_arguments = arguments_processor.c(&self._vtable)
self.references = [] self.references = []
if len(arguments) > 0:
c_arguments = &arguments.c_args
self.references.append(arguments)
c_target = target c_target = target
if channel_credentials is None: if channel_credentials is None:
with nogil: self.c_channel = grpc_insecure_channel_create(c_target, c_arguments, NULL)
self.c_channel = grpc_insecure_channel_create(c_target, c_arguments,
NULL)
else: else:
c_channel_credentials = channel_credentials.c() c_channel_credentials = channel_credentials.c()
self.c_channel = grpc_secure_channel_create( self.c_channel = grpc_secure_channel_create(
c_channel_credentials, c_target, c_arguments, NULL) c_channel_credentials, c_target, c_arguments, NULL)
grpc_channel_credentials_release(c_channel_credentials) grpc_channel_credentials_release(c_channel_credentials)
arguments_processor.un_c()
self.references.append(target) self.references.append(target)
self.references.append(arguments) self.references.append(arguments)

@ -29,19 +29,6 @@ cdef class SslPemKeyCertPair:
cdef readonly object private_key, certificate_chain cdef readonly object private_key, certificate_chain
cdef class ChannelArg:
cdef grpc_arg c_arg
cdef grpc_arg_pointer_vtable ptr_vtable
cdef readonly object key, value
cdef class ChannelArgs:
cdef grpc_channel_args c_args
cdef list args
cdef class CompressionOptions: cdef class CompressionOptions:
cdef grpc_compression_options c_options cdef grpc_compression_options c_options

@ -157,81 +157,6 @@ cdef class SslPemKeyCertPair:
self.c_pair.certificate_chain = self.certificate_chain self.c_pair.certificate_chain = self.certificate_chain
cdef void* copy_ptr(void* ptr):
return ptr
cdef void destroy_ptr(void* ptr):
pass
cdef int compare_ptr(void* ptr1, void* ptr2):
if ptr1 < ptr2:
return -1
elif ptr1 > ptr2:
return 1
else:
return 0
cdef class ChannelArg:
def __cinit__(self, bytes key, value):
self.key = key
self.value = value
self.c_arg.key = self.key
if isinstance(value, int):
self.c_arg.type = GRPC_ARG_INTEGER
self.c_arg.value.integer = self.value
elif isinstance(value, bytes):
self.c_arg.type = GRPC_ARG_STRING
self.c_arg.value.string = self.value
elif hasattr(value, '__int__'):
# Pointer objects must override __int__() to return
# the underlying C address (Python ints are word size). The
# lifecycle of the pointer is fixed to the lifecycle of the
# python object wrapping it.
self.ptr_vtable.copy = &copy_ptr
self.ptr_vtable.destroy = &destroy_ptr
self.ptr_vtable.cmp = &compare_ptr
self.c_arg.type = GRPC_ARG_POINTER
self.c_arg.value.pointer.vtable = &self.ptr_vtable
self.c_arg.value.pointer.address = <void*>(<intptr_t>int(self.value))
else:
# TODO Add supported pointer types to this message
raise TypeError('Expected int or bytes, got {}'.format(type(value)))
cdef class ChannelArgs:
def __cinit__(self, args):
grpc_init()
self.args = list(args)
for arg in self.args:
if not isinstance(arg, ChannelArg):
raise TypeError("expected list of ChannelArg")
self.c_args.arguments_length = len(self.args)
with nogil:
self.c_args.arguments = <grpc_arg *>gpr_malloc(
self.c_args.arguments_length*sizeof(grpc_arg))
for i in range(self.c_args.arguments_length):
self.c_args.arguments[i] = (<ChannelArg>self.args[i]).c_arg
def __dealloc__(self):
with nogil:
gpr_free(self.c_args.arguments)
grpc_shutdown()
def __len__(self):
# self.args is never stale; it's only updated from this file
return len(self.args)
def __getitem__(self, size_t i):
# self.args is never stale; it's only updated from this file
return self.args[i]
cdef class CompressionOptions: cdef class CompressionOptions:
def __cinit__(self): def __cinit__(self):
@ -254,8 +179,10 @@ cdef class CompressionOptions:
return result return result
def to_channel_arg(self): def to_channel_arg(self):
return ChannelArg(GRPC_COMPRESSION_CHANNEL_ENABLED_ALGORITHMS_BITSET, return (
self.c_options.enabled_algorithms_bitset) GRPC_COMPRESSION_CHANNEL_ENABLED_ALGORITHMS_BITSET,
self.c_options.enabled_algorithms_bitset,
)
def compression_algorithm_name(grpc_compression_algorithm algorithm): def compression_algorithm_name(grpc_compression_algorithm algorithm):

@ -15,6 +15,8 @@
cdef class Server: cdef class Server:
cdef grpc_arg_pointer_vtable _vtable
cdef readonly _ArgumentsProcessor _arguments_processor
cdef grpc_server *c_server cdef grpc_server *c_server
cdef bint is_started # start has been called cdef bint is_started # start has been called
cdef bint is_shutting_down # shutdown has been called cdef bint is_shutting_down # shutdown has been called

@ -57,16 +57,19 @@ cdef grpc_ssl_certificate_config_reload_status _server_cert_config_fetcher_wrapp
cdef class Server: cdef class Server:
def __cinit__(self, ChannelArgs arguments): def __cinit__(self, object arguments):
grpc_init() grpc_init()
cdef grpc_channel_args *c_arguments = NULL
self.references = [] self.references = []
self.registered_completion_queues = [] self.registered_completion_queues = []
if len(arguments) > 0: self._vtable.copy = &_copy_pointer
c_arguments = &arguments.c_args self._vtable.destroy = &_destroy_pointer
self.references.append(arguments) self._vtable.cmp = &_compare_pointer
with nogil: cdef _ArgumentsProcessor arguments_processor = _ArgumentsProcessor(
self.c_server = grpc_server_create(c_arguments, NULL) arguments)
cdef grpc_channel_args *c_arguments = arguments_processor.c(&self._vtable)
self.c_server = grpc_server_create(c_arguments, NULL)
arguments_processor.un_c()
self.references.append(arguments)
self.is_started = False self.is_started = False
self.is_shutting_down = False self.is_shutting_down = False
self.is_shutdown = False self.is_shutdown = False

@ -14,6 +14,7 @@
include "_cygrpc/grpc.pxi" include "_cygrpc/grpc.pxi"
include "_cygrpc/arguments.pxd.pxi"
include "_cygrpc/call.pxd.pxi" include "_cygrpc/call.pxd.pxi"
include "_cygrpc/channel.pxd.pxi" include "_cygrpc/channel.pxd.pxi"
include "_cygrpc/credentials.pxd.pxi" include "_cygrpc/credentials.pxd.pxi"

@ -21,6 +21,7 @@ import sys
# TODO(atash): figure out why the coverage tool gets confused about the Cython # TODO(atash): figure out why the coverage tool gets confused about the Cython
# coverage plugin when the following files don't have a '.pxi' suffix. # coverage plugin when the following files don't have a '.pxi' suffix.
include "_cygrpc/grpc_string.pyx.pxi" include "_cygrpc/grpc_string.pyx.pxi"
include "_cygrpc/arguments.pyx.pxi"
include "_cygrpc/call.pyx.pxi" include "_cygrpc/call.pyx.pxi"
include "_cygrpc/channel.pyx.pxi" include "_cygrpc/channel.pyx.pxi"
include "_cygrpc/credentials.pyx.pxi" include "_cygrpc/credentials.pyx.pxi"

@ -791,7 +791,7 @@ class Server(grpc.Server):
def __init__(self, thread_pool, generic_handlers, interceptors, options, def __init__(self, thread_pool, generic_handlers, interceptors, options,
maximum_concurrent_rpcs): maximum_concurrent_rpcs):
completion_queue = cygrpc.CompletionQueue() completion_queue = cygrpc.CompletionQueue()
server = cygrpc.Server(_common.channel_args(options)) server = cygrpc.Server(options)
server.register_completion_queue(completion_queue) server.register_completion_queue(completion_queue)
self._state = _ServerState(completion_queue, server, generic_handlers, self._state = _ServerState(completion_queue, server, generic_handlers,
_interceptor.service_pipeline(interceptors), _interceptor.service_pipeline(interceptors),

@ -141,13 +141,16 @@ class CancelManyCallsTest(unittest.TestCase):
test_constants.THREAD_CONCURRENCY) test_constants.THREAD_CONCURRENCY)
server_completion_queue = cygrpc.CompletionQueue() server_completion_queue = cygrpc.CompletionQueue()
server = cygrpc.Server( server = cygrpc.Server([
cygrpc.ChannelArgs([cygrpc.ChannelArg(b'grpc.so_reuseport', 0)])) (
b'grpc.so_reuseport',
0,
),
])
server.register_completion_queue(server_completion_queue) server.register_completion_queue(server_completion_queue)
port = server.add_http2_port(b'[::]:0') port = server.add_http2_port(b'[::]:0')
server.start() server.start()
channel = cygrpc.Channel('localhost:{}'.format(port).encode(), channel = cygrpc.Channel('localhost:{}'.format(port).encode(), None)
cygrpc.ChannelArgs([]))
state = _State() state = _State()

@ -22,7 +22,7 @@ from tests.unit.framework.common import test_constants
def _channel_and_completion_queue(): def _channel_and_completion_queue():
channel = cygrpc.Channel(b'localhost:54321', cygrpc.ChannelArgs(())) channel = cygrpc.Channel(b'localhost:54321', ())
completion_queue = cygrpc.CompletionQueue() completion_queue = cygrpc.CompletionQueue()
return channel, completion_queue return channel, completion_queue

@ -96,13 +96,11 @@ class RpcTest(object):
def setUp(self): def setUp(self):
self.server_completion_queue = cygrpc.CompletionQueue() self.server_completion_queue = cygrpc.CompletionQueue()
self.server = cygrpc.Server( self.server = cygrpc.Server([(b'grpc.so_reuseport', 0)])
cygrpc.ChannelArgs([cygrpc.ChannelArg(b'grpc.so_reuseport', 0)]))
self.server.register_completion_queue(self.server_completion_queue) self.server.register_completion_queue(self.server_completion_queue)
port = self.server.add_http2_port(b'[::]:0') port = self.server.add_http2_port(b'[::]:0')
self.server.start() self.server.start()
self.channel = cygrpc.Channel('localhost:{}'.format(port).encode(), self.channel = cygrpc.Channel('localhost:{}'.format(port).encode(), [])
cygrpc.ChannelArgs([]))
self._server_shutdown_tag = 'server_shutdown_tag' self._server_shutdown_tag = 'server_shutdown_tag'
self.server_condition = threading.Condition() self.server_condition = threading.Condition()

@ -111,13 +111,14 @@ class ReadSomeButNotAllResponsesTest(unittest.TestCase):
def testReadSomeButNotAllResponses(self): def testReadSomeButNotAllResponses(self):
server_completion_queue = cygrpc.CompletionQueue() server_completion_queue = cygrpc.CompletionQueue()
server = cygrpc.Server( server = cygrpc.Server([(
cygrpc.ChannelArgs([cygrpc.ChannelArg(b'grpc.so_reuseport', 0)])) b'grpc.so_reuseport',
0,
)])
server.register_completion_queue(server_completion_queue) server.register_completion_queue(server_completion_queue)
port = server.add_http2_port(b'[::]:0') port = server.add_http2_port(b'[::]:0')
server.start() server.start()
channel = cygrpc.Channel('localhost:{}'.format(port).encode(), channel = cygrpc.Channel('localhost:{}'.format(port).encode(), set())
cygrpc.ChannelArgs([]))
server_shutdown_tag = 'server_shutdown_tag' server_shutdown_tag = 'server_shutdown_tag'
server_driver = _ServerDriver(server_completion_queue, server_driver = _ServerDriver(server_completion_queue,

@ -25,7 +25,7 @@ class Test(unittest.TestCase):
def test_lonely_server(self): def test_lonely_server(self):
server_call_completion_queue = cygrpc.CompletionQueue() server_call_completion_queue = cygrpc.CompletionQueue()
server_shutdown_completion_queue = cygrpc.CompletionQueue() server_shutdown_completion_queue = cygrpc.CompletionQueue()
server = cygrpc.Server(cygrpc.ChannelArgs([])) server = cygrpc.Server(None)
server.register_completion_queue(server_call_completion_queue) server.register_completion_queue(server_call_completion_queue)
server.register_completion_queue(server_shutdown_completion_queue) server.register_completion_queue(server_shutdown_completion_queue)
port = server.add_http2_port(b'[::]:0') port = server.add_http2_port(b'[::]:0')

@ -42,12 +42,16 @@ class TypeSmokeTest(unittest.TestCase):
del completion_queue del completion_queue
def testServerUpDown(self): def testServerUpDown(self):
server = cygrpc.Server( server = cygrpc.Server(set([
cygrpc.ChannelArgs([cygrpc.ChannelArg(b'grpc.so_reuseport', 0)])) (
b'grpc.so_reuseport',
0,
),
]))
del server del server
def testChannelUpDown(self): def testChannelUpDown(self):
channel = cygrpc.Channel(b'[::]:0', cygrpc.ChannelArgs([])) channel = cygrpc.Channel(b'[::]:0', None)
del channel del channel
def test_metadata_plugin_call_credentials_up_down(self): def test_metadata_plugin_call_credentials_up_down(self):
@ -55,8 +59,12 @@ class TypeSmokeTest(unittest.TestCase):
b'test plugin name!') b'test plugin name!')
def testServerStartNoExplicitShutdown(self): def testServerStartNoExplicitShutdown(self):
server = cygrpc.Server( server = cygrpc.Server([
cygrpc.ChannelArgs([cygrpc.ChannelArg(b'grpc.so_reuseport', 0)])) (
b'grpc.so_reuseport',
0,
),
])
completion_queue = cygrpc.CompletionQueue() completion_queue = cygrpc.CompletionQueue()
server.register_completion_queue(completion_queue) server.register_completion_queue(completion_queue)
port = server.add_http2_port(b'[::]:0') port = server.add_http2_port(b'[::]:0')
@ -66,8 +74,12 @@ class TypeSmokeTest(unittest.TestCase):
def testServerStartShutdown(self): def testServerStartShutdown(self):
completion_queue = cygrpc.CompletionQueue() completion_queue = cygrpc.CompletionQueue()
server = cygrpc.Server( server = cygrpc.Server([
cygrpc.ChannelArgs([cygrpc.ChannelArg(b'grpc.so_reuseport', 0)])) (
b'grpc.so_reuseport',
0,
),
])
server.add_http2_port(b'[::]:0') server.add_http2_port(b'[::]:0')
server.register_completion_queue(completion_queue) server.register_completion_queue(completion_queue)
server.start() server.start()
@ -85,8 +97,12 @@ class ServerClientMixin(object):
def setUpMixin(self, server_credentials, client_credentials, host_override): def setUpMixin(self, server_credentials, client_credentials, host_override):
self.server_completion_queue = cygrpc.CompletionQueue() self.server_completion_queue = cygrpc.CompletionQueue()
self.server = cygrpc.Server( self.server = cygrpc.Server([
cygrpc.ChannelArgs([cygrpc.ChannelArg(b'grpc.so_reuseport', 0)])) (
b'grpc.so_reuseport',
0,
),
])
self.server.register_completion_queue(self.server_completion_queue) self.server.register_completion_queue(self.server_completion_queue)
if server_credentials: if server_credentials:
self.port = self.server.add_http2_port(b'[::]:0', self.port = self.server.add_http2_port(b'[::]:0',
@ -96,16 +112,16 @@ class ServerClientMixin(object):
self.server.start() self.server.start()
self.client_completion_queue = cygrpc.CompletionQueue() self.client_completion_queue = cygrpc.CompletionQueue()
if client_credentials: if client_credentials:
client_channel_arguments = cygrpc.ChannelArgs([ client_channel_arguments = ((
cygrpc.ChannelArg(cygrpc.ChannelArgKey.ssl_target_name_override, cygrpc.ChannelArgKey.ssl_target_name_override,
host_override) host_override,
]) ),)
self.client_channel = cygrpc.Channel('localhost:{}'.format( self.client_channel = cygrpc.Channel('localhost:{}'.format(
self.port).encode(), client_channel_arguments, self.port).encode(), client_channel_arguments,
client_credentials) client_credentials)
else: else:
self.client_channel = cygrpc.Channel('localhost:{}'.format( self.client_channel = cygrpc.Channel('localhost:{}'.format(
self.port).encode(), cygrpc.ChannelArgs([])) self.port).encode(), set())
if host_override: if host_override:
self.host_argument = None # default host self.host_argument = None # default host
self.expected_host = host_override self.expected_host = host_override

@ -80,7 +80,7 @@ _TRAILING_METADATA = (
_EXPECTED_TRAILING_METADATA = _TRAILING_METADATA _EXPECTED_TRAILING_METADATA = _TRAILING_METADATA
def user_agent(metadata): def _user_agent(metadata):
for key, val in metadata: for key, val in metadata:
if key == 'user-agent': if key == 'user-agent':
return val return val
@ -88,16 +88,14 @@ def user_agent(metadata):
def validate_client_metadata(test, servicer_context): def validate_client_metadata(test, servicer_context):
invocation_metadata = servicer_context.invocation_metadata()
test.assertTrue( test.assertTrue(
test_common.metadata_transmitted( test_common.metadata_transmitted(_EXPECTED_INVOCATION_METADATA,
_EXPECTED_INVOCATION_METADATA, invocation_metadata))
servicer_context.invocation_metadata())) user_agent = _user_agent(invocation_metadata)
test.assertTrue( test.assertTrue(
user_agent(servicer_context.invocation_metadata()) user_agent.startswith('primary-agent ' + _channel._USER_AGENT))
.startswith('primary-agent ' + _channel._USER_AGENT)) test.assertTrue(user_agent.endswith('secondary-agent'))
test.assertTrue(
user_agent(servicer_context.invocation_metadata())
.endswith('secondary-agent'))
def handle_unary_unary(test, request, servicer_context): def handle_unary_unary(test, request, servicer_context):

Loading…
Cancel
Save