Merge branch 'main' into cli-info

cli-info
Burhan 2 months ago committed by GitHub
commit f1326ddff6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 3
      docker/Dockerfile
  2. 4
      docs/en/integrations/kaggle.md
  3. 2
      docs/en/integrations/openvino.md
  4. 4
      docs/en/reference/data/converter.md
  5. 2
      docs/en/yolov5/environments/docker_image_quickstart_tutorial.md
  6. 2
      pyproject.toml
  7. 2
      ultralytics/__init__.py
  8. 73
      ultralytics/data/converter.py
  9. 21
      ultralytics/data/explorer/gui/dash.py
  10. 2
      ultralytics/engine/model.py
  11. 4
      ultralytics/utils/autobatch.py
  12. 8
      ultralytics/utils/torch_utils.py

@ -11,7 +11,8 @@ ENV PYTHONUNBUFFERED=1 \
PYTHONDONTWRITEBYTECODE=1 \
PIP_NO_CACHE_DIR=1 \
PIP_BREAK_SYSTEM_PACKAGES=1 \
MKL_THREADING_LAYER=GNU
MKL_THREADING_LAYER=GNU \
OMP_NUM_THREADS=1
# Downloads to user config dir
ADD https://github.com/ultralytics/assets/releases/download/v0.0.0/Arial.ttf \

