From 89108513c42f9edd98ceb14468386960779e0e4a Mon Sep 17 00:00:00 2001 From: Muhammad Rizwan Munawar Date: Thu, 6 Jun 2024 13:16:58 +0500 Subject: [PATCH] Add area chart in `analytics` (#13391) Co-authored-by: UltralyticsAssistant --- docs/en/guides/analytics.md | 85 +++++++++++++++++----- ultralytics/solutions/analytics.py | 112 ++++++++++++++++++++++------- 2 files changed, 157 insertions(+), 40 deletions(-) diff --git a/docs/en/guides/analytics.md b/docs/en/guides/analytics.md index dc9bf8e7d..531f3ac5a 100644 --- a/docs/en/guides/analytics.md +++ b/docs/en/guides/analytics.md @@ -229,27 +229,80 @@ This guide provides a comprehensive overview of three fundamental types of data out.release() cv2.destroyAllWindows() ``` + + === "Area chart" + + ```python + import cv2 + from ultralytics import YOLO, solutions + model = YOLO("yolov8s.pt") + + cap = cv2.VideoCapture("path/to/video/file.mp4") + assert cap.isOpened(), "Error reading video file" + w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS)) + + out = cv2.VideoWriter("area_plot.avi", cv2.VideoWriter_fourcc(*"MJPG"), fps, (w, h)) + + analytics = solutions.Analytics( + type="area", + writer=out, + im0_shape=(w, h), + view_img=True, + ) + + clswise_count = {} + frame_count = 0 + + while cap.isOpened(): + success, frame = cap.read() + if success: + + frame_count += 1 + results = model.track(frame, persist=True, verbose=True) + + if results[0].boxes.id is not None: + boxes = results[0].boxes.xyxy.cpu() + clss = results[0].boxes.cls.cpu().tolist() + + for box, cls in zip(boxes, clss): + if model.names[int(cls)] in clswise_count: + clswise_count[model.names[int(cls)]] += 1 + else: + clswise_count[model.names[int(cls)]] = 1 + + analytics.update_area(frame_count, clswise_count) + clswise_count = {} + if cv2.waitKey(1) & 0xFF == ord("q"): + break + else: + break + + cap.release() + out.release() + cv2.destroyAllWindows() + ``` ### Argument `Analytics` Here's a table with the `Analytics` arguments: -| Name | Type | Default | Description | -|--------------|-------------------|---------------|----------------------------------------------------------------------------------| -| `type` | `str` | `None` | Type of data or object. | -| `im0_shape` | `tuple` | `None` | Shape of the initial image. | -| `writer` | `cv2.VideoWriter` | `None` | Object for writing video files. | -| `title` | `str` | `ultralytics` | Title for the visualization. | -| `x_label` | `str` | `x` | Label for the x-axis. | -| `y_label` | `str` | `y` | Label for the y-axis. | -| `bg_color` | `str` | `white` | Background color. | -| `fg_color` | `str` | `black` | Foreground color. | -| `line_color` | `str` | `yellow` | Color of the lines. | -| `line_width` | `int` | `2` | Width of the lines. | -| `fontsize` | `int` | `13` | Font size for text. | -| `view_img` | `bool` | `False` | Flag to display the image or video. | -| `save_img` | `bool` | `True` | Flag to save the image or video. | -| `max_points` | `int` | `50` | For multiple lines, total points drawn on frame, before deleting initial points. | +| Name | Type | Default | Description | +|----------------|-------------------|---------------|----------------------------------------------------------------------------------| +| `type` | `str` | `None` | Type of data or object. | +| `im0_shape` | `tuple` | `None` | Shape of the initial image. | +| `writer` | `cv2.VideoWriter` | `None` | Object for writing video files. | +| `title` | `str` | `ultralytics` | Title for the visualization. | +| `x_label` | `str` | `x` | Label for the x-axis. | +| `y_label` | `str` | `y` | Label for the y-axis. | +| `bg_color` | `str` | `white` | Background color. | +| `fg_color` | `str` | `black` | Foreground color. | +| `line_color` | `str` | `yellow` | Color of the lines. | +| `line_width` | `int` | `2` | Width of the lines. | +| `fontsize` | `int` | `13` | Font size for text. | +| `view_img` | `bool` | `False` | Flag to display the image or video. | +| `save_img` | `bool` | `True` | Flag to save the image or video. | +| `max_points` | `int` | `50` | For multiple lines, total points drawn on frame, before deleting initial points. | +| `points_width` | `int` | `15` | Width of line points highlighter. | ### Arguments `model.track` diff --git a/ultralytics/solutions/analytics.py b/ultralytics/solutions/analytics.py index 98d0bb256..3715c21ab 100644 --- a/ultralytics/solutions/analytics.py +++ b/ultralytics/solutions/analytics.py @@ -11,7 +11,7 @@ from matplotlib.figure import Figure class Analytics: - """A class to create and update various types of charts (line, bar, pie) for visual analytics.""" + """A class to create and update various types of charts (line, bar, pie, area) for visual analytics.""" def __init__( self, @@ -25,6 +25,7 @@ class Analytics: fg_color="black", line_color="yellow", line_width=2, + points_width=10, fontsize=13, view_img=False, save_img=True, @@ -34,7 +35,7 @@ class Analytics: Initialize the Analytics class with various chart types. Args: - type (str): Type of chart to initialize ('line', 'bar', or 'pie'). + type (str): Type of chart to initialize ('line', 'bar', 'pie', or 'area'). writer (object): Video writer object to save the frames. im0_shape (tuple): Shape of the input image (width, height). title (str): Title of the chart. @@ -44,6 +45,7 @@ class Analytics: fg_color (str): Foreground (text) color of the chart. line_color (str): Line color for line charts. line_width (int): Width of the lines in line charts. + points_width (int): Width of line points highlighter fontsize (int): Font size for chart text. view_img (bool): Whether to display the image. save_img (bool): Whether to save the image. @@ -57,17 +59,24 @@ class Analytics: self.title = title self.writer = writer self.max_points = max_points + self.line_color = line_color + self.x_label = x_label + self.y_label = y_label + self.points_width = points_width + self.line_width = line_width + self.fontsize = fontsize # Set figure size based on image shape figsize = (im0_shape[0] / 100, im0_shape[1] / 100) - if type == "line": - # Initialize line plot + if type in {"line", "area"}: + # Initialize line or area plot self.lines = {} - fig = Figure(facecolor=self.bg_color, figsize=figsize) - self.canvas = FigureCanvas(fig) - self.ax = fig.add_subplot(111, facecolor=self.bg_color) - (self.line,) = self.ax.plot([], [], color=line_color, linewidth=line_width) + self.fig = Figure(facecolor=self.bg_color, figsize=figsize) + self.canvas = FigureCanvas(self.fig) + self.ax = self.fig.add_subplot(111, facecolor=self.bg_color) + if type == "line": + (self.line,) = self.ax.plot([], [], color=self.line_color, linewidth=self.line_width) elif type in {"bar", "pie"}: # Initialize bar or pie plot @@ -93,11 +102,73 @@ class Analytics: self.ax.axis("equal") if type == "pie" else None # Set common axis properties - self.ax.set_title(self.title, color=self.fg_color, fontsize=fontsize) - self.ax.set_xlabel(x_label, color=self.fg_color, fontsize=fontsize - 3) - self.ax.set_ylabel(y_label, color=self.fg_color, fontsize=fontsize - 3) + self.ax.set_title(self.title, color=self.fg_color, fontsize=self.fontsize) + self.ax.set_xlabel(x_label, color=self.fg_color, fontsize=self.fontsize - 3) + self.ax.set_ylabel(y_label, color=self.fg_color, fontsize=self.fontsize - 3) self.ax.tick_params(axis="both", colors=self.fg_color) + def update_area(self, frame_number, counts_dict): + """ + Update the area graph with new data for multiple classes. + + Args: + frame_number (int): The current frame number. + counts_dict (dict): Dictionary with class names as keys and counts as values. + """ + + x_data = np.array([]) + y_data_dict = {key: np.array([]) for key in counts_dict.keys()} + + if self.ax.lines: + x_data = self.ax.lines[0].get_xdata() + for line, key in zip(self.ax.lines, counts_dict.keys()): + y_data_dict[key] = line.get_ydata() + + x_data = np.append(x_data, float(frame_number)) + max_length = len(x_data) + + for key in counts_dict.keys(): + y_data_dict[key] = np.append(y_data_dict[key], float(counts_dict[key])) + if len(y_data_dict[key]) < max_length: + y_data_dict[key] = np.pad(y_data_dict[key], (0, max_length - len(y_data_dict[key])), "constant") + + # Remove the oldest points if the number of points exceeds max_points + if len(x_data) > self.max_points: + x_data = x_data[1:] + for key in counts_dict.keys(): + y_data_dict[key] = y_data_dict[key][1:] + + self.ax.clear() + + colors = ["#E1FF25", "#0BDBEB", "#FF64DA", "#111F68", "#042AFF"] + color_cycle = cycle(colors) + + for key, y_data in y_data_dict.items(): + color = next(color_cycle) + self.ax.fill_between(x_data, y_data, color=color, alpha=0.6) + self.ax.plot( + x_data, + y_data, + color=color, + linewidth=self.line_width, + marker="o", + markersize=self.points_width, + label=f"{key} Data Points", + ) + + self.ax.set_title(self.title, color=self.fg_color, fontsize=self.fontsize) + self.ax.set_xlabel(self.x_label, color=self.fg_color, fontsize=self.fontsize - 3) + self.ax.set_ylabel(self.y_label, color=self.fg_color, fontsize=self.fontsize - 3) + legend = self.ax.legend(loc="upper left", fontsize=13, facecolor=self.bg_color, edgecolor=self.fg_color) + + # Set legend text color + for text in legend.get_texts(): + text.set_color(self.fg_color) + + self.canvas.draw() + im0 = np.array(self.canvas.renderer.buffer_rgba()) + self.write_and_display(im0) + def update_line(self, frame_number, total_counts): """ Update the line graph with new data. @@ -117,7 +188,7 @@ class Analytics: self.ax.autoscale_view() self.canvas.draw() im0 = np.array(self.canvas.renderer.buffer_rgba()) - self.write_and_display_line(im0) + self.write_and_display(im0) def update_multiple_lines(self, counts_dict, labels_list, frame_number): """ @@ -131,7 +202,7 @@ class Analytics: warnings.warn("Display is not supported for multiple lines, output will be stored normally!") for obj in labels_list: if obj not in self.lines: - (line,) = self.ax.plot([], [], label=obj, marker="o", markersize=15) + (line,) = self.ax.plot([], [], label=obj, marker="o", markersize=self.points_width) self.lines[obj] = line x_data = self.lines[obj].get_xdata() @@ -153,16 +224,14 @@ class Analytics: im0 = np.array(self.canvas.renderer.buffer_rgba()) self.view_img = False # for multiple line view_img not supported yet, coming soon! - self.write_and_display_line(im0) + self.write_and_display(im0) - def write_and_display_line(self, im0): + def write_and_display(self, im0): """ Write and display the line graph Args: im0 (ndarray): Image for processing """ - - # convert image to BGR format im0 = cv2.cvtColor(im0[:, :, :3], cv2.COLOR_RGBA2BGR) cv2.imshow(self.title, im0) if self.view_img else None self.writer.write(im0) if self.save_img else None @@ -204,10 +273,7 @@ class Analytics: canvas.draw() buf = canvas.buffer_rgba() im0 = np.asarray(buf) - im0 = cv2.cvtColor(im0, cv2.COLOR_RGBA2BGR) - - self.writer.write(im0) if self.save_img else None - cv2.imshow(self.title, im0) if self.view_img else None + self.write_and_display(im0) def update_pie(self, classes_dict): """ @@ -239,9 +305,7 @@ class Analytics: # Display and save the updated chart im0 = self.fig.canvas.draw() im0 = np.array(self.fig.canvas.renderer.buffer_rgba()) - im0 = cv2.cvtColor(im0[:, :, :3], cv2.COLOR_RGBA2BGR) - self.writer.write(im0) if self.save_img else None - cv2.imshow(self.title, im0) if self.view_img else None + self.write_and_display(im0) if __name__ == "__main__":