Add more tracking args for solutions (#17878)

Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
pull/17749/head^2
Muhammad Rizwan Munawar 4 days ago committed by GitHub
parent 5c2cdb6841
commit 69c12a4ba3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 2
      ultralytics/solutions/ai_gym.py
  2. 6
      ultralytics/solutions/solutions.py

@ -71,7 +71,7 @@ class AIGym(BaseSolution):
>>> processed_image = gym.monitor(image) >>> processed_image = gym.monitor(image)
""" """
# Extract tracks # Extract tracks
tracks = self.model.track(source=im0, persist=True, classes=self.CFG["classes"])[0] tracks = self.model.track(source=im0, persist=True, classes=self.CFG["classes"], **self.track_add_args)[0]
if tracks.boxes.id is not None: if tracks.boxes.id is not None:
# Extract and check keypoints # Extract and check keypoints

@ -74,6 +74,10 @@ class BaseSolution:
self.model = YOLO(self.CFG["model"]) self.model = YOLO(self.CFG["model"])
self.names = self.model.names self.names = self.model.names
self.track_add_args = { # Tracker additional arguments for advance configuration
k: self.CFG[k] for k in ["verbose", "iou", "conf", "device", "max_det", "half", "tracker"]
}
if IS_CLI and self.CFG["source"] is None: if IS_CLI and self.CFG["source"] is None:
d_s = "solutions_ci_demo.mp4" if "-pose" not in self.CFG["model"] else "solution_ci_pose_demo.mp4" d_s = "solutions_ci_demo.mp4" if "-pose" not in self.CFG["model"] else "solution_ci_pose_demo.mp4"
LOGGER.warning(f" WARNING: source not provided. using default source {ASSETS_URL}/{d_s}") LOGGER.warning(f" WARNING: source not provided. using default source {ASSETS_URL}/{d_s}")
@ -98,7 +102,7 @@ class BaseSolution:
>>> frame = cv2.imread("path/to/image.jpg") >>> frame = cv2.imread("path/to/image.jpg")
>>> solution.extract_tracks(frame) >>> solution.extract_tracks(frame)
""" """
self.tracks = self.model.track(source=im0, persist=True, classes=self.CFG["classes"]) self.tracks = self.model.track(source=im0, persist=True, classes=self.CFG["classes"], **self.track_add_args)
# Extract tracks for OBB or object detection # Extract tracks for OBB or object detection
self.track_data = self.tracks[0].obb or self.tracks[0].boxes self.track_data = self.tracks[0].obb or self.tracks[0].boxes

Loading…
Cancel
Save