`ultralytics 8.2.97` robust HUB model downloads (#16347)

Signed-off-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
action-recog
Glenn Jocher 2 months ago committed by fcakyon
parent 4cb94273b0
commit ef1267bfec
  1. 2
      ultralytics/__init__.py
  2. 1
      ultralytics/cfg/__init__.py
  3. 22
      ultralytics/engine/model.py
  4. 50
      ultralytics/hub/session.py

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license # Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.2.96" __version__ = "8.2.97"
import os import os

@ -712,6 +712,7 @@ def entrypoint(debug=""):
"cfg": lambda: yaml_print(DEFAULT_CFG_PATH), "cfg": lambda: yaml_print(DEFAULT_CFG_PATH),
"hub": lambda: handle_yolo_hub(args[1:]), "hub": lambda: handle_yolo_hub(args[1:]),
"login": lambda: handle_yolo_hub(args), "login": lambda: handle_yolo_hub(args),
"logout": lambda: handle_yolo_hub(args),
"copy-cfg": copy_default_cfg, "copy-cfg": copy_default_cfg,
"explorer": lambda: handle_explorer(args[1:]), "explorer": lambda: handle_explorer(args[1:]),
"streamlit-predict": lambda: handle_streamlit_inference(), "streamlit-predict": lambda: handle_streamlit_inference(),

@ -206,33 +206,21 @@ class Model(nn.Module):
Check if the provided model is an Ultralytics HUB model. Check if the provided model is an Ultralytics HUB model.
This static method determines whether the given model string represents a valid Ultralytics HUB model This static method determines whether the given model string represents a valid Ultralytics HUB model
identifier. It checks for three possible formats: a full HUB URL, an API key and model ID combination, identifier.
or a standalone model ID.
Args: Args:
model (str): The model identifier to check. This can be a URL, an API key and model ID model (str): The model string to check.
combination, or a standalone model ID.
Returns: Returns:
(bool): True if the model is a valid Ultralytics HUB model, False otherwise. (bool): True if the model is a valid Ultralytics HUB model, False otherwise.
Examples: Examples:
>>> Model.is_hub_model("https://hub.ultralytics.com/models/example_model") >>> Model.is_hub_model("https://hub.ultralytics.com/models/MODEL")
True True
>>> Model.is_hub_model("api_key_example_model_id") >>> Model.is_hub_model("yolov8n.pt")
True
>>> Model.is_hub_model("example_model_id")
True
>>> Model.is_hub_model("not_a_hub_model.pt")
False False
""" """
return any( return model.startswith(f"{HUB_WEB_ROOT}/models/")
(
model.startswith(f"{HUB_WEB_ROOT}/models/"), # i.e. https://hub.ultralytics.com/models/MODEL_ID
[len(x) for x in model.split("_")] == [42, 20], # APIKEY_MODEL
len(model) == 20 and not Path(model).exists() and all(x not in model for x in "./\\"), # MODEL
)
)
def _new(self, cfg: str, task=None, model=None, verbose=False) -> None: def _new(self, cfg: str, task=None, model=None, verbose=False) -> None:
""" """

@ -5,6 +5,7 @@ import threading
import time import time
from http import HTTPStatus from http import HTTPStatus
from pathlib import Path from pathlib import Path
from urllib.parse import parse_qs, urlparse
import requests import requests
@ -77,7 +78,6 @@ class HUBTrainingSession:
if not session.client.authenticated: if not session.client.authenticated:
if identifier.startswith(f"{HUB_WEB_ROOT}/models/"): if identifier.startswith(f"{HUB_WEB_ROOT}/models/"):
LOGGER.warning(f"{PREFIX}WARNING ⚠ Login to Ultralytics HUB with 'yolo hub login API_KEY'.") LOGGER.warning(f"{PREFIX}WARNING ⚠ Login to Ultralytics HUB with 'yolo hub login API_KEY'.")
exit()
return None return None
if args and not identifier.startswith(f"{HUB_WEB_ROOT}/models/"): # not a HUB model URL if args and not identifier.startswith(f"{HUB_WEB_ROOT}/models/"): # not a HUB model URL
session.create_model(args) session.create_model(args)
@ -96,7 +96,8 @@ class HUBTrainingSession:
self.model_url = f"{HUB_WEB_ROOT}/models/{self.model.id}" self.model_url = f"{HUB_WEB_ROOT}/models/{self.model.id}"
if self.model.is_trained(): if self.model.is_trained():
print(emojis(f"Loading trained HUB model {self.model_url} 🚀")) print(emojis(f"Loading trained HUB model {self.model_url} 🚀"))
self.model_file = self.model.get_weights_url("best") url = self.model.get_weights_url("best") # download URL with auth
self.model_file = checks.check_file(url, download_dir=Path(SETTINGS["weights_dir"]) / "hub" / self.model.id)
return return
# Set training args and start heartbeats for HUB to monitor agent # Set training args and start heartbeats for HUB to monitor agent
@ -146,9 +147,8 @@ class HUBTrainingSession:
Parses the given identifier to determine the type of identifier and extract relevant components. Parses the given identifier to determine the type of identifier and extract relevant components.
The method supports different identifier formats: The method supports different identifier formats:
- A HUB URL, which starts with HUB_WEB_ROOT followed by '/models/' - A HUB model URL https://hub.ultralytics.com/models/MODEL
- An identifier containing an API key and a model ID separated by an underscore - A HUB model URL with API Key https://hub.ultralytics.com/models/MODEL?api_key=APIKEY
- An identifier that is solely a model ID of a fixed length
- A local filename that ends with '.pt' or '.yaml' - A local filename that ends with '.pt' or '.yaml'
Args: Args:
@ -160,32 +160,26 @@ class HUBTrainingSession:
Raises: Raises:
HUBModelError: If the identifier format is not recognized. HUBModelError: If the identifier format is not recognized.
""" """
# Initialize variables
api_key, model_id, filename = None, None, None api_key, model_id, filename = None, None, None
# Check if identifier is a HUB URL # path = identifier.split(f"{HUB_WEB_ROOT}/models/")[-1]
if identifier.startswith(f"{HUB_WEB_ROOT}/models/"): # parts = path.split("_")
# Extract the model_id after the HUB_WEB_ROOT URL # if Path(path).suffix in {".pt", ".yaml"}:
model_id = identifier.split(f"{HUB_WEB_ROOT}/models/")[-1] # filename = path
# elif len(parts) == 2 and len(parts[0]) == 42 and len(parts[1]) == 20:
# api_key, model_id = parts
# elif len(path) == 20:
# model_id = path
if Path(identifier).suffix in {".pt", ".yaml"}:
filename = identifier
elif identifier.startswith(f"{HUB_WEB_ROOT}/models/"):
parsed_url = urlparse(identifier)
model_id = Path(parsed_url.path).stem # handle possible final backslash robustly
query_params = parse_qs(parsed_url.query) # dictionary, i.e. {"api_key": ["API_KEY_HERE"]}
api_key = query_params.get("api_key", [None])[0]
else: else:
# Split the identifier based on underscores only if it's not a HUB URL raise HUBModelError(f"model='{identifier} invalid, correct format is {HUB_WEB_ROOT}/models/MODEL_ID")
parts = identifier.split("_")
# Check if identifier is in the format of API key and model ID
if len(parts) == 2 and len(parts[0]) == 42 and len(parts[1]) == 20:
api_key, model_id = parts
# Check if identifier is a single model ID
elif len(parts) == 1 and len(parts[0]) == 20:
model_id = parts[0]
# Check if identifier is a local filename
elif identifier.endswith(".pt") or identifier.endswith(".yaml"):
filename = identifier
else:
raise HUBModelError(
f"model='{identifier}' could not be parsed. Check format is correct. "
f"Supported formats are Ultralytics HUB URL, apiKey_modelId, modelId, local pt or yaml file."
)
return api_key, model_id, filename return api_key, model_id, filename
def _set_train_args(self): def _set_train_args(self):

Loading…
Cancel
Save