experiment with early warmup

warmups
Burhan 6 months ago
parent e36df75721
commit 4c8229d328
  1. 13
      ultralytics/engine/model.py
  2. 5
      ultralytics/engine/predictor.py
  3. 2
      ultralytics/nn/autobackend.py

@ -450,6 +450,10 @@ class Model(nn.Module):
self.predictor.save_dir = get_save_dir(self.predictor.args)
if prompts and hasattr(self.predictor, "set_prompts"): # for SAM-type models
self.predictor.set_prompts(prompts)
cycle = int(args.get("warmup_epochs", 0))
if cycle > 0:
self.warmup(from_cli=is_cli, cycles=cycle, **args.copy())
return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream)
def track(
@ -843,3 +847,12 @@ class Model(nn.Module):
task_map (dict): The map of model task to mode classes.
"""
raise NotImplementedError("Please provide task map for your model!")
def warmup(self, from_cli, cycles, **kwargs):
"""Warmup model."""
if self.predictor and not self.predictor.done_warmup:
self.predictor.args.verbose = self.predictor.args.save = False # silent
_ = [self.predictor(source=np.random.randint(0, 255, (640, 640, 3), dtype=np.uint8),) for _ in range(cycles)]
self.predictor.args.verbose = from_cli or kwargs.get("verbose", True)
self.predictor.args.save = kwargs.get("save", from_cli) # reset
self.predictor.done_warmup = True

@ -231,7 +231,10 @@ class BasePredictor:
# Warmup model
if not self.done_warmup:
self.model.warmup(imgsz=(1 if self.model.pt or self.model.triton else self.dataset.bs, 3, *self.imgsz), cycles=self.args.warmup_epochs)
self.model.warmup(
imgsz=(1 if self.model.pt or self.model.triton else self.dataset.bs, 3, *self.imgsz),
cycles=self.args.warmup_epochs,
)
self.done_warmup = True
self.seen, self.windows, self.batch = 0, [], None

@ -628,7 +628,7 @@ class AutoBackend(nn.Module):
if any(warmup_types) and (self.device.type != "cpu" or self.triton):
# im = torch.empty(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) # input
im = torch.randn(*imgsz, device=self.device, dtype=torch.half if self.fp16 else torch.float) # random input
for _ in range(2 if self.jit else int(cycles)):
for _ in range(2 if self.jit else (cycles or 1)):
self.forward(im) # warmup
@staticmethod

Loading…
Cancel
Save