`ultralytics 8.0.21` Windows, segments, YAML fixes (#655)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: corey-nm <109536191+corey-nm@users.noreply.github.com>
pull/640/head v8.0.21
Glenn Jocher 2 years ago committed by GitHub
parent dc9502c700
commit 6c44ce21d9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 4
      .github/ISSUE_TEMPLATE/bug-report.yml
  2. 13
      tests/test_cli.py
  3. 8
      tests/test_python.py
  4. 2
      ultralytics/__init__.py
  5. 26
      ultralytics/yolo/cfg/__init__.py
  6. 11
      ultralytics/yolo/data/dataloaders/stream_loaders.py
  7. 10
      ultralytics/yolo/data/dataloaders/v5loader.py
  8. 2
      ultralytics/yolo/engine/exporter.py
  9. 30
      ultralytics/yolo/engine/model.py
  10. 20
      ultralytics/yolo/engine/predictor.py
  11. 3
      ultralytics/yolo/engine/trainer.py
  12. 2
      ultralytics/yolo/engine/validator.py
  13. 117
      ultralytics/yolo/utils/__init__.py
  14. 6
      ultralytics/yolo/utils/checks.py
  15. 2
      ultralytics/yolo/utils/downloads.py
  16. 39
      ultralytics/yolo/utils/torch_utils.py

@ -51,9 +51,9 @@ body:
label: Environment
description: Please specify the software and hardware you used to produce the bug.
placeholder: |
- YOLO: YOLOv8 🚀 v6.0-67-g60e42e1 torch 1.9.0+cu111 CUDA:0 (A100-SXM4-40GB, 40536MiB)
- YOLO: Ultralytics YOLOv8.0.21 🚀 Python-3.8.10 torch-1.13.1+cu117 CUDA:0 (A100-SXM-80GB, 81251MiB)
- OS: Ubuntu 20.04
- Python: 3.9.0
- Python: 3.8.10
validations:
required: false

@ -35,28 +35,29 @@ def test_train_cls():
# Val checks -----------------------------------------------------------------------------------------------------------
def test_val_detect():
run(f'yolo val detect model={MODEL}.pt data=coco8.yaml imgsz=32 epochs=1')
run(f'yolo val detect model={MODEL}.pt data=coco8.yaml imgsz=32')
def test_val_segment():
run(f'yolo val segment model={MODEL}-seg.pt data=coco8-seg.yaml imgsz=32 epochs=1')
run(f'yolo val segment model={MODEL}-seg.pt data=coco8-seg.yaml imgsz=32')
def test_val_classify():
pass
run(f'yolo val classify model={MODEL}-cls.pt data=mnist160 imgsz=32')
# Predict checks -------------------------------------------------------------------------------------------------------
def test_predict_detect():
run(f"yolo predict detect model={MODEL}.pt source={ROOT / 'assets'} imgsz=320 conf=0.25")
run(f"yolo predict detect model={MODEL}.pt source={ROOT / 'assets'} imgsz=32")
run(f"yolo predict detect model={MODEL}.pt source=https://ultralytics.com/images/bus.jpg imgsz=32")
def test_predict_segment():
run(f"yolo predict segment model={MODEL}-seg.pt source={ROOT / 'assets'}")
run(f"yolo predict segment model={MODEL}-seg.pt source={ROOT / 'assets'} imgsz=32")
def test_predict_classify():
pass
run(f"yolo predict segment model={MODEL}-cls.pt source={ROOT / 'assets'} imgsz=32")
# Export checks --------------------------------------------------------------------------------------------------------

@ -111,9 +111,11 @@ def test_export_coreml():
model.export(format='coreml')
def test_export_paddle():
model = YOLO(MODEL)
model.export(format='paddle')
def test_export_paddle(enabled=False):
# Paddle protobuf requirements conflicting with onnx protobuf requirements
if enabled:
model = YOLO(MODEL)
model.export(format='paddle')
def test_all_model_yamls():

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, GPL-3.0 license
__version__ = "8.0.20"
__version__ = "8.0.21"
from ultralytics.yolo.engine.model import YOLO
from ultralytics.yolo.utils import ops

@ -9,8 +9,8 @@ from types import SimpleNamespace
from typing import Dict, List, Union
from ultralytics import __version__
from ultralytics.yolo.utils import (DEFAULT_CFG_DICT, DEFAULT_CFG_PATH, LOGGER, PREFIX, ROOT, USER_CONFIG_DIR,
IterableSimpleNamespace, colorstr, yaml_load, yaml_print)
from ultralytics.yolo.utils import (DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_PATH, LOGGER, PREFIX, ROOT,
USER_CONFIG_DIR, IterableSimpleNamespace, colorstr, emojis, yaml_load, yaml_print)
from ultralytics.yolo.utils.checks import check_yolo
CLI_HELP_MSG = \
@ -69,7 +69,7 @@ def cfg2dict(cfg):
return cfg
def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace], overrides: Dict = None):
def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = DEFAULT_CFG, overrides: Dict = None):
"""
Load and merge configuration data from a file or dictionary.
@ -214,17 +214,19 @@ def entrypoint(debug=False):
# Mode
mode = overrides.pop('mode', None)
model = overrides.pop('model', None)
if mode == 'checks':
if mode is None:
mode = DEFAULT_CFG.mode or 'predict'
LOGGER.warning(f"WARNING ⚠ 'mode' is missing. Valid modes are {modes}. Using default 'mode={mode}'.")
elif mode not in modes:
if mode != 'checks':
raise ValueError(emojis(f"ERROR ❌ Invalid 'mode={mode}'. Valid modes are {modes}."))
LOGGER.warning("WARNING ⚠ 'yolo mode=checks' is deprecated. Use 'yolo checks' instead.")
check_yolo()
return
elif mode is None:
mode = DEFAULT_CFG_DICT['mode'] or 'predict'
LOGGER.warning(f"WARNING ⚠ 'mode' is missing. Valid modes are {modes}. Using default 'mode={mode}'.")
# Model
if model is None:
model = DEFAULT_CFG_DICT['model'] or 'yolov8n.pt'
model = DEFAULT_CFG.model or 'yolov8n.pt'
LOGGER.warning(f"WARNING ⚠ 'model' is missing. Using default 'model={model}'.")
from ultralytics.yolo.engine.model import YOLO
model = YOLO(model)
@ -232,21 +234,21 @@ def entrypoint(debug=False):
# Task
if mode == 'predict' and 'source' not in overrides:
overrides['source'] = DEFAULT_CFG_DICT['source'] or ROOT / "assets" if (ROOT / "assets").exists() \
overrides['source'] = DEFAULT_CFG.source or ROOT / "assets" if (ROOT / "assets").exists() \
else "https://ultralytics.com/images/bus.jpg"
LOGGER.warning(f"WARNING ⚠ 'source' is missing. Using default 'source={overrides['source']}'.")
elif mode in ('train', 'val'):
if 'data' not in overrides:
overrides['data'] = DEFAULT_CFG_DICT['data'] or 'mnist160' if task == 'classify' \
overrides['data'] = DEFAULT_CFG.data or 'mnist160' if task == 'classify' \
else 'coco128-seg.yaml' if task == 'segment' else 'coco128.yaml'
LOGGER.warning(f"WARNING ⚠ 'data' is missing. Using default 'data={overrides['data']}'.")
elif mode == 'export':
if 'format' not in overrides:
overrides['format'] = DEFAULT_CFG_DICT['format'] or 'torchscript'
overrides['format'] = DEFAULT_CFG.format or 'torchscript'
LOGGER.warning(f"WARNING ⚠ 'format' is missing. Using default 'format={overrides['format']}'.")
# Run command in python
getattr(model, mode)(verbose=True, **overrides)
getattr(model, mode)(**overrides)
# Special modes --------------------------------------------------------------------------------------------------------

@ -44,7 +44,8 @@ class LoadStreams:
assert not is_colab(), '--source 0 webcam unsupported on Colab. Rerun command in a local environment.'
assert not is_kaggle(), '--source 0 webcam unsupported on Kaggle. Rerun command in a local environment.'
cap = cv2.VideoCapture(s)
assert cap.isOpened(), f'{st}Failed to open {s}'
if not cap.isOpened():
raise ConnectionError(f'{st}Failed to open {s}')
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = cap.get(cv2.CAP_PROP_FPS) # warning: may return 0 or nan
@ -188,8 +189,9 @@ class LoadImages:
self._new_video(videos[0]) # new video
else:
self.cap = None
assert self.nf > 0, f'No images or videos found in {p}. ' \
f'Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}'
if self.nf == 0:
raise FileNotFoundError(f'No images or videos found in {p}. '
f'Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}')
def __iter__(self):
self.count = 0
@ -223,7 +225,8 @@ class LoadImages:
# Read image
self.count += 1
im0 = cv2.imread(path) # BGR
assert im0 is not None, f'Image Not Found {path}'
if im0 is None:
raise FileNotFoundError(f'Image Not Found {path}')
s = f'image {self.count}/{self.nf} {path}: '
if self.transforms:

@ -23,14 +23,13 @@ import numpy as np
import psutil
import torch
import torchvision
import yaml
from PIL import ExifTags, Image, ImageOps
from torch.utils.data import DataLoader, Dataset, dataloader, distributed
from tqdm import tqdm
from ultralytics.yolo.data.utils import check_det_dataset, unzip_file
from ultralytics.yolo.utils import (DATASETS_DIR, LOGGER, NUM_THREADS, TQDM_BAR_FORMAT, is_colab, is_dir_writeable,
is_kaggle)
is_kaggle, yaml_load)
from ultralytics.yolo.utils.checks import check_requirements, check_yaml
from ultralytics.yolo.utils.ops import clean_str, segments2boxes, xyn2xy, xywh2xyxy, xywhn2xyxy, xyxy2xywhn
from ultralytics.yolo.utils.torch_utils import torch_distributed_zero_first
@ -1056,10 +1055,9 @@ class HUBDatasetStats():
# Initialize class
zipped, data_dir, yaml_path = self._unzip(Path(path))
try:
with open(check_yaml(yaml_path), errors='ignore') as f:
data = yaml.safe_load(f) # data dict
if zipped:
data['path'] = data_dir
data = yaml_load(check_yaml(yaml_path)) # data dict
if zipped:
data['path'] = data_dir
except Exception as e:
raise Exception("error/HUB/dataset_stats/yaml_load") from e

@ -129,7 +129,7 @@ class Exporter:
overrides (dict, optional): Configuration overrides. Defaults to None.
"""
self.args = get_cfg(cfg, overrides)
self.callbacks = defaultdict(list, {k: v for k, v in callbacks.default_callbacks.items()}) # add callbacks
self.callbacks = defaultdict(list, callbacks.default_callbacks) # add callbacks
callbacks.add_integration_callbacks(self)
@smart_inference_mode()

