Fix metadata plugin for concurrent invocations (#26405)

* Fix metadata plugin for concurrent invocations

* Disable broad-except lint error
reviewable/pr26409/r1
Richard Belleville 4 years ago committed by GitHub
parent 21ddbabf62
commit 67f0330c10
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx.pxi
  2. 42
      src/python/grpcio_tests/tests/unit/_contextvars_propagation_test.py

@ -59,7 +59,7 @@ cdef int _get_metadata(void *state,
args = context.service_url, context.method_name, callback,
plugin = <object>state
if plugin._stored_ctx is not None:
plugin._stored_ctx.run(_spawn_callback_async, plugin, args)
plugin._stored_ctx.copy().run(_spawn_callback_async, plugin, args)
else:
_spawn_callback_async(<object>state, args)
return 0 # Asynchronous return

@ -17,11 +17,13 @@ import contextlib
import logging
import os
import sys
import threading
import unittest
import grpc
from tests.unit import test_common
from six.moves import queue
_UNARY_UNARY = "/test/UnaryUnary"
_REQUEST = b"0000"
@ -112,6 +114,46 @@ class ContextVarsPropagationTest(unittest.TestCase):
response = stub(_REQUEST, wait_for_ready=True)
self.assertEqual(_REQUEST, response)
def test_concurrent_propagation(self):
_THREAD_COUNT = 32
_RPC_COUNT = 32
set_up_expected_context()
with _server() as port:
target = "localhost:{}".format(port)
local_credentials = grpc.local_channel_credentials()
test_call_credentials = TestCallCredentials()
call_credentials = grpc.metadata_call_credentials(
test_call_credentials, "test call credentials")
composite_credentials = grpc.composite_channel_credentials(
local_credentials, call_credentials)
start_event = threading.Event()
def _run_on_thread(exception_queue):
try:
for _ in range(_THREAD_COUNT):
with grpc.secure_channel(
target, composite_credentials) as channel:
start_event.wait()
stub = channel.unary_unary(_UNARY_UNARY)
response = stub(_REQUEST, wait_for_ready=True)
self.assertEqual(_REQUEST, response)
except Exception as e: # pylint: disable=broad-except
exception_queue.put(e)
threads = []
for _ in range(_RPC_COUNT):
q = queue.Queue()
thread = threading.Thread(target=_run_on_thread, args=(q,))
thread.setDaemon(True)
thread.start()
threads.append((thread, q))
start_event.set()
for thread, q in threads:
thread.join()
if not q.empty():
raise q.get()
if __name__ == '__main__':
logging.basicConfig()

Loading…
Cancel
Save