|
|
|
@ -17,11 +17,13 @@ import contextlib |
|
|
|
|
import logging |
|
|
|
|
import os |
|
|
|
|
import sys |
|
|
|
|
import threading |
|
|
|
|
import unittest |
|
|
|
|
|
|
|
|
|
import grpc |
|
|
|
|
|
|
|
|
|
from tests.unit import test_common |
|
|
|
|
from six.moves import queue |
|
|
|
|
|
|
|
|
|
_UNARY_UNARY = "/test/UnaryUnary" |
|
|
|
|
_REQUEST = b"0000" |
|
|
|
@ -112,6 +114,46 @@ class ContextVarsPropagationTest(unittest.TestCase): |
|
|
|
|
response = stub(_REQUEST, wait_for_ready=True) |
|
|
|
|
self.assertEqual(_REQUEST, response) |
|
|
|
|
|
|
|
|
|
def test_concurrent_propagation(self): |
|
|
|
|
_THREAD_COUNT = 32 |
|
|
|
|
_RPC_COUNT = 32 |
|
|
|
|
|
|
|
|
|
set_up_expected_context() |
|
|
|
|
with _server() as port: |
|
|
|
|
target = "localhost:{}".format(port) |
|
|
|
|
local_credentials = grpc.local_channel_credentials() |
|
|
|
|
test_call_credentials = TestCallCredentials() |
|
|
|
|
call_credentials = grpc.metadata_call_credentials( |
|
|
|
|
test_call_credentials, "test call credentials") |
|
|
|
|
composite_credentials = grpc.composite_channel_credentials( |
|
|
|
|
local_credentials, call_credentials) |
|
|
|
|
start_event = threading.Event() |
|
|
|
|
|
|
|
|
|
def _run_on_thread(exception_queue): |
|
|
|
|
try: |
|
|
|
|
for _ in range(_THREAD_COUNT): |
|
|
|
|
with grpc.secure_channel( |
|
|
|
|
target, composite_credentials) as channel: |
|
|
|
|
start_event.wait() |
|
|
|
|
stub = channel.unary_unary(_UNARY_UNARY) |
|
|
|
|
response = stub(_REQUEST, wait_for_ready=True) |
|
|
|
|
self.assertEqual(_REQUEST, response) |
|
|
|
|
except Exception as e: # pylint: disable=broad-except |
|
|
|
|
exception_queue.put(e) |
|
|
|
|
|
|
|
|
|
threads = [] |
|
|
|
|
for _ in range(_RPC_COUNT): |
|
|
|
|
q = queue.Queue() |
|
|
|
|
thread = threading.Thread(target=_run_on_thread, args=(q,)) |
|
|
|
|
thread.setDaemon(True) |
|
|
|
|
thread.start() |
|
|
|
|
threads.append((thread, q)) |
|
|
|
|
start_event.set() |
|
|
|
|
for thread, q in threads: |
|
|
|
|
thread.join() |
|
|
|
|
if not q.empty(): |
|
|
|
|
raise q.get() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
logging.basicConfig() |
|
|
|
|