|
|
@ -246,9 +246,21 @@ class Pose(Detect): |
|
|
|
def kpts_decode(self, bs, kpts): |
|
|
|
def kpts_decode(self, bs, kpts): |
|
|
|
"""Decodes keypoints.""" |
|
|
|
"""Decodes keypoints.""" |
|
|
|
ndim = self.kpt_shape[1] |
|
|
|
ndim = self.kpt_shape[1] |
|
|
|
if self.export: # required for TFLite export to avoid 'PLACEHOLDER_FOR_GREATER_OP_CODES' bug |
|
|
|
if self.export: |
|
|
|
y = kpts.view(bs, *self.kpt_shape, -1) |
|
|
|
if self.format in { |
|
|
|
a = (y[:, :, :2] * 2.0 + (self.anchors - 0.5)) * self.strides |
|
|
|
"tflite", |
|
|
|
|
|
|
|
"edgetpu", |
|
|
|
|
|
|
|
}: # required for TFLite export to avoid 'PLACEHOLDER_FOR_GREATER_OP_CODES' bug |
|
|
|
|
|
|
|
# Precompute normalization factor to increase numerical stability |
|
|
|
|
|
|
|
y = kpts.view(bs, *self.kpt_shape, -1) |
|
|
|
|
|
|
|
grid_h, grid_w = self.shape[2], self.shape[3] |
|
|
|
|
|
|
|
grid_size = torch.tensor([grid_w, grid_h], device=y.device).reshape(1, 2, 1) |
|
|
|
|
|
|
|
norm = self.strides / (self.stride[0] * grid_size) |
|
|
|
|
|
|
|
a = (y[:, :, :2] * 2.0 + (self.anchors - 0.5)) * norm |
|
|
|
|
|
|
|
else: |
|
|
|
|
|
|
|
# NCNN fix |
|
|
|
|
|
|
|
y = kpts.view(bs, *self.kpt_shape, -1) |
|
|
|
|
|
|
|
a = (y[:, :, :2] * 2.0 + (self.anchors - 0.5)) * self.strides |
|
|
|
if ndim == 3: |
|
|
|
if ndim == 3: |
|
|
|
a = torch.cat((a, y[:, :, 2:3].sigmoid()), 2) |
|
|
|
a = torch.cat((a, y[:, :, 2:3].sigmoid()), 2) |
|
|
|
return a.view(bs, self.nk, -1) |
|
|
|
return a.view(bs, self.nk, -1) |
|
|
|