Finish unittests

own
Bobholamovic 2 years ago
parent 228cab666e
commit 61f818411c
  1. 1
      deploy/export/README.md
  2. 22
      paddlers/custom_models/cd/bit.py
  3. 12
      paddlers/custom_models/cd/fc_ef.py
  4. 16
      paddlers/custom_models/cd/fc_siam_conc.py
  5. 16
      paddlers/custom_models/cd/fc_siam_diff.py
  6. 2
      paddlers/custom_models/cd/snunet.py
  7. 9
      paddlers/custom_models/cls/condensenet_v2.py
  8. 31
      paddlers/custom_models/seg/farseg.py
  9. 26
      paddlers/deploy/predictor.py
  10. 2
      paddlers/models/ppdet/modeling/post_process.py
  11. 21
      paddlers/tasks/change_detector.py
  12. 13
      paddlers/tasks/classifier.py
  13. 15
      paddlers/tasks/object_detector.py
  14. 13
      paddlers/tasks/segmenter.py
  15. 22
      paddlers/transforms/__init__.py
  16. 85
      paddlers/transforms/operators.py
  17. 110
      tests/deploy/test_predictor.py

@ -60,4 +60,3 @@ python deploy/export_model.py --model_dir=./output/deeplabv3p/best_model/ --save
- 对于检测模型中的YOLO/PPYOLO系列模型,请保证输入影像的`w`和`h`有相同取值、且均为32的倍数;指定`--fixed_input_shape`时,R-CNN模型的`w`和`h`也均需为32的倍数。
- 指定`[w,h]`时,请使用半角逗号(`,`)分隔`w`和`h`,二者之间不允许存在空格等其它字符。
- 将`w`和`h`设得越大,则模型在推理过程中的耗时和内存/显存占用越高。不过,如果`w`和`h`过小,则可能对模型的精度存在较大负面影响。
- 对于变化检测模型BIT,请保证指定`--fixed_input_shape`,并且数值不包含负数,因为BIT用到空间注意力,需要从tensor中获取`b,c,h,w`的属性,若为负数则报错。

@ -22,6 +22,15 @@ from .layers import Conv3x3, Conv1x1, get_norm_layer, Identity
from .param_init import KaimingInitMixin
def calc_product(*args):
if len(args) < 1:
raise ValueError
ret = args[0]
for arg in args[1:]:
ret *= arg
return ret
class BIT(nn.Layer):
"""
The BIT implementation based on PaddlePaddle.
@ -131,9 +140,10 @@ class BIT(nn.Layer):
def _get_semantic_tokens(self, x):
b, c = x.shape[:2]
att_map = self.conv_att(x)
att_map = att_map.reshape((b, self.token_len, 1, -1))
att_map = att_map.reshape(
(b, self.token_len, 1, calc_product(*att_map.shape[2:])))
att_map = F.softmax(att_map, axis=-1)
x = x.reshape((b, 1, c, -1))
x = x.reshape((b, 1, c, att_map.shape[-1]))
tokens = (x * att_map).sum(-1)
return tokens
@ -253,6 +263,7 @@ class CrossAttention(nn.Layer):
inner_dim = head_dim * n_heads
self.n_heads = n_heads
self.head_dim = head_dim
self.scale = dim**-0.5
self.apply_softmax = apply_softmax
@ -272,9 +283,10 @@ class CrossAttention(nn.Layer):
k = self.fc_k(ref)
v = self.fc_v(ref)
q = q.reshape((b, n, h, -1)).transpose((0, 2, 1, 3))
k = k.reshape((b, paddle.shape(ref)[1], h, -1)).transpose((0, 2, 1, 3))
v = v.reshape((b, paddle.shape(ref)[1], h, -1)).transpose((0, 2, 1, 3))
q = q.reshape((b, n, h, self.head_dim)).transpose((0, 2, 1, 3))
rn = ref.shape[1]
k = k.reshape((b, rn, h, self.head_dim)).transpose((0, 2, 1, 3))
v = v.reshape((b, rn, h, self.head_dim)).transpose((0, 2, 1, 3))
mult = paddle.matmul(q, k, transpose_y=True) * self.scale

@ -131,8 +131,7 @@ class FCEarlyFusion(nn.Layer):
# Stage 4d
x4d = self.upconv4(x4p)
pad4 = (0, paddle.shape(x43)[3] - paddle.shape(x4d)[3], 0,
paddle.shape(x43)[2] - paddle.shape(x4d)[2])
pad4 = (0, x43.shape[3] - x4d.shape[3], 0, x43.shape[2] - x4d.shape[2])
x4d = paddle.concat([F.pad(x4d, pad=pad4, mode='replicate'), x43], 1)
x43d = self.do43d(self.conv43d(x4d))
x42d = self.do42d(self.conv42d(x43d))
@ -140,8 +139,7 @@ class FCEarlyFusion(nn.Layer):
# Stage 3d
x3d = self.upconv3(x41d)
pad3 = (0, paddle.shape(x33)[3] - paddle.shape(x3d)[3], 0,
paddle.shape(x33)[2] - paddle.shape(x3d)[2])
pad3 = (0, x33.shape[3] - x3d.shape[3], 0, x33.shape[2] - x3d.shape[2])
x3d = paddle.concat([F.pad(x3d, pad=pad3, mode='replicate'), x33], 1)
x33d = self.do33d(self.conv33d(x3d))
x32d = self.do32d(self.conv32d(x33d))
@ -149,16 +147,14 @@ class FCEarlyFusion(nn.Layer):
# Stage 2d
x2d = self.upconv2(x31d)
pad2 = (0, paddle.shape(x22)[3] - paddle.shape(x2d)[3], 0,
paddle.shape(x22)[2] - paddle.shape(x2d)[2])
pad2 = (0, x22.shape[3] - x2d.shape[3], 0, x22.shape[2] - x2d.shape[2])
x2d = paddle.concat([F.pad(x2d, pad=pad2, mode='replicate'), x22], 1)
x22d = self.do22d(self.conv22d(x2d))
x21d = self.do21d(self.conv21d(x22d))
# Stage 1d
x1d = self.upconv1(x21d)
pad1 = (0, paddle.shape(x12)[3] - paddle.shape(x1d)[3], 0,
paddle.shape(x12)[2] - paddle.shape(x1d)[2])
pad1 = (0, x12.shape[3] - x1d.shape[3], 0, x12.shape[2] - x1d.shape[2])
x1d = paddle.concat([F.pad(x1d, pad=pad1, mode='replicate'), x12], 1)
x12d = self.do12d(self.conv12d(x1d))
x11d = self.conv11d(x12d)

@ -154,8 +154,8 @@ class FCSiamConc(nn.Layer):
# Decode
# Stage 4d
x4d = self.upconv4(x4p)
pad4 = (0, paddle.shape(x43_1)[3] - paddle.shape(x4d)[3], 0,
paddle.shape(x43_1)[2] - paddle.shape(x4d)[2])
pad4 = (0, x43_1.shape[3] - x4d.shape[3], 0,
x43_1.shape[2] - x4d.shape[2])
x4d = paddle.concat(
[F.pad(x4d, pad=pad4, mode='replicate'), x43_1, x43_2], 1)
x43d = self.do43d(self.conv43d(x4d))
@ -164,8 +164,8 @@ class FCSiamConc(nn.Layer):
# Stage 3d
x3d = self.upconv3(x41d)
pad3 = (0, paddle.shape(x33_1)[3] - paddle.shape(x3d)[3], 0,
paddle.shape(x33_1)[2] - paddle.shape(x3d)[2])
pad3 = (0, x33_1.shape[3] - x3d.shape[3], 0,
x33_1.shape[2] - x3d.shape[2])
x3d = paddle.concat(
[F.pad(x3d, pad=pad3, mode='replicate'), x33_1, x33_2], 1)
x33d = self.do33d(self.conv33d(x3d))
@ -174,8 +174,8 @@ class FCSiamConc(nn.Layer):
# Stage 2d
x2d = self.upconv2(x31d)
pad2 = (0, paddle.shape(x22_1)[3] - paddle.shape(x2d)[3], 0,
paddle.shape(x22_1)[2] - paddle.shape(x2d)[2])
pad2 = (0, x22_1.shape[3] - x2d.shape[3], 0,
x22_1.shape[2] - x2d.shape[2])
x2d = paddle.concat(
[F.pad(x2d, pad=pad2, mode='replicate'), x22_1, x22_2], 1)
x22d = self.do22d(self.conv22d(x2d))
@ -183,8 +183,8 @@ class FCSiamConc(nn.Layer):
# Stage 1d
x1d = self.upconv1(x21d)
pad1 = (0, paddle.shape(x12_1)[3] - paddle.shape(x1d)[3], 0,
paddle.shape(x12_1)[2] - paddle.shape(x1d)[2])
pad1 = (0, x12_1.shape[3] - x1d.shape[3], 0,
x12_1.shape[2] - x1d.shape[2])
x1d = paddle.concat(
[F.pad(x1d, pad=pad1, mode='replicate'), x12_1, x12_2], 1)
x12d = self.do12d(self.conv12d(x1d))

@ -154,8 +154,8 @@ class FCSiamDiff(nn.Layer):
# Decode
# Stage 4d
x4d = self.upconv4(x4p)
pad4 = (0, paddle.shape(x43_1)[3] - paddle.shape(x4d)[3], 0,
paddle.shape(x43_1)[2] - paddle.shape(x4d)[2])
pad4 = (0, x43_1.shape[3] - x4d.shape[3], 0,
x43_1.shape[2] - x4d.shape[2])
x4d = F.pad(x4d, pad=pad4, mode='replicate')
x4d = paddle.concat([x4d, paddle.abs(x43_1 - x43_2)], 1)
x43d = self.do43d(self.conv43d(x4d))
@ -164,8 +164,8 @@ class FCSiamDiff(nn.Layer):
# Stage 3d
x3d = self.upconv3(x41d)
pad3 = (0, paddle.shape(x33_1)[3] - paddle.shape(x3d)[3], 0,
paddle.shape(x33_1)[2] - paddle.shape(x3d)[2])
pad3 = (0, x33_1.shape[3] - x3d.shape[3], 0,
x33_1.shape[2] - x3d.shape[2])
x3d = F.pad(x3d, pad=pad3, mode='replicate')
x3d = paddle.concat([x3d, paddle.abs(x33_1 - x33_2)], 1)
x33d = self.do33d(self.conv33d(x3d))
@ -174,8 +174,8 @@ class FCSiamDiff(nn.Layer):
# Stage 2d
x2d = self.upconv2(x31d)
pad2 = (0, paddle.shape(x22_1)[3] - paddle.shape(x2d)[3], 0,
paddle.shape(x22_1)[2] - paddle.shape(x2d)[2])
pad2 = (0, x22_1.shape[3] - x2d.shape[3], 0,
x22_1.shape[2] - x2d.shape[2])
x2d = F.pad(x2d, pad=pad2, mode='replicate')
x2d = paddle.concat([x2d, paddle.abs(x22_1 - x22_2)], 1)
x22d = self.do22d(self.conv22d(x2d))
@ -183,8 +183,8 @@ class FCSiamDiff(nn.Layer):
# Stage 1d
x1d = self.upconv1(x21d)
pad1 = (0, paddle.shape(x12_1)[3] - paddle.shape(x1d)[3], 0,
paddle.shape(x12_1)[2] - paddle.shape(x1d)[2])
pad1 = (0, x12_1.shape[3] - x1d.shape[3], 0,
x12_1.shape[2] - x1d.shape[2])
x1d = F.pad(x1d, pad=pad1, mode='replicate')
x1d = paddle.concat([x1d, paddle.abs(x12_1 - x12_2)], 1)
x12d = self.do12d(self.conv12d(x1d))

@ -132,7 +132,7 @@ class SNUNet(nn.Layer, KaimingInitMixin):
out = paddle.concat([x0_1, x0_2, x0_3, x0_4], 1)
intra = paddle.sum(paddle.stack([x0_1, x0_2, x0_3, x0_4]), axis=0)
intra = x0_1 + x0_2 + x0_3 + x0_4
m_intra = self.ca_intra(intra)
out = self.ca_inter(out) * (out + paddle.tile(m_intra, (1, 4, 1, 1)))

@ -39,7 +39,7 @@ class SELayer(nn.Layer):
b, c, _, _ = x.shape
y = self.avg_pool(x).reshape((b, c))
y = self.fc(y).reshape((b, c, 1, 1))
return x * y.expand_as(x)
return x * paddle.expand(y, shape=x.shape)
class HS(nn.Layer):
@ -92,7 +92,7 @@ def ShuffleLayer(x, groups):
# transpose
x = x.transpose((0, 2, 1, 3, 4))
# reshape
x = x.reshape((batchsize, -1, height, width))
x = x.reshape((batchsize, groups * channels_per_group, height, width))
return x
@ -104,7 +104,7 @@ def ShuffleLayerTrans(x, groups):
# transpose
x = x.transpose((0, 2, 1, 3, 4))
# reshape
x = x.reshape((batchsize, -1, height, width))
x = x.reshape((batchsize, channels_per_group * groups, height, width))
return x
@ -374,7 +374,8 @@ class CondenseNetV2(nn.Layer):
def forward(self, x):
features = self.features(x)
out = features.reshape((features.shape[0], -1))
out = features.reshape((features.shape[0], features.shape[1] *
features.shape[2] * features.shape[3]))
out = self.fc(out)
out = self.fc_act(out)

@ -41,38 +41,35 @@ class FPN(nn.Layer):
conv_block=ConvReLU,
top_blocks=None):
super(FPN, self).__init__()
self.inner_blocks = []
self.layer_blocks = []
inner_blocks = []
layer_blocks = []
for idx, in_channels in enumerate(in_channels_list, 1):
inner_block = "fpn_inner{}".format(idx)
layer_block = "fpn_layer{}".format(idx)
if in_channels == 0:
continue
inner_block_module = conv_block(in_channels, out_channels, 1)
layer_block_module = conv_block(out_channels, out_channels, 3, 1)
self.add_sublayer(inner_block, inner_block_module)
self.add_sublayer(layer_block, layer_block_module)
for module in [inner_block_module, layer_block_module]:
for m in module.sublayers():
if isinstance(m, nn.Conv2D):
kaiming_normal_init(m.weight)
self.inner_blocks.append(inner_block)
self.layer_blocks.append(layer_block)
inner_blocks.append(inner_block_module)
layer_blocks.append(layer_block_module)
self.inner_blocks = nn.LayerList(inner_blocks)
self.layer_blocks = nn.LayerList(layer_blocks)
self.top_blocks = top_blocks
def forward(self, x):
last_inner = getattr(self, self.inner_blocks[-1])(x[-1])
results = [getattr(self, self.layer_blocks[-1])(last_inner)]
for feature, inner_block, layer_block in zip(
x[:-1][::-1], self.inner_blocks[:-1][::-1],
self.layer_blocks[:-1][::-1]):
if not inner_block:
continue
last_inner = self.inner_blocks[-1](x[-1])
results = [self.layer_blocks[-1](last_inner)]
for i, feature in enumerate(x[-2::-1]):
inner_block = self.inner_blocks[len(self.inner_blocks) - 2 - i]
layer_block = self.layer_blocks[len(self.layer_blocks) - 2 - i]
inner_top_down = F.interpolate(
last_inner, scale_factor=2, mode="nearest")
inner_lateral = getattr(self, inner_block)(feature)
inner_lateral = inner_block(feature)
last_inner = inner_lateral + inner_top_down
results.insert(0, getattr(self, layer_block)(last_inner))
results.insert(0, layer_block(last_inner))
if isinstance(self.top_blocks, LastLevelP6P7):
last_results = self.top_blocks(x[-1], results[-1])
results.extend(last_results)

@ -252,22 +252,26 @@ class Predictor(object):
transforms=None,
warmup_iters=0,
repeats=1):
""" 图片预测
"""
Do prediction.
Args:
img_file(List[str or tuple or np.ndarray], str, tuple, or np.ndarray):
对于场景分类图像复原目标检测和语义分割任务来说该参数可为单一图像路径或是解码后的排列格式为H, W, C
且具有float32类型的BGR图像表示为numpy的ndarray形式或者是一组图像路径或np.ndarray对象构成的列表对于变化检测
任务来说该参数可以为图像路径二元组分别表示前后两个时相影像路径或是两幅图像组成的二元组或者是上述两种二元组
之一构成的列表
topk(int): 场景分类模型预测时使用表示预测前topk的结果默认值为1
transforms (paddlers.transforms): 数据预处理操作默认值为None, 即使用`model.yml`中保存的数据预处理操作
warmup_iters (int): 预热轮数用于评估模型推理以及前后处理速度若大于1会预先重复预测warmup_iters而后才开始正式的预测及其速度评估默认为0
repeats (int): 重复次数用于评估模型推理以及前后处理速度若大于1会预测repeats次取时间平均值默认值为1
img_file(list[str | tuple | np.ndarray] | str | tuple | np.ndarray): For scene classification, image restoration,
object detection and semantic segmentation tasks, `img_file` should be either the path of the image to predict
, a decoded image (a `np.ndarray`, which should be consistent with what you get from passing image path to
`paddlers.transforms.decode_image()`), or a list of image paths or decoded images. For change detection tasks,
`img_file` should be a tuple of image paths, a tuple of decoded images, or a list of tuples.
topk(int, optional): Top-k values to reserve in a classification result. Defaults to 1.
transforms (paddlers.transforms.Compose | None, optional): Pipeline of data preprocessing. If None, load transforms
from `model.yml`. Defaults to None.
warmup_iters (int, optional): Warm-up iterations before measuring the execution time. Defaults to 0.
repeats (int, optional): Number of repetitions to evaluate model inference and data processing speed. If greater than
1, the reported time consumption is the average of all repeats. Defaults to 1.
"""
if repeats < 1:
logging.error("`repeats` must be greater than 1.", exit=True)
if transforms is None and not hasattr(self._model, 'test_transforms'):
raise Exception("Transforms need to be defined, now is None.")
raise ValueError("Transforms need to be defined, now is None.")
if transforms is None:
transforms = self._model.test_transforms
if isinstance(img_file, tuple) and len(img_file) != 2:

@ -209,7 +209,7 @@ class MaskPostProcess(object):
# TODO: support bs > 1 and mask output dtype is bool
pred_result = paddle.zeros(
[num_mask, origin_shape[0][0], origin_shape[0][1]], dtype='int32')
if bbox_num == 1 and bboxes[0][0] == -1:
if (len(bbox_num) == 1 and bbox_num[0] == 1) and bboxes[0][0] == -1:
return pred_result
# TODO: optimize chunk paste

@ -29,7 +29,7 @@ import paddlers.custom_models.cd as cmcd
import paddlers.utils.logging as logging
import paddlers.models.ppseg as paddleseg
from paddlers.transforms import arrange_transforms
from paddlers.transforms import DecodeImg, Resize
from paddlers.transforms import Resize, decode_image
from paddlers.utils import get_single_card_bs, DisablePrint
from paddlers.utils.checkpoint import seg_pretrain_weights_dict
from .base import BaseModel
@ -502,8 +502,8 @@ class BaseChangeDetector(BaseModel):
Args:
Args:
img_file(List[tuple], Tuple[str or np.ndarray]):
Tuple of image paths or decoded image data in a BGR format for bi-temporal images, which also could constitute
a list, meaning all image pairs to be predicted as a mini-batch.
Tuple of image paths or decoded image data for bi-temporal images, which also could constitute a list,
meaning all image pairs to be predicted as a mini-batch.
transforms(paddlers.transforms.Compose or None, optional):
Transforms for inputs. If None, the transforms for evaluation process will be used. Defaults to None.
@ -646,15 +646,12 @@ class BaseChangeDetector(BaseModel):
batch_im1, batch_im2 = list(), list()
batch_ori_shape = list()
for im1, im2 in images:
sample = {'image_t1': im1, 'image_t2': im2}
if isinstance(sample['image_t1'], str) or \
isinstance(sample['image_t2'], str):
sample = DecodeImg(to_rgb=False)(sample)
sample['image'] = sample['image'].astype('float32')
sample['image2'] = sample['image2'].astype('float32')
ori_shape = sample['image'].shape[:2]
else:
ori_shape = im1.shape[:2]
if isinstance(im1, str) or isinstance(im2, str):
im1 = decode_image(im1, to_rgb=False)
im2 = decode_image(im2, to_rgb=False)
ori_shape = im1.shape[:2]
# XXX: sample do not contain 'image_t1' and 'image_t2'.
sample = {'image': im1, 'image2': im2}
im1, im2 = transforms(sample)[:2]
batch_im1.append(im1)
batch_im2.append(im2)

@ -33,7 +33,7 @@ from paddlers.models.ppcls.metric import build_metrics
from paddlers.models.ppcls.loss import build_loss
from paddlers.models.ppcls.data.postprocess import build_postprocess
from paddlers.utils.checkpoint import cls_pretrain_weights_dict
from paddlers.transforms import DecodeImg, Resize
from paddlers.transforms import Resize, decode_image
__all__ = [
"ResNet50_vd", "MobileNetV3_small_x1_0", "HRNet_W18_C", "CondenseNetV2_b"
@ -411,8 +411,8 @@ class BaseClassifier(BaseModel):
Args:
Args:
img_file(List[np.ndarray or str], str or np.ndarray):
Image path or decoded image data in a BGR format, which also could constitute a list,
meaning all images to be predicted as a mini-batch.
Image path or decoded image data, which also could constitute a list, meaning all images to be
predicted as a mini-batch.
transforms(paddlers.transforms.Compose or None, optional):
Transforms for inputs. If None, the transforms for evaluation process will be used. Defaults to None.
@ -465,11 +465,10 @@ class BaseClassifier(BaseModel):
batch_im = list()
batch_ori_shape = list()
for im in images:
if isinstance(im, str):
im = decode_image(im, to_rgb=False)
ori_shape = im.shape[:2]
sample = {'image': im}
if isinstance(sample['image'], str):
sample = DecodeImg(to_rgb=False)(sample)
sample['image'] = sample['image'].astype('float32')
ori_shape = sample['image'].shape[:2]
im = transforms(sample)
batch_im.append(im)
batch_ori_shape.append(ori_shape)

@ -27,7 +27,8 @@ import paddlers.models.ppdet as ppdet
from paddlers.models.ppdet.modeling.proposal_generator.target_layer import BBoxAssigner, MaskAssigner
import paddlers
import paddlers.utils.logging as logging
from paddlers.transforms.operators import _NormalizeBox, _PadBox, _BboxXYXY2XYWH, Resize, Pad, DecodeImg
from paddlers.transforms import decode_image
from paddlers.transforms.operators import _NormalizeBox, _PadBox, _BboxXYXY2XYWH, Resize, Pad
from paddlers.transforms.batch_operators import BatchCompose, BatchRandomResize, BatchRandomResizeByShort, \
_BatchPad, _Gt2YoloTarget
from paddlers.transforms import arrange_transforms
@ -37,8 +38,7 @@ from paddlers.models.ppdet.optimizer import ModelEMA
from paddlers.utils.checkpoint import det_pretrain_weights_dict
__all__ = [
"YOLOv3", "FasterRCNN", "PPYOLO", "PPYOLOTiny", "PPYOLOv2", "MaskRCNN",
"PicoDet"
"YOLOv3", "FasterRCNN", "PPYOLO", "PPYOLOTiny", "PPYOLOv2", "MaskRCNN"
]
@ -512,8 +512,8 @@ class BaseDetector(BaseModel):
Do inference.
Args:
img_file(List[np.ndarray or str], str or np.ndarray):
Image path or decoded image data in a BGR format, which also could constitute a list,
meaning all images to be predicted as a mini-batch.
Image path or decoded image data, which also could constitute a list,meaning all images to be
predicted as a mini-batch.
transforms(paddlers.transforms.Compose or None, optional):
Transforms for inputs. If None, the transforms for evaluation process will be used. Defaults to None.
Returns:
@ -549,10 +549,9 @@ class BaseDetector(BaseModel):
model_type=self.model_type, transforms=transforms, mode='test')
batch_samples = list()
for im in images:
if isinstance(im, str):
im = decode_image(im, to_rgb=False)
sample = {'image': im}
if isinstance(sample['image'], str):
sample = DecodeImg(to_rgb=False)(sample)
sample['image'] = sample['image'].astype('float32')
sample = transforms(sample)
batch_samples.append(sample)
batch_transforms = self._compose_batch_transform(transforms, 'test')

