`ultralytics 8.2.20` new `Analytics` class with plotting visuals (#12955)

Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
pull/13051/head v8.2.20
Muhammad Rizwan Munawar 8 months ago committed by GitHub
parent 1a4ac2c6ba
commit 03d380d730
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 207
      docs/en/guides/analytics.md
  2. 3
      docs/en/guides/index.md
  3. 2
      docs/en/guides/workouts-monitoring.md
  4. 2
      docs/en/modes/predict.md
  5. 3
      mkdocs.yml
  6. 2
      ultralytics/__init__.py
  7. 2
      ultralytics/solutions/__init__.py
  8. 6
      ultralytics/solutions/ai_gym.py
  9. 197
      ultralytics/solutions/analytics.py

@ -0,0 +1,207 @@
---
comments: true
description: Comprehensive Guide to Understanding and Creating Line Graphs, Bar Plots, and Pie Charts
keywords: Analytics, Data Visualization, Line Graphs, Bar Plots, Pie Charts, Quickstart Guide, Data Analysis, Python, Visualization Tools
---
# Analytics using Ultralytics YOLOv8 📊
## Introduction
This guide provides a comprehensive overview of three fundamental types of data visualizations: line graphs, bar plots, and pie charts. Each section includes step-by-step instructions and code snippets on how to create these visualizations using Python.
### Visual Samples
| Line Graph | Bar Plot | Pie Chart |
|:------------------------------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------------------------:|:-----------------------------------------------------------------------------------------------------------------:|
| ![Line Graph](https://github.com/RizwanMunawar/RizwanMunawar/assets/62513924/eeabd90c-04fd-4e5b-aac9-c7777f892200) | ![Bar Plot](https://github.com/RizwanMunawar/RizwanMunawar/assets/62513924/c1da2d6a-99ff-43a8-b5dc-ca93127917f8) | ![Pie Chart](https://github.com/RizwanMunawar/RizwanMunawar/assets/62513924/9d8acce6-d9e4-4685-949d-cd4851483187) |
### Why Graphs are Important
- Line graphs are ideal for tracking changes over short and long periods and for comparing changes for multiple groups over the same period.
- Bar plots, on the other hand, are suitable for comparing quantities across different categories and showing relationships between a category and its numerical value.
- Lastly, pie charts are effective for illustrating proportions among categories and showing parts of a whole.
!!! Analytics "Analytics Examples"
=== "Line Graph"
```python
import cv2
from ultralytics import YOLO, solutions
model = YOLO("yolov8s.pt")
cap = cv2.VideoCapture("Path/to/video/file.mp4")
assert cap.isOpened(), "Error reading video file"
w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS))
out = cv2.VideoWriter("line_plot.avi", cv2.VideoWriter_fourcc(*"MJPG"), fps, (w, h))
analytics = solutions.Analytics(
type="line",
writer=out,
im0_shape=(w, h),
view_img=True,
)
total_counts = 0
frame_count = 0
while cap.isOpened():
success, frame = cap.read()
if success:
frame_count += 1
results = model.track(frame, persist=True, verbose=True)
if results[0].boxes.id is not None:
boxes = results[0].boxes.xyxy.cpu()
for box in boxes:
total_counts += 1
analytics.update_line(frame_count, total_counts)
total_counts = 0
if cv2.waitKey(1) & 0xFF == ord("q"):
break
else:
break
cap.release()
out.release()
cv2.destroyAllWindows()
```
=== "Pie Chart"
```python
import cv2
from ultralytics import YOLO, solutions
model = YOLO("yolov8s.pt")
cap = cv2.VideoCapture("Path/to/video/file.mp4")
assert cap.isOpened(), "Error reading video file"
w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS))
out = cv2.VideoWriter("pie_chart.avi", cv2.VideoWriter_fourcc(*"MJPG"), fps, (w, h))
analytics = solutions.Analytics(
type="pie",
writer=out,
im0_shape=(w, h),
view_img=True,
)
clswise_count = {}
while cap.isOpened():
success, frame = cap.read()
if success:
results = model.track(frame, persist=True, verbose=True)
if results[0].boxes.id is not None:
boxes = results[0].boxes.xyxy.cpu()
clss = results[0].boxes.cls.cpu().tolist()
for box, cls in zip(boxes, clss):
if model.names[int(cls)] in clswise_count:
clswise_count[model.names[int(cls)]] += 1
else:
clswise_count[model.names[int(cls)]] = 1
analytics.update_pie(clswise_count)
clswise_count = {}
if cv2.waitKey(1) & 0xFF == ord("q"):
break
else:
break
cap.release()
out.release()
cv2.destroyAllWindows()
```
=== "Bar Plot"
```python
import cv2
from ultralytics import YOLO, solutions
model = YOLO("yolov8s.pt")
cap = cv2.VideoCapture("Path/to/video/file.mp4")
assert cap.isOpened(), "Error reading video file"
w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS))
out = cv2.VideoWriter("bar_plot.avi", cv2.VideoWriter_fourcc(*"MJPG"), fps, (w, h))
analytics = solutions.Analytics(
type="bar",
writer=out,
im0_shape=(w, h),
view_img=True,
)
clswise_count = {}
while cap.isOpened():
success, frame = cap.read()
if success:
results = model.track(frame, persist=True, verbose=True)
if results[0].boxes.id is not None:
boxes = results[0].boxes.xyxy.cpu()
clss = results[0].boxes.cls.cpu().tolist()
for box, cls in zip(boxes, clss):
if model.names[int(cls)] in clswise_count:
clswise_count[model.names[int(cls)]] += 1
else:
clswise_count[model.names[int(cls)]] = 1
analytics.update_bar(clswise_count)
clswise_count = {}
if cv2.waitKey(1) & 0xFF == ord("q"):
break
else:
break
cap.release()
out.release()
cv2.destroyAllWindows()
```
### Argument `Analytics`
Here's a table with the `Analytics` arguments:
| Name | Type | Default | Description |
|--------------|-------------------|---------------|---------------------------------------------|
| `type` | `str` | `None` | Type of data or object. |
| `im0_shape` | `tuple` | `None` | Shape of the initial image. |
| `writer` | `cv2.VideoWriter` | `None` | Object for writing video files. |
| `title` | `str` | `ultralytics` | Title for the visualization. |
| `x_label` | `str` | `x` | Label for the x-axis. |
| `y_label` | `str` | `y` | Label for the y-axis. |
| `bg_color` | `str` | `white` | Background color. |
| `fg_color` | `str` | `black` | Foreground color. |
| `line_color` | `str` | `yellow` | Color of the lines. |
| `line_width` | `int` | `2` | Width of the lines. |
| `fontsize` | `int` | `13` | Font size for text. |
| `view_img` | `bool` | `False` | Flag to display the image or video. |
| `save_img` | `bool` | `True` | Flag to save the image or video. |
### Arguments `model.track`
| Name | Type | Default | Description |
|-----------|---------|----------------|-------------------------------------------------------------|
| `source` | `im0` | `None` | source directory for images or videos |
| `persist` | `bool` | `False` | persisting tracks between frames |
| `tracker` | `str` | `botsort.yaml` | Tracking method 'bytetrack' or 'botsort' |
| `conf` | `float` | `0.3` | Confidence Threshold |
| `iou` | `float` | `0.5` | IOU Threshold |
| `classes` | `list` | `None` | filter results by class, i.e. classes=0, or classes=[0,2,3] |
| `verbose` | `bool` | `True` | Display the object tracking results |
## Conclusion
Understanding when and how to use different types of visualizations is crucial for effective data analysis. Line graphs, bar plots, and pie charts are fundamental tools that can help you convey your data's story more clearly and effectively.

@ -45,6 +45,8 @@ Here's a compilation of in-depth guides to help you master different aspects of
## Real-World Projects
![Ultralytics Solutions Thumbnail](https://github.com/RizwanMunawar/RizwanMunawar/assets/62513924/44c8b148-7a9d-43e4-b7bf-272a7ac4e636)
- [Object Counting](object-counting.md) 🚀 NEW: Explore the process of real-time object counting with Ultralytics YOLOv8 and acquire the knowledge to effectively count objects in a live video stream.
- [Object Cropping](object-cropping.md) 🚀 NEW: Explore object cropping using YOLOv8 for precise extraction of objects from images and videos.
- [Object Blurring](object-blurring.md) 🚀 NEW: Apply object blurring with YOLOv8 for privacy protection in image and video processing.
@ -58,6 +60,7 @@ Here's a compilation of in-depth guides to help you master different aspects of
- [Distance Calculation](distance-calculation.md) 🚀 NEW: Distance calculation, which involves measuring the separation between two objects within a defined space, is a crucial aspect. In the context of Ultralytics YOLOv8, the method employed for this involves using the bounding box centroid to determine the distance associated with user-highlighted bounding boxes.
- [Queue Management](queue-management.md) 🚀 NEW: Queue management is the practice of efficiently controlling and directing the flow of people or tasks, often through strategic planning and technology implementation, to minimize wait times and improve overall productivity.
- [Parking Management](parking-management.md) 🚀 NEW: Parking management involves efficiently organizing and directing the flow of vehicles in parking areas, often through strategic planning and technology integration, to optimize space utilization and enhance user experience.
- [Analytics](analytics.md) 📊 NEW: Analytics involves the systematic computational analysis of data or statistics. It is used for discovering, interpreting, and communicating significant patterns in data, and for applying data patterns towards effective decision-making. Analytics can be descriptive, predictive, or prescriptive in nature, and it is integral to data-driven strategies in various industries.
## Contribute to Our Guides

@ -121,7 +121,7 @@ Monitoring workouts through pose estimation with [Ultralytics YOLOv8](https://gi
| `view_img` | `bool` | `False` | Flag to display the image. |
| `pose_up_angle` | `float` | `145.0` | Angle threshold for the 'up' pose. |
| `pose_down_angle` | `float` | `90.0` | Angle threshold for the 'down' pose. |
| `pose_type` | `str` | `pullup` | Type of pose to detect (`'pullup`', `pushup`, `abworkout`). |
| `pose_type` | `str` | `pullup` | Type of pose to detect (`'pullup`', `pushup`, `abworkout`, `squat`). |
### Arguments `model.predict`

@ -406,7 +406,7 @@ The below table contains valid Ultralytics image formats.
| Image Suffixes | Example Predict Command | Reference |
|----------------|----------------------------------|----------------------------------------------------------------------------|
| `.bmp` | `yolo predict source=image.bmp` | [Microsoft BMP File Format](https://en.wikipedia.org/wiki/BMP_file_format) |
| `.dng` | `yolo predict source=image.dng` | [Adobe DNG](https://helpx.adobe.com/camera-raw/digital-negative.html) |
| `.dng` | `yolo predict source=image.dng` | [Adobe DNG](https://en.wikipedia.org/wiki/Digital_Negative) |
| `.jpeg` | `yolo predict source=image.jpeg` | [JPEG](https://en.wikipedia.org/wiki/JPEG) |
| `.jpg` | `yolo predict source=image.jpg` | [JPEG](https://en.wikipedia.org/wiki/JPEG) |
| `.mpo` | `yolo predict source=image.mpo` | [Multi Picture Object](https://fileinfo.com/extension/mpo) |

@ -158,7 +158,7 @@ nav:
- datasets/index.md
- Guides:
- guides/index.md
- New 🚀 Parking Management: guides/parking-management.md
- New 🚀 Analytics: guides/analytics.md
- Explorer:
- datasets/explorer/index.md
- Languages:
@ -291,6 +291,7 @@ nav:
- Viewing Inference Images in a Terminal: guides/view-results-in-terminal.md
- OpenVINO Latency vs Throughput modes: guides/optimizing-openvino-latency-vs-throughput-modes.md
- Real-World Projects:
- NEW 🚀 Analytics: guides/analytics.md
- Object Counting: guides/object-counting.md
- Object Cropping: guides/object-cropping.md
- Object Blurring: guides/object-blurring.md

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.2.19"
__version__ = "8.2.20"
from ultralytics.data.explorer.explorer import Explorer
from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld

@ -1,6 +1,7 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from .ai_gym import AIGym
from .analytics import Analytics
from .distance_calculation import DistanceCalculation
from .heatmap import Heatmap
from .object_counter import ObjectCounter
@ -16,4 +17,5 @@ __all__ = (
"ParkingManagement",
"QueueManager",
"SpeedEstimator",
"Analytics",
)

@ -73,11 +73,11 @@ class AIGym:
self.stage = ["-" for _ in results[0]]
self.keypoints = results[0].keypoints.data
self.annotator = Annotator(im0, line_width=2)
self.annotator = Annotator(im0, line_width=self.tf)
for ind, k in enumerate(reversed(self.keypoints)):
# Estimate angle and draw specific points based on pose type
if self.pose_type in {"pushup", "pullup", "abworkout"}:
if self.pose_type in {"pushup", "pullup", "abworkout", "squat"}:
self.angle[ind] = self.annotator.estimate_pose_angle(
k[int(self.kpts_to_check[0])].cpu(),
k[int(self.kpts_to_check[1])].cpu(),
@ -93,7 +93,7 @@ class AIGym:
self.stage[ind] = "up"
self.count[ind] += 1
elif self.pose_type == "pushup":
elif self.pose_type == "pushup" or self.pose_type == "squat":
if self.angle[ind] > self.poseup_angle:
self.stage[ind] = "up"
if self.angle[ind] < self.posedown_angle and self.stage[ind] == "up":

@ -0,0 +1,197 @@
from itertools import cycle
import cv2
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from matplotlib.figure import Figure
class Analytics:
"""A class to create and update various types of charts (line, bar, pie) for visual analytics."""
def __init__(
self,
type,
writer,
im0_shape,
title="ultralytics",
x_label="x",
y_label="y",
bg_color="white",
fg_color="black",
line_color="yellow",
line_width=2,
fontsize=13,
view_img=False,
save_img=True,
):
"""
Initialize the Analytics class with various chart types.
Args:
type (str): Type of chart to initialize ('line', 'bar', or 'pie').
writer: Video writer object to save the frames.
im0_shape (tuple): Shape of the input image (width, height).
title (str): Title of the chart.
x_label (str): Label for the x-axis.
y_label (str): Label for the y-axis.
bg_color (str): Background color of the chart.
fg_color (str): Foreground (text) color of the chart.
line_color (str): Line color for line charts.
line_width (int): Width of the lines in line charts.
fontsize (int): Font size for chart text.
view_img (bool): Whether to display the image.
save_img (bool): Whether to save the image.
"""
self.bg_color = bg_color
self.fg_color = fg_color
self.view_img = view_img
self.save_img = save_img
self.title = title
self.writer = writer
# Set figure size based on image shape
figsize = (im0_shape[0] / 100, im0_shape[1] / 100)
if type == "line":
# Initialize line plot
fig = Figure(facecolor=self.bg_color, figsize=figsize)
self.canvas = FigureCanvas(fig)
self.ax = fig.add_subplot(111, facecolor=self.bg_color)
(self.line,) = self.ax.plot([], [], color=line_color, linewidth=line_width)
elif type == "bar" or type == "pie":
# Initialize bar or pie plot
self.fig, self.ax = plt.subplots(figsize=figsize, facecolor=self.bg_color)
self.ax.set_facecolor(self.bg_color)
color_palette = [
(31, 119, 180),
(255, 127, 14),
(44, 160, 44),
(214, 39, 40),
(148, 103, 189),
(140, 86, 75),
(227, 119, 194),
(127, 127, 127),
(188, 189, 34),
(23, 190, 207),
]
self.color_palette = [(r / 255, g / 255, b / 255, 1) for r, g, b in color_palette]
self.color_cycle = cycle(self.color_palette)
self.color_mapping = {}
# Ensure pie chart is circular
self.ax.axis("equal") if type == "pie" else None
# Set common axis properties
self.ax.set_title(self.title, color=self.fg_color, fontsize=fontsize)
self.ax.set_xlabel(x_label, color=self.fg_color, fontsize=fontsize - 3)
self.ax.set_ylabel(y_label, color=self.fg_color, fontsize=fontsize - 3)
self.ax.tick_params(axis="both", colors=self.fg_color)
def update_line(self, frame_number, total_counts):
"""
Update the line graph with new data.
Args:
frame_number (int): The current frame number.
total_counts (int): The total counts to plot.
"""
# Update line graph data
x_data = self.line.get_xdata()
y_data = self.line.get_ydata()
x_data = np.append(x_data, float(frame_number))
y_data = np.append(y_data, float(total_counts))
self.line.set_data(x_data, y_data)
self.ax.relim()
self.ax.autoscale_view()
self.canvas.draw()
im0 = np.array(self.canvas.renderer.buffer_rgba())
im0 = cv2.cvtColor(im0[:, :, :3], cv2.COLOR_RGBA2BGR)
# Display and save the updated graph
cv2.imshow(self.title, im0) if self.view_img else None
self.writer.write(im0) if self.save_img else None
def update_bar(self, count_dict):
"""
Update the bar graph with new data.
Args:
count_dict (dict): Dictionary containing the count data to plot.
"""
# Update bar graph data
self.ax.clear()
self.ax.set_facecolor(self.bg_color)
labels = list(count_dict.keys())
counts = list(count_dict.values())
# Map labels to colors
for label in labels:
if label not in self.color_mapping:
self.color_mapping[label] = next(self.color_cycle)
colors = [self.color_mapping[label] for label in labels]
bars = self.ax.bar(labels, counts, color=colors)
for bar, count in zip(bars, counts):
self.ax.text(
bar.get_x() + bar.get_width() / 2,
bar.get_height(),
str(count),
ha="center",
va="bottom",
color=self.fg_color,
)
# Display and save the updated graph
canvas = FigureCanvas(self.fig)
canvas.draw()
buf = canvas.buffer_rgba()
im0 = np.asarray(buf)
im0 = cv2.cvtColor(im0, cv2.COLOR_RGBA2BGR)
self.writer.write(im0) if self.save_img else None
cv2.imshow(self.title, im0) if self.view_img else None
def update_pie(self, classes_dict):
"""
Update the pie chart with new data.
Args:
classes_dict (dict): Dictionary containing the class data to plot.
"""
# Update pie chart data
labels = list(classes_dict.keys())
sizes = list(classes_dict.values())
total = sum(sizes)
percentages = [size / total * 100 for size in sizes]
start_angle = 90
self.ax.clear()
# Create pie chart without labels inside the slices
wedges, autotexts = self.ax.pie(sizes, autopct=None, startangle=start_angle, textprops={"color": self.fg_color})
# Construct legend labels with percentages
legend_labels = [f"{label} ({percentage:.1f}%)" for label, percentage in zip(labels, percentages)]
self.ax.legend(wedges, legend_labels, title="Classes", loc="center left", bbox_to_anchor=(1, 0, 0.5, 1))
# Adjust layout to fit the legend
self.fig.tight_layout()
self.fig.subplots_adjust(left=0.1, right=0.75)
# Display and save the updated chart
im0 = self.fig.canvas.draw()
im0 = np.array(self.fig.canvas.renderer.buffer_rgba())
im0 = cv2.cvtColor(im0[:, :, :3], cv2.COLOR_RGBA2BGR)
self.writer.write(im0) if self.save_img else None
cv2.imshow(self.title, im0) if self.view_img else None
if __name__ == "__main__":
Analytics("line", writer=None, im0_shape=None)
Loading…
Cancel
Save