diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/fork_posix.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/fork_posix.pyx.pxi index 55c8673dd4d..53657e8b1a9 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/fork_posix.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/fork_posix.pyx.pxi @@ -94,6 +94,8 @@ def fork_handlers_and_grpc_init(): _fork_state.fork_handler_registered = True + + class ForkManagedThread(object): def __init__(self, target, args=()): if _GRPC_ENABLE_FORK_SUPPORT: @@ -102,9 +104,9 @@ class ForkManagedThread(object): target(*args) finally: _fork_state.active_thread_count.decrement() - self._thread = threading.Thread(target=managed_target, args=args) + self._thread = threading.Thread(target=_run_with_context(managed_target), args=args) else: - self._thread = threading.Thread(target=target, args=args) + self._thread = threading.Thread(target=_run_with_context(target), args=args) def setDaemon(self, daemonic): self._thread.daemon = daemonic diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/fork_windows.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/fork_windows.pyx.pxi index 9167cb45173..67aaf4d033d 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/fork_windows.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/fork_windows.pyx.pxi @@ -21,7 +21,7 @@ def fork_handlers_and_grpc_init(): class ForkManagedThread(object): def __init__(self, target, args=()): - self._thread = threading.Thread(target=target, args=args) + self._thread = threading.Thread(target=_run_with_context(target), args=args) def setDaemon(self, daemonic): self._thread.daemon = daemonic diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/thread.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/thread.pyx.pxi new file mode 100644 index 00000000000..be4cb8b9a8e --- /dev/null +++ b/src/python/grpcio/grpc/_cython/_cygrpc/thread.pyx.pxi @@ -0,0 +1,59 @@ +# Copyright 2020 The gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +def _contextvars_supported(): + """Determines if the contextvars module is supported. + + We use a 'try it and see if it works approach' here rather than predicting + based on interpreter version in order to support older interpreters that + may have a backported module based on, e.g. `threading.local`. + + Returns: + A bool indicating whether `contextvars` are supported in the current + environment. + """ + try: + import contextvars + return True + except ImportError: + return False + + +def _run_with_context(target): + """Runs a callable with contextvars propagated. + + If contextvars are supported, the calling thread's context will be copied + and propagated. If they are not supported, this function is equivalent + to the identity function. + + Args: + target: A callable object to wrap. + Returns: + A callable object with the same signature as `target` but with + contextvars propagated. + """ + + +if _contextvars_supported(): + import contextvars + def _run_with_context(target): + ctx = contextvars.copy_context() + def _run(*args): + ctx.run(target, *args) + return _run +else: + def _run_with_context(target): + def _run(*args): + target(*args) + return _run diff --git a/src/python/grpcio/grpc/_cython/cygrpc.pyx b/src/python/grpcio/grpc/_cython/cygrpc.pyx index b0a753c7ebe..0ce8bda0d89 100644 --- a/src/python/grpcio/grpc/_cython/cygrpc.pyx +++ b/src/python/grpcio/grpc/_cython/cygrpc.pyx @@ -59,6 +59,8 @@ include "_cygrpc/iomgr.pyx.pxi" include "_cygrpc/grpc_gevent.pyx.pxi" +include "_cygrpc/thread.pyx.pxi" + IF UNAME_SYSNAME == "Windows": include "_cygrpc/fork_windows.pyx.pxi" ELSE: diff --git a/src/python/grpcio_tests/commands.py b/src/python/grpcio_tests/commands.py index f7cd7c6b8a1..f6b5859ced4 100644 --- a/src/python/grpcio_tests/commands.py +++ b/src/python/grpcio_tests/commands.py @@ -220,6 +220,9 @@ class TestGevent(setuptools.Command): 'unit._cython._channel_test.ChannelTest.test_negative_deadline_connectivity', # TODO(https://github.com/grpc/grpc/issues/15411) enable this test 'unit._local_credentials_test.LocalCredentialsTest', + # TODO(https://github.com/grpc/grpc/issues/22020) LocalCredentials + # aren't supported with custom io managers. + 'unit._contextvars_propagation_test', 'testing._time_test.StrictRealTimeTest', ) BANNED_WINDOWS_TESTS = ( diff --git a/src/python/grpcio_tests/tests/tests.json b/src/python/grpcio_tests/tests/tests.json index 196e9f08b0a..c7d913f49ca 100644 --- a/src/python/grpcio_tests/tests/tests.json +++ b/src/python/grpcio_tests/tests/tests.json @@ -35,6 +35,7 @@ "unit._channel_connectivity_test.ChannelConnectivityTest", "unit._channel_ready_future_test.ChannelReadyFutureTest", "unit._compression_test.CompressionTest", + "unit._contextvars_propagation_test.ContextVarsPropagationTest", "unit._credentials_test.CredentialsTest", "unit._cython._cancel_many_calls_test.CancelManyCallsTest", "unit._cython._channel_test.ChannelTest", diff --git a/src/python/grpcio_tests/tests/unit/BUILD.bazel b/src/python/grpcio_tests/tests/unit/BUILD.bazel index 42b99023463..690397942cc 100644 --- a/src/python/grpcio_tests/tests/unit/BUILD.bazel +++ b/src/python/grpcio_tests/tests/unit/BUILD.bazel @@ -13,6 +13,7 @@ GRPCIO_TESTS_UNIT = [ "_channel_connectivity_test.py", "_channel_ready_future_test.py", "_compression_test.py", + "_contextvars_propagation_test.py", "_credentials_test.py", "_dns_resolver_test.py", "_empty_message_test.py", diff --git a/src/python/grpcio_tests/tests/unit/_contextvars_propagation_test.py b/src/python/grpcio_tests/tests/unit/_contextvars_propagation_test.py new file mode 100644 index 00000000000..fec0fbd7df4 --- /dev/null +++ b/src/python/grpcio_tests/tests/unit/_contextvars_propagation_test.py @@ -0,0 +1,118 @@ +# Copyright 2020 The gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test of propagation of contextvars to AuthMetadataPlugin threads..""" + +import contextlib +import logging +import os +import sys +import unittest + +import grpc + +from tests.unit import test_common + +_UNARY_UNARY = "/test/UnaryUnary" +_REQUEST = b"0000" + + +def _unary_unary_handler(request, context): + return request + + +def contextvars_supported(): + try: + import contextvars + return True + except ImportError: + return False + + +class _GenericHandler(grpc.GenericRpcHandler): + + def service(self, handler_call_details): + if handler_call_details.method == _UNARY_UNARY: + return grpc.unary_unary_rpc_method_handler(_unary_unary_handler) + else: + raise NotImplementedError() + + +@contextlib.contextmanager +def _server(): + try: + server = test_common.test_server() + target = 'localhost:0' + port = server.add_insecure_port(target) + server.add_generic_rpc_handlers((_GenericHandler(),)) + server.start() + yield port + finally: + server.stop(None) + + +if contextvars_supported(): + import contextvars + + _EXPECTED_VALUE = 24601 + test_var = contextvars.ContextVar("test_var", default=None) + + def set_up_expected_context(): + test_var.set(_EXPECTED_VALUE) + + class TestCallCredentials(grpc.AuthMetadataPlugin): + + def __call__(self, context, callback): + if test_var.get() != _EXPECTED_VALUE: + raise AssertionError("{} != {}".format(test_var.get(), + _EXPECTED_VALUE)) + callback((), None) + + def assert_called(self, test): + test.assertTrue(self._invoked) + test.assertEqual(_EXPECTED_VALUE, self._recorded_value) + +else: + + def set_up_expected_context(): + pass + + class TestCallCredentials(grpc.AuthMetadataPlugin): + + def __call__(self, context, callback): + callback((), None) + + +# TODO(https://github.com/grpc/grpc/issues/22257) +@unittest.skipIf(os.name == "nt", "LocalCredentials not supported on Windows.") +class ContextVarsPropagationTest(unittest.TestCase): + + def test_propagation_to_auth_plugin(self): + set_up_expected_context() + with _server() as port: + target = "localhost:{}".format(port) + local_credentials = grpc.local_channel_credentials() + test_call_credentials = TestCallCredentials() + call_credentials = grpc.metadata_call_credentials( + test_call_credentials, "test call credentials") + composite_credentials = grpc.composite_channel_credentials( + local_credentials, call_credentials) + with grpc.secure_channel(target, composite_credentials) as channel: + stub = channel.unary_unary(_UNARY_UNARY) + response = stub(_REQUEST, wait_for_ready=True) + self.assertEqual(_REQUEST, response) + + +if __name__ == '__main__': + logging.basicConfig() + unittest.main(verbosity=2)