diff --git a/ultralytics/models/nas/model.py b/ultralytics/models/nas/model.py index fd444f1389..90446c585e 100644 --- a/ultralytics/models/nas/model.py +++ b/ultralytics/models/nas/model.py @@ -17,7 +17,7 @@ import torch from ultralytics.engine.model import Model from ultralytics.utils.downloads import attempt_download_asset -from ultralytics.utils.torch_utils import model_info, smart_inference_mode +from ultralytics.utils.torch_utils import model_info from .predict import NASPredictor from .val import NASValidator @@ -50,16 +50,25 @@ class NAS(Model): assert Path(model).suffix not in {".yaml", ".yml"}, "YOLO-NAS models only support pre-trained models." super().__init__(model, task="detect") - @smart_inference_mode() - def _load(self, weights: str, task: str): + def _load(self, weights: str, task=None) -> None: """Loads an existing NAS model weights or creates a new NAS model with pretrained weights if not provided.""" import super_gradients suffix = Path(weights).suffix if suffix == ".pt": self.model = torch.load(attempt_download_asset(weights)) + elif suffix == "": self.model = super_gradients.training.models.get(weights, pretrained_weights="coco") + + # Override the forward method to ignore additional arguments + def new_forward(x, *args, **kwargs): + """Ignore additional __call__ arguments.""" + return self.model._original_forward(x) + + self.model._original_forward = self.model.forward + self.model.forward = new_forward + # Standardize model self.model.fuse = lambda verbose=True: self.model self.model.stride = torch.tensor([32])