Attempt to fix NAS models inference (#14630)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Ultralytics Assistant <135830346+UltralyticsAssistant@users.noreply.github.com>
pull/14621/head
Laughing 6 months ago committed by GitHub
parent fb20867262
commit 03225fce9a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 15
      ultralytics/models/nas/model.py

@ -17,7 +17,7 @@ import torch
from ultralytics.engine.model import Model
from ultralytics.utils.downloads import attempt_download_asset
from ultralytics.utils.torch_utils import model_info, smart_inference_mode
from ultralytics.utils.torch_utils import model_info
from .predict import NASPredictor
from .val import NASValidator
@ -50,16 +50,25 @@ class NAS(Model):
assert Path(model).suffix not in {".yaml", ".yml"}, "YOLO-NAS models only support pre-trained models."
super().__init__(model, task="detect")
@smart_inference_mode()
def _load(self, weights: str, task: str):
def _load(self, weights: str, task=None) -> None:
"""Loads an existing NAS model weights or creates a new NAS model with pretrained weights if not provided."""
import super_gradients
suffix = Path(weights).suffix
if suffix == ".pt":
self.model = torch.load(attempt_download_asset(weights))
elif suffix == "":
self.model = super_gradients.training.models.get(weights, pretrained_weights="coco")
# Override the forward method to ignore additional arguments
def new_forward(x, *args, **kwargs):
"""Ignore additional __call__ arguments."""
return self.model._original_forward(x)
self.model._original_forward = self.model.forward
self.model.forward = new_forward
# Standardize model
self.model.fuse = lambda verbose=True: self.model
self.model.stride = torch.tensor([32])

Loading…
Cancel
Save