From de01212465f92a60e0f9d2c1e838f5d5b38c017f Mon Sep 17 00:00:00 2001 From: Laughing <61612323+Laughing-q@users.noreply.github.com> Date: Mon, 19 Feb 2024 23:59:24 +0800 Subject: [PATCH] `ultralytics 8.1.16` OBB ConfusionMatrix support (#8299) Signed-off-by: Glenn Jocher Co-authored-by: UltralyticsAssistant Co-authored-by: Glenn Jocher --- .github/workflows/docker.yaml | 2 +- docker/Dockerfile | 2 +- docker/Dockerfile-cpu | 2 +- docker/Dockerfile-python | 2 +- docs/en/modes/train.md | 2 +- docs/en/usage/cfg.md | 2 +- ultralytics/__init__.py | 2 +- ultralytics/cfg/default.yaml | 2 +- ultralytics/models/yolo/detect/val.py | 6 ++---- ultralytics/models/yolo/obb/val.py | 9 +++++---- ultralytics/utils/downloads.py | 1 + ultralytics/utils/metrics.py | 14 ++++++++++---- 12 files changed, 26 insertions(+), 20 deletions(-) diff --git a/.github/workflows/docker.yaml b/.github/workflows/docker.yaml index a724b344d9..afd2a66384 100644 --- a/.github/workflows/docker.yaml +++ b/.github/workflows/docker.yaml @@ -115,7 +115,7 @@ jobs: uses: nick-invision/retry@v3 with: timeout_minutes: 60 - retry_wait_seconds: 0 + retry_wait_seconds: 30 max_attempts: 2 # retry once command: | docker build \ diff --git a/docker/Dockerfile b/docker/Dockerfile index af010290db..1dd8e4bb4a 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -30,7 +30,7 @@ ADD https://github.com/ultralytics/assets/releases/download/v8.1.0/yolov8n.pt /u # Install pip packages RUN python3 -m pip install --upgrade pip wheel -RUN pip install --no-cache -e ".[export]" albumentations comet pycocotools lancedb pytest-cov +RUN pip install --no-cache -e ".[export]" albumentations comet pycocotools # Run exports to AutoInstall packages RUN yolo export model=tmp/yolov8n.pt format=edgetpu imgsz=32 diff --git a/docker/Dockerfile-cpu b/docker/Dockerfile-cpu index 712f721d03..f829ebfa24 100644 --- a/docker/Dockerfile-cpu +++ b/docker/Dockerfile-cpu @@ -28,7 +28,7 @@ RUN rm -rf /usr/lib/python3.11/EXTERNALLY-MANAGED # Install pip packages RUN python3 -m pip install --upgrade pip wheel -RUN pip install --no-cache -e ".[export]" lancedb --extra-index-url https://download.pytorch.org/whl/cpu +RUN pip install --no-cache -e ".[export]" --extra-index-url https://download.pytorch.org/whl/cpu # Run exports to AutoInstall packages RUN yolo export model=tmp/yolov8n.pt format=edgetpu imgsz=32 diff --git a/docker/Dockerfile-python b/docker/Dockerfile-python index 03f256b552..8423dbb813 100644 --- a/docker/Dockerfile-python +++ b/docker/Dockerfile-python @@ -28,7 +28,7 @@ ADD https://github.com/ultralytics/assets/releases/download/v8.1.0/yolov8n.pt /u # Install pip packages RUN python3 -m pip install --upgrade pip wheel -RUN pip install --no-cache -e ".[export]" lancedb --extra-index-url https://download.pytorch.org/whl/cpu +RUN pip install --no-cache -e ".[export]" --extra-index-url https://download.pytorch.org/whl/cpu # Run exports to AutoInstall packages RUN yolo export model=tmp/yolov8n.pt format=edgetpu imgsz=32 diff --git a/docs/en/modes/train.md b/docs/en/modes/train.md index 4560125ae8..51e24840cd 100644 --- a/docs/en/modes/train.md +++ b/docs/en/modes/train.md @@ -181,7 +181,7 @@ Training settings for YOLO models refer to the various hyperparameters and confi | `data` | `None` | Path to the dataset configuration file (e.g., `coco128.yaml`). This file contains dataset-specific parameters, including paths to training and validation data, class names, and number of classes. | | `epochs` | `100` | Total number of training epochs. Each epoch represents a full pass over the entire dataset. Adjusting this value can affect training duration and model performance. | | `time` | `None` | Maximum training time in hours. If set, this overrides the `epochs` argument, allowing training to automatically stop after the specified duration. Useful for time-constrained training scenarios. | -| `patience` | `50` | Number of epochs to wait without improvement in validation metrics before early stopping the training. Helps prevent overfitting by stopping training when performance plateaus. | +| `patience` | `100` | Number of epochs to wait without improvement in validation metrics before early stopping the training. Helps prevent overfitting by stopping training when performance plateaus. | | `batch` | `16` | Batch size for training, indicating how many images are processed before the model's internal parameters are updated. AutoBatch (`batch=-1`) dynamically adjusts the batch size based on GPU memory availability. | | `imgsz` | `640` | Target image size for training. All images are resized to this dimension before being fed into the model. Affects model accuracy and computational complexity. | | `save` | `True` | Enables saving of training checkpoints and final model weights. Useful for resuming training or model deployment. | diff --git a/docs/en/usage/cfg.md b/docs/en/usage/cfg.md index a37c0bbca8..4631de2bab 100644 --- a/docs/en/usage/cfg.md +++ b/docs/en/usage/cfg.md @@ -89,7 +89,7 @@ The training settings for YOLO models encompass various hyperparameters and conf | `data` | `None` | Path to the dataset configuration file (e.g., `coco128.yaml`). This file contains dataset-specific parameters, including paths to training and validation data, class names, and number of classes. | | `epochs` | `100` | Total number of training epochs. Each epoch represents a full pass over the entire dataset. Adjusting this value can affect training duration and model performance. | | `time` | `None` | Maximum training time in hours. If set, this overrides the `epochs` argument, allowing training to automatically stop after the specified duration. Useful for time-constrained training scenarios. | -| `patience` | `50` | Number of epochs to wait without improvement in validation metrics before early stopping the training. Helps prevent overfitting by stopping training when performance plateaus. | +| `patience` | `100` | Number of epochs to wait without improvement in validation metrics before early stopping the training. Helps prevent overfitting by stopping training when performance plateaus. | | `batch` | `16` | Batch size for training, indicating how many images are processed before the model's internal parameters are updated. AutoBatch (`batch=-1`) dynamically adjusts the batch size based on GPU memory availability. | | `imgsz` | `640` | Target image size for training. All images are resized to this dimension before being fed into the model. Affects model accuracy and computational complexity. | | `save` | `True` | Enables saving of training checkpoints and final model weights. Useful for resuming training or model deployment. | diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py index b70fd42a70..4d1725234d 100644 --- a/ultralytics/__init__.py +++ b/ultralytics/__init__.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -__version__ = "8.1.15" +__version__ = "8.1.16" from ultralytics.data.explorer.explorer import Explorer from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld diff --git a/ultralytics/cfg/default.yaml b/ultralytics/cfg/default.yaml index fa4b45a71c..1c1d086d97 100644 --- a/ultralytics/cfg/default.yaml +++ b/ultralytics/cfg/default.yaml @@ -9,7 +9,7 @@ model: # (str, optional) path to model file, i.e. yolov8n.pt, yolov8n.yaml data: # (str, optional) path to data file, i.e. coco128.yaml epochs: 100 # (int) number of epochs to train for time: # (float, optional) number of hours to train for, overrides epochs if supplied -patience: 50 # (int) epochs to wait for no observable improvement for early stopping of training +patience: 100 # (int) epochs to wait for no observable improvement for early stopping of training batch: 16 # (int) number of images per batch (-1 for AutoBatch) imgsz: 640 # (int | list) input images size as int for train and val modes, or list[w,h] for predict and export modes save: True # (bool) save train checkpoints and predict results diff --git a/ultralytics/models/yolo/detect/val.py b/ultralytics/models/yolo/detect/val.py index 33d1610fd2..4ca307b80f 100644 --- a/ultralytics/models/yolo/detect/val.py +++ b/ultralytics/models/yolo/detect/val.py @@ -132,8 +132,7 @@ class DetectionValidator(BaseValidator): if nl: for k in self.stats.keys(): self.stats[k].append(stat[k]) - # TODO: obb has not supported confusion_matrix yet. - if self.args.plots and self.args.task != "obb": + if self.args.plots: self.confusion_matrix.process_batch(detections=None, gt_bboxes=bbox, gt_cls=cls) continue @@ -147,8 +146,7 @@ class DetectionValidator(BaseValidator): # Evaluate if nl: stat["tp"] = self._process_batch(predn, bbox, cls) - # TODO: obb has not supported confusion_matrix yet. - if self.args.plots and self.args.task != "obb": + if self.args.plots: self.confusion_matrix.process_batch(predn, bbox, cls) for k in self.stats.keys(): self.stats[k].append(stat[k]) diff --git a/ultralytics/models/yolo/obb/val.py b/ultralytics/models/yolo/obb/val.py index accf3d72d5..c440fe2b19 100644 --- a/ultralytics/models/yolo/obb/val.py +++ b/ultralytics/models/yolo/obb/val.py @@ -55,10 +55,11 @@ class OBBValidator(DetectionValidator): Return correct prediction matrix. Args: - detections (torch.Tensor): Tensor of shape [N, 6] representing detections. - Each detection is of the format: x1, y1, x2, y2, conf, class. - labels (torch.Tensor): Tensor of shape [M, 5] representing labels. - Each label is of the format: class, x1, y1, x2, y2. + detections (torch.Tensor): Tensor of shape [N, 7] representing detections. + Each detection is of the format: x1, y1, x2, y2, conf, class, angle. + gt_bboxes (torch.Tensor): Tensor of shape [M, 5] representing rotated boxes. + Each box is of the format: x1, y1, x2, y2, angle. + labels (torch.Tensor): Tensor of shape [M] representing labels. Returns: (torch.Tensor): Correct prediction matrix of shape [N, 10] for 10 IoU levels. diff --git a/ultralytics/utils/downloads.py b/ultralytics/utils/downloads.py index a91c24ccfe..213145624f 100644 --- a/ultralytics/utils/downloads.py +++ b/ultralytics/utils/downloads.py @@ -26,6 +26,7 @@ GITHUB_ASSETS_NAMES = ( + [f"FastSAM-{k}.pt" for k in "sx"] + [f"rtdetr-{k}.pt" for k in "lx"] + ["mobile_sam.pt"] + + ["calibration_image_sample_data_20x128x128x3_float32.npy.zip"] ) GITHUB_ASSETS_STEMS = [Path(k).stem for k in GITHUB_ASSETS_NAMES] diff --git a/ultralytics/utils/metrics.py b/ultralytics/utils/metrics.py index 7d79df51cb..17c0782b38 100644 --- a/ultralytics/utils/metrics.py +++ b/ultralytics/utils/metrics.py @@ -326,9 +326,10 @@ class ConfusionMatrix: Update confusion matrix for object detection task. Args: - detections (Array[N, 6]): Detected bounding boxes and their associated information. - Each row should contain (x1, y1, x2, y2, conf, class). - gt_bboxes (Array[M, 4]): Ground truth bounding boxes with xyxy format. + detections (Array[N, 6] | Array[N, 7]): Detected bounding boxes and their associated information. + Each row should contain (x1, y1, x2, y2, conf, class) + or with an additional element `angle` when it's obb. + gt_bboxes (Array[M, 4]| Array[N, 5]): Ground truth bounding boxes with xyxy/xyxyr format. gt_cls (Array[M]): The class labels. """ if gt_cls.shape[0] == 0: # Check if labels is empty @@ -347,7 +348,12 @@ class ConfusionMatrix: detections = detections[detections[:, 4] > self.conf] gt_classes = gt_cls.int() detection_classes = detections[:, 5].int() - iou = box_iou(gt_bboxes, detections[:, :4]) + is_obb = detections.shape[1] == 7 and gt_bboxes.shape[1] == 5 # with additional `angle` dimension + iou = ( + batch_probiou(gt_bboxes, torch.cat([detections[:, :4], detections[:, -1:]], dim=-1)) + if is_obb + else box_iou(gt_bboxes, detections[:, :4]) + ) x = torch.where(iou > self.iou_thres) if x[0].shape[0]: