|
|
|
@ -17,7 +17,7 @@ import torch |
|
|
|
|
|
|
|
|
|
from ultralytics.engine.model import Model |
|
|
|
|
from ultralytics.utils.downloads import attempt_download_asset |
|
|
|
|
from ultralytics.utils.torch_utils import model_info, smart_inference_mode |
|
|
|
|
from ultralytics.utils.torch_utils import model_info |
|
|
|
|
|
|
|
|
|
from .predict import NASPredictor |
|
|
|
|
from .val import NASValidator |
|
|
|
@ -50,16 +50,25 @@ class NAS(Model): |
|
|
|
|
assert Path(model).suffix not in {".yaml", ".yml"}, "YOLO-NAS models only support pre-trained models." |
|
|
|
|
super().__init__(model, task="detect") |
|
|
|
|
|
|
|
|
|
@smart_inference_mode() |
|
|
|
|
def _load(self, weights: str, task: str): |
|
|
|
|
def _load(self, weights: str, task=None) -> None: |
|
|
|
|
"""Loads an existing NAS model weights or creates a new NAS model with pretrained weights if not provided.""" |
|
|
|
|
import super_gradients |
|
|
|
|
|
|
|
|
|
suffix = Path(weights).suffix |
|
|
|
|
if suffix == ".pt": |
|
|
|
|
self.model = torch.load(attempt_download_asset(weights)) |
|
|
|
|
|
|
|
|
|
elif suffix == "": |
|
|
|
|
self.model = super_gradients.training.models.get(weights, pretrained_weights="coco") |
|
|
|
|
|
|
|
|
|
# Override the forward method to ignore additional arguments |
|
|
|
|
def new_forward(x, *args, **kwargs): |
|
|
|
|
"""Ignore additional __call__ arguments.""" |
|
|
|
|
return self.model._original_forward(x) |
|
|
|
|
|
|
|
|
|
self.model._original_forward = self.model.forward |
|
|
|
|
self.model.forward = new_forward |
|
|
|
|
|
|
|
|
|
# Standardize model |
|
|
|
|
self.model.fuse = lambda verbose=True: self.model |
|
|
|
|
self.model.stride = torch.tensor([32]) |
|
|
|
|