|
|
|
@ -91,13 +91,14 @@ class Detect(nn.Module): |
|
|
|
|
# Inference path |
|
|
|
|
shape = x[0].shape # BCHW |
|
|
|
|
x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2) |
|
|
|
|
# if self.dynamic or self.shape != shape: |
|
|
|
|
self.anchors, self.strides = ( |
|
|
|
|
x.transpose(0, 1) for x in make_anchors(torch.Tensor([80, 40, 20]).cuda(), self.stride, 0.5) |
|
|
|
|
) |
|
|
|
|
self.shape = shape |
|
|
|
|
if self.export and self.format == "sony": |
|
|
|
|
self.anchors, self.strides = ( |
|
|
|
|
x.transpose(0, 1) for x in make_anchors(torch.Tensor([80, 40, 20]).cuda(), self.stride, 0.5) |
|
|
|
|
) |
|
|
|
|
self.strides /= 640 # NOTE: the relu could be removed in the future. |
|
|
|
|
elif self.dynamic or self.shape != shape: |
|
|
|
|
self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5)) |
|
|
|
|
self.shape = shape |
|
|
|
|
|
|
|
|
|
if self.export and self.format in {"saved_model", "pb", "tflite", "edgetpu", "tfjs"}: # avoid TF FlexSplitV ops |
|
|
|
|
box = x_cat[:, : self.reg_max * 4] |
|
|
|
@ -114,7 +115,7 @@ class Detect(nn.Module): |
|
|
|
|
norm = self.strides / (self.stride[0] * grid_size) |
|
|
|
|
dbox = self.decode_bboxes(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2]) |
|
|
|
|
elif self.export and self.format == "sony": |
|
|
|
|
dbox = self.decode_bboxes(self.dfl(box) * self.strides, self.anchors.unsqueeze(0) * self.strides) |
|
|
|
|
dbox = self.decode_bboxes(self.dfl(box) * self.strides, self.anchors.unsqueeze(0) * self.strides, xywh=False) |
|
|
|
|
# NOTE: the relu could be removed in the future. |
|
|
|
|
y1 = self.relu(dbox[:, 0, :]) |
|
|
|
|
x1 = self.relu(dbox[:, 1, :]) |
|
|
|
@ -140,9 +141,9 @@ class Detect(nn.Module): |
|
|
|
|
a[-1].bias.data[:] = 1.0 # box |
|
|
|
|
b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img) |
|
|
|
|
|
|
|
|
|
def decode_bboxes(self, bboxes, anchors): |
|
|
|
|
def decode_bboxes(self, bboxes, anchors, xywh=True): |
|
|
|
|
"""Decode bounding boxes.""" |
|
|
|
|
return dist2bbox(bboxes, anchors, xywh=False, dim=1) |
|
|
|
|
return dist2bbox(bboxes, anchors, xywh=xywh and (not self.end2end), dim=1) |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def postprocess(preds: torch.Tensor, max_det: int, nc: int = 80): |
|
|
|
|