`ultralytics 8.0.169``TQDM`, `INTERP_LINEAR` and RTDETR `load_image()` updates (#4704)

Co-authored-by: Rustem Galiullin <rustemgal@gmail.com>
Co-authored-by: Rustem Galiullin <rustem.galiullin@bayanat.ai>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
main
Glenn Jocher 2 years ago committed by GitHub
parent a4fabfdacf
commit 187b504d68
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 12
      .github/workflows/ci.yaml
  2. 3
      docker/Dockerfile-arm64
  3. 3
      docker/Dockerfile-cpu
  4. 3
      docker/Dockerfile-jetson
  5. 3
      docker/Dockerfile-python
  6. 9
      docker/Dockerfile-runner
  7. 2
      ultralytics/__init__.py
  8. 18
      ultralytics/cfg/models/README.md
  9. 17
      ultralytics/data/base.py
  10. 7
      ultralytics/data/converter.py
  11. 13
      ultralytics/data/dataset.py
  12. 11
      ultralytics/data/utils.py
  13. 7
      ultralytics/engine/trainer.py
  14. 10
      ultralytics/engine/validator.py
  15. 7
      ultralytics/hub/utils.py
  16. 7
      ultralytics/models/fastsam/prompt.py
  17. 27
      ultralytics/models/rtdetr/val.py
  18. 10
      ultralytics/trackers/README.md
  19. 19
      ultralytics/utils/__init__.py
  20. 7
      ultralytics/utils/benchmarks.py
  21. 1
      ultralytics/utils/callbacks/clearml.py
  22. 1
      ultralytics/utils/callbacks/mlflow.py
  23. 20
      ultralytics/utils/downloads.py

@ -128,18 +128,18 @@ jobs:
python --version python --version
pip --version pip --version
pip list pip list
#- name: Benchmark DetectionModel # - name: Benchmark DetectionModel
# shell: bash # shell: bash
# run: coverage run -a --source=ultralytics -m ultralytics.cfg.__init__ benchmark model='path with spaces/${{ matrix.model }}.pt' imgsz=160 verbose=0.26 # run: coverage run -a --source=ultralytics -m ultralytics.cfg.__init__ benchmark model='path with spaces/${{ matrix.model }}.pt' imgsz=160 verbose=0.318
- name: Benchmark SegmentationModel - name: Benchmark SegmentationModel
shell: bash shell: bash
run: coverage run -a --source=ultralytics -m ultralytics.cfg.__init__ benchmark model='path with spaces/${{ matrix.model }}-seg.pt' imgsz=160 verbose=0.30 run: coverage run -a --source=ultralytics -m ultralytics.cfg.__init__ benchmark model='path with spaces/${{ matrix.model }}-seg.pt' imgsz=160 verbose=0.286
- name: Benchmark ClassificationModel - name: Benchmark ClassificationModel
shell: bash shell: bash
run: coverage run -a --source=ultralytics -m ultralytics.cfg.__init__ benchmark model='path with spaces/${{ matrix.model }}-cls.pt' imgsz=160 verbose=0.16 run: coverage run -a --source=ultralytics -m ultralytics.cfg.__init__ benchmark model='path with spaces/${{ matrix.model }}-cls.pt' imgsz=160 verbose=0.166
- name: Benchmark PoseModel - name: Benchmark PoseModel
shell: bash shell: bash
run: coverage run -a --source=ultralytics -m ultralytics.cfg.__init__ benchmark model='path with spaces/${{ matrix.model }}-pose.pt' imgsz=160 verbose=0.17 run: coverage run -a --source=ultralytics -m ultralytics.cfg.__init__ benchmark model='path with spaces/${{ matrix.model }}-pose.pt' imgsz=160 verbose=0.185
- name: Merge Coverage Reports - name: Merge Coverage Reports
run: | run: |
coverage xml -o coverage-benchmarks.xml coverage xml -o coverage-benchmarks.xml

@ -35,5 +35,8 @@ RUN pip install --no-cache -e . thop
# Run # Run
# t=ultralytics/ultralytics:latest-arm64 && sudo docker run -it --ipc=host $t # t=ultralytics/ultralytics:latest-arm64 && sudo docker run -it --ipc=host $t
# Pull and Run
# t=ultralytics/ultralytics:latest-arm64 && sudo docker pull $t && sudo docker run -it --ipc=host $t
# Pull and Run with local volume mounted # Pull and Run with local volume mounted
# t=ultralytics/ultralytics:latest-arm64 && sudo docker pull $t && sudo docker run -it --ipc=host -v "$(pwd)"/datasets:/usr/src/datasets $t # t=ultralytics/ultralytics:latest-arm64 && sudo docker pull $t && sudo docker run -it --ipc=host -v "$(pwd)"/datasets:/usr/src/datasets $t

@ -45,5 +45,8 @@ RUN rm -rf tmp
# Run # Run
# t=ultralytics/ultralytics:latest-cpu && sudo docker run -it --ipc=host $t # t=ultralytics/ultralytics:latest-cpu && sudo docker run -it --ipc=host $t
# Pull and Run
# t=ultralytics/ultralytics:latest-cpu && sudo docker pull $t && sudo docker run -it --ipc=host $t
# Pull and Run with local volume mounted # Pull and Run with local volume mounted
# t=ultralytics/ultralytics:latest-cpu && sudo docker pull $t && sudo docker run -it --ipc=host -v "$(pwd)"/datasets:/usr/src/datasets $t # t=ultralytics/ultralytics:latest-cpu && sudo docker pull $t && sudo docker run -it --ipc=host -v "$(pwd)"/datasets:/usr/src/datasets $t

@ -42,5 +42,8 @@ ENV OMP_NUM_THREADS=1
# Run # Run
# t=ultralytics/ultralytics:latest-jetson && sudo docker run -it --ipc=host $t # t=ultralytics/ultralytics:latest-jetson && sudo docker run -it --ipc=host $t
# Pull and Run
# t=ultralytics/ultralytics:latest-jetson && sudo docker pull $t && sudo docker run -it --ipc=host $t
# Pull and Run with NVIDIA runtime # Pull and Run with NVIDIA runtime
# t=ultralytics/ultralytics:latest-jetson && sudo docker pull $t && sudo docker run -it --ipc=host --runtime=nvidia $t # t=ultralytics/ultralytics:latest-jetson && sudo docker pull $t && sudo docker run -it --ipc=host --runtime=nvidia $t

@ -45,5 +45,8 @@ RUN rm -rf tmp
# Run # Run
# t=ultralytics/ultralytics:latest-python && sudo docker run -it --ipc=host $t # t=ultralytics/ultralytics:latest-python && sudo docker run -it --ipc=host $t
# Pull and Run
# t=ultralytics/ultralytics:latest-python && sudo docker pull $t && sudo docker run -it --ipc=host $t
# Pull and Run with local volume mounted # Pull and Run with local volume mounted
# t=ultralytics/ultralytics:latest-python && sudo docker pull $t && sudo docker run -it --ipc=host -v "$(pwd)"/datasets:/usr/src/datasets $t # t=ultralytics/ultralytics:latest-python && sudo docker pull $t && sudo docker run -it --ipc=host -v "$(pwd)"/datasets:/usr/src/datasets $t

@ -8,10 +8,11 @@ FROM ultralytics/ultralytics:latest
# Set the working directory # Set the working directory
WORKDIR /actions-runner WORKDIR /actions-runner
# Download and unpack the latest runner # Download and unpack the latest runner from https://github.com/actions/runner
RUN curl -o actions-runner-linux-x64-2.308.0.tar.gz -L https://github.com/actions/runner/releases/download/v2.308.0/actions-runner-linux-x64-2.308.0.tar.gz && \ RUN FILENAME=actions-runner-linux-x64-2.308.0.tar.gz && \
tar xzf actions-runner-linux-x64-2.308.0.tar.gz && \ curl -o $FILENAME -L https://github.com/actions/runner/releases/download/v2.308.0/$FILENAME && \
rm actions-runner-linux-x64-2.308.0.tar.gz tar xzf $FILENAME && \
rm $FILENAME
# Install runner dependencies # Install runner dependencies
ENV RUNNER_ALLOW_RUNASROOT=1 ENV RUNNER_ALLOW_RUNASROOT=1

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license # Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = '8.0.168' __version__ = '8.0.169'
from ultralytics.models import RTDETR, SAM, YOLO from ultralytics.models import RTDETR, SAM, YOLO
from ultralytics.models.fastsam import FastSAM from ultralytics.models.fastsam import FastSAM

@ -1,19 +1,10 @@
## Models ## Models
Welcome to the Ultralytics Models directory! Here you will find a wide variety of pre-configured model configuration Welcome to the Ultralytics Models directory! Here you will find a wide variety of pre-configured model configuration files (`*.yaml`s) that can be used to create custom YOLO models. The models in this directory have been expertly crafted and fine-tuned by the Ultralytics team to provide the best performance for a wide range of object detection and image segmentation tasks.
files (`*.yaml`s) that can be used to create custom YOLO models. The models in this directory have been expertly crafted
and fine-tuned by the Ultralytics team to provide the best performance for a wide range of object detection and image
segmentation tasks.
These model configurations cover a wide range of scenarios, from simple object detection to more complex tasks like These model configurations cover a wide range of scenarios, from simple object detection to more complex tasks like instance segmentation and object tracking. They are also designed to run efficiently on a variety of hardware platforms, from CPUs to GPUs. Whether you are a seasoned machine learning practitioner or just getting started with YOLO, this directory provides a great starting point for your custom model development needs.
instance segmentation and object tracking. They are also designed to run efficiently on a variety of hardware platforms,
from CPUs to GPUs. Whether you are a seasoned machine learning practitioner or just getting started with YOLO, this
directory provides a great starting point for your custom model development needs.
To get started, simply browse through the models in this directory and find one that best suits your needs. Once you've To get started, simply browse through the models in this directory and find one that best suits your needs. Once you've selected a model, you can use the provided `*.yaml` file to train and deploy your custom YOLO model with ease. See full details at the Ultralytics [Docs](https://docs.ultralytics.com/models), and if you need help or have any questions, feel free to reach out to the Ultralytics team for support. So, don't wait, start creating your custom YOLO model now!
selected a model, you can use the provided `*.yaml` file to train and deploy your custom YOLO model with ease. See full
details at the Ultralytics [Docs](https://docs.ultralytics.com/models), and if you need help or have any questions, feel free
to reach out to the Ultralytics team for support. So, don't wait, start creating your custom YOLO model now!
### Usage ### Usage
@ -37,8 +28,7 @@ model.train(data="coco128.yaml", epochs=100) # train the model
## Pre-trained Model Architectures ## Pre-trained Model Architectures
Ultralytics supports many model architectures. Visit https://docs.ultralytics.com/models to view detailed information Ultralytics supports many model architectures. Visit https://docs.ultralytics.com/models to view detailed information and usage. Any of these models can be used by loading their configs or pretrained checkpoints if available.
and usage. Any of these models can be used by loading their configs or pretrained checkpoints if available.
## Contributing New Models ## Contributing New Models

@ -13,9 +13,8 @@ import cv2
import numpy as np import numpy as np
import psutil import psutil
from torch.utils.data import Dataset from torch.utils.data import Dataset
from tqdm import tqdm
from ultralytics.utils import DEFAULT_CFG, LOCAL_RANK, LOGGER, NUM_THREADS, TQDM_BAR_FORMAT from ultralytics.utils import DEFAULT_CFG, LOCAL_RANK, LOGGER, NUM_THREADS, TQDM
from .utils import HELP_URL, IMG_FORMATS from .utils import HELP_URL, IMG_FORMATS
@ -141,7 +140,7 @@ class BaseDataset(Dataset):
if self.single_cls: if self.single_cls:
self.labels[i]['cls'][:, 0] = 0 self.labels[i]['cls'][:, 0] = 0
def load_image(self, i): def load_image(self, i, rect_mode=True):
"""Loads 1 image from dataset index 'i', returns (im, resized hw).""" """Loads 1 image from dataset index 'i', returns (im, resized hw)."""
im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i] im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i]
if im is None: # not cached in RAM if im is None: # not cached in RAM
@ -152,11 +151,13 @@ class BaseDataset(Dataset):
if im is None: if im is None:
raise FileNotFoundError(f'Image Not Found {f}') raise FileNotFoundError(f'Image Not Found {f}')
h0, w0 = im.shape[:2] # orig hw h0, w0 = im.shape[:2] # orig hw
if rect_mode: # resize long side to imgsz while maintaining aspect ratio
r = self.imgsz / max(h0, w0) # ratio r = self.imgsz / max(h0, w0) # ratio
if r != 1: # if sizes are not equal if r != 1: # if sizes are not equal
interp = cv2.INTER_LINEAR if (self.augment or r > 1) else cv2.INTER_AREA w, h = (min(math.ceil(w0 * r), self.imgsz), min(math.ceil(h0 * r), self.imgsz))
im = cv2.resize(im, (min(math.ceil(w0 * r), self.imgsz), min(math.ceil(h0 * r), self.imgsz)), im = cv2.resize(im, (w, h), interpolation=cv2.INTER_LINEAR)
interpolation=interp) elif not (h0 == w0 == self.imgsz): # resize by stretching image to square imgsz
im = cv2.resize(im, (self.imgsz, self.imgsz), interpolation=cv2.INTER_LINEAR)
# Add to buffer if training with augmentations # Add to buffer if training with augmentations
if self.augment: if self.augment:
@ -176,7 +177,7 @@ class BaseDataset(Dataset):
fcn = self.cache_images_to_disk if cache == 'disk' else self.load_image fcn = self.cache_images_to_disk if cache == 'disk' else self.load_image
with ThreadPool(NUM_THREADS) as pool: with ThreadPool(NUM_THREADS) as pool:
results = pool.imap(fcn, range(self.ni)) results = pool.imap(fcn, range(self.ni))
pbar = tqdm(enumerate(results), total=self.ni, bar_format=TQDM_BAR_FORMAT, disable=LOCAL_RANK > 0) pbar = TQDM(enumerate(results), total=self.ni, disable=LOCAL_RANK > 0)
for i, x in pbar: for i, x in pbar:
if cache == 'disk': if cache == 'disk':
b += self.npy_files[i].stat().st_size b += self.npy_files[i].stat().st_size
@ -190,7 +191,7 @@ class BaseDataset(Dataset):
"""Saves an image as an *.npy file for faster loading.""" """Saves an image as an *.npy file for faster loading."""
f = self.npy_files[i] f = self.npy_files[i]
if not f.exists(): if not f.exists():
np.save(f.as_posix(), cv2.imread(self.im_files[i])) np.save(f.as_posix(), cv2.imread(self.im_files[i]), allow_pickle=False)
def check_cache_ram(self, safety_margin=0.5): def check_cache_ram(self, safety_margin=0.5):
"""Check image caching requirements vs available memory.""" """Check image caching requirements vs available memory."""

@ -7,7 +7,8 @@ from pathlib import Path
import cv2 import cv2
import numpy as np import numpy as np
from tqdm import tqdm
from ultralytics.utils import TQDM
def coco91_to_coco80_class(): def coco91_to_coco80_class():
@ -90,7 +91,7 @@ def convert_coco(labels_dir='../coco/annotations/', use_segments=False, use_keyp
imgToAnns[ann['image_id']].append(ann) imgToAnns[ann['image_id']].append(ann)
# Write labels file # Write labels file
for img_id, anns in tqdm(imgToAnns.items(), desc=f'Annotations {json_file}'): for img_id, anns in TQDM(imgToAnns.items(), desc=f'Annotations {json_file}'):
img = images[f'{img_id:d}'] img = images[f'{img_id:d}']
h, w, f = img['height'], img['width'], img['file_name'] h, w, f = img['height'], img['width'], img['file_name']
@ -222,7 +223,7 @@ def convert_dota_to_yolo_obb(dota_root_path: str):
save_dir.mkdir(parents=True, exist_ok=True) save_dir.mkdir(parents=True, exist_ok=True)
image_paths = list(image_dir.iterdir()) image_paths = list(image_dir.iterdir())
for image_path in tqdm(image_paths, desc=f'Processing {phase} images'): for image_path in TQDM(image_paths, desc=f'Processing {phase} images'):
if image_path.suffix != '.png': if image_path.suffix != '.png':
continue continue
image_name_without_ext = image_path.stem image_name_without_ext = image_path.stem

@ -8,9 +8,8 @@ import cv2
import numpy as np import numpy as np
import torch import torch
import torchvision import torchvision
from tqdm import tqdm
from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM_BAR_FORMAT, colorstr, is_dir_writeable from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM, colorstr, is_dir_writeable
from .augment import Compose, Format, Instances, LetterBox, classify_albumentations, classify_transforms, v8_transforms from .augment import Compose, Format, Instances, LetterBox, classify_albumentations, classify_transforms, v8_transforms
from .base import BaseDataset from .base import BaseDataset
@ -60,7 +59,7 @@ class YOLODataset(BaseDataset):
iterable=zip(self.im_files, self.label_files, repeat(self.prefix), iterable=zip(self.im_files, self.label_files, repeat(self.prefix),
repeat(self.use_keypoints), repeat(len(self.data['names'])), repeat(nkpt), repeat(self.use_keypoints), repeat(len(self.data['names'])), repeat(nkpt),
repeat(ndim))) repeat(ndim)))
pbar = tqdm(results, desc=desc, total=total, bar_format=TQDM_BAR_FORMAT) pbar = TQDM(results, desc=desc, total=total)
for im_file, lb, shape, segments, keypoint, nm_f, nf_f, ne_f, nc_f, msg in pbar: for im_file, lb, shape, segments, keypoint, nm_f, nf_f, ne_f, nc_f, msg in pbar:
nm += nm_f nm += nm_f
nf += nf_f nf += nf_f
@ -107,7 +106,7 @@ class YOLODataset(BaseDataset):
nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupt, total nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupt, total
if exists and LOCAL_RANK in (-1, 0): if exists and LOCAL_RANK in (-1, 0):
d = f'Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt' d = f'Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt'
tqdm(None, desc=self.prefix + d, total=n, initial=n, bar_format=TQDM_BAR_FORMAT) # display results TQDM(None, desc=self.prefix + d, total=n, initial=n) # display results
if cache['msgs']: if cache['msgs']:
LOGGER.info('\n'.join(cache['msgs'])) # display warnings LOGGER.info('\n'.join(cache['msgs'])) # display warnings
@ -244,7 +243,7 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
im = self.samples[i][3] = cv2.imread(f) im = self.samples[i][3] = cv2.imread(f)
elif self.cache_disk: elif self.cache_disk:
if not fn.exists(): # load npy if not fn.exists(): # load npy
np.save(fn.as_posix(), cv2.imread(f)) np.save(fn.as_posix(), cv2.imread(f), allow_pickle=False)
im = np.load(fn) im = np.load(fn)
else: # read image else: # read image
im = cv2.imread(f) # BGR im = cv2.imread(f) # BGR
@ -269,7 +268,7 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
nf, nc, n, samples = cache.pop('results') # found, missing, empty, corrupt, total nf, nc, n, samples = cache.pop('results') # found, missing, empty, corrupt, total
if LOCAL_RANK in (-1, 0): if LOCAL_RANK in (-1, 0):
d = f'{desc} {nf} images, {nc} corrupt' d = f'{desc} {nf} images, {nc} corrupt'
tqdm(None, desc=d, total=n, initial=n, bar_format=TQDM_BAR_FORMAT) TQDM(None, desc=d, total=n, initial=n)
if cache['msgs']: if cache['msgs']:
LOGGER.info('\n'.join(cache['msgs'])) # display warnings LOGGER.info('\n'.join(cache['msgs'])) # display warnings
return samples return samples
@ -278,7 +277,7 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
nf, nc, msgs, samples, x = 0, 0, [], [], {} nf, nc, msgs, samples, x = 0, 0, [], [], {}
with ThreadPool(NUM_THREADS) as pool: with ThreadPool(NUM_THREADS) as pool:
results = pool.imap(func=verify_image, iterable=zip(self.samples, repeat(self.prefix))) results = pool.imap(func=verify_image, iterable=zip(self.samples, repeat(self.prefix)))
pbar = tqdm(results, desc=desc, total=len(self.samples), bar_format=TQDM_BAR_FORMAT) pbar = TQDM(results, desc=desc, total=len(self.samples))
for sample, nf_f, nc_f, msg in pbar: for sample, nf_f, nc_f, msg in pbar:
if nf_f: if nf_f:
samples.append(sample) samples.append(sample)

