diff --git a/ultralytics/solutions/ai_gym.py b/ultralytics/solutions/ai_gym.py index 68e3697627..359ed9dd49 100644 --- a/ultralytics/solutions/ai_gym.py +++ b/ultralytics/solutions/ai_gym.py @@ -71,7 +71,7 @@ class AIGym(BaseSolution): >>> processed_image = gym.monitor(image) """ # Extract tracks - tracks = self.model.track(source=im0, persist=True, classes=self.CFG["classes"])[0] + tracks = self.model.track(source=im0, persist=True, classes=self.CFG["classes"], **self.track_add_args)[0] if tracks.boxes.id is not None: # Extract and check keypoints diff --git a/ultralytics/solutions/solutions.py b/ultralytics/solutions/solutions.py index fc05d42d6e..f47223b570 100644 --- a/ultralytics/solutions/solutions.py +++ b/ultralytics/solutions/solutions.py @@ -74,6 +74,10 @@ class BaseSolution: self.model = YOLO(self.CFG["model"]) self.names = self.model.names + self.track_add_args = { # Tracker additional arguments for advance configuration + k: self.CFG[k] for k in ["verbose", "iou", "conf", "device", "max_det", "half", "tracker"] + } + if IS_CLI and self.CFG["source"] is None: d_s = "solutions_ci_demo.mp4" if "-pose" not in self.CFG["model"] else "solution_ci_pose_demo.mp4" LOGGER.warning(f"⚠️ WARNING: source not provided. using default source {ASSETS_URL}/{d_s}") @@ -98,7 +102,7 @@ class BaseSolution: >>> frame = cv2.imread("path/to/image.jpg") >>> solution.extract_tracks(frame) """ - self.tracks = self.model.track(source=im0, persist=True, classes=self.CFG["classes"]) + self.tracks = self.model.track(source=im0, persist=True, classes=self.CFG["classes"], **self.track_add_args) # Extract tracks for OBB or object detection self.track_data = self.tracks[0].obb or self.tracks[0].boxes