Add predictor unittests

own
Bobholamovic 2 years ago
parent 40a5edc039
commit 1a35b297af
  1. 36
      deploy/export/export_model.py
  2. 6
      paddlers/deploy/predictor.py
  3. 2
      paddlers/tasks/change_detector.py
  4. 1
      paddlers/tasks/classifier.py
  5. 8
      paddlers/tasks/object_detector.py
  6. 1
      paddlers/tasks/segmenter.py
  7. 13
      tests/deploy/test_export.py
  8. 13
      tests/deploy/test_predict.py
  9. 351
      tests/deploy/test_predictor.py
  10. 2
      tests/tools/test_match.py
  11. 2
      tests/tools/test_oif.py
  12. 2
      tests/tools/test_pca.py
  13. 2
      tests/tools/test_split.py
  14. 4
      tests/transforms/test_functions.py
  15. 58
      tests/transforms/test_operators.py

@ -21,9 +21,23 @@ from paddlers.tasks import load_model
def get_parser():
parser = argparse.ArgumentParser()
parser.add_argument('--model_dir', '-m', type=str, default=None, help='model directory path')
parser.add_argument('--save_dir', '-s', type=str, default=None, help='path to save inference model')
parser.add_argument('--fixed_input_shape', '-fs', type=str, default=None,
parser.add_argument(
'--model_dir',
'-m',
type=str,
default=None,
help='model directory path')
parser.add_argument(
'--save_dir',
'-s',
type=str,
default=None,
help='path to save inference model')
parser.add_argument(
'--fixed_input_shape',
'-fs',
type=str,
default=None,
help="export inference model with fixed input shape: [w,h] or [n,c,w,h]")
return parser
@ -39,13 +53,17 @@ if __name__ == '__main__':
fixed_input_shape = literal_eval(args.fixed_input_shape)
# Check validaty
if not isinstance(fixed_input_shape, list):
raise ValueError("fixed_input_shape should be of None or list type.")
raise ValueError(
"fixed_input_shape should be of None or list type.")
if len(fixed_input_shape) not in (2, 4):
raise ValueError("fixed_input_shape contains an incorrect number of elements.")
raise ValueError(
"fixed_input_shape contains an incorrect number of elements.")
if fixed_input_shape[-1] <= 0 or fixed_input_shape[-2] <= 0:
raise ValueError("the input width and height must be positive integers.")
if len(fixed_input_shape)==4 and fixed_input_shape[1] <= 0:
raise ValueError("the number of input channels must be a positive integer.")
raise ValueError(
"Input width and height must be positive integers.")
if len(fixed_input_shape) == 4 and fixed_input_shape[1] <= 0:
raise ValueError(
"The number of input channels must be a positive integer.")
# Set environment variables
os.environ['PADDLEX_EXPORT_STAGE'] = 'True'
@ -56,4 +74,4 @@ if __name__ == '__main__':
# Do dynamic-to-static cast
# XXX: Invoke a protected (single underscore) method outside of subclasses.
model._export_inference_model(args.save_dir, fixed_input_shape)
model._export_inference_model(args.save_dir, fixed_input_shape)

@ -175,9 +175,9 @@ class Predictor(object):
if self._model._postprocess is None:
self._model.build_postprocess_from_labels(topk)
# XXX: Convert ndarray to tensor as self._model._postprocess requires
net_outputs = paddle.to_tensor(net_outputs)
assert net_outputs.shape[1] == 1
outputs = self._model._postprocess(net_outputs.squeeze(1))
assert len(net_outputs) == 1
net_outputs = paddle.to_tensor(net_outputs[0])
outputs = self._model._postprocess(net_outputs)
class_ids = map(itemgetter('class_ids'), outputs)
scores = map(itemgetter('scores'), outputs)
label_names = map(itemgetter('label_names'), outputs)

@ -650,6 +650,8 @@ class BaseChangeDetector(BaseModel):
if isinstance(sample['image_t1'], str) or \
isinstance(sample['image_t2'], str):
sample = DecodeImg(to_rgb=False)(sample)
sample['image'] = sample['image'].astype('float32')
sample['image2'] = sample['image2'].astype('float32')
ori_shape = sample['image'].shape[:2]
else:
ori_shape = im1.shape[:2]

@ -468,6 +468,7 @@ class BaseClassifier(BaseModel):
sample = {'image': im}
if isinstance(sample['image'], str):
sample = DecodeImg(to_rgb=False)(sample)
sample['image'] = sample['image'].astype('float32')
ori_shape = sample['image'].shape[:2]
im = transforms(sample)
batch_im.append(im)

@ -27,7 +27,7 @@ import paddlers.models.ppdet as ppdet
from paddlers.models.ppdet.modeling.proposal_generator.target_layer import BBoxAssigner, MaskAssigner
import paddlers
import paddlers.utils.logging as logging
from paddlers.transforms.operators import _NormalizeBox, _PadBox, _BboxXYXY2XYWH, Resize, Pad
from paddlers.transforms.operators import _NormalizeBox, _PadBox, _BboxXYXY2XYWH, Resize, Pad, DecodeImg
from paddlers.transforms.batch_operators import BatchCompose, BatchRandomResize, BatchRandomResizeByShort, \
_BatchPad, _Gt2YoloTarget
from paddlers.transforms import arrange_transforms
@ -550,7 +550,11 @@ class BaseDetector(BaseModel):
batch_samples = list()
for im in images:
sample = {'image': im}
batch_samples.append(transforms(sample))
if isinstance(sample['image'], str):
sample = DecodeImg(to_rgb=False)(sample)
sample['image'] = sample['image'].astype('float32')
sample = transforms(sample)
batch_samples.append(sample)
batch_transforms = self._compose_batch_transform(transforms, 'test')
batch_samples = batch_transforms(batch_samples)
if to_tensor:

@ -614,6 +614,7 @@ class BaseSegmenter(BaseModel):
sample = {'image': im}
if isinstance(sample['image'], str):
sample = DecodeImg(to_rgb=False)(sample)
sample['image'] = sample['image'].astype('float32')
ori_shape = sample['image'].shape[:2]
im = transforms(sample)[0]
batch_im.append(im)

@ -1,13 +0,0 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -1,13 +0,0 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -0,0 +1,351 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import unittest.mock as mock
import cv2
import paddle
import paddlers as pdrs
from testing_utils import CommonTest, run_script
class TestPredictor(CommonTest):
MODULE = pdrs.tasks
TRAINER_NAME_TO_EXPORT_OPTS = {}
@staticmethod
def add_tests(cls):
def _test_predictor(trainer_name):
def _test_predictor_impl(self):
trainer_class = getattr(self.MODULE, trainer_name)
# Construct trainer with default parameters
trainer = trainer_class()
with tempfile.TemporaryDirectory() as td:
dynamic_model_dir = f"{td}/dynamic"
static_model_dir = f"{td}/static"
# HACK: BaseModel.save_model() requires BaseModel().optimizer to be set
optimizer = mock.Mock()
optimizer.state_dict.return_value = {'foo': 'bar'}
trainer.optimizer = optimizer
trainer.save_model(dynamic_model_dir)
export_cmd = f"python export_model.py --model_dir {dynamic_model_dir} --save_dir {static_model_dir} "
if trainer_name in self.TRAINER_NAME_TO_EXPORT_OPTS:
export_cmd += self.TRAINER_NAME_TO_EXPORT_OPTS[
trainer_name]
elif '_default' in self.TRAINER_NAME_TO_EXPORT_OPTS:
export_cmd += self.TRAINER_NAME_TO_EXPORT_OPTS[
'_default']
run_script(export_cmd, wd="../deploy/export")
# Construct predictor
# TODO: Test trt and mkl
predictor = pdrs.deploy.Predictor(
static_model_dir,
use_gpu=paddle.device.get_device().startswith('gpu'))
self.check_predictor(predictor, trainer)
return _test_predictor_impl
for trainer_name in cls.MODULE.__all__:
setattr(cls, 'test_' + trainer_name, _test_predictor(trainer_name))
return cls
def check_predictor(self, predictor, trainer):
raise NotImplementedError
def check_dict_equal(self, dict_, expected_dict):
if isinstance(dict_, list):
self.assertIsInstance(expected_dict, list)
self.assertEqual(len(dict_), len(expected_dict))
for d1, d2 in zip(dict_, expected_dict):
self.check_dict_equal(d1, d2)
else:
assert isinstance(dict_, dict)
assert isinstance(expected_dict, dict)
self.assertEqual(dict_.keys(), expected_dict.keys())
for key in dict_.keys():
self.check_output_equal(dict_[key], expected_dict[key])
@TestPredictor.add_tests
class TestCDPredictor(TestPredictor):
MODULE = pdrs.tasks.change_detector
TRAINER_NAME_TO_EXPORT_OPTS = {
'BIT': "--fixed_input_shape [1,3,256,256]",
'_default': "--fixed_input_shape [-1,3,256,256]"
}
def check_predictor(self, predictor, trainer):
t1_path = "data/ssmt/optical_t1.bmp"
t2_path = "data/ssmt/optical_t2.bmp"
single_input = (t1_path, t2_path)
num_inputs = 2
transforms = pdrs.transforms.Compose([pdrs.transforms.Normalize()])
# Expected failure
with self.assertRaises(ValueError):
predictor.predict(t1_path, transforms=transforms)
# Single input (file paths)
input_ = single_input
out_single_file_p = predictor.predict(input_, transforms=transforms)
out_single_file_t = trainer.predict(input_, transforms=transforms)
self.check_dict_equal(out_single_file_p, out_single_file_t)
out_single_file_list_p = predictor.predict(
[input_], transforms=transforms)
self.assertEqual(len(out_single_file_list_p), 1)
self.check_dict_equal(out_single_file_list_p[0], out_single_file_p)
out_single_file_list_t = trainer.predict(
[input_], transforms=transforms)
self.check_dict_equal(out_single_file_list_p[0],
out_single_file_list_t[0])
# Single input (ndarrays)
input_ = (cv2.imread(t1_path).astype('float32'),
cv2.imread(t2_path).astype('float32')
) # Reuse the name `input_`
out_single_array_p = predictor.predict(input_, transforms=transforms)
self.check_dict_equal(out_single_array_p, out_single_file_p)
out_single_array_t = trainer.predict(input_, transforms=transforms)
self.check_dict_equal(out_single_array_p, out_single_array_t)
out_single_array_list_p = predictor.predict(
[input_], transforms=transforms)
self.assertEqual(len(out_single_array_list_p), 1)
self.check_dict_equal(out_single_array_list_p[0], out_single_array_p)
out_single_array_list_t = trainer.predict(
[input_], transforms=transforms)
self.check_dict_equal(out_single_array_list_p[0],
out_single_array_list_t[0])
if isinstance(trainer, pdrs.tasks.change_detector.BIT):
return
# Multiple inputs (file paths)
input_ = [single_input] * num_inputs # Reuse the name `input_`
out_multi_file_p = predictor.predict(input_, transforms=transforms)
self.assertEqual(len(out_multi_file_p), num_inputs)
out_multi_file_t = trainer.predict(input_, transforms=transforms)
self.check_dict_equal(out_multi_file_p, out_multi_file_t)
# Multiple inputs (ndarrays)
input_ = [(cv2.imread(t1_path).astype('float32'), cv2.imread(t2_path)
.astype('float32'))] * num_inputs # Reuse the name `input_`
out_multi_array_p = predictor.predict(input_, transforms=transforms)
self.assertEqual(len(out_multi_array_p), num_inputs)
out_multi_array_t = trainer.predict(input_, transforms=transforms)
self.check_dict_equal(out_multi_array_p, out_multi_array_t)
@TestPredictor.add_tests
class TestClasPredictor(TestPredictor):
MODULE = pdrs.tasks.classifier
TRAINER_NAME_TO_EXPORT_OPTS = {
'_default': "--fixed_input_shape [-1,3,256,256]"
}
def check_predictor(self, predictor, trainer):
single_input = "data/ssmt/optical_t1.bmp"
num_inputs = 2
transforms = pdrs.transforms.Compose([pdrs.transforms.Normalize()])
labels = list(range(2))
trainer.labels = labels
predictor._model.labels = labels
# Single input (file paths)
input_ = single_input
out_single_file_p = predictor.predict(input_, transforms=transforms)
out_single_file_t = trainer.predict(input_, transforms=transforms)
self.check_dict_equal(out_single_file_p, out_single_file_t)
out_single_file_list_p = predictor.predict(
[input_], transforms=transforms)
self.assertEqual(len(out_single_file_list_p), 1)
self.check_dict_equal(out_single_file_list_p[0], out_single_file_p)
out_single_file_list_t = trainer.predict(
[input_], transforms=transforms)
self.check_dict_equal(out_single_file_list_p[0],
out_single_file_list_t[0])
# Single input (ndarrays)
input_ = cv2.imread(single_input).astype(
'float32') # Reuse the name `input_`
out_single_array_p = predictor.predict(input_, transforms=transforms)
self.check_dict_equal(out_single_array_p, out_single_file_p)
out_single_array_t = trainer.predict(input_, transforms=transforms)
self.check_dict_equal(out_single_array_p, out_single_array_t)
out_single_array_list_p = predictor.predict(
[input_], transforms=transforms)
self.assertEqual(len(out_single_array_list_p), 1)
self.check_dict_equal(out_single_array_list_p[0], out_single_array_p)
out_single_array_list_t = trainer.predict(
[input_], transforms=transforms)
self.check_dict_equal(out_single_array_list_p[0],
out_single_array_list_t[0])
# Multiple inputs (file paths)
input_ = [single_input] * num_inputs # Reuse the name `input_`
out_multi_file_p = predictor.predict(input_, transforms=transforms)
self.assertEqual(len(out_multi_file_p), num_inputs)
out_multi_file_t = trainer.predict(input_, transforms=transforms)
self.assertEqual(len(out_multi_file_p), len(out_multi_file_t))
self.check_dict_equal(out_multi_file_p, out_multi_file_t)
# Multiple inputs (ndarrays)
input_ = [cv2.imread(single_input).astype('float32')
] * num_inputs # Reuse the name `input_`
out_multi_array_p = predictor.predict(input_, transforms=transforms)
self.assertEqual(len(out_multi_array_p), num_inputs)
out_multi_array_t = trainer.predict(input_, transforms=transforms)
self.assertEqual(len(out_multi_array_p), len(out_multi_array_t))
self.check_dict_equal(out_multi_array_p, out_multi_array_t)
@TestPredictor.add_tests
class TestDetPredictor(TestPredictor):
MODULE = pdrs.tasks.object_detector
TRAINER_NAME_TO_EXPORT_OPTS = {
'_default': "--fixed_input_shape [-1,3,256,256]"
}
def check_predictor(self, predictor, trainer):
single_input = "data/ssmt/optical_t1.bmp"
num_inputs = 2
transforms = pdrs.transforms.Compose([pdrs.transforms.Normalize()])
labels = list(range(80))
trainer.labels = labels
predictor._model.labels = labels
# Single input (file paths)
input_ = single_input
out_single_file_p = predictor.predict(input_, transforms=transforms)
out_single_file_t = trainer.predict(input_, transforms=transforms)
self.check_dict_equal(out_single_file_p, out_single_file_t)
out_single_file_list_p = predictor.predict(
[input_], transforms=transforms)
self.assertEqual(len(out_single_file_list_p), 1)
self.check_dict_equal(out_single_file_list_p[0], out_single_file_p)
out_single_file_list_t = trainer.predict(
[input_], transforms=transforms)
self.check_dict_equal(out_single_file_list_p[0],
out_single_file_list_t[0])
# Single input (ndarrays)
input_ = cv2.imread(single_input).astype(
'float32') # Reuse the name `input_`
out_single_array_p = predictor.predict(input_, transforms=transforms)
self.check_dict_equal(out_single_array_p, out_single_file_p)
out_single_array_t = trainer.predict(input_, transforms=transforms)
self.check_dict_equal(out_single_array_p, out_single_array_t)
out_single_array_list_p = predictor.predict(
[input_], transforms=transforms)
self.assertEqual(len(out_single_array_list_p), 1)
self.check_dict_equal(out_single_array_list_p[0], out_single_array_p)
out_single_array_list_t = trainer.predict(
[input_], transforms=transforms)
self.check_dict_equal(out_single_array_list_p[0],
out_single_array_list_t[0])
# Single input (ndarrays)
input_ = cv2.imread(single_input).astype(
'float32') # Reuse the name `input_`
out_single_array_p = predictor.predict(input_, transforms=transforms)
self.check_dict_equal(out_single_array_p, out_single_file_p)
out_single_array_t = trainer.predict(input_, transforms=transforms)
self.check_dict_equal(out_single_array_p, out_single_array_t)
out_single_array_list_p = predictor.predict(
[input_], transforms=transforms)
self.assertEqual(len(out_single_array_list_p), 1)
self.check_dict_equal(out_single_array_list_p[0], out_single_array_p)
out_single_array_list_t = trainer.predict(
[input_], transforms=transforms)
self.check_dict_equal(out_single_array_list_p[0],
out_single_array_list_t[0])
# Multiple inputs (file paths)
input_ = [single_input] * num_inputs # Reuse the name `input_`
out_multi_file_p = predictor.predict(input_, transforms=transforms)
self.assertEqual(len(out_multi_file_p), num_inputs)
out_multi_file_t = trainer.predict(input_, transforms=transforms)
self.assertEqual(len(out_multi_file_p), len(out_multi_file_t))
self.check_dict_equal(out_multi_file_p, out_multi_file_t)
# Multiple inputs (ndarrays)
input_ = [cv2.imread(single_input).astype('float32')
] * num_inputs # Reuse the name `input_`
out_multi_array_p = predictor.predict(input_, transforms=transforms)
self.assertEqual(len(out_multi_array_p), num_inputs)
out_multi_array_t = trainer.predict(input_, transforms=transforms)
self.assertEqual(len(out_multi_array_p), len(out_multi_array_t))
self.check_dict_equal(out_multi_array_p, out_multi_array_t)
@TestPredictor.add_tests
class TestSegPredictor(TestPredictor):
MODULE = pdrs.tasks.segmenter
TRAINER_NAME_TO_EXPORT_OPTS = {
'_default': "--fixed_input_shape [-1,3,256,256]"
}
def check_predictor(self, predictor, trainer):
single_input = "data/ssmt/optical_t1.bmp"
num_inputs = 2
transforms = pdrs.transforms.Compose([pdrs.transforms.Normalize()])
# Single input (file paths)
input_ = single_input
out_single_file_p = predictor.predict(input_, transforms=transforms)
out_single_file_t = trainer.predict(input_, transforms=transforms)
self.check_dict_equal(out_single_file_p, out_single_file_t)
out_single_file_list_p = predictor.predict(
[input_], transforms=transforms)
self.assertEqual(len(out_single_file_list_p), 1)
self.check_dict_equal(out_single_file_list_p[0], out_single_file_p)
out_single_file_list_t = trainer.predict(
[input_], transforms=transforms)
self.check_dict_equal(out_single_file_list_p[0],
out_single_file_list_t[0])
# Single input (ndarrays)
input_ = cv2.imread(single_input).astype(
'float32') # Reuse the name `input_`
out_single_array_p = predictor.predict(input_, transforms=transforms)
self.check_dict_equal(out_single_array_p, out_single_file_p)
out_single_array_t = trainer.predict(input_, transforms=transforms)
self.check_dict_equal(out_single_array_p, out_single_array_t)
out_single_array_list_p = predictor.predict(
[input_], transforms=transforms)
self.assertEqual(len(out_single_array_list_p), 1)
self.check_dict_equal(out_single_array_list_p[0], out_single_array_p)
out_single_array_list_t = trainer.predict(
[input_], transforms=transforms)
self.check_dict_equal(out_single_array_list_p[0],
out_single_array_list_t[0])
# Multiple inputs (file paths)
input_ = [single_input] * num_inputs # Reuse the name `input_`
out_multi_file_p = predictor.predict(input_, transforms=transforms)
self.assertEqual(len(out_multi_file_p), num_inputs)
out_multi_file_t = trainer.predict(input_, transforms=transforms)
self.assertEqual(len(out_multi_file_p), len(out_multi_file_t))
self.check_dict_equal(out_multi_file_p, out_multi_file_t)
# Multiple inputs (ndarrays)
input_ = [cv2.imread(single_input).astype('float32')
] * num_inputs # Reuse the name `input_`
out_multi_array_p = predictor.predict(input_, transforms=transforms)
self.assertEqual(len(out_multi_array_p), num_inputs)
out_multi_array_t = trainer.predict(input_, transforms=transforms)
self.assertEqual(len(out_multi_array_p), len(out_multi_array_t))
self.check_dict_equal(out_multi_array_p, out_multi_array_t)

@ -22,4 +22,4 @@ class TestMatch(CpuCommonTest):
with tempfile.TemporaryDirectory() as td:
run_script(
f"python match.py --im1_path ../tests/data/ssmt/multispectral_t1.tif --im2_path ../tests/data/ssmt/multispectral_t1.tif --save_path {td}/out.tiff",
wd='../tools')
wd="../tools")

