From 0f9f7b806c709f762483bc7ab6c56a72b357c58b Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Thu, 29 Aug 2024 17:15:24 +0200 Subject: [PATCH] fix HUB download and train (#15896) Signed-off-by: UltralyticsAssistant Co-authored-by: UltralyticsAssistant --- ultralytics/engine/model.py | 6 ++++-- ultralytics/hub/session.py | 1 + 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/ultralytics/engine/model.py b/ultralytics/engine/model.py index f696ffb3b..47a7f07a2 100644 --- a/ultralytics/engine/model.py +++ b/ultralytics/engine/model.py @@ -128,8 +128,10 @@ class Model(nn.Module): if self.is_hub_model(model): # Fetch model from HUB checks.check_requirements("hub-sdk>=0.0.8") - self.session = HUBTrainingSession.create_session(model) - model = self.session.model_file + session = HUBTrainingSession.create_session(model) + model = session.model_file + if session.train_args: # training sent from HUB + self.session = session # Check if Triton Server model elif self.is_triton_model(model): diff --git a/ultralytics/hub/session.py b/ultralytics/hub/session.py index 31e955efc..d93c96281 100644 --- a/ultralytics/hub/session.py +++ b/ultralytics/hub/session.py @@ -50,6 +50,7 @@ class HUBTrainingSession: self.model = None self.model_url = None self.model_file = None + self.train_args = None # Parse input api_key, model_id, self.filename = self._parse_identifier(identifier)