|
|
|
@ -119,30 +119,27 @@ class Model(nn.Module): |
|
|
|
|
self.metrics = None # validation/training metrics |
|
|
|
|
self.session = None # HUB session |
|
|
|
|
self.task = task # task type |
|
|
|
|
self.model_name = model = str(model).strip() # strip spaces |
|
|
|
|
model = str(model).strip() |
|
|
|
|
|
|
|
|
|
# Check if Ultralytics HUB model from https://hub.ultralytics.com |
|
|
|
|
if self.is_hub_model(model): |
|
|
|
|
# Fetch model from HUB |
|
|
|
|
checks.check_requirements("hub-sdk>=0.0.5") |
|
|
|
|
checks.check_requirements("hub-sdk>=0.0.6") |
|
|
|
|
self.session = self._get_hub_session(model) |
|
|
|
|
model = self.session.model_file |
|
|
|
|
|
|
|
|
|
# Check if Triton Server model |
|
|
|
|
elif self.is_triton_model(model): |
|
|
|
|
self.model = model |
|
|
|
|
self.model_name = self.model = model |
|
|
|
|
self.task = task |
|
|
|
|
return |
|
|
|
|
|
|
|
|
|
# Load or create new YOLO model |
|
|
|
|
model = checks.check_model_file_from_stem(model) # add suffix, i.e. yolov8n -> yolov8n.pt |
|
|
|
|
if Path(model).suffix in (".yaml", ".yml"): |
|
|
|
|
self._new(model, task=task, verbose=verbose) |
|
|
|
|
else: |
|
|
|
|
self._load(model, task=task) |
|
|
|
|
|
|
|
|
|
self.model_name = model |
|
|
|
|
|
|
|
|
|
def __call__( |
|
|
|
|
self, |
|
|
|
|
source: Union[str, Path, int, list, tuple, np.ndarray, torch.Tensor] = None, |
|
|
|
@ -190,8 +187,8 @@ class Model(nn.Module): |
|
|
|
|
return any( |
|
|
|
|
( |
|
|
|
|
model.startswith(f"{HUB_WEB_ROOT}/models/"), # i.e. https://hub.ultralytics.com/models/MODEL_ID |
|
|
|
|
[len(x) for x in model.split("_")] == [42, 20], # APIKEY_MODELID |
|
|
|
|
len(model) == 20 and not Path(model).exists() and all(x not in model for x in "./\\"), # MODELID |
|
|
|
|
[len(x) for x in model.split("_")] == [42, 20], # APIKEY_MODEL |
|
|
|
|
len(model) == 20 and not Path(model).exists() and all(x not in model for x in "./\\"), # MODEL |
|
|
|
|
) |
|
|
|
|
) |
|
|
|
|
|
|
|
|
@ -215,6 +212,7 @@ class Model(nn.Module): |
|
|
|
|
# Below added to allow export from YAMLs |
|
|
|
|
self.model.args = {**DEFAULT_CFG_DICT, **self.overrides} # combine default and model args (prefer model args) |
|
|
|
|
self.model.task = self.task |
|
|
|
|
self.model_name = cfg |
|
|
|
|
|
|
|
|
|
def _load(self, weights: str, task=None) -> None: |
|
|
|
|
""" |
|
|
|
@ -224,19 +222,23 @@ class Model(nn.Module): |
|
|
|
|
weights (str): model checkpoint to be loaded |
|
|
|
|
task (str | None): model task |
|
|
|
|
""" |
|
|
|
|
suffix = Path(weights).suffix |
|
|
|
|
if suffix == ".pt": |
|
|
|
|
if weights.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://")): |
|
|
|
|
weights = checks.check_file(weights) # automatically download and return local filename |
|
|
|
|
weights = checks.check_model_file_from_stem(weights) # add suffix, i.e. yolov8n -> yolov8n.pt |
|
|
|
|
|
|
|
|
|
if Path(weights).suffix == ".pt": |
|
|
|
|
self.model, self.ckpt = attempt_load_one_weight(weights) |
|
|
|
|
self.task = self.model.args["task"] |
|
|
|
|
self.overrides = self.model.args = self._reset_ckpt_args(self.model.args) |
|
|
|
|
self.ckpt_path = self.model.pt_path |
|
|
|
|
else: |
|
|
|
|
weights = checks.check_file(weights) |
|
|
|
|
weights = checks.check_file(weights) # runs in all cases, not redundant with above call |
|
|
|
|
self.model, self.ckpt = weights, None |
|
|
|
|
self.task = task or guess_model_task(weights) |
|
|
|
|
self.ckpt_path = weights |
|
|
|
|
self.overrides["model"] = weights |
|
|
|
|
self.overrides["task"] = self.task |
|
|
|
|
self.model_name = weights |
|
|
|
|
|
|
|
|
|
def _check_is_pytorch_model(self) -> None: |
|
|
|
|
"""Raises TypeError is model is not a PyTorch model.""" |
|
|
|
|