|
|
|
@ -72,7 +72,7 @@ class STANet(nn.Layer): |
|
|
|
|
f1, f2 = self.attend(f1, f2) |
|
|
|
|
|
|
|
|
|
y = paddle.abs(f1- f2) |
|
|
|
|
y = F.interpolate(y, size=t1.shape[2:], mode='bilinear', align_corners=True) |
|
|
|
|
y = F.interpolate(y, size=paddle.shape(t1)[2:], mode='bilinear', align_corners=True) |
|
|
|
|
|
|
|
|
|
pred = self.conv_out(y) |
|
|
|
|
return pred, |
|
|
|
@ -166,9 +166,9 @@ class Decoder(nn.Layer, KaimingInitMixin): |
|
|
|
|
f3 = self.dr3(feats[2]) |
|
|
|
|
f4 = self.dr4(feats[3]) |
|
|
|
|
|
|
|
|
|
f2 = F.interpolate(f2, size=f1.shape[2:], mode='bilinear', align_corners=True) |
|
|
|
|
f3 = F.interpolate(f3, size=f1.shape[2:], mode='bilinear', align_corners=True) |
|
|
|
|
f4 = F.interpolate(f4, size=f1.shape[2:], mode='bilinear', align_corners=True) |
|
|
|
|
f2 = F.interpolate(f2, size=paddle.shape(f1)[2:], mode='bilinear', align_corners=True) |
|
|
|
|
f3 = F.interpolate(f3, size=paddle.shape(f1)[2:], mode='bilinear', align_corners=True) |
|
|
|
|
f4 = F.interpolate(f4, size=paddle.shape(f1)[2:], mode='bilinear', align_corners=True) |
|
|
|
|
|
|
|
|
|
x = paddle.concat([f1, f2, f3, f4], axis=1) |
|
|
|
|
y = self.conv_out(x) |
|
|
|
@ -195,7 +195,7 @@ class BAM(nn.Layer): |
|
|
|
|
x = x.flatten(-2) |
|
|
|
|
x_rs = self.pool(x) |
|
|
|
|
|
|
|
|
|
b, c, h, w = x_rs.shape |
|
|
|
|
b, c, h, w = paddle.shape(x_rs) |
|
|
|
|
query = self.conv_q(x_rs).reshape((b,-1,h*w)).transpose((0,2,1)) |
|
|
|
|
key = self.conv_k(x_rs).reshape((b,-1,h*w)) |
|
|
|
|
energy = paddle.bmm(query, key) |
|
|
|
@ -236,7 +236,7 @@ class PAMBlock(nn.Layer): |
|
|
|
|
value = self.conv_v(x_rs) |
|
|
|
|
|
|
|
|
|
# Split the whole image into subregions. |
|
|
|
|
b, c, h, w = x_rs.shape |
|
|
|
|
b, c, h, w = paddle.shape(x_rs) |
|
|
|
|
query = self._split_subregions(query) |
|
|
|
|
key = self._split_subregions(key) |
|
|
|
|
value = self._split_subregions(value) |
|
|
|
@ -257,7 +257,7 @@ class PAMBlock(nn.Layer): |
|
|
|
|
return out |
|
|
|
|
|
|
|
|
|
def _split_subregions(self, x): |
|
|
|
|
b, c, h, w = x.shape |
|
|
|
|
b, c, h, w = paddle.shape(x) |
|
|
|
|
assert h % self.scale == 0 and w % self.scale == 0 |
|
|
|
|
x = x.reshape((b, c, self.scale, h//self.scale, self.scale, w//self.scale)) |
|
|
|
|
x = x.transpose((0,2,4,1,3,5)).reshape((b*self.scale*self.scale, c, -1)) |
|
|
|
|