|
|
@ -1,12 +1,10 @@ |
|
|
|
from pathlib import Path |
|
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from ultralytics import yolo # noqa |
|
|
|
from ultralytics import yolo # noqa |
|
|
|
from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel, attempt_load_weights |
|
|
|
from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel, attempt_load_weights |
|
|
|
from ultralytics.yolo.configs import get_config |
|
|
|
from ultralytics.yolo.configs import get_config |
|
|
|
from ultralytics.yolo.engine.exporter import Exporter |
|
|
|
from ultralytics.yolo.engine.exporter import Exporter |
|
|
|
from ultralytics.yolo.utils import DEFAULT_CONFIG, HELP_MSG, LOGGER, yaml_load |
|
|
|
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, yaml_load |
|
|
|
from ultralytics.yolo.utils.checks import check_imgsz, check_yaml |
|
|
|
from ultralytics.yolo.utils.checks import check_imgsz, check_yaml |
|
|
|
from ultralytics.yolo.utils.torch_utils import guess_task_from_head, smart_inference_mode |
|
|
|
from ultralytics.yolo.utils.torch_utils import guess_task_from_head, smart_inference_mode |
|
|
|
|
|
|
|
|
|
|
@ -55,6 +53,9 @@ class YOLO: |
|
|
|
# Load or create new YOLO model |
|
|
|
# Load or create new YOLO model |
|
|
|
{'.pt': self._load, '.yaml': self._new}[Path(model).suffix](model) |
|
|
|
{'.pt': self._load, '.yaml': self._new}[Path(model).suffix](model) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __call__(self, source): |
|
|
|
|
|
|
|
return self.predict(source) |
|
|
|
|
|
|
|
|
|
|
|
def _new(self, cfg: str, verbose=True): |
|
|
|
def _new(self, cfg: str, verbose=True): |
|
|
|
""" |
|
|
|
""" |
|
|
|
Initializes a new model and infers the task type from the model definitions. |
|
|
|
Initializes a new model and infers the task type from the model definitions. |
|
|
@ -211,14 +212,6 @@ class YOLO: |
|
|
|
|
|
|
|
|
|
|
|
return model_class, trainer_class, validator_class, predictor_class |
|
|
|
return model_class, trainer_class, validator_class, predictor_class |
|
|
|
|
|
|
|
|
|
|
|
@smart_inference_mode() |
|
|
|
|
|
|
|
def __call__(self, imgs): |
|
|
|
|
|
|
|
device = next(self.model.parameters()).device # get model device |
|
|
|
|
|
|
|
return self.model(imgs.to(device)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, imgs): |
|
|
|
|
|
|
|
return self.__call__(imgs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
@staticmethod |
|
|
|
def _reset_ckpt_args(args): |
|
|
|
def _reset_ckpt_args(args): |
|
|
|
args.pop("device", None) |
|
|
|
args.pop("device", None) |
|
|
|