@ -61,8 +61,8 @@ class YOLO:
else:
raise NotImplementedError(f"'{suffix}' model loading not implemented")
def __call__(self, source=None, stream=False, verbose=False, **kwargs):
return self.predict(source, stream, verbose, **kwargs)
def __call__(self, source=None, stream=False, **kwargs):
return self.predict(source, stream, **kwargs)
def _new(self, cfg: str, verbose=True):
"""
@ -118,7 +118,7 @@ class YOLO:
self.model.fuse()
@smart_inference_mode()
def predict(self, source=None, stream=False, verbose=False, **kwargs):
def predict(self, source=None, stream=False, **kwargs):
"""
Perform prediction using the YOLO model.
@ -126,7 +126,6 @@ class YOLO:
source (str | int | PIL | np.ndarray): The source of the image to make predictions on.
Accepts all source types accepted by the YOLO model.
stream (bool): Whether to stream the predictions or not. Defaults to False.
verbose (bool): Whether to print verbose information or not. Defaults to False.
**kwargs : Additional keyword arguments passed to the predictor.
Check the 'configuration' section in the documentation for all available options.
@ -143,7 +142,7 @@ class YOLO:
self.predictor.setup_model(model=self.model)
else: # only update args if predictor is already setup
self.predictor.args = get_cfg(self.predictor.args, overrides)
return self.predictor(source=source, stream=stream, verbose=verbose)
return self.predictor(source=source, stream=stream)
@smart_inference_mode()
def val(self, data=None, **kwargs):
@ -234,7 +233,8 @@ class YOLO:
"""
return self.model.names
def add_callback(self, event: str, func):
@staticmethod
def add_callback(event: str, func):
"""
Add callback
"""
@ -242,16 +242,8 @@ class YOLO:
@staticmethod
def _reset_ckpt_args(args):
args.pop("project", None)
args.pop("name", None)
args.pop("exist_ok", None)
args.pop("resume", None)
args.pop("batch", None)
args.pop("epochs", None)
args.pop("cache", None)
args.pop("save_json", None)
args.pop("half", None)
args.pop("v5loader", None)
# set device to '' to prevent from auto DDP usage
args["device"] = ''
for arg in 'verbose', 'project', 'name', 'exist_ok', 'resume', 'batch', 'epochs', 'cache', 'save_json', \
'half', 'v5loader':
args.pop(arg, None)
args["device"] = '' # set device to '' to prevent auto-DDP usage

@ -88,7 +88,7 @@ class BasePredictor:
self.vid_path, self.vid_writer = None, None
self.annotator = None
self.data_path = None
self.callbacks = defaultdict(list, {k: v for k, v in callbacks.default_callbacks.items()}) # add callbacks
self.callbacks = defaultdict(list, callbacks.default_callbacks) # add callbacks
callbacks.add_integration_callbacks(self)
def preprocess(self, img):
@ -151,19 +151,19 @@ class BasePredictor:
self.bs = bs
@smart_inference_mode()
def __call__(self, source=None, model=None, verbose=False, stream=False):
def __call__(self, source=None, model=None, stream=False):
if stream:
return self.stream_inference(source, model, verbose)
return self.stream_inference(source, model)
else:
return list(self.stream_inference(source, model, verbose)) # merge list of Result into one
return list(self.stream_inference(source, model)) # merge list of Result into one
def predict_cli(self):
# Method used for CLI prediction. It uses always generator as outputs as not required by CLI mode
gen = self.stream_inference(verbose=True)
gen = self.stream_inference()
for _ in gen: # running CLI inference without accumulating any outputs (do not modify)
pass
def stream_inference(self, source=None, model=None, verbose=False):
def stream_inference(self, source=None, model=None):
self.run_callbacks("on_predict_start")
# setup model
@ -201,7 +201,7 @@ class BasePredictor:
p, im0 = (path[i], im0s[i]) if self.webcam or self.from_img else (path, im0s)
p = Path(p)
if verbose or self.args.save or self.args.save_txt or self.args.show:
if self.args.verbose or self.args.save or self.args.save_txt or self.args.show:
s += self.write_results(i, self.results, (p, im, im0))
if self.args.show:
@ -214,11 +214,11 @@ class BasePredictor:
yield from self.results
# Print time (inference-only)
if verbose:
if self.args.verbose:
LOGGER.info(f"{s}{'' if len(preds) else '(no detections), '}{self.dt[1].dt * 1E3:.1f}ms")
# Print results
if verbose and self.seen:
if self.args.verbose and self.seen:
t = tuple(x.t / self.seen * 1E3 for x in self.dt) # speeds per image
LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms postprocess per image at shape '
f'{(1, 3, *self.imgsz)}' % t)
@ -243,7 +243,7 @@ class BasePredictor:
if isinstance(source, (str, int, Path)): # int for local usb carame
source = str(source)
is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
is_url = source.lower().startswith(('rtsp://', 'rtmp://', 'http://', 'https://'))
is_url = source.lower().startswith(('https://', 'http://', 'rtsp://', 'rtmp://'))
webcam = source.isnumeric() or source.endswith('.streams') or (is_url and not is_file)
screenshot = source.lower().startswith('screen')
if is_url and is_file:

@ -85,7 +85,6 @@ class BaseTrainer:
self.console = LOGGER
self.validator = None
self.model = None
self.callbacks = defaultdict(list)
init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic)
# Dirs
@ -141,7 +140,7 @@ class BaseTrainer:
self.plot_idx = [0, 1, 2]
# Callbacks
self.callbacks = defaultdict(list, {k: v for k, v in callbacks.default_callbacks.items()}) # add callbacks
self.callbacks = defaultdict(list, callbacks.default_callbacks) # add callbacks
if RANK in {0, -1}:
callbacks.add_integration_callbacks(self)

@ -70,7 +70,7 @@ class BaseValidator:
if self.args.conf is None:
self.args.conf = 0.001 # default conf=0.001
self.callbacks = defaultdict(list, {k: v for k, v in callbacks.default_callbacks.items()}) # add callbacks
self.callbacks = defaultdict(list, callbacks.default_callbacks) # add callbacks
@smart_inference_mode()
def __call__(self, trainer=None, model=None):

@ -5,6 +5,7 @@ import inspect
import logging.config
import os
import platform
import re
import subprocess
import sys
import tempfile
@ -113,12 +114,66 @@ class IterableSimpleNamespace(SimpleNamespace):
return getattr(self, key, default)
def yaml_save(file='data.yaml', data=None):
"""
Save YAML data to a file.
Args:
file (str, optional): File name. Default is 'data.yaml'.
data (dict, optional): Data to save in YAML format. Default is None.
Returns:
None: Data is saved to the specified file.
"""
file = Path(file)
if not file.parent.exists():
# Create parent directories if they don't exist
file.parent.mkdir(parents=True, exist_ok=True)
with open(file, 'w') as f:
# Dump data to file in YAML format, converting Path objects to strings
yaml.safe_dump({k: str(v) if isinstance(v, Path) else v for k, v in data.items()}, f, sort_keys=False)
def yaml_load(file='data.yaml', append_filename=False):
"""
Load YAML data from a file.
Args:
file (str, optional): File name. Default is 'data.yaml'.
append_filename (bool): Add the YAML filename to the YAML dictionary. Default is False.
Returns:
dict: YAML data and file name.
"""
with open(file, errors='ignore', encoding='utf-8') as f:
# Add YAML filename to dict and return
s = f.read() # string
if not s.isprintable(): # remove special characters
s = re.sub(r'[^\x09\x0A\x0D\x20-\x7E\x85\xA0-\uD7FF\uE000-\uFFFD\U00010000-\U0010ffff]+', '', s)
return {**yaml.safe_load(s), 'yaml_file': str(file)} if append_filename else yaml.safe_load(s)
def yaml_print(yaml_file: Union[str, Path, dict]) -> None:
"""
Pretty prints a yaml file or a yaml-formatted dictionary.
Args:
yaml_file: The file path of the yaml file or a yaml-formatted dictionary.
Returns:
None
"""
yaml_dict = yaml_load(yaml_file) if isinstance(yaml_file, (str, Path)) else yaml_file
dump = yaml.dump(yaml_dict, default_flow_style=False)
LOGGER.info(f"Printing '{colorstr('bold', 'black', yaml_file)}'\n\n{dump}")
# Default configuration
with open(DEFAULT_CFG_PATH, errors='ignore') as f:
DEFAULT_CFG_DICT = yaml.safe_load(f)
for k, v in DEFAULT_CFG_DICT.items():
if isinstance(v, str) and v.lower() == 'none':
DEFAULT_CFG_DICT[k] = None
DEFAULT_CFG_DICT = yaml_load(DEFAULT_CFG_PATH)
for k, v in DEFAULT_CFG_DICT.items():
if isinstance(v, str) and v.lower() == 'none':
DEFAULT_CFG_DICT[k] = None
DEFAULT_CFG_KEYS = DEFAULT_CFG_DICT.keys()
DEFAULT_CFG = IterableSimpleNamespace(**DEFAULT_CFG_DICT)
@ -393,58 +448,6 @@ def threaded(func):
return wrapper
def yaml_save(file='data.yaml', data=None):
"""
Save YAML data to a file.
Args:
file (str, optional): File name. Default is 'data.yaml'.
data (dict, optional): Data to save in YAML format. Default is None.
Returns:
None: Data is saved to the specified file.
"""
file = Path(file)
if not file.parent.exists():
# Create parent directories if they don't exist
file.parent.mkdir(parents=True, exist_ok=True)
with open(file, 'w') as f:
# Dump data to file in YAML format, converting Path objects to strings
yaml.safe_dump({k: str(v) if isinstance(v, Path) else v for k, v in data.items()}, f, sort_keys=False)
def yaml_load(file='data.yaml', append_filename=False):
"""
Load YAML data from a file.
Args:
file (str, optional): File name. Default is 'data.yaml'.
append_filename (bool): Add the YAML filename to the YAML dictionary. Default is False.
Returns:
dict: YAML data and file name.
"""
with open(file, errors='ignore') as f:
# Add YAML filename to dict and return
return {**yaml.safe_load(f), 'yaml_file': str(file)} if append_filename else yaml.safe_load(f)
def yaml_print(yaml_file: Union[str, Path, dict]) -> None:
"""
Pretty prints a yaml file or a yaml-formatted dictionary.
Args:
yaml_file: The file path of the yaml file or a yaml-formatted dictionary.
Returns:
None
"""
yaml_dict = yaml_load(yaml_file) if isinstance(yaml_file, (str, Path)) else yaml_file
dump = yaml.dump(yaml_dict, default_flow_style=False)
LOGGER.info(f"Printing '{colorstr('bold', 'black', yaml_file)}'\n\n{dump}")
def set_sentry():
"""
Initialize the Sentry SDK for error tracking and reporting if pytest is not currently running.

