From 6f2bb65953b27c931b6dce0e92a31c2f2783d45e Mon Sep 17 00:00:00 2001 From: Mohammed Yasin <32206511+Y-T-G@users.noreply.github.com> Date: Wed, 18 Sep 2024 00:44:56 +0800 Subject: [PATCH] Disable FP16 val on AMP fail and improve AMP checks (#16306) Co-authored-by: UltralyticsAssistant Co-authored-by: Ultralytics Assistant <135830346+UltralyticsAssistant@users.noreply.github.com> --- ultralytics/engine/validator.py | 3 ++- ultralytics/utils/checks.py | 5 +++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/ultralytics/engine/validator.py b/ultralytics/engine/validator.py index 655f2455c..160a549d6 100644 --- a/ultralytics/engine/validator.py +++ b/ultralytics/engine/validator.py @@ -110,7 +110,8 @@ class BaseValidator: if self.training: self.device = trainer.device self.data = trainer.data - self.args.half = self.device.type != "cpu" # force FP16 val during training + # force FP16 val during training + self.args.half = self.device.type != "cpu" and self.args.amp model = trainer.ema.ema or trainer.model model = model.half() if self.args.half else model.float() # self.model = model diff --git a/ultralytics/utils/checks.py b/ultralytics/utils/checks.py index 6b308bc14..70d3d088b 100644 --- a/ultralytics/utils/checks.py +++ b/ultralytics/utils/checks.py @@ -656,9 +656,10 @@ def check_amp(model): def amp_allclose(m, im): """All close FP32 vs AMP results.""" - a = m(im, device=device, verbose=False)[0].boxes.data # FP32 inference + batch = [im] * 8 + a = m(batch, imgsz=128, device=device, verbose=False)[0].boxes.data # FP32 inference with autocast(enabled=True): - b = m(im, device=device, verbose=False)[0].boxes.data # AMP inference + b = m(batch, imgsz=128, device=device, verbose=False)[0].boxes.data # AMP inference del m return a.shape == b.shape and torch.allclose(a, b.float(), atol=0.5) # close to 0.5 absolute tolerance