Revert "x.shape->paddle.shape(x)"

This reverts commit d4232c036d47b02e4e9f474def69245462d76ca7.
own
Bobholamovic 2 years ago
parent a2e745006e
commit 8176ea16aa
  1. 12
      paddlers/custom_models/cd/bit.py
  2. 12
      paddlers/custom_models/cd/fc_ef.py
  3. 16
      paddlers/custom_models/cd/fc_siam_conc.py
  4. 16
      paddlers/custom_models/cd/fc_siam_diff.py
  5. 10
      paddlers/custom_models/cd/stanet.py
  6. 14
      paddlers/custom_models/cls/condensenet_v2.py
  7. 2
      paddlers/datasets/coco.py
  8. 2
      paddlers/datasets/voc.py

@ -138,12 +138,12 @@ class BIT(nn.Layer):
Conv3x3(EBD_DIM, num_classes)) Conv3x3(EBD_DIM, num_classes))
def _get_semantic_tokens(self, x): def _get_semantic_tokens(self, x):
b, c = paddle.shape(x)[:2] b, c = x.shape[:2]
att_map = self.conv_att(x) att_map = self.conv_att(x)
att_map = att_map.reshape( att_map = att_map.reshape(
(b, self.token_len, 1, calc_product(*paddle.shape(att_map)[2:]))) (b, self.token_len, 1, calc_product(*att_map.shape[2:])))
att_map = F.softmax(att_map, axis=-1) att_map = F.softmax(att_map, axis=-1)
x = x.reshape((b, 1, c, paddle.shape(att_map)[-1])) x = x.reshape((b, 1, c, att_map.shape[-1]))
tokens = (x * att_map).sum(-1) tokens = (x * att_map).sum(-1)
return tokens return tokens
@ -164,7 +164,7 @@ class BIT(nn.Layer):
return x return x
def decode(self, x, m): def decode(self, x, m):
b, c, h, w = paddle.shape(x) b, c, h, w = x.shape
x = x.transpose((0, 2, 3, 1)).flatten(1, 2) x = x.transpose((0, 2, 3, 1)).flatten(1, 2)
x = self.decoder(x, m) x = self.decoder(x, m)
x = x.transpose((0, 2, 1)).reshape((b, c, h, w)) x = x.transpose((0, 2, 1)).reshape((b, c, h, w))
@ -276,7 +276,7 @@ class CrossAttention(nn.Layer):
nn.Linear(inner_dim, dim), nn.Dropout(dropout_rate)) nn.Linear(inner_dim, dim), nn.Dropout(dropout_rate))
def forward(self, x, ref): def forward(self, x, ref):
b, n = paddle.shape(x)[:2] b, n = x.shape[:2]
h = self.n_heads h = self.n_heads
q = self.fc_q(x) q = self.fc_q(x)
@ -284,7 +284,7 @@ class CrossAttention(nn.Layer):
v = self.fc_v(ref) v = self.fc_v(ref)
q = q.reshape((b, n, h, self.head_dim)).transpose((0, 2, 1, 3)) q = q.reshape((b, n, h, self.head_dim)).transpose((0, 2, 1, 3))
rn = paddle.shape(ref)[1] rn = ref.shape[1]
k = k.reshape((b, rn, h, self.head_dim)).transpose((0, 2, 1, 3)) k = k.reshape((b, rn, h, self.head_dim)).transpose((0, 2, 1, 3))
v = v.reshape((b, rn, h, self.head_dim)).transpose((0, 2, 1, 3)) v = v.reshape((b, rn, h, self.head_dim)).transpose((0, 2, 1, 3))

@ -131,8 +131,7 @@ class FCEarlyFusion(nn.Layer):
# Stage 4d # Stage 4d
x4d = self.upconv4(x4p) x4d = self.upconv4(x4p)
pad4 = (0, paddle.shape(x43)[3] - paddle.shape(x4d)[3], 0, pad4 = (0, x43.shape[3] - x4d.shape[3], 0, x43.shape[2] - x4d.shape[2])
paddle.shape(x43)[2] - paddle.shape(x4d)[2])
x4d = paddle.concat([F.pad(x4d, pad=pad4, mode='replicate'), x43], 1) x4d = paddle.concat([F.pad(x4d, pad=pad4, mode='replicate'), x43], 1)
x43d = self.do43d(self.conv43d(x4d)) x43d = self.do43d(self.conv43d(x4d))
x42d = self.do42d(self.conv42d(x43d)) x42d = self.do42d(self.conv42d(x43d))
@ -140,8 +139,7 @@ class FCEarlyFusion(nn.Layer):
# Stage 3d # Stage 3d
x3d = self.upconv3(x41d) x3d = self.upconv3(x41d)
pad3 = (0, paddle.shape(x33)[3] - paddle.shape(x3d)[3], 0, pad3 = (0, x33.shape[3] - x3d.shape[3], 0, x33.shape[2] - x3d.shape[2])
paddle.shape(x33)[2] - paddle.shape(x3d)[2])
x3d = paddle.concat([F.pad(x3d, pad=pad3, mode='replicate'), x33], 1) x3d = paddle.concat([F.pad(x3d, pad=pad3, mode='replicate'), x33], 1)
x33d = self.do33d(self.conv33d(x3d)) x33d = self.do33d(self.conv33d(x3d))
x32d = self.do32d(self.conv32d(x33d)) x32d = self.do32d(self.conv32d(x33d))
@ -149,16 +147,14 @@ class FCEarlyFusion(nn.Layer):
# Stage 2d # Stage 2d
x2d = self.upconv2(x31d) x2d = self.upconv2(x31d)
pad2 = (0, paddle.shape(x22)[3] - paddle.shape(x2d)[3], 0, pad2 = (0, x22.shape[3] - x2d.shape[3], 0, x22.shape[2] - x2d.shape[2])
paddle.shape(x22)[2] - paddle.shape(x2d)[2])
x2d = paddle.concat([F.pad(x2d, pad=pad2, mode='replicate'), x22], 1) x2d = paddle.concat([F.pad(x2d, pad=pad2, mode='replicate'), x22], 1)
x22d = self.do22d(self.conv22d(x2d)) x22d = self.do22d(self.conv22d(x2d))
x21d = self.do21d(self.conv21d(x22d)) x21d = self.do21d(self.conv21d(x22d))
# Stage 1d # Stage 1d
x1d = self.upconv1(x21d) x1d = self.upconv1(x21d)
pad1 = (0, paddle.shape(x12)[3] - paddle.shape(x1d)[3], 0, pad1 = (0, x12.shape[3] - x1d.shape[3], 0, x12.shape[2] - x1d.shape[2])
paddle.shape(x12)[2] - paddle.shape(x1d)[2])
x1d = paddle.concat([F.pad(x1d, pad=pad1, mode='replicate'), x12], 1) x1d = paddle.concat([F.pad(x1d, pad=pad1, mode='replicate'), x12], 1)
x12d = self.do12d(self.conv12d(x1d)) x12d = self.do12d(self.conv12d(x1d))
x11d = self.conv11d(x12d) x11d = self.conv11d(x12d)

@ -154,8 +154,8 @@ class FCSiamConc(nn.Layer):
# Decode # Decode
# Stage 4d # Stage 4d
x4d = self.upconv4(x4p) x4d = self.upconv4(x4p)
pad4 = (0, paddle.shape(x43_1)[3] - paddle.shape(x4d)[3], 0, pad4 = (0, x43_1.shape[3] - x4d.shape[3], 0,
paddle.shape(x43_1)[2] - paddle.shape(x4d)[2]) x43_1.shape[2] - x4d.shape[2])
x4d = paddle.concat( x4d = paddle.concat(
[F.pad(x4d, pad=pad4, mode='replicate'), x43_1, x43_2], 1) [F.pad(x4d, pad=pad4, mode='replicate'), x43_1, x43_2], 1)
x43d = self.do43d(self.conv43d(x4d)) x43d = self.do43d(self.conv43d(x4d))
@ -164,8 +164,8 @@ class FCSiamConc(nn.Layer):
# Stage 3d # Stage 3d
x3d = self.upconv3(x41d) x3d = self.upconv3(x41d)
pad3 = (0, paddle.shape(x33_1)[3] - paddle.shape(x3d)[3], 0, pad3 = (0, x33_1.shape[3] - x3d.shape[3], 0,
paddle.shape(x33_1)[2] - paddle.shape(x3d)[2]) x33_1.shape[2] - x3d.shape[2])
x3d = paddle.concat( x3d = paddle.concat(
[F.pad(x3d, pad=pad3, mode='replicate'), x33_1, x33_2], 1) [F.pad(x3d, pad=pad3, mode='replicate'), x33_1, x33_2], 1)
x33d = self.do33d(self.conv33d(x3d)) x33d = self.do33d(self.conv33d(x3d))
@ -174,8 +174,8 @@ class FCSiamConc(nn.Layer):
# Stage 2d # Stage 2d
x2d = self.upconv2(x31d) x2d = self.upconv2(x31d)
pad2 = (0, paddle.shape(x22_1)[3] - paddle.shape(x2d)[3], 0, pad2 = (0, x22_1.shape[3] - x2d.shape[3], 0,
paddle.shape(x22_1)[2] - paddle.shape(x2d)[2]) x22_1.shape[2] - x2d.shape[2])
x2d = paddle.concat( x2d = paddle.concat(
[F.pad(x2d, pad=pad2, mode='replicate'), x22_1, x22_2], 1) [F.pad(x2d, pad=pad2, mode='replicate'), x22_1, x22_2], 1)
x22d = self.do22d(self.conv22d(x2d)) x22d = self.do22d(self.conv22d(x2d))
@ -183,8 +183,8 @@ class FCSiamConc(nn.Layer):
# Stage 1d # Stage 1d
x1d = self.upconv1(x21d) x1d = self.upconv1(x21d)
pad1 = (0, paddle.shape(x12_1)[3] - paddle.shape(x1d)[3], 0, pad1 = (0, x12_1.shape[3] - x1d.shape[3], 0,
paddle.shape(x12_1)[2] - paddle.shape(x1d)[2]) x12_1.shape[2] - x1d.shape[2])
x1d = paddle.concat( x1d = paddle.concat(
[F.pad(x1d, pad=pad1, mode='replicate'), x12_1, x12_2], 1) [F.pad(x1d, pad=pad1, mode='replicate'), x12_1, x12_2], 1)
x12d = self.do12d(self.conv12d(x1d)) x12d = self.do12d(self.conv12d(x1d))

