`ultralytics 8.3.81` Fix Metrics `on_plot` circular references (#19318)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: remipiche <rpiche@flyscan.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: Laughing-q <1185102784@qq.com>
Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
pull/19475/head^2 v8.3.81
RemiPT 2 weeks ago committed by GitHub
parent c15aabe762
commit 3fceec57be
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 2
      ultralytics/__init__.py
  2. 2
      ultralytics/models/fastsam/val.py
  3. 4
      ultralytics/models/yolo/detect/val.py
  4. 2
      ultralytics/models/yolo/obb/val.py
  5. 2
      ultralytics/models/yolo/pose/val.py
  6. 2
      ultralytics/models/yolo/segment/val.py
  7. 40
      ultralytics/utils/metrics.py

@ -1,6 +1,6 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
__version__ = "8.3.80"
__version__ = "8.3.81"
import os

@ -37,4 +37,4 @@ class FastSAMValidator(SegmentationValidator):
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
self.args.task = "segment"
self.args.plots = False # disable ConfusionMatrix and other plots to avoid errors
self.metrics = SegmentMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
self.metrics = SegmentMetrics(save_dir=self.save_dir)

@ -37,7 +37,7 @@ class DetectionValidator(BaseValidator):
self.is_lvis = False
self.class_map = None
self.args.task = "detect"
self.metrics = DetMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
self.metrics = DetMetrics(save_dir=self.save_dir)
self.iouv = torch.linspace(0.5, 0.95, 10) # IoU vector for mAP@0.5:0.95
self.niou = self.iouv.numel()
self.lb = [] # for autolabelling
@ -187,7 +187,7 @@ class DetectionValidator(BaseValidator):
self.nt_per_image = np.bincount(stats["target_img"].astype(int), minlength=self.nc)
stats.pop("target_img", None)
if len(stats):
self.metrics.process(**stats)
self.metrics.process(**stats, on_plot=self.on_plot)
return self.metrics.results_dict
def print_results(self):

@ -28,7 +28,7 @@ class OBBValidator(DetectionValidator):
"""Initialize OBBValidator and set task to 'obb', metrics to OBBMetrics."""
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
self.args.task = "obb"
self.metrics = OBBMetrics(save_dir=self.save_dir, plot=True, on_plot=self.on_plot)
self.metrics = OBBMetrics(save_dir=self.save_dir, plot=True)
def init_metrics(self, model):
"""Initialize evaluation metrics for YOLO."""

@ -32,7 +32,7 @@ class PoseValidator(DetectionValidator):
self.sigma = None
self.kpt_shape = None
self.args.task = "pose"
self.metrics = PoseMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
self.metrics = PoseMetrics(save_dir=self.save_dir)
if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
LOGGER.warning(
"WARNING ⚠ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "

@ -34,7 +34,7 @@ class SegmentationValidator(DetectionValidator):
self.plot_masks = None
self.process = None
self.args.task = "segment"
self.metrics = SegmentMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
self.metrics = SegmentMetrics(save_dir=self.save_dir)
def preprocess(self, batch):
"""Preprocesses batch by converting masks to float and sending to device."""

@ -803,13 +803,11 @@ class DetMetrics(SimpleClass):
Args:
save_dir (Path): A path to the directory where the output plots will be saved. Defaults to current directory.
plot (bool): A flag that indicates whether to plot precision-recall curves for each class. Defaults to False.
on_plot (func): An optional callback to pass plots path and data when they are rendered. Defaults to None.
names (dict of str): A dict of strings that represents the names of the classes. Defaults to an empty tuple.
Attributes:
save_dir (Path): A path to the directory where the output plots will be saved.
plot (bool): A flag that indicates whether to plot the precision-recall curves for each class.
on_plot (func): An optional callback to pass plots path and data when they are rendered.
names (dict of str): A dict of strings that represents the names of the classes.
box (Metric): An instance of the Metric class for storing the results of the detection metrics.
speed (dict): A dictionary for storing the execution time of different parts of the detection process.
@ -827,17 +825,16 @@ class DetMetrics(SimpleClass):
curves_results: TODO
"""
def __init__(self, save_dir=Path("."), plot=False, on_plot=None, names={}) -> None:
def __init__(self, save_dir=Path("."), plot=False, names={}) -> None:
"""Initialize a DetMetrics instance with a save directory, plot flag, callback function, and class names."""
self.save_dir = save_dir
self.plot = plot
self.on_plot = on_plot
self.names = names
self.box = Metric()
self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
self.task = "detect"
def process(self, tp, conf, pred_cls, target_cls):
def process(self, tp, conf, pred_cls, target_cls, on_plot=None):
"""Process predicted results for object detection and update metrics."""
results = ap_per_class(
tp,
@ -847,7 +844,7 @@ class DetMetrics(SimpleClass):
plot=self.plot,
save_dir=self.save_dir,
names=self.names,
on_plot=self.on_plot,
on_plot=on_plot,
)[2:]
self.box.nc = len(self.names)
self.box.update(results)
@ -903,13 +900,11 @@ class SegmentMetrics(SimpleClass):
Args:
save_dir (Path): Path to the directory where the output plots should be saved. Default is the current directory.
plot (bool): Whether to save the detection and segmentation plots. Default is False.
on_plot (func): An optional callback to pass plots path and data when they are rendered. Defaults to None.
names (list): List of class names. Default is an empty list.
Attributes:
save_dir (Path): Path to the directory where the output plots should be saved.
plot (bool): Whether to save the detection and segmentation plots.
on_plot (func): An optional callback to pass plots path and data when they are rendered.
names (list): List of class names.
box (Metric): An instance of the Metric class to calculate box detection metrics.
seg (Metric): An instance of the Metric class to calculate mask segmentation metrics.
@ -925,18 +920,17 @@ class SegmentMetrics(SimpleClass):
results_dict: Returns the dictionary containing all the detection and segmentation metrics and fitness score.
"""
def __init__(self, save_dir=Path("."), plot=False, on_plot=None, names=()) -> None:
def __init__(self, save_dir=Path("."), plot=False, names=()) -> None:
"""Initialize a SegmentMetrics instance with a save directory, plot flag, callback function, and class names."""
self.save_dir = save_dir
self.plot = plot
self.on_plot = on_plot
self.names = names
self.box = Metric()
self.seg = Metric()
self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
self.task = "segment"
def process(self, tp, tp_m, conf, pred_cls, target_cls):
def process(self, tp, tp_m, conf, pred_cls, target_cls, on_plot=None):
"""
Processes the detection and segmentation metrics over the given set of predictions.
@ -946,6 +940,7 @@ class SegmentMetrics(SimpleClass):
conf (list): List of confidence scores.
pred_cls (list): List of predicted classes.
target_cls (list): List of target classes.
on_plot (func): An optional callback to pass plots path and data when they are rendered. Defaults to None.
"""
results_mask = ap_per_class(
tp_m,
@ -953,7 +948,7 @@ class SegmentMetrics(SimpleClass):
pred_cls,
target_cls,
plot=self.plot,
on_plot=self.on_plot,
on_plot=on_plot,
save_dir=self.save_dir,
names=self.names,
prefix="Mask",
@ -966,7 +961,7 @@ class SegmentMetrics(SimpleClass):
pred_cls,
target_cls,
plot=self.plot,
on_plot=self.on_plot,
on_plot=on_plot,
save_dir=self.save_dir,
names=self.names,
prefix="Box",
@ -1043,13 +1038,11 @@ class PoseMetrics(SegmentMetrics):
Args:
save_dir (Path): Path to the directory where the output plots should be saved. Default is the current directory.
plot (bool): Whether to save the detection and segmentation plots. Default is False.
on_plot (func): An optional callback to pass plots path and data when they are rendered. Defaults to None.
names (list): List of class names. Default is an empty list.
Attributes:
save_dir (Path): Path to the directory where the output plots should be saved.
plot (bool): Whether to save the detection and segmentation plots.
on_plot (func): An optional callback to pass plots path and data when they are rendered.
names (list): List of class names.
box (Metric): An instance of the Metric class to calculate box detection metrics.
pose (Metric): An instance of the Metric class to calculate mask segmentation metrics.
@ -1065,19 +1058,18 @@ class PoseMetrics(SegmentMetrics):
results_dict: Returns the dictionary containing all the detection and segmentation metrics and fitness score.
"""
def __init__(self, save_dir=Path("."), plot=False, on_plot=None, names=()) -> None:
def __init__(self, save_dir=Path("."), plot=False, names=()) -> None:
"""Initialize the PoseMetrics class with directory path, class names, and plotting options."""
super().__init__(save_dir, plot, names)
self.save_dir = save_dir
self.plot = plot
self.on_plot = on_plot
self.names = names
self.box = Metric()
self.pose = Metric()
self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
self.task = "pose"
def process(self, tp, tp_p, conf, pred_cls, target_cls):
def process(self, tp, tp_p, conf, pred_cls, target_cls, on_plot=None):
"""
Processes the detection and pose metrics over the given set of predictions.
@ -1087,6 +1079,7 @@ class PoseMetrics(SegmentMetrics):
conf (list): List of confidence scores.
pred_cls (list): List of predicted classes.
target_cls (list): List of target classes.
on_plot (func): An optional callback to pass plots path and data when they are rendered. Defaults to None.
"""
results_pose = ap_per_class(
tp_p,
@ -1094,7 +1087,7 @@ class PoseMetrics(SegmentMetrics):
pred_cls,
target_cls,
plot=self.plot,
on_plot=self.on_plot,
on_plot=on_plot,
save_dir=self.save_dir,
names=self.names,
prefix="Pose",
@ -1107,7 +1100,7 @@ class PoseMetrics(SegmentMetrics):
pred_cls,
target_cls,
plot=self.plot,
on_plot=self.on_plot,
on_plot=on_plot,
save_dir=self.save_dir,
names=self.names,
prefix="Box",
@ -1226,16 +1219,15 @@ class ClassifyMetrics(SimpleClass):
class OBBMetrics(SimpleClass):
"""Metrics for evaluating oriented bounding box (OBB) detection, see https://arxiv.org/pdf/2106.06072.pdf."""
def __init__(self, save_dir=Path("."), plot=False, on_plot=None, names=()) -> None:
def __init__(self, save_dir=Path("."), plot=False, names=()) -> None:
"""Initialize an OBBMetrics instance with directory, plotting, callback, and class names."""
self.save_dir = save_dir
self.plot = plot
self.on_plot = on_plot
self.names = names
self.box = Metric()
self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
def process(self, tp, conf, pred_cls, target_cls):
def process(self, tp, conf, pred_cls, target_cls, on_plot=None):
"""Process predicted results for object detection and update metrics."""
results = ap_per_class(
tp,
@ -1245,7 +1237,7 @@ class OBBMetrics(SimpleClass):
plot=self.plot,
save_dir=self.save_dir,
names=self.names,
on_plot=self.on_plot,
on_plot=on_plot,
)[2:]
self.box.nc = len(self.names)
self.box.update(results)

Loading…
Cancel
Save