|
|
|
@ -51,7 +51,7 @@ class _GenericHandler(grpc.GenericRpcHandler): |
|
|
|
|
def _server(): |
|
|
|
|
try: |
|
|
|
|
server = test_common.test_server() |
|
|
|
|
target = '[::]:0' |
|
|
|
|
target = 'localhost:0' |
|
|
|
|
port = server.add_insecure_port(target) |
|
|
|
|
server.add_generic_rpc_handlers((_GenericHandler(),)) |
|
|
|
|
server.start() |
|
|
|
@ -65,21 +65,28 @@ if contextvars_supported(): |
|
|
|
|
|
|
|
|
|
_EXPECTED_VALUE = 24601 |
|
|
|
|
test_var = contextvars.ContextVar("test_var", default=None) |
|
|
|
|
test_var.set(_EXPECTED_VALUE) |
|
|
|
|
|
|
|
|
|
def set_up_expected_context(): |
|
|
|
|
test_var.set(_EXPECTED_VALUE) |
|
|
|
|
|
|
|
|
|
class TestCallCredentials(grpc.AuthMetadataPlugin): |
|
|
|
|
|
|
|
|
|
def __init__(self): |
|
|
|
|
self._recorded_value = None |
|
|
|
|
self._invoked = False |
|
|
|
|
|
|
|
|
|
def __call__(self, context, callback): |
|
|
|
|
self._recorded_value = test_var.get() |
|
|
|
|
self._invoked = True |
|
|
|
|
callback((), None) |
|
|
|
|
|
|
|
|
|
def assert_called(self, test): |
|
|
|
|
test.assertTrue(self._invoked) |
|
|
|
|
test.assertEqual(_EXPECTED_VALUE, self._recorded_value) |
|
|
|
|
|
|
|
|
|
else: |
|
|
|
|
def set_up_expected_context(): |
|
|
|
|
pass |
|
|
|
|
|
|
|
|
|
class TestCallCredentials(grpc.AuthMetadataPlugin): |
|
|
|
|
|
|
|
|
@ -93,6 +100,7 @@ else: |
|
|
|
|
class ContextVarsPropagationTest(unittest.TestCase): |
|
|
|
|
|
|
|
|
|
def test_propagation_to_auth_plugin(self): |
|
|
|
|
set_up_expected_context() |
|
|
|
|
with _server() as port: |
|
|
|
|
target = "localhost:{}".format(port) |
|
|
|
|
local_credentials = grpc.local_channel_credentials() |
|
|
|
|