@ -207,9 +207,9 @@ def check_file(file, suffix=''):
# Search/download file (if necessary) and return path
check_suffix(file, suffix) # optional
file = str(file) # convert to str()
if Path(file).is_file() or not file: # exists
if not file or ('://' not in file and Path(file).is_file()): # exists ('://' check required in Windows Python<3.10)
return file
elif file.startswith(('http:/', 'https:/')): # download
elif file.lower().startswith(('https://', 'http://', 'rtsp://', 'rtmp://')): # download
url = file # warning: Pathlib turns :// -> :/
file = Path(urllib.parse.unquote(file).split('?')[0]).name # '%2F' to '/', split https://url.com/file.txt?auth
if Path(file).is_file():
@ -276,7 +276,7 @@ def git_describe(path=ROOT): # path must be a directory
try:
assert (Path(path) / '.git').is_dir()
return check_output(f'git -C {path} describe --tags --long --always', shell=True).decode()[:-1]
except Exception:
except AssertionError:
return ''

@ -104,7 +104,7 @@ def download(url, dir=Path.cwd(), unzip=True, delete=True, curl=False, threads=1
def download_one(url, dir):
# Download 1 file
success = True
if Path(url).is_file():
if '://' not in str(url) and Path(url).is_file(): # exists ('://' check required in Windows Python<3.10)
f = Path(url) # filename
else: # does not exist
f = dir / Path(url).name

@ -17,11 +17,8 @@ import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parallel import DistributedDataParallel as DDP
import ultralytics
from ultralytics.yolo.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER
from ultralytics.yolo.utils.checks import git_describe
from .checks import check_version
from ultralytics.yolo.utils.checks import check_version
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
RANK = int(os.getenv('RANK', -1))
@ -60,8 +57,8 @@ def DDP_model(model):
def select_device(device='', batch=0, newline=False):
# device = None or 'cpu' or 0 or '0' or '0,1,2,3'
ver = git_describe() or ultralytics.__version__ # git commit or pip package version
s = f'Ultralytics YOLOv{ver} 🚀 Python-{platform.python_version()} torch-{torch.__version__} '
from ultralytics import __version__
s = f'Ultralytics YOLOv{__version__} 🚀 Python-{platform.python_version()} torch-{torch.__version__} '
device = str(device).lower()
for remove in 'cuda:', 'none', '(', ')', '[', ']', "'", ' ':
device = device.replace(remove, '') # to string, 'cuda:0' -> '0' and '(0, 1)' -> '0,1'
@ -247,6 +244,7 @@ class ModelEMA:
""" Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
Keeps a moving average of everything in the model state_dict (parameters and buffers)
For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
To disable EMA set the `enabled` attribute to `False`.
"""
def __init__(self, model, decay=0.9999, tau=2000, updates=0):
@ -256,22 +254,25 @@ class ModelEMA:
self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
for p in self.ema.parameters():
p.requires_grad_(False)
self.enabled = True
def update(self, model):
# Update EMA parameters
self.updates += 1
d = self.decay(self.updates)
if self.enabled:
self.updates += 1
d = self.decay(self.updates)
msd = de_parallel(model).state_dict() # model state_dict
for k, v in self.ema.state_dict().items():
if v.dtype.is_floating_point: # true for FP16 and FP32
v *= d
v += (1 - d) * msd[k].detach()
# assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype} and model {msd[k].dtype} must be FP32'
msd = de_parallel(model).state_dict() # model state_dict
for k, v in self.ema.state_dict().items():
if v.dtype.is_floating_point: # true for FP16 and FP32
v *= d
v += (1 - d) * msd[k].detach()
# assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype}, model {msd[k].dtype}'
def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
# Update EMA attributes
copy_attr(self.ema, model, include, exclude)
if self.enabled:
copy_attr(self.ema, model, include, exclude)
def strip_optimizer(f='best.pt', s=''):
@ -285,8 +286,8 @@ def strip_optimizer(f='best.pt', s=''):
strip_optimizer(f)
Args:
f (str): file path to model state to strip the optimizer from. Default is 'best.pt'.
s (str): file path to save the model with stripped optimizer to. Default is ''. If not provided, the original file will be overwritten.
f (str): file path to model to strip the optimizer from. Default is 'best.pt'.
s (str): file path to save the model with stripped optimizer to. If not provided, 'f' will be overwritten.
Returns:
None
@ -364,12 +365,12 @@ class EarlyStopping:
Early stopping class that stops training when a specified number of epochs have passed without improvement.
"""
def __init__(self, patience=30):
def __init__(self, patience=50):
"""
Initialize early stopping object
Args:
patience (int, optional): Number of epochs to wait after fitness stops improving before stopping. Default is 30.
patience (int, optional): Number of epochs to wait after fitness stops improving before stopping.
"""
self.best_fitness = 0.0 # i.e. mAP
self.best_epoch = 0

Loading…
Cancel
Save