|
|
|
@ -384,8 +384,8 @@ class TinyViTBlock(nn.Module): |
|
|
|
|
convolution. |
|
|
|
|
""" |
|
|
|
|
h, w = self.input_resolution |
|
|
|
|
b, l, c = x.shape |
|
|
|
|
assert l == h * w, "input feature has wrong size" |
|
|
|
|
b, hw, c = x.shape # batch, height*width, channels |
|
|
|
|
assert hw == h * w, "input feature has wrong size" |
|
|
|
|
res_x = x |
|
|
|
|
if h == self.window_size and w == self.window_size: |
|
|
|
|
x = self.attn(x) |
|
|
|
@ -394,13 +394,13 @@ class TinyViTBlock(nn.Module): |
|
|
|
|
pad_b = (self.window_size - h % self.window_size) % self.window_size |
|
|
|
|
pad_r = (self.window_size - w % self.window_size) % self.window_size |
|
|
|
|
padding = pad_b > 0 or pad_r > 0 |
|
|
|
|
|
|
|
|
|
if padding: |
|
|
|
|
x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b)) |
|
|
|
|
|
|
|
|
|
pH, pW = h + pad_b, w + pad_r |
|
|
|
|
nH = pH // self.window_size |
|
|
|
|
nW = pW // self.window_size |
|
|
|
|
|
|
|
|
|
# Window partition |
|
|
|
|
x = ( |
|
|
|
|
x.view(b, nH, self.window_size, nW, self.window_size, c) |
|
|
|
@ -408,19 +408,18 @@ class TinyViTBlock(nn.Module): |
|
|
|
|
.reshape(b * nH * nW, self.window_size * self.window_size, c) |
|
|
|
|
) |
|
|
|
|
x = self.attn(x) |
|
|
|
|
|
|
|
|
|
# Window reverse |
|
|
|
|
x = x.view(b, nH, nW, self.window_size, self.window_size, c).transpose(2, 3).reshape(b, pH, pW, c) |
|
|
|
|
|
|
|
|
|
if padding: |
|
|
|
|
x = x[:, :h, :w].contiguous() |
|
|
|
|
|
|
|
|
|
x = x.view(b, l, c) |
|
|
|
|
x = x.view(b, hw, c) |
|
|
|
|
|
|
|
|
|
x = res_x + self.drop_path(x) |
|
|
|
|
|
|
|
|
|
x = x.transpose(1, 2).reshape(b, c, h, w) |
|
|
|
|
x = self.local_conv(x) |
|
|
|
|
x = x.view(b, c, l).transpose(1, 2) |
|
|
|
|
x = x.view(b, c, hw).transpose(1, 2) |
|
|
|
|
|
|
|
|
|
return x + self.drop_path(self.mlp(x)) |
|
|
|
|
|
|
|
|
|