From 69c12a4ba36d3dd9c6f541085125de588df15be3 Mon Sep 17 00:00:00 2001 From: Muhammad Rizwan Munawar Date: Sun, 1 Dec 2024 23:19:36 +0500 Subject: [PATCH] Add more tracking args for solutions (#17878) Co-authored-by: UltralyticsAssistant Co-authored-by: Glenn Jocher --- ultralytics/solutions/ai_gym.py | 2 +- ultralytics/solutions/solutions.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/ultralytics/solutions/ai_gym.py b/ultralytics/solutions/ai_gym.py index 68e369762..359ed9dd4 100644 --- a/ultralytics/solutions/ai_gym.py +++ b/ultralytics/solutions/ai_gym.py @@ -71,7 +71,7 @@ class AIGym(BaseSolution): >>> processed_image = gym.monitor(image) """ # Extract tracks - tracks = self.model.track(source=im0, persist=True, classes=self.CFG["classes"])[0] + tracks = self.model.track(source=im0, persist=True, classes=self.CFG["classes"], **self.track_add_args)[0] if tracks.boxes.id is not None: # Extract and check keypoints diff --git a/ultralytics/solutions/solutions.py b/ultralytics/solutions/solutions.py index fc05d42d6..f47223b57 100644 --- a/ultralytics/solutions/solutions.py +++ b/ultralytics/solutions/solutions.py @@ -74,6 +74,10 @@ class BaseSolution: self.model = YOLO(self.CFG["model"]) self.names = self.model.names + self.track_add_args = { # Tracker additional arguments for advance configuration + k: self.CFG[k] for k in ["verbose", "iou", "conf", "device", "max_det", "half", "tracker"] + } + if IS_CLI and self.CFG["source"] is None: d_s = "solutions_ci_demo.mp4" if "-pose" not in self.CFG["model"] else "solution_ci_pose_demo.mp4" LOGGER.warning(f"⚠️ WARNING: source not provided. using default source {ASSETS_URL}/{d_s}") @@ -98,7 +102,7 @@ class BaseSolution: >>> frame = cv2.imread("path/to/image.jpg") >>> solution.extract_tracks(frame) """ - self.tracks = self.model.track(source=im0, persist=True, classes=self.CFG["classes"]) + self.tracks = self.model.track(source=im0, persist=True, classes=self.CFG["classes"], **self.track_add_args) # Extract tracks for OBB or object detection self.track_data = self.tracks[0].obb or self.tracks[0].boxes