`ultralytics 8.1.33` fix HUB model checks (#9153)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
pull/9270/head v8.1.33
Kalen Michael 8 months ago committed by GitHub
parent fc6c66a4a4
commit ec1d110689
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 2
      pyproject.toml
  2. 2
      ultralytics/__init__.py
  3. 24
      ultralytics/engine/model.py
  4. 2
      ultralytics/hub/__init__.py
  5. 2
      ultralytics/nn/autobackend.py
  6. 4
      ultralytics/utils/checks.py
  7. 2
      ultralytics/utils/downloads.py

@ -117,7 +117,7 @@ logging = [
"dvclive>=2.12.0", "dvclive>=2.12.0",
] ]
extra = [ extra = [
"hub-sdk>=0.0.2", # Ultralytics HUB "hub-sdk>=0.0.5", # Ultralytics HUB
"ipython", # interactive notebook "ipython", # interactive notebook
"albumentations>=1.0.3", # training augmentations "albumentations>=1.0.3", # training augmentations
"pycocotools>=2.0.7", # COCO mAP "pycocotools>=2.0.7", # COCO mAP

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license # Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.1.32" __version__ = "8.1.33"
from ultralytics.data.explorer.explorer import Explorer from ultralytics.data.explorer.explorer import Explorer
from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld

@ -119,30 +119,27 @@ class Model(nn.Module):
self.metrics = None # validation/training metrics self.metrics = None # validation/training metrics
self.session = None # HUB session self.session = None # HUB session
self.task = task # task type self.task = task # task type
self.model_name = model = str(model).strip() # strip spaces model = str(model).strip()
# Check if Ultralytics HUB model from https://hub.ultralytics.com # Check if Ultralytics HUB model from https://hub.ultralytics.com
if self.is_hub_model(model): if self.is_hub_model(model):
# Fetch model from HUB # Fetch model from HUB
checks.check_requirements("hub-sdk>=0.0.5") checks.check_requirements("hub-sdk>=0.0.6")
self.session = self._get_hub_session(model) self.session = self._get_hub_session(model)
model = self.session.model_file model = self.session.model_file
# Check if Triton Server model # Check if Triton Server model
elif self.is_triton_model(model): elif self.is_triton_model(model):
self.model = model self.model_name = self.model = model
self.task = task self.task = task
return return
# Load or create new YOLO model # Load or create new YOLO model
model = checks.check_model_file_from_stem(model) # add suffix, i.e. yolov8n -> yolov8n.pt
if Path(model).suffix in (".yaml", ".yml"): if Path(model).suffix in (".yaml", ".yml"):
self._new(model, task=task, verbose=verbose) self._new(model, task=task, verbose=verbose)
else: else:
self._load(model, task=task) self._load(model, task=task)
self.model_name = model
def __call__( def __call__(
self, self,
source: Union[str, Path, int, list, tuple, np.ndarray, torch.Tensor] = None, source: Union[str, Path, int, list, tuple, np.ndarray, torch.Tensor] = None,
@ -190,8 +187,8 @@ class Model(nn.Module):
return any( return any(
( (
model.startswith(f"{HUB_WEB_ROOT}/models/"), # i.e. https://hub.ultralytics.com/models/MODEL_ID model.startswith(f"{HUB_WEB_ROOT}/models/"), # i.e. https://hub.ultralytics.com/models/MODEL_ID
[len(x) for x in model.split("_")] == [42, 20], # APIKEY_MODELID [len(x) for x in model.split("_")] == [42, 20], # APIKEY_MODEL
len(model) == 20 and not Path(model).exists() and all(x not in model for x in "./\\"), # MODELID len(model) == 20 and not Path(model).exists() and all(x not in model for x in "./\\"), # MODEL
) )
) )
@ -215,6 +212,7 @@ class Model(nn.Module):
# Below added to allow export from YAMLs # Below added to allow export from YAMLs
self.model.args = {**DEFAULT_CFG_DICT, **self.overrides} # combine default and model args (prefer model args) self.model.args = {**DEFAULT_CFG_DICT, **self.overrides} # combine default and model args (prefer model args)
self.model.task = self.task self.model.task = self.task
self.model_name = cfg
def _load(self, weights: str, task=None) -> None: def _load(self, weights: str, task=None) -> None:
""" """
@ -224,19 +222,23 @@ class Model(nn.Module):
weights (str): model checkpoint to be loaded weights (str): model checkpoint to be loaded
task (str | None): model task task (str | None): model task
""" """
suffix = Path(weights).suffix if weights.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://")):
if suffix == ".pt": weights = checks.check_file(weights) # automatically download and return local filename
weights = checks.check_model_file_from_stem(weights) # add suffix, i.e. yolov8n -> yolov8n.pt
if Path(weights).suffix == ".pt":
self.model, self.ckpt = attempt_load_one_weight(weights) self.model, self.ckpt = attempt_load_one_weight(weights)
self.task = self.model.args["task"] self.task = self.model.args["task"]
self.overrides = self.model.args = self._reset_ckpt_args(self.model.args) self.overrides = self.model.args = self._reset_ckpt_args(self.model.args)
self.ckpt_path = self.model.pt_path self.ckpt_path = self.model.pt_path
else: else:
weights = checks.check_file(weights) weights = checks.check_file(weights) # runs in all cases, not redundant with above call
self.model, self.ckpt = weights, None self.model, self.ckpt = weights, None
self.task = task or guess_model_task(weights) self.task = task or guess_model_task(weights)
self.ckpt_path = weights self.ckpt_path = weights
self.overrides["model"] = weights self.overrides["model"] = weights
self.overrides["task"] = self.task self.overrides["task"] = self.task
self.model_name = weights
def _check_is_pytorch_model(self) -> None: def _check_is_pytorch_model(self) -> None:
"""Raises TypeError is model is not a PyTorch model.""" """Raises TypeError is model is not a PyTorch model."""

@ -23,7 +23,7 @@ def login(api_key: str = None, save=True) -> bool:
Returns: Returns:
(bool): True if authentication is successful, False otherwise. (bool): True if authentication is successful, False otherwise.
""" """
checks.check_requirements("hub-sdk>=0.0.2") checks.check_requirements("hub-sdk>=0.0.6")
from hub_sdk import HUBClient from hub_sdk import HUBClient
api_key_url = f"{HUB_WEB_ROOT}/settings?tab=api+keys" # set the redirect URL api_key_url = f"{HUB_WEB_ROOT}/settings?tab=api+keys" # set the redirect URL

@ -603,7 +603,7 @@ class AutoBackend(nn.Module):
from ultralytics.engine.exporter import export_formats from ultralytics.engine.exporter import export_formats
sf = list(export_formats().Suffix) # export suffixes sf = list(export_formats().Suffix) # export suffixes
if not is_url(p, check=False) and not isinstance(p, str): if not is_url(p) and not isinstance(p, str):
check_suffix(p, sf) # checks check_suffix(p, sf) # checks
name = Path(p).name name = Path(p).name
types = [s in name for s in sf] types = [s in name for s in sf]

@ -315,7 +315,7 @@ def check_font(font="Arial.ttf"):
# Download to USER_CONFIG_DIR if missing # Download to USER_CONFIG_DIR if missing
url = f"https://ultralytics.com/assets/{name}" url = f"https://ultralytics.com/assets/{name}"
if downloads.is_url(url): if downloads.is_url(url, check=True):
downloads.safe_download(url=url, file=file) downloads.safe_download(url=url, file=file)
return file return file
@ -498,7 +498,7 @@ def check_file(file, suffix="", download=True, hard=True):
raise FileNotFoundError(f"'{file}' does not exist") raise FileNotFoundError(f"'{file}' does not exist")
elif len(files) > 1 and hard: elif len(files) > 1 and hard:
raise FileNotFoundError(f"Multiple files match '{file}', specify exact path: {files}") raise FileNotFoundError(f"Multiple files match '{file}', specify exact path: {files}")
return files[0] if len(files) else [] # return file return files[0] if len(files) else [] if hard else file # return file
def check_yaml(file, suffix=(".yaml", ".yml"), hard=True): def check_yaml(file, suffix=(".yaml", ".yml"), hard=True):

@ -33,7 +33,7 @@ GITHUB_ASSETS_NAMES = (
GITHUB_ASSETS_STEMS = [Path(k).stem for k in GITHUB_ASSETS_NAMES] GITHUB_ASSETS_STEMS = [Path(k).stem for k in GITHUB_ASSETS_NAMES]
def is_url(url, check=True): def is_url(url, check=False):
""" """
Validates if the given string is a URL and optionally checks if the URL exists online. Validates if the given string is a URL and optionally checks if the URL exists online.

Loading…
Cancel
Save