From a6a2c256d481645ce80988a512d364d4bd8ef35e Mon Sep 17 00:00:00 2001 From: Laughing <61612323+Laughing-q@users.noreply.github.com> Date: Tue, 9 Jan 2024 02:54:09 +0800 Subject: [PATCH] Add dota8.yaml and O tests (#7394) Co-authored-by: Glenn Jocher --- README.md | 2 +- docs/en/datasets/obb/dota-v2.md | 29 ++++++++- docs/en/datasets/obb/dota8.md | 81 ++++++++++++++++++++++++ docs/en/tasks/obb.md | 16 ++--- docs/mkdocs.yml | 1 + tests/test_cli.py | 6 +- tests/test_python.py | 3 + ultralytics/cfg/datasets/DOTAv1.5.yaml | 2 +- ultralytics/cfg/datasets/DOTAv1.yaml | 2 +- ultralytics/cfg/datasets/dota8.yaml | 34 ++++++++++ ultralytics/engine/results.py | 3 + ultralytics/models/yolo/obb/val.py | 11 ++++ ultralytics/models/yolo/segment/train.py | 2 +- 13 files changed, 176 insertions(+), 16 deletions(-) create mode 100644 docs/en/datasets/obb/dota8.md create mode 100644 ultralytics/cfg/datasets/dota8.yaml diff --git a/README.md b/README.md index 943d510402..42c5b85f15 100644 --- a/README.md +++ b/README.md @@ -199,7 +199,7 @@ See [OBB Docs](https://docs.ultralytics.com/tasks/obb/) for usage examples with | [YOLOv8l-obb](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8l-obb.pt) | 1024 | 80.7 | 1278.42 | 11.83 | 44.5 | 433.8 | | [YOLOv8x-obb](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8x-obb.pt) | 1024 | 81.36 | 1759.10 | 13.23 | 69.5 | 676.7 | -- **mAPtest** values are for single-model multi-scale on [DOTAv1](https://captain-whu.github.io/DOTA/index.html) dataset.
Reproduce by `yolo val obb data=DOTAv1.yaml device=0 split=test` +- **mAPtest** values are for single-model multi-scale on [DOTAv1](https://captain-whu.github.io/DOTA/index.html) dataset.
Reproduce by `yolo val obb data=DOTAv1.yaml device=0 split=test` and submit merged results to [DOTA evaluation](https://captain-whu.github.io/DOTA/evaluation.html). - **Speed** averaged over DOTAv1 val images using an [Amazon EC2 P4d](https://aws.amazon.com/ec2/instance-types/p4/) instance.
Reproduce by `yolo val obb data=DOTAv1.yaml batch=1 device=0|cpu` diff --git a/docs/en/datasets/obb/dota-v2.md b/docs/en/datasets/obb/dota-v2.md index 81e9a398a2..42d64f653a 100644 --- a/docs/en/datasets/obb/dota-v2.md +++ b/docs/en/datasets/obb/dota-v2.md @@ -66,13 +66,40 @@ Typically, datasets incorporate a YAML (Yet Another Markup Language) file detail --8<-- "ultralytics/cfg/datasets/DOTAv1.yaml" ``` +## Split DOTA images + +To train DOTA dataset, We split original DOTA images with high-resolution into images with 1024x1024 resolution in multi-scale way. + +!!! Example "Split images" + + === "Python" + + ```python + from ultralytics.data.split_dota import split_trainval, split_test + + # split train and val set, with labels. + split_trainval( + data_root='path/to/DOTAv1.0/', + save_dir='path/to/DOTAv1.0-split/', + rates=[0.5, 1.0, 1.5], # multi-scale + gap=500 + ) + # split test set, without labels. + split_test( + data_root='path/to/DOTAv1.0/', + save_dir='path/to/DOTAv1.0-split/', + rates=[0.5, 1.0, 1.5], # multi-scale + gap=500 + ) + ``` + ## Usage To train a model on the DOTA v1 dataset, you can utilize the following code snippets. Always refer to your model's documentation for a thorough list of available arguments. !!! Warning - Please note that all images and associated annotations in the DOTAv2 dataset can be used for academic purposes, but commercial use is prohibited. Your understanding and respect for the dataset creators' wishes are greatly appreciated! + Please note that all images and associated annotations in the DOTAv1 dataset can be used for academic purposes, but commercial use is prohibited. Your understanding and respect for the dataset creators' wishes are greatly appreciated! !!! Example "Train Example" diff --git a/docs/en/datasets/obb/dota8.md b/docs/en/datasets/obb/dota8.md new file mode 100644 index 0000000000..c246d6d207 --- /dev/null +++ b/docs/en/datasets/obb/dota8.md @@ -0,0 +1,81 @@ +--- +comments: true +description: Discover the versatile DOTA8 dataset, perfect for testing and debugging oriented detection models. Learn how to get started with YOLOv8-obb model training. +keywords: Ultralytics, YOLOv8, oriented detection, DOTA8 dataset, dataset, model training, YAML +--- + +# DOTA8 Dataset + +## Introduction + +[Ultralytics](https://ultralytics.com) DOTA8 is a small, but versatile oriented object detection dataset composed of the first 8 images of 8 images of the split DOTAv1 set, 4 for training and 4 for validation. This dataset is ideal for testing and debugging object detection models, or for experimenting with new detection approaches. With 8 images, it is small enough to be easily manageable, yet diverse enough to test training pipelines for errors and act as a sanity check before training larger datasets. + +This dataset is intended for use with Ultralytics [HUB](https://hub.ultralytics.com) and [YOLOv8](https://github.com/ultralytics/ultralytics). + +## Dataset YAML + +A YAML (Yet Another Markup Language) file is used to define the dataset configuration. It contains information about the dataset's paths, classes, and other relevant information. In the case of the DOTA8 dataset, the `dota8.yaml` file is maintained at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/cfg/datasets/dota8.yaml](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/cfg/datasets/dota8.yaml). + +!!! Example "ultralytics/cfg/datasets/dota8.yaml" + + ```yaml + --8<-- "ultralytics/cfg/datasets/dota8.yaml" + ``` + +## Usage + +To train a YOLOv8n-obb model on the DOTA8 dataset for 100 epochs with an image size of 640, you can use the following code snippets. For a comprehensive list of available arguments, refer to the model [Training](../../modes/train.md) page. + +!!! Example "Train Example" + + === "Python" + + ```python + from ultralytics import YOLO + + # Load a model + model = YOLO('yolov8n-obb.pt') # load a pretrained model (recommended for training) + + # Train the model + results = model.train(data='dota8.yaml', epochs=100, imgsz=640) + ``` + + === "CLI" + + ```bash + # Start training from a pretrained *.pt model + yolo obb train data=dota8.yaml model=yolov8n-obb.pt epochs=100 imgsz=640 + ``` + +## Sample Images and Annotations + +Here are some examples of images from the DOTA8 dataset, along with their corresponding annotations: + +Dataset sample image + +- **Mosaiced Image**: This image demonstrates a training batch composed of mosaiced dataset images. Mosaicing is a technique used during training that combines multiple images into a single image to increase the variety of objects and scenes within each training batch. This helps improve the model's ability to generalize to different object sizes, aspect ratios, and contexts. + +The example showcases the variety and complexity of the images in the DOTA8 dataset and the benefits of using mosaicing during the training process. + +## Citations and Acknowledgments + +If you use the DOTA dataset in your research or development work, please cite the following paper: + +!!! Quote "" + + === "BibTeX" + + ```bibtex + @article{9560031, + author={Ding, Jian and Xue, Nan and Xia, Gui-Song and Bai, Xiang and Yang, Wen and Yang, Michael and Belongie, Serge and Luo, Jiebo and Datcu, Mihai and Pelillo, Marcello and Zhang, Liangpei}, + journal={IEEE Transactions on Pattern Analysis and Machine Intelligence}, + title={Object Detection in Aerial Images: A Large-Scale Benchmark and Challenges}, + year={2021}, + volume={}, + number={}, + pages={1-1}, + doi={10.1109/TPAMI.2021.3117983} + } + ``` + +A special note of gratitude to the team behind the DOTA datasets for their commendable effort in curating this dataset. For an exhaustive understanding of the dataset and its nuances, please visit the [official DOTA website](https://captain-whu.github.io/DOTA/index.html). diff --git a/docs/en/tasks/obb.md b/docs/en/tasks/obb.md index 2dcfca9675..4c3d8091ae 100644 --- a/docs/en/tasks/obb.md +++ b/docs/en/tasks/obb.md @@ -32,14 +32,12 @@ YOLOv8 pretrained OBB models are shown here, which are pretrained on the [DOTAv1 | [YOLOv8l-obb](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8l-obb.pt) | 1024 | 80.7 | 1278.42 | 11.83 | 44.5 | 433.8 | | [YOLOv8x-obb](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8x-obb.pt) | 1024 | 81.36 | 1759.10 | 13.23 | 69.5 | 676.7 | -- **mAPtest** values are for single-model multi-scale on [DOTAv1 test](http://cocodataset.org) dataset.
Reproduce by `yolo val obb data=DOTAv1.yaml device=0 split=test` +- **mAPtest** values are for single-model multi-scale on [DOTAv1 test](https://captain-whu.github.io/DOTA/index.html) dataset.
Reproduce by `yolo val obb data=DOTAv1.yaml device=0 split=test` and submit merged results to [DOTA evaluation](https://captain-whu.github.io/DOTA/evaluation.html). - **Speed** averaged over DOTAv1 val images using an [Amazon EC2 P4d](https://aws.amazon.com/ec2/instance-types/p4/) instance.
Reproduce by `yolo val obb data=DOTAv1.yaml batch=1 device=0|cpu` ## Train - - -Train YOLOv8n-obb on the dota128.yaml dataset for 100 epochs at image size 640. For a full list of available arguments see the [Configuration](../usage/cfg.md) page. +Train YOLOv8n-obb on the dota8.yaml dataset for 100 epochs at image size 640. For a full list of available arguments see the [Configuration](../usage/cfg.md) page. !!! Example @@ -54,19 +52,19 @@ Train YOLOv8n-obb on the dota128.yaml dataset for 100 epochs at image size 640. model = YOLO('yolov8n-obb.yaml').load('yolov8n.pt') # build from YAML and transfer weights # Train the model - results = model.train(data='dota128-obb.yaml', epochs=100, imgsz=640) + results = model.train(data='dota8-obb.yaml', epochs=100, imgsz=640) ``` === "CLI" ```bash # Build a new model from YAML and start training from scratch - yolo obb train data=dota128-obb.yaml model=yolov8n-obb.yaml epochs=100 imgsz=640 + yolo obb train data=dota8-obb.yaml model=yolov8n-obb.yaml epochs=100 imgsz=640 # Start training from a pretrained *.pt model - yolo obb train data=dota128-obb.yaml model=yolov8n-obb.pt epochs=100 imgsz=640 + yolo obb train data=dota8-obb.yaml model=yolov8n-obb.pt epochs=100 imgsz=640 # Build a new model from YAML, transfer pretrained weights to it and start training - yolo obb train data=dota128-obb.yaml model=yolov8n-obb.yaml pretrained=yolov8n-obb.pt epochs=100 imgsz=640 + yolo obb train data=dota8-obb.yaml model=yolov8n-obb.yaml pretrained=yolov8n-obb.pt epochs=100 imgsz=640 ``` ### Dataset format @@ -75,7 +73,7 @@ OBB dataset format can be found in detail in the [Dataset Guide](../datasets/obb ## Val -Validate trained YOLOv8n-obb model accuracy on the dota128-obb dataset. No argument need to passed as the `model` +Validate trained YOLOv8n-obb model accuracy on the dota8-obb dataset. No argument need to passed as the `model` retains it's training `data` and arguments as model attributes. !!! Example diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index c556f41748..e484e7094f 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -260,6 +260,7 @@ nav: - Oriented Bounding Boxes (OBB): - datasets/obb/index.md - DOTAv2: datasets/obb/dota-v2.md + - DOTA8: datasets/obb/dota8.md - Multi-Object Tracking: - datasets/track/index.md - Guides: diff --git a/tests/test_cli.py b/tests/test_cli.py index 7a3fd9929f..994ce5a28e 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -13,12 +13,14 @@ TASK_ARGS = [ ('detect', 'yolov8n', 'coco8.yaml'), ('segment', 'yolov8n-seg', 'coco8-seg.yaml'), ('classify', 'yolov8n-cls', 'imagenet10'), - ('pose', 'yolov8n-pose', 'coco8-pose.yaml'), ] # (task, model, data) + ('pose', 'yolov8n-pose', 'coco8-pose.yaml'), + ('obb', 'yolov8n-obb', 'dota8.yaml'), ] # (task, model, data) EXPORT_ARGS = [ ('yolov8n', 'torchscript'), ('yolov8n-seg', 'torchscript'), ('yolov8n-cls', 'torchscript'), - ('yolov8n-pose', 'torchscript'), ] # (model, format) + ('yolov8n-pose', 'torchscript'), + ('yolov8n-obb', 'torchscript'), ] # (model, format) def run(cmd): diff --git a/tests/test_python.py b/tests/test_python.py index 8032602ef4..e20b7bf232 100644 --- a/tests/test_python.py +++ b/tests/test_python.py @@ -77,6 +77,7 @@ def test_predict_img(): seg_model = YOLO(WEIGHTS_DIR / 'yolov8n-seg.pt') cls_model = YOLO(WEIGHTS_DIR / 'yolov8n-cls.pt') pose_model = YOLO(WEIGHTS_DIR / 'yolov8n-pose.pt') + obb_model = YOLO(WEIGHTS_DIR / 'yolov8n-obb.pt') im = cv2.imread(str(SOURCE)) assert len(model(source=Image.open(SOURCE), save=True, verbose=True, imgsz=32)) == 1 # PIL assert len(model(source=im, save=True, save_txt=True, imgsz=32)) == 1 # ndarray @@ -105,6 +106,8 @@ def test_predict_img(): assert len(results) == t.shape[0] results = pose_model(t, imgsz=32) assert len(results) == t.shape[0] + results = obb_model(t, imgsz=32) + assert len(results) == t.shape[0] def test_predict_grey_and_4ch(): diff --git a/ultralytics/cfg/datasets/DOTAv1.5.yaml b/ultralytics/cfg/datasets/DOTAv1.5.yaml index 1480ad0f54..7e9b4d4109 100644 --- a/ultralytics/cfg/datasets/DOTAv1.5.yaml +++ b/ultralytics/cfg/datasets/DOTAv1.5.yaml @@ -5,7 +5,7 @@ # parent # ├── ultralytics # └── datasets -# └── dota2 ← downloads here (2GB) +# └── dota1.5 ← downloads here (2GB) # Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] path: ../datasets/DOTAv1.5 # dataset root dir diff --git a/ultralytics/cfg/datasets/DOTAv1.yaml b/ultralytics/cfg/datasets/DOTAv1.yaml index fa13404817..7fedfd30a7 100644 --- a/ultralytics/cfg/datasets/DOTAv1.yaml +++ b/ultralytics/cfg/datasets/DOTAv1.yaml @@ -5,7 +5,7 @@ # parent # ├── ultralytics # └── datasets -# └── dota2 ← downloads here (2GB) +# └── dota1 ← downloads here (2GB) # Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] path: ../datasets/DOTAv1 # dataset root dir diff --git a/ultralytics/cfg/datasets/dota8.yaml b/ultralytics/cfg/datasets/dota8.yaml new file mode 100644 index 0000000000..cbf6361ab3 --- /dev/null +++ b/ultralytics/cfg/datasets/dota8.yaml @@ -0,0 +1,34 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license +# DOTA8 dataset 8 images from split DOTAv1 dataset by Ultralytics +# Documentation: https://docs.ultralytics.com/datasets/obb/dota8/ +# Example usage: yolo train model=yolov8n-obb.pt data=dota8.yaml +# parent +# ├── ultralytics +# └── datasets +# └── dota8 ← downloads here (1MB) + +# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] +path: ../datasets/dota8 # dataset root dir +train: images/train # train images (relative to 'path') 4 images +val: images/val # val images (relative to 'path') 4 images + +# Classes for DOTA 1.0 +names: + 0: plane + 1: ship + 2: storage tank + 3: baseball diamond + 4: tennis court + 5: basketball court + 6: ground track field + 7: harbor + 8: bridge + 9: large vehicle + 10: small vehicle + 11: helicopter + 12: roundabout + 13: soccer ball field + 14: swimming pool + +# Download script/URL (optional) +download: https://github.com/ultralytics/yolov5/releases/download/v1.0/dota8.zip diff --git a/ultralytics/engine/results.py b/ultralytics/engine/results.py index 4444ef1c69..713658f5da 100644 --- a/ultralytics/engine/results.py +++ b/ultralytics/engine/results.py @@ -323,6 +323,9 @@ class Results(SimpleClass): if self.probs is not None: LOGGER.warning('WARNING ⚠️ Classify task do not support `save_crop`.') return + if self.obb is not None: + LOGGER.warning('WARNING ⚠️ OBB task do not support `save_crop`.') + return for d in self.boxes: save_one_box(d.xyxy, self.orig_img.copy(), diff --git a/ultralytics/models/yolo/obb/val.py b/ultralytics/models/yolo/obb/val.py index c9f70b1318..cbffd57d79 100644 --- a/ultralytics/models/yolo/obb/val.py +++ b/ultralytics/models/yolo/obb/val.py @@ -106,6 +106,17 @@ class OBBValidator(DetectionValidator): 'rbox': [round(x, 3) for x in r], 'poly': [round(x, 3) for x in b]}) + def save_one_txt(self, predn, save_conf, shape, file): + """Save YOLO detections to a txt file in normalized coordinates in a specific format.""" + gn = torch.tensor(shape)[[1, 0, 1, 0]] # normalization gain whwh + for *xyxy, conf, cls, angle in predn.tolist(): + xywha = torch.tensor([*xyxy, angle]).view(1, 5) + xywha[:, :4] /= gn + xyxyxyxy = ops.xywhr2xyxyxyxy(xywha).view(-1).tolist() # normalized xywh + line = (cls, *xyxyxyxy, conf) if save_conf else (cls, *xyxyxyxy) # label format + with open(file, 'a') as f: + f.write(('%g ' * len(line)).rstrip() % line + '\n') + def eval_json(self, stats): """Evaluates YOLO output in JSON format and returns performance statistics.""" if self.args.save_json and self.is_dota and len(self.jdict): diff --git a/ultralytics/models/yolo/segment/train.py b/ultralytics/models/yolo/segment/train.py index 949f3cd6db..1d1227daf4 100644 --- a/ultralytics/models/yolo/segment/train.py +++ b/ultralytics/models/yolo/segment/train.py @@ -51,7 +51,7 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer): batch['batch_idx'], batch['cls'].squeeze(-1), batch['bboxes'], - batch['masks'], + masks=batch['masks'], paths=batch['im_file'], fname=self.save_dir / f'train_batch{ni}.jpg', on_plot=self.on_plot)