parent
568783d2eb
commit
4eeb430504
2 changed files with 316 additions and 0 deletions
@ -0,0 +1,59 @@ |
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
|
||||
import copy |
||||
|
||||
import paddlers.transforms as T |
||||
from testing_utils import CpuCommonTest |
||||
from data import build_input_from_file |
||||
|
||||
|
||||
class TestMatchHistograms(CpuCommonTest): |
||||
def setUp(self): |
||||
self.inputs = [ |
||||
build_input_from_file( |
||||
'data/ssmt/test_mixed_binary.txt', prefix='./data/ssmt') |
||||
] |
||||
|
||||
def test_output_shape(self): |
||||
decoder = T.DecodeImg() |
||||
for input in copy.deepcopy(self.inputs): |
||||
for sample in input: |
||||
sample = decoder(sample) |
||||
im_out = T.functions.match_histograms(sample['image'], |
||||
sample['image2']) |
||||
self.check_output_equal(im_out.shape, sample['image2'].shape) |
||||
im_out = T.functions.match_histograms(sample['image2'], |
||||
sample['image']) |
||||
self.check_output_equal(im_out.shape, sample['image'].shape) |
||||
|
||||
|
||||
class TestMatchByRegression(CpuCommonTest): |
||||
def setUp(self): |
||||
self.inputs = [ |
||||
build_input_from_file( |
||||
'data/ssmt/test_mixed_binary.txt', prefix='./data/ssmt') |
||||
] |
||||
|
||||
def test_output_shape(self): |
||||
decoder = T.DecodeImg() |
||||
for input in copy.deepcopy(self.inputs): |
||||
for sample in input: |
||||
sample = decoder(sample) |
||||
im_out = T.functions.match_by_regression(sample['image'], |
||||
sample['image2']) |
||||
self.check_output_equal(im_out.shape, sample['image2'].shape) |
||||
im_out = T.functions.match_by_regression(sample['image2'], |
||||
sample['image']) |
||||
self.check_output_equal(im_out.shape, sample['image'].shape) |
@ -0,0 +1,257 @@ |
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
|
||||
import inspect |
||||
import copy |
||||
|
||||
import numpy as np |
||||
|
||||
import paddlers.transforms as T |
||||
from testing_utils import CpuCommonTest |
||||
from data import build_input_from_file |
||||
|
||||
WHITE_LIST = [] |
||||
|
||||
|
||||
def _add_op_tests(cls): |
||||
""" |
||||
Automatically patch testing functions for transform operators. |
||||
""" |
||||
|
||||
for op_name in T.operators.__all__: |
||||
op_class = getattr(T.operators, op_name) |
||||
if isinstance(op_class, type) and issubclass(op_class, |
||||
T.operators.Transform): |
||||
if op_class is T.DecodeImg or op_class in WHITE_LIST or op_name in WHITE_LIST: |
||||
continue |
||||
attr_name = 'test_' + op_name |
||||
if hasattr(cls, attr_name): |
||||
continue |
||||
# If the operator cannot be initialized with default parameters, skip it |
||||
for key, param in inspect.signature( |
||||
op_class.__init__).parameters.items(): |
||||
if key == 'self': |
||||
continue |
||||
if param.default is param.empty: |
||||
break |
||||
else: |
||||
filter_ = OP2FILTER.get(op_name, None) |
||||
setattr( |
||||
cls, attr_name, make_test_func( |
||||
op_class, filter_=filter_)) |
||||
return cls |
||||
|
||||
|
||||
def make_test_func(op_class, |
||||
*args, |
||||
in_hook=None, |
||||
out_hook=None, |
||||
filter_=None, |
||||
**kwargs): |
||||
def _test_func(self): |
||||
op = op_class(*args, **kwargs) |
||||
decoder = T.DecodeImg() |
||||
inputs = map(decoder, copy.deepcopy(self.inputs)) |
||||
for i, input_ in enumerate(inputs): |
||||
if filter_ is not None: |
||||
input_ = filter_(input_) |
||||
with self.subTest(i=i): |
||||
for sample in input_: |
||||
if in_hook: |
||||
sample = in_hook(sample) |
||||
sample = op(sample) |
||||
if out_hook: |
||||
sample = out_hook(sample) |
||||
|
||||
return _test_func |
||||
|
||||
|
||||
class _InputFilter(object): |
||||
def __init__(self, conds): |
||||
self.conds = conds |
||||
|
||||
def __call__(self, samples): |
||||
for sample in samples: |
||||
for cond in self.conds: |
||||
if cond(sample): |
||||
yield sample |
||||
|
||||
def __or__(self, filter): |
||||
return _InputFilter(self.conds + filter.conds) |
||||
|
||||
def __and__(self, filter): |
||||
return _InputFilter( |
||||
[cond for cond in self.conds if cond in filter.conds]) |
||||
|
||||
def get_sample(self, input): |
||||
return input[0] |
||||
|
||||
|
||||
def _is_optical(sample): |
||||
return sample['image'].shape[2] == 3 |
||||
|
||||
|
||||
def _is_sar(sample): |
||||
return sample['image'].shape[2] == 1 |
||||
|
||||
|
||||
def _is_multispectral(sample): |
||||
return sample['image'].shape[2] > 3 |
||||
|
||||
|
||||
def _is_mt(sample): |
||||
return 'image2' in sample |
||||
|
||||
|
||||
_filter_only_optical = _InputFilter([_is_optical]) |
||||
_filter_only_sar = _InputFilter([_is_sar]) |
||||
_filter_only_multispectral = _InputFilter([_is_multispectral]) |
||||
_filter_no_multispectral = _filter_only_optical | _filter_only_sar |
||||
_filter_no_sar = _filter_only_optical | _filter_only_multispectral |
||||
_filter_no_optical = _filter_only_sar | _filter_only_multispectral |
||||
_filter_only_mt = _InputFilter([_is_mt]) |
||||
|
||||
OP2FILTER = { |
||||
'RandomSwap': _filter_only_mt, |
||||
'SelectBand': _filter_no_sar, |
||||
'Dehaze': _filter_only_optical, |
||||
'Normalize': _filter_only_optical, |
||||
'RandomDistort': _filter_only_optical |
||||
} |
||||
|
||||
|
||||
@_add_op_tests |
||||
class TestTransform(CpuCommonTest): |
||||
def setUp(self): |
||||
self.inputs = [ |
||||
build_input_from_file( |
||||
'data/ssst/test_optical_clas.txt', |
||||
prefix='./data/ssst'), build_input_from_file( |
||||
'data/ssst/test_sar_clas.txt', |
||||
prefix='./data/ssst'), build_input_from_file( |
||||
'data/ssst/test_multispectral_clas.txt', |
||||
prefix='./data/ssst'), build_input_from_file( |
||||
'data/ssst/test_optical_seg.txt', |
||||
prefix='./data/ssst'), build_input_from_file( |
||||
'data/ssst/test_sar_seg.txt', |
||||
prefix='./data/ssst'), build_input_from_file( |
||||
'data/ssst/test_multispectral_seg.txt', |
||||
prefix='./data/ssst'), |
||||
build_input_from_file( |
||||
'data/ssst/test_optical_det.txt', |
||||
prefix='./data/ssst', |
||||
label_list='data/ssst/labels_det.txt'), build_input_from_file( |
||||
'data/ssst/test_sar_det.txt', |
||||
prefix='./data/ssst', |
||||
label_list='data/ssst/labels_det.txt'), |
||||
build_input_from_file( |
||||
'data/ssst/test_multispectral_det.txt', |
||||
prefix='./data/ssst', |
||||
label_list='data/ssst/labels_det.txt'), build_input_from_file( |
||||
'data/ssmt/test_mixed_binary.txt', |
||||
prefix='./data/ssmt'), build_input_from_file( |
||||
'data/ssmt/test_mixed_multiclass.txt', |
||||
prefix='./data/ssmt'), build_input_from_file( |
||||
'data/ssmt/test_mixed_multitask.txt', |
||||
prefix='./data/ssmt') |
||||
] |
||||
|
||||
def test_DecodeImg(self): |
||||
decoder = T.DecodeImg(to_rgb=True) |
||||
for i, input in enumerate(self.inputs): |
||||
with self.subTest(i=i): |
||||
for sample in input: |
||||
sample = decoder(sample) |
||||
# Check type |
||||
self.assertIsInstance(sample['image'], np.ndarray) |
||||
if 'mask' in sample: |
||||
self.assertIsInstance(sample['mask'], np.ndarray) |
||||
if 'aux_masks' in sample: |
||||
for aux_mask in sample['aux_masks']: |
||||
self.assertIsInstance(aux_mask, np.ndarray) |
||||
# TODO: Check dtype |
||||
|
||||
def test_Resize(self): |
||||
TARGET_SIZE = (128, 128) |
||||
|
||||
def _in_hook(sample): |
||||
self.image_shape = sample['image'].shape |
||||
if 'mask' in sample: |
||||
self.mask_shape = sample['mask'].shape |
||||
self.mask_values = set(sample['mask'].ravel()) |
||||
if 'aux_masks' in sample: |
||||
self.aux_mask_shapes = [ |
||||
aux_mask.shape for aux_mask in sample['aux_masks'] |
||||
] |
||||
self.aux_mask_values = [ |
||||
set(aux_mask.ravel()) for aux_mask in sample['aux_masks'] |
||||
] |
||||
return sample |
||||
|
||||
def _out_hook_not_keep_ratio(sample): |
||||
self.check_output_equal(sample['image'].shape[:2], TARGET_SIZE) |
||||
if 'image2' in sample: |
||||
self.check_output_equal(sample['image2'].shape[:2], TARGET_SIZE) |
||||
if 'mask' in sample: |
||||
self.check_output_equal(sample['mask'].shape[:2], TARGET_SIZE) |
||||
self.assertLessEqual( |
||||
set(sample['mask'].ravel()), self.mask_values) |
||||
if 'aux_masks' in sample: |
||||
for aux_mask in sample['aux_masks']: |
||||
self.check_output_equal(aux_mask.shape[:2], TARGET_SIZE) |
||||
for aux_mask, amv in zip(sample['aux_masks'], |
||||
self.aux_mask_values): |
||||
self.assertLessEqual(set(aux_mask.ravel()), amv) |
||||
# TODO: Test gt_bbox and gt_poly |
||||
return sample |
||||
|
||||
def _out_hook_keep_ratio(sample): |
||||
def __check_ratio(shape1, shape2): |
||||
self.check_output_equal(shape1[0] / shape1[1], |
||||
shape2[0] / shape2[1]) |
||||
|
||||
__check_ratio(sample['image'].shape, self.image_shape) |
||||
if 'image2' in sample: |
||||
__check_ratio(sample['image2'].shape, self.image_shape) |
||||
if 'mask' in sample: |
||||
__check_ratio(sample['mask'].shape, self.mask_shape) |
||||
if 'aux_masks' in sample: |
||||
for aux_mask, ori_aux_mask_shape in zip(sample['aux_masks'], |
||||
self.aux_mask_shapes): |
||||
__check_ratio(aux_mask.shape, ori_aux_mask_shape) |
||||
# TODO: Test gt_bbox and gt_poly |
||||
return sample |
||||
|
||||
test_func_not_keep_ratio = make_test_func( |
||||
T.Resize, |
||||
in_hook=_in_hook, |
||||
out_hook=_out_hook_not_keep_ratio, |
||||
target_size=TARGET_SIZE, |
||||
keep_ratio=False) |
||||
test_func_not_keep_ratio(self) |
||||
test_func_keep_ratio = make_test_func( |
||||
T.Resize, |
||||
in_hook=_in_hook, |
||||
out_hook=_out_hook_keep_ratio, |
||||
target_size=TARGET_SIZE, |
||||
keep_ratio=True) |
||||
test_func_keep_ratio(self) |
||||
|
||||
|
||||
class TestCompose(CpuCommonTest): |
||||
pass |
||||
|
||||
|
||||
class TestArrange(CpuCommonTest): |
||||
pass |
Loading…
Reference in new issue