From 67f0330c10993ae222471acc63acf69a275d56ce Mon Sep 17 00:00:00 2001 From: Richard Belleville Date: Tue, 1 Jun 2021 15:25:56 -0700 Subject: [PATCH] Fix metadata plugin for concurrent invocations (#26405) * Fix metadata plugin for concurrent invocations * Disable broad-except lint error --- .../grpc/_cython/_cygrpc/credentials.pyx.pxi | 2 +- .../unit/_contextvars_propagation_test.py | 42 +++++++++++++++++++ 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx.pxi index acd79db8780..23de3a0b188 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx.pxi @@ -59,7 +59,7 @@ cdef int _get_metadata(void *state, args = context.service_url, context.method_name, callback, plugin = state if plugin._stored_ctx is not None: - plugin._stored_ctx.run(_spawn_callback_async, plugin, args) + plugin._stored_ctx.copy().run(_spawn_callback_async, plugin, args) else: _spawn_callback_async(state, args) return 0 # Asynchronous return diff --git a/src/python/grpcio_tests/tests/unit/_contextvars_propagation_test.py b/src/python/grpcio_tests/tests/unit/_contextvars_propagation_test.py index fec0fbd7df4..e3540245663 100644 --- a/src/python/grpcio_tests/tests/unit/_contextvars_propagation_test.py +++ b/src/python/grpcio_tests/tests/unit/_contextvars_propagation_test.py @@ -17,11 +17,13 @@ import contextlib import logging import os import sys +import threading import unittest import grpc from tests.unit import test_common +from six.moves import queue _UNARY_UNARY = "/test/UnaryUnary" _REQUEST = b"0000" @@ -112,6 +114,46 @@ class ContextVarsPropagationTest(unittest.TestCase): response = stub(_REQUEST, wait_for_ready=True) self.assertEqual(_REQUEST, response) + def test_concurrent_propagation(self): + _THREAD_COUNT = 32 + _RPC_COUNT = 32 + + set_up_expected_context() + with _server() as port: + target = "localhost:{}".format(port) + local_credentials = grpc.local_channel_credentials() + test_call_credentials = TestCallCredentials() + call_credentials = grpc.metadata_call_credentials( + test_call_credentials, "test call credentials") + composite_credentials = grpc.composite_channel_credentials( + local_credentials, call_credentials) + start_event = threading.Event() + + def _run_on_thread(exception_queue): + try: + for _ in range(_THREAD_COUNT): + with grpc.secure_channel( + target, composite_credentials) as channel: + start_event.wait() + stub = channel.unary_unary(_UNARY_UNARY) + response = stub(_REQUEST, wait_for_ready=True) + self.assertEqual(_REQUEST, response) + except Exception as e: # pylint: disable=broad-except + exception_queue.put(e) + + threads = [] + for _ in range(_RPC_COUNT): + q = queue.Queue() + thread = threading.Thread(target=_run_on_thread, args=(q,)) + thread.setDaemon(True) + thread.start() + threads.append((thread, q)) + start_event.set() + for thread, q in threads: + thread.join() + if not q.empty(): + raise q.get() + if __name__ == '__main__': logging.basicConfig()