OpenMMLab Detection Toolbox and Benchmark
https://mmdetection.readthedocs.io/
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
75 lines
2.5 KiB
75 lines
2.5 KiB
import pytest |
|
import torch |
|
|
|
from mmdet.core.post_processing import mask_matrix_nms |
|
|
|
|
|
def _create_mask(N, h, w): |
|
masks = torch.rand((N, h, w)) > 0.5 |
|
labels = torch.rand(N) |
|
scores = torch.rand(N) |
|
return masks, labels, scores |
|
|
|
|
|
def test_nms_input_errors(): |
|
with pytest.raises(AssertionError): |
|
mask_matrix_nms( |
|
torch.rand((10, 28, 28)), torch.rand(11), torch.rand(11)) |
|
with pytest.raises(AssertionError): |
|
masks = torch.rand((10, 28, 28)) |
|
mask_matrix_nms( |
|
masks, |
|
torch.rand(11), |
|
torch.rand(11), |
|
mask_area=masks.sum((1, 2)).float()[:8]) |
|
with pytest.raises(NotImplementedError): |
|
mask_matrix_nms( |
|
torch.rand((10, 28, 28)), |
|
torch.rand(10), |
|
torch.rand(10), |
|
kernel='None') |
|
# test an empty results |
|
masks, labels, scores = _create_mask(0, 28, 28) |
|
score, label, mask, keep_ind = \ |
|
mask_matrix_nms(masks, labels, scores) |
|
assert len(score) == len(label) == \ |
|
len(mask) == len(keep_ind) == 0 |
|
|
|
# do not use update_thr, nms_pre and max_num |
|
masks, labels, scores = _create_mask(1000, 28, 28) |
|
score, label, mask, keep_ind = \ |
|
mask_matrix_nms(masks, labels, scores) |
|
assert len(score) == len(label) == \ |
|
len(mask) == len(keep_ind) == 1000 |
|
# only use nms_pre |
|
score, label, mask, keep_ind = \ |
|
mask_matrix_nms(masks, labels, scores, nms_pre=500) |
|
assert len(score) == len(label) == \ |
|
len(mask) == len(keep_ind) == 500 |
|
# use max_num |
|
score, label, mask, keep_ind = \ |
|
mask_matrix_nms(masks, labels, scores, |
|
nms_pre=500, max_num=100) |
|
assert len(score) == len(label) == \ |
|
len(mask) == len(keep_ind) == 100 |
|
|
|
masks, labels, _ = _create_mask(1, 28, 28) |
|
scores = torch.Tensor([1.0]) |
|
masks = masks.expand(1000, 28, 28) |
|
labels = labels.expand(1000) |
|
scores = scores.expand(1000) |
|
|
|
# assert scores is decayed and update_thr is worked |
|
# if with the same mask, label, and all scores = 1 |
|
# the first score will set to 1, others will decay. |
|
score, label, mask, keep_ind = \ |
|
mask_matrix_nms(masks, |
|
labels, |
|
scores, |
|
nms_pre=500, |
|
max_num=100, |
|
kernel='gaussian', |
|
sigma=2.0, |
|
filter_thr=0.5) |
|
assert len(score) == 1 |
|
assert score[0] == 1
|
|
|