@ -15,11 +15,10 @@ from tarfile import is_tarfile
import cv2 import cv2
import numpy as np import numpy as np
from PIL import Image, ImageOps from PIL import Image, ImageOps
from tqdm import tqdm
from ultralytics.nn.autobackend import check_class_names from ultralytics.nn.autobackend import check_class_names
from ultralytics.utils import (DATASETS_DIR, LOGGER, NUM_THREADS, ROOT, SETTINGS_YAML, clean_url, colorstr, emojis, from ultralytics.utils import (DATASETS_DIR, LOGGER, NUM_THREADS, ROOT, SETTINGS_YAML, TQDM, clean_url, colorstr,
yaml_load) emojis, yaml_load)
from ultralytics.utils.checks import check_file, check_font, is_ascii from ultralytics.utils.checks import check_file, check_font, is_ascii
from ultralytics.utils.downloads import download, safe_download, unzip_file from ultralytics.utils.downloads import download, safe_download, unzip_file
from ultralytics.utils.ops import segments2boxes from ultralytics.utils.ops import segments2boxes
@ -510,7 +509,7 @@ class HUBDatasetStats:
use_keypoints=self.task == 'pose') use_keypoints=self.task == 'pose')
x = np.array([ x = np.array([
np.bincount(label['cls'].astype(int).flatten(), minlength=self.data['nc']) np.bincount(label['cls'].astype(int).flatten(), minlength=self.data['nc'])
for label in tqdm(dataset.labels, total=len(dataset), desc='Statistics')]) # shape(128x80) for label in TQDM(dataset.labels, total=len(dataset), desc='Statistics')]) # shape(128x80)
self.stats[split] = { self.stats[split] = {
'instance_stats': { 'instance_stats': {
'total': int(x.sum()), 'total': int(x.sum()),
@ -541,7 +540,7 @@ class HUBDatasetStats:
continue continue
dataset = YOLODataset(img_path=self.data[split], data=self.data) dataset = YOLODataset(img_path=self.data[split], data=self.data)
with ThreadPool(NUM_THREADS) as pool: with ThreadPool(NUM_THREADS) as pool:
for _ in tqdm(pool.imap(self._hub_ops, dataset.im_files), total=len(dataset), desc=f'{split} images'): for _ in TQDM(pool.imap(self._hub_ops, dataset.im_files), total=len(dataset), desc=f'{split} images'):
pass pass
LOGGER.info(f'Done. All images saved to {self.im_dir}') LOGGER.info(f'Done. All images saved to {self.im_dir}')
return self.im_dir return self.im_dir
@ -614,7 +613,7 @@ def autosplit(path=DATASETS_DIR / 'coco8/images', weights=(0.9, 0.1, 0.0), annot
(path.parent / x).unlink() # remove existing (path.parent / x).unlink() # remove existing
LOGGER.info(f'Autosplitting images from {path}' + ', using *.txt labeled images only' * annotated_only) LOGGER.info(f'Autosplitting images from {path}' + ', using *.txt labeled images only' * annotated_only)
for i, img in tqdm(zip(indices, files), total=n): for i, img in TQDM(zip(indices, files), total=n):
if not annotated_only or Path(img2label_paths([str(img)])[0]).exists(): # check label if not annotated_only or Path(img2label_paths([str(img)])[0]).exists(): # check label
with open(path.parent / txt[i], 'a') as f: with open(path.parent / txt[i], 'a') as f:
f.write(f'./{img.relative_to(path.parent).as_posix()}' + '\n') # add image to txt file f.write(f'./{img.relative_to(path.parent).as_posix()}' + '\n') # add image to txt file

@ -21,13 +21,12 @@ from torch import distributed as dist
from torch import nn, optim from torch import nn, optim
from torch.cuda import amp from torch.cuda import amp
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from tqdm import tqdm
from ultralytics.cfg import get_cfg, get_save_dir from ultralytics.cfg import get_cfg, get_save_dir
from ultralytics.data.utils import check_cls_dataset, check_det_dataset from ultralytics.data.utils import check_cls_dataset, check_det_dataset
from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights
from ultralytics.utils import (DEFAULT_CFG, LOGGER, RANK, TQDM_BAR_FORMAT, __version__, callbacks, clean_url, colorstr, from ultralytics.utils import (DEFAULT_CFG, LOGGER, RANK, TQDM, __version__, callbacks, clean_url, colorstr, emojis,
emojis, yaml_save) yaml_save)
from ultralytics.utils.autobatch import check_train_batch_size from ultralytics.utils.autobatch import check_train_batch_size
from ultralytics.utils.checks import check_amp, check_file, check_imgsz, print_args from ultralytics.utils.checks import check_amp, check_file, check_imgsz, print_args
from ultralytics.utils.dist import ddp_cleanup, generate_ddp_command from ultralytics.utils.dist import ddp_cleanup, generate_ddp_command
@ -326,7 +325,7 @@ class BaseTrainer:
if RANK in (-1, 0): if RANK in (-1, 0):
LOGGER.info(self.progress_string()) LOGGER.info(self.progress_string())
pbar = tqdm(enumerate(self.train_loader), total=nb, bar_format=TQDM_BAR_FORMAT) pbar = TQDM(enumerate(self.train_loader), total=nb)
self.tloss = None self.tloss = None
self.optimizer.zero_grad() self.optimizer.zero_grad()
for i, batch in pbar: for i, batch in pbar:

@ -24,12 +24,11 @@ from pathlib import Path
import numpy as np import numpy as np
import torch import torch
from tqdm import tqdm
from ultralytics.cfg import get_cfg, get_save_dir from ultralytics.cfg import get_cfg, get_save_dir
from ultralytics.data.utils import check_cls_dataset, check_det_dataset from ultralytics.data.utils import check_cls_dataset, check_det_dataset
from ultralytics.nn.autobackend import AutoBackend from ultralytics.nn.autobackend import AutoBackend
from ultralytics.utils import LOGGER, TQDM_BAR_FORMAT, callbacks, colorstr, emojis from ultralytics.utils import LOGGER, TQDM, callbacks, colorstr, emojis
from ultralytics.utils.checks import check_imgsz from ultralytics.utils.checks import check_imgsz
from ultralytics.utils.ops import Profile from ultralytics.utils.ops import Profile
from ultralytics.utils.torch_utils import de_parallel, select_device, smart_inference_mode from ultralytics.utils.torch_utils import de_parallel, select_device, smart_inference_mode
@ -154,12 +153,7 @@ class BaseValidator:
model.warmup(imgsz=(1 if pt else self.args.batch, 3, imgsz, imgsz)) # warmup model.warmup(imgsz=(1 if pt else self.args.batch, 3, imgsz, imgsz)) # warmup
dt = Profile(), Profile(), Profile(), Profile() dt = Profile(), Profile(), Profile(), Profile()
n_batches = len(self.dataloader) bar = TQDM(self.dataloader, desc=self.get_desc(), total=len(self.dataloader))
desc = self.get_desc()
# NOTE: keeping `not self.training` in tqdm will eliminate pbar after segmentation evaluation during training,
# which may affect classification task since this arg is in yolov5/classify/val.py.
# bar = tqdm(self.dataloader, desc, n_batches, not self.training, bar_format=TQDM_BAR_FORMAT)
bar = tqdm(self.dataloader, desc, n_batches, bar_format=TQDM_BAR_FORMAT)
self.init_metrics(de_parallel(model)) self.init_metrics(de_parallel(model))
self.jdict = [] # empty before each val self.jdict = [] # empty before each val
for batch_i, batch in enumerate(bar): for batch_i, batch in enumerate(bar):

@ -9,10 +9,9 @@ import time
from pathlib import Path from pathlib import Path
import requests import requests
from tqdm import tqdm
from ultralytics.utils import (ENVIRONMENT, LOGGER, ONLINE, RANK, SETTINGS, TESTS_RUNNING, TQDM_BAR_FORMAT, TryExcept, from ultralytics.utils import (ENVIRONMENT, LOGGER, ONLINE, RANK, SETTINGS, TESTS_RUNNING, TQDM, TryExcept, __version__,
__version__, colorstr, get_git_origin_url, is_colab, is_git_dir, is_pip_package) colorstr, get_git_origin_url, is_colab, is_git_dir, is_pip_package)
from ultralytics.utils.downloads import GITHUB_ASSETS_NAMES from ultralytics.utils.downloads import GITHUB_ASSETS_NAMES
PREFIX = colorstr('Ultralytics HUB: ') PREFIX = colorstr('Ultralytics HUB: ')
@ -80,7 +79,7 @@ def requests_with_progress(method, url, **kwargs):
response = requests.request(method, url, stream=True, **kwargs) response = requests.request(method, url, stream=True, **kwargs)
total = int(response.headers.get('content-length', 0)) # total size total = int(response.headers.get('content-length', 0)) # total size
try: try:
pbar = tqdm(total=total, unit='B', unit_scale=True, unit_divisor=1024, bar_format=TQDM_BAR_FORMAT) pbar = TQDM(total=total, unit='B', unit_scale=True, unit_divisor=1024)
for data in response.iter_content(chunk_size=1024): for data in response.iter_content(chunk_size=1024):
pbar.update(len(data)) pbar.update(len(data))
pbar.close() pbar.close()

@ -8,9 +8,8 @@ import matplotlib.pyplot as plt
import numpy as np import numpy as np
import torch import torch
from PIL import Image from PIL import Image
from tqdm import tqdm
from ultralytics.utils import TQDM_BAR_FORMAT from ultralytics.utils import TQDM
class FastSAMPrompt: class FastSAMPrompt:
@ -87,7 +86,7 @@ class FastSAMPrompt:
retina=False, retina=False,
withContours=True): withContours=True):
n = len(annotations) n = len(annotations)
pbar = tqdm(annotations, total=n, bar_format=TQDM_BAR_FORMAT) pbar = TQDM(annotations, total=n)
for ann in pbar: for ann in pbar:
result_name = os.path.basename(ann.path) result_name = os.path.basename(ann.path)
image = ann.orig_img image = ann.orig_img
@ -156,7 +155,7 @@ class FastSAMPrompt:
save_path.parent.mkdir(exist_ok=True, parents=True) save_path.parent.mkdir(exist_ok=True, parents=True)
cv2.imwrite(str(save_path), img_array) cv2.imwrite(str(save_path), img_array)
plt.close() plt.close()
pbar.set_description('Saving {} to {}'.format(result_name, save_path)) pbar.set_description(f'Saving {result_name} to {save_path}')
@staticmethod @staticmethod
def fast_show_mask( def fast_show_mask(

@ -2,8 +2,6 @@
from pathlib import Path from pathlib import Path
import cv2
import numpy as np
import torch import torch
from ultralytics.data import YOLODataset from ultralytics.data import YOLODataset
@ -21,30 +19,9 @@ class RTDETRDataset(YOLODataset):
super().__init__(*args, data=data, use_segments=False, use_keypoints=False, **kwargs) super().__init__(*args, data=data, use_segments=False, use_keypoints=False, **kwargs)
# NOTE: add stretch version load_image for rtdetr mosaic # NOTE: add stretch version load_image for rtdetr mosaic
def load_image(self, i): def load_image(self, i, rect_mode=False):
"""Loads 1 image from dataset index 'i', returns (im, resized hw).""" """Loads 1 image from dataset index 'i', returns (im, resized hw)."""
im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i] return super().load_image(i=i, rect_mode=rect_mode)
if im is None: # not cached in RAM
if fn.exists(): # load npy
im = np.load(fn)
else: # read image
im = cv2.imread(f) # BGR
if im is None:
raise FileNotFoundError(f'Image Not Found {f}')
h0, w0 = im.shape[:2] # orig hw
im = cv2.resize(im, (self.imgsz, self.imgsz), interpolation=cv2.INTER_LINEAR)
# Add to buffer if training with augmentations
if self.augment:
self.ims[i], self.im_hw0[i], self.im_hw[i] = im, (h0, w0), im.shape[:2] # im, hw_original, hw_resized
self.buffer.append(i)
if len(self.buffer) >= self.max_buffer_length:
j = self.buffer.pop(0)
self.ims[j], self.im_hw0[j], self.im_hw[j] = None, None, None
return im, (h0, w0), im.shape[:2]
return self.ims[i], self.im_hw0[i], self.im_hw[i]
def build_transforms(self, hyp=None): def build_transforms(self, hyp=None):
"""Temporary, only for evaluation.""" """Temporary, only for evaluation."""

@ -69,7 +69,7 @@ while True:
## Change tracker parameters ## Change tracker parameters
You can change the tracker parameters by eding the `tracker.yaml` file which is located in the ultralytics/cfg/trackers folder. You can change the tracker parameters by editing the `tracker.yaml` file which is located in the ultralytics/cfg/trackers folder.
## Command Line Interface (CLI) ## Command Line Interface (CLI)
@ -81,6 +81,8 @@ yolo segment track source=... tracker=...
yolo pose track source=... tracker=... yolo pose track source=... tracker=...
``` ```
By default, trackers will use the configuration in `ultralytics/cfg/trackers`. By default, trackers will use the configuration in `ultralytics/cfg/trackers`. We also support using a modified tracker config file. Please refer to the tracker config files in `ultralytics/cfg/trackers`.
We also support using a modified tracker config file. Please refer to the tracker config files
in `ultralytics/cfg/trackers`.<br> ## Contributing New Trackers
If you've developed a new tracker architecture or have improvements for existing trackers that you'd like to contribute to the Ultralytics community, please submit your contribution in a new Pull Request. For more details, visit our [Contributing Guide](https://docs.ultralytics.com/help/contributing).

@ -20,6 +20,7 @@ import matplotlib.pyplot as plt
import numpy as np import numpy as np
import torch import torch
import yaml import yaml
from tqdm import tqdm as tqdm_original
from ultralytics import __version__ from ultralytics import __version__
@ -35,7 +36,7 @@ DEFAULT_CFG_PATH = ROOT / 'cfg/default.yaml'
NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads
AUTOINSTALL = str(os.getenv('YOLO_AUTOINSTALL', True)).lower() == 'true' # global auto-install mode AUTOINSTALL = str(os.getenv('YOLO_AUTOINSTALL', True)).lower() == 'true' # global auto-install mode
VERBOSE = str(os.getenv('YOLO_VERBOSE', True)).lower() == 'true' # global verbose mode VERBOSE = str(os.getenv('YOLO_VERBOSE', True)).lower() == 'true' # global verbose mode
TQDM_BAR_FORMAT = '{l_bar}{bar:10}{r_bar}' # tqdm bar format TQDM_BAR_FORMAT = '{l_bar}{bar:10}{r_bar}' if VERBOSE else None # tqdm bar format
LOGGING_NAME = 'ultralytics' LOGGING_NAME = 'ultralytics'
MACOS, LINUX, WINDOWS = (platform.system() == x for x in ['Darwin', 'Linux', 'Windows']) # environment booleans MACOS, LINUX, WINDOWS = (platform.system() == x for x in ['Darwin', 'Linux', 'Windows']) # environment booleans
ARM64 = platform.machine() in ('arm64', 'aarch64') # ARM64 booleans ARM64 = platform.machine() in ('arm64', 'aarch64') # ARM64 booleans
@ -106,6 +107,22 @@ os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' # for deterministic training
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # suppress verbose TF compiler warnings in Colab os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # suppress verbose TF compiler warnings in Colab
class TQDM(tqdm_original):
"""
Custom Ultralytics tqdm class with different default arguments.
Args:
(*args): Positional arguments passed to original tqdm.
(**kwargs): Keyword arguments, with custom defaults applied.
"""
def __init__(self, *args, **kwargs):
# Set new default values (these can still be overridden when calling TQDM)
kwargs['disable'] = not VERBOSE or kwargs.get('disable', False) # logical 'and' with default value if passed
kwargs.setdefault('bar_format', TQDM_BAR_FORMAT) # override default value if passed
super().__init__(*args, **kwargs)
class SimpleClass: class SimpleClass:
""" """
Ultralytics SimpleClass is a base class providing helpful string representation, error reporting, and attribute Ultralytics SimpleClass is a base class providing helpful string representation, error reporting, and attribute

@ -32,12 +32,11 @@ from pathlib import Path
import numpy as np import numpy as np
import torch.cuda import torch.cuda
from tqdm import tqdm
from ultralytics import YOLO from ultralytics import YOLO
from ultralytics.cfg import TASK2DATA, TASK2METRIC from ultralytics.cfg import TASK2DATA, TASK2METRIC
from ultralytics.engine.exporter import export_formats from ultralytics.engine.exporter import export_formats
from ultralytics.utils import ASSETS, LINUX, LOGGER, MACOS, SETTINGS from ultralytics.utils import ASSETS, LINUX, LOGGER, MACOS, SETTINGS, TQDM
from ultralytics.utils.checks import check_requirements, check_yolo from ultralytics.utils.checks import check_requirements, check_yolo
from ultralytics.utils.files import file_size from ultralytics.utils.files import file_size
from ultralytics.utils.torch_utils import select_device from ultralytics.utils.torch_utils import select_device
@ -285,7 +284,7 @@ class ProfileModels:
# Timed runs # Timed runs
run_times = [] run_times = []
for _ in tqdm(range(num_runs), desc=engine_file): for _ in TQDM(range(num_runs), desc=engine_file):
results = model(input_data, imgsz=self.imgsz, verbose=False) results = model(input_data, imgsz=self.imgsz, verbose=False)
run_times.append(results[0].speed['inference']) # Convert to milliseconds run_times.append(results[0].speed['inference']) # Convert to milliseconds
@ -336,7 +335,7 @@ class ProfileModels:
# Timed runs # Timed runs
run_times = [] run_times = []
for _ in tqdm(range(num_runs), desc=onnx_file): for _ in TQDM(range(num_runs), desc=onnx_file):
start_time = time.time() start_time = time.time()
sess.run([output_name], {input_name: input_data}) sess.run([output_name], {input_name: input_data})
run_times.append((time.time() - start_time) * 1000) # Convert to milliseconds run_times.append((time.time() - start_time) * 1000) # Convert to milliseconds

@ -9,6 +9,7 @@ try:
from clearml import Task from clearml import Task
from clearml.binding.frameworks.pytorch_bind import PatchPyTorchModelIO from clearml.binding.frameworks.pytorch_bind import PatchPyTorchModelIO
from clearml.binding.matplotlib_bind import PatchedMatplotlib from clearml.binding.matplotlib_bind import PatchedMatplotlib
assert hasattr(clearml, '__version__') # verify package is not directory assert hasattr(clearml, '__version__') # verify package is not directory
except (ImportError, AssertionError): except (ImportError, AssertionError):

@ -6,6 +6,7 @@ try:
assert not TESTS_RUNNING # do not log pytest assert not TESTS_RUNNING # do not log pytest
assert SETTINGS['mlflow'] is True # verify integration is enabled assert SETTINGS['mlflow'] is True # verify integration is enabled
import mlflow import mlflow
assert hasattr(mlflow, '__version__') # verify package is not directory assert hasattr(mlflow, '__version__') # verify package is not directory
import os import os

@ -11,9 +11,8 @@ from urllib import parse, request
import requests import requests
import torch import torch
from tqdm import tqdm
from ultralytics.utils import LOGGER, TQDM_BAR_FORMAT, checks, clean_url, emojis, is_online, url2file from ultralytics.utils import LOGGER, TQDM, checks, clean_url, emojis, is_online, url2file
# Define Ultralytics GitHub assets maintained at https://github.com/ultralytics/assets # Define Ultralytics GitHub assets maintained at https://github.com/ultralytics/assets
GITHUB_ASSETS_REPO = 'ultralytics/assets' GITHUB_ASSETS_REPO = 'ultralytics/assets'
@ -101,11 +100,7 @@ def zip_directory(directory, compress=True, exclude=('.DS_Store', '__MACOSX'), p
zip_file = directory.with_suffix('.zip') zip_file = directory.with_suffix('.zip')
compression = ZIP_DEFLATED if compress else ZIP_STORED compression = ZIP_DEFLATED if compress else ZIP_STORED
with ZipFile(zip_file, 'w', compression) as f: with ZipFile(zip_file, 'w', compression) as f:
for file in tqdm(files_to_zip, for file in TQDM(files_to_zip, desc=f'Zipping {directory} to {zip_file}...', unit='file', disable=not progress):
desc=f'Zipping {directory} to {zip_file}...',
unit='file',
disable=not progress,
bar_format=TQDM_BAR_FORMAT):
f.write(file, file.relative_to(directory)) f.write(file, file.relative_to(directory))
return zip_file # return path to zip file return zip_file # return path to zip file
@ -163,11 +158,7 @@ def unzip_file(file, path=None, exclude=('.DS_Store', '__MACOSX'), exist_ok=Fals
LOGGER.warning(f'WARNING ⚠ Skipping {file} unzip as destination directory {path} is not empty.') LOGGER.warning(f'WARNING ⚠ Skipping {file} unzip as destination directory {path} is not empty.')
return path return path
for f in tqdm(files, for f in TQDM(files, desc=f'Unzipping {file} to {Path(path).resolve()}...', unit='file', disable=not progress):
desc=f'Unzipping {file} to {Path(path).resolve()}...',
unit='file',
disable=not progress,
bar_format=TQDM_BAR_FORMAT):
zipObj.extract(f, path=extract_path) zipObj.extract(f, path=extract_path)
return path # return unzip dir return path # return unzip dir
@ -297,13 +288,12 @@ def safe_download(url,
if method == 'torch': if method == 'torch':
torch.hub.download_url_to_file(url, f, progress=progress) torch.hub.download_url_to_file(url, f, progress=progress)
else: else:
with request.urlopen(url) as response, tqdm(total=int(response.getheader('Content-Length', 0)), with request.urlopen(url) as response, TQDM(total=int(response.getheader('Content-Length', 0)),
desc=desc, desc=desc,
disable=not progress, disable=not progress,
unit='B', unit='B',
unit_scale=True, unit_scale=True,
unit_divisor=1024, unit_divisor=1024) as pbar:
bar_format=TQDM_BAR_FORMAT) as pbar:
with open(f, 'wb') as f_opened: with open(f, 'wb') as f_opened:
for data in response: for data in response:
f_opened.write(data) f_opened.write(data)

Loading…
Cancel
Save