|
|
|
@ -170,10 +170,19 @@ class HUBTrainingSession: |
|
|
|
|
|
|
|
|
|
return api_key, model_id, filename |
|
|
|
|
|
|
|
|
|
def _set_train_args(self, **kwargs): |
|
|
|
|
"""Initializes training arguments and creates a model entry on the Ultralytics HUB.""" |
|
|
|
|
def _set_train_args(self): |
|
|
|
|
""" |
|
|
|
|
Initializes training arguments and creates a model entry on the Ultralytics HUB. |
|
|
|
|
|
|
|
|
|
This method sets up training arguments based on the model's state and updates them with any additional |
|
|
|
|
arguments provided. It handles different states of the model, such as whether it's resumable, pretrained, |
|
|
|
|
or requires specific file setup. |
|
|
|
|
|
|
|
|
|
Raises: |
|
|
|
|
ValueError: If the model is already trained, if required dataset information is missing, or if there are |
|
|
|
|
issues with the provided training arguments. |
|
|
|
|
""" |
|
|
|
|
if self.model.is_trained(): |
|
|
|
|
# Model is already trained |
|
|
|
|
raise ValueError(emojis(f"Model is already trained and uploaded to {self.model_url} 🚀")) |
|
|
|
|
|
|
|
|
|
if self.model.is_resumable(): |
|
|
|
@ -182,26 +191,16 @@ class HUBTrainingSession: |
|
|
|
|
self.model_file = self.model.get_weights_url("last") |
|
|
|
|
else: |
|
|
|
|
# Model has no saved weights |
|
|
|
|
def get_train_args(config): |
|
|
|
|
"""Parses an identifier to extract API key, model ID, and filename if applicable.""" |
|
|
|
|
return { |
|
|
|
|
"batch": config["batchSize"], |
|
|
|
|
"epochs": config["epochs"], |
|
|
|
|
"imgsz": config["imageSize"], |
|
|
|
|
"patience": config["patience"], |
|
|
|
|
"device": config["device"], |
|
|
|
|
"cache": config["cache"], |
|
|
|
|
"data": self.model.get_dataset_url(), |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
self.train_args = get_train_args(self.model.data.get("config")) |
|
|
|
|
self.train_args = self.model.data.get("train_args") # new response |
|
|
|
|
|
|
|
|
|
# Set the model file as either a *.pt or *.yaml file |
|
|
|
|
self.model_file = ( |
|
|
|
|
self.model.get_weights_url("parent") if self.model.is_pretrained() else self.model.get_architecture() |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
if not self.train_args.get("data"): |
|
|
|
|
raise ValueError("Dataset may still be processing. Please wait a minute and try again.") # RF fix |
|
|
|
|
if "data" not in self.train_args: |
|
|
|
|
# RF bug - datasets are sometimes not exported |
|
|
|
|
raise ValueError("Dataset may still be processing. Please wait a minute and try again.") |
|
|
|
|
|
|
|
|
|
self.model_file = checks.check_yolov5u_filename(self.model_file, verbose=False) # YOLOv5->YOLOv5u |
|
|
|
|
self.model_id = self.model.id |
|
|
|
|