From 70821ebe4ab23573c4eade8a8b6af95dc124469e Mon Sep 17 00:00:00 2001 From: Lidi Zheng Date: Wed, 23 Oct 2019 18:49:58 -0700 Subject: [PATCH] Supply the event loop to aio iomgr in test case level --- src/python/grpcio_tests/commands.py | 2 -- src/python/grpcio_tests/tests_aio/unit/BUILD.bazel | 7 +++++++ src/python/grpcio_tests/tests_aio/unit/channel_test.py | 9 +++++---- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/src/python/grpcio_tests/commands.py b/src/python/grpcio_tests/commands.py index 436ca4d3c7a..301a2bea23b 100644 --- a/src/python/grpcio_tests/commands.py +++ b/src/python/grpcio_tests/commands.py @@ -119,8 +119,6 @@ class TestAio(setuptools.Command): pass def run(self): - from grpc.experimental import aio - aio.init_grpc_aio() self._add_eggs_to_path() import tests diff --git a/src/python/grpcio_tests/tests_aio/unit/BUILD.bazel b/src/python/grpcio_tests/tests_aio/unit/BUILD.bazel index f635dc534be..82a12f9b7e4 100644 --- a/src/python/grpcio_tests/tests_aio/unit/BUILD.bazel +++ b/src/python/grpcio_tests/tests_aio/unit/BUILD.bazel @@ -20,6 +20,12 @@ package( GRPC_ASYNC_TESTS = glob(["*_test.py"]) +py_library( + name = "_test_base", + srcs_version = "PY3", + srcs = ["_test_base.py"], +) + py_library( name = "_test_server", srcs_version = "PY3", @@ -41,6 +47,7 @@ py_library( python_version="PY3", deps=[ ":_test_server", + ":_test_base", "//src/python/grpcio/grpc:grpcio", "//src/proto/grpc/testing:py_messages_proto", "//src/proto/grpc/testing:benchmark_service_py_pb2_grpc", diff --git a/src/python/grpcio_tests/tests_aio/unit/channel_test.py b/src/python/grpcio_tests/tests_aio/unit/channel_test.py index 3bacf3c01dd..4f95fcfb850 100644 --- a/src/python/grpcio_tests/tests_aio/unit/channel_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/channel_test.py @@ -22,9 +22,10 @@ from grpc.experimental import aio from src.proto.grpc.testing import messages_pb2 from tests.unit.framework.common import test_constants from tests_aio.unit._test_server import start_test_server +from tests_aio.unit._test_base import AioTestBase -class TestChannel(unittest.TestCase): +class TestChannel(AioTestBase): def test_async_context(self): @@ -40,7 +41,7 @@ class TestChannel(unittest.TestCase): ) await hi(messages_pb2.SimpleRequest()) - asyncio.get_event_loop().run_until_complete(coro()) + self._loop.run_until_complete(coro()) def test_unary_unary(self): @@ -58,7 +59,7 @@ class TestChannel(unittest.TestCase): await channel.close() - asyncio.get_event_loop().run_until_complete(coro()) + self._loop.run_until_complete(coro()) def test_unary_call_times_out(self): @@ -90,7 +91,7 @@ class TestChannel(unittest.TestCase): self.assertIsNotNone( exception_context.exception.trailing_metadata()) - asyncio.get_event_loop().run_until_complete(coro()) + self._loop.run_until_complete(coro()) if __name__ == '__main__':