|
|
# Ultralytics YOLO 🚀, AGPL-3.0 license |
|
|
|
|
|
import threading |
|
|
import time |
|
|
from http import HTTPStatus |
|
|
from pathlib import Path |
|
|
|
|
|
import requests |
|
|
from hub_sdk import HUB_WEB_ROOT, HUBClient |
|
|
|
|
|
from ultralytics.hub.utils import HELP_MSG, PREFIX, TQDM |
|
|
from ultralytics.utils import LOGGER, SETTINGS, __version__, checks, emojis, is_colab |
|
|
from ultralytics.utils.errors import HUBModelError |
|
|
|
|
|
AGENT_NAME = f"python-{__version__}-colab" if is_colab() else f"python-{__version__}-local" |
|
|
|
|
|
|
|
|
class HUBTrainingSession: |
|
|
""" |
|
|
HUB training session for Ultralytics HUB YOLO models. Handles model initialization, heartbeats, and checkpointing. |
|
|
|
|
|
Attributes: |
|
|
agent_id (str): Identifier for the instance communicating with the server. |
|
|
model_id (str): Identifier for the YOLO model being trained. |
|
|
model_url (str): URL for the model in Ultralytics HUB. |
|
|
api_url (str): API URL for the model in Ultralytics HUB. |
|
|
auth_header (dict): Authentication header for the Ultralytics HUB API requests. |
|
|
rate_limits (dict): Rate limits for different API calls (in seconds). |
|
|
timers (dict): Timers for rate limiting. |
|
|
metrics_queue (dict): Queue for the model's metrics. |
|
|
model (dict): Model data fetched from Ultralytics HUB. |
|
|
alive (bool): Indicates if the heartbeat loop is active. |
|
|
""" |
|
|
|
|
|
def __init__(self, identifier): |
|
|
""" |
|
|
Initialize the HUBTrainingSession with the provided model identifier. |
|
|
|
|
|
Args: |
|
|
identifier (str): Model identifier used to initialize the HUB training session. |
|
|
It can be a URL string or a model key with specific format. |
|
|
|
|
|
Raises: |
|
|
ValueError: If the provided model identifier is invalid. |
|
|
ConnectionError: If connecting with global API key is not supported. |
|
|
""" |
|
|
self.rate_limits = { |
|
|
"metrics": 3.0, |
|
|
"ckpt": 900.0, |
|
|
"heartbeat": 300.0, |
|
|
} # rate limits (seconds) |
|
|
self.metrics_queue = {} # holds metrics for each epoch until upload |
|
|
self.timers = {} # holds timers in ultralytics/utils/callbacks/hub.py |
|
|
|
|
|
# Parse input |
|
|
api_key, model_id, self.filename = self._parse_identifier(identifier) |
|
|
|
|
|
# Get credentials |
|
|
active_key = api_key or SETTINGS.get("api_key") |
|
|
credentials = {"api_key": active_key} if active_key else None # set credentials |
|
|
|
|
|
# Initialize client |
|
|
self.client = HUBClient(credentials) |
|
|
|
|
|
if model_id: |
|
|
self.load_model(model_id) # load existing model |
|
|
else: |
|
|
self.model = self.client.model() # load empty model |
|
|
|
|
|
def load_model(self, model_id): |
|
|
# Initialize model |
|
|
self.model = self.client.model(model_id) |
|
|
self.model_url = f"{HUB_WEB_ROOT}/models/{self.model.id}" |
|
|
|
|
|
self._set_train_args() |
|
|
|
|
|
# Start heartbeats for HUB to monitor agent |
|
|
self.model.start_heartbeat(self.rate_limits["heartbeat"]) |
|
|
LOGGER.info(f"{PREFIX}View model at {self.model_url} 🚀") |
|
|
|
|
|
def create_model(self, model_args): |
|
|
# Initialize model |
|
|
payload = { |
|
|
"config": { |
|
|
"batchSize": model_args.get("batch", -1), |
|
|
"epochs": model_args.get("epochs", 300), |
|
|
"imageSize": model_args.get("imgsz", 640), |
|
|
"patience": model_args.get("patience", 100), |
|
|
"device": model_args.get("device", ""), |
|
|
"cache": model_args.get("cache", "ram"), |
|
|
}, |
|
|
"dataset": {"name": model_args.get("data")}, |
|
|
"lineage": { |
|
|
"architecture": { |
|
|
"name": self.filename.replace(".pt", "").replace(".yaml", ""), |
|
|
}, |
|
|
"parent": {}, |
|
|
}, |
|
|
"meta": {"name": self.filename}, |
|
|
} |
|
|
|
|
|
if self.filename.endswith(".pt"): |
|
|
payload["lineage"]["parent"]["name"] = self.filename |
|
|
|
|
|
self.model.create_model(payload) |
|
|
|
|
|
# Model could not be created |
|
|
# TODO: improve error handling |
|
|
if not self.model.id: |
|
|
return |
|
|
|
|
|
self.model_url = f"{HUB_WEB_ROOT}/models/{self.model.id}" |
|
|
|
|
|
# Start heartbeats for HUB to monitor agent |
|
|
self.model.start_heartbeat(self.rate_limits["heartbeat"]) |
|
|
|
|
|
LOGGER.info(f"{PREFIX}View model at {self.model_url} 🚀") |
|
|
|
|
|
def _parse_identifier(self, identifier): |
|
|
""" |
|
|
Parses the given identifier to determine the type of identifier and extract relevant components. |
|
|
|
|
|
The method supports different identifier formats: |
|
|
- A HUB URL, which starts with HUB_WEB_ROOT followed by '/models/' |
|
|
- An identifier containing an API key and a model ID separated by an underscore |
|
|
- An identifier that is solely a model ID of a fixed length |
|
|
- A local filename that ends with '.pt' or '.yaml' |
|
|
|
|
|
Args: |
|
|
identifier (str): The identifier string to be parsed. |
|
|
|
|
|
Returns: |
|
|
(tuple): A tuple containing the API key, model ID, and filename as applicable. |
|
|
|
|
|
Raises: |
|
|
HUBModelError: If the identifier format is not recognized. |
|
|
""" |
|
|
|
|
|
# Initialize variables |
|
|
api_key, model_id, filename = None, None, None |
|
|
|
|
|
# Check if identifier is a HUB URL |
|
|
if identifier.startswith(f"{HUB_WEB_ROOT}/models/"): |
|
|
# Extract the model_id after the HUB_WEB_ROOT URL |
|
|
model_id = identifier.split(f"{HUB_WEB_ROOT}/models/")[-1] |
|
|
else: |
|
|
# Split the identifier based on underscores only if it's not a HUB URL |
|
|
parts = identifier.split("_") |
|
|
|
|
|
# Check if identifier is in the format of API key and model ID |
|
|
if len(parts) == 2 and len(parts[0]) == 42 and len(parts[1]) == 20: |
|
|
api_key, model_id = parts |
|
|
# Check if identifier is a single model ID |
|
|
elif len(parts) == 1 and len(parts[0]) == 20: |
|
|
model_id = parts[0] |
|
|
# Check if identifier is a local filename |
|
|
elif identifier.endswith(".pt") or identifier.endswith(".yaml"): |
|
|
filename = identifier |
|
|
else: |
|
|
raise HUBModelError( |
|
|
f"model='{identifier}' could not be parsed. Check format is correct. " |
|
|
f"Supported formats are Ultralytics HUB URL, apiKey_modelId, modelId, local pt or yaml file." |
|
|
) |
|
|
|
|
|
return api_key, model_id, filename |
|
|
|
|
|
def _set_train_args(self, **kwargs): |
|
|
if self.model.is_trained(): |
|
|
# Model is already trained |
|
|
raise ValueError(emojis(f"Model is already trained and uploaded to {self.model_url} 🚀")) |
|
|
|
|
|
if self.model.is_resumable(): |
|
|
# Model has saved weights |
|
|
self.train_args = {"data": self.model.get_dataset_url(), "resume": True} |
|
|
self.model_file = self.model.get_weights_url("last") |
|
|
else: |
|
|
# Model has no saved weights |
|
|
def get_train_args(config): |
|
|
return { |
|
|
"batch": config["batchSize"], |
|
|
"epochs": config["epochs"], |
|
|
"imgsz": config["imageSize"], |
|
|
"patience": config["patience"], |
|
|
"device": config["device"], |
|
|
"cache": config["cache"], |
|
|
"data": self.model.get_dataset_url(), |
|
|
} |
|
|
|
|
|
self.train_args = get_train_args(self.model.data.get("config")) |
|
|
# Set the model file as either a *.pt or *.yaml file |
|
|
self.model_file = ( |
|
|
self.model.get_weights_url("parent") if self.model.is_pretrained() else self.model.get_architecture() |
|
|
) |
|
|
|
|
|
if not self.train_args.get("data"): |
|
|
raise ValueError("Dataset may still be processing. Please wait a minute and try again.") # RF fix |
|
|
|
|
|
self.model_file = checks.check_yolov5u_filename(self.model_file, verbose=False) # YOLOv5->YOLOv5u |
|
|
self.model_id = self.model.id |
|
|
|
|
|
def request_queue( |
|
|
self, |
|
|
request_func, |
|
|
retry=3, |
|
|
timeout=30, |
|
|
thread=True, |
|
|
verbose=True, |
|
|
progress_total=None, |
|
|
*args, |
|
|
**kwargs, |
|
|
): |
|
|
def retry_request(): |
|
|
t0 = time.time() # Record the start time for the timeout |
|
|
for i in range(retry + 1): |
|
|
if (time.time() - t0) > timeout: |
|
|
LOGGER.warning(f"{PREFIX}Timeout for request reached. {HELP_MSG}") |
|
|
break # Timeout reached, exit loop |
|
|
|
|
|
response = request_func(*args, **kwargs) |
|
|
if progress_total: |
|
|
self._show_upload_progress(progress_total, response) |
|
|
|
|
|
if response is None: |
|
|
LOGGER.warning(f"{PREFIX}Received no response from the request. {HELP_MSG}") |
|
|
time.sleep(2**i) # Exponential backoff before retrying |
|
|
continue # Skip further processing and retry |
|
|
|
|
|
if HTTPStatus.OK <= response.status_code < HTTPStatus.MULTIPLE_CHOICES: |
|
|
return response # Success, no need to retry |
|
|
|
|
|
if i == 0: |
|
|
# Initial attempt, check status code and provide messages |
|
|
message = self._get_failure_message(response, retry, timeout) |
|
|
|
|
|
if verbose: |
|
|
LOGGER.warning(f"{PREFIX}{message} {HELP_MSG} ({response.status_code})") |
|
|
|
|
|
if not self._should_retry(response.status_code): |
|
|
LOGGER.warning(f"{PREFIX}Request failed. {HELP_MSG} ({response.status_code}") |
|
|
break # Not an error that should be retried, exit loop |
|
|
|
|
|
time.sleep(2**i) # Exponential backoff for retries |
|
|
|
|
|
return response |
|
|
|
|
|
if thread: |
|
|
# Start a new thread to run the retry_request function |
|
|
threading.Thread(target=retry_request, daemon=True).start() |
|
|
else: |
|
|
# If running in the main thread, call retry_request directly |
|
|
return retry_request() |
|
|
|
|
|
def _should_retry(self, status_code): |
|
|
# Status codes that trigger retries |
|
|
retry_codes = { |
|
|
HTTPStatus.REQUEST_TIMEOUT, |
|
|
HTTPStatus.BAD_GATEWAY, |
|
|
HTTPStatus.GATEWAY_TIMEOUT, |
|
|
} |
|
|
return True if status_code in retry_codes else False |
|
|
|
|
|
def _get_failure_message(self, response: requests.Response, retry: int, timeout: int): |
|
|
""" |
|
|
Generate a retry message based on the response status code. |
|
|
|
|
|
Args: |
|
|
response: The HTTP response object. |
|
|
retry: The number of retry attempts allowed. |
|
|
timeout: The maximum timeout duration. |
|
|
|
|
|
Returns: |
|
|
str: The retry message. |
|
|
""" |
|
|
if self._should_retry(response.status_code): |
|
|
return f"Retrying {retry}x for {timeout}s." if retry else "" |
|
|
elif response.status_code == HTTPStatus.TOO_MANY_REQUESTS: # rate limit |
|
|
headers = response.headers |
|
|
return ( |
|
|
f"Rate limit reached ({headers['X-RateLimit-Remaining']}/{headers['X-RateLimit-Limit']}). " |
|
|
f"Please retry after {headers['Retry-After']}s." |
|
|
) |
|
|
else: |
|
|
try: |
|
|
return response.json().get("message", "No JSON message.") |
|
|
except AttributeError: |
|
|
return "Unable to read JSON." |
|
|
|
|
|
def upload_metrics(self): |
|
|
"""Upload model metrics to Ultralytics HUB.""" |
|
|
return self.request_queue(self.model.upload_metrics, metrics=self.metrics_queue.copy(), thread=True) |
|
|
|
|
|
def upload_model( |
|
|
self, |
|
|
epoch: int, |
|
|
weights: str, |
|
|
is_best: bool = False, |
|
|
map: float = 0.0, |
|
|
final: bool = False, |
|
|
) -> None: |
|
|
""" |
|
|
Upload a model checkpoint to Ultralytics HUB. |
|
|
|
|
|
Args: |
|
|
epoch (int): The current training epoch. |
|
|
weights (str): Path to the model weights file. |
|
|
is_best (bool): Indicates if the current model is the best one so far. |
|
|
map (float): Mean average precision of the model. |
|
|
final (bool): Indicates if the model is the final model after training. |
|
|
""" |
|
|
if Path(weights).is_file(): |
|
|
progress_total = Path(weights).stat().st_size if final else None # Only show progress if final |
|
|
self.request_queue( |
|
|
self.model.upload_model, |
|
|
epoch=epoch, |
|
|
weights=weights, |
|
|
is_best=is_best, |
|
|
map=map, |
|
|
final=final, |
|
|
retry=10, |
|
|
timeout=3600, |
|
|
thread=not final, |
|
|
progress_total=progress_total, |
|
|
) |
|
|
else: |
|
|
LOGGER.warning(f"{PREFIX}WARNING ⚠️ Model upload issue. Missing model {weights}.") |
|
|
|
|
|
def _show_upload_progress(self, content_length: int, response: requests.Response) -> None: |
|
|
""" |
|
|
Display a progress bar to track the upload progress of a file download. |
|
|
|
|
|
Args: |
|
|
content_length (int): The total size of the content to be downloaded in bytes. |
|
|
response (requests.Response): The response object from the file download request. |
|
|
|
|
|
Returns: |
|
|
(None) |
|
|
""" |
|
|
with TQDM(total=content_length, unit="B", unit_scale=True, unit_divisor=1024) as pbar: |
|
|
for data in response.iter_content(chunk_size=1024): |
|
|
pbar.update(len(data))
|
|
|
|