@ -48,7 +48,7 @@ These options include:
When working with Kaggle, you might come across some common issues. Here are some points to help you navigate the platform smoothly:
- **Access to GPUs**: In your Kaggle notebooks, you can activate a GPU at any time, with usage allowed for up to 30 hours per week. Kaggle provides the Nvidia Tesla P100 GPU with 16GB of memory and also offers the option of using a Nvidia GPU T4 x2. Powerful hardware accelerates your machine-learning tasks, making model training and inference much faster.
- **Access to GPUs**: In your Kaggle notebooks, you can activate a GPU at any time, with usage allowed for up to 30 hours per week. Kaggle provides the NVIDIA Tesla P100 GPU with 16GB of memory and also offers the option of using a NVIDIA GPU T4 x2. Powerful hardware accelerates your machine-learning tasks, making model training and inference much faster.
- **Kaggle Kernels**: Kaggle Kernels are free Jupyter notebook servers that can integrate GPUs, allowing you to perform machine learning operations on cloud computers. You don't have to rely on your own computer's CPU, avoiding overload and freeing up your local resources.
- **Kaggle Datasets**: Kaggle datasets are free to download. However, it's important to check the license for each dataset to understand any usage restrictions. Some datasets may have limitations on academic publications or commercial use. You can download datasets directly to your Kaggle notebook or anywhere else via the Kaggle API.
- **Saving and Committing Notebooks**: To save and commit a notebook on Kaggle, click "Save Version." This saves the current state of your notebook. Once the background kernel finishes generating the output files, you can access them from the Output tab on the main notebook page.
@ -101,7 +101,7 @@ Training a YOLO11 model on Kaggle is straightforward. First, access the [Kaggle
Kaggle offers several advantages for training YOLO11 models:
- **Free GPU Access**: Utilize powerful GPUs like Nvidia Tesla P100 or T4 x2 for up to 30 hours per week.
- **Free GPU Access**: Utilize powerful GPUs like NVIDIA Tesla P100 or T4 x2 for up to 30 hours per week.
- **Pre-installed Libraries**: Libraries like TensorFlow and PyTorch are pre-installed, simplifying the setup.
- **Community Collaboration**: Engage with a vast community of data scientists and machine learning enthusiasts.
- **Version Control**: Easily manage different versions of your notebooks and revert to previous versions if needed.

@ -148,7 +148,7 @@ This table represents the benchmark results for five different models (YOLOv8n,
### Intel Arc GPU
Intel® Arc™ represents Intel's foray into the dedicated GPU market. The Arc™ series, designed to compete with leading GPU manufacturers like AMD and Nvidia, caters to both the laptop and desktop markets. The series includes mobile versions for compact devices like laptops, and larger, more powerful versions for desktop computers.
Intel® Arc™ represents Intel's foray into the dedicated GPU market. The Arc™ series, designed to compete with leading GPU manufacturers like AMD and NVIDIA, caters to both the laptop and desktop markets. The series includes mobile versions for compact devices like laptops, and larger, more powerful versions for desktop computers.
The Arc™ series is divided into three categories: Arc™ 3, Arc™ 5, and Arc™ 7, with each number indicating the performance level. Each category includes several models, and the 'M' in the GPU model name signifies a mobile, integrated variant.

@ -41,4 +41,8 @@ keywords: Ultralytics, data conversion, YOLO models, COCO, DOTA, YOLO bbox2segme
## ::: ultralytics.data.converter.yolo_bbox2segment
<br><br><hr><br>
## ::: ultralytics.data.converter.create_synthetic_coco_dataset
<br><br>

@ -12,7 +12,7 @@ You can also explore other quickstart options for YOLOv5, such as our [Colab Not
## Prerequisites
1. **NVIDIA Driver**: Version 455.23 or higher. Download from [Nvidia's website](https://www.nvidia.com/Download/index.aspx).
1. **NVIDIA Driver**: Version 455.23 or higher. Download from [NVIDIA's website](https://www.nvidia.com/Download/index.aspx).
2. **NVIDIA-Docker**: Allows Docker to interact with your local GPU. Installation instructions are available on the [NVIDIA-Docker GitHub repository](https://github.com/NVIDIA/nvidia-docker).
3. **Docker Engine - CE**: Version 19.03 or higher. Download and installation instructions can be found on the [Docker website](https://docs.docker.com/get-started/get-docker/).

@ -19,7 +19,7 @@
# For comprehensive documentation and usage instructions, visit: https://docs.ultralytics.com
[build-system]
requires = ["setuptools>=57.0.0", "wheel"]
requires = ["setuptools>=70.0.0", "wheel"]
build-backend = "setuptools.build_meta"
# Project settings -----------------------------------------------------------------------------------------------------

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.3.6"
__version__ = "8.3.7"
import os

@ -1,13 +1,18 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import json
import random
import shutil
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
import cv2
import numpy as np
from PIL import Image
from ultralytics.utils import LOGGER, TQDM
from ultralytics.utils import DATASETS_DIR, LOGGER, NUM_THREADS, TQDM
from ultralytics.utils.downloads import download
from ultralytics.utils.files import increment_path
@ -588,15 +593,13 @@ def yolo_bbox2segment(im_dir, save_dir=None, sam_model="sam_b.pt"):
- im_dir
001.jpg
..
...
NNN.jpg
- labels
001.txt
..
...
NNN.txt
"""
from tqdm import tqdm
from ultralytics import SAM
from ultralytics.data import YOLODataset
from ultralytics.utils import LOGGER
@ -610,7 +613,7 @@ def yolo_bbox2segment(im_dir, save_dir=None, sam_model="sam_b.pt"):
LOGGER.info("Detection labels detected, generating segment labels by SAM model!")
sam_model = SAM(sam_model)
for label in tqdm(dataset.labels, total=len(dataset.labels), desc="Generating segment labels"):
for label in TQDM(dataset.labels, total=len(dataset.labels), desc="Generating segment labels"):
h, w = label["shape"]
boxes = label["bboxes"]
if len(boxes) == 0: # skip empty labels
@ -635,3 +638,61 @@ def yolo_bbox2segment(im_dir, save_dir=None, sam_model="sam_b.pt"):
with open(txt_file, "a") as f:
f.writelines(text + "\n" for text in texts)
LOGGER.info(f"Generated segment labels saved in {save_dir}")
def create_synthetic_coco_dataset():
"""
Creates a synthetic COCO dataset with random images based on filenames from label lists.
This function downloads COCO labels, reads image filenames from label list files,
creates synthetic images for train2017 and val2017 subsets, and organizes
them in the COCO dataset structure. It uses multithreading to generate images efficiently.
Examples:
>>> from ultralytics.data.converter import create_synthetic_coco_dataset
>>> create_synthetic_coco_dataset()
Notes:
- Requires internet connection to download label files.
- Generates random RGB images of varying sizes (480x480 to 640x640 pixels).
- Existing test2017 directory is removed as it's not needed.
- Reads image filenames from train2017.txt and val2017.txt files.
"""
def create_synthetic_image(image_file):
"""Generates synthetic images with random sizes and colors for dataset augmentation or testing purposes."""
if not image_file.exists():
size = (random.randint(480, 640), random.randint(480, 640))
Image.new(
"RGB",
size=size,
color=(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)),
).save(image_file)
# Download labels
dir = DATASETS_DIR / "coco"
url = "https://github.com/ultralytics/assets/releases/download/v0.0.0/"
label_zip = "coco2017labels-segments.zip"
download([url + label_zip], dir=dir.parent)
# Create synthetic images
shutil.rmtree(dir / "labels" / "test2017", ignore_errors=True) # Remove test2017 directory as not needed
with ThreadPoolExecutor(max_workers=NUM_THREADS) as executor:
for subset in ["train2017", "val2017"]:
subset_dir = dir / "images" / subset
subset_dir.mkdir(parents=True, exist_ok=True)
# Read image filenames from label list file
label_list_file = dir / f"{subset}.txt"
if label_list_file.exists():
with open(label_list_file, "r") as f:
image_files = [dir / line.strip() for line in f]
# Submit all tasks
futures = [executor.submit(create_synthetic_image, image_file) for image_file in image_files]
for _ in TQDM(as_completed(futures), total=len(futures), desc=f"Generating images for {subset}"):
pass # The actual work is done in the background
else:
print(f"Warning: Labels file {label_list_file} does not exist. Skipping image creation for {subset}.")
print("Synthetic COCO dataset created successfully.")

@ -39,24 +39,11 @@ def init_explorer_form(data=None, model=None):
else:
ds = [data]
prefixes = ["yolov8", "yolo11"]
sizes = ["n", "s", "m", "l", "x"]
tasks = ["", "-seg", "-pose"]
if model is None:
models = [
"yolov8n.pt",
"yolov8s.pt",
"yolov8m.pt",
"yolov8l.pt",
"yolov8x.pt",
"yolov8n-seg.pt",
"yolov8s-seg.pt",
"yolov8m-seg.pt",
"yolov8l-seg.pt",
"yolov8x-seg.pt",
"yolov8n-pose.pt",
"yolov8s-pose.pt",
"yolov8m-pose.pt",
"yolov8l-pose.pt",
"yolov8x-pose.pt",
]
models = [f"{p}{s}{t}" for p in prefixes for s in sizes for t in tasks]
else:
models = [model]

@ -544,6 +544,8 @@ class Model(nn.Module):
if not self.predictor:
self.predictor = predictor or self._smart_load("predictor")(overrides=args, _callbacks=self.callbacks)
if predictor:
self.predictor.args = get_cfg(self.predictor.args, args)
self.predictor.setup_model(model=self.model, verbose=is_cli)
else: # only update args if predictor is already setup
self.predictor.args = get_cfg(self.predictor.args, args)

@ -69,7 +69,7 @@ def autobatch(model, imgsz=640, fraction=0.60, batch_size=DEFAULT_CFG.batch):
batch_sizes = [1, 2, 4, 8, 16]
try:
img = [torch.empty(b, 3, imgsz, imgsz) for b in batch_sizes]
results = profile(img, model, n=3, device=device)
results = profile(img, model, n=1, device=device)
# Fit a solution
y = [x[2] for x in results if x] # memory [2]
@ -89,3 +89,5 @@ def autobatch(model, imgsz=640, fraction=0.60, batch_size=DEFAULT_CFG.batch):
except Exception as e:
LOGGER.warning(f"{prefix}WARNING ⚠ error detected: {e}, using default batch-size {batch_size}.")
return batch_size
finally:
torch.cuda.empty_cache()

@ -643,7 +643,8 @@ def profile(input, ops, n=10, device=None):
f"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}"
f"{'input':>24s}{'output':>24s}"
)
gc.collect() # attempt to free unused memory
torch.cuda.empty_cache()
for x in input if isinstance(input, list) else [input]:
x = x.to(device)
x.requires_grad = True
@ -677,8 +678,9 @@ def profile(input, ops, n=10, device=None):
except Exception as e:
LOGGER.info(e)
results.append(None)
gc.collect() # attempt to free unused memory
torch.cuda.empty_cache()
finally:
gc.collect() # attempt to free unused memory
torch.cuda.empty_cache()
return results

Loading…
Cancel
Save