Strip `dfl_loss` from `BboxLoss` (#14041)

Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
pull/14074/head^2
Laughing 8 months ago committed by GitHub
parent f533d77611
commit f5ccddf5df
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 8
      docs/en/guides/model-evaluation-insights.md
  2. 4
      docs/en/reference/utils/loss.md
  3. 77
      ultralytics/utils/loss.py

@ -8,9 +8,9 @@ keywords: Model Evaluation, Machine Learning Model Evaluation, Fine Tuning Machi
## Introduction
Once youve [trained](./model-training-tips.md) your computer vision model, evaluating and refining it to perform optimally is essential. Just training your model isnt enough. You need to make sure that your model is accurate, efficient, and fulfills the [objective](./defining-project-goals.md) of your computer vision project. By evaluating and fine-tuning your model, you can identify weaknesses, improve its accuracy, and boost overall performance.
Once you've [trained](./model-training-tips.md) your computer vision model, evaluating and refining it to perform optimally is essential. Just training your model isn't enough. You need to make sure that your model is accurate, efficient, and fulfills the [objective](./defining-project-goals.md) of your computer vision project. By evaluating and fine-tuning your model, you can identify weaknesses, improve its accuracy, and boost overall performance.
In this guide, we’ll share insights on model evaluation and fine-tuning that’ll make this [step of a computer vision project](./steps-of-a-cv-project.md) more approachable. Well discuss how to understand evaluation metrics and implement fine-tuning techniques, giving you the knowledge to elevate your model's capabilities.
In this guide, we'll share insights on model evaluation and fine-tuning that'll make this [step of a computer vision project](./steps-of-a-cv-project.md) more approachable. We'll discuss how to understand evaluation metrics and implement fine-tuning techniques, giving you the knowledge to elevate your model's capabilities.
## Evaluating Model Performance Using Metrics
@ -34,7 +34,7 @@ Intersection over Union (IoU) is a metric in object detection that measures how
Mean Average Precision (mAP) is a way to measure how well an object detection model performs. It looks at the precision of detecting each object class, averages these scores, and gives an overall number that shows how accurately the model can identify and classify objects.
Lets focus on two specific mAP metrics:
Let's focus on two specific mAP metrics:
- *mAP@.5:* Measures the average precision at a single IoU (Intersection over Union) threshold of 0.5. This metric checks if the model can correctly find objects with a looser accuracy requirement. It focuses on whether the object is roughly in the right place, not needing perfect placement. It helps see if the model is generally good at spotting objects.
- *mAP@.5:.95:* Averages the mAP values calculated at multiple IoU thresholds, from 0.5 to 0.95 in 0.05 increments. This metric is more detailed and strict. It gives a fuller picture of how accurately the model can find objects at different levels of strictness and is especially useful for applications that need precise object detection.
@ -136,4 +136,4 @@ Sharing your ideas and questions with other computer vision enthusiasts can insp
## Final Thoughts
Evaluating and fine-tuning your computer vision model are important steps for successful model deployment. These steps help make sure that your model is accurate, efficient, and suited to your overall application. The key to training the best model possible is continuous experimentation and learning. Dont hesitate to tweak parameters, try new techniques, and explore different datasets. Keep experimenting and pushing the boundaries of what's possible!
Evaluating and fine-tuning your computer vision model are important steps for successful model deployment. These steps help make sure that your model is accurate, efficient, and suited to your overall application. The key to training the best model possible is continuous experimentation and learning. Don't hesitate to tweak parameters, try new techniques, and explore different datasets. Keep experimenting and pushing the boundaries of what's possible!

@ -19,6 +19,10 @@ keywords: Ultralytics, loss functions, Varifocal Loss, Focal Loss, Bbox Loss, Ro
<br><br>
## ::: ultralytics.utils.loss.DFLoss
<br><br>
## ::: ultralytics.utils.loss.BboxLoss
<br><br>

@ -61,39 +61,22 @@ class FocalLoss(nn.Module):
return loss.mean(1).sum()
class BboxLoss(nn.Module):
"""Criterion class for computing training losses during training."""
class DFLoss(nn.Module):
"""Criterion class for computing DFL losses during training."""
def __init__(self, reg_max, use_dfl=False):
"""Initialize the BboxLoss module with regularization maximum and DFL settings."""
def __init__(self, reg_max=16) -> None:
"""Initialize the DFL module."""
super().__init__()
self.reg_max = reg_max
self.use_dfl = use_dfl
def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):
"""IoU loss."""
weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, CIoU=True)
loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum
# DFL loss
if self.use_dfl:
target_ltrb = bbox2dist(anchor_points, target_bboxes, self.reg_max)
loss_dfl = self._df_loss(pred_dist[fg_mask].view(-1, self.reg_max + 1), target_ltrb[fg_mask]) * weight
loss_dfl = loss_dfl.sum() / target_scores_sum
else:
loss_dfl = torch.tensor(0.0).to(pred_dist.device)
return loss_iou, loss_dfl
@staticmethod
def _df_loss(pred_dist, target):
def __call__(self, pred_dist, target):
"""
Return sum of left and right DFL losses.
Distribution Focal Loss (DFL) proposed in Generalized Focal Loss
https://ieeexplore.ieee.org/document/9792391
"""
target = target.clamp_(0, self.reg_max - 1 - 0.01)
tl = target.long() # target left
tr = tl + 1 # target right
wl = tr - target # weight left
@ -104,12 +87,37 @@ class BboxLoss(nn.Module):
).mean(-1, keepdim=True)
class BboxLoss(nn.Module):
"""Criterion class for computing training losses during training."""
def __init__(self, reg_max=16):
"""Initialize the BboxLoss module with regularization maximum and DFL settings."""
super().__init__()
self.dfl_loss = DFLoss(reg_max) if reg_max > 1 else None
def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):
"""IoU loss."""
weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, CIoU=True)
loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum
# DFL loss
if self.dfl_loss:
target_ltrb = bbox2dist(anchor_points, target_bboxes, self.dfl_loss.reg_max - 1)
loss_dfl = self.dfl_loss(pred_dist[fg_mask].view(-1, self.dfl_loss.reg_max), target_ltrb[fg_mask]) * weight
loss_dfl = loss_dfl.sum() / target_scores_sum
else:
loss_dfl = torch.tensor(0.0).to(pred_dist.device)
return loss_iou, loss_dfl
class RotatedBboxLoss(BboxLoss):
"""Criterion class for computing training losses during training."""
def __init__(self, reg_max, use_dfl=False):
def __init__(self, reg_max):
"""Initialize the BboxLoss module with regularization maximum and DFL settings."""
super().__init__(reg_max, use_dfl)
super().__init__(reg_max)
def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):
"""IoU loss."""
@ -118,9 +126,9 @@ class RotatedBboxLoss(BboxLoss):
loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum
# DFL loss
if self.use_dfl:
target_ltrb = bbox2dist(anchor_points, xywh2xyxy(target_bboxes[..., :4]), self.reg_max)
loss_dfl = self._df_loss(pred_dist[fg_mask].view(-1, self.reg_max + 1), target_ltrb[fg_mask]) * weight
if self.dfl_loss:
target_ltrb = bbox2dist(anchor_points, xywh2xyxy(target_bboxes[..., :4]), self.dfl_loss.reg_max - 1)
loss_dfl = self.dfl_loss(pred_dist[fg_mask].view(-1, self.dfl_loss.reg_max), target_ltrb[fg_mask]) * weight
loss_dfl = loss_dfl.sum() / target_scores_sum
else:
loss_dfl = torch.tensor(0.0).to(pred_dist.device)
@ -165,18 +173,19 @@ class v8DetectionLoss:
self.use_dfl = m.reg_max > 1
self.assigner = TaskAlignedAssigner(topk=tal_topk, num_classes=self.nc, alpha=0.5, beta=6.0)
self.bbox_loss = BboxLoss(m.reg_max - 1, use_dfl=self.use_dfl).to(device)
self.bbox_loss = BboxLoss(m.reg_max).to(device)
self.proj = torch.arange(m.reg_max, dtype=torch.float, device=device)
def preprocess(self, targets, batch_size, scale_tensor):
"""Preprocesses the target counts and matches with the input batch size to output a tensor."""
if targets.shape[0] == 0:
out = torch.zeros(batch_size, 0, 5, device=self.device)
nl, ne = targets.shape
if nl == 0:
out = torch.zeros(batch_size, 0, ne - 1, device=self.device)
else:
i = targets[:, 0] # image index
_, counts = i.unique(return_counts=True)
counts = counts.to(dtype=torch.int32)
out = torch.zeros(batch_size, counts.max(), 5, device=self.device)
out = torch.zeros(batch_size, counts.max(), ne - 1, device=self.device)
for j in range(batch_size):
matches = i == j
n = matches.sum()
@ -592,7 +601,7 @@ class v8ClassificationLoss:
def __call__(self, preds, batch):
"""Compute the classification loss between predictions and true labels."""
loss = torch.nn.functional.cross_entropy(preds, batch["cls"], reduction="mean")
loss = F.cross_entropy(preds, batch["cls"], reduction="mean")
loss_items = loss.detach()
return loss, loss_items
@ -606,7 +615,7 @@ class v8OBBLoss(v8DetectionLoss):
"""
super().__init__(model)
self.assigner = RotatedTaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=0.5, beta=6.0)
self.bbox_loss = RotatedBboxLoss(self.reg_max - 1, use_dfl=self.use_dfl).to(self.device)
self.bbox_loss = RotatedBboxLoss(self.reg_max).to(self.device)
def preprocess(self, targets, batch_size, scale_tensor):
"""Preprocesses the target counts and matches with the input batch size to output a tensor."""

Loading…
Cancel
Save