|
|
|
@ -23,15 +23,15 @@ import paddle |
|
|
|
|
import paddle.nn.functional as F |
|
|
|
|
from paddle.static import InputSpec |
|
|
|
|
|
|
|
|
|
import paddlers.models.ppseg as paddleseg |
|
|
|
|
import paddlers.rs_models.seg as cmseg |
|
|
|
|
import paddlers |
|
|
|
|
import paddlers.models.ppseg as ppseg |
|
|
|
|
import paddlers.rs_models.seg as cmseg |
|
|
|
|
from paddlers.utils import get_single_card_bs, DisablePrint |
|
|
|
|
import paddlers.utils.logging as logging |
|
|
|
|
from .base import BaseModel |
|
|
|
|
from .utils import seg_metrics as metrics |
|
|
|
|
from paddlers.utils.checkpoint import seg_pretrain_weights_dict |
|
|
|
|
from paddlers.transforms import Resize, decode_image |
|
|
|
|
from .base import BaseModel |
|
|
|
|
from .utils import seg_metrics as metrics |
|
|
|
|
|
|
|
|
|
__all__ = ["UNet", "DeepLabV3P", "FastSCNN", "HRNet", "BiSeNetV2", "FarSeg"] |
|
|
|
|
|
|
|
|
@ -46,7 +46,7 @@ class BaseSegmenter(BaseModel): |
|
|
|
|
if 'with_net' in self.init_params: |
|
|
|
|
del self.init_params['with_net'] |
|
|
|
|
super(BaseSegmenter, self).__init__('segmenter') |
|
|
|
|
if not hasattr(paddleseg.models, model_name) and \ |
|
|
|
|
if not hasattr(ppseg.models, model_name) and \ |
|
|
|
|
not hasattr(cmseg, model_name): |
|
|
|
|
raise ValueError("ERROR: There is no model named {}.".format( |
|
|
|
|
model_name)) |
|
|
|
@ -63,9 +63,8 @@ class BaseSegmenter(BaseModel): |
|
|
|
|
def build_net(self, **params): |
|
|
|
|
# TODO: when using paddle.utils.unique_name.guard, |
|
|
|
|
# DeepLabv3p and HRNet will raise a error |
|
|
|
|
net = dict(paddleseg.models.__dict__, |
|
|
|
|
**cmseg.__dict__)[self.model_name]( |
|
|
|
|
num_classes=self.num_classes, **params) |
|
|
|
|
net = dict(ppseg.models.__dict__, **cmseg.__dict__)[self.model_name]( |
|
|
|
|
num_classes=self.num_classes, **params) |
|
|
|
|
return net |
|
|
|
|
|
|
|
|
|
def _fix_transforms_shape(self, image_shape): |
|
|
|
@ -143,7 +142,7 @@ class BaseSegmenter(BaseModel): |
|
|
|
|
origin_shape = [label.shape[-2:]] |
|
|
|
|
pred = self._postprocess( |
|
|
|
|
pred, origin_shape, transforms=inputs[2])[0] # NCHW |
|
|
|
|
intersect_area, pred_area, label_area = paddleseg.utils.metrics.calculate_area( |
|
|
|
|
intersect_area, pred_area, label_area = ppseg.utils.metrics.calculate_area( |
|
|
|
|
pred, label, self.num_classes) |
|
|
|
|
outputs['intersect_area'] = intersect_area |
|
|
|
|
outputs['pred_area'] = pred_area |
|
|
|
@ -161,16 +160,13 @@ class BaseSegmenter(BaseModel): |
|
|
|
|
if isinstance(self.use_mixed_loss, bool): |
|
|
|
|
if self.use_mixed_loss: |
|
|
|
|
losses = [ |
|
|
|
|
paddleseg.models.CrossEntropyLoss(), |
|
|
|
|
paddleseg.models.LovaszSoftmaxLoss() |
|
|
|
|
ppseg.models.CrossEntropyLoss(), |
|
|
|
|
ppseg.models.LovaszSoftmaxLoss() |
|
|
|
|
] |
|
|
|
|
coef = [.8, .2] |
|
|
|
|
loss_type = [ |
|
|
|
|
paddleseg.models.MixedLoss( |
|
|
|
|
losses=losses, coef=coef), |
|
|
|
|
] |
|
|
|
|
loss_type = [ppseg.models.MixedLoss(losses=losses, coef=coef), ] |
|
|
|
|
else: |
|
|
|
|
loss_type = [paddleseg.models.CrossEntropyLoss()] |
|
|
|
|
loss_type = [ppseg.models.CrossEntropyLoss()] |
|
|
|
|
else: |
|
|
|
|
losses, coef = list(zip(*self.use_mixed_loss)) |
|
|
|
|
if not set(losses).issubset( |
|
|
|
@ -178,11 +174,8 @@ class BaseSegmenter(BaseModel): |
|
|
|
|
raise ValueError( |
|
|
|
|
"Only 'CrossEntropyLoss', 'DiceLoss', 'LovaszSoftmaxLoss' are supported." |
|
|
|
|
) |
|
|
|
|
losses = [getattr(paddleseg.models, loss)() for loss in losses] |
|
|
|
|
loss_type = [ |
|
|
|
|
paddleseg.models.MixedLoss( |
|
|
|
|
losses=losses, coef=list(coef)) |
|
|
|
|
] |
|
|
|
|
losses = [getattr(ppseg.models, loss)() for loss in losses] |
|
|
|
|
loss_type = [ppseg.models.MixedLoss(losses=losses, coef=list(coef))] |
|
|
|
|
if self.model_name == 'FastSCNN': |
|
|
|
|
loss_type *= 2 |
|
|
|
|
loss_coef = [1.0, 0.4] |
|
|
|
@ -475,13 +468,13 @@ class BaseSegmenter(BaseModel): |
|
|
|
|
pred_area_all = pred_area_all + pred_area |
|
|
|
|
label_area_all = label_area_all + label_area |
|
|
|
|
conf_mat_all.append(conf_mat) |
|
|
|
|
class_iou, miou = paddleseg.utils.metrics.mean_iou( |
|
|
|
|
class_iou, miou = ppseg.utils.metrics.mean_iou( |
|
|
|
|
intersect_area_all, pred_area_all, label_area_all) |
|
|
|
|
# TODO 确认是按oacc还是macc |
|
|
|
|
class_acc, oacc = paddleseg.utils.metrics.accuracy(intersect_area_all, |
|
|
|
|
pred_area_all) |
|
|
|
|
kappa = paddleseg.utils.metrics.kappa(intersect_area_all, pred_area_all, |
|
|
|
|
label_area_all) |
|
|
|
|
class_acc, oacc = ppseg.utils.metrics.accuracy(intersect_area_all, |
|
|
|
|
pred_area_all) |
|
|
|
|
kappa = ppseg.utils.metrics.kappa(intersect_area_all, pred_area_all, |
|
|
|
|
label_area_all) |
|
|
|
|
category_f1score = metrics.f1_score(intersect_area_all, pred_area_all, |
|
|
|
|
label_area_all) |
|
|
|
|
eval_metrics = OrderedDict( |
|
|
|
@ -613,15 +606,15 @@ class BaseSegmenter(BaseModel): |
|
|
|
|
ysize = int(height - yoff) |
|
|
|
|
im = src_data.ReadAsArray(int(xoff), int(yoff), xsize, |
|
|
|
|
ysize).transpose((1, 2, 0)) |
|
|
|
|
# fill |
|
|
|
|
# Fill |
|
|
|
|
h, w = im.shape[:2] |
|
|
|
|
im_fill = np.zeros( |
|
|
|
|
(block_size[1], block_size[0], bands), dtype=im.dtype) |
|
|
|
|
im_fill[:h, :w, :] = im |
|
|
|
|
# predict |
|
|
|
|
# Predict |
|
|
|
|
pred = self.predict(im_fill, |
|
|
|
|
transforms)["label_map"].astype("uint8") |
|
|
|
|
# overlap |
|
|
|
|
# Overlap |
|
|
|
|
rd_block = band.ReadAsArray(int(xoff), int(yoff), xsize, ysize) |
|
|
|
|
mask = (rd_block == pred[:h, :w]) | (rd_block == 255) |
|
|
|
|
temp = pred[:h, :w].copy() |
|
|
|
@ -818,7 +811,7 @@ class DeepLabV3P(BaseSegmenter): |
|
|
|
|
"{'ResNet50_vd', 'ResNet101_vd'}.".format(backbone)) |
|
|
|
|
if params.get('with_net', True): |
|
|
|
|
with DisablePrint(): |
|
|
|
|
backbone = getattr(paddleseg.models, backbone)( |
|
|
|
|
backbone = getattr(ppseg.models, backbone)( |
|
|
|
|
input_channel=input_channel, output_stride=output_stride) |
|
|
|
|
else: |
|
|
|
|
backbone = None |
|
|
|
@ -864,7 +857,7 @@ class HRNet(BaseSegmenter): |
|
|
|
|
self.backbone_name = 'HRNet_W{}'.format(width) |
|
|
|
|
if params.get('with_net', True): |
|
|
|
|
with DisablePrint(): |
|
|
|
|
backbone = getattr(paddleseg.models, self.backbone_name)( |
|
|
|
|
backbone = getattr(ppseg.models, self.backbone_name)( |
|
|
|
|
align_corners=align_corners) |
|
|
|
|
else: |
|
|
|
|
backbone = None |
|
|
|
|