|
|
|
@ -68,6 +68,35 @@ class DWConvTranspose2d(nn.ConvTranspose2d): |
|
|
|
|
super().__init__(c1, c2, k, s, p1, p2, groups=math.gcd(c1, c2)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ConvTranspose(nn.Module): |
|
|
|
|
# Convolution transpose 2d layer |
|
|
|
|
default_act = nn.SiLU() # default activation |
|
|
|
|
|
|
|
|
|
def __init__(self, c1, c2, k=2, s=2, p=0, bn=True, act=True): |
|
|
|
|
super().__init__() |
|
|
|
|
self.conv_transpose = nn.ConvTranspose2d(c1, c2, k, s, p, bias=not bn) |
|
|
|
|
self.bn = nn.BatchNorm2d(c2) if bn else nn.Identity() |
|
|
|
|
self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity() |
|
|
|
|
|
|
|
|
|
def forward(self, x): |
|
|
|
|
return self.act(self.bn(self.conv_transpose(x))) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DFL(nn.Module): |
|
|
|
|
# DFL module |
|
|
|
|
def __init__(self, c1=16): |
|
|
|
|
super().__init__() |
|
|
|
|
self.conv = nn.Conv2d(c1, 1, 1, bias=False).requires_grad_(False) |
|
|
|
|
x = torch.arange(c1, dtype=torch.float) |
|
|
|
|
self.conv.weight.data[:] = nn.Parameter(x.view(1, c1, 1, 1)) |
|
|
|
|
self.c1 = c1 |
|
|
|
|
|
|
|
|
|
def forward(self, x): |
|
|
|
|
b, c, a = x.shape # batch, channels, anchors |
|
|
|
|
return self.conv(x.view(b, 4, self.c1, a).transpose(2, 1).softmax(1)).view(b, 4, a) |
|
|
|
|
# return self.conv(x.view(b, self.c1, 4, a).softmax(1)).view(b, 4, a) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TransformerLayer(nn.Module): |
|
|
|
|
# Transformer layer https://arxiv.org/abs/2010.11929 (LayerNorm layers removed for better performance) |
|
|
|
|
def __init__(self, c, num_heads): |
|
|
|
@ -106,11 +135,11 @@ class TransformerBlock(nn.Module): |
|
|
|
|
|
|
|
|
|
class Bottleneck(nn.Module): |
|
|
|
|
# Standard bottleneck |
|
|
|
|
def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion |
|
|
|
|
def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5): # ch_in, ch_out, shortcut, kernels, groups, expand |
|
|
|
|
super().__init__() |
|
|
|
|
c_ = int(c2 * e) # hidden channels |
|
|
|
|
self.cv1 = Conv(c1, c_, 1, 1) |
|
|
|
|
self.cv2 = Conv(c_, c2, 3, 1, g=g) |
|
|
|
|
self.cv1 = Conv(c1, c_, k[0], 1) |
|
|
|
|
self.cv2 = Conv(c_, c2, k[1], 1, g=g) |
|
|
|
|
self.add = shortcut and c1 == c2 |
|
|
|
|
|
|
|
|
|
def forward(self, x): |
|
|
|
@ -136,20 +165,6 @@ class BottleneckCSP(nn.Module): |
|
|
|
|
return self.cv4(self.act(self.bn(torch.cat((y1, y2), 1)))) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CrossConv(nn.Module): |
|
|
|
|
# Cross Convolution Downsample |
|
|
|
|
def __init__(self, c1, c2, k=3, s=1, g=1, e=1.0, shortcut=False): |
|
|
|
|
# ch_in, ch_out, kernel, stride, groups, expansion, shortcut |
|
|
|
|
super().__init__() |
|
|
|
|
c_ = int(c2 * e) # hidden channels |
|
|
|
|
self.cv1 = Conv(c1, c_, (1, k), (1, s)) |
|
|
|
|
self.cv2 = Conv(c_, c2, (k, 1), (s, 1), g=g) |
|
|
|
|
self.add = shortcut and c1 == c2 |
|
|
|
|
|
|
|
|
|
def forward(self, x): |
|
|
|
|
return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class C3(nn.Module): |
|
|
|
|
# CSP Bottleneck with 3 convolutions |
|
|
|
|
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion |
|
|
|
@ -164,12 +179,90 @@ class C3(nn.Module): |
|
|
|
|
return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class C2(nn.Module): |
|
|
|
|
# CSP Bottleneck with 2 convolutions |
|
|
|
|
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion |
|
|
|
|
super().__init__() |
|
|
|
|
self.c = int(c2 * e) # hidden channels |
|
|
|
|
self.cv1 = Conv(c1, 2 * self.c, 1, 1) |
|
|
|
|
self.cv2 = Conv(2 * self.c, c2, 1) # optional act=FReLU(c2) |
|
|
|
|
# self.attention = ChannelAttention(2 * self.c) # or SpatialAttention() |
|
|
|
|
self.m = nn.Sequential(*(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n))) |
|
|
|
|
|
|
|
|
|
def forward(self, x): |
|
|
|
|
a, b = self.cv1(x).split((self.c, self.c), 1) |
|
|
|
|
return self.cv2(torch.cat((self.m(a), b), 1)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class C2f(nn.Module): |
|
|
|
|
# CSP Bottleneck with 2 convolutions |
|
|
|
|
def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion |
|
|
|
|
super().__init__() |
|
|
|
|
self.c = int(c2 * e) # hidden channels |
|
|
|
|
self.cv1 = Conv(c1, 2 * self.c, 1, 1) |
|
|
|
|
self.cv2 = Conv((2 + n) * self.c, c2, 1) # optional act=FReLU(c2) |
|
|
|
|
self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n)) |
|
|
|
|
|
|
|
|
|
def forward(self, x): |
|
|
|
|
y = list(self.cv1(x).split((self.c, self.c), 1)) |
|
|
|
|
y.extend(m(y[-1]) for m in self.m) |
|
|
|
|
return self.cv2(torch.cat(y, 1)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ChannelAttention(nn.Module): |
|
|
|
|
# Channel-attention module https://github.com/open-mmlab/mmdetection/tree/v3.0.0rc1/configs/rtmdet |
|
|
|
|
def __init__(self, channels: int) -> None: |
|
|
|
|
super().__init__() |
|
|
|
|
self.pool = nn.AdaptiveAvgPool2d(1) |
|
|
|
|
self.fc = nn.Conv2d(channels, channels, 1, 1, 0, bias=True) |
|
|
|
|
self.act = nn.Sigmoid() |
|
|
|
|
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
|
|
return x * self.act(self.fc(self.pool(x))) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SpatialAttention(nn.Module): |
|
|
|
|
# Spatial-attention module |
|
|
|
|
def __init__(self, kernel_size=7): |
|
|
|
|
super().__init__() |
|
|
|
|
assert kernel_size in (3, 7), 'kernel size must be 3 or 7' |
|
|
|
|
padding = 3 if kernel_size == 7 else 1 |
|
|
|
|
self.cv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False) |
|
|
|
|
self.act = nn.Sigmoid() |
|
|
|
|
|
|
|
|
|
def forward(self, x): |
|
|
|
|
return x * self.act(self.cv1(torch.cat([torch.mean(x, 1, keepdim=True), torch.max(x, 1, keepdim=True)[0]], 1))) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CBAM(nn.Module): |
|
|
|
|
# CSP Bottleneck with 3 convolutions |
|
|
|
|
def __init__(self, c1, ratio=16, kernel_size=7): # ch_in, ch_out, number, shortcut, groups, expansion |
|
|
|
|
super().__init__() |
|
|
|
|
self.channel_attention = ChannelAttention(c1) |
|
|
|
|
self.spatial_attention = SpatialAttention(kernel_size) |
|
|
|
|
|
|
|
|
|
def forward(self, x): |
|
|
|
|
return self.spatial_attention(self.channel_attention(x)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class C1(nn.Module): |
|
|
|
|
# CSP Bottleneck with 3 convolutions |
|
|
|
|
def __init__(self, c1, c2, n=1): # ch_in, ch_out, number, shortcut, groups, expansion |
|
|
|
|
super().__init__() |
|
|
|
|
self.cv1 = Conv(c1, c2, 1, 1) |
|
|
|
|
self.m = nn.Sequential(*(Conv(c2, c2, 3) for _ in range(n))) |
|
|
|
|
|
|
|
|
|
def forward(self, x): |
|
|
|
|
y = self.cv1(x) |
|
|
|
|
return self.m(y) + y |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class C3x(C3): |
|
|
|
|
# C3 module with cross-convolutions |
|
|
|
|
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): |
|
|
|
|
super().__init__(c1, c2, n, shortcut, g, e) |
|
|
|
|
c_ = int(c2 * e) |
|
|
|
|
self.m = nn.Sequential(*(CrossConv(c_, c_, 3, 1, g, 1.0, shortcut) for _ in range(n))) |
|
|
|
|
self.c_ = int(c2 * e) |
|
|
|
|
self.m = nn.Sequential(*(Bottleneck(self.c_, self.c_, shortcut, g, k=((1, 3), (3, 1)), e=1) for _ in range(n))) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class C3TR(C3): |
|
|
|
@ -180,14 +273,6 @@ class C3TR(C3): |
|
|
|
|
self.m = TransformerBlock(c_, c_, 4, n) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class C3SPP(C3): |
|
|
|
|
# C3 module with SPP() |
|
|
|
|
def __init__(self, c1, c2, k=(5, 9, 13), n=1, shortcut=True, g=1, e=0.5): |
|
|
|
|
super().__init__(c1, c2, n, shortcut, g, e) |
|
|
|
|
c_ = int(c2 * e) |
|
|
|
|
self.m = SPP(c_, c_, k) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class C3Ghost(C3): |
|
|
|
|
# C3 module with GhostBottleneck() |
|
|
|
|
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): |
|
|
|
@ -271,34 +356,6 @@ class GhostBottleneck(nn.Module): |
|
|
|
|
return self.conv(x) + self.shortcut(x) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Contract(nn.Module): |
|
|
|
|
# Contract width-height into channels, i.e. x(1,64,80,80) to x(1,256,40,40) |
|
|
|
|
def __init__(self, gain=2): |
|
|
|
|
super().__init__() |
|
|
|
|
self.gain = gain |
|
|
|
|
|
|
|
|
|
def forward(self, x): |
|
|
|
|
b, c, h, w = x.size() # assert (h / s == 0) and (W / s == 0), 'Indivisible gain' |
|
|
|
|
s = self.gain |
|
|
|
|
x = x.view(b, c, h // s, s, w // s, s) # x(1,64,40,2,40,2) |
|
|
|
|
x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # x(1,2,2,64,40,40) |
|
|
|
|
return x.view(b, c * s * s, h // s, w // s) # x(1,256,40,40) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Expand(nn.Module): |
|
|
|
|
# Expand channels into width-height, i.e. x(1,64,80,80) to x(1,16,160,160) |
|
|
|
|
def __init__(self, gain=2): |
|
|
|
|
super().__init__() |
|
|
|
|
self.gain = gain |
|
|
|
|
|
|
|
|
|
def forward(self, x): |
|
|
|
|
b, c, h, w = x.size() # assert C / s ** 2 == 0, 'Indivisible gain' |
|
|
|
|
s = self.gain |
|
|
|
|
x = x.view(b, s, s, c // s ** 2, h, w) # x(1,2,2,16,80,80) |
|
|
|
|
x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # x(1,16,80,2,80,2) |
|
|
|
|
return x.view(b, c // s ** 2, h * s, w * s) # x(1,16,160,160) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Concat(nn.Module): |
|
|
|
|
# Concatenate a list of tensors along dimension |
|
|
|
|
def __init__(self, dimension=1): |
|
|
|
|