`ultralytics 8.2.97` robust HUB model downloads (#16347)

Signed-off-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
action-recog
Glenn Jocher 2 months ago committed by fcakyon
parent 4cb94273b0
commit ef1267bfec
  1. 2
      ultralytics/__init__.py
  2. 1
      ultralytics/cfg/__init__.py
  3. 22
      ultralytics/engine/model.py
  4. 50
      ultralytics/hub/session.py

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.2.96"
__version__ = "8.2.97"
import os

@ -712,6 +712,7 @@ def entrypoint(debug=""):
"cfg": lambda: yaml_print(DEFAULT_CFG_PATH),
"hub": lambda: handle_yolo_hub(args[1:]),
"login": lambda: handle_yolo_hub(args),
"logout": lambda: handle_yolo_hub(args),
"copy-cfg": copy_default_cfg,
"explorer": lambda: handle_explorer(args[1:]),
"streamlit-predict": lambda: handle_streamlit_inference(),

@ -206,33 +206,21 @@ class Model(nn.Module):
Check if the provided model is an Ultralytics HUB model.
This static method determines whether the given model string represents a valid Ultralytics HUB model
identifier. It checks for three possible formats: a full HUB URL, an API key and model ID combination,
or a standalone model ID.
identifier.
Args:
model (str): The model identifier to check. This can be a URL, an API key and model ID
combination, or a standalone model ID.
model (str): The model string to check.
Returns:
(bool): True if the model is a valid Ultralytics HUB model, False otherwise.
Examples:
>>> Model.is_hub_model("https://hub.ultralytics.com/models/example_model")
>>> Model.is_hub_model("https://hub.ultralytics.com/models/MODEL")
True
>>> Model.is_hub_model("api_key_example_model_id")
True
>>> Model.is_hub_model("example_model_id")
True
>>> Model.is_hub_model("not_a_hub_model.pt")
>>> Model.is_hub_model("yolov8n.pt")
False
"""
return any(
(
model.startswith(f"{HUB_WEB_ROOT}/models/"), # i.e. https://hub.ultralytics.com/models/MODEL_ID
[len(x) for x in model.split("_")] == [42, 20], # APIKEY_MODEL
len(model) == 20 and not Path(model).exists() and all(x not in model for x in "./\\"), # MODEL
)
)
return model.startswith(f"{HUB_WEB_ROOT}/models/")
def _new(self, cfg: str, task=None, model=None, verbose=False) -> None:
"""

@ -5,6 +5,7 @@ import threading
import time
from http import HTTPStatus
from pathlib import Path
from urllib.parse import parse_qs, urlparse
import requests
@ -77,7 +78,6 @@ class HUBTrainingSession:
if not session.client.authenticated:
if identifier.startswith(f"{HUB_WEB_ROOT}/models/"):
LOGGER.warning(f"{PREFIX}WARNING ⚠ Login to Ultralytics HUB with 'yolo hub login API_KEY'.")
exit()
return None
if args and not identifier.startswith(f"{HUB_WEB_ROOT}/models/"): # not a HUB model URL
session.create_model(args)
@ -96,7 +96,8 @@ class HUBTrainingSession:
self.model_url = f"{HUB_WEB_ROOT}/models/{self.model.id}"
if self.model.is_trained():
print(emojis(f"Loading trained HUB model {self.model_url} 🚀"))
self.model_file = self.model.get_weights_url("best")
url = self.model.get_weights_url("best") # download URL with auth
self.model_file = checks.check_file(url, download_dir=Path(SETTINGS["weights_dir"]) / "hub" / self.model.id)
return
# Set training args and start heartbeats for HUB to monitor agent
@ -146,9 +147,8 @@ class HUBTrainingSession:
Parses the given identifier to determine the type of identifier and extract relevant components.
The method supports different identifier formats:
- A HUB URL, which starts with HUB_WEB_ROOT followed by '/models/'
- An identifier containing an API key and a model ID separated by an underscore
- An identifier that is solely a model ID of a fixed length
- A HUB model URL https://hub.ultralytics.com/models/MODEL
- A HUB model URL with API Key https://hub.ultralytics.com/models/MODEL?api_key=APIKEY
- A local filename that ends with '.pt' or '.yaml'
Args:
@ -160,32 +160,26 @@ class HUBTrainingSession:
Raises:
HUBModelError: If the identifier format is not recognized.
"""
# Initialize variables
api_key, model_id, filename = None, None, None
# Check if identifier is a HUB URL
if identifier.startswith(f"{HUB_WEB_ROOT}/models/"):
# Extract the model_id after the HUB_WEB_ROOT URL
model_id = identifier.split(f"{HUB_WEB_ROOT}/models/")[-1]
# path = identifier.split(f"{HUB_WEB_ROOT}/models/")[-1]
# parts = path.split("_")
# if Path(path).suffix in {".pt", ".yaml"}:
# filename = path
# elif len(parts) == 2 and len(parts[0]) == 42 and len(parts[1]) == 20:
# api_key, model_id = parts
# elif len(path) == 20:
# model_id = path
if Path(identifier).suffix in {".pt", ".yaml"}:
filename = identifier
elif identifier.startswith(f"{HUB_WEB_ROOT}/models/"):
parsed_url = urlparse(identifier)
model_id = Path(parsed_url.path).stem # handle possible final backslash robustly
query_params = parse_qs(parsed_url.query) # dictionary, i.e. {"api_key": ["API_KEY_HERE"]}
api_key = query_params.get("api_key", [None])[0]
else:
# Split the identifier based on underscores only if it's not a HUB URL
parts = identifier.split("_")
# Check if identifier is in the format of API key and model ID
if len(parts) == 2 and len(parts[0]) == 42 and len(parts[1]) == 20:
api_key, model_id = parts
# Check if identifier is a single model ID
elif len(parts) == 1 and len(parts[0]) == 20:
model_id = parts[0]
# Check if identifier is a local filename
elif identifier.endswith(".pt") or identifier.endswith(".yaml"):
filename = identifier
else:
raise HUBModelError(
f"model='{identifier}' could not be parsed. Check format is correct. "
f"Supported formats are Ultralytics HUB URL, apiKey_modelId, modelId, local pt or yaml file."
)
raise HUBModelError(f"model='{identifier} invalid, correct format is {HUB_WEB_ROOT}/models/MODEL_ID")
return api_key, model_id, filename
def _set_train_args(self):

Loading…
Cancel
Save