|
|
|
@ -5,6 +5,7 @@ import threading |
|
|
|
|
import time |
|
|
|
|
from http import HTTPStatus |
|
|
|
|
from pathlib import Path |
|
|
|
|
from urllib.parse import parse_qs, urlparse |
|
|
|
|
|
|
|
|
|
import requests |
|
|
|
|
|
|
|
|
@ -77,7 +78,6 @@ class HUBTrainingSession: |
|
|
|
|
if not session.client.authenticated: |
|
|
|
|
if identifier.startswith(f"{HUB_WEB_ROOT}/models/"): |
|
|
|
|
LOGGER.warning(f"{PREFIX}WARNING ⚠️ Login to Ultralytics HUB with 'yolo hub login API_KEY'.") |
|
|
|
|
exit() |
|
|
|
|
return None |
|
|
|
|
if args and not identifier.startswith(f"{HUB_WEB_ROOT}/models/"): # not a HUB model URL |
|
|
|
|
session.create_model(args) |
|
|
|
@ -96,7 +96,8 @@ class HUBTrainingSession: |
|
|
|
|
self.model_url = f"{HUB_WEB_ROOT}/models/{self.model.id}" |
|
|
|
|
if self.model.is_trained(): |
|
|
|
|
print(emojis(f"Loading trained HUB model {self.model_url} 🚀")) |
|
|
|
|
self.model_file = self.model.get_weights_url("best") |
|
|
|
|
url = self.model.get_weights_url("best") # download URL with auth |
|
|
|
|
self.model_file = checks.check_file(url, download_dir=Path(SETTINGS["weights_dir"]) / "hub" / self.model.id) |
|
|
|
|
return |
|
|
|
|
|
|
|
|
|
# Set training args and start heartbeats for HUB to monitor agent |
|
|
|
@ -146,9 +147,8 @@ class HUBTrainingSession: |
|
|
|
|
Parses the given identifier to determine the type of identifier and extract relevant components. |
|
|
|
|
|
|
|
|
|
The method supports different identifier formats: |
|
|
|
|
- A HUB URL, which starts with HUB_WEB_ROOT followed by '/models/' |
|
|
|
|
- An identifier containing an API key and a model ID separated by an underscore |
|
|
|
|
- An identifier that is solely a model ID of a fixed length |
|
|
|
|
- A HUB model URL https://hub.ultralytics.com/models/MODEL |
|
|
|
|
- A HUB model URL with API Key https://hub.ultralytics.com/models/MODEL?api_key=APIKEY |
|
|
|
|
- A local filename that ends with '.pt' or '.yaml' |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
@ -160,32 +160,26 @@ class HUBTrainingSession: |
|
|
|
|
Raises: |
|
|
|
|
HUBModelError: If the identifier format is not recognized. |
|
|
|
|
""" |
|
|
|
|
# Initialize variables |
|
|
|
|
api_key, model_id, filename = None, None, None |
|
|
|
|
|
|
|
|
|
# Check if identifier is a HUB URL |
|
|
|
|
if identifier.startswith(f"{HUB_WEB_ROOT}/models/"): |
|
|
|
|
# Extract the model_id after the HUB_WEB_ROOT URL |
|
|
|
|
model_id = identifier.split(f"{HUB_WEB_ROOT}/models/")[-1] |
|
|
|
|
# path = identifier.split(f"{HUB_WEB_ROOT}/models/")[-1] |
|
|
|
|
# parts = path.split("_") |
|
|
|
|
# if Path(path).suffix in {".pt", ".yaml"}: |
|
|
|
|
# filename = path |
|
|
|
|
# elif len(parts) == 2 and len(parts[0]) == 42 and len(parts[1]) == 20: |
|
|
|
|
# api_key, model_id = parts |
|
|
|
|
# elif len(path) == 20: |
|
|
|
|
# model_id = path |
|
|
|
|
|
|
|
|
|
if Path(identifier).suffix in {".pt", ".yaml"}: |
|
|
|
|
filename = identifier |
|
|
|
|
elif identifier.startswith(f"{HUB_WEB_ROOT}/models/"): |
|
|
|
|
parsed_url = urlparse(identifier) |
|
|
|
|
model_id = Path(parsed_url.path).stem # handle possible final backslash robustly |
|
|
|
|
query_params = parse_qs(parsed_url.query) # dictionary, i.e. {"api_key": ["API_KEY_HERE"]} |
|
|
|
|
api_key = query_params.get("api_key", [None])[0] |
|
|
|
|
else: |
|
|
|
|
# Split the identifier based on underscores only if it's not a HUB URL |
|
|
|
|
parts = identifier.split("_") |
|
|
|
|
|
|
|
|
|
# Check if identifier is in the format of API key and model ID |
|
|
|
|
if len(parts) == 2 and len(parts[0]) == 42 and len(parts[1]) == 20: |
|
|
|
|
api_key, model_id = parts |
|
|
|
|
# Check if identifier is a single model ID |
|
|
|
|
elif len(parts) == 1 and len(parts[0]) == 20: |
|
|
|
|
model_id = parts[0] |
|
|
|
|
# Check if identifier is a local filename |
|
|
|
|
elif identifier.endswith(".pt") or identifier.endswith(".yaml"): |
|
|
|
|
filename = identifier |
|
|
|
|
else: |
|
|
|
|
raise HUBModelError( |
|
|
|
|
f"model='{identifier}' could not be parsed. Check format is correct. " |
|
|
|
|
f"Supported formats are Ultralytics HUB URL, apiKey_modelId, modelId, local pt or yaml file." |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
raise HUBModelError(f"model='{identifier} invalid, correct format is {HUB_WEB_ROOT}/models/MODEL_ID") |
|
|
|
|
return api_key, model_id, filename |
|
|
|
|
|
|
|
|
|
def _set_train_args(self): |
|
|
|
|