|
|
|
@ -383,44 +383,44 @@ class TinyViTBlock(nn.Module): |
|
|
|
|
"""Applies attention-based transformation or padding to input 'x' before passing it through a local |
|
|
|
|
convolution. |
|
|
|
|
""" |
|
|
|
|
H, W = self.input_resolution |
|
|
|
|
B, L, C = x.shape |
|
|
|
|
assert L == H * W, "input feature has wrong size" |
|
|
|
|
h, w = self.input_resolution |
|
|
|
|
b, l, c = x.shape |
|
|
|
|
assert l == h * w, "input feature has wrong size" |
|
|
|
|
res_x = x |
|
|
|
|
if H == self.window_size and W == self.window_size: |
|
|
|
|
if h == self.window_size and w == self.window_size: |
|
|
|
|
x = self.attn(x) |
|
|
|
|
else: |
|
|
|
|
x = x.view(B, H, W, C) |
|
|
|
|
pad_b = (self.window_size - H % self.window_size) % self.window_size |
|
|
|
|
pad_r = (self.window_size - W % self.window_size) % self.window_size |
|
|
|
|
x = x.view(b, h, w, c) |
|
|
|
|
pad_b = (self.window_size - h % self.window_size) % self.window_size |
|
|
|
|
pad_r = (self.window_size - w % self.window_size) % self.window_size |
|
|
|
|
padding = pad_b > 0 or pad_r > 0 |
|
|
|
|
|
|
|
|
|
if padding: |
|
|
|
|
x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b)) |
|
|
|
|
|
|
|
|
|
pH, pW = H + pad_b, W + pad_r |
|
|
|
|
pH, pW = h + pad_b, w + pad_r |
|
|
|
|
nH = pH // self.window_size |
|
|
|
|
nW = pW // self.window_size |
|
|
|
|
# Window partition |
|
|
|
|
x = ( |
|
|
|
|
x.view(B, nH, self.window_size, nW, self.window_size, C) |
|
|
|
|
x.view(b, nH, self.window_size, nW, self.window_size, c) |
|
|
|
|
.transpose(2, 3) |
|
|
|
|
.reshape(B * nH * nW, self.window_size * self.window_size, C) |
|
|
|
|
.reshape(b * nH * nW, self.window_size * self.window_size, c) |
|
|
|
|
) |
|
|
|
|
x = self.attn(x) |
|
|
|
|
# Window reverse |
|
|
|
|
x = x.view(B, nH, nW, self.window_size, self.window_size, C).transpose(2, 3).reshape(B, pH, pW, C) |
|
|
|
|
x = x.view(b, nH, nW, self.window_size, self.window_size, c).transpose(2, 3).reshape(b, pH, pW, c) |
|
|
|
|
|
|
|
|
|
if padding: |
|
|
|
|
x = x[:, :H, :W].contiguous() |
|
|
|
|
x = x[:, :h, :w].contiguous() |
|
|
|
|
|
|
|
|
|
x = x.view(B, L, C) |
|
|
|
|
x = x.view(b, l, c) |
|
|
|
|
|
|
|
|
|
x = res_x + self.drop_path(x) |
|
|
|
|
|
|
|
|
|
x = x.transpose(1, 2).reshape(B, C, H, W) |
|
|
|
|
x = x.transpose(1, 2).reshape(b, c, h, w) |
|
|
|
|
x = self.local_conv(x) |
|
|
|
|
x = x.view(B, C, L).transpose(1, 2) |
|
|
|
|
x = x.view(b, c, l).transpose(1, 2) |
|
|
|
|
|
|
|
|
|
return x + self.drop_path(self.mlp(x)) |
|
|
|
|
|
|
|
|
@ -565,10 +565,10 @@ class TinyViT(nn.Module): |
|
|
|
|
img_size=224, |
|
|
|
|
in_chans=3, |
|
|
|
|
num_classes=1000, |
|
|
|
|
embed_dims=[96, 192, 384, 768], |
|
|
|
|
depths=[2, 2, 6, 2], |
|
|
|
|
num_heads=[3, 6, 12, 24], |
|
|
|
|
window_sizes=[7, 7, 14, 7], |
|
|
|
|
embed_dims=(96, 192, 384, 768), |
|
|
|
|
depths=(2, 2, 6, 2), |
|
|
|
|
num_heads=(3, 6, 12, 24), |
|
|
|
|
window_sizes=(7, 7, 14, 7), |
|
|
|
|
mlp_ratio=4.0, |
|
|
|
|
drop_rate=0.0, |
|
|
|
|
drop_path_rate=0.1, |
|
|
|
@ -732,8 +732,8 @@ class TinyViT(nn.Module): |
|
|
|
|
for i in range(start_i, len(self.layers)): |
|
|
|
|
layer = self.layers[i] |
|
|
|
|
x = layer(x) |
|
|
|
|
B, _, C = x.shape |
|
|
|
|
x = x.view(B, 64, 64, C) |
|
|
|
|
batch, _, channel = x.shape |
|
|
|
|
x = x.view(batch, 64, 64, channel) |
|
|
|
|
x = x.permute(0, 3, 1, 2) |
|
|
|
|
return self.neck(x) |
|
|
|
|
|
|
|
|
|