From 40f7e0806e2d9108c3a8e6eff489283c5401c7e5 Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Wed, 21 Dec 2022 04:59:58 +0530 Subject: [PATCH] Add v8 modules (#81) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher --- ultralytics/yolo/utils/modeling/__init__.py | 13 +- ultralytics/yolo/utils/modeling/modules.py | 167 +++++++++++++------- 2 files changed, 115 insertions(+), 65 deletions(-) diff --git a/ultralytics/yolo/utils/modeling/__init__.py b/ultralytics/yolo/utils/modeling/__init__.py index 7b9600b19..2719650eb 100644 --- a/ultralytics/yolo/utils/modeling/__init__.py +++ b/ultralytics/yolo/utils/modeling/__init__.py @@ -25,11 +25,8 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=True): # Module compatibility updates for m in model.modules(): t = type(m) - if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect): + if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Segment): m.inplace = inplace # torch 1.7.0 compatibility - if t is Detect and not isinstance(m.anchor_grid, list): - delattr(m, 'anchor_grid') - setattr(m, 'anchor_grid', [torch.zeros(1)] * m.nl) elif t is nn.Upsample and not hasattr(m, 'recompute_scale_factor'): m.recompute_scale_factor = None # torch 1.11.0 compatibility @@ -65,8 +62,8 @@ def parse_model(d, ch): # model_dict, input_channels(3) n = n_ = max(round(n * gd), 1) if n > 1 else n # depth gain if m in { - Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, Focus, CrossConv, BottleneckCSP, C3, - C3TR, C3SPP, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x}: + Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, Focus, BottleneckCSP, C3, C3TR, + C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x}: c1, c2 = ch[f], args[0] if c2 != no: # if not output c2 = make_divisible(c2 * gw, 8) @@ -86,10 +83,6 @@ def parse_model(d, ch): # model_dict, input_channels(3) args[1] = [list(range(args[1] * 2))] * len(f) if m is Segment: args[3] = make_divisible(args[3] * gw, 8) - elif m is Contract: - c2 = ch[f] * args[0] ** 2 - elif m is Expand: - c2 = ch[f] // args[0] ** 2 else: c2 = ch[f] diff --git a/ultralytics/yolo/utils/modeling/modules.py b/ultralytics/yolo/utils/modeling/modules.py index 0a5bc5412..6fe61fa6c 100644 --- a/ultralytics/yolo/utils/modeling/modules.py +++ b/ultralytics/yolo/utils/modeling/modules.py @@ -68,6 +68,35 @@ class DWConvTranspose2d(nn.ConvTranspose2d): super().__init__(c1, c2, k, s, p1, p2, groups=math.gcd(c1, c2)) +class ConvTranspose(nn.Module): + # Convolution transpose 2d layer + default_act = nn.SiLU() # default activation + + def __init__(self, c1, c2, k=2, s=2, p=0, bn=True, act=True): + super().__init__() + self.conv_transpose = nn.ConvTranspose2d(c1, c2, k, s, p, bias=not bn) + self.bn = nn.BatchNorm2d(c2) if bn else nn.Identity() + self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity() + + def forward(self, x): + return self.act(self.bn(self.conv_transpose(x))) + + +class DFL(nn.Module): + # DFL module + def __init__(self, c1=16): + super().__init__() + self.conv = nn.Conv2d(c1, 1, 1, bias=False).requires_grad_(False) + x = torch.arange(c1, dtype=torch.float) + self.conv.weight.data[:] = nn.Parameter(x.view(1, c1, 1, 1)) + self.c1 = c1 + + def forward(self, x): + b, c, a = x.shape # batch, channels, anchors + return self.conv(x.view(b, 4, self.c1, a).transpose(2, 1).softmax(1)).view(b, 4, a) + # return self.conv(x.view(b, self.c1, 4, a).softmax(1)).view(b, 4, a) + + class TransformerLayer(nn.Module): # Transformer layer https://arxiv.org/abs/2010.11929 (LayerNorm layers removed for better performance) def __init__(self, c, num_heads): @@ -106,11 +135,11 @@ class TransformerBlock(nn.Module): class Bottleneck(nn.Module): # Standard bottleneck - def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion + def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5): # ch_in, ch_out, shortcut, kernels, groups, expand super().__init__() c_ = int(c2 * e) # hidden channels - self.cv1 = Conv(c1, c_, 1, 1) - self.cv2 = Conv(c_, c2, 3, 1, g=g) + self.cv1 = Conv(c1, c_, k[0], 1) + self.cv2 = Conv(c_, c2, k[1], 1, g=g) self.add = shortcut and c1 == c2 def forward(self, x): @@ -136,20 +165,6 @@ class BottleneckCSP(nn.Module): return self.cv4(self.act(self.bn(torch.cat((y1, y2), 1)))) -class CrossConv(nn.Module): - # Cross Convolution Downsample - def __init__(self, c1, c2, k=3, s=1, g=1, e=1.0, shortcut=False): - # ch_in, ch_out, kernel, stride, groups, expansion, shortcut - super().__init__() - c_ = int(c2 * e) # hidden channels - self.cv1 = Conv(c1, c_, (1, k), (1, s)) - self.cv2 = Conv(c_, c2, (k, 1), (s, 1), g=g) - self.add = shortcut and c1 == c2 - - def forward(self, x): - return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x)) - - class C3(nn.Module): # CSP Bottleneck with 3 convolutions def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion @@ -164,12 +179,90 @@ class C3(nn.Module): return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1)) +class C2(nn.Module): + # CSP Bottleneck with 2 convolutions + def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion + super().__init__() + self.c = int(c2 * e) # hidden channels + self.cv1 = Conv(c1, 2 * self.c, 1, 1) + self.cv2 = Conv(2 * self.c, c2, 1) # optional act=FReLU(c2) + # self.attention = ChannelAttention(2 * self.c) # or SpatialAttention() + self.m = nn.Sequential(*(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n))) + + def forward(self, x): + a, b = self.cv1(x).split((self.c, self.c), 1) + return self.cv2(torch.cat((self.m(a), b), 1)) + + +class C2f(nn.Module): + # CSP Bottleneck with 2 convolutions + def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion + super().__init__() + self.c = int(c2 * e) # hidden channels + self.cv1 = Conv(c1, 2 * self.c, 1, 1) + self.cv2 = Conv((2 + n) * self.c, c2, 1) # optional act=FReLU(c2) + self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n)) + + def forward(self, x): + y = list(self.cv1(x).split((self.c, self.c), 1)) + y.extend(m(y[-1]) for m in self.m) + return self.cv2(torch.cat(y, 1)) + + +class ChannelAttention(nn.Module): + # Channel-attention module https://github.com/open-mmlab/mmdetection/tree/v3.0.0rc1/configs/rtmdet + def __init__(self, channels: int) -> None: + super().__init__() + self.pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Conv2d(channels, channels, 1, 1, 0, bias=True) + self.act = nn.Sigmoid() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x * self.act(self.fc(self.pool(x))) + + +class SpatialAttention(nn.Module): + # Spatial-attention module + def __init__(self, kernel_size=7): + super().__init__() + assert kernel_size in (3, 7), 'kernel size must be 3 or 7' + padding = 3 if kernel_size == 7 else 1 + self.cv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False) + self.act = nn.Sigmoid() + + def forward(self, x): + return x * self.act(self.cv1(torch.cat([torch.mean(x, 1, keepdim=True), torch.max(x, 1, keepdim=True)[0]], 1))) + + +class CBAM(nn.Module): + # CSP Bottleneck with 3 convolutions + def __init__(self, c1, ratio=16, kernel_size=7): # ch_in, ch_out, number, shortcut, groups, expansion + super().__init__() + self.channel_attention = ChannelAttention(c1) + self.spatial_attention = SpatialAttention(kernel_size) + + def forward(self, x): + return self.spatial_attention(self.channel_attention(x)) + + +class C1(nn.Module): + # CSP Bottleneck with 3 convolutions + def __init__(self, c1, c2, n=1): # ch_in, ch_out, number, shortcut, groups, expansion + super().__init__() + self.cv1 = Conv(c1, c2, 1, 1) + self.m = nn.Sequential(*(Conv(c2, c2, 3) for _ in range(n))) + + def forward(self, x): + y = self.cv1(x) + return self.m(y) + y + + class C3x(C3): # C3 module with cross-convolutions def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): super().__init__(c1, c2, n, shortcut, g, e) - c_ = int(c2 * e) - self.m = nn.Sequential(*(CrossConv(c_, c_, 3, 1, g, 1.0, shortcut) for _ in range(n))) + self.c_ = int(c2 * e) + self.m = nn.Sequential(*(Bottleneck(self.c_, self.c_, shortcut, g, k=((1, 3), (3, 1)), e=1) for _ in range(n))) class C3TR(C3): @@ -180,14 +273,6 @@ class C3TR(C3): self.m = TransformerBlock(c_, c_, 4, n) -class C3SPP(C3): - # C3 module with SPP() - def __init__(self, c1, c2, k=(5, 9, 13), n=1, shortcut=True, g=1, e=0.5): - super().__init__(c1, c2, n, shortcut, g, e) - c_ = int(c2 * e) - self.m = SPP(c_, c_, k) - - class C3Ghost(C3): # C3 module with GhostBottleneck() def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): @@ -271,34 +356,6 @@ class GhostBottleneck(nn.Module): return self.conv(x) + self.shortcut(x) -class Contract(nn.Module): - # Contract width-height into channels, i.e. x(1,64,80,80) to x(1,256,40,40) - def __init__(self, gain=2): - super().__init__() - self.gain = gain - - def forward(self, x): - b, c, h, w = x.size() # assert (h / s == 0) and (W / s == 0), 'Indivisible gain' - s = self.gain - x = x.view(b, c, h // s, s, w // s, s) # x(1,64,40,2,40,2) - x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # x(1,2,2,64,40,40) - return x.view(b, c * s * s, h // s, w // s) # x(1,256,40,40) - - -class Expand(nn.Module): - # Expand channels into width-height, i.e. x(1,64,80,80) to x(1,16,160,160) - def __init__(self, gain=2): - super().__init__() - self.gain = gain - - def forward(self, x): - b, c, h, w = x.size() # assert C / s ** 2 == 0, 'Indivisible gain' - s = self.gain - x = x.view(b, s, s, c // s ** 2, h, w) # x(1,2,2,16,80,80) - x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # x(1,16,80,2,80,2) - return x.view(b, c // s ** 2, h * s, w * s) # x(1,16,160,160) - - class Concat(nn.Module): # Concatenate a list of tensors along dimension def __init__(self, dimension=1):