# Ultralytics YOLO 🚀, AGPL-3.0 license """ YOLO-NAS model interface. Usage - Predict: from ultralytics import NAS model = NAS('yolo_nas_s') results = model.predict('ultralytics/assets/bus.jpg') """ from pathlib import Path import torch from ultralytics.yolo.cfg import get_cfg from ultralytics.yolo.engine.exporter import Exporter from ultralytics.yolo.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, LOGGER, ROOT, is_git_dir from ultralytics.yolo.utils.checks import check_imgsz from ...yolo.utils.torch_utils import model_info, smart_inference_mode from .predict import NASPredictor from .val import NASValidator class NAS: def __init__(self, model='yolo_nas_s.pt') -> None: # Load or create new NAS model import super_gradients self.predictor = None suffix = Path(model).suffix if suffix == '.pt': self._load(model) elif suffix == '': self.model = super_gradients.training.models.get(model, pretrained_weights='coco') self.task = 'detect' self.model.args = DEFAULT_CFG_DICT # attach args to model # Standardize model self.model.fuse = lambda verbose=True: self.model self.model.stride = torch.tensor([32]) self.model.names = dict(enumerate(self.model._class_names)) self.model.is_fused = lambda: False # for info() self.model.yaml = {} # for info() self.model.pt_path = model # for export() self.model.task = 'detect' # for export() self.info() @smart_inference_mode() def _load(self, weights: str): self.model = torch.load(weights) @smart_inference_mode() def predict(self, source=None, stream=False, **kwargs): """ Perform prediction using the YOLO model. Args: source (str | int | PIL | np.ndarray): The source of the image to make predictions on. Accepts all source types accepted by the YOLO model. stream (bool): Whether to stream the predictions or not. Defaults to False. **kwargs : Additional keyword arguments passed to the predictor. Check the 'configuration' section in the documentation for all available options. Returns: (List[ultralytics.yolo.engine.results.Results]): The prediction results. """ if source is None: source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg' LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.") overrides = dict(conf=0.25, task='detect', mode='predict') overrides.update(kwargs) # prefer kwargs if not self.predictor: self.predictor = NASPredictor(overrides=overrides) self.predictor.setup_model(model=self.model) else: # only update args if predictor is already setup self.predictor.args = get_cfg(self.predictor.args, overrides) return self.predictor(source, stream=stream) def train(self, **kwargs): """Function trains models but raises an error as NAS models do not support training.""" raise NotImplementedError("NAS models don't support training") def val(self, **kwargs): """Run validation given dataset.""" overrides = dict(task='detect', mode='val') overrides.update(kwargs) # prefer kwargs args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides) args.imgsz = check_imgsz(args.imgsz, max_dim=1) validator = NASValidator(args=args) validator(model=self.model) self.metrics = validator.metrics return validator.metrics @smart_inference_mode() def export(self, **kwargs): """ Export model. Args: **kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in docs """ overrides = dict(task='detect') overrides.update(kwargs) overrides['mode'] = 'export' args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides) args.task = self.task if args.imgsz == DEFAULT_CFG.imgsz: args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed if args.batch == DEFAULT_CFG.batch: args.batch = 1 # default to 1 if not modified return Exporter(overrides=args)(model=self.model) def info(self, detailed=False, verbose=True): """ Logs model info. Args: detailed (bool): Show detailed information about model. verbose (bool): Controls verbosity. """ return model_info(self.model, detailed=detailed, verbose=verbose, imgsz=640) def __call__(self, source=None, stream=False, **kwargs): """Calls the 'predict' function with given arguments to perform object detection.""" return self.predict(source, stream, **kwargs) def __getattr__(self, attr): """Raises error if object has no requested attribute.""" name = self.__class__.__name__ raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")