diff --git a/src/python/grpcio/grpc/_channel.py b/src/python/grpcio/grpc/_channel.py index bfc7208310c..25a42109747 100644 --- a/src/python/grpcio/grpc/_channel.py +++ b/src/python/grpcio/grpc/_channel.py @@ -13,10 +13,10 @@ # limitations under the License. """Invocation-side implementation of gRPC Python.""" +import logging import sys import threading import time -import logging import grpc from grpc import _common @@ -882,8 +882,12 @@ def _unsubscribe(state, callback): def _options(options): - return list(options) + [(cygrpc.ChannelArgKey.primary_user_agent_string, - _USER_AGENT)] + return list(options) + [ + ( + cygrpc.ChannelArgKey.primary_user_agent_string, + _USER_AGENT, + ), + ] class Channel(grpc.Channel): @@ -892,14 +896,13 @@ class Channel(grpc.Channel): def __init__(self, target, options, credentials): """Constructor. - Args: - target: The target to which to connect. - options: Configuration options for the channel. - credentials: A cygrpc.ChannelCredentials or None. - """ + Args: + target: The target to which to connect. + options: Configuration options for the channel. + credentials: A cygrpc.ChannelCredentials or None. + """ self._channel = cygrpc.Channel( - _common.encode(target), _common.channel_args(_options(options)), - credentials) + _common.encode(target), _options(options), credentials) self._call_state = _ChannelCallState(self._channel) self._connectivity_state = _ChannelConnectivityState(self._channel) diff --git a/src/python/grpcio/grpc/_common.py b/src/python/grpcio/grpc/_common.py index 130fc426305..bbb69ad4893 100644 --- a/src/python/grpcio/grpc/_common.py +++ b/src/python/grpcio/grpc/_common.py @@ -79,16 +79,6 @@ def decode(b): return b.decode('latin1') -def channel_args(options): - cygrpc_args = [] - for key, value in options: - if isinstance(value, six.string_types): - cygrpc_args.append(cygrpc.ChannelArg(encode(key), encode(value))) - else: - cygrpc_args.append(cygrpc.ChannelArg(encode(key), value)) - return cygrpc.ChannelArgs(cygrpc_args) - - def _transform(message, transformer, exception_message): if transformer is None: return message diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/arguments.pxd.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/arguments.pxd.pxi new file mode 100644 index 00000000000..853bf6f8e04 --- /dev/null +++ b/src/python/grpcio/grpc/_cython/_cygrpc/arguments.pxd.pxi @@ -0,0 +1,40 @@ +# Copyright 2018 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +cdef void* _copy_pointer(void* pointer) + + +cdef void _destroy_pointer(void* pointer) + + +cdef int _compare_pointer(void* first_pointer, void* second_pointer) + + +cdef class _ArgumentProcessor: + + cdef grpc_arg c_argument + + cdef void c(self, argument, grpc_arg_pointer_vtable *vtable, references) + + +cdef class _ArgumentsProcessor: + + cdef readonly tuple _arguments + cdef list _argument_processors + cdef readonly list _references + cdef grpc_channel_args _c_arguments + + cdef grpc_channel_args *c(self, grpc_arg_pointer_vtable *vtable) + cdef un_c(self) diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/arguments.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/arguments.pyx.pxi new file mode 100644 index 00000000000..65de30884c2 --- /dev/null +++ b/src/python/grpcio/grpc/_cython/_cygrpc/arguments.pyx.pxi @@ -0,0 +1,88 @@ +# Copyright 2018 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +cimport cpython + + +cdef void* _copy_pointer(void* pointer): + return pointer + + +cdef void _destroy_pointer(void* pointer): + pass + + +cdef int _compare_pointer(void* first_pointer, void* second_pointer): + if first_pointer < second_pointer: + return -1 + elif first_pointer > second_pointer: + return 1 + else: + return 0 + + +cdef class _ArgumentProcessor: + + cdef void c(self, argument, grpc_arg_pointer_vtable *vtable, references): + key, value = argument + cdef bytes encoded_key = _encode(key) + if encoded_key is not key: + references.append(encoded_key) + self.c_argument.key = encoded_key + if isinstance(value, int): + self.c_argument.type = GRPC_ARG_INTEGER + self.c_argument.value.integer = value + elif isinstance(value, (bytes, str, unicode,)): + self.c_argument.type = GRPC_ARG_STRING + encoded_value = _encode(value) + if encoded_value is not value: + references.append(encoded_value) + self.c_argument.value.string = encoded_value + elif hasattr(value, '__int__'): + # Pointer objects must override __int__() to return + # the underlying C address (Python ints are word size). The + # lifecycle of the pointer is fixed to the lifecycle of the + # python object wrapping it. + self.c_argument.type = GRPC_ARG_POINTER + self.c_argument.value.pointer.vtable = vtable + self.c_argument.value.pointer.address = (int(value)) + else: + raise TypeError( + 'Expected int, bytes, or behavior, got {}'.format(type(value))) + + +cdef class _ArgumentsProcessor: + + def __cinit__(self, arguments): + self._arguments = () if arguments is None else tuple(arguments) + self._argument_processors = [] + self._references = [] + + cdef grpc_channel_args *c(self, grpc_arg_pointer_vtable *vtable): + self._c_arguments.arguments_length = len(self._arguments) + if self._c_arguments.arguments_length == 0: + return NULL + else: + self._c_arguments.arguments = gpr_malloc( + self._c_arguments.arguments_length * sizeof(grpc_arg)) + for index, argument in enumerate(self._arguments): + argument_processor = _ArgumentProcessor() + argument_processor.c(argument, vtable, self._references) + self._c_arguments.arguments[index] = argument_processor.c_argument + self._argument_processors.append(argument_processor) + return &self._c_arguments + + cdef un_c(self): + if self._arguments: + gpr_free(self._c_arguments.arguments) diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/channel.pxd.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/channel.pxd.pxi index 4b07e71cec1..1ba76b7f838 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/channel.pxd.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/channel.pxd.pxi @@ -15,5 +15,7 @@ cdef class Channel: + cdef grpc_arg_pointer_vtable _vtable cdef grpc_channel *c_channel cdef list references + cdef readonly _ArgumentsProcessor _arguments_processor diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi index efe5f2e0db2..a3966497bcb 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi @@ -17,26 +17,25 @@ cimport cpython cdef class Channel: - def __cinit__(self, bytes target, ChannelArgs arguments, + def __cinit__(self, bytes target, object arguments, ChannelCredentials channel_credentials=None): grpc_init() - cdef grpc_channel_args *c_arguments = NULL - cdef char *c_target = NULL - self.c_channel = NULL + self._vtable.copy = &_copy_pointer + self._vtable.destroy = &_destroy_pointer + self._vtable.cmp = &_compare_pointer + cdef _ArgumentsProcessor arguments_processor = _ArgumentsProcessor( + arguments) + cdef grpc_channel_args *c_arguments = arguments_processor.c(&self._vtable) self.references = [] - if len(arguments) > 0: - c_arguments = &arguments.c_args - self.references.append(arguments) c_target = target if channel_credentials is None: - with nogil: - self.c_channel = grpc_insecure_channel_create(c_target, c_arguments, - NULL) + self.c_channel = grpc_insecure_channel_create(c_target, c_arguments, NULL) else: c_channel_credentials = channel_credentials.c() self.c_channel = grpc_secure_channel_create( c_channel_credentials, c_target, c_arguments, NULL) grpc_channel_credentials_release(c_channel_credentials) + arguments_processor.un_c() self.references.append(target) self.references.append(arguments) diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/records.pxd.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/records.pxd.pxi index 297bbadfe08..35e1bdb0aeb 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/records.pxd.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/records.pxd.pxi @@ -29,19 +29,6 @@ cdef class SslPemKeyCertPair: cdef readonly object private_key, certificate_chain -cdef class ChannelArg: - - cdef grpc_arg c_arg - cdef grpc_arg_pointer_vtable ptr_vtable - cdef readonly object key, value - - -cdef class ChannelArgs: - - cdef grpc_channel_args c_args - cdef list args - - cdef class CompressionOptions: cdef grpc_compression_options c_options diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/records.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/records.pyx.pxi index b2343b53d6a..ecd991685fa 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/records.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/records.pyx.pxi @@ -157,81 +157,6 @@ cdef class SslPemKeyCertPair: self.c_pair.certificate_chain = self.certificate_chain - -cdef void* copy_ptr(void* ptr): - return ptr - - -cdef void destroy_ptr(void* ptr): - pass - - -cdef int compare_ptr(void* ptr1, void* ptr2): - if ptr1 < ptr2: - return -1 - elif ptr1 > ptr2: - return 1 - else: - return 0 - - -cdef class ChannelArg: - - def __cinit__(self, bytes key, value): - self.key = key - self.value = value - self.c_arg.key = self.key - if isinstance(value, int): - self.c_arg.type = GRPC_ARG_INTEGER - self.c_arg.value.integer = self.value - elif isinstance(value, bytes): - self.c_arg.type = GRPC_ARG_STRING - self.c_arg.value.string = self.value - elif hasattr(value, '__int__'): - # Pointer objects must override __int__() to return - # the underlying C address (Python ints are word size). The - # lifecycle of the pointer is fixed to the lifecycle of the - # python object wrapping it. - self.ptr_vtable.copy = ©_ptr - self.ptr_vtable.destroy = &destroy_ptr - self.ptr_vtable.cmp = &compare_ptr - self.c_arg.type = GRPC_ARG_POINTER - self.c_arg.value.pointer.vtable = &self.ptr_vtable - self.c_arg.value.pointer.address = (int(self.value)) - else: - # TODO Add supported pointer types to this message - raise TypeError('Expected int or bytes, got {}'.format(type(value))) - - -cdef class ChannelArgs: - - def __cinit__(self, args): - grpc_init() - self.args = list(args) - for arg in self.args: - if not isinstance(arg, ChannelArg): - raise TypeError("expected list of ChannelArg") - self.c_args.arguments_length = len(self.args) - with nogil: - self.c_args.arguments = gpr_malloc( - self.c_args.arguments_length*sizeof(grpc_arg)) - for i in range(self.c_args.arguments_length): - self.c_args.arguments[i] = (self.args[i]).c_arg - - def __dealloc__(self): - with nogil: - gpr_free(self.c_args.arguments) - grpc_shutdown() - - def __len__(self): - # self.args is never stale; it's only updated from this file - return len(self.args) - - def __getitem__(self, size_t i): - # self.args is never stale; it's only updated from this file - return self.args[i] - - cdef class CompressionOptions: def __cinit__(self): @@ -254,8 +179,10 @@ cdef class CompressionOptions: return result def to_channel_arg(self): - return ChannelArg(GRPC_COMPRESSION_CHANNEL_ENABLED_ALGORITHMS_BITSET, - self.c_options.enabled_algorithms_bitset) + return ( + GRPC_COMPRESSION_CHANNEL_ENABLED_ALGORITHMS_BITSET, + self.c_options.enabled_algorithms_bitset, + ) def compression_algorithm_name(grpc_compression_algorithm algorithm): diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/server.pxd.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/server.pxd.pxi index df4577e1246..4588db30d36 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/server.pxd.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/server.pxd.pxi @@ -15,6 +15,8 @@ cdef class Server: + cdef grpc_arg_pointer_vtable _vtable + cdef readonly _ArgumentsProcessor _arguments_processor cdef grpc_server *c_server cdef bint is_started # start has been called cdef bint is_shutting_down # shutdown has been called diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi index e5d28a85d58..707ec742dda 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi @@ -57,16 +57,19 @@ cdef grpc_ssl_certificate_config_reload_status _server_cert_config_fetcher_wrapp cdef class Server: - def __cinit__(self, ChannelArgs arguments): + def __cinit__(self, object arguments): grpc_init() - cdef grpc_channel_args *c_arguments = NULL self.references = [] self.registered_completion_queues = [] - if len(arguments) > 0: - c_arguments = &arguments.c_args - self.references.append(arguments) - with nogil: - self.c_server = grpc_server_create(c_arguments, NULL) + self._vtable.copy = &_copy_pointer + self._vtable.destroy = &_destroy_pointer + self._vtable.cmp = &_compare_pointer + cdef _ArgumentsProcessor arguments_processor = _ArgumentsProcessor( + arguments) + cdef grpc_channel_args *c_arguments = arguments_processor.c(&self._vtable) + self.c_server = grpc_server_create(c_arguments, NULL) + arguments_processor.un_c() + self.references.append(arguments) self.is_started = False self.is_shutting_down = False self.is_shutdown = False diff --git a/src/python/grpcio/grpc/_cython/cygrpc.pxd b/src/python/grpcio/grpc/_cython/cygrpc.pxd index 01e2da6d542..b6a794c6d7b 100644 --- a/src/python/grpcio/grpc/_cython/cygrpc.pxd +++ b/src/python/grpcio/grpc/_cython/cygrpc.pxd @@ -14,6 +14,7 @@ include "_cygrpc/grpc.pxi" +include "_cygrpc/arguments.pxd.pxi" include "_cygrpc/call.pxd.pxi" include "_cygrpc/channel.pxd.pxi" include "_cygrpc/credentials.pxd.pxi" diff --git a/src/python/grpcio/grpc/_cython/cygrpc.pyx b/src/python/grpcio/grpc/_cython/cygrpc.pyx index d8ac84a317d..2ee2e6b73e8 100644 --- a/src/python/grpcio/grpc/_cython/cygrpc.pyx +++ b/src/python/grpcio/grpc/_cython/cygrpc.pyx @@ -21,6 +21,7 @@ import sys # TODO(atash): figure out why the coverage tool gets confused about the Cython # coverage plugin when the following files don't have a '.pxi' suffix. include "_cygrpc/grpc_string.pyx.pxi" +include "_cygrpc/arguments.pyx.pxi" include "_cygrpc/call.pyx.pxi" include "_cygrpc/channel.pyx.pxi" include "_cygrpc/credentials.pyx.pxi" diff --git a/src/python/grpcio/grpc/_server.py b/src/python/grpcio/grpc/_server.py index 9402941bab1..0b79b50108c 100644 --- a/src/python/grpcio/grpc/_server.py +++ b/src/python/grpcio/grpc/_server.py @@ -791,7 +791,7 @@ class Server(grpc.Server): def __init__(self, thread_pool, generic_handlers, interceptors, options, maximum_concurrent_rpcs): completion_queue = cygrpc.CompletionQueue() - server = cygrpc.Server(_common.channel_args(options)) + server = cygrpc.Server(options) server.register_completion_queue(completion_queue) self._state = _ServerState(completion_queue, server, generic_handlers, _interceptor.service_pipeline(interceptors), diff --git a/src/python/grpcio_tests/tests/unit/_cython/_cancel_many_calls_test.py b/src/python/grpcio_tests/tests/unit/_cython/_cancel_many_calls_test.py index 2ca1fa82f4f..3765ce4fb04 100644 --- a/src/python/grpcio_tests/tests/unit/_cython/_cancel_many_calls_test.py +++ b/src/python/grpcio_tests/tests/unit/_cython/_cancel_many_calls_test.py @@ -141,13 +141,16 @@ class CancelManyCallsTest(unittest.TestCase): test_constants.THREAD_CONCURRENCY) server_completion_queue = cygrpc.CompletionQueue() - server = cygrpc.Server( - cygrpc.ChannelArgs([cygrpc.ChannelArg(b'grpc.so_reuseport', 0)])) + server = cygrpc.Server([ + ( + b'grpc.so_reuseport', + 0, + ), + ]) server.register_completion_queue(server_completion_queue) port = server.add_http2_port(b'[::]:0') server.start() - channel = cygrpc.Channel('localhost:{}'.format(port).encode(), - cygrpc.ChannelArgs([])) + channel = cygrpc.Channel('localhost:{}'.format(port).encode(), None) state = _State() diff --git a/src/python/grpcio_tests/tests/unit/_cython/_channel_test.py b/src/python/grpcio_tests/tests/unit/_cython/_channel_test.py index c22c77ddbd2..7305d0fa3f0 100644 --- a/src/python/grpcio_tests/tests/unit/_cython/_channel_test.py +++ b/src/python/grpcio_tests/tests/unit/_cython/_channel_test.py @@ -22,7 +22,7 @@ from tests.unit.framework.common import test_constants def _channel_and_completion_queue(): - channel = cygrpc.Channel(b'localhost:54321', cygrpc.ChannelArgs(())) + channel = cygrpc.Channel(b'localhost:54321', ()) completion_queue = cygrpc.CompletionQueue() return channel, completion_queue diff --git a/src/python/grpcio_tests/tests/unit/_cython/_common.py b/src/python/grpcio_tests/tests/unit/_cython/_common.py index d4b01ca38b4..7fd3d19b4ec 100644 --- a/src/python/grpcio_tests/tests/unit/_cython/_common.py +++ b/src/python/grpcio_tests/tests/unit/_cython/_common.py @@ -96,13 +96,11 @@ class RpcTest(object): def setUp(self): self.server_completion_queue = cygrpc.CompletionQueue() - self.server = cygrpc.Server( - cygrpc.ChannelArgs([cygrpc.ChannelArg(b'grpc.so_reuseport', 0)])) + self.server = cygrpc.Server([(b'grpc.so_reuseport', 0)]) self.server.register_completion_queue(self.server_completion_queue) port = self.server.add_http2_port(b'[::]:0') self.server.start() - self.channel = cygrpc.Channel('localhost:{}'.format(port).encode(), - cygrpc.ChannelArgs([])) + self.channel = cygrpc.Channel('localhost:{}'.format(port).encode(), []) self._server_shutdown_tag = 'server_shutdown_tag' self.server_condition = threading.Condition() diff --git a/src/python/grpcio_tests/tests/unit/_cython/_read_some_but_not_all_responses_test.py b/src/python/grpcio_tests/tests/unit/_cython/_read_some_but_not_all_responses_test.py index ecd23afda71..bc63b548799 100644 --- a/src/python/grpcio_tests/tests/unit/_cython/_read_some_but_not_all_responses_test.py +++ b/src/python/grpcio_tests/tests/unit/_cython/_read_some_but_not_all_responses_test.py @@ -111,13 +111,14 @@ class ReadSomeButNotAllResponsesTest(unittest.TestCase): def testReadSomeButNotAllResponses(self): server_completion_queue = cygrpc.CompletionQueue() - server = cygrpc.Server( - cygrpc.ChannelArgs([cygrpc.ChannelArg(b'grpc.so_reuseport', 0)])) + server = cygrpc.Server([( + b'grpc.so_reuseport', + 0, + )]) server.register_completion_queue(server_completion_queue) port = server.add_http2_port(b'[::]:0') server.start() - channel = cygrpc.Channel('localhost:{}'.format(port).encode(), - cygrpc.ChannelArgs([])) + channel = cygrpc.Channel('localhost:{}'.format(port).encode(), set()) server_shutdown_tag = 'server_shutdown_tag' server_driver = _ServerDriver(server_completion_queue, diff --git a/src/python/grpcio_tests/tests/unit/_cython/_server_test.py b/src/python/grpcio_tests/tests/unit/_cython/_server_test.py index 12bf40be6b3..bbd25457b3e 100644 --- a/src/python/grpcio_tests/tests/unit/_cython/_server_test.py +++ b/src/python/grpcio_tests/tests/unit/_cython/_server_test.py @@ -25,7 +25,7 @@ class Test(unittest.TestCase): def test_lonely_server(self): server_call_completion_queue = cygrpc.CompletionQueue() server_shutdown_completion_queue = cygrpc.CompletionQueue() - server = cygrpc.Server(cygrpc.ChannelArgs([])) + server = cygrpc.Server(None) server.register_completion_queue(server_call_completion_queue) server.register_completion_queue(server_shutdown_completion_queue) port = server.add_http2_port(b'[::]:0') diff --git a/src/python/grpcio_tests/tests/unit/_cython/cygrpc_test.py b/src/python/grpcio_tests/tests/unit/_cython/cygrpc_test.py index 561adf7dff0..9045ff58a0f 100644 --- a/src/python/grpcio_tests/tests/unit/_cython/cygrpc_test.py +++ b/src/python/grpcio_tests/tests/unit/_cython/cygrpc_test.py @@ -42,12 +42,16 @@ class TypeSmokeTest(unittest.TestCase): del completion_queue def testServerUpDown(self): - server = cygrpc.Server( - cygrpc.ChannelArgs([cygrpc.ChannelArg(b'grpc.so_reuseport', 0)])) + server = cygrpc.Server(set([ + ( + b'grpc.so_reuseport', + 0, + ), + ])) del server def testChannelUpDown(self): - channel = cygrpc.Channel(b'[::]:0', cygrpc.ChannelArgs([])) + channel = cygrpc.Channel(b'[::]:0', None) del channel def test_metadata_plugin_call_credentials_up_down(self): @@ -55,8 +59,12 @@ class TypeSmokeTest(unittest.TestCase): b'test plugin name!') def testServerStartNoExplicitShutdown(self): - server = cygrpc.Server( - cygrpc.ChannelArgs([cygrpc.ChannelArg(b'grpc.so_reuseport', 0)])) + server = cygrpc.Server([ + ( + b'grpc.so_reuseport', + 0, + ), + ]) completion_queue = cygrpc.CompletionQueue() server.register_completion_queue(completion_queue) port = server.add_http2_port(b'[::]:0') @@ -66,8 +74,12 @@ class TypeSmokeTest(unittest.TestCase): def testServerStartShutdown(self): completion_queue = cygrpc.CompletionQueue() - server = cygrpc.Server( - cygrpc.ChannelArgs([cygrpc.ChannelArg(b'grpc.so_reuseport', 0)])) + server = cygrpc.Server([ + ( + b'grpc.so_reuseport', + 0, + ), + ]) server.add_http2_port(b'[::]:0') server.register_completion_queue(completion_queue) server.start() @@ -85,8 +97,12 @@ class ServerClientMixin(object): def setUpMixin(self, server_credentials, client_credentials, host_override): self.server_completion_queue = cygrpc.CompletionQueue() - self.server = cygrpc.Server( - cygrpc.ChannelArgs([cygrpc.ChannelArg(b'grpc.so_reuseport', 0)])) + self.server = cygrpc.Server([ + ( + b'grpc.so_reuseport', + 0, + ), + ]) self.server.register_completion_queue(self.server_completion_queue) if server_credentials: self.port = self.server.add_http2_port(b'[::]:0', @@ -96,16 +112,16 @@ class ServerClientMixin(object): self.server.start() self.client_completion_queue = cygrpc.CompletionQueue() if client_credentials: - client_channel_arguments = cygrpc.ChannelArgs([ - cygrpc.ChannelArg(cygrpc.ChannelArgKey.ssl_target_name_override, - host_override) - ]) + client_channel_arguments = (( + cygrpc.ChannelArgKey.ssl_target_name_override, + host_override, + ),) self.client_channel = cygrpc.Channel('localhost:{}'.format( self.port).encode(), client_channel_arguments, client_credentials) else: self.client_channel = cygrpc.Channel('localhost:{}'.format( - self.port).encode(), cygrpc.ChannelArgs([])) + self.port).encode(), set()) if host_override: self.host_argument = None # default host self.expected_host = host_override diff --git a/src/python/grpcio_tests/tests/unit/_metadata_test.py b/src/python/grpcio_tests/tests/unit/_metadata_test.py index a918066ea48..59084210113 100644 --- a/src/python/grpcio_tests/tests/unit/_metadata_test.py +++ b/src/python/grpcio_tests/tests/unit/_metadata_test.py @@ -80,7 +80,7 @@ _TRAILING_METADATA = ( _EXPECTED_TRAILING_METADATA = _TRAILING_METADATA -def user_agent(metadata): +def _user_agent(metadata): for key, val in metadata: if key == 'user-agent': return val @@ -88,16 +88,14 @@ def user_agent(metadata): def validate_client_metadata(test, servicer_context): + invocation_metadata = servicer_context.invocation_metadata() test.assertTrue( - test_common.metadata_transmitted( - _EXPECTED_INVOCATION_METADATA, - servicer_context.invocation_metadata())) + test_common.metadata_transmitted(_EXPECTED_INVOCATION_METADATA, + invocation_metadata)) + user_agent = _user_agent(invocation_metadata) test.assertTrue( - user_agent(servicer_context.invocation_metadata()) - .startswith('primary-agent ' + _channel._USER_AGENT)) - test.assertTrue( - user_agent(servicer_context.invocation_metadata()) - .endswith('secondary-agent')) + user_agent.startswith('primary-agent ' + _channel._USER_AGENT)) + test.assertTrue(user_agent.endswith('secondary-agent')) def handle_unary_unary(test, request, servicer_context):