|
|
|
@ -781,7 +781,7 @@ class BaseSegmenter(BaseModel): |
|
|
|
|
|
|
|
|
|
class UNet(BaseSegmenter): |
|
|
|
|
def __init__(self, |
|
|
|
|
input_channel=3, |
|
|
|
|
in_channels=3, |
|
|
|
|
num_classes=2, |
|
|
|
|
use_mixed_loss=False, |
|
|
|
|
use_deconv=False, |
|
|
|
@ -793,7 +793,7 @@ class UNet(BaseSegmenter): |
|
|
|
|
}) |
|
|
|
|
super(UNet, self).__init__( |
|
|
|
|
model_name='UNet', |
|
|
|
|
input_channel=input_channel, |
|
|
|
|
input_channel=in_channels, |
|
|
|
|
num_classes=num_classes, |
|
|
|
|
use_mixed_loss=use_mixed_loss, |
|
|
|
|
**params) |
|
|
|
@ -801,7 +801,7 @@ class UNet(BaseSegmenter): |
|
|
|
|
|
|
|
|
|
class DeepLabV3P(BaseSegmenter): |
|
|
|
|
def __init__(self, |
|
|
|
|
input_channel=3, |
|
|
|
|
in_channels=3, |
|
|
|
|
num_classes=2, |
|
|
|
|
backbone='ResNet50_vd', |
|
|
|
|
use_mixed_loss=False, |
|
|
|
@ -819,7 +819,7 @@ class DeepLabV3P(BaseSegmenter): |
|
|
|
|
if params.get('with_net', True): |
|
|
|
|
with DisablePrint(): |
|
|
|
|
backbone = getattr(paddleseg.models, backbone)( |
|
|
|
|
input_channel=input_channel, output_stride=output_stride) |
|
|
|
|
input_channel=in_channels, output_stride=output_stride) |
|
|
|
|
else: |
|
|
|
|
backbone = None |
|
|
|
|
params.update({ |
|
|
|
|