Merge branch 'main' into exp

exp-a
Laughing-q 11 months ago
commit c480ac3127
  1. 8
      tests/test_engine.py
  2. 2
      ultralytics/__init__.py
  3. 8
      ultralytics/cfg/__init__.py
  4. 4
      ultralytics/engine/exporter.py
  5. 18
      ultralytics/engine/model.py
  6. 2
      ultralytics/engine/trainer.py
  7. 4
      ultralytics/hub/utils.py
  8. 3
      ultralytics/nn/autobackend.py
  9. 9
      ultralytics/utils/__init__.py

@ -1,5 +1,7 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import sys
from unittest import mock
from ultralytics import YOLO
from ultralytics.cfg import get_cfg
from ultralytics.engine.exporter import Exporter
@ -49,8 +51,10 @@ def test_detect():
pred = detect.DetectionPredictor(overrides={"imgsz": [64, 64]})
pred.add_callback("on_predict_start", test_func)
assert test_func in pred.callbacks["on_predict_start"], "callback test failed"
result = pred(source=ASSETS, model=f"{MODEL}.pt")
assert len(result), "predictor test failed"
# Confirm there is no issue with sys.argv being empty.
with mock.patch.object(sys, 'argv', []):
result = pred(source=ASSETS, model=f"{MODEL}.pt")
assert len(result), "predictor test failed"
overrides["resume"] = trainer.last
trainer = detect.DetectionTrainer(overrides=overrides)

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.1.35"
__version__ = "8.1.37"
from ultralytics.data.explorer.explorer import Explorer
from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld

@ -54,8 +54,9 @@ TASK2METRIC = {
"obb": "metrics/mAP50-95(B)",
}
ARGV = sys.argv or ["", ""] # sometimes sys.argv = []
CLI_HELP_MSG = f"""
Arguments received: {str(['yolo'] + sys.argv[1:])}. Ultralytics 'yolo' commands use the following syntax:
Arguments received: {str(['yolo'] + ARGV[1:])}. Ultralytics 'yolo' commands use the following syntax:
yolo TASK MODE ARGS
@ -93,7 +94,7 @@ CLI_HELP_MSG = f"""
"""
# Define keys for arg type checks
CFG_FLOAT_KEYS = {"warmup_epochs", "box", "cls", "dfl", "degrees", "shear", "time"}
CFG_FLOAT_KEYS = {"warmup_epochs", "box", "cls", "dfl", "degrees", "shear", "time", "workspace"}
CFG_FRACTION_KEYS = {
"dropout",
"iou",
@ -131,7 +132,6 @@ CFG_INT_KEYS = {
"max_det",
"vid_stride",
"line_width",
"workspace",
"nbs",
"save_period",
}
@ -452,7 +452,7 @@ def entrypoint(debug=""):
It uses the package's default cfg and initializes it using the passed overrides.
Then it calls the CLI function with the composed cfg
"""
args = (debug.split(" ") if debug else sys.argv)[1:]
args = (debug.split(" ") if debug else ARGV)[1:]
if not args: # no arguments passed
LOGGER.info(CLI_HELP_MSG)
return

@ -675,9 +675,7 @@ class Exporter:
builder = trt.Builder(logger)
config = builder.create_builder_config()
config.max_workspace_size = self.args.workspace * 1 << 30
# config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace << 30) # fix TRT 8.4 deprecation notice
config.max_workspace_size = int(self.args.workspace * (1 << 30))
flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
network = builder.create_network(flag)
parser = trt.OnnxParser(network, logger)

