From 73e6861d950fc5210d2a2b851918b99eb26229f7 Mon Sep 17 00:00:00 2001 From: Muhammad Rizwan Munawar Date: Sat, 5 Oct 2024 18:08:37 +0500 Subject: [PATCH] Update `workouts_monitoring` solution (#16706) Co-authored-by: UltralyticsAssistant Co-authored-by: Glenn Jocher --- docs/en/guides/object-counting.md | 4 +- docs/en/guides/workouts-monitoring.md | 89 +++++-------- tests/test_solutions.py | 6 +- ultralytics/cfg/solutions/default.yaml | 4 + ultralytics/solutions/ai_gym.py | 172 +++++++++---------------- ultralytics/solutions/solutions.py | 12 +- ultralytics/utils/plotting.py | 120 +++++++---------- 7 files changed, 162 insertions(+), 245 deletions(-) diff --git a/docs/en/guides/object-counting.md b/docs/en/guides/object-counting.md index 1abb7d4933..cefc9ae281 100644 --- a/docs/en/guides/object-counting.md +++ b/docs/en/guides/object-counting.md @@ -286,7 +286,7 @@ def count_objects_in_region(video_path, output_video_path, model_path): if not success: print("Video frame is empty or video processing has been successfully completed.") break - im0 = counter.start_counting(im0) + im0 = counter.count(im0) video_writer.write(im0) cap.release() @@ -334,7 +334,7 @@ def count_specific_classes(video_path, output_video_path, model_path, classes_to if not success: print("Video frame is empty or video processing has been successfully completed.") break - im0 = counter.start_counting(im0) + im0 = counter.count(im0) video_writer.write(im0) cap.release() diff --git a/docs/en/guides/workouts-monitoring.md b/docs/en/guides/workouts-monitoring.md index af996894b3..78d894e81d 100644 --- a/docs/en/guides/workouts-monitoring.md +++ b/docs/en/guides/workouts-monitoring.md @@ -41,18 +41,16 @@ Monitoring workouts through pose estimation with [Ultralytics YOLO11](https://gi ```python import cv2 - from ultralytics import YOLO, solutions + from ultralytics import solutions - model = YOLO("yolo11n-pose.pt") cap = cv2.VideoCapture("path/to/video/file.mp4") assert cap.isOpened(), "Error reading video file" w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS)) - gym_object = solutions.AIGym( - line_thickness=2, - view_img=True, - pose_type="pushup", - kpts_to_check=[6, 8, 10], + gym = solutions.AIGym( + model="yolo11n-pose.pt", + show=True, + kpts=[6, 8, 10], ) while cap.isOpened(): @@ -60,9 +58,7 @@ Monitoring workouts through pose estimation with [Ultralytics YOLO11](https://gi if not success: print("Video frame is empty or video processing has been successfully completed.") break - results = model.track(im0, verbose=False) # Tracking recommended - # results = model.predict(im0) # Prediction also supported - im0 = gym_object.start_counting(im0, results) + im0 = gym.monitor(im0) cv2.destroyAllWindows() ``` @@ -72,20 +68,17 @@ Monitoring workouts through pose estimation with [Ultralytics YOLO11](https://gi ```python import cv2 - from ultralytics import YOLO, solutions + from ultralytics import solutions - model = YOLO("yolo11n-pose.pt") cap = cv2.VideoCapture("path/to/video/file.mp4") assert cap.isOpened(), "Error reading video file" w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS)) video_writer = cv2.VideoWriter("workouts.avi", cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h)) - gym_object = solutions.AIGym( - line_thickness=2, - view_img=True, - pose_type="pushup", - kpts_to_check=[6, 8, 10], + gym = solutions.AIGym( + show=True, + kpts=[6, 8, 10], ) while cap.isOpened(): @@ -93,33 +86,26 @@ Monitoring workouts through pose estimation with [Ultralytics YOLO11](https://gi if not success: print("Video frame is empty or video processing has been successfully completed.") break - results = model.track(im0, verbose=False) # Tracking recommended - # results = model.predict(im0) # Prediction also supported - im0 = gym_object.start_counting(im0, results) + im0 = gym.monitor(im0) video_writer.write(im0) cv2.destroyAllWindows() video_writer.release() ``` -???+ tip "Support" - - "pushup", "pullup" and "abworkout" supported - ### KeyPoints Map ![keyPoints Order Ultralytics YOLO11 Pose](https://github.com/ultralytics/docs/releases/download/0/keypoints-order-ultralytics-yolov8-pose.avif) ### Arguments `AIGym` -| Name | Type | Default | Description | -| ----------------- | ------- | -------- | -------------------------------------------------------------------------------------- | -| `kpts_to_check` | `list` | `None` | List of three keypoints index, for counting specific workout, followed by keypoint Map | -| `line_thickness` | `int` | `2` | Thickness of the lines drawn. | -| `view_img` | `bool` | `False` | Flag to display the image. | -| `pose_up_angle` | `float` | `145.0` | Angle threshold for the 'up' pose. | -| `pose_down_angle` | `float` | `90.0` | Angle threshold for the 'down' pose. | -| `pose_type` | `str` | `pullup` | Type of pose to detect (`'pullup`', `pushup`, `abworkout`, `squat`). | +| Name | Type | Default | Description | +| ------------ | ------- | ------- | -------------------------------------------------------------------------------------- | +| `kpts` | `list` | `None` | List of three keypoints index, for counting specific workout, followed by keypoint Map | +| `line_width` | `int` | `2` | Thickness of the lines drawn. | +| `show` | `bool` | `False` | Flag to display the image. | +| `up_angle` | `float` | `145.0` | Angle threshold for the 'up' pose. | +| `down_angle` | `float` | `90.0` | Angle threshold for the 'down' pose. | ### Arguments `model.predict` @@ -138,18 +124,16 @@ To monitor your workouts using Ultralytics YOLO11, you can utilize the pose esti ```python import cv2 -from ultralytics import YOLO, solutions +from ultralytics import solutions -model = YOLO("yolo11n-pose.pt") cap = cv2.VideoCapture("path/to/video/file.mp4") assert cap.isOpened(), "Error reading video file" w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS)) -gym_object = solutions.AIGym( - line_thickness=2, - view_img=True, - pose_type="pushup", - kpts_to_check=[6, 8, 10], +gym = solutions.AIGym( + line_width=2, + show=True, + kpts=[6, 8, 10], ) while cap.isOpened(): @@ -157,8 +141,7 @@ while cap.isOpened(): if not success: print("Video frame is empty or video processing has been successfully completed.") break - results = model.track(im0, verbose=False) - im0 = gym_object.start_counting(im0, results) + im0 = gym.monitor(im0) cv2.destroyAllWindows() ``` @@ -188,11 +171,10 @@ Yes, Ultralytics YOLO11 can be adapted for custom workout routines. The `AIGym` ```python from ultralytics import solutions -gym_object = solutions.AIGym( - line_thickness=2, - view_img=True, - pose_type="squat", - kpts_to_check=[6, 8, 10], +gym = solutions.AIGym( + line_width=2, + show=True, + kpts=[6, 8, 10], ) ``` @@ -205,20 +187,18 @@ To save the workout monitoring output, you can modify the code to include a vide ```python import cv2 -from ultralytics import YOLO, solutions +from ultralytics import solutions -model = YOLO("yolo11n-pose.pt") cap = cv2.VideoCapture("path/to/video/file.mp4") assert cap.isOpened(), "Error reading video file" w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS)) video_writer = cv2.VideoWriter("workouts.avi", cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h)) -gym_object = solutions.AIGym( - line_thickness=2, - view_img=True, - pose_type="pushup", - kpts_to_check=[6, 8, 10], +gym = solutions.AIGym( + line_width=2, + show=True, + kpts=[6, 8, 10], ) while cap.isOpened(): @@ -226,8 +206,7 @@ while cap.isOpened(): if not success: print("Video frame is empty or video processing has been successfully completed.") break - results = model.track(im0, verbose=False) - im0 = gym_object.start_counting(im0, results) + im0 = gym.monitor(im0) video_writer.write(im0) cv2.destroyAllWindows() diff --git a/tests/test_solutions.py b/tests/test_solutions.py index 55b8efc326..485c795ee4 100644 --- a/tests/test_solutions.py +++ b/tests/test_solutions.py @@ -41,16 +41,14 @@ def test_major_solutions(): def test_aigym(): """Test the workouts monitoring solution.""" safe_download(url=WORKOUTS_SOLUTION_DEMO) - model = YOLO("yolo11n-pose.pt") cap = cv2.VideoCapture("solution_ci_pose_demo.mp4") assert cap.isOpened(), "Error reading video file" - gym_object = solutions.AIGym(line_thickness=2, pose_type="squat", kpts_to_check=[5, 11, 13]) + gym = solutions.AIGym(line_width=2, kpts=[5, 11, 13]) while cap.isOpened(): success, im0 = cap.read() if not success: break - results = model.track(im0, verbose=False) - _ = gym_object.start_counting(im0, results) + _ = gym.monitor(im0) cap.release() cv2.destroyAllWindows() diff --git a/ultralytics/cfg/solutions/default.yaml b/ultralytics/cfg/solutions/default.yaml index ccbc234793..f22dce2c91 100644 --- a/ultralytics/cfg/solutions/default.yaml +++ b/ultralytics/cfg/solutions/default.yaml @@ -10,3 +10,7 @@ show: True # Flag to control whether to display output image or not show_in: True # Flag to display objects moving *into* the defined region show_out: True # Flag to display objects moving *out of* the defined region classes: # To count specific classes + +up_angle: 145.0 # workouts up_angle for counts, 145.0 is default value +down_angle: 90 # workouts down_angle for counts, 90 is default value +kpts: [6, 8, 10] # keypoints for workouts monitoring diff --git a/ultralytics/solutions/ai_gym.py b/ultralytics/solutions/ai_gym.py index 349e46e8f0..26f22d7032 100644 --- a/ultralytics/solutions/ai_gym.py +++ b/ultralytics/solutions/ai_gym.py @@ -1,127 +1,79 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -import cv2 - -from ultralytics.utils.checks import check_imshow +from ultralytics.solutions.solutions import BaseSolution # Import a parent class from ultralytics.utils.plotting import Annotator -class AIGym: +class AIGym(BaseSolution): """A class to manage the gym steps of people in a real-time video stream based on their poses.""" - def __init__( - self, - kpts_to_check, - line_thickness=2, - view_img=False, - pose_up_angle=145.0, - pose_down_angle=90.0, - pose_type="pullup", - ): + def __init__(self, **kwargs): + """Initialization function for AiGYM class, a child class of BaseSolution class, can be used for workouts + monitoring. """ - Initializes the AIGym class with the specified parameters. - - Args: - kpts_to_check (list): Indices of keypoints to check. - line_thickness (int, optional): Thickness of the lines drawn. Defaults to 2. - view_img (bool, optional): Flag to display the image. Defaults to False. - pose_up_angle (float, optional): Angle threshold for the 'up' pose. Defaults to 145.0. - pose_down_angle (float, optional): Angle threshold for the 'down' pose. Defaults to 90.0. - pose_type (str, optional): Type of pose to detect ('pullup', 'pushup', 'abworkout'). Defaults to "pullup". + # Check if the model name ends with '-pose' + if "model" in kwargs and "-pose" not in kwargs["model"]: + kwargs["model"] = "yolo11n-pose.pt" + elif "model" not in kwargs: + kwargs["model"] = "yolo11n-pose.pt" + + super().__init__(**kwargs) + self.count = [] # List for counts, necessary where there are multiple objects in frame + self.angle = [] # List for angle, necessary where there are multiple objects in frame + self.stage = [] # List for stage, necessary where there are multiple objects in frame + + # Extract details from CFG single time for usage later + self.initial_stage = None + self.up_angle = float(self.CFG["up_angle"]) # Pose up predefined angle to consider up pose + self.down_angle = float(self.CFG["down_angle"]) # Pose down predefined angle to consider down pose + self.kpts = self.CFG["kpts"] # User selected kpts of workouts storage for further usage + self.lw = self.CFG["line_width"] # Store line_width for usage + + def monitor(self, im0): """ - # Image and line thickness - self.im0 = None - self.tf = line_thickness - - # Keypoints and count information - self.keypoints = None - self.poseup_angle = pose_up_angle - self.posedown_angle = pose_down_angle - self.threshold = 0.001 - - # Store stage, count and angle information - self.angle = None - self.count = None - self.stage = None - self.pose_type = pose_type - self.kpts_to_check = kpts_to_check - - # Visual Information - self.view_img = view_img - self.annotator = None - - # Check if environment supports imshow - self.env_check = check_imshow(warn=True) - self.count = [] - self.angle = [] - self.stage = [] - - def start_counting(self, im0, results): - """ - Function used to count the gym steps. + Monitor the workouts using Ultralytics YOLOv8 Pose Model: https://docs.ultralytics.com/tasks/pose/. Args: - im0 (ndarray): Current frame from the video stream. - results (list): Pose estimation data. + im0 (ndarray): The input image that will be used for processing + Returns + im0 (ndarray): The processed image for more usage """ - self.im0 = im0 - - if not len(results[0]): - return self.im0 - - if len(results[0]) > len(self.count): - new_human = len(results[0]) - len(self.count) - self.count += [0] * new_human - self.angle += [0] * new_human - self.stage += ["-"] * new_human - - self.keypoints = results[0].keypoints.data - self.annotator = Annotator(im0, line_width=self.tf) - - for ind, k in enumerate(reversed(self.keypoints)): - # Estimate angle and draw specific points based on pose type - if self.pose_type in {"pushup", "pullup", "abworkout", "squat"}: - self.angle[ind] = self.annotator.estimate_pose_angle( - k[int(self.kpts_to_check[0])].cpu(), - k[int(self.kpts_to_check[1])].cpu(), - k[int(self.kpts_to_check[2])].cpu(), - ) - self.im0 = self.annotator.draw_specific_points(k, self.kpts_to_check, shape=(640, 640), radius=10) - - # Check and update pose stages and counts based on angle - if self.pose_type in {"abworkout", "pullup"}: - if self.angle[ind] > self.poseup_angle: - self.stage[ind] = "down" - if self.angle[ind] < self.posedown_angle and self.stage[ind] == "down": - self.stage[ind] = "up" - self.count[ind] += 1 - - elif self.pose_type in {"pushup", "squat"}: - if self.angle[ind] > self.poseup_angle: - self.stage[ind] = "up" - if self.angle[ind] < self.posedown_angle and self.stage[ind] == "up": - self.stage[ind] = "down" + # Extract tracks + tracks = self.model.track(source=im0, persist=True, classes=self.CFG["classes"])[0] + + if tracks.boxes.id is not None: + # Extract and check keypoints + if len(tracks) > len(self.count): + new_human = len(tracks) - len(self.count) + self.angle += [0] * new_human + self.count += [0] * new_human + self.stage += ["-"] * new_human + + # Initialize annotator + self.annotator = Annotator(im0, line_width=self.lw) + + # Enumerate over keypoints + for ind, k in enumerate(reversed(tracks.keypoints.data)): + # Get keypoints and estimate the angle + kpts = [k[int(self.kpts[i])].cpu() for i in range(3)] + self.angle[ind] = self.annotator.estimate_pose_angle(*kpts) + im0 = self.annotator.draw_specific_points(k, self.kpts, radius=self.lw * 3) + + # Determine stage and count logic based on angle thresholds + if self.angle[ind] < self.down_angle: + if self.stage[ind] == "up": self.count[ind] += 1 + self.stage[ind] = "down" + elif self.angle[ind] > self.up_angle: + self.stage[ind] = "up" + # Display angle, count, and stage text self.annotator.plot_angle_and_count_and_stage( - angle_text=self.angle[ind], - count_text=self.count[ind], - stage_text=self.stage[ind], - center_kpt=k[int(self.kpts_to_check[1])], + angle_text=self.angle[ind], # angle text for display + count_text=self.count[ind], # count text for workouts + stage_text=self.stage[ind], # stage position text + center_kpt=k[int(self.kpts[1])], # center keypoint for display ) - # Draw keypoints - self.annotator.kpts(k, shape=(640, 640), radius=1, kpt_line=True) - - # Display the image if environment supports it and view_img is True - if self.env_check and self.view_img: - cv2.imshow("Ultralytics YOLOv8 AI GYM", self.im0) - if cv2.waitKey(1) & 0xFF == ord("q"): - return - - return self.im0 - - -if __name__ == "__main__": - kpts_to_check = [0, 1, 2] # example keypoints - aigym = AIGym(kpts_to_check) + self.display_output(im0) # Display output image, if environment support display + return im0 # return an image for writing or further usage diff --git a/ultralytics/solutions/solutions.py b/ultralytics/solutions/solutions.py index b122d9da93..ed53de654b 100644 --- a/ultralytics/solutions/solutions.py +++ b/ultralytics/solutions/solutions.py @@ -4,11 +4,13 @@ from collections import defaultdict from pathlib import Path import cv2 -from shapely.geometry import LineString, Polygon from ultralytics import YOLO -from ultralytics.utils import yaml_load -from ultralytics.utils.checks import check_imshow +from ultralytics.utils import LOGGER, yaml_load +from ultralytics.utils.checks import check_imshow, check_requirements + +check_requirements("shapely>=2.0.0") +from shapely.geometry import LineString, Polygon DEFAULT_SOL_CFG_PATH = Path(__file__).resolve().parents[1] / "cfg/solutions/default.yaml" @@ -25,7 +27,7 @@ class BaseSolution: # Load config and update with args self.CFG = yaml_load(DEFAULT_SOL_CFG_PATH) self.CFG.update(kwargs) - print("Ultralytics Solutions: ✅", self.CFG) + LOGGER.info(f"Ultralytics Solutions: ✅ {self.CFG}") self.region = self.CFG["region"] # Store region data for other classes usage self.line_width = self.CFG["line_width"] # Store line_width for usage @@ -54,6 +56,8 @@ class BaseSolution: self.boxes = self.track_data.xyxy.cpu() self.clss = self.track_data.cls.cpu().tolist() self.track_ids = self.track_data.id.int().cpu().tolist() + else: + LOGGER.warning("WARNING ⚠️ tracks none, no keypoints will be considered.") def store_tracking_history(self, track_id, box): """ diff --git a/ultralytics/utils/plotting.py b/ultralytics/utils/plotting.py index 9d3051239c..116bdb841c 100644 --- a/ultralytics/utils/plotting.py +++ b/ultralytics/utils/plotting.py @@ -697,14 +697,13 @@ class Annotator: angle = 360 - angle return angle - def draw_specific_points(self, keypoints, indices=None, shape=(640, 640), radius=2, conf_thres=0.25): + def draw_specific_points(self, keypoints, indices=None, radius=2, conf_thres=0.25): """ Draw specific keypoints for gym steps counting. Args: keypoints (list): Keypoints data to be plotted. indices (list, optional): Keypoint indices to be plotted. Defaults to [2, 5, 7]. - shape (tuple, optional): Image size for model inference. Defaults to (640, 640). radius (int, optional): Keypoint radius. Defaults to 2. conf_thres (float, optional): Confidence threshold for keypoints. Defaults to 0.25. @@ -715,90 +714,71 @@ class Annotator: Keypoint format: [x, y] or [x, y, confidence]. Modifies self.im in-place. """ - if indices is None: - indices = [2, 5, 7] - for i, k in enumerate(keypoints): - if i in indices: - x_coord, y_coord = k[0], k[1] - if x_coord % shape[1] != 0 and y_coord % shape[0] != 0: - if len(k) == 3: - conf = k[2] - if conf < conf_thres: - continue - cv2.circle(self.im, (int(x_coord), int(y_coord)), radius, (0, 255, 0), -1, lineType=cv2.LINE_AA) + indices = indices or [2, 5, 7] + points = [(int(k[0]), int(k[1])) for i, k in enumerate(keypoints) if i in indices and k[2] >= conf_thres] + + # Draw lines between consecutive points + for start, end in zip(points[:-1], points[1:]): + cv2.line(self.im, start, end, (0, 255, 0), 2, lineType=cv2.LINE_AA) + + # Draw circles for keypoints + for pt in points: + cv2.circle(self.im, pt, radius, (0, 0, 255), -1, lineType=cv2.LINE_AA) + return self.im - def plot_angle_and_count_and_stage( - self, angle_text, count_text, stage_text, center_kpt, color=(104, 31, 17), txt_color=(255, 255, 255) - ): + def plot_workout_information(self, display_text, position, color=(104, 31, 17), txt_color=(255, 255, 255)): """ - Plot the pose angle, count value and step stage. + Draw text with a background on the image. Args: - angle_text (str): angle value for workout monitoring - count_text (str): counts value for workout monitoring - stage_text (str): stage decision for workout monitoring - center_kpt (list): centroid pose index for workout monitoring - color (tuple): text background color for workout monitoring - txt_color (tuple): text foreground color for workout monitoring + display_text (str): The text to be displayed. + position (tuple): Coordinates (x, y) on the image where the text will be placed. + color (tuple, optional): Text background color + txt_color (tuple, optional): Text foreground color """ - angle_text, count_text, stage_text = (f" {angle_text:.2f}", f"Steps : {count_text}", f" {stage_text}") + (text_width, text_height), _ = cv2.getTextSize(display_text, 0, self.sf, self.tf) - # Draw angle - (angle_text_width, angle_text_height), _ = cv2.getTextSize(angle_text, 0, self.sf, self.tf) - angle_text_position = (int(center_kpt[0]), int(center_kpt[1])) - angle_background_position = (angle_text_position[0], angle_text_position[1] - angle_text_height - 5) - angle_background_size = (angle_text_width + 2 * 5, angle_text_height + 2 * 5 + (self.tf * 2)) + # Draw background rectangle cv2.rectangle( self.im, - angle_background_position, - ( - angle_background_position[0] + angle_background_size[0], - angle_background_position[1] + angle_background_size[1], - ), + (position[0], position[1] - text_height - 5), + (position[0] + text_width + 10, position[1] - text_height - 5 + text_height + 10 + self.tf), color, -1, ) - cv2.putText(self.im, angle_text, angle_text_position, 0, self.sf, txt_color, self.tf) - - # Draw Counts - (count_text_width, count_text_height), _ = cv2.getTextSize(count_text, 0, self.sf, self.tf) - count_text_position = (angle_text_position[0], angle_text_position[1] + angle_text_height + 20) - count_background_position = ( - angle_background_position[0], - angle_background_position[1] + angle_background_size[1] + 5, - ) - count_background_size = (count_text_width + 10, count_text_height + 10 + self.tf) + # Draw text + cv2.putText(self.im, display_text, position, 0, self.sf, txt_color, self.tf) - cv2.rectangle( - self.im, - count_background_position, - ( - count_background_position[0] + count_background_size[0], - count_background_position[1] + count_background_size[1], - ), - color, - -1, - ) - cv2.putText(self.im, count_text, count_text_position, 0, self.sf, txt_color, self.tf) + return text_height - # Draw Stage - (stage_text_width, stage_text_height), _ = cv2.getTextSize(stage_text, 0, self.sf, self.tf) - stage_text_position = (int(center_kpt[0]), int(center_kpt[1]) + angle_text_height + count_text_height + 40) - stage_background_position = (stage_text_position[0], stage_text_position[1] - stage_text_height - 5) - stage_background_size = (stage_text_width + 10, stage_text_height + 10) + def plot_angle_and_count_and_stage( + self, angle_text, count_text, stage_text, center_kpt, color=(104, 31, 17), txt_color=(255, 255, 255) + ): + """ + Plot the pose angle, count value, and step stage. - cv2.rectangle( - self.im, - stage_background_position, - ( - stage_background_position[0] + stage_background_size[0], - stage_background_position[1] + stage_background_size[1], - ), - color, - -1, + Args: + angle_text (str): Angle value for workout monitoring + count_text (str): Counts value for workout monitoring + stage_text (str): Stage decision for workout monitoring + center_kpt (list): Centroid pose index for workout monitoring + color (tuple, optional): Text background color + txt_color (tuple, optional): Text foreground color + """ + # Format text + angle_text, count_text, stage_text = f" {angle_text:.2f}", f"Steps : {count_text}", f" {stage_text}" + + # Draw angle, count and stage text + angle_height = self.plot_workout_information( + angle_text, (int(center_kpt[0]), int(center_kpt[1])), color, txt_color + ) + count_height = self.plot_workout_information( + count_text, (int(center_kpt[0]), int(center_kpt[1]) + angle_height + 20), color, txt_color + ) + self.plot_workout_information( + stage_text, (int(center_kpt[0]), int(center_kpt[1]) + angle_height + count_height + 40), color, txt_color ) - cv2.putText(self.im, stage_text, stage_text_position, 0, self.sf, txt_color, self.tf) def seg_bbox(self, mask, mask_color=(255, 0, 255), label=None, txt_color=(255, 255, 255)): """