Update head.py

exp-b
Laughing-q 6 months ago
parent 29785074cc
commit 9e218b86a8
  1. 114
      ultralytics/nn/modules/head.py

@ -104,6 +104,120 @@ class Detect(nn.Module):
return dist2bbox(bboxes, anchors, xywh=True, dim=1)
class Detect4(nn.Module):
"""YOLOv8 Detect head for detection models."""
dynamic = False # force grid reconstruction
export = False # export mode
shape = None
anchors = torch.empty(0) # init
strides = torch.empty(0) # init
def __init__(self, nc=80, ch=()):
"""Initializes the YOLOv8 detection layer with specified number of classes and channels."""
super().__init__()
self.nc = nc # number of classes
self.nl = len(ch) # number of detection layers
self.reg_max = 16 # DFL channels (ch[0] // 16 to scale 4/8/12/16/20 for n/s/m/l/x)
self.no = nc + self.reg_max * 4 # number of outputs per anchor
self.stride = torch.zeros(self.nl) # strides computed during build
c2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], min(self.nc, 100)) # channels
self.cv2 = nn.ModuleList(
nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch
)
# self.cv2 = nn.ModuleList(
# nn.Sequential(
# nn.Sequential(DWConv(x, x, 3), Conv(x, c3, 1)),
# nn.Sequential(DWConv(c3, c3, 3), Conv(c3, c3, 1)),
# nn.Conv2d(c3, 4 * self.reg_max, 1),
# )
# for x in ch
# )
# self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch)
self.cv3 = nn.ModuleList(
nn.Sequential(
nn.Sequential(DWConv(x, x, 3), Conv(x, c3, 1)),
nn.Sequential(DWConv(c3, c3, 3), Conv(c3, c3, 1)),
nn.Conv2d(c3, self.nc, 1),
)
for x in ch
)
self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()
self.cv4 = nn.ModuleList(
nn.Sequential(
Conv(4 * self.reg_max, 16, 1),
nn.Conv2d(16, 1, 1),
# nn.Conv2d(4 * 5, 1, 1),
# nn.Conv2d(4 * 1, 1, 1),
nn.Sigmoid(),
)
for _ in ch
)
def forward(self, x):
"""Concatenates and returns predicted bounding boxes and class probabilities."""
for i in range(self.nl):
box = self.cv2[i](x[i])
cls = self.cv3[i](x[i])
conf = self.cv4[i](box)
# N, C, H, W = box.size()
# prob = box.view(N, 4, self.reg_max, H, W).softmax(dim=2)
# prob_topk = prob.topk(4, dim=2)[0]
# prob_topk = torch.cat([prob_topk, prob_topk.mean(dim=2, keepdim=True)], dim=2).view(N, -1, H, W)
# conf = self.cv4[i](prob_topk)
# N, C, H, W = box.size()
# prob = box.view(N, 4, self.reg_max, H, W).softmax(dim=2)
# prob_max = prob.amax(dim=2)
# prob_mean = prob.mean(dim=2)
# conf = self.cv4[i](torch.cat([prob_max, prob_mean], dim=1))
# # conf = self.cv4[i](prob_max)
x[i] = torch.cat((box, cls * conf), 1)
if self.training: # Training path
return x
# Inference path
shape = x[0].shape # BCHW
x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)
if self.dynamic or self.shape != shape:
self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
self.shape = shape
if self.export and self.format in {"saved_model", "pb", "tflite", "edgetpu", "tfjs"}: # avoid TF FlexSplitV ops
box = x_cat[:, : self.reg_max * 4]
cls = x_cat[:, self.reg_max * 4 :]
else:
box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
if self.export and self.format in {"tflite", "edgetpu"}:
# Precompute normalization factor to increase numerical stability
# See https://github.com/ultralytics/ultralytics/issues/7371
grid_h = shape[2]
grid_w = shape[3]
grid_size = torch.tensor([grid_w, grid_h, grid_w, grid_h], device=box.device).reshape(1, 4, 1)
norm = self.strides / (self.stride[0] * grid_size)
dbox = self.decode_bboxes(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2])
else:
dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides
y = torch.cat((dbox, cls.sigmoid()), 1)
return y if self.export else (y, x)
def bias_init(self):
"""Initialize Detect() biases, WARNING: requires stride availability."""
m = self # self.model[-1] # Detect() module
# cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
# ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
for a, b, c, s in zip(m.cv2, m.cv3, m.cv4, m.stride): # from
a[-1].bias.data[:] = 1.0 # box
b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
c[-2].bias.data[:] = 2.0 # cls (.01 objects, 80 classes, 640 img)
def decode_bboxes(self, bboxes, anchors):
"""Decode bounding boxes."""
return dist2bbox(bboxes, anchors, xywh=True, dim=1)
class DetectNew(nn.Module):
"""YOLOv8 Detect head for detection models."""

Loading…
Cancel
Save