From 3c88bebc9514a4d7f70b771811ddfe3a625ef14d Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sun, 10 Sep 2023 23:59:43 +0200 Subject: [PATCH] `ultralytics 8.0.175` StreamLoader wait for missing frames (#4814) --- docs/datasets/detect/sku-110k.md | 2 +- docs/guides/hyperparameter-tuning.md | 130 ++++++++++++++++++++++++--- docs/modes/train.md | 2 +- docs/modes/val.md | 2 +- tests/conftest.py | 49 ++++++++-- ultralytics/__init__.py | 2 +- ultralytics/cfg/__init__.py | 4 +- ultralytics/data/build.py | 6 +- ultralytics/data/loaders.py | 45 +++++----- ultralytics/engine/predictor.py | 2 +- ultralytics/hub/session.py | 2 +- 11 files changed, 197 insertions(+), 49 deletions(-) diff --git a/docs/datasets/detect/sku-110k.md b/docs/datasets/detect/sku-110k.md index 07c35793b9..bfa2b49fa4 100644 --- a/docs/datasets/detect/sku-110k.md +++ b/docs/datasets/detect/sku-110k.md @@ -1,6 +1,6 @@ --- comments: true -description: 'Explore the SKU-110k dataset: densely packed retail shelf images for object detection research. Learn how to use it with Ultralytics.' +description: Explore the SKU-110k dataset of densely packed retail shelf images for object detection research. Learn how to use it with Ultralytics. keywords: SKU-110k dataset, object detection, retail shelf images, Ultralytics, YOLO, computer vision, deep learning models --- diff --git a/docs/guides/hyperparameter-tuning.md b/docs/guides/hyperparameter-tuning.md index 02d24872e7..e6a3f310bd 100644 --- a/docs/guides/hyperparameter-tuning.md +++ b/docs/guides/hyperparameter-tuning.md @@ -14,16 +14,16 @@ Hyperparameter tuning is not just a one-time set-up but an iterative process aim Hyperparameters are high-level, structural settings for the algorithm. They are set prior to the training phase and remain constant during it. Here are some commonly tuned hyperparameters in Ultralytics YOLO: -- **Learning Rate**: Determines the step size at each iteration while moving towards a minimum in the loss function. -- **Batch Size**: Number of training samples utilized in one iteration. -- **Number of Epochs**: An epoch is one complete forward and backward pass of all the training examples. -- **Architecture Specifics**: Such as anchor box sizes, number of layers, types of activation functions, etc. +- **Learning Rate** `lr0`: Determines the step size at each iteration while moving towards a minimum in the loss function. +- **Batch Size** `batch`: Number of images processed simultaneously in a forward pass. +- **Number of Epochs** `epochs`: An epoch is one complete forward and backward pass of all the training examples. +- **Architecture Specifics**: Such as channel counts, number of layers, types of activation functions, etc.

- Hyperparameter Tuning Visual + Hyperparameter Tuning Visual

-For a full list of augmentation hyperparameters used in YOLOv8 please refer to https://docs.ultralytics.com/usage/cfg/#augmentation. +For a full list of augmentation hyperparameters used in YOLOv8 please refer to [https://docs.ultralytics.com/usage/cfg/#augmentation)(https://docs.ultralytics.com/usage/cfg/#augmentation). ### Genetic Evolution and Mutation @@ -67,7 +67,7 @@ The process is repeated until either the set number of iterations is reached or ## Usage Example -Here's how to use the `model.tune()` method to utilize the `Tuner` class for hyperparameter tuning: +Here's how to use the `model.tune()` method to utilize the `Tuner` class for hyperparameter tuning of YOLOv8n on COCO8 for 30 epochs with an AdamW optimizer and skipping plotting, checkpointing and validation other than on final epoch for faster Tuning. !!! example "" @@ -79,10 +79,120 @@ Here's how to use the `model.tune()` method to utilize the `Tuner` class for hyp # Initialize the YOLO model model = YOLO('yolov8n.pt') - # Perform hyperparameter tuning - model.tune(data='coco8.yaml', imgsz=640, epochs=30, iterations=300) + # Tune hyperparameters on COCO8 for 30 epochs + model.tune(data='coco8.yaml', epochs=30, iterations=300, optimizer='AdamW', plots=False, save=False, val=False) ``` +## Results + +After you've successfully completed the hyperparameter tuning process, you will obtain several files and directories that encapsulate the results of the tuning. The following describes each: + +### File Structure + +Here's what the directory structure of the results will look like. Training directories like `train1/` contain individual tuning iterations, i.e. one model trained with one set of hyperparameters. The `tune/` directory contains tuning results from all the individual model trainings: + +```plaintext +runs/ +└── detect/ + ├── train1/ + ├── train2/ + ├── ... + └── tune/ + ├── best_hyperparameters.yaml + ├── best_fitness.png + ├── tune_results.csv + ├── tune_scatter_plots.png + └── weights/ + ├── last.pt + └── best.pt +``` + +### File Descriptions + +#### best_hyperparameters.yaml + +This YAML file contains the best-performing hyperparameters found during the tuning process. You can use this file to initialize future trainings with these optimized settings. + +- **Format**: YAML +- **Usage**: Hyperparameter results +- **Example**: + ```yaml + # 558/900 iterations complete ✅ (45536.81s) + # Results saved to /usr/src/ultralytics/runs/detect/tune + # Best fitness=0.64297 observed at iteration 498 + # Best fitness metrics are {'metrics/precision(B)': 0.87247, 'metrics/recall(B)': 0.71387, 'metrics/mAP50(B)': 0.79106, 'metrics/mAP50-95(B)': 0.62651, 'val/box_loss': 2.79884, 'val/cls_loss': 2.72386, 'val/dfl_loss': 0.68503, 'fitness': 0.64297} + # Best fitness model is /usr/src/ultralytics/runs/detect/train498 + # Best fitness hyperparameters are printed below. + + lr0: 0.00269 + lrf: 0.00288 + momentum: 0.73375 + weight_decay: 0.00015 + warmup_epochs: 1.22935 + warmup_momentum: 0.1525 + box: 18.27875 + cls: 1.32899 + dfl: 0.56016 + hsv_h: 0.01148 + hsv_s: 0.53554 + hsv_v: 0.13636 + degrees: 0.0 + translate: 0.12431 + scale: 0.07643 + shear: 0.0 + perspective: 0.0 + flipud: 0.0 + fliplr: 0.08631 + mosaic: 0.42551 + mixup: 0.0 + copy_paste: 0.0 + ``` + +#### best_fitness.png + +This is a plot displaying fitness (typically a performance metric like AP50) against the number of iterations. It helps you visualize how well the genetic algorithm performed over time. + +- **Format**: PNG +- **Usage**: Performance visualization + +

+ Hyperparameter Tuning Fitness vs Iteration +

+ +#### tune_results.csv + +A CSV file containing detailed results of each iteration during the tuning. Each row in the file represents one iteration, and it includes metrics like fitness score, precision, recall, as well as the hyperparameters used. + +- **Format**: CSV +- **Usage**: Per-iteration results tracking. +- **Example**: + ```csv + fitness,lr0,lrf,momentum,weight_decay,warmup_epochs,warmup_momentum,box,cls,dfl,hsv_h,hsv_s,hsv_v,degrees,translate,scale,shear,perspective,flipud,fliplr,mosaic,mixup,copy_paste + 0.05021,0.01,0.01,0.937,0.0005,3.0,0.8,7.5,0.5,1.5,0.015,0.7,0.4,0.0,0.1,0.5,0.0,0.0,0.0,0.5,1.0,0.0,0.0 + 0.07217,0.01003,0.00967,0.93897,0.00049,2.79757,0.81075,7.5,0.50746,1.44826,0.01503,0.72948,0.40658,0.0,0.0987,0.4922,0.0,0.0,0.0,0.49729,1.0,0.0,0.0 + 0.06584,0.01003,0.00855,0.91009,0.00073,3.42176,0.95,8.64301,0.54594,1.72261,0.01503,0.59179,0.40658,0.0,0.0987,0.46955,0.0,0.0,0.0,0.49729,0.80187,0.0,0.0 + ``` + +#### tune_scatter_plots.png + +This file contains scatter plots generated from `tune_results.csv`, helping you visualize relationships between different hyperparameters and performance metrics. Note that hyperparameters initialized to 0 will not be tuned, such as `degrees` and `shear` below. + +- **Format**: PNG +- **Usage**: Exploratory data analysis + +

+ Hyperparameter Tuning Scatter Plots +

+ +#### weights/ + +This directory contains the saved PyTorch models for the last and the best iterations during the hyperparameter tuning process. + +- **`last.pt`**: The last.pt weights for the iteration that achieved the best fitness score. +- **`best.pt`**: The best.pt weights for the iteration that achieved the best fitness score. + +Using these results, you can make more informed decisions for your future model trainings and analyses. Feel free to consult these artifacts to understand how well your model performed and how you might improve it further. + ## Conclusion The hyperparameter tuning process in Ultralytics YOLO is simplified yet powerful, thanks to its genetic algorithm-based approach focused on mutation. Following the steps outlined in this guide will assist you in systematically tuning your model to achieve better performance. @@ -93,4 +203,4 @@ The hyperparameter tuning process in Ultralytics YOLO is simplified yet powerful 2. [YOLOv5 Hyperparameter Evolution Guide](https://docs.ultralytics.com/yolov5/tutorials/hyperparameter_evolution/) 3. [Efficient Hyperparameter Tuning with Ray Tune and YOLOv8](https://docs.ultralytics.com/integrations/ray-tune/) -For deeper insights, you can explore the `Tuner` class source code and accompanying documentation. Should you have any questions, feature requests, or need further assistance, feel free to reach out to our support team. +For deeper insights, you can explore the `Tuner` class source code and accompanying documentation. Should you have any questions, feature requests, or need further assistance, feel free to reach out to us on [GitHub](https://github.com/ultralytics/ultralytics/issues/new/choose) or [Discord](https://ultralytics.com/discord) \ No newline at end of file diff --git a/docs/modes/train.md b/docs/modes/train.md index ad3b4642c6..0e60804e0e 100644 --- a/docs/modes/train.md +++ b/docs/modes/train.md @@ -1,6 +1,6 @@ --- comments: true -description: Step-by-step guide to train YOLOv8 models with Ultralytics YOLO with examples of single-GPU and multi-GPU training. Efficient way for object detection training. +description: Step-by-step guide to train YOLOv8 models with Ultralytics YOLO including examples of single-GPU and multi-GPU training keywords: Ultralytics, YOLOv8, YOLO, object detection, train mode, custom dataset, GPU training, multi-GPU, hyperparameters, CLI examples, Python examples --- diff --git a/docs/modes/val.md b/docs/modes/val.md index 7677267cdb..abdce4fe84 100644 --- a/docs/modes/val.md +++ b/docs/modes/val.md @@ -1,6 +1,6 @@ --- comments: true -description: 'Guide for Validating YOLOv8 Models: Learn how to evaluate the performance of your YOLO models using validation settings and metrics with Python and CLI examples.' +description: Guide for Validating YOLOv8 Models. Learn how to evaluate the performance of your YOLO models using validation settings and metrics with Python and CLI examples. keywords: Ultralytics, YOLO Docs, YOLOv8, validation, model evaluation, hyperparameters, accuracy, metrics, Python, CLI --- diff --git a/tests/conftest.py b/tests/conftest.py index 50ea353f0c..0afb652bf1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,26 +12,46 @@ TMP = (ROOT / '../tests/tmp').resolve() # temp directory for test files def pytest_addoption(parser): - parser.addoption('--runslow', action='store_true', default=False, help='run slow tests') + """Add custom command-line options to pytest. + + Args: + parser (pytest.config.Parser): The pytest parser object. + """ + parser.addoption('--slow', action='store_true', default=False, help='Run slow tests') def pytest_configure(config): + """Register custom markers to avoid pytest warnings. + + Args: + config (pytest.config.Config): The pytest config object. + """ config.addinivalue_line('markers', 'slow: mark test as slow to run') def pytest_collection_modifyitems(config, items): - if config.getoption('--runslow'): - # --runslow given in cli: do not skip slow tests - return - skip_slow = pytest.mark.skip(reason='need --runslow option to run') - for item in items: - if 'slow' in item.keywords: - item.add_marker(skip_slow) + """Modify collected test items based on custom command-line options. + + Args: + config (pytest.config.Config): The pytest config object. + items (list): List of collected test items. + """ + if not config.getoption('--slow'): + skip_slow = pytest.mark.skip(reason="remove this test because it's slow") + for item in items: + if 'slow' in item.keywords: + item.add_marker(skip_slow) def pytest_sessionstart(session): """ - Called after the 'Session' object has been created and before performing test collection. + Initialize session configurations for pytest. + + This function is automatically called by pytest after the 'Session' object has been created but before performing + test collection. It sets the initial seeds and prepares the temporary directory for the test session. + + Args: + session (pytest.Session): The pytest session object. """ init_seeds() shutil.rmtree(TMP, ignore_errors=True) # delete any existing tests/tmp directory @@ -39,6 +59,17 @@ def pytest_sessionstart(session): def pytest_terminal_summary(terminalreporter, exitstatus, config): + """ + Cleanup operations after pytest session. + + This function is automatically called by pytest at the end of the entire test session. It removes certain files + and directories used during testing. + + Args: + terminalreporter (pytest.terminal.TerminalReporter): The terminal reporter object. + exitstatus (int): The exit status of the test run. + config (pytest.config.Config): The pytest config object. + """ # Remove files for file in ['bus.jpg', 'decelera_landscape_min.mov']: Path(file).unlink(missing_ok=True) diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py index 7a9f7f8cf6..ca8d475235 100644 --- a/ultralytics/__init__.py +++ b/ultralytics/__init__.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -__version__ = '8.0.174' +__version__ = '8.0.175' from ultralytics.models import RTDETR, SAM, YOLO from ultralytics.models.fastsam import FastSAM diff --git a/ultralytics/cfg/__init__.py b/ultralytics/cfg/__init__.py index cc37ee560c..1e72e2db17 100644 --- a/ultralytics/cfg/__init__.py +++ b/ultralytics/cfg/__init__.py @@ -442,9 +442,11 @@ def entrypoint(debug=''): LOGGER.warning(f"WARNING ⚠️ 'format' is missing. Using default 'format={overrides['format']}'.") # Run command in python - # getattr(model, mode)(**vars(get_cfg(overrides=overrides))) # default args using default.yaml getattr(model, mode)(**overrides) # default args from model + # Show help + LOGGER.info(f'💡 Learn more at https://docs.ultralytics.com/modes/{mode}') + # Special modes -------------------------------------------------------------------------------------------------------- def copy_default_cfg(): diff --git a/ultralytics/data/build.py b/ultralytics/data/build.py index 9d40e5a0a7..4186146201 100644 --- a/ultralytics/data/build.py +++ b/ultralytics/data/build.py @@ -135,7 +135,7 @@ def check_source(source): return source, webcam, screenshot, from_img, in_memory, tensor -def load_inference_source(source=None, imgsz=640, vid_stride=1, stream_buffer=False): +def load_inference_source(source=None, imgsz=640, vid_stride=1, buffer=False): """ Loads an inference source for object detection and applies necessary transformations. @@ -143,7 +143,7 @@ def load_inference_source(source=None, imgsz=640, vid_stride=1, stream_buffer=Fa source (str, Path, Tensor, PIL.Image, np.ndarray): The input source for inference. imgsz (int, optional): The size of the image for inference. Default is 640. vid_stride (int, optional): The frame interval for video sources. Default is 1. - stream_buffer (bool, optional): Determined whether stream frames will be buffered. Default is False. + buffer (bool, optional): Determined whether stream frames will be buffered. Default is False. Returns: dataset (Dataset): A dataset object for the specified input source. @@ -157,7 +157,7 @@ def load_inference_source(source=None, imgsz=640, vid_stride=1, stream_buffer=Fa elif in_memory: dataset = source elif webcam: - dataset = LoadStreams(source, imgsz=imgsz, vid_stride=vid_stride, stream_buffer=stream_buffer) + dataset = LoadStreams(source, imgsz=imgsz, vid_stride=vid_stride, buffer=buffer) elif screenshot: dataset = LoadScreenshots(source, imgsz=imgsz) elif from_img: diff --git a/ultralytics/data/loaders.py b/ultralytics/data/loaders.py index 8aec7fad9c..8524f6d594 100644 --- a/ultralytics/data/loaders.py +++ b/ultralytics/data/loaders.py @@ -31,10 +31,10 @@ class SourceTypes: class LoadStreams: """YOLOv8 streamloader, i.e. `yolo predict source='rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP streams`.""" - def __init__(self, sources='file.streams', imgsz=640, vid_stride=1, stream_buffer=False): + def __init__(self, sources='file.streams', imgsz=640, vid_stride=1, buffer=False): """Initialize instance variables and check for consistent input stream shapes.""" torch.backends.cudnn.benchmark = True # faster for fixed-size inference - self.stream_buffer = stream_buffer # buffer input streams + self.buffer = buffer # buffer input streams self.running = True # running flag for Thread self.mode = 'stream' self.imgsz = imgsz @@ -42,7 +42,7 @@ class LoadStreams: sources = Path(sources).read_text().rsplit() if os.path.isfile(sources) else [sources] n = len(sources) self.sources = [ops.clean_str(x) for x in sources] # clean source names for later - self.imgs, self.fps, self.frames, self.threads, self.shape = [[]] * n, [0] * n, [0] * n, [None] * n, [None] * n + self.imgs, self.fps, self.frames, self.threads, self.shape = [[]] * n, [0] * n, [0] * n, [None] * n, [[]] * n self.caps = [None] * n # video capture objects for i, s in enumerate(sources): # index, source # Start thread to read frames from video stream @@ -81,8 +81,7 @@ class LoadStreams: """Read stream `i` frames in daemon thread.""" n, f = 0, self.frames[i] # frame number, frame array while self.running and cap.isOpened() and n < (f - 1): - # Only read a new frame if the buffer is empty - if not self.imgs[i] or not self.stream_buffer: + if len(self.imgs[i]) < 30: # keep a <=30-image buffer n += 1 cap.grab() # .read() = .grab() followed by .retrieve() if n % self.vid_stride == 0: @@ -91,7 +90,10 @@ class LoadStreams: im = np.zeros(self.shape[i], dtype=np.uint8) LOGGER.warning('WARNING ⚠️ Video stream unresponsive, please check your IP camera connection.') cap.open(stream) # re-open stream if signal was lost - self.imgs[i].append(im) # add image to buffer + if self.buffer: + self.imgs[i].append(im) + else: + self.imgs[i] = [im] else: time.sleep(0.01) # wait until the buffer is empty @@ -117,21 +119,24 @@ class LoadStreams: """Returns source paths, transformed and original images for processing.""" self.count += 1 - # Wait until a frame is available in each buffer - while not all(self.imgs): - if not all(x.is_alive() for x in self.threads) or cv2.waitKey(1) == ord('q'): # q to quit - self.close() - raise StopIteration - time.sleep(1 / min(self.fps)) + images = [] + for i, x in enumerate(self.imgs): - # Get and remove the next frame from imgs buffer - if self.stream_buffer: - images = [x.pop(0) for x in self.imgs] - else: - # Get the latest frame, and clear the rest from the imgs buffer - images = [] - for x in self.imgs: - images.append(x.pop(-1) if x else None) + # Wait until a frame is available in each buffer + while not x: + if not self.threads[i].is_alive() or cv2.waitKey(1) == ord('q'): # q to quit + self.close() + raise StopIteration + LOGGER.warning(f'WARNING ⚠️ Waiting for stream {i}') + time.sleep(1 / min(self.fps)) + + # Get and remove the first frame from imgs buffer + if self.buffer: + images.append(x.pop(0)) + + # Get the last frame, and clear the rest from the imgs buffer + else: + images.append(x.pop(-1) if x else np.zeros(self.shape[i], dtype=np.uint8)) x.clear() return self.sources, images, None, '' diff --git a/ultralytics/engine/predictor.py b/ultralytics/engine/predictor.py index 9da29ee8f6..7a832b720b 100644 --- a/ultralytics/engine/predictor.py +++ b/ultralytics/engine/predictor.py @@ -207,7 +207,7 @@ class BasePredictor: self.dataset = load_inference_source(source=source, imgsz=self.imgsz, vid_stride=self.args.vid_stride, - stream_buffer=self.args.stream_buffer) + buffer=self.args.stream_buffer) self.source_type = self.dataset.source_type if not getattr(self, 'stream', True) and (self.dataset.mode == 'stream' or # streams len(self.dataset) > 1000 or # images diff --git a/ultralytics/hub/session.py b/ultralytics/hub/session.py index 660d2c66df..f69690d512 100644 --- a/ultralytics/hub/session.py +++ b/ultralytics/hub/session.py @@ -117,7 +117,7 @@ class HUBTrainingSession: if data['status'] == 'new': # new model to start training self.train_args = { - # TODO deprecate before 8.1.0 + # TODO deprecate 'batch_size' argument in favor of 'batch' 'batch': data['batch' if 'batch' in data else 'batch_size'], 'epochs': data['epochs'], 'imgsz': data['imgsz'],