`ultralytics 8.3.30` run TAL on CPU if `torch.OutOfMemoryError` (#17515)

Co-authored-by: Muhammad Rizwan Munawar <muhammadrizwanmunawar123@gmail.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
pull/17257/head^2 v8.3.30
Laughing 2 weeks ago committed by GitHub
parent 1a5c35366e
commit f43c211ab4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 2
      ultralytics/__init__.py
  2. 40
      ultralytics/utils/tal.py

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.3.29"
__version__ = "8.3.30"
import os

@ -58,17 +58,45 @@ class TaskAlignedAssigner(nn.Module):
"""
self.bs = pd_scores.shape[0]
self.n_max_boxes = gt_bboxes.shape[1]
device = gt_bboxes.device
if self.n_max_boxes == 0:
device = gt_bboxes.device
return (
torch.full_like(pd_scores[..., 0], self.bg_idx).to(device),
torch.zeros_like(pd_bboxes).to(device),
torch.zeros_like(pd_scores).to(device),
torch.zeros_like(pd_scores[..., 0]).to(device),
torch.zeros_like(pd_scores[..., 0]).to(device),
torch.full_like(pd_scores[..., 0], self.bg_idx),
torch.zeros_like(pd_bboxes),
torch.zeros_like(pd_scores),
torch.zeros_like(pd_scores[..., 0]),
torch.zeros_like(pd_scores[..., 0]),
)
try:
return self._forward(pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt)
except torch.OutOfMemoryError:
# Move tensors to CPU, compute, then move back to original device
cpu_tensors = [t.cpu() for t in (pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt)]
result = self._forward(*cpu_tensors)
return tuple(t.to(device) for t in result)
def _forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):
"""
Compute the task-aligned assignment. Reference code is available at
https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/assigner/tal_assigner.py.
Args:
pd_scores (Tensor): shape(bs, num_total_anchors, num_classes)
pd_bboxes (Tensor): shape(bs, num_total_anchors, 4)
anc_points (Tensor): shape(num_total_anchors, 2)
gt_labels (Tensor): shape(bs, n_max_boxes, 1)
gt_bboxes (Tensor): shape(bs, n_max_boxes, 4)
mask_gt (Tensor): shape(bs, n_max_boxes, 1)
Returns:
target_labels (Tensor): shape(bs, num_total_anchors)
target_bboxes (Tensor): shape(bs, num_total_anchors, 4)
target_scores (Tensor): shape(bs, num_total_anchors, num_classes)
fg_mask (Tensor): shape(bs, num_total_anchors)
target_gt_idx (Tensor): shape(bs, num_total_anchors)
"""
mask_pos, align_metric, overlaps = self.get_pos_mask(
pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt
)

Loading…
Cancel
Save