@ -1,7 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import inspect
import sys
from pathlib import Path
from typing import Union
@ -11,7 +10,18 @@ import torch
from ultralytics.cfg import TASK2DATA, get_cfg, get_save_dir
from ultralytics.hub.utils import HUB_WEB_ROOT
from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, nn, yaml_model_load
from ultralytics.utils import ASSETS, DEFAULT_CFG_DICT, LOGGER, RANK, SETTINGS, callbacks, checks, emojis, yaml_load
from ultralytics.utils import (
ARGV,
ASSETS,
DEFAULT_CFG_DICT,
LOGGER,
RANK,
SETTINGS,
callbacks,
checks,
emojis,
yaml_load,
)
class Model(nn.Module):
@ -421,8 +431,8 @@ class Model(nn.Module):
source = ASSETS
LOGGER.warning(f"WARNING ⚠ 'source' is missing. Using 'source={source}'.")
is_cli = (sys.argv[0].endswith("yolo") or sys.argv[0].endswith("ultralytics")) and any(
x in sys.argv for x in ("predict", "track", "mode=predict", "mode=track")
is_cli = (ARGV[0].endswith("yolo") or ARGV[0].endswith("ultralytics")) and any(
x in ARGV for x in ("predict", "track", "mode=predict", "mode=track")
)
custom = {"conf": 0.25, "batch": 1, "save": is_cli, "mode": "predict"} # method defaults

@ -422,7 +422,7 @@ class BaseTrainer:
self.lr = {f"lr/pg{ir}": x["lr"] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers
self.run_callbacks("on_train_epoch_end")
if RANK in (-1, 0):
final_epoch = epoch + 1 == self.epochs
final_epoch = epoch + 1 >= self.epochs
self.ema.update_attr(self.model, include=["yaml", "nc", "args", "names", "stride", "class_weights"])
# Validation

@ -3,7 +3,6 @@
import os
import platform
import random
import sys
import threading
import time
from pathlib import Path
@ -11,6 +10,7 @@ from pathlib import Path
import requests
from ultralytics.utils import (
ARGV,
ENVIRONMENT,
LOGGER,
ONLINE,
@ -188,7 +188,7 @@ class Events:
self.rate_limit = 60.0 # rate limit (seconds)
self.t = 0.0 # rate limit timer (seconds)
self.metadata = {
"cli": Path(sys.argv[0]).name == "yolo",
"cli": Path(ARGV[0]).name == "yolo",
"install": "git" if is_git_dir() else "pip" if is_pip_package() else "other",
"python": ".".join(platform.python_version_tuple()[:2]), # i.e. 3.10
"version": __version__,

@ -140,7 +140,8 @@ class AutoBackend(nn.Module):
# In-memory PyTorch model
if nn_module:
model = weights.to(device)
model = model.fuse(verbose=verbose) if fuse else model
if fuse:
model = model.fuse(verbose=verbose)
if hasattr(model, "kpt_shape"):
kpt_shape = model.kpt_shape # pose-only
stride = max(int(model.stride.max()), 32) # model stride

@ -30,6 +30,7 @@ RANK = int(os.getenv("RANK", -1))
LOCAL_RANK = int(os.getenv("LOCAL_RANK", -1)) # https://pytorch.org/docs/stable/elastic/run.html
# Other Constants
ARGV = sys.argv or ["", ""] # sometimes sys.argv = []
FILE = Path(__file__).resolve()
ROOT = FILE.parents[1] # YOLO
ASSETS = ROOT / "assets" # default images
@ -522,7 +523,7 @@ def is_pytest_running():
Returns:
(bool): True if pytest is running, False otherwise.
"""
return ("PYTEST_CURRENT_TEST" in os.environ) or ("pytest" in sys.modules) or ("pytest" in Path(sys.argv[0]).stem)
return ("PYTEST_CURRENT_TEST" in os.environ) or ("pytest" in sys.modules) or ("pytest" in Path(ARGV[0]).stem)
def is_github_action_running() -> bool:
@ -869,8 +870,8 @@ def set_sentry():
return None # do not send event
event["tags"] = {
"sys_argv": sys.argv[0],
"sys_argv_name": Path(sys.argv[0]).name,
"sys_argv": ARGV[0],
"sys_argv_name": Path(ARGV[0]).name,
"install": "git" if is_git_dir() else "pip" if is_pip_package() else "other",
"os": ENVIRONMENT,
}
@ -879,7 +880,7 @@ def set_sentry():
if (
SETTINGS["sync"]
and RANK in (-1, 0)
and Path(sys.argv[0]).name == "yolo"
and Path(ARGV[0]).name == "yolo"
and not TESTS_RUNNING
and ONLINE
and is_pip_package()

Loading…
Cancel
Save