|
|
|
@ -129,7 +129,7 @@ class BIT(nn.Layer): |
|
|
|
|
Conv3x3(EBD_DIM, num_classes)) |
|
|
|
|
|
|
|
|
|
def _get_semantic_tokens(self, x): |
|
|
|
|
b, c = paddle.shape(x)[:2] |
|
|
|
|
b, c = x.shape[:2] |
|
|
|
|
att_map = self.conv_att(x) |
|
|
|
|
att_map = att_map.reshape((b, self.token_len, 1, -1)) |
|
|
|
|
att_map = F.softmax(att_map, axis=-1) |
|
|
|
@ -154,7 +154,7 @@ class BIT(nn.Layer): |
|
|
|
|
return x |
|
|
|
|
|
|
|
|
|
def decode(self, x, m): |
|
|
|
|
b, c, h, w = paddle.shape(x) |
|
|
|
|
b, c, h, w = x.shape |
|
|
|
|
x = x.transpose((0, 2, 3, 1)).flatten(1, 2) |
|
|
|
|
x = self.decoder(x, m) |
|
|
|
|
x = x.transpose((0, 2, 1)).reshape((b, c, h, w)) |
|
|
|
@ -265,7 +265,7 @@ class CrossAttention(nn.Layer): |
|
|
|
|
nn.Linear(inner_dim, dim), nn.Dropout(dropout_rate)) |
|
|
|
|
|
|
|
|
|
def forward(self, x, ref): |
|
|
|
|
b, n = paddle.shape(x)[:2] |
|
|
|
|
b, n = x.shape[:2] |
|
|
|
|
h = self.n_heads |
|
|
|
|
|
|
|
|
|
q = self.fc_q(x) |
|
|
|
|