`ultralytics 8.1.16` OBB ConfusionMatrix support (#8299)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
pull/8305/head^2 v8.1.16
Laughing 9 months ago committed by GitHub
parent 42744a1717
commit de01212465
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 2
      .github/workflows/docker.yaml
  2. 2
      docker/Dockerfile
  3. 2
      docker/Dockerfile-cpu
  4. 2
      docker/Dockerfile-python
  5. 2
      docs/en/modes/train.md
  6. 2
      docs/en/usage/cfg.md
  7. 2
      ultralytics/__init__.py
  8. 2
      ultralytics/cfg/default.yaml
  9. 6
      ultralytics/models/yolo/detect/val.py
  10. 9
      ultralytics/models/yolo/obb/val.py
  11. 1
      ultralytics/utils/downloads.py
  12. 14
      ultralytics/utils/metrics.py

@ -115,7 +115,7 @@ jobs:
uses: nick-invision/retry@v3
with:
timeout_minutes: 60
retry_wait_seconds: 0
retry_wait_seconds: 30
max_attempts: 2 # retry once
command: |
docker build \

@ -30,7 +30,7 @@ ADD https://github.com/ultralytics/assets/releases/download/v8.1.0/yolov8n.pt /u
# Install pip packages
RUN python3 -m pip install --upgrade pip wheel
RUN pip install --no-cache -e ".[export]" albumentations comet pycocotools lancedb pytest-cov
RUN pip install --no-cache -e ".[export]" albumentations comet pycocotools
# Run exports to AutoInstall packages
RUN yolo export model=tmp/yolov8n.pt format=edgetpu imgsz=32

@ -28,7 +28,7 @@ RUN rm -rf /usr/lib/python3.11/EXTERNALLY-MANAGED
# Install pip packages
RUN python3 -m pip install --upgrade pip wheel
RUN pip install --no-cache -e ".[export]" lancedb --extra-index-url https://download.pytorch.org/whl/cpu
RUN pip install --no-cache -e ".[export]" --extra-index-url https://download.pytorch.org/whl/cpu
# Run exports to AutoInstall packages
RUN yolo export model=tmp/yolov8n.pt format=edgetpu imgsz=32

@ -28,7 +28,7 @@ ADD https://github.com/ultralytics/assets/releases/download/v8.1.0/yolov8n.pt /u
# Install pip packages
RUN python3 -m pip install --upgrade pip wheel
RUN pip install --no-cache -e ".[export]" lancedb --extra-index-url https://download.pytorch.org/whl/cpu
RUN pip install --no-cache -e ".[export]" --extra-index-url https://download.pytorch.org/whl/cpu
# Run exports to AutoInstall packages
RUN yolo export model=tmp/yolov8n.pt format=edgetpu imgsz=32

@ -181,7 +181,7 @@ Training settings for YOLO models refer to the various hyperparameters and confi
| `data` | `None` | Path to the dataset configuration file (e.g., `coco128.yaml`). This file contains dataset-specific parameters, including paths to training and validation data, class names, and number of classes. |
| `epochs` | `100` | Total number of training epochs. Each epoch represents a full pass over the entire dataset. Adjusting this value can affect training duration and model performance. |
| `time` | `None` | Maximum training time in hours. If set, this overrides the `epochs` argument, allowing training to automatically stop after the specified duration. Useful for time-constrained training scenarios. |
| `patience` | `50` | Number of epochs to wait without improvement in validation metrics before early stopping the training. Helps prevent overfitting by stopping training when performance plateaus. |
| `patience` | `100` | Number of epochs to wait without improvement in validation metrics before early stopping the training. Helps prevent overfitting by stopping training when performance plateaus. |
| `batch` | `16` | Batch size for training, indicating how many images are processed before the model's internal parameters are updated. AutoBatch (`batch=-1`) dynamically adjusts the batch size based on GPU memory availability. |
| `imgsz` | `640` | Target image size for training. All images are resized to this dimension before being fed into the model. Affects model accuracy and computational complexity. |
| `save` | `True` | Enables saving of training checkpoints and final model weights. Useful for resuming training or model deployment. |

@ -89,7 +89,7 @@ The training settings for YOLO models encompass various hyperparameters and conf
| `data` | `None` | Path to the dataset configuration file (e.g., `coco128.yaml`). This file contains dataset-specific parameters, including paths to training and validation data, class names, and number of classes. |
| `epochs` | `100` | Total number of training epochs. Each epoch represents a full pass over the entire dataset. Adjusting this value can affect training duration and model performance. |
| `time` | `None` | Maximum training time in hours. If set, this overrides the `epochs` argument, allowing training to automatically stop after the specified duration. Useful for time-constrained training scenarios. |
| `patience` | `50` | Number of epochs to wait without improvement in validation metrics before early stopping the training. Helps prevent overfitting by stopping training when performance plateaus. |
| `patience` | `100` | Number of epochs to wait without improvement in validation metrics before early stopping the training. Helps prevent overfitting by stopping training when performance plateaus. |
| `batch` | `16` | Batch size for training, indicating how many images are processed before the model's internal parameters are updated. AutoBatch (`batch=-1`) dynamically adjusts the batch size based on GPU memory availability. |
| `imgsz` | `640` | Target image size for training. All images are resized to this dimension before being fed into the model. Affects model accuracy and computational complexity. |
| `save` | `True` | Enables saving of training checkpoints and final model weights. Useful for resuming training or model deployment. |

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.1.15"
__version__ = "8.1.16"
from ultralytics.data.explorer.explorer import Explorer
from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld

@ -9,7 +9,7 @@ model: # (str, optional) path to model file, i.e. yolov8n.pt, yolov8n.yaml
data: # (str, optional) path to data file, i.e. coco128.yaml
epochs: 100 # (int) number of epochs to train for
time: # (float, optional) number of hours to train for, overrides epochs if supplied
patience: 50 # (int) epochs to wait for no observable improvement for early stopping of training
patience: 100 # (int) epochs to wait for no observable improvement for early stopping of training
batch: 16 # (int) number of images per batch (-1 for AutoBatch)
imgsz: 640 # (int | list) input images size as int for train and val modes, or list[w,h] for predict and export modes
save: True # (bool) save train checkpoints and predict results

@ -132,8 +132,7 @@ class DetectionValidator(BaseValidator):
if nl:
for k in self.stats.keys():
self.stats[k].append(stat[k])
# TODO: obb has not supported confusion_matrix yet.
if self.args.plots and self.args.task != "obb":
if self.args.plots:
self.confusion_matrix.process_batch(detections=None, gt_bboxes=bbox, gt_cls=cls)
continue
@ -147,8 +146,7 @@ class DetectionValidator(BaseValidator):
# Evaluate
if nl:
stat["tp"] = self._process_batch(predn, bbox, cls)
# TODO: obb has not supported confusion_matrix yet.
if self.args.plots and self.args.task != "obb":
if self.args.plots:
self.confusion_matrix.process_batch(predn, bbox, cls)
for k in self.stats.keys():
self.stats[k].append(stat[k])

@ -55,10 +55,11 @@ class OBBValidator(DetectionValidator):
Return correct prediction matrix.
Args:
detections (torch.Tensor): Tensor of shape [N, 6] representing detections.
Each detection is of the format: x1, y1, x2, y2, conf, class.
labels (torch.Tensor): Tensor of shape [M, 5] representing labels.
Each label is of the format: class, x1, y1, x2, y2.
detections (torch.Tensor): Tensor of shape [N, 7] representing detections.
Each detection is of the format: x1, y1, x2, y2, conf, class, angle.
gt_bboxes (torch.Tensor): Tensor of shape [M, 5] representing rotated boxes.
Each box is of the format: x1, y1, x2, y2, angle.
labels (torch.Tensor): Tensor of shape [M] representing labels.
Returns:
(torch.Tensor): Correct prediction matrix of shape [N, 10] for 10 IoU levels.

@ -26,6 +26,7 @@ GITHUB_ASSETS_NAMES = (
+ [f"FastSAM-{k}.pt" for k in "sx"]
+ [f"rtdetr-{k}.pt" for k in "lx"]
+ ["mobile_sam.pt"]
+ ["calibration_image_sample_data_20x128x128x3_float32.npy.zip"]
)
GITHUB_ASSETS_STEMS = [Path(k).stem for k in GITHUB_ASSETS_NAMES]

@ -326,9 +326,10 @@ class ConfusionMatrix:
Update confusion matrix for object detection task.
Args:
detections (Array[N, 6]): Detected bounding boxes and their associated information.
Each row should contain (x1, y1, x2, y2, conf, class).
gt_bboxes (Array[M, 4]): Ground truth bounding boxes with xyxy format.
detections (Array[N, 6] | Array[N, 7]): Detected bounding boxes and their associated information.
Each row should contain (x1, y1, x2, y2, conf, class)
or with an additional element `angle` when it's obb.
gt_bboxes (Array[M, 4]| Array[N, 5]): Ground truth bounding boxes with xyxy/xyxyr format.
gt_cls (Array[M]): The class labels.
"""
if gt_cls.shape[0] == 0: # Check if labels is empty
@ -347,7 +348,12 @@ class ConfusionMatrix:
detections = detections[detections[:, 4] > self.conf]
gt_classes = gt_cls.int()
detection_classes = detections[:, 5].int()
iou = box_iou(gt_bboxes, detections[:, :4])
is_obb = detections.shape[1] == 7 and gt_bboxes.shape[1] == 5 # with additional `angle` dimension
iou = (
batch_probiou(gt_bboxes, torch.cat([detections[:, :4], detections[:, -1:]], dim=-1))
if is_obb
else box_iou(gt_bboxes, detections[:, :4])
)
x = torch.where(iou > self.iou_thres)
if x[0].shape[0]:

Loading…
Cancel
Save