From 2a17462367f59307cc4dd5477bfac68c7bb3352c Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sun, 15 Sep 2024 21:55:58 +0200 Subject: [PATCH] Fix `IS_TMP_WRITEABLE` order of operations (#16294) Co-authored-by: UltralyticsAssistant --- tests/__init__.py | 5 +++-- tests/test_python.py | 42 ++++++++++++++++++++++++++++++++++++------ 2 files changed, 39 insertions(+), 8 deletions(-) diff --git a/tests/__init__.py b/tests/__init__.py index 3356f1cadb..ea6b398292 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,13 +1,13 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -from ultralytics.utils import ASSETS, ROOT, WEIGHTS_DIR, checks, is_dir_writeable +from ultralytics.utils import ASSETS, ROOT, WEIGHTS_DIR, checks # Constants used in tests MODEL = WEIGHTS_DIR / "path with spaces" / "yolov8n.pt" # test spaces in path CFG = "yolov8n.yaml" SOURCE = ASSETS / "bus.jpg" +SOURCES_LIST = [ASSETS / "bus.jpg", ASSETS, ASSETS / "*", ASSETS / "**/*.jpg"] TMP = (ROOT / "../tests/tmp").resolve() # temp directory for test files -IS_TMP_WRITEABLE = is_dir_writeable(TMP) CUDA_IS_AVAILABLE = checks.cuda_is_available() CUDA_DEVICE_COUNT = checks.cuda_device_count() @@ -15,6 +15,7 @@ __all__ = ( "MODEL", "CFG", "SOURCE", + "SOURCES_LIST", "TMP", "IS_TMP_WRITEABLE", "CUDA_IS_AVAILABLE", diff --git a/tests/test_python.py b/tests/test_python.py index aa18029d75..b5dd0c883b 100644 --- a/tests/test_python.py +++ b/tests/test_python.py @@ -1,6 +1,7 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license import contextlib +import csv import urllib from copy import copy from pathlib import Path @@ -12,7 +13,7 @@ import torch import yaml from PIL import Image -from tests import CFG, IS_TMP_WRITEABLE, MODEL, SOURCE, TMP +from tests import CFG, MODEL, SOURCE, SOURCES_LIST, TMP from ultralytics import RTDETR, YOLO from ultralytics.cfg import MODELS, TASK2DATA, TASKS from ultralytics.data.build import load_inference_source @@ -26,11 +27,14 @@ from ultralytics.utils import ( WEIGHTS_DIR, WINDOWS, checks, + is_dir_writeable, is_github_action_running, ) from ultralytics.utils.downloads import download from ultralytics.utils.torch_utils import TORCH_1_9 +IS_TMP_WRITEABLE = is_dir_writeable(TMP) # WARNING: must be run once tests start as TMP does not exist on tests/init + def test_model_forward(): """Test the forward pass of the YOLO model.""" @@ -70,11 +74,37 @@ def test_model_profile(): @pytest.mark.skipif(not IS_TMP_WRITEABLE, reason="directory is not writeable") def test_predict_txt(): """Tests YOLO predictions with file, directory, and pattern sources listed in a text file.""" - txt_file = TMP / "sources.txt" - with open(txt_file, "w") as f: - for x in [ASSETS / "bus.jpg", ASSETS, ASSETS / "*", ASSETS / "**/*.jpg"]: - f.write(f"{x}\n") - _ = YOLO(MODEL)(source=txt_file, imgsz=32) + file = TMP / "sources_multi_row.txt" + with open(file, "w") as f: + for src in SOURCES_LIST: + f.write(f"{src}\n") + results = YOLO(MODEL)(source=file, imgsz=32) + assert len(results) == 7 # 1 + 2 + 2 + 2 = 7 images + + +@pytest.mark.skipif(True, reason="disabled for testing") +@pytest.mark.skipif(not IS_TMP_WRITEABLE, reason="directory is not writeable") +def test_predict_csv_multi_row(): + """Tests YOLO predictions with sources listed in multiple rows of a CSV file.""" + file = TMP / "sources_multi_row.csv" + with open(file, "w", newline="") as f: + writer = csv.writer(f) + writer.writerow(["source"]) + writer.writerows([[src] for src in SOURCES_LIST]) + results = YOLO(MODEL)(source=file, imgsz=32) + assert len(results) == 7 # 1 + 2 + 2 + 2 = 7 images + + +@pytest.mark.skipif(True, reason="disabled for testing") +@pytest.mark.skipif(not IS_TMP_WRITEABLE, reason="directory is not writeable") +def test_predict_csv_single_row(): + """Tests YOLO predictions with sources listed in a single row of a CSV file.""" + file = TMP / "sources_single_row.csv" + with open(file, "w", newline="") as f: + writer = csv.writer(f) + writer.writerow(SOURCES_LIST) + results = YOLO(MODEL)(source=file, imgsz=32) + assert len(results) == 7 # 1 + 2 + 2 + 2 = 7 images @pytest.mark.parametrize("model_name", MODELS)