Fix `torch.amp.autocast('cuda')` warnings (#14633)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Ultralytics Assistant <135830346+UltralyticsAssistant@users.noreply.github.com>
pull/14572/head^2
Glenn Jocher 4 months ago committed by GitHub
parent 23ce08791f
commit 0d7bf447eb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 4
      docs/en/reference/utils/torch_utils.md
  2. 10
      ultralytics/engine/trainer.py
  3. 2
      ultralytics/models/utils/ops.py
  4. 4
      ultralytics/utils/autobatch.py
  5. 4
      ultralytics/utils/checks.py
  6. 3
      ultralytics/utils/loss.py
  7. 31
      ultralytics/utils/torch_utils.py

@ -27,6 +27,10 @@ keywords: Ultralytics, torch utils, model optimization, device selection, infere
<br><br><hr><br>
## ::: ultralytics.utils.torch_utils.autocast
<br><br><hr><br>
## ::: ultralytics.utils.torch_utils.get_cpu_info
<br><br><hr><br>

@ -41,8 +41,10 @@ from ultralytics.utils.checks import check_amp, check_file, check_imgsz, check_m
from ultralytics.utils.dist import ddp_cleanup, generate_ddp_command
from ultralytics.utils.files import get_latest_run
from ultralytics.utils.torch_utils import (
TORCH_1_13,
EarlyStopping,
ModelEMA,
autocast,
convert_optimizer_state_dict_to_fp16,
init_seeds,
one_cycle,
@ -264,7 +266,11 @@ class BaseTrainer:
if RANK > -1 and world_size > 1: # DDP
dist.broadcast(self.amp, src=0) # broadcast the tensor from rank 0 to all other ranks (returns None)
self.amp = bool(self.amp) # as boolean
self.scaler = torch.cuda.amp.GradScaler(enabled=self.amp)
self.scaler = (
torch.amp.GradScaler("cuda", enabled=self.amp)
if TORCH_1_13
else torch.cuda.amp.GradScaler(enabled=self.amp)
)
if world_size > 1:
self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[RANK], find_unused_parameters=True)
@ -376,7 +382,7 @@ class BaseTrainer:
x["momentum"] = np.interp(ni, xi, [self.args.warmup_momentum, self.args.momentum])
# Forward
with torch.cuda.amp.autocast(self.amp):
with autocast(self.amp):
batch = self.preprocess_batch(batch)
self.loss, self.loss_items = self.model(batch)
if RANK != -1:

@ -133,7 +133,7 @@ class HungarianMatcher(nn.Module):
# sample_points = torch.cat([a.repeat(b, 1, 1, 1) for a, b in zip(sample_points, num_gts) if b > 0])
# tgt_mask = F.grid_sample(tgt_mask, sample_points, align_corners=False).squeeze([1, 2])
#
# with torch.cuda.amp.autocast(False):
# with torch.amp.autocast("cuda", enabled=False):
# # binary cross entropy cost
# pos_cost_mask = F.binary_cross_entropy_with_logits(out_mask, torch.ones_like(out_mask), reduction='none')
# neg_cost_mask = F.binary_cross_entropy_with_logits(out_mask, torch.zeros_like(out_mask), reduction='none')

@ -7,7 +7,7 @@ import numpy as np
import torch
from ultralytics.utils import DEFAULT_CFG, LOGGER, colorstr
from ultralytics.utils.torch_utils import profile
from ultralytics.utils.torch_utils import autocast, profile
def check_train_batch_size(model, imgsz=640, amp=True, batch=-1):
@ -23,7 +23,7 @@ def check_train_batch_size(model, imgsz=640, amp=True, batch=-1):
(int): Optimal batch size computed using the autobatch() function.
"""
with torch.cuda.amp.autocast(amp):
with autocast(enabled=amp):
return autobatch(deepcopy(model).train(), imgsz, fraction=batch if 0.0 < batch < 1.0 else 0.6)

@ -641,6 +641,8 @@ def check_amp(model):
Returns:
(bool): Returns True if the AMP functionality works correctly with YOLOv8 model, else False.
"""
from ultralytics.utils.torch_utils import autocast
device = next(model.parameters()).device # get model device
if device.type in {"cpu", "mps"}:
return False # AMP only used on CUDA devices
@ -648,7 +650,7 @@ def check_amp(model):
def amp_allclose(m, im):
"""All close FP32 vs AMP results."""
a = m(im, device=device, verbose=False)[0].boxes.data # FP32 inference
with torch.cuda.amp.autocast(True):
with autocast(enabled=True):
b = m(im, device=device, verbose=False)[0].boxes.data # AMP inference
del m
return a.shape == b.shape and torch.allclose(a, b.float(), atol=0.5) # close to 0.5 absolute tolerance

@ -7,6 +7,7 @@ import torch.nn.functional as F
from ultralytics.utils.metrics import OKS_SIGMA
from ultralytics.utils.ops import crop_mask, xywh2xyxy, xyxy2xywh
from ultralytics.utils.tal import RotatedTaskAlignedAssigner, TaskAlignedAssigner, dist2bbox, dist2rbox, make_anchors
from ultralytics.utils.torch_utils import autocast
from .metrics import bbox_iou, probiou
from .tal import bbox2dist
@ -27,7 +28,7 @@ class VarifocalLoss(nn.Module):
def forward(pred_score, gt_score, label, alpha=0.75, gamma=2.0):
"""Computes varfocal loss."""
weight = alpha * pred_score.sigmoid().pow(gamma) * (1 - label) + gt_score * label
with torch.cuda.amp.autocast(enabled=False):
with autocast(enabled=False):
loss = (
(F.binary_cross_entropy_with_logits(pred_score.float(), gt_score.float(), reduction="none") * weight)
.mean(1)

@ -68,6 +68,37 @@ def smart_inference_mode():
return decorate
def autocast(enabled: bool, device: str = "cuda"):
"""
Get the appropriate autocast context manager based on PyTorch version and AMP setting.
This function returns a context manager for automatic mixed precision (AMP) training that is compatible with both
older and newer versions of PyTorch. It handles the differences in the autocast API between PyTorch versions.
Args:
enabled (bool): Whether to enable automatic mixed precision.
device (str, optional): The device to use for autocast. Defaults to 'cuda'.
Returns:
(torch.amp.autocast): The appropriate autocast context manager.
Note:
- For PyTorch versions 1.13 and newer, it uses `torch.amp.autocast`.
- For older versions, it uses `torch.cuda.autocast`.
Example:
```python
with autocast(amp=True):
# Your mixed precision operations here
pass
```
"""
if TORCH_1_13:
return torch.amp.autocast(device, enabled=enabled)
else:
return torch.cuda.amp.autocast(enabled)
def get_cpu_info():
"""Return a string with system CPU information, i.e. 'Apple M2'."""
import cpuinfo # pip install py-cpuinfo

Loading…
Cancel
Save