@ -32,7 +32,7 @@ import paddlers.utils.logging as logging
from .base import BaseModel
from .utils import seg_metrics as metrics
from paddlers.utils.checkpoint import seg_pretrain_weights_dict
from paddlers.transforms import DecodeImg, Resize
from paddlers.transforms import Resize, decode_image
__all__ = ["UNet", "DeepLabV3P", "FastSCNN", "HRNet", "BiSeNetV2", "FarSeg"]
@ -479,8 +479,8 @@ class BaseSegmenter(BaseModel):
Args:
Args:
img_file(List[np.ndarray or str], str or np.ndarray):
Image path or decoded image data in a BGR format, which also could constitute a list,
meaning all images to be predicted as a mini-batch.
Image path or decoded image data, which also could constitute a list,meaning all images to be
predicted as a mini-batch.
transforms(paddlers.transforms.Compose or None, optional):
Transforms for inputs. If None, the transforms for evaluation process will be used. Defaults to None.
@ -611,11 +611,10 @@ class BaseSegmenter(BaseModel):
batch_im = list()
batch_ori_shape = list()
for im in images:
if isinstance(im, str):
im = decode_image(im, to_rgb=False)
ori_shape = im.shape[:2]
sample = {'image': im}
if isinstance(sample['image'], str):
sample = DecodeImg(to_rgb=False)(sample)
sample['image'] = sample['image'].astype('float32')
ori_shape = sample['image'].shape[:2]
im = transforms(sample)[0]
batch_im.append(im)
batch_ori_shape.append(ori_shape)

@ -12,11 +12,33 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import os.path as osp
from .operators import *
from .batch_operators import BatchRandomResize, BatchRandomResizeByShort, _BatchPad
from paddlers import transforms as T
def decode_image(im_path,
to_rgb=True,
to_uint8=True,
decode_rgb=True,
decode_sar=False):
# Do a presence check. `osp.exists` assumes `im_path` is a path-like object.
if not osp.exists(im_path):
raise ValueError(f"{im_path} does not exist!")
decoder = T.DecodeImg(
to_rgb=to_rgb,
to_uint8=to_uint8,
decode_rgb=decode_rgb,
decode_sar=decode_sar)
# Deepcopy to avoid inplace modification
sample = {'image': copy.deepcopy(im_path)}
sample = decoder(sample)
return sample['image']
def arrange_transforms(model_type, transforms, mode='train'):
# 给transforms添加arrange操作
if model_type == 'segmenter':