@ -154,8 +154,8 @@ class FCSiamDiff(nn.Layer):
# Decode # Decode
# Stage 4d # Stage 4d
x4d = self.upconv4(x4p) x4d = self.upconv4(x4p)
pad4 = (0, paddle.shape(x43_1)[3] - paddle.shape(x4d)[3], 0, pad4 = (0, x43_1.shape[3] - x4d.shape[3], 0,
paddle.shape(x43_1)[2] - paddle.shape(x4d)[2]) x43_1.shape[2] - x4d.shape[2])
x4d = F.pad(x4d, pad=pad4, mode='replicate') x4d = F.pad(x4d, pad=pad4, mode='replicate')
x4d = paddle.concat([x4d, paddle.abs(x43_1 - x43_2)], 1) x4d = paddle.concat([x4d, paddle.abs(x43_1 - x43_2)], 1)
x43d = self.do43d(self.conv43d(x4d)) x43d = self.do43d(self.conv43d(x4d))
@ -164,8 +164,8 @@ class FCSiamDiff(nn.Layer):
# Stage 3d # Stage 3d
x3d = self.upconv3(x41d) x3d = self.upconv3(x41d)
pad3 = (0, paddle.shape(x33_1)[3] - paddle.shape(x3d)[3], 0, pad3 = (0, x33_1.shape[3] - x3d.shape[3], 0,
paddle.shape(x33_1)[2] - paddle.shape(x3d)[2]) x33_1.shape[2] - x3d.shape[2])
x3d = F.pad(x3d, pad=pad3, mode='replicate') x3d = F.pad(x3d, pad=pad3, mode='replicate')
x3d = paddle.concat([x3d, paddle.abs(x33_1 - x33_2)], 1) x3d = paddle.concat([x3d, paddle.abs(x33_1 - x33_2)], 1)
x33d = self.do33d(self.conv33d(x3d)) x33d = self.do33d(self.conv33d(x3d))
@ -174,8 +174,8 @@ class FCSiamDiff(nn.Layer):
# Stage 2d # Stage 2d
x2d = self.upconv2(x31d) x2d = self.upconv2(x31d)
pad2 = (0, paddle.shape(x22_1)[3] - paddle.shape(x2d)[3], 0, pad2 = (0, x22_1.shape[3] - x2d.shape[3], 0,
paddle.shape(x22_1)[2] - paddle.shape(x2d)[2]) x22_1.shape[2] - x2d.shape[2])
x2d = F.pad(x2d, pad=pad2, mode='replicate') x2d = F.pad(x2d, pad=pad2, mode='replicate')
x2d = paddle.concat([x2d, paddle.abs(x22_1 - x22_2)], 1) x2d = paddle.concat([x2d, paddle.abs(x22_1 - x22_2)], 1)
x22d = self.do22d(self.conv22d(x2d)) x22d = self.do22d(self.conv22d(x2d))
@ -183,8 +183,8 @@ class FCSiamDiff(nn.Layer):
# Stage 1d # Stage 1d
x1d = self.upconv1(x21d) x1d = self.upconv1(x21d)
pad1 = (0, paddle.shape(x12_1)[3] - paddle.shape(x1d)[3], 0, pad1 = (0, x12_1.shape[3] - x1d.shape[3], 0,
paddle.shape(x12_1)[2] - paddle.shape(x1d)[2]) x12_1.shape[2] - x1d.shape[2])
x1d = F.pad(x1d, pad=pad1, mode='replicate') x1d = F.pad(x1d, pad=pad1, mode='replicate')
x1d = paddle.concat([x1d, paddle.abs(x12_1 - x12_2)], 1) x1d = paddle.concat([x1d, paddle.abs(x12_1 - x12_2)], 1)
x12d = self.do12d(self.conv12d(x1d)) x12d = self.do12d(self.conv12d(x1d))

