|
|
|
@ -68,7 +68,7 @@ class HUBTrainingSession: |
|
|
|
|
self.model = self.client.model() # load empty model |
|
|
|
|
|
|
|
|
|
def load_model(self, model_id): |
|
|
|
|
# Initialize model |
|
|
|
|
"""Loads an existing model from Ultralytics HUB using the provided model identifier.""" |
|
|
|
|
self.model = self.client.model(model_id) |
|
|
|
|
if not self.model.data: # then model model does not exist |
|
|
|
|
raise ValueError(emojis(f"❌ The specified HUB model does not exist")) # TODO: improve error handling |
|
|
|
@ -82,7 +82,7 @@ class HUBTrainingSession: |
|
|
|
|
LOGGER.info(f"{PREFIX}View model at {self.model_url} 🚀") |
|
|
|
|
|
|
|
|
|
def create_model(self, model_args): |
|
|
|
|
# Initialize model |
|
|
|
|
"""Initializes a HUB training session with the specified model identifier.""" |
|
|
|
|
payload = { |
|
|
|
|
"config": { |
|
|
|
|
"batchSize": model_args.get("batch", -1), |
|
|
|
@ -168,6 +168,7 @@ class HUBTrainingSession: |
|
|
|
|
return api_key, model_id, filename |
|
|
|
|
|
|
|
|
|
def _set_train_args(self, **kwargs): |
|
|
|
|
"""Initializes training arguments and creates a model entry on the Ultralytics HUB.""" |
|
|
|
|
if self.model.is_trained(): |
|
|
|
|
# Model is already trained |
|
|
|
|
raise ValueError(emojis(f"Model is already trained and uploaded to {self.model_url} 🚀")) |
|
|
|
@ -179,6 +180,7 @@ class HUBTrainingSession: |
|
|
|
|
else: |
|
|
|
|
# Model has no saved weights |
|
|
|
|
def get_train_args(config): |
|
|
|
|
"""Parses an identifier to extract API key, model ID, and filename if applicable.""" |
|
|
|
|
return { |
|
|
|
|
"batch": config["batchSize"], |
|
|
|
|
"epochs": config["epochs"], |
|
|
|
@ -213,6 +215,7 @@ class HUBTrainingSession: |
|
|
|
|
**kwargs, |
|
|
|
|
): |
|
|
|
|
def retry_request(): |
|
|
|
|
"""Attempts to call `request_func` with retries, timeout, and optional threading.""" |
|
|
|
|
t0 = time.time() # Record the start time for the timeout |
|
|
|
|
for i in range(retry + 1): |
|
|
|
|
if (time.time() - t0) > timeout: |
|
|
|
@ -254,7 +257,7 @@ class HUBTrainingSession: |
|
|
|
|
return retry_request() |
|
|
|
|
|
|
|
|
|
def _should_retry(self, status_code): |
|
|
|
|
# Status codes that trigger retries |
|
|
|
|
"""Determines if a request should be retried based on the HTTP status code.""" |
|
|
|
|
retry_codes = { |
|
|
|
|
HTTPStatus.REQUEST_TIMEOUT, |
|
|
|
|
HTTPStatus.BAD_GATEWAY, |
|
|
|
|