Replace x.shape with paddle.shape(x)

own
Bobholamovic 3 years ago
parent d35cffa344
commit ab59d5cacb
  1. 10
      paddlers/models/cd/models/bit.py
  2. 14
      paddlers/models/cd/models/stanet.py

@ -122,7 +122,7 @@ class BIT(nn.Layer):
)
def _get_semantic_tokens(self, x):
b, c = x.shape[:2]
b, c = paddle.shape(x)[:2]
att_map = self.conv_att(x)
att_map = att_map.reshape((b,self.token_len,1,-1))
att_map = F.softmax(att_map, axis=-1)
@ -147,7 +147,7 @@ class BIT(nn.Layer):
return x
def decode(self, x, m):
b, c, h, w = x.shape
b, c, h, w = paddle.shape(x)
x = x.transpose((0,2,3,1)).flatten(1,2)
x = self.decoder(x, m)
x = x.transpose((0,2,1)).reshape((b,c,h,w))
@ -257,7 +257,7 @@ class CrossAttention(nn.Layer):
)
def forward(self, x, ref):
b, n = x.shape[:2]
b, n = paddle.shape(x)[:2]
h = self.n_heads
q = self.fc_q(x)
@ -265,8 +265,8 @@ class CrossAttention(nn.Layer):
v = self.fc_v(ref)
q = q.reshape((b,n,h,-1)).transpose((0,2,1,3))
k = k.reshape((b,ref.shape[1],h,-1)).transpose((0,2,1,3))
v = v.reshape((b,ref.shape[1],h,-1)).transpose((0,2,1,3))
k = k.reshape((b,paddle.shape(ref)[1],h,-1)).transpose((0,2,1,3))
v = v.reshape((b,paddle.shape(ref)[1],h,-1)).transpose((0,2,1,3))
mult = paddle.matmul(q, k, transpose_y=True) * self.scale

@ -72,7 +72,7 @@ class STANet(nn.Layer):
f1, f2 = self.attend(f1, f2)
y = paddle.abs(f1- f2)
y = F.interpolate(y, size=t1.shape[2:], mode='bilinear', align_corners=True)
y = F.interpolate(y, size=paddle.shape(t1)[2:], mode='bilinear', align_corners=True)
pred = self.conv_out(y)
return pred,
@ -166,9 +166,9 @@ class Decoder(nn.Layer, KaimingInitMixin):
f3 = self.dr3(feats[2])
f4 = self.dr4(feats[3])
f2 = F.interpolate(f2, size=f1.shape[2:], mode='bilinear', align_corners=True)
f3 = F.interpolate(f3, size=f1.shape[2:], mode='bilinear', align_corners=True)
f4 = F.interpolate(f4, size=f1.shape[2:], mode='bilinear', align_corners=True)
f2 = F.interpolate(f2, size=paddle.shape(f1)[2:], mode='bilinear', align_corners=True)
f3 = F.interpolate(f3, size=paddle.shape(f1)[2:], mode='bilinear', align_corners=True)
f4 = F.interpolate(f4, size=paddle.shape(f1)[2:], mode='bilinear', align_corners=True)
x = paddle.concat([f1, f2, f3, f4], axis=1)
y = self.conv_out(x)
@ -195,7 +195,7 @@ class BAM(nn.Layer):
x = x.flatten(-2)
x_rs = self.pool(x)
b, c, h, w = x_rs.shape
b, c, h, w = paddle.shape(x_rs)
query = self.conv_q(x_rs).reshape((b,-1,h*w)).transpose((0,2,1))
key = self.conv_k(x_rs).reshape((b,-1,h*w))
energy = paddle.bmm(query, key)
@ -236,7 +236,7 @@ class PAMBlock(nn.Layer):
value = self.conv_v(x_rs)
# Split the whole image into subregions.
b, c, h, w = x_rs.shape
b, c, h, w = paddle.shape(x_rs)
query = self._split_subregions(query)
key = self._split_subregions(key)
value = self._split_subregions(value)
@ -257,7 +257,7 @@ class PAMBlock(nn.Layer):
return out
def _split_subregions(self, x):
b, c, h, w = x.shape
b, c, h, w = paddle.shape(x)
assert h % self.scale == 0 and w % self.scale == 0
x = x.reshape((b, c, self.scale, h//self.scale, self.scale, w//self.scale))
x = x.transpose((0,2,4,1,3,5)).reshape((b*self.scale*self.scale, c, -1))

Loading…
Cancel
Save