diff --git a/mesonbuild/mtest.py b/mesonbuild/mtest.py index 2caa3678d..59ddec112 100644 --- a/mesonbuild/mtest.py +++ b/mesonbuild/mtest.py @@ -608,6 +608,13 @@ def load_tests(build_dir: str) -> T.List[TestSerialisation]: # Custom waiting primitives for asyncio +async def complete(future: asyncio.Future) -> None: + """Wait for completion of the given future, ignoring cancellation.""" + try: + await future + except asyncio.CancelledError: + pass + async def complete_all(futures: T.Iterable[asyncio.Future]) -> None: """Wait for completion of all the given futures, ignoring cancellation.""" while futures: @@ -1161,10 +1168,11 @@ class TestHarness: if self.options.wd: os.chdir(self.options.wd) self.build_data = build.load(os.getcwd()) + interrupted = False async def run_test(test: SingleTestRunner, name: str, index: int) -> None: - if self.options.repeat > 1 and self.fail_count: + if interrupted or (self.options.repeat > 1 and self.fail_count): return res = await asyncio.get_event_loop().run_in_executor(executor, test.run) self.process_test_result(res) @@ -1175,6 +1183,17 @@ class TestHarness: f.result() futures.remove(f) + def cancel_all_futures() -> None: + nonlocal interrupted + if interrupted: + return + interrupted = True + mlog.warning('CTRL-C detected, interrupting') + for f in futures: + f.cancel() + + if sys.platform != 'win32': + asyncio.get_event_loop().add_signal_handler(signal.SIGINT, cancel_all_futures) try: for _ in range(self.options.repeat): for i, test in enumerate(tests, 1): @@ -1183,11 +1202,11 @@ class TestHarness: if not test.is_parallel or single_test.options.gdb: await complete_all(futures) - await run_test(single_test, visible_name, i) - else: - future = asyncio.ensure_future(run_test(single_test, visible_name, i)) - futures.append(future) - future.add_done_callback(test_done) + future = asyncio.ensure_future(run_test(single_test, visible_name, i)) + futures.append(future) + future.add_done_callback(test_done) + if not test.is_parallel or single_test.options.gdb: + await complete(future) if self.options.repeat > 1 and self.fail_count: break @@ -1198,6 +1217,8 @@ class TestHarness: if self.logfilename: print('Full log written to {}'.format(self.logfilename)) finally: + if sys.platform != 'win32': + asyncio.get_event_loop().remove_signal_handler(signal.SIGINT) os.chdir(startdir) def list_tests(th: TestHarness) -> bool: