diff --git a/src/python/grpcio_tests/tests/unit/_contextvars_propagation_test.py b/src/python/grpcio_tests/tests/unit/_contextvars_propagation_test.py index e3540245663..d97990aff97 100644 --- a/src/python/grpcio_tests/tests/unit/_contextvars_propagation_test.py +++ b/src/python/grpcio_tests/tests/unit/_contextvars_propagation_test.py @@ -127,15 +127,16 @@ class ContextVarsPropagationTest(unittest.TestCase): test_call_credentials, "test call credentials") composite_credentials = grpc.composite_channel_credentials( local_credentials, call_credentials) - start_event = threading.Event() + wait_group = test_common.WaitGroup(_THREAD_COUNT) def _run_on_thread(exception_queue): try: - for _ in range(_THREAD_COUNT): - with grpc.secure_channel( - target, composite_credentials) as channel: - start_event.wait() - stub = channel.unary_unary(_UNARY_UNARY) + with grpc.secure_channel(target, + composite_credentials) as channel: + stub = channel.unary_unary(_UNARY_UNARY) + wait_group.done() + wait_group.wait() + for i in range(_RPC_COUNT): response = stub(_REQUEST, wait_for_ready=True) self.assertEqual(_REQUEST, response) except Exception as e: # pylint: disable=broad-except @@ -148,7 +149,7 @@ class ContextVarsPropagationTest(unittest.TestCase): thread.setDaemon(True) thread.start() threads.append((thread, q)) - start_event.set() + for thread, q in threads: thread.join() if not q.empty():