OpenMMLab Detection Toolbox and Benchmark https://mmdetection.readthedocs.io/
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

75 lines
2.5 KiB

import pytest
import torch
from mmdet.core.post_processing import mask_matrix_nms
def _create_mask(N, h, w):
masks = torch.rand((N, h, w)) > 0.5
labels = torch.rand(N)
scores = torch.rand(N)
return masks, labels, scores
def test_nms_input_errors():
with pytest.raises(AssertionError):
mask_matrix_nms(
torch.rand((10, 28, 28)), torch.rand(11), torch.rand(11))
with pytest.raises(AssertionError):
masks = torch.rand((10, 28, 28))
mask_matrix_nms(
masks,
torch.rand(11),
torch.rand(11),
mask_area=masks.sum((1, 2)).float()[:8])
with pytest.raises(NotImplementedError):
mask_matrix_nms(
torch.rand((10, 28, 28)),
torch.rand(10),
torch.rand(10),
kernel='None')
# test an empty results
masks, labels, scores = _create_mask(0, 28, 28)
score, label, mask, keep_ind = \
mask_matrix_nms(masks, labels, scores)
assert len(score) == len(label) == \
len(mask) == len(keep_ind) == 0
# do not use update_thr, nms_pre and max_num
masks, labels, scores = _create_mask(1000, 28, 28)
score, label, mask, keep_ind = \
mask_matrix_nms(masks, labels, scores)
assert len(score) == len(label) == \
len(mask) == len(keep_ind) == 1000
# only use nms_pre
score, label, mask, keep_ind = \
mask_matrix_nms(masks, labels, scores, nms_pre=500)
assert len(score) == len(label) == \
len(mask) == len(keep_ind) == 500
# use max_num
score, label, mask, keep_ind = \
mask_matrix_nms(masks, labels, scores,
nms_pre=500, max_num=100)
assert len(score) == len(label) == \
len(mask) == len(keep_ind) == 100
masks, labels, _ = _create_mask(1, 28, 28)
scores = torch.Tensor([1.0])
masks = masks.expand(1000, 28, 28)
labels = labels.expand(1000)
scores = scores.expand(1000)
# assert scores is decayed and update_thr is worked
# if with the same mask, label, and all scores = 1
# the first score will set to 1, others will decay.
score, label, mask, keep_ind = \
mask_matrix_nms(masks,
labels,
scores,
nms_pre=500,
max_num=100,
kernel='gaussian',
sigma=2.0,
filter_thr=0.5)
assert len(score) == 1
assert score[0] == 1