You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
189 lines
6.3 KiB
189 lines
6.3 KiB
# code was heavily based on https://github.com/wtjiang98/PSGAN |
|
# MIT License |
|
# Copyright (c) 2020 Wentao Jiang |
|
|
|
import paddle |
|
import paddle.nn as nn |
|
import paddle.nn.functional as F |
|
|
|
from paddle.vision.models import resnet18 |
|
|
|
|
|
class ConvBNReLU(paddle.nn.Layer): |
|
def __init__(self, |
|
in_chan, |
|
out_chan, |
|
ks=3, |
|
stride=1, |
|
padding=1, |
|
*args, |
|
**kwargs): |
|
super(ConvBNReLU, self).__init__() |
|
self.conv = nn.Conv2D( |
|
in_chan, |
|
out_chan, |
|
kernel_size=ks, |
|
stride=stride, |
|
padding=padding, |
|
bias_attr=False) |
|
self.bn = nn.BatchNorm2D(out_chan) |
|
self.relu = nn.ReLU() |
|
|
|
def forward(self, x): |
|
x = self.conv(x) |
|
x = self.bn(x) |
|
x = self.relu(x) |
|
return x |
|
|
|
|
|
class BiSeNetOutput(paddle.nn.Layer): |
|
def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs): |
|
super(BiSeNetOutput, self).__init__() |
|
self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1) |
|
self.conv_out = nn.Conv2D( |
|
mid_chan, n_classes, kernel_size=1, bias_attr=False) |
|
|
|
def forward(self, x): |
|
x = self.conv(x) |
|
x = self.conv_out(x) |
|
return x |
|
|
|
|
|
class AttentionRefinementModule(paddle.nn.Layer): |
|
def __init__(self, in_chan, out_chan, *args, **kwargs): |
|
super(AttentionRefinementModule, self).__init__() |
|
self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1) |
|
self.conv_atten = nn.Conv2D( |
|
out_chan, out_chan, kernel_size=1, bias_attr=False) |
|
self.bn_atten = nn.BatchNorm(out_chan) |
|
self.sigmoid_atten = nn.Sigmoid() |
|
|
|
def forward(self, x): |
|
feat = self.conv(x) |
|
atten = F.avg_pool2d(feat, feat.shape[2:]) |
|
atten = self.conv_atten(atten) |
|
atten = self.bn_atten(atten) |
|
atten = self.sigmoid_atten(atten) |
|
out = feat * atten |
|
return out |
|
|
|
|
|
class ContextPath(paddle.nn.Layer): |
|
def __init__(self, *args, **kwargs): |
|
super(ContextPath, self).__init__() |
|
self.backbone = resnet18(pretrained=True) |
|
self.arm16 = AttentionRefinementModule(256, 128) |
|
self.arm32 = AttentionRefinementModule(512, 128) |
|
self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) |
|
self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) |
|
self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0) |
|
|
|
def backbone_forward(self, x): |
|
x = self.backbone.conv1(x) |
|
x = self.backbone.bn1(x) |
|
x = self.backbone.relu(x) |
|
x = self.backbone.maxpool(x) |
|
x = self.backbone.layer1(x) |
|
c2 = self.backbone.layer2(x) |
|
c3 = self.backbone.layer3(c2) |
|
c4 = self.backbone.layer4(c3) |
|
return c2, c3, c4 |
|
|
|
def forward(self, x): |
|
H0, W0 = x.shape[2:] |
|
feat8, feat16, feat32 = self.backbone_forward(x) |
|
H8, W8 = feat8.shape[2:] |
|
H16, W16 = feat16.shape[2:] |
|
H32, W32 = feat32.shape[2:] |
|
|
|
avg = F.avg_pool2d(feat32, feat32.shape[2:]) |
|
avg = self.conv_avg(avg) |
|
avg_up = F.interpolate(avg, size=(H32, W32), mode='nearest') |
|
|
|
feat32_arm = self.arm32(feat32) |
|
feat32_sum = feat32_arm + avg_up |
|
feat32_up = F.interpolate(feat32_sum, size=(H16, W16), mode='nearest') |
|
feat32_up = self.conv_head32(feat32_up) |
|
|
|
feat16_arm = self.arm16(feat16) |
|
feat16_sum = feat16_arm + feat32_up |
|
feat16_up = F.interpolate(feat16_sum, size=(H8, W8), mode='nearest') |
|
feat16_up = self.conv_head16(feat16_up) |
|
|
|
return feat8, feat16_up, feat32_up # x8, x8, x16 |
|
|
|
|
|
class SpatialPath(paddle.nn.Layer): |
|
def __init__(self, *args, **kwargs): |
|
super(SpatialPath, self).__init__() |
|
self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3) |
|
self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1) |
|
self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1) |
|
self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0) |
|
|
|
def forward(self, x): |
|
feat = self.conv1(x) |
|
feat = self.conv2(feat) |
|
feat = self.conv3(feat) |
|
feat = self.conv_out(feat) |
|
return feat |
|
|
|
|
|
class FeatureFusionModule(paddle.nn.Layer): |
|
def __init__(self, in_chan, out_chan, *args, **kwargs): |
|
super(FeatureFusionModule, self).__init__() |
|
self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0) |
|
self.conv1 = nn.Conv2D( |
|
out_chan, |
|
out_chan // 4, |
|
kernel_size=1, |
|
stride=1, |
|
padding=0, |
|
bias_attr=False) |
|
self.conv2 = nn.Conv2D( |
|
out_chan // 4, |
|
out_chan, |
|
kernel_size=1, |
|
stride=1, |
|
padding=0, |
|
bias_attr=False) |
|
self.relu = nn.ReLU() |
|
self.sigmoid = nn.Sigmoid() |
|
|
|
def forward(self, fsp, fcp): |
|
fcat = paddle.concat([fsp, fcp], axis=1) |
|
feat = self.convblk(fcat) |
|
atten = F.avg_pool2d(feat, feat.shape[2:]) |
|
atten = self.conv1(atten) |
|
atten = self.relu(atten) |
|
atten = self.conv2(atten) |
|
atten = self.sigmoid(atten) |
|
feat_atten = feat * atten |
|
feat_out = feat_atten + feat |
|
return feat_out |
|
|
|
|
|
class BiSeNet(paddle.nn.Layer): |
|
def __init__(self, n_classes, *args, **kwargs): |
|
super(BiSeNet, self).__init__() |
|
self.cp = ContextPath() |
|
self.ffm = FeatureFusionModule(256, 256) |
|
self.conv_out = BiSeNetOutput(256, 256, n_classes) |
|
self.conv_out16 = BiSeNetOutput(128, 64, n_classes) |
|
self.conv_out32 = BiSeNetOutput(128, 64, n_classes) |
|
|
|
def forward(self, x): |
|
H, W = x.shape[2:] |
|
feat_res8, feat_cp8, feat_cp16 = self.cp( |
|
x) # here return res3b1 feature |
|
feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature |
|
feat_fuse = self.ffm(feat_sp, feat_cp8) |
|
|
|
feat_out = self.conv_out(feat_fuse) |
|
feat_out16 = self.conv_out16(feat_cp8) |
|
feat_out32 = self.conv_out32(feat_cp16) |
|
|
|
feat_out = F.interpolate(feat_out, size=(H, W)) |
|
feat_out16 = F.interpolate(feat_out16, size=(H, W)) |
|
feat_out32 = F.interpolate(feat_out32, size=(H, W)) |
|
return feat_out, feat_out16, feat_out32
|
|
|