diff --git a/ultralytics/engine/validator.py b/ultralytics/engine/validator.py index 655f2455ca..160a549d64 100644 --- a/ultralytics/engine/validator.py +++ b/ultralytics/engine/validator.py @@ -110,7 +110,8 @@ class BaseValidator: if self.training: self.device = trainer.device self.data = trainer.data - self.args.half = self.device.type != "cpu" # force FP16 val during training + # force FP16 val during training + self.args.half = self.device.type != "cpu" and self.args.amp model = trainer.ema.ema or trainer.model model = model.half() if self.args.half else model.float() # self.model = model diff --git a/ultralytics/utils/checks.py b/ultralytics/utils/checks.py index 6b308bc146..70d3d088b4 100644 --- a/ultralytics/utils/checks.py +++ b/ultralytics/utils/checks.py @@ -656,9 +656,10 @@ def check_amp(model): def amp_allclose(m, im): """All close FP32 vs AMP results.""" - a = m(im, device=device, verbose=False)[0].boxes.data # FP32 inference + batch = [im] * 8 + a = m(batch, imgsz=128, device=device, verbose=False)[0].boxes.data # FP32 inference with autocast(enabled=True): - b = m(im, device=device, verbose=False)[0].boxes.data # AMP inference + b = m(batch, imgsz=128, device=device, verbose=False)[0].boxes.data # AMP inference del m return a.shape == b.shape and torch.allclose(a, b.float(), atol=0.5) # close to 0.5 absolute tolerance