diff --git a/ultralytics/engine/model.py b/ultralytics/engine/model.py index f696ffb3b..47a7f07a2 100644 --- a/ultralytics/engine/model.py +++ b/ultralytics/engine/model.py @@ -128,8 +128,10 @@ class Model(nn.Module): if self.is_hub_model(model): # Fetch model from HUB checks.check_requirements("hub-sdk>=0.0.8") - self.session = HUBTrainingSession.create_session(model) - model = self.session.model_file + session = HUBTrainingSession.create_session(model) + model = session.model_file + if session.train_args: # training sent from HUB + self.session = session # Check if Triton Server model elif self.is_triton_model(model): diff --git a/ultralytics/hub/session.py b/ultralytics/hub/session.py index 31e955efc..d93c96281 100644 --- a/ultralytics/hub/session.py +++ b/ultralytics/hub/session.py @@ -50,6 +50,7 @@ class HUBTrainingSession: self.model = None self.model_url = None self.model_file = None + self.train_args = None # Parse input api_key, model_id, self.filename = self._parse_identifier(identifier)