@ -124,15 +124,24 @@ class DecodeImg(Transform):
Decode image(s) in input.
Args:
to_rgb (bool, optional): If True, convert input images from BGR format to RGB format. Defaults to True.
to_rgb (bool, optional): If True, convert input image(s) from BGR format to RGB format. Defaults to True.
to_uint8 (bool, optional): If True, quantize and convert decoded image(s) to uint8 type. Defaults to True.
decode_rgb (bool, optional): If the image to decode is a non-geo RGB image (e.g., jpeg images), set this argument to True. Defaults to True.
decode_sar (bool, optional): If the image to decode is a SAR image, set this argument to True. Defaults to False.
"""
def __init__(self, to_rgb=True, to_uint8=True):
def __init__(self,
to_rgb=True,
to_uint8=True,
decode_rgb=True,
decode_sar=False):
super(DecodeImg, self).__init__()
self.to_rgb = to_rgb
self.to_uint8 = to_uint8
self.decode_rgb = decode_rgb
self.decode_sar = decode_sar
def read_img(self, img_path, input_channel=3):
def read_img(self, img_path):
img_format = imghdr.what(img_path)
name, ext = os.path.splitext(img_path)
if img_format == 'tiff' or ext == '.img':
@ -141,24 +150,28 @@ class DecodeImg(Transform):
except:
try:
from osgeo import gdal
except:
raise Exception(
"Failed to import gdal! You can try use conda to install gdal"
except ImportError:
raise ImportError(
"Failed to import gdal! Please install GDAL library according to the document."
)
six.reraise(*sys.exc_info())
dataset = gdal.Open(img_path)
if dataset == None:
raise Exception('Can not open', img_path)
raise IOError('Can not open', img_path)
im_data = dataset.ReadAsArray()
if im_data.ndim == 2:
if self.decode_sar:
if im_data.ndim != 2:
raise ValueError(
f"SAR images should have exactly 2 channels, but the image has {im_data.ndim} channels."
)
im_data = to_intensity(im_data) # is read SAR
im_data = im_data[:, :, np.newaxis]
elif im_data.ndim == 3:
im_data = im_data.transpose((1, 2, 0))
else:
if im_data.ndim == 3:
im_data = im_data.transpose((1, 2, 0))
return im_data
elif img_format in ['jpeg', 'bmp', 'png', 'jpg']:
if input_channel == 3:
if self.decode_rgb:
return cv2.imread(img_path, cv2.IMREAD_ANYDEPTH |
cv2.IMREAD_ANYCOLOR | cv2.IMREAD_COLOR)
else:
@ -167,7 +180,7 @@ class DecodeImg(Transform):
elif ext == '.npy':
return np.load(img_path)
else:
raise Exception('Image format {} is not supported!'.format(ext))
raise TypeError('Image format {} is not supported!'.format(ext))
def apply_im(self, im_path):
if isinstance(im_path, str):
@ -193,7 +206,7 @@ class DecodeImg(Transform):
except:
raise ValueError("Cannot read the mask file {}!".format(mask))
if len(mask.shape) != 2:
raise Exception(
raise ValueError(
"Mask should be a 1-channel image, but recevied is a {}-channel image.".
format(mask.shape[2]))
return mask
@ -202,6 +215,7 @@ class DecodeImg(Transform):
"""
Args:
sample (dict): Input sample.
Returns:
dict: Decoded sample.
"""
@ -219,8 +233,8 @@ class DecodeImg(Transform):
im_height, im_width, _ = sample['image'].shape
se_height, se_width = sample['mask'].shape
if im_height != se_height or im_width != se_width:
raise Exception(
"The height or width of the im is not same as the mask")
raise ValueError(
"The height or width of the image is not same as the mask.")
if 'aux_masks' in sample:
sample['aux_masks'] = list(
map(self.apply_mask, sample['aux_masks']))
@ -595,6 +609,16 @@ class RandomFlipOrRotate(Transform):
mask = img_simple_rotate(mask, mode_id)
return mask
def apply_bbox(self, bbox, mode_id, flip_mode=True):
raise TypeError(
"Currently, `paddlers.transforms.RandomFlipOrRotate` is not available for object detection tasks."
)
def apply_segm(self, bbox, mode_id, flip_mode=True):
raise TypeError(
"Currently, `paddlers.transforms.RandomFlipOrRotate` is not available for object detection tasks."
)
def get_probs_range(self, probs):
'''
Change various probabilities into cumulative probabilities
@ -638,14 +662,43 @@ class RandomFlipOrRotate(Transform):
mode_p = random.random()
mode_id = self.judge_probs_range(mode_p, self.probsf)
sample['image'] = self.apply_im(sample['image'], mode_id, True)
if 'image2' in sample:
sample['image2'] = self.apply_im(sample['image2'], mode_id,
True)
if 'mask' in sample:
sample['mask'] = self.apply_mask(sample['mask'], mode_id, True)
if 'aux_masks' in sample:
sample['aux_masks'] = [
self.apply_mask(aux_mask, mode_id, True)
for aux_mask in sample['aux_masks']
]
if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'], mode_id,
True)
if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
sample['gt_poly'] = self.apply_segm(sample['gt_poly'], mode_id,
True)
elif p_m < self.probs[1]:
mode_p = random.random()
mode_id = self.judge_probs_range(mode_p, self.probsr)
sample['image'] = self.apply_im(sample['image'], mode_id, False)
if 'image2' in sample:
sample['image2'] = self.apply_im(sample['image2'], mode_id,
False)
if 'mask' in sample:
sample['mask'] = self.apply_mask(sample['mask'], mode_id, False)
if 'aux_masks' in sample:
sample['aux_masks'] = [
self.apply_mask(aux_mask, mode_id, False)
for aux_mask in sample['aux_masks']
]
if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'], mode_id,
False)
if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
sample['gt_poly'] = self.apply_segm(sample['gt_poly'], mode_id,
False)
return sample

@ -16,10 +16,10 @@ import os.path as osp
import tempfile
import unittest.mock as mock
import cv2
import paddle
import paddlers as pdrs
from paddlers.transforms import decode_image
from testing_utils import CommonTest, run_script
__all__ = [
@ -31,6 +31,7 @@ __all__ = [
class TestPredictor(CommonTest):
MODULE = pdrs.tasks
TRAINER_NAME_TO_EXPORT_OPTS = {}
WHITE_LIST = []
@staticmethod
def add_tests(cls):
@ -42,6 +43,7 @@ class TestPredictor(CommonTest):
def _test_predictor_impl(self):
trainer_class = getattr(self.MODULE, trainer_name)
# Construct trainer with default parameters
# TODO: Load pretrained weights to avoid numeric problems
trainer = trainer_class()
with tempfile.TemporaryDirectory() as td:
dynamic_model_dir = osp.join(td, "dynamic")
@ -69,6 +71,8 @@ class TestPredictor(CommonTest):
return _test_predictor_impl
for trainer_name in cls.MODULE.__all__:
if trainer_name in cls.WHITE_LIST:
continue
setattr(cls, 'test_' + trainer_name, _test_predictor(trainer_name))
return cls
@ -76,27 +80,44 @@ class TestPredictor(CommonTest):
def check_predictor(self, predictor, trainer):
raise NotImplementedError
def check_dict_equal(self, dict_, expected_dict):
def check_dict_equal(
self,
dict_,
expected_dict,
ignore_keys=('label_map', 'mask', 'category', 'category_id')):
# By default do not compare label_maps, masks, or categories,
# because numeric errors could result in large difference in labels.
if isinstance(dict_, list):
self.assertIsInstance(expected_dict, list)
self.assertEqual(len(dict_), len(expected_dict))
for d1, d2 in zip(dict_, expected_dict):
self.check_dict_equal(d1, d2)
self.check_dict_equal(d1, d2, ignore_keys=ignore_keys)
else:
assert isinstance(dict_, dict)
assert isinstance(expected_dict, dict)
self.assertEqual(dict_.keys(), expected_dict.keys())
ignore_keys = set() if ignore_keys is None else set(ignore_keys)
for key in dict_.keys():
self.check_output_equal(dict_[key], expected_dict[key])
if key in ignore_keys:
continue
if isinstance(dict_[key], (list, dict)):
self.check_dict_equal(
dict_[key], expected_dict[key], ignore_keys=ignore_keys)
else:
# Use higher tolerance
self.check_output_equal(
dict_[key], expected_dict[key], rtol=1.e-4, atol=1.e-6)
@TestPredictor.add_tests
class TestCDPredictor(TestPredictor):
MODULE = pdrs.tasks.change_detector
TRAINER_NAME_TO_EXPORT_OPTS = {
'BIT': "--fixed_input_shape [1,3,256,256]",
'_default': "--fixed_input_shape [-1,3,256,256]"
}
# HACK: Skip CDNet.
# These models are heavily affected by numeric errors.
WHITE_LIST = ['CDNet']
def check_predictor(self, predictor, trainer):
t1_path = "data/ssmt/optical_t1.bmp"
@ -124,9 +145,9 @@ class TestCDPredictor(TestPredictor):
out_single_file_list_t[0])
# Single input (ndarrays)
input_ = (
cv2.imread(t1_path).astype('float32'),
cv2.imread(t2_path).astype('float32')) # Reuse the name `input_`
input_ = (decode_image(
t1_path, to_rgb=False), decode_image(
t2_path, to_rgb=False)) # Reuse the name `input_`
out_single_array_p = predictor.predict(input_, transforms=transforms)
self.check_dict_equal(out_single_array_p, out_single_file_p)
out_single_array_t = trainer.predict(input_, transforms=transforms)
@ -140,23 +161,21 @@ class TestCDPredictor(TestPredictor):
self.check_dict_equal(out_single_array_list_p[0],
out_single_array_list_t[0])
if isinstance(trainer, pdrs.tasks.change_detector.BIT):
return
# Multiple inputs (file paths)
input_ = [single_input] * num_inputs # Reuse the name `input_`
out_multi_file_p = predictor.predict(input_, transforms=transforms)
self.assertEqual(len(out_multi_file_p), num_inputs)
out_multi_file_t = trainer.predict(input_, transforms=transforms)
self.check_dict_equal(out_multi_file_p, out_multi_file_t)
self.assertEqual(len(out_multi_file_t), num_inputs)
# Multiple inputs (ndarrays)
input_ = [(cv2.imread(t1_path).astype('float32'), cv2.imread(t2_path)
.astype('float32'))] * num_inputs # Reuse the name `input_`
input_ = [(decode_image(
t1_path, to_rgb=False), decode_image(
t2_path, to_rgb=False))] * num_inputs # Reuse the name `input_`
out_multi_array_p = predictor.predict(input_, transforms=transforms)
self.assertEqual(len(out_multi_array_p), num_inputs)
out_multi_array_t = trainer.predict(input_, transforms=transforms)
self.check_dict_equal(out_multi_array_p, out_multi_array_t)
self.assertEqual(len(out_multi_array_t), num_inputs)
@TestPredictor.add_tests
@ -189,8 +208,8 @@ class TestClasPredictor(TestPredictor):
out_single_file_list_t[0])
# Single input (ndarray)
input_ = cv2.imread(single_input).astype(
'float32') # Reuse the name `input_`
input_ = decode_image(
single_input, to_rgb=False) # Reuse the name `input_`
out_single_array_p = predictor.predict(input_, transforms=transforms)
self.check_dict_equal(out_single_array_p, out_single_file_p)
out_single_array_t = trainer.predict(input_, transforms=transforms)
@ -209,16 +228,15 @@ class TestClasPredictor(TestPredictor):
out_multi_file_p = predictor.predict(input_, transforms=transforms)
self.assertEqual(len(out_multi_file_p), num_inputs)
out_multi_file_t = trainer.predict(input_, transforms=transforms)
self.assertEqual(len(out_multi_file_p), len(out_multi_file_t))
# Check value consistence
self.check_dict_equal(out_multi_file_p, out_multi_file_t)
# Multiple inputs (ndarrays)
input_ = [cv2.imread(single_input).astype('float32')
] * num_inputs # Reuse the name `input_`
input_ = [decode_image(
single_input, to_rgb=False)] * num_inputs # Reuse the name `input_`
out_multi_array_p = predictor.predict(input_, transforms=transforms)
self.assertEqual(len(out_multi_array_p), num_inputs)
out_multi_array_t = trainer.predict(input_, transforms=transforms)
self.assertEqual(len(out_multi_array_p), len(out_multi_array_t))
self.check_dict_equal(out_multi_array_p, out_multi_array_t)
@ -230,6 +248,9 @@ class TestDetPredictor(TestPredictor):
}
def check_predictor(self, predictor, trainer):
# For detection tasks, do NOT ensure the consistence of bboxes.
# This is because the coordinates of bboxes were observed to be very sensitive to numeric errors,
# given that the network is (partially?) randomly initialized.
single_input = "data/ssmt/optical_t1.bmp"
num_inputs = 2
transforms = pdrs.transforms.Compose([pdrs.transforms.Normalize()])
@ -239,50 +260,41 @@ class TestDetPredictor(TestPredictor):
# Single input (file path)
input_ = single_input
out_single_file_p = predictor.predict(input_, transforms=transforms)
out_single_file_t = trainer.predict(input_, transforms=transforms)
self.check_dict_equal(out_single_file_p, out_single_file_t)
predictor.predict(input_, transforms=transforms)
trainer.predict(input_, transforms=transforms)
out_single_file_list_p = predictor.predict(
[input_], transforms=transforms)
self.assertEqual(len(out_single_file_list_p), 1)
self.check_dict_equal(out_single_file_list_p[0], out_single_file_p)
out_single_file_list_t = trainer.predict(
[input_], transforms=transforms)
self.check_dict_equal(out_single_file_list_p[0],
out_single_file_list_t[0])
self.assertEqual(len(out_single_file_list_t), 1)
# Single input (ndarray)
input_ = cv2.imread(single_input).astype(
'float32') # Reuse the name `input_`
out_single_array_p = predictor.predict(input_, transforms=transforms)
self.check_dict_equal(out_single_array_p, out_single_file_p)
out_single_array_t = trainer.predict(input_, transforms=transforms)
self.check_dict_equal(out_single_array_p, out_single_array_t)
input_ = decode_image(
single_input, to_rgb=False) # Reuse the name `input_`
predictor.predict(input_, transforms=transforms)
trainer.predict(input_, transforms=transforms)
out_single_array_list_p = predictor.predict(
[input_], transforms=transforms)
self.assertEqual(len(out_single_array_list_p), 1)
self.check_dict_equal(out_single_array_list_p[0], out_single_array_p)
out_single_array_list_t = trainer.predict(
[input_], transforms=transforms)
self.check_dict_equal(out_single_array_list_p[0],
out_single_array_list_t[0])
self.assertEqual(len(out_single_array_list_t), 1)
# Multiple inputs (file paths)
input_ = [single_input] * num_inputs # Reuse the name `input_`
out_multi_file_p = predictor.predict(input_, transforms=transforms)
self.assertEqual(len(out_multi_file_p), num_inputs)
out_multi_file_t = trainer.predict(input_, transforms=transforms)
self.assertEqual(len(out_multi_file_p), len(out_multi_file_t))
self.check_dict_equal(out_multi_file_p, out_multi_file_t)
self.assertEqual(len(out_multi_file_t), num_inputs)
# Multiple inputs (ndarrays)
input_ = [cv2.imread(single_input).astype('float32')
] * num_inputs # Reuse the name `input_`
input_ = [decode_image(
single_input, to_rgb=False)] * num_inputs # Reuse the name `input_`
out_multi_array_p = predictor.predict(input_, transforms=transforms)
self.assertEqual(len(out_multi_array_p), num_inputs)
out_multi_array_t = trainer.predict(input_, transforms=transforms)
self.assertEqual(len(out_multi_array_p), len(out_multi_array_t))
self.check_dict_equal(out_multi_array_p, out_multi_array_t)
self.assertEqual(len(out_multi_array_t), num_inputs)
@TestPredictor.add_tests
@ -312,8 +324,8 @@ class TestSegPredictor(TestPredictor):
out_single_file_list_t[0])
# Single input (ndarray)
input_ = cv2.imread(single_input).astype(
'float32') # Reuse the name `input_`
input_ = decode_image(
single_input, to_rgb=False) # Reuse the name `input_`
out_single_array_p = predictor.predict(input_, transforms=transforms)
self.check_dict_equal(out_single_array_p, out_single_file_p)
out_single_array_t = trainer.predict(input_, transforms=transforms)
@ -332,14 +344,12 @@ class TestSegPredictor(TestPredictor):
out_multi_file_p = predictor.predict(input_, transforms=transforms)
self.assertEqual(len(out_multi_file_p), num_inputs)
out_multi_file_t = trainer.predict(input_, transforms=transforms)
self.assertEqual(len(out_multi_file_p), len(out_multi_file_t))
self.check_dict_equal(out_multi_file_p, out_multi_file_t)
self.assertEqual(len(out_multi_file_t), num_inputs)
# Multiple inputs (ndarrays)
input_ = [cv2.imread(single_input).astype('float32')
] * num_inputs # Reuse the name `input_`
input_ = [decode_image(
single_input, to_rgb=False)] * num_inputs # Reuse the name `input_`
out_multi_array_p = predictor.predict(input_, transforms=transforms)
self.assertEqual(len(out_multi_array_p), num_inputs)
out_multi_array_t = trainer.predict(input_, transforms=transforms)
self.assertEqual(len(out_multi_array_p), len(out_multi_array_t))
self.check_dict_equal(out_multi_array_p, out_multi_array_t)
self.assertEqual(len(out_multi_array_t), num_inputs)

Loading…
Cancel
Save