|
|
|
@ -23,15 +23,16 @@ import paddle |
|
|
|
|
import paddle.nn.functional as F |
|
|
|
|
from paddle.static import InputSpec |
|
|
|
|
|
|
|
|
|
import paddlers.models.ppseg as paddleseg |
|
|
|
|
import paddlers.rs_models.seg as cmseg |
|
|
|
|
import paddlers |
|
|
|
|
from paddlers.utils import get_single_card_bs, DisablePrint |
|
|
|
|
import paddlers.models.ppseg as ppseg |
|
|
|
|
import paddlers.rs_models.seg as cmseg |
|
|
|
|
import paddlers.utils.logging as logging |
|
|
|
|
from paddlers.models import seg_losses |
|
|
|
|
from paddlers.transforms import Resize, decode_image |
|
|
|
|
from paddlers.utils import get_single_card_bs, DisablePrint |
|
|
|
|
from paddlers.utils.checkpoint import seg_pretrain_weights_dict |
|
|
|
|
from .base import BaseModel |
|
|
|
|
from .utils import seg_metrics as metrics |
|
|
|
|
from paddlers.utils.checkpoint import seg_pretrain_weights_dict |
|
|
|
|
from paddlers.transforms import Resize, decode_image |
|
|
|
|
|
|
|
|
|
__all__ = ["UNet", "DeepLabV3P", "FastSCNN", "HRNet", "BiSeNetV2", "FarSeg"] |
|
|
|
|
|
|
|
|
@ -41,19 +42,20 @@ class BaseSegmenter(BaseModel): |
|
|
|
|
model_name, |
|
|
|
|
num_classes=2, |
|
|
|
|
use_mixed_loss=False, |
|
|
|
|
losses=None, |
|
|
|
|
**params): |
|
|
|
|
self.init_params = locals() |
|
|
|
|
if 'with_net' in self.init_params: |
|
|
|
|
del self.init_params['with_net'] |
|
|
|
|
super(BaseSegmenter, self).__init__('segmenter') |
|
|
|
|
if not hasattr(paddleseg.models, model_name) and \ |
|
|
|
|
if not hasattr(ppseg.models, model_name) and \ |
|
|
|
|
not hasattr(cmseg, model_name): |
|
|
|
|
raise ValueError("ERROR: There is no model named {}.".format( |
|
|
|
|
model_name)) |
|
|
|
|
self.model_name = model_name |
|
|
|
|
self.num_classes = num_classes |
|
|
|
|
self.use_mixed_loss = use_mixed_loss |
|
|
|
|
self.losses = None |
|
|
|
|
self.losses = losses |
|
|
|
|
self.labels = None |
|
|
|
|
if params.get('with_net', True): |
|
|
|
|
params.pop('with_net', None) |
|
|
|
@ -63,9 +65,8 @@ class BaseSegmenter(BaseModel): |
|
|
|
|
def build_net(self, **params): |
|
|
|
|
# TODO: when using paddle.utils.unique_name.guard, |
|
|
|
|
# DeepLabv3p and HRNet will raise a error |
|
|
|
|
net = dict(paddleseg.models.__dict__, |
|
|
|
|
**cmseg.__dict__)[self.model_name]( |
|
|
|
|
num_classes=self.num_classes, **params) |
|
|
|
|
net = dict(ppseg.models.__dict__, **cmseg.__dict__)[self.model_name]( |
|
|
|
|
num_classes=self.num_classes, **params) |
|
|
|
|
return net |
|
|
|
|
|
|
|
|
|
def _fix_transforms_shape(self, image_shape): |
|
|
|
@ -143,7 +144,7 @@ class BaseSegmenter(BaseModel): |
|
|
|
|
origin_shape = [label.shape[-2:]] |
|
|
|
|
pred = self._postprocess( |
|
|
|
|
pred, origin_shape, transforms=inputs[2])[0] # NCHW |
|
|
|
|
intersect_area, pred_area, label_area = paddleseg.utils.metrics.calculate_area( |
|
|
|
|
intersect_area, pred_area, label_area = ppseg.utils.metrics.calculate_area( |
|
|
|
|
pred, label, self.num_classes) |
|
|
|
|
outputs['intersect_area'] = intersect_area |
|
|
|
|
outputs['pred_area'] = pred_area |
|
|
|
@ -161,16 +162,13 @@ class BaseSegmenter(BaseModel): |
|
|
|
|
if isinstance(self.use_mixed_loss, bool): |
|
|
|
|
if self.use_mixed_loss: |
|
|
|
|
losses = [ |
|
|
|
|
paddleseg.models.CrossEntropyLoss(), |
|
|
|
|
paddleseg.models.LovaszSoftmaxLoss() |
|
|
|
|
seg_losses.CrossEntropyLoss(), |
|
|
|
|
seg_losses.LovaszSoftmaxLoss() |
|
|
|
|
] |
|
|
|
|
coef = [.8, .2] |
|
|
|
|
loss_type = [ |
|
|
|
|
paddleseg.models.MixedLoss( |
|
|
|
|
losses=losses, coef=coef), |
|
|
|
|
] |
|
|
|
|
loss_type = [seg_losses.MixedLoss(losses=losses, coef=coef), ] |
|
|
|
|
else: |
|
|
|
|
loss_type = [paddleseg.models.CrossEntropyLoss()] |
|
|
|
|
loss_type = [seg_losses.CrossEntropyLoss()] |
|
|
|
|
else: |
|
|
|
|
losses, coef = list(zip(*self.use_mixed_loss)) |
|
|
|
|
if not set(losses).issubset( |
|
|
|
@ -178,11 +176,8 @@ class BaseSegmenter(BaseModel): |
|
|
|
|
raise ValueError( |
|
|
|
|
"Only 'CrossEntropyLoss', 'DiceLoss', 'LovaszSoftmaxLoss' are supported." |
|
|
|
|
) |
|
|
|
|
losses = [getattr(paddleseg.models, loss)() for loss in losses] |
|
|
|
|
loss_type = [ |
|
|
|
|
paddleseg.models.MixedLoss( |
|
|
|
|
losses=losses, coef=list(coef)) |
|
|
|
|
] |
|
|
|
|
losses = [getattr(seg_losses, loss)() for loss in losses] |
|
|
|
|
loss_type = [seg_losses.MixedLoss(losses=losses, coef=list(coef))] |
|
|
|
|
if self.model_name == 'FastSCNN': |
|
|
|
|
loss_type *= 2 |
|
|
|
|
loss_coef = [1.0, 0.4] |
|
|
|
@ -475,13 +470,13 @@ class BaseSegmenter(BaseModel): |
|
|
|
|
pred_area_all = pred_area_all + pred_area |
|
|
|
|
label_area_all = label_area_all + label_area |
|
|
|
|
conf_mat_all.append(conf_mat) |
|
|
|
|
class_iou, miou = paddleseg.utils.metrics.mean_iou( |
|
|
|
|
class_iou, miou = ppseg.utils.metrics.mean_iou( |
|
|
|
|
intersect_area_all, pred_area_all, label_area_all) |
|
|
|
|
# TODO 确认是按oacc还是macc |
|
|
|
|
class_acc, oacc = paddleseg.utils.metrics.accuracy(intersect_area_all, |
|
|
|
|
pred_area_all) |
|
|
|
|
kappa = paddleseg.utils.metrics.kappa(intersect_area_all, pred_area_all, |
|
|
|
|
label_area_all) |
|
|
|
|
class_acc, oacc = ppseg.utils.metrics.accuracy(intersect_area_all, |
|
|
|
|
pred_area_all) |
|
|
|
|
kappa = ppseg.utils.metrics.kappa(intersect_area_all, pred_area_all, |
|
|
|
|
label_area_all) |
|
|
|
|
category_f1score = metrics.f1_score(intersect_area_all, pred_area_all, |
|
|
|
|
label_area_all) |
|
|
|
|
eval_metrics = OrderedDict( |
|
|
|
@ -613,15 +608,15 @@ class BaseSegmenter(BaseModel): |
|
|
|
|
ysize = int(height - yoff) |
|
|
|
|
im = src_data.ReadAsArray(int(xoff), int(yoff), xsize, |
|
|
|
|
ysize).transpose((1, 2, 0)) |
|
|
|
|
# fill |
|
|
|
|
# Fill |
|
|
|
|
h, w = im.shape[:2] |
|
|
|
|
im_fill = np.zeros( |
|
|
|
|
(block_size[1], block_size[0], bands), dtype=im.dtype) |
|
|
|
|
im_fill[:h, :w, :] = im |
|
|
|
|
# predict |
|
|
|
|
# Predict |
|
|
|
|
pred = self.predict(im_fill, |
|
|
|
|
transforms)["label_map"].astype("uint8") |
|
|
|
|
# overlap |
|
|
|
|
# Overlap |
|
|
|
|
rd_block = band.ReadAsArray(int(xoff), int(yoff), xsize, ysize) |
|
|
|
|
mask = (rd_block == pred[:h, :w]) | (rd_block == 255) |
|
|
|
|
temp = pred[:h, :w].copy() |
|
|
|
@ -778,12 +773,18 @@ class BaseSegmenter(BaseModel): |
|
|
|
|
raise TypeError( |
|
|
|
|
"`transforms.arrange` must be an ArrangeSegmenter object.") |
|
|
|
|
|
|
|
|
|
def set_losses(self, losses, weights=None): |
|
|
|
|
if weights is None: |
|
|
|
|
weights = [1. for _ in range(len(losses))] |
|
|
|
|
self.losses = {'types': losses, 'coef': weights} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class UNet(BaseSegmenter): |
|
|
|
|
def __init__(self, |
|
|
|
|
input_channel=3, |
|
|
|
|
num_classes=2, |
|
|
|
|
use_mixed_loss=False, |
|
|
|
|
losses=None, |
|
|
|
|
use_deconv=False, |
|
|
|
|
align_corners=False, |
|
|
|
|
**params): |
|
|
|
@ -796,6 +797,7 @@ class UNet(BaseSegmenter): |
|
|
|
|
input_channel=input_channel, |
|
|
|
|
num_classes=num_classes, |
|
|
|
|
use_mixed_loss=use_mixed_loss, |
|
|
|
|
losses=losses, |
|
|
|
|
**params) |
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -805,6 +807,7 @@ class DeepLabV3P(BaseSegmenter): |
|
|
|
|
num_classes=2, |
|
|
|
|
backbone='ResNet50_vd', |
|
|
|
|
use_mixed_loss=False, |
|
|
|
|
losses=None, |
|
|
|
|
output_stride=8, |
|
|
|
|
backbone_indices=(0, 3), |
|
|
|
|
aspp_ratios=(1, 12, 24, 36), |
|
|
|
@ -818,7 +821,7 @@ class DeepLabV3P(BaseSegmenter): |
|
|
|
|
"{'ResNet50_vd', 'ResNet101_vd'}.".format(backbone)) |
|
|
|
|
if params.get('with_net', True): |
|
|
|
|
with DisablePrint(): |
|
|
|
|
backbone = getattr(paddleseg.models, backbone)( |
|
|
|
|
backbone = getattr(ppseg.models, backbone)( |
|
|
|
|
input_channel=input_channel, output_stride=output_stride) |
|
|
|
|
else: |
|
|
|
|
backbone = None |
|
|
|
@ -833,6 +836,7 @@ class DeepLabV3P(BaseSegmenter): |
|
|
|
|
model_name='DeepLabV3P', |
|
|
|
|
num_classes=num_classes, |
|
|
|
|
use_mixed_loss=use_mixed_loss, |
|
|
|
|
losses=losses, |
|
|
|
|
**params) |
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -840,6 +844,7 @@ class FastSCNN(BaseSegmenter): |
|
|
|
|
def __init__(self, |
|
|
|
|
num_classes=2, |
|
|
|
|
use_mixed_loss=False, |
|
|
|
|
losses=None, |
|
|
|
|
align_corners=False, |
|
|
|
|
**params): |
|
|
|
|
params.update({'align_corners': align_corners}) |
|
|
|
@ -847,6 +852,7 @@ class FastSCNN(BaseSegmenter): |
|
|
|
|
model_name='FastSCNN', |
|
|
|
|
num_classes=num_classes, |
|
|
|
|
use_mixed_loss=use_mixed_loss, |
|
|
|
|
losses=losses, |
|
|
|
|
**params) |
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -855,6 +861,7 @@ class HRNet(BaseSegmenter): |
|
|
|
|
num_classes=2, |
|
|
|
|
width=48, |
|
|
|
|
use_mixed_loss=False, |
|
|
|
|
losses=None, |
|
|
|
|
align_corners=False, |
|
|
|
|
**params): |
|
|
|
|
if width not in (18, 48): |
|
|
|
@ -864,7 +871,7 @@ class HRNet(BaseSegmenter): |
|
|
|
|
self.backbone_name = 'HRNet_W{}'.format(width) |
|
|
|
|
if params.get('with_net', True): |
|
|
|
|
with DisablePrint(): |
|
|
|
|
backbone = getattr(paddleseg.models, self.backbone_name)( |
|
|
|
|
backbone = getattr(ppseg.models, self.backbone_name)( |
|
|
|
|
align_corners=align_corners) |
|
|
|
|
else: |
|
|
|
|
backbone = None |
|
|
|
@ -874,6 +881,7 @@ class HRNet(BaseSegmenter): |
|
|
|
|
model_name='FCN', |
|
|
|
|
num_classes=num_classes, |
|
|
|
|
use_mixed_loss=use_mixed_loss, |
|
|
|
|
losses=losses, |
|
|
|
|
**params) |
|
|
|
|
self.model_name = 'HRNet' |
|
|
|
|
|
|
|
|
@ -882,6 +890,7 @@ class BiSeNetV2(BaseSegmenter): |
|
|
|
|
def __init__(self, |
|
|
|
|
num_classes=2, |
|
|
|
|
use_mixed_loss=False, |
|
|
|
|
losses=None, |
|
|
|
|
align_corners=False, |
|
|
|
|
**params): |
|
|
|
|
params.update({'align_corners': align_corners}) |
|
|
|
@ -889,13 +898,19 @@ class BiSeNetV2(BaseSegmenter): |
|
|
|
|
model_name='BiSeNetV2', |
|
|
|
|
num_classes=num_classes, |
|
|
|
|
use_mixed_loss=use_mixed_loss, |
|
|
|
|
losses=losses, |
|
|
|
|
**params) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FarSeg(BaseSegmenter): |
|
|
|
|
def __init__(self, num_classes=2, use_mixed_loss=False, **params): |
|
|
|
|
def __init__(self, |
|
|
|
|
num_classes=2, |
|
|
|
|
use_mixed_loss=False, |
|
|
|
|
losses=None, |
|
|
|
|
**params): |
|
|
|
|
super(FarSeg, self).__init__( |
|
|
|
|
model_name='FarSeg', |
|
|
|
|
num_classes=num_classes, |
|
|
|
|
use_mixed_loss=use_mixed_loss, |
|
|
|
|
losses=losses, |
|
|
|
|
**params) |
|
|
|
|