`ultralytics 8.3.31` add `max_num_obj` factor for `AutoBatch` (#17514)

Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
prune v8.3.31
Laughing 2 weeks ago committed by GitHub
parent e100484422
commit 4453ddab93
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 2
      ultralytics/__init__.py
  2. 17
      ultralytics/engine/trainer.py
  3. 7
      ultralytics/models/yolo/detect/train.py
  4. 12
      ultralytics/utils/autobatch.py
  5. 2
      ultralytics/utils/tal.py
  6. 10
      ultralytics/utils/torch_utils.py

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.3.30"
__version__ = "8.3.31"
import os

@ -279,12 +279,7 @@ class BaseTrainer:
# Batch size
if self.batch_size < 1 and RANK == -1: # single-GPU only, estimate best batch size
self.args.batch = self.batch_size = check_train_batch_size(
model=self.model,
imgsz=self.args.imgsz,
amp=self.amp,
batch=self.batch_size,
)
self.args.batch = self.batch_size = self.auto_batch()
# Dataloaders
batch_size = self.batch_size // max(world_size, 1)
@ -478,6 +473,16 @@ class BaseTrainer:
self._clear_memory()
self.run_callbacks("teardown")
def auto_batch(self, max_num_obj=0):
"""Get batch size by calculating memory occupation of model."""
return check_train_batch_size(
model=self.model,
imgsz=self.args.imgsz,
amp=self.amp,
batch=self.batch_size,
max_num_obj=max_num_obj,
) # returns batch size
def _get_memory(self):
"""Get accelerator memory utilization in GB."""
if self.device.type == "mps":

@ -141,3 +141,10 @@ class DetectionTrainer(BaseTrainer):
boxes = np.concatenate([lb["bboxes"] for lb in self.train_loader.dataset.labels], 0)
cls = np.concatenate([lb["cls"] for lb in self.train_loader.dataset.labels], 0)
plot_labels(boxes, cls.squeeze(), names=self.data["names"], save_dir=self.save_dir, on_plot=self.on_plot)
def auto_batch(self):
"""Get batch size by calculating memory occupation of model."""
train_dataset = self.build_dataset(self.trainset, mode="train", batch=16)
# 4 for mosaic augmentation
max_num_obj = max(len(l["cls"]) for l in train_dataset.labels) * 4
return super().auto_batch(max_num_obj)

@ -11,7 +11,7 @@ from ultralytics.utils import DEFAULT_CFG, LOGGER, colorstr
from ultralytics.utils.torch_utils import autocast, profile
def check_train_batch_size(model, imgsz=640, amp=True, batch=-1):
def check_train_batch_size(model, imgsz=640, amp=True, batch=-1, max_num_obj=1):
"""
Compute optimal YOLO training batch size using the autobatch() function.
@ -20,6 +20,7 @@ def check_train_batch_size(model, imgsz=640, amp=True, batch=-1):
imgsz (int, optional): Image size used for training.
amp (bool, optional): Use automatic mixed precision if True.
batch (float, optional): Fraction of GPU memory to use. If -1, use default.
max_num_obj (int, optional): The maximum number of objects from dataset.
Returns:
(int): Optimal batch size computed using the autobatch() function.
@ -29,10 +30,12 @@ def check_train_batch_size(model, imgsz=640, amp=True, batch=-1):
Otherwise, a default fraction of 0.6 is used.
"""
with autocast(enabled=amp):
return autobatch(deepcopy(model).train(), imgsz, fraction=batch if 0.0 < batch < 1.0 else 0.6)
return autobatch(
deepcopy(model).train(), imgsz, fraction=batch if 0.0 < batch < 1.0 else 0.6, max_num_obj=max_num_obj
)
def autobatch(model, imgsz=640, fraction=0.60, batch_size=DEFAULT_CFG.batch):
def autobatch(model, imgsz=640, fraction=0.60, batch_size=DEFAULT_CFG.batch, max_num_obj=1):
"""
Automatically estimate the best YOLO batch size to use a fraction of the available CUDA memory.
@ -41,6 +44,7 @@ def autobatch(model, imgsz=640, fraction=0.60, batch_size=DEFAULT_CFG.batch):
imgsz (int, optional): The image size used as input for the YOLO model. Defaults to 640.
fraction (float, optional): The fraction of available CUDA memory to use. Defaults to 0.60.
batch_size (int, optional): The default batch size to use if an error is detected. Defaults to 16.
max_num_obj (int, optional): The maximum number of objects from dataset.
Returns:
(int): The optimal batch size.
@ -70,7 +74,7 @@ def autobatch(model, imgsz=640, fraction=0.60, batch_size=DEFAULT_CFG.batch):
batch_sizes = [1, 2, 4, 8, 16] if t < 16 else [1, 2, 4, 8, 16, 32, 64]
try:
img = [torch.empty(b, 3, imgsz, imgsz) for b in batch_sizes]
results = profile(img, model, n=1, device=device)
results = profile(img, model, n=1, device=device, max_num_obj=max_num_obj)
# Fit a solution
y = [x[2] for x in results if x] # memory [2]

@ -3,6 +3,7 @@
import torch
import torch.nn as nn
from . import LOGGER
from .checks import check_version
from .metrics import bbox_iou, probiou
from .ops import xywhr2xyxyxyxy
@ -73,6 +74,7 @@ class TaskAlignedAssigner(nn.Module):
return self._forward(pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt)
except torch.OutOfMemoryError:
# Move tensors to CPU, compute, then move back to original device
LOGGER.warning("WARNING: CUDA OutOfMemoryError in TaskAlignedAssigner, using CPU")
cpu_tensors = [t.cpu() for t in (pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt)]
result = self._forward(*cpu_tensors)
return tuple(t.to(device) for t in result)

@ -623,7 +623,7 @@ def convert_optimizer_state_dict_to_fp16(state_dict):
return state_dict
def profile(input, ops, n=10, device=None):
def profile(input, ops, n=10, device=None, max_num_obj=0):
"""
Ultralytics speed, memory and FLOPs profiler.
@ -671,6 +671,14 @@ def profile(input, ops, n=10, device=None):
t[2] = float("nan")
tf += (t[1] - t[0]) * 1000 / n # ms per op forward
tb += (t[2] - t[1]) * 1000 / n # ms per op backward
if max_num_obj: # simulate training with predictions per image grid (for AutoBatch)
torch.randn(
x.shape[0],
max_num_obj,
int(sum([(x.shape[-1] / s) * (x.shape[-2] / s) for s in m.stride.tolist()])),
device=device,
dtype=torch.float32,
)
mem = torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0 # (GB)
s_in, s_out = (tuple(x.shape) if isinstance(x, torch.Tensor) else "list" for x in (x, y)) # shapes
p = sum(x.numel() for x in m.parameters()) if isinstance(m, nn.Module) else 0 # parameters

Loading…
Cancel
Save