@ -215,8 +215,7 @@ class BAM(nn.Layer):
out = F.interpolate(out, scale_factor=self.ds) out = F.interpolate(out, scale_factor=self.ds)
out = out + x out = out + x
return out.reshape( return out.reshape(tuple(out.shape[:-1]) + (out.shape[-1] // 2, 2))
tuple(paddle.shape(out)[:-1]) + (paddle.shape(out)[-1] // 2, 2))
class PAMBlock(nn.Layer): class PAMBlock(nn.Layer):
@ -242,7 +241,7 @@ class PAMBlock(nn.Layer):
value = self.conv_v(x_rs) value = self.conv_v(x_rs)
# Split the whole image into subregions. # Split the whole image into subregions.
b, c, h, w = paddle.shape(x_rs) b, c, h, w = x_rs.shape
query = self._split_subregions(query) query = self._split_subregions(query)
key = self._split_subregions(key) key = self._split_subregions(key)
@ -265,7 +264,7 @@ class PAMBlock(nn.Layer):
return out return out
def _split_subregions(self, x): def _split_subregions(self, x):
b, c, h, w = paddle.shape(x) b, c, h, w = x.shape
assert h % self.scale == 0 and w % self.scale == 0 assert h % self.scale == 0 and w % self.scale == 0
x = x.reshape( x = x.reshape(
(b, c, self.scale, h // self.scale, self.scale, w // self.scale)) (b, c, self.scale, h // self.scale, self.scale, w // self.scale))
@ -297,8 +296,7 @@ class PAM(nn.Layer):
out = self.conv_out(paddle.concat(res, axis=1)) out = self.conv_out(paddle.concat(res, axis=1))
return out.reshape( return out.reshape(tuple(out.shape[:-1]) + (out.shape[-1] // 2, 2))
tuple(paddle.shape(out)[:-1]) + (paddle.shape(out)[-1] // 2, 2))
class Attention(nn.Layer): class Attention(nn.Layer):

@ -36,10 +36,10 @@ class SELayer(nn.Layer):
nn.Sigmoid(), ) nn.Sigmoid(), )
def forward(self, x): def forward(self, x):
b, c, _, _ = paddle.shape(x) b, c, _, _ = x.shape
y = self.avg_pool(x).reshape((b, c)) y = self.avg_pool(x).reshape((b, c))
y = self.fc(y).reshape((b, c, 1, 1)) y = self.fc(y).reshape((b, c, 1, 1))
return x * paddle.expand(y, shape=paddle.shape(x)) return x * paddle.expand(y, shape=x.shape)
class HS(nn.Layer): class HS(nn.Layer):
@ -85,7 +85,7 @@ class Conv(nn.Sequential):
def ShuffleLayer(x, groups): def ShuffleLayer(x, groups):
batchsize, num_channels, height, width = paddle.shape(x) batchsize, num_channels, height, width = x.shape
channels_per_group = num_channels // groups channels_per_group = num_channels // groups
# reshape # reshape
x = x.reshape((batchsize, groups, channels_per_group, height, width)) x = x.reshape((batchsize, groups, channels_per_group, height, width))
@ -97,7 +97,7 @@ def ShuffleLayer(x, groups):
def ShuffleLayerTrans(x, groups): def ShuffleLayerTrans(x, groups):
batchsize, num_channels, height, width = paddle.shape(x) batchsize, num_channels, height, width = x.shape
channels_per_group = num_channels // groups channels_per_group = num_channels // groups
# reshape # reshape
x = x.reshape((batchsize, channels_per_group, groups, height, width)) x = x.reshape((batchsize, channels_per_group, groups, height, width))
@ -188,7 +188,7 @@ class CondenseSFR(nn.Layer):
x = self.activation(x) x = self.activation(x)
x = ShuffleLayerTrans(x, self.groups) x = ShuffleLayerTrans(x, self.groups)
x = self.conv(x) # SIZE: N, C, H, W x = self.conv(x) # SIZE: N, C, H, W
N, C, H, W = paddle.shape(x) N, C, H, W = x.shape
x = x.reshape((N, C, H * W)) x = x.reshape((N, C, H * W))
x = x.transpose((0, 2, 1)) # SIZE: N, HW, C x = x.transpose((0, 2, 1)) # SIZE: N, HW, C
# x SIZE: N, HW, C; self.index SIZE: C, C; OUTPUT SIZE: N, HW, C # x SIZE: N, HW, C; self.index SIZE: C, C; OUTPUT SIZE: N, HW, C
@ -374,8 +374,8 @@ class CondenseNetV2(nn.Layer):
def forward(self, x): def forward(self, x):
features = self.features(x) features = self.features(x)
shape = paddle.shape(features) out = features.reshape((features.shape[0], features.shape[1] *
out = features.reshape((shape[0], shape[1] * shape[2] * shape[3])) features.shape[2] * features.shape[3]))
out = self.fc(out) out = self.fc(out)
out = self.fc_act(out) out = self.fc_act(out)

@ -336,7 +336,7 @@ class COCODetection(Dataset):
max_img_id += 1 max_img_id += 1
im_fname = osp.join(image_dir, image) im_fname = osp.join(image_dir, image)
img_data = cv2.imread(im_fname, cv2.IMREAD_UNCHANGED) img_data = cv2.imread(im_fname, cv2.IMREAD_UNCHANGED)
im_h, im_w, im_c = paddle.shape(img_data) im_h, im_w, im_c = img_data.shape
im_info = { im_info = {
'im_id': np.asarray([max_img_id]), 'im_id': np.asarray([max_img_id]),

@ -400,7 +400,7 @@ class VOCDetection(Dataset):
max_img_id += 1 max_img_id += 1
im_fname = osp.join(image_dir, image) im_fname = osp.join(image_dir, image)
img_data = cv2.imread(im_fname, cv2.IMREAD_UNCHANGED) img_data = cv2.imread(im_fname, cv2.IMREAD_UNCHANGED)
im_h, im_w, im_c = paddle.shape(img_data) im_h, im_w, im_c = img_data.shape
im_info = { im_info = {
'im_id': np.asarray([max_img_id]), 'im_id': np.asarray([max_img_id]),

Loading…
Cancel
Save