`ultralytics 8.1.30` add advanced HUB train arguments (#9110)

Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
pull/8753/head v8.1.30
Kalen Michael 11 months ago committed by GitHub
parent a62cdab53a
commit 8617fcf32d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 2
      ultralytics/__init__.py
  2. 2
      ultralytics/engine/model.py
  3. 35
      ultralytics/hub/session.py

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.1.29"
__version__ = "8.1.30"
from ultralytics.data.explorer.explorer import Explorer
from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld

@ -124,7 +124,7 @@ class Model(nn.Module):
# Check if Ultralytics HUB model from https://hub.ultralytics.com
if self.is_hub_model(model):
# Fetch model from HUB
checks.check_requirements("hub-sdk>0.0.2")
checks.check_requirements("hub-sdk>=0.0.5")
self.session = self._get_hub_session(model)
model = self.session.model_file

@ -170,10 +170,19 @@ class HUBTrainingSession:
return api_key, model_id, filename
def _set_train_args(self, **kwargs):
"""Initializes training arguments and creates a model entry on the Ultralytics HUB."""
def _set_train_args(self):
"""
Initializes training arguments and creates a model entry on the Ultralytics HUB.
This method sets up training arguments based on the model's state and updates them with any additional
arguments provided. It handles different states of the model, such as whether it's resumable, pretrained,
or requires specific file setup.
Raises:
ValueError: If the model is already trained, if required dataset information is missing, or if there are
issues with the provided training arguments.
"""
if self.model.is_trained():
# Model is already trained
raise ValueError(emojis(f"Model is already trained and uploaded to {self.model_url} 🚀"))
if self.model.is_resumable():
@ -182,26 +191,16 @@ class HUBTrainingSession:
self.model_file = self.model.get_weights_url("last")
else:
# Model has no saved weights
def get_train_args(config):
"""Parses an identifier to extract API key, model ID, and filename if applicable."""
return {
"batch": config["batchSize"],
"epochs": config["epochs"],
"imgsz": config["imageSize"],
"patience": config["patience"],
"device": config["device"],
"cache": config["cache"],
"data": self.model.get_dataset_url(),
}
self.train_args = get_train_args(self.model.data.get("config"))
self.train_args = self.model.data.get("train_args") # new response
# Set the model file as either a *.pt or *.yaml file
self.model_file = (
self.model.get_weights_url("parent") if self.model.is_pretrained() else self.model.get_architecture()
)
if not self.train_args.get("data"):
raise ValueError("Dataset may still be processing. Please wait a minute and try again.") # RF fix
if "data" not in self.train_args:
# RF bug - datasets are sometimes not exported
raise ValueError("Dataset may still be processing. Please wait a minute and try again.")
self.model_file = checks.check_yolov5u_filename(self.model_file, verbose=False) # YOLOv5->YOLOv5u
self.model_id = self.model.id

Loading…
Cancel
Save