Update head.py

mct-2.1.1
Laughing-q 6 months ago
parent b65cdfadba
commit 72184613a2
  1. 17
      ultralytics/nn/modules/head.py

@ -91,13 +91,14 @@ class Detect(nn.Module):
# Inference path
shape = x[0].shape # BCHW
x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)
# if self.dynamic or self.shape != shape:
self.anchors, self.strides = (
x.transpose(0, 1) for x in make_anchors(torch.Tensor([80, 40, 20]).cuda(), self.stride, 0.5)
)
self.shape = shape
if self.export and self.format == "sony":
self.anchors, self.strides = (
x.transpose(0, 1) for x in make_anchors(torch.Tensor([80, 40, 20]).cuda(), self.stride, 0.5)
)
self.strides /= 640 # NOTE: the relu could be removed in the future.
elif self.dynamic or self.shape != shape:
self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
self.shape = shape
if self.export and self.format in {"saved_model", "pb", "tflite", "edgetpu", "tfjs"}: # avoid TF FlexSplitV ops
box = x_cat[:, : self.reg_max * 4]
@ -114,7 +115,7 @@ class Detect(nn.Module):
norm = self.strides / (self.stride[0] * grid_size)
dbox = self.decode_bboxes(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2])
elif self.export and self.format == "sony":
dbox = self.decode_bboxes(self.dfl(box) * self.strides, self.anchors.unsqueeze(0) * self.strides)
dbox = self.decode_bboxes(self.dfl(box) * self.strides, self.anchors.unsqueeze(0) * self.strides, xywh=False)
# NOTE: the relu could be removed in the future.
y1 = self.relu(dbox[:, 0, :])
x1 = self.relu(dbox[:, 1, :])
@ -140,9 +141,9 @@ class Detect(nn.Module):
a[-1].bias.data[:] = 1.0 # box
b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
def decode_bboxes(self, bboxes, anchors):
def decode_bboxes(self, bboxes, anchors, xywh=True):
"""Decode bounding boxes."""
return dist2bbox(bboxes, anchors, xywh=False, dim=1)
return dist2bbox(bboxes, anchors, xywh=xywh and (not self.end2end), dim=1)
@staticmethod
def postprocess(preds: torch.Tensor, max_det: int, nc: int = 80):

Loading…
Cancel
Save