@ -21,4 +21,4 @@ class TestOIF(CpuCommonTest):
def test_script(self):
run_script(
f"python oif.py --im_path ../tests/data/ssst/multispectral.tif",
wd='../tools')
wd="../tools")

@ -22,4 +22,4 @@ class TestPCA(CpuCommonTest):
with tempfile.TemporaryDirectory() as td:
run_script(
f"python pca.py --im_path ../tests/data/ssst/multispectral.tif --save_dir {td} --dim 5",
wd='../tools')
wd="../tools")

@ -22,4 +22,4 @@ class TestSplit(CpuCommonTest):
with tempfile.TemporaryDirectory() as td:
run_script(
f"python split.py --image_path ../tests/data/ssst/multispectral.tif --mask_path ../tests/data/ssst/multiclass_gt2.png --block_size 128 --save_dir {td}",
wd='../tools')
wd="../tools")

@ -23,7 +23,7 @@ class TestMatchHistograms(CpuCommonTest):
def setUp(self):
self.inputs = [
build_input_from_file(
'data/ssmt/test_mixed_binary.txt', prefix='./data/ssmt')
"data/ssmt/test_mixed_binary.txt", prefix="./data/ssmt")
]
def test_output_shape(self):
@ -43,7 +43,7 @@ class TestMatchByRegression(CpuCommonTest):
def setUp(self):
self.inputs = [
build_input_from_file(
'data/ssmt/test_mixed_binary.txt', prefix='./data/ssmt')
"data/ssmt/test_mixed_binary.txt", prefix="./data/ssmt")
]
def test_output_shape(self):

