Disable FP16 val on AMP fail and improve AMP checks (#16306)

Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Ultralytics Assistant <135830346+UltralyticsAssistant@users.noreply.github.com>
action-recog
Mohammed Yasin 2 months ago committed by fcakyon
parent 639b2cced9
commit db960cac7f
  1. 3
      ultralytics/engine/validator.py
  2. 5
      ultralytics/utils/checks.py

@ -110,7 +110,8 @@ class BaseValidator:
if self.training:
self.device = trainer.device
self.data = trainer.data
self.args.half = self.device.type != "cpu" # force FP16 val during training
# force FP16 val during training
self.args.half = self.device.type != "cpu" and self.args.amp
model = trainer.ema.ema or trainer.model
model = model.half() if self.args.half else model.float()
# self.model = model

@ -656,9 +656,10 @@ def check_amp(model):
def amp_allclose(m, im):
"""All close FP32 vs AMP results."""
a = m(im, device=device, verbose=False)[0].boxes.data # FP32 inference
batch = [im] * 8
a = m(batch, imgsz=128, device=device, verbose=False)[0].boxes.data # FP32 inference
with autocast(enabled=True):
b = m(im, device=device, verbose=False)[0].boxes.data # AMP inference
b = m(batch, imgsz=128, device=device, verbose=False)[0].boxes.data # AMP inference
del m
return a.shape == b.shape and torch.allclose(a, b.float(), atol=0.5) # close to 0.5 absolute tolerance

Loading…
Cancel
Save