From 0d7bf447eb6dd726ead6174e55c542b527b2cc77 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Tue, 23 Jul 2024 21:58:39 +0200 Subject: [PATCH] Fix `torch.amp.autocast('cuda')` warnings (#14633) Signed-off-by: Glenn Jocher Co-authored-by: UltralyticsAssistant Co-authored-by: Ultralytics Assistant <135830346+UltralyticsAssistant@users.noreply.github.com> --- docs/en/reference/utils/torch_utils.md | 4 ++++ ultralytics/engine/trainer.py | 10 +++++++-- ultralytics/models/utils/ops.py | 2 +- ultralytics/utils/autobatch.py | 4 ++-- ultralytics/utils/checks.py | 4 +++- ultralytics/utils/loss.py | 3 ++- ultralytics/utils/torch_utils.py | 31 ++++++++++++++++++++++++++ 7 files changed, 51 insertions(+), 7 deletions(-) diff --git a/docs/en/reference/utils/torch_utils.md b/docs/en/reference/utils/torch_utils.md index 6a48fec74..dd4c364d9 100644 --- a/docs/en/reference/utils/torch_utils.md +++ b/docs/en/reference/utils/torch_utils.md @@ -27,6 +27,10 @@ keywords: Ultralytics, torch utils, model optimization, device selection, infere



+## ::: ultralytics.utils.torch_utils.autocast + +



+ ## ::: ultralytics.utils.torch_utils.get_cpu_info



diff --git a/ultralytics/engine/trainer.py b/ultralytics/engine/trainer.py index 3fb3e0b85..4415ba94e 100644 --- a/ultralytics/engine/trainer.py +++ b/ultralytics/engine/trainer.py @@ -41,8 +41,10 @@ from ultralytics.utils.checks import check_amp, check_file, check_imgsz, check_m from ultralytics.utils.dist import ddp_cleanup, generate_ddp_command from ultralytics.utils.files import get_latest_run from ultralytics.utils.torch_utils import ( + TORCH_1_13, EarlyStopping, ModelEMA, + autocast, convert_optimizer_state_dict_to_fp16, init_seeds, one_cycle, @@ -264,7 +266,11 @@ class BaseTrainer: if RANK > -1 and world_size > 1: # DDP dist.broadcast(self.amp, src=0) # broadcast the tensor from rank 0 to all other ranks (returns None) self.amp = bool(self.amp) # as boolean - self.scaler = torch.cuda.amp.GradScaler(enabled=self.amp) + self.scaler = ( + torch.amp.GradScaler("cuda", enabled=self.amp) + if TORCH_1_13 + else torch.cuda.amp.GradScaler(enabled=self.amp) + ) if world_size > 1: self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[RANK], find_unused_parameters=True) @@ -376,7 +382,7 @@ class BaseTrainer: x["momentum"] = np.interp(ni, xi, [self.args.warmup_momentum, self.args.momentum]) # Forward - with torch.cuda.amp.autocast(self.amp): + with autocast(self.amp): batch = self.preprocess_batch(batch) self.loss, self.loss_items = self.model(batch) if RANK != -1: diff --git a/ultralytics/models/utils/ops.py b/ultralytics/models/utils/ops.py index 4f66feef6..64d10e36b 100644 --- a/ultralytics/models/utils/ops.py +++ b/ultralytics/models/utils/ops.py @@ -133,7 +133,7 @@ class HungarianMatcher(nn.Module): # sample_points = torch.cat([a.repeat(b, 1, 1, 1) for a, b in zip(sample_points, num_gts) if b > 0]) # tgt_mask = F.grid_sample(tgt_mask, sample_points, align_corners=False).squeeze([1, 2]) # - # with torch.cuda.amp.autocast(False): + # with torch.amp.autocast("cuda", enabled=False): # # binary cross entropy cost # pos_cost_mask = F.binary_cross_entropy_with_logits(out_mask, torch.ones_like(out_mask), reduction='none') # neg_cost_mask = F.binary_cross_entropy_with_logits(out_mask, torch.zeros_like(out_mask), reduction='none') diff --git a/ultralytics/utils/autobatch.py b/ultralytics/utils/autobatch.py index 2f695df82..784210c57 100644 --- a/ultralytics/utils/autobatch.py +++ b/ultralytics/utils/autobatch.py @@ -7,7 +7,7 @@ import numpy as np import torch from ultralytics.utils import DEFAULT_CFG, LOGGER, colorstr -from ultralytics.utils.torch_utils import profile +from ultralytics.utils.torch_utils import autocast, profile def check_train_batch_size(model, imgsz=640, amp=True, batch=-1): @@ -23,7 +23,7 @@ def check_train_batch_size(model, imgsz=640, amp=True, batch=-1): (int): Optimal batch size computed using the autobatch() function. """ - with torch.cuda.amp.autocast(amp): + with autocast(enabled=amp): return autobatch(deepcopy(model).train(), imgsz, fraction=batch if 0.0 < batch < 1.0 else 0.6) diff --git a/ultralytics/utils/checks.py b/ultralytics/utils/checks.py index dfd792283..d94e157fb 100644 --- a/ultralytics/utils/checks.py +++ b/ultralytics/utils/checks.py @@ -641,6 +641,8 @@ def check_amp(model): Returns: (bool): Returns True if the AMP functionality works correctly with YOLOv8 model, else False. """ + from ultralytics.utils.torch_utils import autocast + device = next(model.parameters()).device # get model device if device.type in {"cpu", "mps"}: return False # AMP only used on CUDA devices @@ -648,7 +650,7 @@ def check_amp(model): def amp_allclose(m, im): """All close FP32 vs AMP results.""" a = m(im, device=device, verbose=False)[0].boxes.data # FP32 inference - with torch.cuda.amp.autocast(True): + with autocast(enabled=True): b = m(im, device=device, verbose=False)[0].boxes.data # AMP inference del m return a.shape == b.shape and torch.allclose(a, b.float(), atol=0.5) # close to 0.5 absolute tolerance diff --git a/ultralytics/utils/loss.py b/ultralytics/utils/loss.py index 3c3d3b71e..15bf92f9d 100644 --- a/ultralytics/utils/loss.py +++ b/ultralytics/utils/loss.py @@ -7,6 +7,7 @@ import torch.nn.functional as F from ultralytics.utils.metrics import OKS_SIGMA from ultralytics.utils.ops import crop_mask, xywh2xyxy, xyxy2xywh from ultralytics.utils.tal import RotatedTaskAlignedAssigner, TaskAlignedAssigner, dist2bbox, dist2rbox, make_anchors +from ultralytics.utils.torch_utils import autocast from .metrics import bbox_iou, probiou from .tal import bbox2dist @@ -27,7 +28,7 @@ class VarifocalLoss(nn.Module): def forward(pred_score, gt_score, label, alpha=0.75, gamma=2.0): """Computes varfocal loss.""" weight = alpha * pred_score.sigmoid().pow(gamma) * (1 - label) + gt_score * label - with torch.cuda.amp.autocast(enabled=False): + with autocast(enabled=False): loss = ( (F.binary_cross_entropy_with_logits(pred_score.float(), gt_score.float(), reduction="none") * weight) .mean(1) diff --git a/ultralytics/utils/torch_utils.py b/ultralytics/utils/torch_utils.py index 21973d7e2..fcecd1481 100644 --- a/ultralytics/utils/torch_utils.py +++ b/ultralytics/utils/torch_utils.py @@ -68,6 +68,37 @@ def smart_inference_mode(): return decorate +def autocast(enabled: bool, device: str = "cuda"): + """ + Get the appropriate autocast context manager based on PyTorch version and AMP setting. + + This function returns a context manager for automatic mixed precision (AMP) training that is compatible with both + older and newer versions of PyTorch. It handles the differences in the autocast API between PyTorch versions. + + Args: + enabled (bool): Whether to enable automatic mixed precision. + device (str, optional): The device to use for autocast. Defaults to 'cuda'. + + Returns: + (torch.amp.autocast): The appropriate autocast context manager. + + Note: + - For PyTorch versions 1.13 and newer, it uses `torch.amp.autocast`. + - For older versions, it uses `torch.cuda.autocast`. + + Example: + ```python + with autocast(amp=True): + # Your mixed precision operations here + pass + ``` + """ + if TORCH_1_13: + return torch.amp.autocast(device, enabled=enabled) + else: + return torch.cuda.amp.autocast(enabled) + + def get_cpu_info(): """Return a string with system CPU information, i.e. 'Apple M2'.""" import cpuinfo # pip install py-cpuinfo