Added a `max_size` parameter to the `plot_images` function (#14002)

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
pull/14011/head
bobyard-com 9 months ago committed by GitHub
parent 87dba199b2
commit 08acdd198a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 84
      ultralytics/utils/plotting.py

@ -4,6 +4,7 @@ import contextlib
import math import math
import warnings import warnings
from pathlib import Path from pathlib import Path
from typing import Callable, Dict, List, Optional, Union
import cv2 import cv2
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
@ -579,7 +580,8 @@ class Annotator:
def display_analytics(self, im0, text, txt_color, bg_color, margin): def display_analytics(self, im0, text, txt_color, bg_color, margin):
""" """
Display the overall statistics for parking lots Display the overall statistics for parking lots.
Args: Args:
im0 (ndarray): inference image im0 (ndarray): inference image
text (dict): labels dictionary text (dict): labels dictionary
@ -661,7 +663,7 @@ class Annotator:
angle_text (str): angle value for workout monitoring angle_text (str): angle value for workout monitoring
count_text (str): counts value for workout monitoring count_text (str): counts value for workout monitoring
stage_text (str): stage decision for workout monitoring stage_text (str): stage decision for workout monitoring
center_kpt (int): centroid pose index for workout monitoring center_kpt (list): centroid pose index for workout monitoring
color (tuple): text background color for workout monitoring color (tuple): text background color for workout monitoring
txt_color (tuple): text foreground color for workout monitoring txt_color (tuple): text foreground color for workout monitoring
""" """
@ -917,22 +919,49 @@ def save_one_box(xyxy, im, file=Path("im.jpg"), gain=1.02, pad=10, square=False,
@threaded @threaded
def plot_images( def plot_images(
images, images: Union[torch.Tensor, np.ndarray],
batch_idx, batch_idx: Union[torch.Tensor, np.ndarray],
cls, cls: Union[torch.Tensor, np.ndarray],
bboxes=np.zeros(0, dtype=np.float32), bboxes: Union[torch.Tensor, np.ndarray] = np.zeros(0, dtype=np.float32),
confs=None, confs: Optional[Union[torch.Tensor, np.ndarray]] = None,
masks=np.zeros(0, dtype=np.uint8), masks: Union[torch.Tensor, np.ndarray] = np.zeros(0, dtype=np.uint8),
kpts=np.zeros((0, 51), dtype=np.float32), kpts: Union[torch.Tensor, np.ndarray] = np.zeros((0, 51), dtype=np.float32),
paths=None, paths: Optional[List[str]] = None,
fname="images.jpg", fname: str = "images.jpg",
names=None, names: Optional[Dict[int, str]] = None,
on_plot=None, on_plot: Optional[Callable] = None,
max_subplots=16, max_size: int = 1920,
save=True, max_subplots: int = 16,
conf_thres=0.25, save: bool = True,
): conf_thres: float = 0.25,
"""Plot image grid with labels.""" ) -> Optional[np.ndarray]:
"""
Plot image grid with labels, bounding boxes, masks, and keypoints.
Args:
images: Batch of images to plot. Shape: (batch_size, channels, height, width).
batch_idx: Batch indices for each detection. Shape: (num_detections,).
cls: Class labels for each detection. Shape: (num_detections,).
bboxes: Bounding boxes for each detection. Shape: (num_detections, 4) or (num_detections, 5) for rotated boxes.
confs: Confidence scores for each detection. Shape: (num_detections,).
masks: Instance segmentation masks. Shape: (num_detections, height, width) or (1, height, width).
kpts: Keypoints for each detection. Shape: (num_detections, 51).
paths: List of file paths for each image in the batch.
fname: Output filename for the plotted image grid.
names: Dictionary mapping class indices to class names.
on_plot: Optional callback function to be called after saving the plot.
max_size: Maximum size of the output image grid.
max_subplots: Maximum number of subplots in the image grid.
save: Whether to save the plotted image grid to a file.
conf_thres: Confidence threshold for displaying detections.
Returns:
np.ndarray: Plotted image grid as a numpy array if save is False, None otherwise.
Note:
This function supports both tensor and numpy array inputs. It will automatically
convert tensor inputs to numpy arrays for processing.
"""
if isinstance(images, torch.Tensor): if isinstance(images, torch.Tensor):
images = images.cpu().float().numpy() images = images.cpu().float().numpy()
if isinstance(cls, torch.Tensor): if isinstance(cls, torch.Tensor):
@ -946,7 +975,6 @@ def plot_images(
if isinstance(batch_idx, torch.Tensor): if isinstance(batch_idx, torch.Tensor):
batch_idx = batch_idx.cpu().numpy() batch_idx = batch_idx.cpu().numpy()
max_size = 1920 # max image size
bs, _, h, w = images.shape # batch size, _, height, width bs, _, h, w = images.shape # batch size, _, height, width
bs = min(bs, max_subplots) # limit plot images bs = min(bs, max_subplots) # limit plot images
ns = np.ceil(bs**0.5) # number of subplots (square) ns = np.ceil(bs**0.5) # number of subplots (square)
@ -1166,6 +1194,12 @@ def plot_tune_results(csv_file="tune_results.csv"):
import pandas as pd # scope for faster 'import ultralytics' import pandas as pd # scope for faster 'import ultralytics'
from scipy.ndimage import gaussian_filter1d from scipy.ndimage import gaussian_filter1d
def _save_one_file(file):
"""Save one matplotlib plot to 'file'."""
plt.savefig(file, dpi=200)
plt.close()
LOGGER.info(f"Saved {file}")
# Scatter plots for each hyperparameter # Scatter plots for each hyperparameter
csv_file = Path(csv_file) csv_file = Path(csv_file)
data = pd.read_csv(csv_file) data = pd.read_csv(csv_file)
@ -1186,11 +1220,7 @@ def plot_tune_results(csv_file="tune_results.csv"):
plt.tick_params(axis="both", labelsize=8) # Set axis label size to 8 plt.tick_params(axis="both", labelsize=8) # Set axis label size to 8
if i % n != 0: if i % n != 0:
plt.yticks([]) plt.yticks([])
_save_one_file(csv_file.with_name("tune_scatter_plots.png"))
file = csv_file.with_name("tune_scatter_plots.png") # filename
plt.savefig(file, dpi=200)
plt.close()
LOGGER.info(f"Saved {file}")
# Fitness vs iteration # Fitness vs iteration
x = range(1, len(fitness) + 1) x = range(1, len(fitness) + 1)
@ -1202,11 +1232,7 @@ def plot_tune_results(csv_file="tune_results.csv"):
plt.ylabel("Fitness") plt.ylabel("Fitness")
plt.grid(True) plt.grid(True)
plt.legend() plt.legend()
_save_one_file(csv_file.with_name("tune_fitness.png"))
file = csv_file.with_name("tune_fitness.png") # filename
plt.savefig(file, dpi=200)
plt.close()
LOGGER.info(f"Saved {file}")
def output_to_target(output, max_det=300): def output_to_target(output, max_det=300):

Loading…
Cancel
Save