From 1b52e5e6932312aae3cbaaf8dc4e9fea6422e7cd Mon Sep 17 00:00:00 2001 From: Muhammad Rizwan Munawar Date: Sun, 13 Oct 2024 19:46:35 +0500 Subject: [PATCH] Update `analytics` solution (#16823) Co-authored-by: UltralyticsAssistant --- docs/en/guides/analytics.md | 368 +++++++++------------- ultralytics/cfg/solutions/default.yaml | 1 + ultralytics/solutions/analytics.py | 415 +++++++++---------------- 3 files changed, 298 insertions(+), 486 deletions(-) diff --git a/docs/en/guides/analytics.md b/docs/en/guides/analytics.md index 1b7049c601..d073cd25b5 100644 --- a/docs/en/guides/analytics.md +++ b/docs/en/guides/analytics.md @@ -40,103 +40,32 @@ This guide provides a comprehensive overview of three fundamental types of [data ```python import cv2 - from ultralytics import YOLO, solutions - - model = YOLO("yolo11n.pt") + from ultralytics import solutions cap = cv2.VideoCapture("Path/to/video/file.mp4") assert cap.isOpened(), "Error reading video file" - w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS)) - out = cv2.VideoWriter("line_plot.avi", cv2.VideoWriter_fourcc(*"MJPG"), fps, (w, h)) + w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS)) - analytics = solutions.Analytics( - type="line", - writer=out, - im0_shape=(w, h), - view_img=True, + out = cv2.VideoWriter( + "ultralytics_analytics.avi", + cv2.VideoWriter_fourcc(*"MJPG"), + fps, + (1920, 1080), # This is fixed ) - total_counts = 0 - frame_count = 0 - - while cap.isOpened(): - success, frame = cap.read() - - if success: - frame_count += 1 - results = model.track(frame, persist=True, verbose=True) - - if results[0].boxes.id is not None: - boxes = results[0].boxes.xyxy.cpu() - for box in boxes: - total_counts += 1 - - analytics.update_line(frame_count, total_counts) - - total_counts = 0 - if cv2.waitKey(1) & 0xFF == ord("q"): - break - else: - break - - cap.release() - out.release() - cv2.destroyAllWindows() - ``` - - === "Multiple Lines" - - ```python - import cv2 - - from ultralytics import YOLO, solutions - - model = YOLO("yolo11n.pt") - - cap = cv2.VideoCapture("Path/to/video/file.mp4") - assert cap.isOpened(), "Error reading video file" - w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS)) - out = cv2.VideoWriter("multiple_line_plot.avi", cv2.VideoWriter_fourcc(*"MJPG"), fps, (w, h)) analytics = solutions.Analytics( - type="line", - writer=out, - im0_shape=(w, h), - view_img=True, - max_points=200, + analytics_type="line", + show=True, ) frame_count = 0 - data = {} - labels = [] - while cap.isOpened(): - success, frame = cap.read() - + success, im0 = cap.read() if success: frame_count += 1 - - results = model.track(frame, persist=True) - - if results[0].boxes.id is not None: - boxes = results[0].boxes.xyxy.cpu() - track_ids = results[0].boxes.id.int().cpu().tolist() - clss = results[0].boxes.cls.cpu().tolist() - - for box, track_id, cls in zip(boxes, track_ids, clss): - # Store each class label - if model.names[int(cls)] not in labels: - labels.append(model.names[int(cls)]) - - # Store each class count - if model.names[int(cls)] in data: - data[model.names[int(cls)]] += 1 - else: - data[model.names[int(cls)]] = 0 - - # update lines every frame - analytics.update_multiple_lines(data, labels, frame_count) - data = {} # clear the data list for next frame + im0 = analytics.process_data(im0, frame_count) # update analytics graph every frame + out.write(im0) # write the video file else: break @@ -150,43 +79,32 @@ This guide provides a comprehensive overview of three fundamental types of [data ```python import cv2 - from ultralytics import YOLO, solutions - - model = YOLO("yolo11n.pt") + from ultralytics import solutions cap = cv2.VideoCapture("Path/to/video/file.mp4") assert cap.isOpened(), "Error reading video file" + w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS)) - out = cv2.VideoWriter("pie_chart.avi", cv2.VideoWriter_fourcc(*"MJPG"), fps, (w, h)) + out = cv2.VideoWriter( + "ultralytics_analytics.avi", + cv2.VideoWriter_fourcc(*"MJPG"), + fps, + (1920, 1080), # This is fixed + ) analytics = solutions.Analytics( - type="pie", - writer=out, - im0_shape=(w, h), - view_img=True, + analytics_type="pie", + show=True, ) - clswise_count = {} - + frame_count = 0 while cap.isOpened(): - success, frame = cap.read() + success, im0 = cap.read() if success: - results = model.track(frame, persist=True, verbose=True) - if results[0].boxes.id is not None: - boxes = results[0].boxes.xyxy.cpu() - clss = results[0].boxes.cls.cpu().tolist() - for box, cls in zip(boxes, clss): - if model.names[int(cls)] in clswise_count: - clswise_count[model.names[int(cls)]] += 1 - else: - clswise_count[model.names[int(cls)]] = 1 - - analytics.update_pie(clswise_count) - clswise_count = {} - - if cv2.waitKey(1) & 0xFF == ord("q"): - break + frame_count += 1 + im0 = analytics.process_data(im0, frame_count) # update analytics graph every frame + out.write(im0) # write the video file else: break @@ -200,43 +118,32 @@ This guide provides a comprehensive overview of three fundamental types of [data ```python import cv2 - from ultralytics import YOLO, solutions - - model = YOLO("yolo11n.pt") + from ultralytics import solutions cap = cv2.VideoCapture("Path/to/video/file.mp4") assert cap.isOpened(), "Error reading video file" + w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS)) - out = cv2.VideoWriter("bar_plot.avi", cv2.VideoWriter_fourcc(*"MJPG"), fps, (w, h)) + out = cv2.VideoWriter( + "ultralytics_analytics.avi", + cv2.VideoWriter_fourcc(*"MJPG"), + fps, + (1920, 1080), # This is fixed + ) analytics = solutions.Analytics( - type="bar", - writer=out, - im0_shape=(w, h), - view_img=True, + analytics_type="bar", + show=True, ) - clswise_count = {} - + frame_count = 0 while cap.isOpened(): - success, frame = cap.read() + success, im0 = cap.read() if success: - results = model.track(frame, persist=True, verbose=True) - if results[0].boxes.id is not None: - boxes = results[0].boxes.xyxy.cpu() - clss = results[0].boxes.cls.cpu().tolist() - for box, cls in zip(boxes, clss): - if model.names[int(cls)] in clswise_count: - clswise_count[model.names[int(cls)]] += 1 - else: - clswise_count[model.names[int(cls)]] = 1 - - analytics.update_bar(clswise_count) - clswise_count = {} - - if cv2.waitKey(1) & 0xFF == ord("q"): - break + frame_count += 1 + im0 = analytics.process_data(im0, frame_count) # update analytics graph every frame + out.write(im0) # write the video file else: break @@ -250,46 +157,32 @@ This guide provides a comprehensive overview of three fundamental types of [data ```python import cv2 - from ultralytics import YOLO, solutions + from ultralytics import solutions - model = YOLO("yolo11n.pt") - - cap = cv2.VideoCapture("path/to/video/file.mp4") + cap = cv2.VideoCapture("Path/to/video/file.mp4") assert cap.isOpened(), "Error reading video file" + w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS)) - out = cv2.VideoWriter("area_plot.avi", cv2.VideoWriter_fourcc(*"MJPG"), fps, (w, h)) + out = cv2.VideoWriter( + "ultralytics_analytics.avi", + cv2.VideoWriter_fourcc(*"MJPG"), + fps, + (1920, 1080), # This is fixed + ) analytics = solutions.Analytics( - type="area", - writer=out, - im0_shape=(w, h), - view_img=True, + analytics_type="area", + show=True, ) - clswise_count = {} frame_count = 0 - while cap.isOpened(): - success, frame = cap.read() + success, im0 = cap.read() if success: frame_count += 1 - results = model.track(frame, persist=True, verbose=True) - - if results[0].boxes.id is not None: - boxes = results[0].boxes.xyxy.cpu() - clss = results[0].boxes.cls.cpu().tolist() - - for box, cls in zip(boxes, clss): - if model.names[int(cls)] in clswise_count: - clswise_count[model.names[int(cls)]] += 1 - else: - clswise_count[model.names[int(cls)]] = 1 - - analytics.update_area(frame_count, clswise_count) - clswise_count = {} - if cv2.waitKey(1) & 0xFF == ord("q"): - break + im0 = analytics.process_data(im0, frame_count) # update analytics graph every frame + out.write(im0) # write the video file else: break @@ -302,23 +195,12 @@ This guide provides a comprehensive overview of three fundamental types of [data Here's a table with the `Analytics` arguments: -| Name | Type | Default | Description | -| -------------- | ----------------- | ------------- | -------------------------------------------------------------------------------- | -| `type` | `str` | `None` | Type of data or object. | -| `im0_shape` | `tuple` | `None` | Shape of the initial image. | -| `writer` | `cv2.VideoWriter` | `None` | Object for writing video files. | -| `title` | `str` | `ultralytics` | Title for the visualization. | -| `x_label` | `str` | `x` | Label for the x-axis. | -| `y_label` | `str` | `y` | Label for the y-axis. | -| `bg_color` | `str` | `white` | Background color. | -| `fg_color` | `str` | `black` | Foreground color. | -| `line_color` | `str` | `yellow` | Color of the lines. | -| `line_width` | `int` | `2` | Width of the lines. | -| `fontsize` | `int` | `13` | Font size for text. | -| `view_img` | `bool` | `False` | Flag to display the image or video. | -| `save_img` | `bool` | `True` | Flag to save the image or video. | -| `max_points` | `int` | `50` | For multiple lines, total points drawn on frame, before deleting initial points. | -| `points_width` | `int` | `15` | Width of line points highlighter. | +| Name | Type | Default | Description | +| ---------------- | ------ | ------- | ---------------------------------------------------- | +| `analytics_type` | `str` | `line` | Type of graph i.e "line", "bar", "area", "pie" | +| `model` | `str` | `None` | Path to Ultralytics YOLO Model File | +| `line_width` | `int` | `2` | Line thickness for bounding boxes. | +| `show` | `bool` | `False` | Flag to control whether to display the video stream. | ### Arguments `model.track` @@ -344,21 +226,33 @@ Example: ```python import cv2 -from ultralytics import YOLO, solutions +from ultralytics import solutions -model = YOLO("yolo11n.pt") cap = cv2.VideoCapture("Path/to/video/file.mp4") -out = cv2.VideoWriter("line_plot.avi", cv2.VideoWriter_fourcc(*"MJPG"), fps, (w, h)) +assert cap.isOpened(), "Error reading video file" + +w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS)) -analytics = solutions.Analytics(type="line", writer=out, im0_shape=(w, h), view_img=True) +out = cv2.VideoWriter( + "ultralytics_analytics.avi", + cv2.VideoWriter_fourcc(*"MJPG"), + fps, + (1920, 1080), # This is fixed +) +analytics = solutions.Analytics( + analytics_type="line", + show=True, +) + +frame_count = 0 while cap.isOpened(): - success, frame = cap.read() + success, im0 = cap.read() if success: - results = model.track(frame, persist=True) - total_counts = sum([1 for box in results[0].boxes.xyxy]) - analytics.update_line(frame_count, total_counts) - if cv2.waitKey(1) & 0xFF == ord("q"): + frame_count += 1 + im0 = analytics.process_data(im0, frame_count) # update analytics graph every frame + out.write(im0) # write the video file + else: break cap.release() @@ -382,24 +276,33 @@ Use the following example to generate a bar plot: ```python import cv2 -from ultralytics import YOLO, solutions +from ultralytics import solutions -model = YOLO("yolo11n.pt") cap = cv2.VideoCapture("Path/to/video/file.mp4") -out = cv2.VideoWriter("bar_plot.avi", cv2.VideoWriter_fourcc(*"MJPG"), fps, (w, h)) +assert cap.isOpened(), "Error reading video file" + +w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS)) -analytics = solutions.Analytics(type="bar", writer=out, im0_shape=(w, h), view_img=True) +out = cv2.VideoWriter( + "ultralytics_analytics.avi", + cv2.VideoWriter_fourcc(*"MJPG"), + fps, + (1920, 1080), # This is fixed +) +analytics = solutions.Analytics( + analytics_type="bar", + show=True, +) + +frame_count = 0 while cap.isOpened(): - success, frame = cap.read() + success, im0 = cap.read() if success: - results = model.track(frame, persist=True) - clswise_count = { - model.names[int(cls)]: boxes.size(0) - for cls, boxes in zip(results[0].boxes.cls.tolist(), results[0].boxes.xyxy) - } - analytics.update_bar(clswise_count) - if cv2.waitKey(1) & 0xFF == ord("q"): + frame_count += 1 + im0 = analytics.process_data(im0, frame_count) # update analytics graph every frame + out.write(im0) # write the video file + else: break cap.release() @@ -423,24 +326,33 @@ Here's a quick example: ```python import cv2 -from ultralytics import YOLO, solutions +from ultralytics import solutions -model = YOLO("yolo11n.pt") cap = cv2.VideoCapture("Path/to/video/file.mp4") -out = cv2.VideoWriter("pie_chart.avi", cv2.VideoWriter_fourcc(*"MJPG"), fps, (w, h)) +assert cap.isOpened(), "Error reading video file" + +w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS)) -analytics = solutions.Analytics(type="pie", writer=out, im0_shape=(w, h), view_img=True) +out = cv2.VideoWriter( + "ultralytics_analytics.avi", + cv2.VideoWriter_fourcc(*"MJPG"), + fps, + (1920, 1080), # This is fixed +) +analytics = solutions.Analytics( + analytics_type="pie", + show=True, +) + +frame_count = 0 while cap.isOpened(): - success, frame = cap.read() + success, im0 = cap.read() if success: - results = model.track(frame, persist=True) - clswise_count = { - model.names[int(cls)]: boxes.size(0) - for cls, boxes in zip(results[0].boxes.cls.tolist(), results[0].boxes.xyxy) - } - analytics.update_pie(clswise_count) - if cv2.waitKey(1) & 0xFF == ord("q"): + frame_count += 1 + im0 = analytics.process_data(im0, frame_count) # update analytics graph every frame + out.write(im0) # write the video file + else: break cap.release() @@ -459,21 +371,33 @@ Example for tracking and updating a line graph: ```python import cv2 -from ultralytics import YOLO, solutions +from ultralytics import solutions -model = YOLO("yolo11n.pt") cap = cv2.VideoCapture("Path/to/video/file.mp4") -out = cv2.VideoWriter("line_plot.avi", cv2.VideoWriter_fourcc(*"MJPG"), fps, (w, h)) +assert cap.isOpened(), "Error reading video file" + +w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS)) + +out = cv2.VideoWriter( + "ultralytics_analytics.avi", + cv2.VideoWriter_fourcc(*"MJPG"), + fps, + (1920, 1080), # This is fixed +) -analytics = solutions.Analytics(type="line", writer=out, im0_shape=(w, h), view_img=True) +analytics = solutions.Analytics( + analytics_type="line", + show=True, +) +frame_count = 0 while cap.isOpened(): - success, frame = cap.read() + success, im0 = cap.read() if success: - results = model.track(frame, persist=True) - total_counts = sum([1 for box in results[0].boxes.xyxy]) - analytics.update_line(frame_count, total_counts) - if cv2.waitKey(1) & 0xFF == ord("q"): + frame_count += 1 + im0 = analytics.process_data(im0, frame_count) # update analytics graph every frame + out.write(im0) # write the video file + else: break cap.release() diff --git a/ultralytics/cfg/solutions/default.yaml b/ultralytics/cfg/solutions/default.yaml index a98ae52749..e4e1b845a0 100644 --- a/ultralytics/cfg/solutions/default.yaml +++ b/ultralytics/cfg/solutions/default.yaml @@ -14,3 +14,4 @@ up_angle: 145.0 # Workouts up_angle for counts, 145.0 is default value. You can down_angle: 90 # Workouts down_angle for counts, 90 is default value. You can change it for different workouts, based on position of keypoints. kpts: [6, 8, 10] # Keypoints for workouts monitoring, i.e. If you want to consider keypoints for pushups that have mostly values of [6, 8, 10]. colormap: # Colormap for heatmap, Only OPENCV supported colormaps can be used. By default COLORMAP_PARULA will be used for visualization. +analytics_type: "line" # Analytics type i.e "line", "pie", "bar" or "area" charts. By default, "line" analytics will be used for processing. diff --git a/ultralytics/solutions/analytics.py b/ultralytics/solutions/analytics.py index c299009778..ade3431bf1 100644 --- a/ultralytics/solutions/analytics.py +++ b/ultralytics/solutions/analytics.py @@ -1,6 +1,5 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -import warnings from itertools import cycle import cv2 @@ -9,299 +8,187 @@ import numpy as np from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas from matplotlib.figure import Figure +from ultralytics.solutions.solutions import BaseSolution # Import a parent class -class Analytics: + +class Analytics(BaseSolution): """A class to create and update various types of charts (line, bar, pie, area) for visual analytics.""" - def __init__( - self, - type, - writer, - im0_shape, - title="ultralytics", - x_label="x", - y_label="y", - bg_color="white", - fg_color="black", - line_color="yellow", - line_width=2, - points_width=10, - fontsize=13, - view_img=False, - save_img=True, - max_points=50, - ): - """ - Initialize the Analytics class with various chart types. + def __init__(self, **kwargs): + """Initialize the Analytics class with various chart types.""" + super().__init__(**kwargs) - Args: - type (str): Type of chart to initialize ('line', 'bar', 'pie', or 'area'). - writer (object): Video writer object to save the frames. - im0_shape (tuple): Shape of the input image (width, height). - title (str): Title of the chart. - x_label (str): Label for the x-axis. - y_label (str): Label for the y-axis. - bg_color (str): Background color of the chart. - fg_color (str): Foreground (text) color of the chart. - line_color (str): Line color for line charts. - line_width (int): Width of the lines in line charts. - points_width (int): Width of line points highlighter - fontsize (int): Font size for chart text. - view_img (bool): Whether to display the image. - save_img (bool): Whether to save the image. - max_points (int): Specifies when to remove the oldest points in a graph for multiple lines. - """ - self.bg_color = bg_color - self.fg_color = fg_color - self.view_img = view_img - self.save_img = save_img - self.title = title - self.writer = writer - self.max_points = max_points - self.line_color = line_color - self.x_label = x_label - self.y_label = y_label - self.points_width = points_width - self.line_width = line_width - self.fontsize = fontsize + self.type = self.CFG["analytics_type"] # extract type of analytics + self.x_label = "Classes" if self.type in {"bar", "pie"} else "Frame#" + self.y_label = "Total Counts" + + # Predefined data + self.bg_color = "#00F344" # background color of frame + self.fg_color = "#111E68" # foreground color of frame + self.title = "Ultralytics Solutions" # window name + self.max_points = 45 # maximum points to be drawn on window + self.fontsize = 25 # text font size for display + figsize = (19.2, 10.8) # Set output image size 1920 * 1080 + self.color_cycle = cycle(["#DD00BA", "#042AFF", "#FF4447", "#7D24FF", "#BD00FF"]) - # Set figure size based on image shape - figsize = (im0_shape[0] / 100, im0_shape[1] / 100) + self.total_counts = 0 # count variable for storing total counts i.e for line + self.clswise_count = {} # dictionary for classwise counts - if type in {"line", "area"}: - # Initialize line or area plot + # Ensure line and area chart + if self.type in {"line", "area"}: self.lines = {} self.fig = Figure(facecolor=self.bg_color, figsize=figsize) - self.canvas = FigureCanvas(self.fig) + self.canvas = FigureCanvas(self.fig) # Set common axis properties self.ax = self.fig.add_subplot(111, facecolor=self.bg_color) - if type == "line": - (self.line,) = self.ax.plot([], [], color=self.line_color, linewidth=self.line_width) - - elif type in {"bar", "pie"}: + if self.type == "line": + (self.line,) = self.ax.plot([], [], color="cyan", linewidth=self.line_width) + elif self.type in {"bar", "pie"}: # Initialize bar or pie plot self.fig, self.ax = plt.subplots(figsize=figsize, facecolor=self.bg_color) + self.canvas = FigureCanvas(self.fig) # Set common axis properties self.ax.set_facecolor(self.bg_color) - color_palette = [ - (31, 119, 180), - (255, 127, 14), - (44, 160, 44), - (214, 39, 40), - (148, 103, 189), - (140, 86, 75), - (227, 119, 194), - (127, 127, 127), - (188, 189, 34), - (23, 190, 207), - ] - self.color_palette = [(r / 255, g / 255, b / 255, 1) for r, g, b in color_palette] - self.color_cycle = cycle(self.color_palette) self.color_mapping = {} + self.ax.axis("equal") if type == "pie" else None # Ensure pie chart is circular - # Ensure pie chart is circular - self.ax.axis("equal") if type == "pie" else None - - # Set common axis properties - self.ax.set_title(self.title, color=self.fg_color, fontsize=self.fontsize) - self.ax.set_xlabel(x_label, color=self.fg_color, fontsize=self.fontsize - 3) - self.ax.set_ylabel(y_label, color=self.fg_color, fontsize=self.fontsize - 3) - self.ax.tick_params(axis="both", colors=self.fg_color) + def process_data(self, im0, frame_number): + """ + Process the image data, run object tracking. - def update_area(self, frame_number, counts_dict): + Args: + im0 (ndarray): Input image for processing. + frame_number (int): Video frame # for plotting the data. + """ + self.extract_tracks(im0) # Extract tracks + + if self.type == "line": + for box in self.boxes: + self.total_counts += 1 + im0 = self.update_graph(frame_number=frame_number) + self.total_counts = 0 + elif self.type == "pie" or self.type == "bar" or self.type == "area": + self.clswise_count = {} + for box, cls in zip(self.boxes, self.clss): + if self.names[int(cls)] in self.clswise_count: + self.clswise_count[self.names[int(cls)]] += 1 + else: + self.clswise_count[self.names[int(cls)]] = 1 + im0 = self.update_graph(frame_number=frame_number, count_dict=self.clswise_count, plot=self.type) + else: + raise ModuleNotFoundError(f"{self.type} chart is not supported ❌") + return im0 + + def update_graph(self, frame_number, count_dict=None, plot="line"): """ - Update the area graph with new data for multiple classes. + Update the graph (line or area) with new data for single or multiple classes. Args: frame_number (int): The current frame number. - counts_dict (dict): Dictionary with class names as keys and counts as values. + count_dict (dict, optional): Dictionary with class names as keys and counts as values for multiple classes. + If None, updates a single line graph. + plot (str): Type of the plot i.e. line, bar or area. """ - x_data = np.array([]) - y_data_dict = {key: np.array([]) for key in counts_dict.keys()} - - if self.ax.lines: - x_data = self.ax.lines[0].get_xdata() - for line, key in zip(self.ax.lines, counts_dict.keys()): - y_data_dict[key] = line.get_ydata() - - x_data = np.append(x_data, float(frame_number)) - max_length = len(x_data) - - for key in counts_dict.keys(): - y_data_dict[key] = np.append(y_data_dict[key], float(counts_dict[key])) - if len(y_data_dict[key]) < max_length: - y_data_dict[key] = np.pad(y_data_dict[key], (0, max_length - len(y_data_dict[key])), "constant") - - # Remove the oldest points if the number of points exceeds max_points - if len(x_data) > self.max_points: - x_data = x_data[1:] - for key in counts_dict.keys(): - y_data_dict[key] = y_data_dict[key][1:] - - self.ax.clear() - - colors = ["#E1FF25", "#0BDBEB", "#FF64DA", "#111F68", "#042AFF"] - color_cycle = cycle(colors) - - for key, y_data in y_data_dict.items(): - color = next(color_cycle) - self.ax.fill_between(x_data, y_data, color=color, alpha=0.6) - self.ax.plot( - x_data, - y_data, - color=color, - linewidth=self.line_width, - marker="o", - markersize=self.points_width, - label=f"{key} Data Points", - ) - + if count_dict is None: + # Single line update + x_data = np.append(self.line.get_xdata(), float(frame_number)) + y_data = np.append(self.line.get_ydata(), float(self.total_counts)) + + if len(x_data) > self.max_points: + x_data, y_data = x_data[-self.max_points :], y_data[-self.max_points :] + + self.line.set_data(x_data, y_data) + self.line.set_label("Counts") + self.line.set_color("#7b0068") # Pink color + self.line.set_marker("*") + self.line.set_markersize(self.line_width * 5) + else: + labels = list(count_dict.keys()) + counts = list(count_dict.values()) + if plot == "area": + color_cycle = cycle(["#DD00BA", "#042AFF", "#FF4447", "#7D24FF", "#BD00FF"]) + # Multiple lines or area update + x_data = self.ax.lines[0].get_xdata() if self.ax.lines else np.array([]) + y_data_dict = {key: np.array([]) for key in count_dict.keys()} + if self.ax.lines: + for line, key in zip(self.ax.lines, count_dict.keys()): + y_data_dict[key] = line.get_ydata() + + x_data = np.append(x_data, float(frame_number)) + max_length = len(x_data) + for key in count_dict.keys(): + y_data_dict[key] = np.append(y_data_dict[key], float(count_dict[key])) + if len(y_data_dict[key]) < max_length: + y_data_dict[key] = np.pad(y_data_dict[key], (0, max_length - len(y_data_dict[key])), "constant") + if len(x_data) > self.max_points: + x_data = x_data[1:] + for key in count_dict.keys(): + y_data_dict[key] = y_data_dict[key][1:] + + self.ax.clear() + for key, y_data in y_data_dict.items(): + color = next(color_cycle) + self.ax.fill_between(x_data, y_data, color=color, alpha=0.7) + self.ax.plot( + x_data, + y_data, + color=color, + linewidth=self.line_width, + marker="o", + markersize=self.line_width * 5, + label=f"{key} Data Points", + ) + if plot == "bar": + self.ax.clear() # clear bar data + for label in labels: # Map labels to colors + if label not in self.color_mapping: + self.color_mapping[label] = next(self.color_cycle) + colors = [self.color_mapping[label] for label in labels] + bars = self.ax.bar(labels, counts, color=colors) + for bar, count in zip(bars, counts): + self.ax.text( + bar.get_x() + bar.get_width() / 2, + bar.get_height(), + str(count), + ha="center", + va="bottom", + color=self.fg_color, + ) + # Create the legend using labels from the bars + for bar, label in zip(bars, labels): + bar.set_label(label) # Assign label to each bar + self.ax.legend(loc="upper left", fontsize=13, facecolor=self.fg_color, edgecolor=self.fg_color) + if plot == "pie": + total = sum(counts) + percentages = [size / total * 100 for size in counts] + start_angle = 90 + self.ax.clear() + + # Create pie chart and create legend labels with percentages + wedges, autotexts = self.ax.pie( + counts, labels=labels, startangle=start_angle, textprops={"color": self.fg_color}, autopct=None + ) + legend_labels = [f"{label} ({percentage:.1f}%)" for label, percentage in zip(labels, percentages)] + + # Assign the legend using the wedges and manually created labels + self.ax.legend(wedges, legend_labels, title="Classes", loc="center left", bbox_to_anchor=(1, 0, 0.5, 1)) + self.fig.subplots_adjust(left=0.1, right=0.75) # Adjust layout to fit the legend + + # Common plot settings + self.ax.set_facecolor("#f0f0f0") # Set to light gray or any other color you like self.ax.set_title(self.title, color=self.fg_color, fontsize=self.fontsize) self.ax.set_xlabel(self.x_label, color=self.fg_color, fontsize=self.fontsize - 3) self.ax.set_ylabel(self.y_label, color=self.fg_color, fontsize=self.fontsize - 3) - legend = self.ax.legend(loc="upper left", fontsize=13, facecolor=self.bg_color, edgecolor=self.fg_color) - # Set legend text color + # Add and format legend + legend = self.ax.legend(loc="upper left", fontsize=13, facecolor=self.bg_color, edgecolor=self.bg_color) for text in legend.get_texts(): text.set_color(self.fg_color) - self.canvas.draw() - im0 = np.array(self.canvas.renderer.buffer_rgba()) - self.write_and_display(im0) - - def update_line(self, frame_number, total_counts): - """ - Update the line graph with new data. - - Args: - frame_number (int): The current frame number. - total_counts (int): The total counts to plot. - """ - # Update line graph data - x_data = self.line.get_xdata() - y_data = self.line.get_ydata() - x_data = np.append(x_data, float(frame_number)) - y_data = np.append(y_data, float(total_counts)) - self.line.set_data(x_data, y_data) + # Redraw graph, update view, capture, and display the updated plot self.ax.relim() self.ax.autoscale_view() self.canvas.draw() im0 = np.array(self.canvas.renderer.buffer_rgba()) - self.write_and_display(im0) - - def update_multiple_lines(self, counts_dict, labels_list, frame_number): - """ - Update the line graph with multiple classes. - - Args: - counts_dict (int): Dictionary include each class counts. - labels_list (int): list include each classes names. - frame_number (int): The current frame number. - """ - warnings.warn("Display is not supported for multiple lines, output will be stored normally!") - for obj in labels_list: - if obj not in self.lines: - (line,) = self.ax.plot([], [], label=obj, marker="o", markersize=self.points_width) - self.lines[obj] = line - - x_data = self.lines[obj].get_xdata() - y_data = self.lines[obj].get_ydata() - - # Remove the initial point if the number of points exceeds max_points - if len(x_data) >= self.max_points: - x_data = np.delete(x_data, 0) - y_data = np.delete(y_data, 0) - - x_data = np.append(x_data, float(frame_number)) # Ensure frame_number is converted to float - y_data = np.append(y_data, float(counts_dict.get(obj, 0))) # Ensure total_count is converted to float - self.lines[obj].set_data(x_data, y_data) - - self.ax.relim() - self.ax.autoscale_view() - self.ax.legend() - self.canvas.draw() - - im0 = np.array(self.canvas.renderer.buffer_rgba()) - self.view_img = False # for multiple line view_img not supported yet, coming soon! - self.write_and_display(im0) - - def write_and_display(self, im0): - """ - Write and display the line graph - Args: - im0 (ndarray): Image for processing. - """ im0 = cv2.cvtColor(im0[:, :, :3], cv2.COLOR_RGBA2BGR) - cv2.imshow(self.title, im0) if self.view_img else None - self.writer.write(im0) if self.save_img else None - - def update_bar(self, count_dict): - """ - Update the bar graph with new data. - - Args: - count_dict (dict): Dictionary containing the count data to plot. - """ - # Update bar graph data - self.ax.clear() - self.ax.set_facecolor(self.bg_color) - labels = list(count_dict.keys()) - counts = list(count_dict.values()) - - # Map labels to colors - for label in labels: - if label not in self.color_mapping: - self.color_mapping[label] = next(self.color_cycle) - - colors = [self.color_mapping[label] for label in labels] - - bars = self.ax.bar(labels, counts, color=colors) - for bar, count in zip(bars, counts): - self.ax.text( - bar.get_x() + bar.get_width() / 2, - bar.get_height(), - str(count), - ha="center", - va="bottom", - color=self.fg_color, - ) - - # Display and save the updated graph - canvas = FigureCanvas(self.fig) - canvas.draw() - buf = canvas.buffer_rgba() - im0 = np.asarray(buf) - self.write_and_display(im0) - - def update_pie(self, classes_dict): - """ - Update the pie chart with new data. - - Args: - classes_dict (dict): Dictionary containing the class data to plot. - """ - # Update pie chart data - labels = list(classes_dict.keys()) - sizes = list(classes_dict.values()) - total = sum(sizes) - percentages = [size / total * 100 for size in sizes] - start_angle = 90 - self.ax.clear() - - # Create pie chart without labels inside the slices - wedges, autotexts = self.ax.pie(sizes, autopct=None, startangle=start_angle, textprops={"color": self.fg_color}) - - # Construct legend labels with percentages - legend_labels = [f"{label} ({percentage:.1f}%)" for label, percentage in zip(labels, percentages)] - self.ax.legend(wedges, legend_labels, title="Classes", loc="center left", bbox_to_anchor=(1, 0, 0.5, 1)) - - # Adjust layout to fit the legend - self.fig.tight_layout() - self.fig.subplots_adjust(left=0.1, right=0.75) - - # Display and save the updated chart - im0 = self.fig.canvas.draw() - im0 = np.array(self.fig.canvas.renderer.buffer_rgba()) - self.write_and_display(im0) - + self.display_output(im0) -if __name__ == "__main__": - Analytics("line", writer=None, im0_shape=None) + return im0 # Return the image