|
|
|
@ -129,7 +129,7 @@ class BIT(nn.Layer): |
|
|
|
|
Conv3x3(EBD_DIM, num_classes)) |
|
|
|
|
|
|
|
|
|
def _get_semantic_tokens(self, x): |
|
|
|
|
b, c = paddle.shape(x)[:2] |
|
|
|
|
b, c = x.shape[:2] |
|
|
|
|
att_map = self.conv_att(x) |
|
|
|
|
att_map = att_map.reshape((b, self.token_len, 1, -1)) |
|
|
|
|
att_map = F.softmax(att_map, axis=-1) |
|
|
|
@ -154,7 +154,7 @@ class BIT(nn.Layer): |
|
|
|
|
return x |
|
|
|
|
|
|
|
|
|
def decode(self, x, m): |
|
|
|
|
b, c, h, w = paddle.shape(x) |
|
|
|
|
b, c, h, w = x.shape |
|
|
|
|
x = x.transpose((0, 2, 3, 1)).flatten(1, 2) |
|
|
|
|
x = self.decoder(x, m) |
|
|
|
|
x = x.transpose((0, 2, 1)).reshape((b, c, h, w)) |
|
|
|
@ -172,7 +172,7 @@ class BIT(nn.Layer): |
|
|
|
|
else: |
|
|
|
|
token1 = self._get_reshaped_tokens(x1) |
|
|
|
|
token2 = self._get_reshaped_tokens(x2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Transformer encoder forward |
|
|
|
|
token = paddle.concat([token1, token2], axis=1) |
|
|
|
|
token = self.encode(token) |
|
|
|
@ -265,7 +265,7 @@ class CrossAttention(nn.Layer): |
|
|
|
|
nn.Linear(inner_dim, dim), nn.Dropout(dropout_rate)) |
|
|
|
|
|
|
|
|
|
def forward(self, x, ref): |
|
|
|
|
b, n = paddle.shape(x)[:2] |
|
|
|
|
b, n = x.shape[:2] |
|
|
|
|
h = self.n_heads |
|
|
|
|
|
|
|
|
|
q = self.fc_q(x) |
|
|
|
|