[Fix] Allow mixed precision training with SimOTAAssigner (#7516)

* Convert valid_pred_scores to float32 in sim_ota_assigner

Workaround to resolve `RuntimeError: "sqrt" "_vml_cpu" not implemented for 'Half'`

* Add unit test for SimOTAAssigner

* Cast output of binary_cross_entropy back to float16 if needed

Also fix yapf lint issue.

* Lint for yapf

* More lint fixes on test_assigner.py

* Cast cls_scores back to float16 directly

Co-Authored-By: Wenwei Zhang <ZwwWayne@users.noreply.github.com>

Co-authored-by: Wenwei Zhang <ZwwWayne@users.noreply.github.com>
pull/7502/head
Wei Ji 3 years ago committed by GitHub
parent b252574e20
commit 0af7f25a73
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 9
      mmdet/core/bbox/assigners/sim_ota_assigner.py
  2. 19
      tests/test_utils/test_assigner.py

@ -157,9 +157,12 @@ class SimOTAAssigner(BaseAssigner):
num_valid, 1, 1))
valid_pred_scores = valid_pred_scores.unsqueeze(1).repeat(1, num_gt, 1)
cls_cost = F.binary_cross_entropy(
valid_pred_scores.sqrt_(), gt_onehot_label,
reduction='none').sum(-1)
cls_cost = (
F.binary_cross_entropy(
valid_pred_scores.to(dtype=torch.float32).sqrt_(),
gt_onehot_label,
reduction='none',
).sum(-1).to(dtype=valid_pred_scores.dtype))
cost_matrix = (
cls_cost * self.cls_weight + iou_cost * self.iou_weight +

@ -11,8 +11,8 @@ import torch
from mmdet.core.bbox.assigners import (ApproxMaxIoUAssigner,
CenterRegionAssigner, HungarianAssigner,
MaskHungarianAssigner, MaxIoUAssigner,
PointAssigner, TaskAlignedAssigner,
UniformAssigner)
PointAssigner, SimOTAAssigner,
TaskAlignedAssigner, UniformAssigner)
def test_max_iou_assigner():
@ -500,6 +500,21 @@ def test_uniform_assigner_with_empty_boxes():
assert len(assign_result.gt_inds) == 0
def test_sim_ota_assigner():
self = SimOTAAssigner(
center_radius=2.5, candidate_topk=1, iou_weight=3.0, cls_weight=1.0)
pred_scores = torch.FloatTensor([[0.2], [0.8]])
priors = torch.Tensor([[0, 12, 23, 34], [4, 5, 6, 7]])
decoded_bboxes = torch.Tensor([[[30, 40, 50, 60]], [[4, 5, 6, 7]]])
gt_bboxes = torch.Tensor([[23.6667, 23.8757, 238.6326, 151.8874]])
gt_labels = torch.LongTensor([2])
assign_result = self.assign(pred_scores, priors, decoded_bboxes, gt_bboxes,
gt_labels)
expected_gt_inds = torch.LongTensor([0, 0])
assert torch.all(assign_result.gt_inds == expected_gt_inds)
def test_task_aligned_assigner():
with pytest.raises(AssertionError):
TaskAlignedAssigner(topk=0)

Loading…
Cancel
Save