add yolov9-c

pull/8571/head
Laughing-q 10 months ago
parent 2945cfc6ef
commit f222d549c5
  1. 36
      ultralytics/cfg/models/v9/yolov9-c.yaml
  2. 6
      ultralytics/nn/modules/__init__.py
  3. 94
      ultralytics/nn/modules/block.py
  4. 6
      ultralytics/nn/tasks.py

@ -0,0 +1,36 @@
# YOLOv9
# parameters
nc: 80 # number of classes
# gelan backbone
backbone:
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
- [-1, 1, RepNCSPELAN4, [256, 128, 64, 1]] # 2
- [-1, 1, ADown, [256]] # 3-P3/8
- [-1, 1, RepNCSPELAN4, [512, 256, 128, 1]] # 4
- [-1, 1, ADown, [512]] # 5-P4/16
- [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]] # 6
- [-1, 1, ADown, [512]] # 7-P5/32
- [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]] # 8
- [-1, 1, SPPELAN, [512, 256]] # 9
head:
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
- [[-1, 6], 1, Concat, [1]] # cat backbone P4
- [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]] # 12
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
- [[-1, 4], 1, Concat, [1]] # cat backbone P3
- [-1, 1, RepNCSPELAN4, [256, 256, 128, 1]] # 15 (P3/8-small)
- [-1, 1, ADown, [256]]
- [[-1, 12], 1, Concat, [1]] # cat head P4
- [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]] # 18 (P4/16-medium)
- [-1, 1, ADown, [512]]
- [[-1, 9], 1, Concat, [1]] # cat head P5
- [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]] # 21 (P5/32-large)
- [[15, 18, 21], 1, Detect, [nc]] # DDetect(P3, P4, P5)

@ -40,6 +40,9 @@ from .block import (
ResNetLayer,
ContrastiveHead,
BNContrastiveHead,
RepNCSPELAN4,
ADown,
SPPELAN,
)
from .conv import (
CBAM,
@ -123,4 +126,7 @@ __all__ = (
"ImagePoolingAttn",
"ContrastiveHead",
"BNContrastiveHead",
"RepNCSPELAN4",
"ADown",
"SPPELAN",
)

@ -31,6 +31,9 @@ __all__ = (
"Proto",
"RepC3",
"ResNetLayer",
"RepNCSPELAN4",
"ADown",
"SPPELAN",
)
@ -548,3 +551,94 @@ class BNContrastiveHead(nn.Module):
w = F.normalize(w, dim=-1, p=2)
x = torch.einsum("bchw,bkc->bkhw", x, w)
return x * self.logit_scale.exp() + self.bias
class RepBottleneck(nn.Module):
"""Rep bottleneck."""
def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5): # ch_in, ch_out, shortcut, kernels, groups, expand
super().__init__()
c_ = int(c2 * e) # hidden channels
self.cv1 = RepConv(c1, c_, k[0], 1)
self.cv2 = Conv(c_, c2, k[1], 1, g=g)
self.add = shortcut and c1 == c2
def forward(self, x):
"""Forward pass through RepBottleneck layer."""
return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
class RepCSP(nn.Module):
"""Rep CSP Bottleneck with 3 convolutions."""
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
super().__init__()
c_ = int(c2 * e) # hidden channels
self.cv1 = Conv(c1, c_, 1, 1)
self.cv2 = Conv(c1, c_, 1, 1)
self.cv3 = Conv(2 * c_, c2, 1) # optional act=FReLU(c2)
self.m = nn.Sequential(*(RepBottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
def forward(self, x):
"""Forward pass through RepCSP layer."""
return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1))
class RepNCSPELAN4(nn.Module):
"""CSP-ELAN."""
def __init__(self, c1, c2, c3, c4, c5=1): # ch_in, ch_out, number, shortcut, groups, expansion
super().__init__()
self.c = c3 // 2
self.cv1 = Conv(c1, c3, 1, 1)
self.cv2 = nn.Sequential(RepCSP(c3 // 2, c4, c5), Conv(c4, c4, 3, 1))
self.cv3 = nn.Sequential(RepCSP(c4, c4, c5), Conv(c4, c4, 3, 1))
self.cv4 = Conv(c3 + (2 * c4), c2, 1, 1)
def forward(self, x):
"""Forward pass through RepNCSPELAN4 layer."""
y = list(self.cv1(x).chunk(2, 1))
y.extend((m(y[-1])) for m in [self.cv2, self.cv3])
return self.cv4(torch.cat(y, 1))
def forward_split(self, x):
"""Forward pass using split() instead of chunk()."""
y = list(self.cv1(x).split((self.c, self.c), 1))
y.extend(m(y[-1]) for m in [self.cv2, self.cv3])
return self.cv4(torch.cat(y, 1))
class ADown(nn.Module):
"""ADown."""
def __init__(self, c1, c2): # ch_in, ch_out, shortcut, kernels, groups, expand
super().__init__()
self.c = c2 // 2
self.cv1 = Conv(c1 // 2, self.c, 3, 2, 1)
self.cv2 = Conv(c1 // 2, self.c, 1, 1, 0)
def forward(self, x):
x = torch.nn.functional.avg_pool2d(x, 2, 1, 0, False, True)
x1, x2 = x.chunk(2, 1)
x1 = self.cv1(x1)
x2 = torch.nn.functional.max_pool2d(x2, 3, 2, 1)
x2 = self.cv2(x2)
return torch.cat((x1, x2), 1)
class SPPELAN(nn.Module):
"""SPP-ELAN."""
def __init__(self, c1, c2, c3): # ch_in, ch_out, number, shortcut, groups, expansion
super().__init__()
self.c = c3
self.cv1 = Conv(c1, c3, 1, 1)
self.cv2 = nn.MaxPool2d(kernel_size=5, stride=1, padding=5 // 2)
self.cv3 = nn.MaxPool2d(kernel_size=5, stride=1, padding=5 // 2)
self.cv4 = nn.MaxPool2d(kernel_size=5, stride=1, padding=5 // 2)
self.cv5 = Conv(4 * c3, c2, 1, 1)
def forward(self, x):
y = [self.cv1(x)]
y.extend(m(y[-1]) for m in [self.cv2, self.cv3, self.cv4])
return self.cv5(torch.cat(y, 1))

@ -43,6 +43,9 @@ from ultralytics.nn.modules import (
RTDETRDecoder,
Segment,
WorldDetect,
RepNCSPELAN4,
ADown,
SPPELAN,
)
from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load
from ultralytics.utils.checks import check_requirements, check_suffix, check_yaml
@ -850,6 +853,9 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
C1,
C2,
C2f,
RepNCSPELAN4,
ADown,
SPPELAN,
C2fAttn,
C3,
C3TR,

Loading…
Cancel
Save