OpenMMLab Detection Toolbox and Benchmark
https://mmdetection.readthedocs.io/
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
47 lines
1.3 KiB
47 lines
1.3 KiB
import numpy as np |
|
import pytest |
|
import torch |
|
|
|
from mmdet.core.mask.structures import BitmapMasks, PolygonMasks |
|
from mmdet.core.utils import mask2ndarray |
|
|
|
|
|
def dummy_raw_polygon_masks(size): |
|
""" |
|
Args: |
|
size (tuple): expected shape of dummy masks, (N, H, W) |
|
|
|
Return: |
|
list[list[ndarray]]: dummy mask |
|
""" |
|
num_obj, heigt, width = size |
|
polygons = [] |
|
for _ in range(num_obj): |
|
num_points = np.random.randint(5) * 2 + 6 |
|
polygons.append([np.random.uniform(0, min(heigt, width), num_points)]) |
|
return polygons |
|
|
|
|
|
def test_mask2ndarray(): |
|
raw_masks = np.ones((3, 28, 28)) |
|
bitmap_mask = BitmapMasks(raw_masks, 28, 28) |
|
output_mask = mask2ndarray(bitmap_mask) |
|
assert np.allclose(raw_masks, output_mask) |
|
|
|
raw_masks = dummy_raw_polygon_masks((3, 28, 28)) |
|
polygon_masks = PolygonMasks(raw_masks, 28, 28) |
|
output_mask = mask2ndarray(polygon_masks) |
|
assert output_mask.shape == (3, 28, 28) |
|
|
|
raw_masks = np.ones((3, 28, 28)) |
|
output_mask = mask2ndarray(raw_masks) |
|
assert np.allclose(raw_masks, output_mask) |
|
|
|
raw_masks = torch.ones((3, 28, 28)) |
|
output_mask = mask2ndarray(raw_masks) |
|
assert np.allclose(raw_masks, output_mask) |
|
|
|
# test unsupported type |
|
raw_masks = [] |
|
with pytest.raises(TypeError): |
|
output_mask = mask2ndarray(raw_masks)
|
|
|