diff --git a/paddlers/custom_models/cd/bit.py b/paddlers/custom_models/cd/bit.py index b30d280..0b38fbe 100644 --- a/paddlers/custom_models/cd/bit.py +++ b/paddlers/custom_models/cd/bit.py @@ -138,12 +138,12 @@ class BIT(nn.Layer): Conv3x3(EBD_DIM, num_classes)) def _get_semantic_tokens(self, x): - b, c = paddle.shape(x)[:2] + b, c = x.shape[:2] att_map = self.conv_att(x) att_map = att_map.reshape( - (b, self.token_len, 1, calc_product(*paddle.shape(att_map)[2:]))) + (b, self.token_len, 1, calc_product(*att_map.shape[2:]))) att_map = F.softmax(att_map, axis=-1) - x = x.reshape((b, 1, c, paddle.shape(att_map)[-1])) + x = x.reshape((b, 1, c, att_map.shape[-1])) tokens = (x * att_map).sum(-1) return tokens @@ -164,7 +164,7 @@ class BIT(nn.Layer): return x def decode(self, x, m): - b, c, h, w = paddle.shape(x) + b, c, h, w = x.shape x = x.transpose((0, 2, 3, 1)).flatten(1, 2) x = self.decoder(x, m) x = x.transpose((0, 2, 1)).reshape((b, c, h, w)) @@ -276,7 +276,7 @@ class CrossAttention(nn.Layer): nn.Linear(inner_dim, dim), nn.Dropout(dropout_rate)) def forward(self, x, ref): - b, n = paddle.shape(x)[:2] + b, n = x.shape[:2] h = self.n_heads q = self.fc_q(x) @@ -284,7 +284,7 @@ class CrossAttention(nn.Layer): v = self.fc_v(ref) q = q.reshape((b, n, h, self.head_dim)).transpose((0, 2, 1, 3)) - rn = paddle.shape(ref)[1] + rn = ref.shape[1] k = k.reshape((b, rn, h, self.head_dim)).transpose((0, 2, 1, 3)) v = v.reshape((b, rn, h, self.head_dim)).transpose((0, 2, 1, 3)) diff --git a/paddlers/custom_models/cd/fc_ef.py b/paddlers/custom_models/cd/fc_ef.py index a008688..a831485 100644 --- a/paddlers/custom_models/cd/fc_ef.py +++ b/paddlers/custom_models/cd/fc_ef.py @@ -131,8 +131,7 @@ class FCEarlyFusion(nn.Layer): # Stage 4d x4d = self.upconv4(x4p) - pad4 = (0, paddle.shape(x43)[3] - paddle.shape(x4d)[3], 0, - paddle.shape(x43)[2] - paddle.shape(x4d)[2]) + pad4 = (0, x43.shape[3] - x4d.shape[3], 0, x43.shape[2] - x4d.shape[2]) x4d = paddle.concat([F.pad(x4d, pad=pad4, mode='replicate'), x43], 1) x43d = self.do43d(self.conv43d(x4d)) x42d = self.do42d(self.conv42d(x43d)) @@ -140,8 +139,7 @@ class FCEarlyFusion(nn.Layer): # Stage 3d x3d = self.upconv3(x41d) - pad3 = (0, paddle.shape(x33)[3] - paddle.shape(x3d)[3], 0, - paddle.shape(x33)[2] - paddle.shape(x3d)[2]) + pad3 = (0, x33.shape[3] - x3d.shape[3], 0, x33.shape[2] - x3d.shape[2]) x3d = paddle.concat([F.pad(x3d, pad=pad3, mode='replicate'), x33], 1) x33d = self.do33d(self.conv33d(x3d)) x32d = self.do32d(self.conv32d(x33d)) @@ -149,16 +147,14 @@ class FCEarlyFusion(nn.Layer): # Stage 2d x2d = self.upconv2(x31d) - pad2 = (0, paddle.shape(x22)[3] - paddle.shape(x2d)[3], 0, - paddle.shape(x22)[2] - paddle.shape(x2d)[2]) + pad2 = (0, x22.shape[3] - x2d.shape[3], 0, x22.shape[2] - x2d.shape[2]) x2d = paddle.concat([F.pad(x2d, pad=pad2, mode='replicate'), x22], 1) x22d = self.do22d(self.conv22d(x2d)) x21d = self.do21d(self.conv21d(x22d)) # Stage 1d x1d = self.upconv1(x21d) - pad1 = (0, paddle.shape(x12)[3] - paddle.shape(x1d)[3], 0, - paddle.shape(x12)[2] - paddle.shape(x1d)[2]) + pad1 = (0, x12.shape[3] - x1d.shape[3], 0, x12.shape[2] - x1d.shape[2]) x1d = paddle.concat([F.pad(x1d, pad=pad1, mode='replicate'), x12], 1) x12d = self.do12d(self.conv12d(x1d)) x11d = self.conv11d(x12d) diff --git a/paddlers/custom_models/cd/fc_siam_conc.py b/paddlers/custom_models/cd/fc_siam_conc.py index af70543..bbe2632 100644 --- a/paddlers/custom_models/cd/fc_siam_conc.py +++ b/paddlers/custom_models/cd/fc_siam_conc.py @@ -154,8 +154,8 @@ class FCSiamConc(nn.Layer): # Decode # Stage 4d x4d = self.upconv4(x4p) - pad4 = (0, paddle.shape(x43_1)[3] - paddle.shape(x4d)[3], 0, - paddle.shape(x43_1)[2] - paddle.shape(x4d)[2]) + pad4 = (0, x43_1.shape[3] - x4d.shape[3], 0, + x43_1.shape[2] - x4d.shape[2]) x4d = paddle.concat( [F.pad(x4d, pad=pad4, mode='replicate'), x43_1, x43_2], 1) x43d = self.do43d(self.conv43d(x4d)) @@ -164,8 +164,8 @@ class FCSiamConc(nn.Layer): # Stage 3d x3d = self.upconv3(x41d) - pad3 = (0, paddle.shape(x33_1)[3] - paddle.shape(x3d)[3], 0, - paddle.shape(x33_1)[2] - paddle.shape(x3d)[2]) + pad3 = (0, x33_1.shape[3] - x3d.shape[3], 0, + x33_1.shape[2] - x3d.shape[2]) x3d = paddle.concat( [F.pad(x3d, pad=pad3, mode='replicate'), x33_1, x33_2], 1) x33d = self.do33d(self.conv33d(x3d)) @@ -174,8 +174,8 @@ class FCSiamConc(nn.Layer): # Stage 2d x2d = self.upconv2(x31d) - pad2 = (0, paddle.shape(x22_1)[3] - paddle.shape(x2d)[3], 0, - paddle.shape(x22_1)[2] - paddle.shape(x2d)[2]) + pad2 = (0, x22_1.shape[3] - x2d.shape[3], 0, + x22_1.shape[2] - x2d.shape[2]) x2d = paddle.concat( [F.pad(x2d, pad=pad2, mode='replicate'), x22_1, x22_2], 1) x22d = self.do22d(self.conv22d(x2d)) @@ -183,8 +183,8 @@ class FCSiamConc(nn.Layer): # Stage 1d x1d = self.upconv1(x21d) - pad1 = (0, paddle.shape(x12_1)[3] - paddle.shape(x1d)[3], 0, - paddle.shape(x12_1)[2] - paddle.shape(x1d)[2]) + pad1 = (0, x12_1.shape[3] - x1d.shape[3], 0, + x12_1.shape[2] - x1d.shape[2]) x1d = paddle.concat( [F.pad(x1d, pad=pad1, mode='replicate'), x12_1, x12_2], 1) x12d = self.do12d(self.conv12d(x1d)) diff --git a/paddlers/custom_models/cd/fc_siam_diff.py b/paddlers/custom_models/cd/fc_siam_diff.py index 9343cfe..b60b5db 100644 --- a/paddlers/custom_models/cd/fc_siam_diff.py +++ b/paddlers/custom_models/cd/fc_siam_diff.py @@ -154,8 +154,8 @@ class FCSiamDiff(nn.Layer): # Decode # Stage 4d x4d = self.upconv4(x4p) - pad4 = (0, paddle.shape(x43_1)[3] - paddle.shape(x4d)[3], 0, - paddle.shape(x43_1)[2] - paddle.shape(x4d)[2]) + pad4 = (0, x43_1.shape[3] - x4d.shape[3], 0, + x43_1.shape[2] - x4d.shape[2]) x4d = F.pad(x4d, pad=pad4, mode='replicate') x4d = paddle.concat([x4d, paddle.abs(x43_1 - x43_2)], 1) x43d = self.do43d(self.conv43d(x4d)) @@ -164,8 +164,8 @@ class FCSiamDiff(nn.Layer): # Stage 3d x3d = self.upconv3(x41d) - pad3 = (0, paddle.shape(x33_1)[3] - paddle.shape(x3d)[3], 0, - paddle.shape(x33_1)[2] - paddle.shape(x3d)[2]) + pad3 = (0, x33_1.shape[3] - x3d.shape[3], 0, + x33_1.shape[2] - x3d.shape[2]) x3d = F.pad(x3d, pad=pad3, mode='replicate') x3d = paddle.concat([x3d, paddle.abs(x33_1 - x33_2)], 1) x33d = self.do33d(self.conv33d(x3d)) @@ -174,8 +174,8 @@ class FCSiamDiff(nn.Layer): # Stage 2d x2d = self.upconv2(x31d) - pad2 = (0, paddle.shape(x22_1)[3] - paddle.shape(x2d)[3], 0, - paddle.shape(x22_1)[2] - paddle.shape(x2d)[2]) + pad2 = (0, x22_1.shape[3] - x2d.shape[3], 0, + x22_1.shape[2] - x2d.shape[2]) x2d = F.pad(x2d, pad=pad2, mode='replicate') x2d = paddle.concat([x2d, paddle.abs(x22_1 - x22_2)], 1) x22d = self.do22d(self.conv22d(x2d)) @@ -183,8 +183,8 @@ class FCSiamDiff(nn.Layer): # Stage 1d x1d = self.upconv1(x21d) - pad1 = (0, paddle.shape(x12_1)[3] - paddle.shape(x1d)[3], 0, - paddle.shape(x12_1)[2] - paddle.shape(x1d)[2]) + pad1 = (0, x12_1.shape[3] - x1d.shape[3], 0, + x12_1.shape[2] - x1d.shape[2]) x1d = F.pad(x1d, pad=pad1, mode='replicate') x1d = paddle.concat([x1d, paddle.abs(x12_1 - x12_2)], 1) x12d = self.do12d(self.conv12d(x1d)) diff --git a/paddlers/custom_models/cd/stanet.py b/paddlers/custom_models/cd/stanet.py index ee98611..b4c1b19 100644 --- a/paddlers/custom_models/cd/stanet.py +++ b/paddlers/custom_models/cd/stanet.py @@ -215,8 +215,7 @@ class BAM(nn.Layer): out = F.interpolate(out, scale_factor=self.ds) out = out + x - return out.reshape( - tuple(paddle.shape(out)[:-1]) + (paddle.shape(out)[-1] // 2, 2)) + return out.reshape(tuple(out.shape[:-1]) + (out.shape[-1] // 2, 2)) class PAMBlock(nn.Layer): @@ -242,7 +241,7 @@ class PAMBlock(nn.Layer): value = self.conv_v(x_rs) # Split the whole image into subregions. - b, c, h, w = paddle.shape(x_rs) + b, c, h, w = x_rs.shape query = self._split_subregions(query) key = self._split_subregions(key) @@ -265,7 +264,7 @@ class PAMBlock(nn.Layer): return out def _split_subregions(self, x): - b, c, h, w = paddle.shape(x) + b, c, h, w = x.shape assert h % self.scale == 0 and w % self.scale == 0 x = x.reshape( (b, c, self.scale, h // self.scale, self.scale, w // self.scale)) @@ -297,8 +296,7 @@ class PAM(nn.Layer): out = self.conv_out(paddle.concat(res, axis=1)) - return out.reshape( - tuple(paddle.shape(out)[:-1]) + (paddle.shape(out)[-1] // 2, 2)) + return out.reshape(tuple(out.shape[:-1]) + (out.shape[-1] // 2, 2)) class Attention(nn.Layer): diff --git a/paddlers/custom_models/cls/condensenet_v2.py b/paddlers/custom_models/cls/condensenet_v2.py index df70113..2ca1073 100644 --- a/paddlers/custom_models/cls/condensenet_v2.py +++ b/paddlers/custom_models/cls/condensenet_v2.py @@ -36,10 +36,10 @@ class SELayer(nn.Layer): nn.Sigmoid(), ) def forward(self, x): - b, c, _, _ = paddle.shape(x) + b, c, _, _ = x.shape y = self.avg_pool(x).reshape((b, c)) y = self.fc(y).reshape((b, c, 1, 1)) - return x * paddle.expand(y, shape=paddle.shape(x)) + return x * paddle.expand(y, shape=x.shape) class HS(nn.Layer): @@ -85,7 +85,7 @@ class Conv(nn.Sequential): def ShuffleLayer(x, groups): - batchsize, num_channels, height, width = paddle.shape(x) + batchsize, num_channels, height, width = x.shape channels_per_group = num_channels // groups # reshape x = x.reshape((batchsize, groups, channels_per_group, height, width)) @@ -97,7 +97,7 @@ def ShuffleLayer(x, groups): def ShuffleLayerTrans(x, groups): - batchsize, num_channels, height, width = paddle.shape(x) + batchsize, num_channels, height, width = x.shape channels_per_group = num_channels // groups # reshape x = x.reshape((batchsize, channels_per_group, groups, height, width)) @@ -188,7 +188,7 @@ class CondenseSFR(nn.Layer): x = self.activation(x) x = ShuffleLayerTrans(x, self.groups) x = self.conv(x) # SIZE: N, C, H, W - N, C, H, W = paddle.shape(x) + N, C, H, W = x.shape x = x.reshape((N, C, H * W)) x = x.transpose((0, 2, 1)) # SIZE: N, HW, C # x SIZE: N, HW, C; self.index SIZE: C, C; OUTPUT SIZE: N, HW, C @@ -374,8 +374,8 @@ class CondenseNetV2(nn.Layer): def forward(self, x): features = self.features(x) - shape = paddle.shape(features) - out = features.reshape((shape[0], shape[1] * shape[2] * shape[3])) + out = features.reshape((features.shape[0], features.shape[1] * + features.shape[2] * features.shape[3])) out = self.fc(out) out = self.fc_act(out) diff --git a/paddlers/datasets/coco.py b/paddlers/datasets/coco.py index f2f236c..b4fc845 100644 --- a/paddlers/datasets/coco.py +++ b/paddlers/datasets/coco.py @@ -336,7 +336,7 @@ class COCODetection(Dataset): max_img_id += 1 im_fname = osp.join(image_dir, image) img_data = cv2.imread(im_fname, cv2.IMREAD_UNCHANGED) - im_h, im_w, im_c = paddle.shape(img_data) + im_h, im_w, im_c = img_data.shape im_info = { 'im_id': np.asarray([max_img_id]), diff --git a/paddlers/datasets/voc.py b/paddlers/datasets/voc.py index 7bef4e7..1876910 100644 --- a/paddlers/datasets/voc.py +++ b/paddlers/datasets/voc.py @@ -400,7 +400,7 @@ class VOCDetection(Dataset): max_img_id += 1 im_fname = osp.join(image_dir, image) img_data = cv2.imread(im_fname, cv2.IMREAD_UNCHANGED) - im_h, im_w, im_c = paddle.shape(img_data) + im_h, im_w, im_c = img_data.shape im_info = { 'im_id': np.asarray([max_img_id]),