`ultralytics 8.0.175` StreamLoader wait for missing frames (#4814)

pull/4822/head v8.0.175
Glenn Jocher 1 year ago committed by GitHub
parent dd0782bd8d
commit 3c88bebc95
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      docs/datasets/detect/sku-110k.md
  2. 130
      docs/guides/hyperparameter-tuning.md
  3. 2
      docs/modes/train.md
  4. 2
      docs/modes/val.md
  5. 43
      tests/conftest.py
  6. 2
      ultralytics/__init__.py
  7. 4
      ultralytics/cfg/__init__.py
  8. 6
      ultralytics/data/build.py
  9. 35
      ultralytics/data/loaders.py
  10. 2
      ultralytics/engine/predictor.py
  11. 2
      ultralytics/hub/session.py

@ -1,6 +1,6 @@
---
comments: true
description: 'Explore the SKU-110k dataset: densely packed retail shelf images for object detection research. Learn how to use it with Ultralytics.'
description: Explore the SKU-110k dataset of densely packed retail shelf images for object detection research. Learn how to use it with Ultralytics.
keywords: SKU-110k dataset, object detection, retail shelf images, Ultralytics, YOLO, computer vision, deep learning models
---

@ -14,16 +14,16 @@ Hyperparameter tuning is not just a one-time set-up but an iterative process aim
Hyperparameters are high-level, structural settings for the algorithm. They are set prior to the training phase and remain constant during it. Here are some commonly tuned hyperparameters in Ultralytics YOLO:
- **Learning Rate**: Determines the step size at each iteration while moving towards a minimum in the loss function.
- **Batch Size**: Number of training samples utilized in one iteration.
- **Number of Epochs**: An epoch is one complete forward and backward pass of all the training examples.
- **Architecture Specifics**: Such as anchor box sizes, number of layers, types of activation functions, etc.
- **Learning Rate** `lr0`: Determines the step size at each iteration while moving towards a minimum in the loss function.
- **Batch Size** `batch`: Number of images processed simultaneously in a forward pass.
- **Number of Epochs** `epochs`: An epoch is one complete forward and backward pass of all the training examples.
- **Architecture Specifics**: Such as channel counts, number of layers, types of activation functions, etc.
<p align="center">
<img width="1000" src="https://user-images.githubusercontent.com/26833433/263858934-4f109a2f-82d9-4d08-8bd6-6fd1ff520bcd.png" alt="Hyperparameter Tuning Visual">
<img width="640" src="https://user-images.githubusercontent.com/26833433/263858934-4f109a2f-82d9-4d08-8bd6-6fd1ff520bcd.png" alt="Hyperparameter Tuning Visual">
</p>
For a full list of augmentation hyperparameters used in YOLOv8 please refer to https://docs.ultralytics.com/usage/cfg/#augmentation.
For a full list of augmentation hyperparameters used in YOLOv8 please refer to [https://docs.ultralytics.com/usage/cfg/#augmentation)(https://docs.ultralytics.com/usage/cfg/#augmentation).
### Genetic Evolution and Mutation
@ -67,7 +67,7 @@ The process is repeated until either the set number of iterations is reached or
## Usage Example
Here's how to use the `model.tune()` method to utilize the `Tuner` class for hyperparameter tuning:
Here's how to use the `model.tune()` method to utilize the `Tuner` class for hyperparameter tuning of YOLOv8n on COCO8 for 30 epochs with an AdamW optimizer and skipping plotting, checkpointing and validation other than on final epoch for faster Tuning.
!!! example ""
@ -79,10 +79,120 @@ Here's how to use the `model.tune()` method to utilize the `Tuner` class for hyp
# Initialize the YOLO model
model = YOLO('yolov8n.pt')
# Perform hyperparameter tuning
model.tune(data='coco8.yaml', imgsz=640, epochs=30, iterations=300)
# Tune hyperparameters on COCO8 for 30 epochs
model.tune(data='coco8.yaml', epochs=30, iterations=300, optimizer='AdamW', plots=False, save=False, val=False)
```
## Results
After you've successfully completed the hyperparameter tuning process, you will obtain several files and directories that encapsulate the results of the tuning. The following describes each:
### File Structure
Here's what the directory structure of the results will look like. Training directories like `train1/` contain individual tuning iterations, i.e. one model trained with one set of hyperparameters. The `tune/` directory contains tuning results from all the individual model trainings:
```plaintext
runs/
└── detect/
├── train1/
├── train2/
├── ...
└── tune/
├── best_hyperparameters.yaml
├── best_fitness.png
├── tune_results.csv
├── tune_scatter_plots.png
└── weights/
├── last.pt
└── best.pt
```
### File Descriptions
#### best_hyperparameters.yaml
This YAML file contains the best-performing hyperparameters found during the tuning process. You can use this file to initialize future trainings with these optimized settings.
- **Format**: YAML
- **Usage**: Hyperparameter results
- **Example**:
```yaml
# 558/900 iterations complete ✅ (45536.81s)
# Results saved to /usr/src/ultralytics/runs/detect/tune
# Best fitness=0.64297 observed at iteration 498
# Best fitness metrics are {'metrics/precision(B)': 0.87247, 'metrics/recall(B)': 0.71387, 'metrics/mAP50(B)': 0.79106, 'metrics/mAP50-95(B)': 0.62651, 'val/box_loss': 2.79884, 'val/cls_loss': 2.72386, 'val/dfl_loss': 0.68503, 'fitness': 0.64297}
# Best fitness model is /usr/src/ultralytics/runs/detect/train498
# Best fitness hyperparameters are printed below.
lr0: 0.00269
lrf: 0.00288
momentum: 0.73375
weight_decay: 0.00015
warmup_epochs: 1.22935
warmup_momentum: 0.1525
box: 18.27875
cls: 1.32899
dfl: 0.56016
hsv_h: 0.01148
hsv_s: 0.53554
hsv_v: 0.13636
degrees: 0.0
translate: 0.12431
scale: 0.07643
shear: 0.0
perspective: 0.0
flipud: 0.0
fliplr: 0.08631
mosaic: 0.42551
mixup: 0.0
copy_paste: 0.0
```
#### best_fitness.png
This is a plot displaying fitness (typically a performance metric like AP50) against the number of iterations. It helps you visualize how well the genetic algorithm performed over time.
- **Format**: PNG
- **Usage**: Performance visualization
<p align="center">
<img width="640" src="https://user-images.githubusercontent.com/26833433/266847423-9d0aea13-d5c4-4771-b06e-0b817a498260.png" alt="Hyperparameter Tuning Fitness vs Iteration">
</p>
#### tune_results.csv
A CSV file containing detailed results of each iteration during the tuning. Each row in the file represents one iteration, and it includes metrics like fitness score, precision, recall, as well as the hyperparameters used.
- **Format**: CSV
- **Usage**: Per-iteration results tracking.
- **Example**:
```csv
fitness,lr0,lrf,momentum,weight_decay,warmup_epochs,warmup_momentum,box,cls,dfl,hsv_h,hsv_s,hsv_v,degrees,translate,scale,shear,perspective,flipud,fliplr,mosaic,mixup,copy_paste
0.05021,0.01,0.01,0.937,0.0005,3.0,0.8,7.5,0.5,1.5,0.015,0.7,0.4,0.0,0.1,0.5,0.0,0.0,0.0,0.5,1.0,0.0,0.0
0.07217,0.01003,0.00967,0.93897,0.00049,2.79757,0.81075,7.5,0.50746,1.44826,0.01503,0.72948,0.40658,0.0,0.0987,0.4922,0.0,0.0,0.0,0.49729,1.0,0.0,0.0
0.06584,0.01003,0.00855,0.91009,0.00073,3.42176,0.95,8.64301,0.54594,1.72261,0.01503,0.59179,0.40658,0.0,0.0987,0.46955,0.0,0.0,0.0,0.49729,0.80187,0.0,0.0
```
#### tune_scatter_plots.png
This file contains scatter plots generated from `tune_results.csv`, helping you visualize relationships between different hyperparameters and performance metrics. Note that hyperparameters initialized to 0 will not be tuned, such as `degrees` and `shear` below.
- **Format**: PNG
- **Usage**: Exploratory data analysis
<p align="center">
<img width="1000" src="https://user-images.githubusercontent.com/26833433/266847488-ec382f3d-79bc-4087-a0e0-42fb8b62cad2.png" alt="Hyperparameter Tuning Scatter Plots">
</p>
#### weights/
This directory contains the saved PyTorch models for the last and the best iterations during the hyperparameter tuning process.
- **`last.pt`**: The last.pt weights for the iteration that achieved the best fitness score.
- **`best.pt`**: The best.pt weights for the iteration that achieved the best fitness score.
Using these results, you can make more informed decisions for your future model trainings and analyses. Feel free to consult these artifacts to understand how well your model performed and how you might improve it further.
## Conclusion
The hyperparameter tuning process in Ultralytics YOLO is simplified yet powerful, thanks to its genetic algorithm-based approach focused on mutation. Following the steps outlined in this guide will assist you in systematically tuning your model to achieve better performance.
@ -93,4 +203,4 @@ The hyperparameter tuning process in Ultralytics YOLO is simplified yet powerful
2. [YOLOv5 Hyperparameter Evolution Guide](https://docs.ultralytics.com/yolov5/tutorials/hyperparameter_evolution/)
3. [Efficient Hyperparameter Tuning with Ray Tune and YOLOv8](https://docs.ultralytics.com/integrations/ray-tune/)
For deeper insights, you can explore the `Tuner` class source code and accompanying documentation. Should you have any questions, feature requests, or need further assistance, feel free to reach out to our support team.
For deeper insights, you can explore the `Tuner` class source code and accompanying documentation. Should you have any questions, feature requests, or need further assistance, feel free to reach out to us on [GitHub](https://github.com/ultralytics/ultralytics/issues/new/choose) or [Discord](https://ultralytics.com/discord)

@ -1,6 +1,6 @@
---
comments: true
description: Step-by-step guide to train YOLOv8 models with Ultralytics YOLO with examples of single-GPU and multi-GPU training. Efficient way for object detection training.
description: Step-by-step guide to train YOLOv8 models with Ultralytics YOLO including examples of single-GPU and multi-GPU training
keywords: Ultralytics, YOLOv8, YOLO, object detection, train mode, custom dataset, GPU training, multi-GPU, hyperparameters, CLI examples, Python examples
---

@ -1,6 +1,6 @@
---
comments: true
description: 'Guide for Validating YOLOv8 Models: Learn how to evaluate the performance of your YOLO models using validation settings and metrics with Python and CLI examples.'
description: Guide for Validating YOLOv8 Models. Learn how to evaluate the performance of your YOLO models using validation settings and metrics with Python and CLI examples.
keywords: Ultralytics, YOLO Docs, YOLOv8, validation, model evaluation, hyperparameters, accuracy, metrics, Python, CLI
---

@ -12,18 +12,32 @@ TMP = (ROOT / '../tests/tmp').resolve() # temp directory for test files
def pytest_addoption(parser):
parser.addoption('--runslow', action='store_true', default=False, help='run slow tests')
"""Add custom command-line options to pytest.
Args:
parser (pytest.config.Parser): The pytest parser object.
"""
parser.addoption('--slow', action='store_true', default=False, help='Run slow tests')
def pytest_configure(config):
"""Register custom markers to avoid pytest warnings.
Args:
config (pytest.config.Config): The pytest config object.
"""
config.addinivalue_line('markers', 'slow: mark test as slow to run')
def pytest_collection_modifyitems(config, items):
if config.getoption('--runslow'):
# --runslow given in cli: do not skip slow tests
return
skip_slow = pytest.mark.skip(reason='need --runslow option to run')
"""Modify collected test items based on custom command-line options.
Args:
config (pytest.config.Config): The pytest config object.
items (list): List of collected test items.
"""
if not config.getoption('--slow'):
skip_slow = pytest.mark.skip(reason="remove this test because it's slow")
for item in items:
if 'slow' in item.keywords:
item.add_marker(skip_slow)
@ -31,7 +45,13 @@ def pytest_collection_modifyitems(config, items):
def pytest_sessionstart(session):
"""
Called after the 'Session' object has been created and before performing test collection.
Initialize session configurations for pytest.
This function is automatically called by pytest after the 'Session' object has been created but before performing
test collection. It sets the initial seeds and prepares the temporary directory for the test session.
Args:
session (pytest.Session): The pytest session object.
"""
init_seeds()
shutil.rmtree(TMP, ignore_errors=True) # delete any existing tests/tmp directory
@ -39,6 +59,17 @@ def pytest_sessionstart(session):
def pytest_terminal_summary(terminalreporter, exitstatus, config):
"""
Cleanup operations after pytest session.
This function is automatically called by pytest at the end of the entire test session. It removes certain files
and directories used during testing.
Args:
terminalreporter (pytest.terminal.TerminalReporter): The terminal reporter object.
exitstatus (int): The exit status of the test run.
config (pytest.config.Config): The pytest config object.
"""
# Remove files
for file in ['bus.jpg', 'decelera_landscape_min.mov']:
Path(file).unlink(missing_ok=True)

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = '8.0.174'
__version__ = '8.0.175'
from ultralytics.models import RTDETR, SAM, YOLO
from ultralytics.models.fastsam import FastSAM

@ -442,9 +442,11 @@ def entrypoint(debug=''):
LOGGER.warning(f"WARNING ⚠ 'format' is missing. Using default 'format={overrides['format']}'.")
# Run command in python
# getattr(model, mode)(**vars(get_cfg(overrides=overrides))) # default args using default.yaml
getattr(model, mode)(**overrides) # default args from model
# Show help
LOGGER.info(f'💡 Learn more at https://docs.ultralytics.com/modes/{mode}')
# Special modes --------------------------------------------------------------------------------------------------------
def copy_default_cfg():

@ -135,7 +135,7 @@ def check_source(source):
return source, webcam, screenshot, from_img, in_memory, tensor
def load_inference_source(source=None, imgsz=640, vid_stride=1, stream_buffer=False):
def load_inference_source(source=None, imgsz=640, vid_stride=1, buffer=False):
"""
Loads an inference source for object detection and applies necessary transformations.
@ -143,7 +143,7 @@ def load_inference_source(source=None, imgsz=640, vid_stride=1, stream_buffer=Fa
source (str, Path, Tensor, PIL.Image, np.ndarray): The input source for inference.
imgsz (int, optional): The size of the image for inference. Default is 640.
vid_stride (int, optional): The frame interval for video sources. Default is 1.
stream_buffer (bool, optional): Determined whether stream frames will be buffered. Default is False.
buffer (bool, optional): Determined whether stream frames will be buffered. Default is False.
Returns:
dataset (Dataset): A dataset object for the specified input source.
@ -157,7 +157,7 @@ def load_inference_source(source=None, imgsz=640, vid_stride=1, stream_buffer=Fa
elif in_memory:
dataset = source
elif webcam:
dataset = LoadStreams(source, imgsz=imgsz, vid_stride=vid_stride, stream_buffer=stream_buffer)
dataset = LoadStreams(source, imgsz=imgsz, vid_stride=vid_stride, buffer=buffer)
elif screenshot:
dataset = LoadScreenshots(source, imgsz=imgsz)
elif from_img:

@ -31,10 +31,10 @@ class SourceTypes:
class LoadStreams:
"""YOLOv8 streamloader, i.e. `yolo predict source='rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP streams`."""
def __init__(self, sources='file.streams', imgsz=640, vid_stride=1, stream_buffer=False):
def __init__(self, sources='file.streams', imgsz=640, vid_stride=1, buffer=False):
"""Initialize instance variables and check for consistent input stream shapes."""
torch.backends.cudnn.benchmark = True # faster for fixed-size inference
self.stream_buffer = stream_buffer # buffer input streams
self.buffer = buffer # buffer input streams
self.running = True # running flag for Thread
self.mode = 'stream'
self.imgsz = imgsz
@ -42,7 +42,7 @@ class LoadStreams:
sources = Path(sources).read_text().rsplit() if os.path.isfile(sources) else [sources]
n = len(sources)
self.sources = [ops.clean_str(x) for x in sources] # clean source names for later
self.imgs, self.fps, self.frames, self.threads, self.shape = [[]] * n, [0] * n, [0] * n, [None] * n, [None] * n
self.imgs, self.fps, self.frames, self.threads, self.shape = [[]] * n, [0] * n, [0] * n, [None] * n, [[]] * n
self.caps = [None] * n # video capture objects
for i, s in enumerate(sources): # index, source
# Start thread to read frames from video stream
@ -81,8 +81,7 @@ class LoadStreams:
"""Read stream `i` frames in daemon thread."""
n, f = 0, self.frames[i] # frame number, frame array
while self.running and cap.isOpened() and n < (f - 1):
# Only read a new frame if the buffer is empty
if not self.imgs[i] or not self.stream_buffer:
if len(self.imgs[i]) < 30: # keep a <=30-image buffer
n += 1
cap.grab() # .read() = .grab() followed by .retrieve()
if n % self.vid_stride == 0:
@ -91,7 +90,10 @@ class LoadStreams:
im = np.zeros(self.shape[i], dtype=np.uint8)
LOGGER.warning('WARNING ⚠ Video stream unresponsive, please check your IP camera connection.')
cap.open(stream) # re-open stream if signal was lost
self.imgs[i].append(im) # add image to buffer
if self.buffer:
self.imgs[i].append(im)
else:
self.imgs[i] = [im]
else:
time.sleep(0.01) # wait until the buffer is empty
@ -117,21 +119,24 @@ class LoadStreams:
"""Returns source paths, transformed and original images for processing."""
self.count += 1
images = []
for i, x in enumerate(self.imgs):
# Wait until a frame is available in each buffer
while not all(self.imgs):
if not all(x.is_alive() for x in self.threads) or cv2.waitKey(1) == ord('q'): # q to quit
while not x:
if not self.threads[i].is_alive() or cv2.waitKey(1) == ord('q'): # q to quit
self.close()
raise StopIteration
LOGGER.warning(f'WARNING ⚠ Waiting for stream {i}')
time.sleep(1 / min(self.fps))
# Get and remove the next frame from imgs buffer
if self.stream_buffer:
images = [x.pop(0) for x in self.imgs]
# Get and remove the first frame from imgs buffer
if self.buffer:
images.append(x.pop(0))
# Get the last frame, and clear the rest from the imgs buffer
else:
# Get the latest frame, and clear the rest from the imgs buffer
images = []
for x in self.imgs:
images.append(x.pop(-1) if x else None)
images.append(x.pop(-1) if x else np.zeros(self.shape[i], dtype=np.uint8))
x.clear()
return self.sources, images, None, ''

@ -207,7 +207,7 @@ class BasePredictor:
self.dataset = load_inference_source(source=source,
imgsz=self.imgsz,
vid_stride=self.args.vid_stride,
stream_buffer=self.args.stream_buffer)
buffer=self.args.stream_buffer)
self.source_type = self.dataset.source_type
if not getattr(self, 'stream', True) and (self.dataset.mode == 'stream' or # streams
len(self.dataset) > 1000 or # images

@ -117,7 +117,7 @@ class HUBTrainingSession:
if data['status'] == 'new': # new model to start training
self.train_args = {
# TODO deprecate before 8.1.0
# TODO deprecate 'batch_size' argument in favor of 'batch'
'batch': data['batch' if 'batch' in data else 'batch_size'],
'epochs': data['epochs'],
'imgsz': data['imgsz'],

Loading…
Cancel
Save