@ -136,47 +136,47 @@ class TestTransform(CpuCommonTest):
def setUp(self):
self.inputs = [
build_input_from_file(
'data/ssst/test_optical_clas.txt',
prefix='./data/ssst'),
"data/ssst/test_optical_clas.txt",
prefix="./data/ssst"),
build_input_from_file(
'data/ssst/test_sar_clas.txt',
prefix='./data/ssst'),
"data/ssst/test_sar_clas.txt",
prefix="./data/ssst"),
build_input_from_file(
'data/ssst/test_multispectral_clas.txt',
prefix='./data/ssst'),
"data/ssst/test_multispectral_clas.txt",
prefix="./data/ssst"),
build_input_from_file(
'data/ssst/test_optical_seg.txt',
prefix='./data/ssst'),
"data/ssst/test_optical_seg.txt",
prefix="./data/ssst"),
build_input_from_file(
'data/ssst/test_sar_seg.txt',
prefix='./data/ssst'),
"data/ssst/test_sar_seg.txt",
prefix="./data/ssst"),
build_input_from_file(
'data/ssst/test_multispectral_seg.txt',
prefix='./data/ssst'),
"data/ssst/test_multispectral_seg.txt",
prefix="./data/ssst"),
build_input_from_file(
'data/ssst/test_optical_det.txt',
prefix='./data/ssst',
label_list='data/ssst/labels_det.txt'),
"data/ssst/test_optical_det.txt",
prefix="./data/ssst",
label_list="data/ssst/labels_det.txt"),
build_input_from_file(
'data/ssst/test_sar_det.txt',
prefix='./data/ssst',
label_list='data/ssst/labels_det.txt'),
"data/ssst/test_sar_det.txt",
prefix="./data/ssst",
label_list="data/ssst/labels_det.txt"),
build_input_from_file(
'data/ssst/test_multispectral_det.txt',
prefix='./data/ssst',
label_list='data/ssst/labels_det.txt'),
"data/ssst/test_multispectral_det.txt",
prefix="./data/ssst",
label_list="data/ssst/labels_det.txt"),
build_input_from_file(
'data/ssst/test_det_coco.txt',
prefix='./data/ssst'),
"data/ssst/test_det_coco.txt",
prefix="./data/ssst"),
build_input_from_file(
'data/ssmt/test_mixed_binary.txt',
prefix='./data/ssmt'),
"data/ssmt/test_mixed_binary.txt",
prefix="./data/ssmt"),
build_input_from_file(
'data/ssmt/test_mixed_multiclass.txt',
prefix='./data/ssmt'),
"data/ssmt/test_mixed_multiclass.txt",
prefix="./data/ssmt"),
build_input_from_file(
'data/ssmt/test_mixed_multitask.txt',
prefix='./data/ssmt')
"data/ssmt/test_mixed_multitask.txt",
prefix="./data/ssmt")
]
def test_DecodeImg(self):

Loading…
Cancel
Save