|
|
|
@ -1,6 +1,7 @@ |
|
|
|
|
# Ultralytics YOLO 🚀, AGPL-3.0 license |
|
|
|
|
|
|
|
|
|
import contextlib |
|
|
|
|
import csv |
|
|
|
|
import urllib |
|
|
|
|
from copy import copy |
|
|
|
|
from pathlib import Path |
|
|
|
@ -12,7 +13,7 @@ import torch |
|
|
|
|
import yaml |
|
|
|
|
from PIL import Image |
|
|
|
|
|
|
|
|
|
from tests import CFG, IS_TMP_WRITEABLE, MODEL, SOURCE, TMP |
|
|
|
|
from tests import CFG, MODEL, SOURCE, SOURCES_LIST, TMP |
|
|
|
|
from ultralytics import RTDETR, YOLO |
|
|
|
|
from ultralytics.cfg import MODELS, TASK2DATA, TASKS |
|
|
|
|
from ultralytics.data.build import load_inference_source |
|
|
|
@ -26,11 +27,14 @@ from ultralytics.utils import ( |
|
|
|
|
WEIGHTS_DIR, |
|
|
|
|
WINDOWS, |
|
|
|
|
checks, |
|
|
|
|
is_dir_writeable, |
|
|
|
|
is_github_action_running, |
|
|
|
|
) |
|
|
|
|
from ultralytics.utils.downloads import download |
|
|
|
|
from ultralytics.utils.torch_utils import TORCH_1_9 |
|
|
|
|
|
|
|
|
|
IS_TMP_WRITEABLE = is_dir_writeable(TMP) # WARNING: must be run once tests start as TMP does not exist on tests/init |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_model_forward(): |
|
|
|
|
"""Test the forward pass of the YOLO model.""" |
|
|
|
@ -70,11 +74,37 @@ def test_model_profile(): |
|
|
|
|
@pytest.mark.skipif(not IS_TMP_WRITEABLE, reason="directory is not writeable") |
|
|
|
|
def test_predict_txt(): |
|
|
|
|
"""Tests YOLO predictions with file, directory, and pattern sources listed in a text file.""" |
|
|
|
|
txt_file = TMP / "sources.txt" |
|
|
|
|
with open(txt_file, "w") as f: |
|
|
|
|
for x in [ASSETS / "bus.jpg", ASSETS, ASSETS / "*", ASSETS / "**/*.jpg"]: |
|
|
|
|
f.write(f"{x}\n") |
|
|
|
|
_ = YOLO(MODEL)(source=txt_file, imgsz=32) |
|
|
|
|
file = TMP / "sources_multi_row.txt" |
|
|
|
|
with open(file, "w") as f: |
|
|
|
|
for src in SOURCES_LIST: |
|
|
|
|
f.write(f"{src}\n") |
|
|
|
|
results = YOLO(MODEL)(source=file, imgsz=32) |
|
|
|
|
assert len(results) == 7 # 1 + 2 + 2 + 2 = 7 images |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.skipif(True, reason="disabled for testing") |
|
|
|
|
@pytest.mark.skipif(not IS_TMP_WRITEABLE, reason="directory is not writeable") |
|
|
|
|
def test_predict_csv_multi_row(): |
|
|
|
|
"""Tests YOLO predictions with sources listed in multiple rows of a CSV file.""" |
|
|
|
|
file = TMP / "sources_multi_row.csv" |
|
|
|
|
with open(file, "w", newline="") as f: |
|
|
|
|
writer = csv.writer(f) |
|
|
|
|
writer.writerow(["source"]) |
|
|
|
|
writer.writerows([[src] for src in SOURCES_LIST]) |
|
|
|
|
results = YOLO(MODEL)(source=file, imgsz=32) |
|
|
|
|
assert len(results) == 7 # 1 + 2 + 2 + 2 = 7 images |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.skipif(True, reason="disabled for testing") |
|
|
|
|
@pytest.mark.skipif(not IS_TMP_WRITEABLE, reason="directory is not writeable") |
|
|
|
|
def test_predict_csv_single_row(): |
|
|
|
|
"""Tests YOLO predictions with sources listed in a single row of a CSV file.""" |
|
|
|
|
file = TMP / "sources_single_row.csv" |
|
|
|
|
with open(file, "w", newline="") as f: |
|
|
|
|
writer = csv.writer(f) |
|
|
|
|
writer.writerow(SOURCES_LIST) |
|
|
|
|
results = YOLO(MODEL)(source=file, imgsz=32) |
|
|
|
|
assert len(results) == 7 # 1 + 2 + 2 + 2 = 7 images |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("model_name", MODELS) |
|
|
|
|