Add area chart in `analytics` (#13391)

Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
pull/13370/head^2
Muhammad Rizwan Munawar 9 months ago committed by GitHub
parent 5376b1a42e
commit 89108513c4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 85
      docs/en/guides/analytics.md
  2. 112
      ultralytics/solutions/analytics.py

@ -229,27 +229,80 @@ This guide provides a comprehensive overview of three fundamental types of data
out.release()
cv2.destroyAllWindows()
```
=== "Area chart"
```python
import cv2
from ultralytics import YOLO, solutions
model = YOLO("yolov8s.pt")
cap = cv2.VideoCapture("path/to/video/file.mp4")
assert cap.isOpened(), "Error reading video file"
w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS))
out = cv2.VideoWriter("area_plot.avi", cv2.VideoWriter_fourcc(*"MJPG"), fps, (w, h))
analytics = solutions.Analytics(
type="area",
writer=out,
im0_shape=(w, h),
view_img=True,
)
clswise_count = {}
frame_count = 0
while cap.isOpened():
success, frame = cap.read()
if success:
frame_count += 1
results = model.track(frame, persist=True, verbose=True)
if results[0].boxes.id is not None:
boxes = results[0].boxes.xyxy.cpu()
clss = results[0].boxes.cls.cpu().tolist()
for box, cls in zip(boxes, clss):
if model.names[int(cls)] in clswise_count:
clswise_count[model.names[int(cls)]] += 1
else:
clswise_count[model.names[int(cls)]] = 1
analytics.update_area(frame_count, clswise_count)
clswise_count = {}
if cv2.waitKey(1) & 0xFF == ord("q"):
break
else:
break
cap.release()
out.release()
cv2.destroyAllWindows()
```
### Argument `Analytics`
Here's a table with the `Analytics` arguments:
| Name | Type | Default | Description |
|--------------|-------------------|---------------|----------------------------------------------------------------------------------|
| `type` | `str` | `None` | Type of data or object. |
| `im0_shape` | `tuple` | `None` | Shape of the initial image. |
| `writer` | `cv2.VideoWriter` | `None` | Object for writing video files. |
| `title` | `str` | `ultralytics` | Title for the visualization. |
| `x_label` | `str` | `x` | Label for the x-axis. |
| `y_label` | `str` | `y` | Label for the y-axis. |
| `bg_color` | `str` | `white` | Background color. |
| `fg_color` | `str` | `black` | Foreground color. |
| `line_color` | `str` | `yellow` | Color of the lines. |
| `line_width` | `int` | `2` | Width of the lines. |
| `fontsize` | `int` | `13` | Font size for text. |
| `view_img` | `bool` | `False` | Flag to display the image or video. |
| `save_img` | `bool` | `True` | Flag to save the image or video. |
| `max_points` | `int` | `50` | For multiple lines, total points drawn on frame, before deleting initial points. |
| Name | Type | Default | Description |
|----------------|-------------------|---------------|----------------------------------------------------------------------------------|
| `type` | `str` | `None` | Type of data or object. |
| `im0_shape` | `tuple` | `None` | Shape of the initial image. |
| `writer` | `cv2.VideoWriter` | `None` | Object for writing video files. |
| `title` | `str` | `ultralytics` | Title for the visualization. |
| `x_label` | `str` | `x` | Label for the x-axis. |
| `y_label` | `str` | `y` | Label for the y-axis. |
| `bg_color` | `str` | `white` | Background color. |
| `fg_color` | `str` | `black` | Foreground color. |
| `line_color` | `str` | `yellow` | Color of the lines. |
| `line_width` | `int` | `2` | Width of the lines. |
| `fontsize` | `int` | `13` | Font size for text. |
| `view_img` | `bool` | `False` | Flag to display the image or video. |
| `save_img` | `bool` | `True` | Flag to save the image or video. |
| `max_points` | `int` | `50` | For multiple lines, total points drawn on frame, before deleting initial points. |
| `points_width` | `int` | `15` | Width of line points highlighter. |
### Arguments `model.track`

@ -11,7 +11,7 @@ from matplotlib.figure import Figure
class Analytics:
"""A class to create and update various types of charts (line, bar, pie) for visual analytics."""
"""A class to create and update various types of charts (line, bar, pie, area) for visual analytics."""
def __init__(
self,
@ -25,6 +25,7 @@ class Analytics:
fg_color="black",
line_color="yellow",
line_width=2,
points_width=10,
fontsize=13,
view_img=False,
save_img=True,
@ -34,7 +35,7 @@ class Analytics:
Initialize the Analytics class with various chart types.
Args:
type (str): Type of chart to initialize ('line', 'bar', or 'pie').
type (str): Type of chart to initialize ('line', 'bar', 'pie', or 'area').
writer (object): Video writer object to save the frames.
im0_shape (tuple): Shape of the input image (width, height).
title (str): Title of the chart.
@ -44,6 +45,7 @@ class Analytics:
fg_color (str): Foreground (text) color of the chart.
line_color (str): Line color for line charts.
line_width (int): Width of the lines in line charts.
points_width (int): Width of line points highlighter
fontsize (int): Font size for chart text.
view_img (bool): Whether to display the image.
save_img (bool): Whether to save the image.
@ -57,17 +59,24 @@ class Analytics:
self.title = title
self.writer = writer
self.max_points = max_points
self.line_color = line_color
self.x_label = x_label
self.y_label = y_label
self.points_width = points_width
self.line_width = line_width
self.fontsize = fontsize
# Set figure size based on image shape
figsize = (im0_shape[0] / 100, im0_shape[1] / 100)
if type == "line":
# Initialize line plot
if type in {"line", "area"}:
# Initialize line or area plot
self.lines = {}
fig = Figure(facecolor=self.bg_color, figsize=figsize)
self.canvas = FigureCanvas(fig)
self.ax = fig.add_subplot(111, facecolor=self.bg_color)
(self.line,) = self.ax.plot([], [], color=line_color, linewidth=line_width)
self.fig = Figure(facecolor=self.bg_color, figsize=figsize)
self.canvas = FigureCanvas(self.fig)
self.ax = self.fig.add_subplot(111, facecolor=self.bg_color)
if type == "line":
(self.line,) = self.ax.plot([], [], color=self.line_color, linewidth=self.line_width)
elif type in {"bar", "pie"}:
# Initialize bar or pie plot
@ -93,11 +102,73 @@ class Analytics:
self.ax.axis("equal") if type == "pie" else None
# Set common axis properties
self.ax.set_title(self.title, color=self.fg_color, fontsize=fontsize)
self.ax.set_xlabel(x_label, color=self.fg_color, fontsize=fontsize - 3)
self.ax.set_ylabel(y_label, color=self.fg_color, fontsize=fontsize - 3)
self.ax.set_title(self.title, color=self.fg_color, fontsize=self.fontsize)
self.ax.set_xlabel(x_label, color=self.fg_color, fontsize=self.fontsize - 3)
self.ax.set_ylabel(y_label, color=self.fg_color, fontsize=self.fontsize - 3)
self.ax.tick_params(axis="both", colors=self.fg_color)
def update_area(self, frame_number, counts_dict):
"""
Update the area graph with new data for multiple classes.
Args:
frame_number (int): The current frame number.
counts_dict (dict): Dictionary with class names as keys and counts as values.
"""
x_data = np.array([])
y_data_dict = {key: np.array([]) for key in counts_dict.keys()}
if self.ax.lines:
x_data = self.ax.lines[0].get_xdata()
for line, key in zip(self.ax.lines, counts_dict.keys()):
y_data_dict[key] = line.get_ydata()
x_data = np.append(x_data, float(frame_number))
max_length = len(x_data)
for key in counts_dict.keys():
y_data_dict[key] = np.append(y_data_dict[key], float(counts_dict[key]))
if len(y_data_dict[key]) < max_length:
y_data_dict[key] = np.pad(y_data_dict[key], (0, max_length - len(y_data_dict[key])), "constant")
# Remove the oldest points if the number of points exceeds max_points
if len(x_data) > self.max_points:
x_data = x_data[1:]
for key in counts_dict.keys():
y_data_dict[key] = y_data_dict[key][1:]
self.ax.clear()
colors = ["#E1FF25", "#0BDBEB", "#FF64DA", "#111F68", "#042AFF"]
color_cycle = cycle(colors)
for key, y_data in y_data_dict.items():
color = next(color_cycle)
self.ax.fill_between(x_data, y_data, color=color, alpha=0.6)
self.ax.plot(
x_data,
y_data,
color=color,
linewidth=self.line_width,
marker="o",
markersize=self.points_width,
label=f"{key} Data Points",
)
self.ax.set_title(self.title, color=self.fg_color, fontsize=self.fontsize)
self.ax.set_xlabel(self.x_label, color=self.fg_color, fontsize=self.fontsize - 3)
self.ax.set_ylabel(self.y_label, color=self.fg_color, fontsize=self.fontsize - 3)
legend = self.ax.legend(loc="upper left", fontsize=13, facecolor=self.bg_color, edgecolor=self.fg_color)
# Set legend text color
for text in legend.get_texts():
text.set_color(self.fg_color)
self.canvas.draw()
im0 = np.array(self.canvas.renderer.buffer_rgba())
self.write_and_display(im0)
def update_line(self, frame_number, total_counts):
"""
Update the line graph with new data.
@ -117,7 +188,7 @@ class Analytics:
self.ax.autoscale_view()
self.canvas.draw()
im0 = np.array(self.canvas.renderer.buffer_rgba())
self.write_and_display_line(im0)
self.write_and_display(im0)
def update_multiple_lines(self, counts_dict, labels_list, frame_number):
"""
@ -131,7 +202,7 @@ class Analytics:
warnings.warn("Display is not supported for multiple lines, output will be stored normally!")
for obj in labels_list:
if obj not in self.lines:
(line,) = self.ax.plot([], [], label=obj, marker="o", markersize=15)
(line,) = self.ax.plot([], [], label=obj, marker="o", markersize=self.points_width)
self.lines[obj] = line
x_data = self.lines[obj].get_xdata()
@ -153,16 +224,14 @@ class Analytics:
im0 = np.array(self.canvas.renderer.buffer_rgba())
self.view_img = False # for multiple line view_img not supported yet, coming soon!
self.write_and_display_line(im0)
self.write_and_display(im0)
def write_and_display_line(self, im0):
def write_and_display(self, im0):
"""
Write and display the line graph
Args:
im0 (ndarray): Image for processing
"""
# convert image to BGR format
im0 = cv2.cvtColor(im0[:, :, :3], cv2.COLOR_RGBA2BGR)
cv2.imshow(self.title, im0) if self.view_img else None
self.writer.write(im0) if self.save_img else None
@ -204,10 +273,7 @@ class Analytics:
canvas.draw()
buf = canvas.buffer_rgba()
im0 = np.asarray(buf)
im0 = cv2.cvtColor(im0, cv2.COLOR_RGBA2BGR)
self.writer.write(im0) if self.save_img else None
cv2.imshow(self.title, im0) if self.view_img else None
self.write_and_display(im0)
def update_pie(self, classes_dict):
"""
@ -239,9 +305,7 @@ class Analytics:
# Display and save the updated chart
im0 = self.fig.canvas.draw()
im0 = np.array(self.fig.canvas.renderer.buffer_rgba())
im0 = cv2.cvtColor(im0[:, :, :3], cv2.COLOR_RGBA2BGR)
self.writer.write(im0) if self.save_img else None
cv2.imshow(self.title, im0) if self.view_img else None
self.write_and_display(im0)
if __name__ == "__main__":

Loading…
Cancel
Save