From 0af7f25a73a2c56e8823489d4f4d4fcc11116b9c Mon Sep 17 00:00:00 2001 From: Wei Ji <23487320+weiji14@users.noreply.github.com> Date: Wed, 6 Apr 2022 08:51:13 -0400 Subject: [PATCH] [Fix] Allow mixed precision training with SimOTAAssigner (#7516) * Convert valid_pred_scores to float32 in sim_ota_assigner Workaround to resolve `RuntimeError: "sqrt" "_vml_cpu" not implemented for 'Half'` * Add unit test for SimOTAAssigner * Cast output of binary_cross_entropy back to float16 if needed Also fix yapf lint issue. * Lint for yapf * More lint fixes on test_assigner.py * Cast cls_scores back to float16 directly Co-Authored-By: Wenwei Zhang Co-authored-by: Wenwei Zhang --- mmdet/core/bbox/assigners/sim_ota_assigner.py | 9 ++++++--- tests/test_utils/test_assigner.py | 19 +++++++++++++++++-- 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/mmdet/core/bbox/assigners/sim_ota_assigner.py b/mmdet/core/bbox/assigners/sim_ota_assigner.py index 263abfcd8..79b3b719e 100644 --- a/mmdet/core/bbox/assigners/sim_ota_assigner.py +++ b/mmdet/core/bbox/assigners/sim_ota_assigner.py @@ -157,9 +157,12 @@ class SimOTAAssigner(BaseAssigner): num_valid, 1, 1)) valid_pred_scores = valid_pred_scores.unsqueeze(1).repeat(1, num_gt, 1) - cls_cost = F.binary_cross_entropy( - valid_pred_scores.sqrt_(), gt_onehot_label, - reduction='none').sum(-1) + cls_cost = ( + F.binary_cross_entropy( + valid_pred_scores.to(dtype=torch.float32).sqrt_(), + gt_onehot_label, + reduction='none', + ).sum(-1).to(dtype=valid_pred_scores.dtype)) cost_matrix = ( cls_cost * self.cls_weight + iou_cost * self.iou_weight + diff --git a/tests/test_utils/test_assigner.py b/tests/test_utils/test_assigner.py index 0124c3b3f..07faa928b 100644 --- a/tests/test_utils/test_assigner.py +++ b/tests/test_utils/test_assigner.py @@ -11,8 +11,8 @@ import torch from mmdet.core.bbox.assigners import (ApproxMaxIoUAssigner, CenterRegionAssigner, HungarianAssigner, MaskHungarianAssigner, MaxIoUAssigner, - PointAssigner, TaskAlignedAssigner, - UniformAssigner) + PointAssigner, SimOTAAssigner, + TaskAlignedAssigner, UniformAssigner) def test_max_iou_assigner(): @@ -500,6 +500,21 @@ def test_uniform_assigner_with_empty_boxes(): assert len(assign_result.gt_inds) == 0 +def test_sim_ota_assigner(): + self = SimOTAAssigner( + center_radius=2.5, candidate_topk=1, iou_weight=3.0, cls_weight=1.0) + pred_scores = torch.FloatTensor([[0.2], [0.8]]) + priors = torch.Tensor([[0, 12, 23, 34], [4, 5, 6, 7]]) + decoded_bboxes = torch.Tensor([[[30, 40, 50, 60]], [[4, 5, 6, 7]]]) + gt_bboxes = torch.Tensor([[23.6667, 23.8757, 238.6326, 151.8874]]) + gt_labels = torch.LongTensor([2]) + assign_result = self.assign(pred_scores, priors, decoded_bboxes, gt_bboxes, + gt_labels) + + expected_gt_inds = torch.LongTensor([0, 0]) + assert torch.all(assign_result.gt_inds == expected_gt_inds) + + def test_task_aligned_assigner(): with pytest.raises(AssertionError): TaskAlignedAssigner(topk=0)