Add pose model support

triplemu/pose-infer
triple-Mu 2 years ago
parent d9687b12c1
commit ccb5833703
  1. 15
      config.py
  2. 112
      infer-pose-without-torch.py
  3. 107
      infer-pose.py
  4. 22
      models/torch_utils.py
  5. 23
      models/utils.py

@ -36,5 +36,20 @@ MASK_COLORS = np.array([(255, 56, 56), (255, 157, 151), (255, 112, 31),
(255, 149, 200), (255, 55, 199)],
dtype=np.float32) / 255.
KPS_COLORS = [[0, 255, 0], [0, 255, 0], [0, 255, 0], [0, 255, 0], [0, 255, 0],
[255, 128, 0], [255, 128, 0], [255, 128, 0], [255, 128, 0],
[255, 128, 0], [255, 128, 0], [51, 153, 255], [51, 153, 255],
[51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255]]
SKELETON = [[16, 14], [14, 12], [17, 15], [15, 13], [12, 13], [6, 12], [7, 13],
[6, 7], [6, 8], [7, 9], [8, 10], [9, 11], [2, 3], [1, 2], [1, 3],
[2, 4], [3, 5], [4, 6], [5, 7]]
LIMB_COLORS = [[51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255],
[255, 51, 255], [255, 51, 255], [255, 51, 255], [255, 128, 0],
[255, 128, 0], [255, 128, 0], [255, 128, 0], [255, 128, 0],
[0, 255, 0], [0, 255, 0], [0, 255, 0], [0, 255, 0], [0, 255, 0],
[0, 255, 0], [0, 255, 0]]
# alpha for segment masks
ALPHA = 0.5

@ -0,0 +1,112 @@
import argparse
from pathlib import Path
import cv2
import numpy as np
from config import COLORS, SKELETON, KPS_COLORS, LIMB_COLORS
from models.utils import blob, letterbox, path_to_list, pose_postprocess
def main(args: argparse.Namespace) -> None:
if args.method == 'cudart':
from models.cudart_api import TRTEngine
elif args.method == 'pycuda':
from models.pycuda_api import TRTEngine
else:
raise NotImplementedError
Engine = TRTEngine(args.engine)
H, W = Engine.inp_info[0].shape[-2:]
images = path_to_list(args.imgs)
save_path = Path(args.out_dir)
if not args.show and not save_path.exists():
save_path.mkdir(parents=True, exist_ok=True)
for image in images:
save_image = save_path / image.name
bgr = cv2.imread(str(image))
draw = bgr.copy()
bgr, ratio, dwdh = letterbox(bgr, (W, H))
dw, dh = int(dwdh[0]), int(dwdh[1])
rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
tensor = blob(rgb, return_seg=False)
dwdh = np.array(dwdh * 2, dtype=np.float32)
tensor = np.ascontiguousarray(tensor)
# inference
data = Engine(tensor)
bboxes, scores, kpts = pose_postprocess(
data,args.conf_thres, args.iou_thres)
bboxes -= dwdh
bboxes /= ratio
for (bbox, score, kpt) in zip(bboxes, scores, kpts):
bbox = bbox.round().astype(np.int32).tolist()
color = COLORS['person']
cv2.rectangle(draw, bbox[:2], bbox[2:], color, 2)
cv2.putText(draw,
f'person:{score:.3f}', (bbox[0], bbox[1] - 2),
cv2.FONT_HERSHEY_SIMPLEX,
0.75, [225, 255, 255],
thickness=2)
for i in range(19):
if i < 17:
px, py, ps = kpt[i]
if ps > 0.5:
kcolor = KPS_COLORS[i]
px = round(float(px - dw) / ratio)
py = round(float(py - dh) / ratio)
cv2.circle(draw, (px, py), 5, kcolor, -1)
xi, yi = SKELETON[i]
pos1_s = kpt[xi - 1][2]
pos2_s = kpt[yi - 1][2]
if pos1_s > 0.5 and pos2_s > 0.5:
limb_color = LIMB_COLORS[i]
pos1_x = round(float(kpt[xi - 1][0] - dw) / ratio)
pos1_y = round(float(kpt[xi - 1][1] - dh) / ratio)
pos2_x = round(float(kpt[yi - 1][0] - dw) / ratio)
pos2_y = round(float(kpt[yi - 1][1] - dh) / ratio)
cv2.line(draw, (pos1_x, pos1_y), (pos2_x, pos2_y), limb_color, 2)
if args.show:
cv2.imshow('result', draw)
cv2.waitKey(0)
else:
cv2.imwrite(str(save_image), draw)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--engine', type=str, help='Engine file')
parser.add_argument('--imgs', type=str, help='Images file')
parser.add_argument('--show',
action='store_true',
help='Show the detection results')
parser.add_argument('--out-dir',
type=str,
default='./output',
help='Path to output file')
parser.add_argument('--conf-thres',
type=float,
default=0.25,
help='Confidence threshold')
parser.add_argument('--iou-thres',
type=float,
default=0.65,
help='Confidence threshold')
parser.add_argument('--method',
type=str,
default='cudart',
help='CUDART pipeline')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
main(args)

@ -0,0 +1,107 @@
from models import TRTModule # isort:skip
import argparse
from pathlib import Path
import cv2
import torch
from config import COLORS, SKELETON, KPS_COLORS, LIMB_COLORS
from models.torch_utils import pose_postprocess
from models.utils import blob, letterbox, path_to_list
def main(args: argparse.Namespace) -> None:
device = torch.device(args.device)
Engine = TRTModule(args.engine, device)
H, W = Engine.inp_info[0].shape[-2:]
images = path_to_list(args.imgs)
save_path = Path(args.out_dir)
if not args.show and not save_path.exists():
save_path.mkdir(parents=True, exist_ok=True)
for image in images:
save_image = save_path / image.name
bgr = cv2.imread(str(image))
draw = bgr.copy()
bgr, ratio, dwdh = letterbox(bgr, (W, H))
dw, dh = int(dwdh[0]), int(dwdh[1])
rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
tensor = blob(rgb, return_seg=False)
dwdh = torch.asarray(dwdh * 2, dtype=torch.float32, device=device)
tensor = torch.asarray(tensor, device=device)
# inference
data = Engine(tensor)
bboxes, scores, kpts = pose_postprocess(data, args.conf_thres, args.iou_thres)
bboxes -= dwdh
bboxes /= ratio
for (bbox, score, kpt) in zip(bboxes, scores, kpts):
bbox = bbox.round().int().tolist()
color = COLORS['person']
cv2.rectangle(draw, bbox[:2], bbox[2:], color, 2)
cv2.putText(draw,
f'person:{score:.3f}', (bbox[0], bbox[1] - 2),
cv2.FONT_HERSHEY_SIMPLEX,
0.75, [225, 255, 255],
thickness=2)
for i in range(19):
if i < 17:
px, py, ps = kpt[i]
if ps > 0.5:
kcolor = KPS_COLORS[i]
px = round(float(px - dw) / ratio)
py = round(float(py - dh) / ratio)
cv2.circle(draw, (px, py), 5, kcolor, -1)
xi, yi = SKELETON[i]
pos1_s = kpt[xi - 1][2]
pos2_s = kpt[yi - 1][2]
if pos1_s > 0.5 and pos2_s > 0.5:
limb_color = LIMB_COLORS[i]
pos1_x = round(float(kpt[xi - 1][0] - dw) / ratio)
pos1_y = round(float(kpt[xi - 1][1] - dh) / ratio)
pos2_x = round(float(kpt[yi - 1][0] - dw) / ratio)
pos2_y = round(float(kpt[yi - 1][1] - dh) / ratio)
cv2.line(draw, (pos1_x, pos1_y), (pos2_x, pos2_y), limb_color, 2)
if args.show:
cv2.imshow('result', draw)
cv2.waitKey(0)
else:
cv2.imwrite(str(save_image), draw)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument('--engine', type=str, help='Engine file')
parser.add_argument('--imgs', type=str, help='Images file')
parser.add_argument('--show',
action='store_true',
help='Show the detection results')
parser.add_argument('--out-dir',
type=str,
default='./output',
help='Path to output file')
parser.add_argument('--conf-thres',
type=float,
default=0.25,
help='Confidence threshold')
parser.add_argument('--iou-thres',
type=float,
default=0.65,
help='Confidence threshold')
parser.add_argument('--device',
type=str,
default='cuda:0',
help='TensorRT infer device')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
main(args)

@ -3,7 +3,7 @@ from typing import List, Tuple, Union
import torch
import torch.nn.functional as F
from torch import Tensor
from torchvision.ops import batched_nms
from torchvision.ops import batched_nms, nms
def seg_postprocess(
@ -33,6 +33,26 @@ def seg_postprocess(
return bboxes, scores, labels, masks
def pose_postprocess(
data: Union[Tuple, Tensor],
conf_thres: float = 0.25,
iou_thres: float = 0.65) \
-> Tuple[Tensor, Tensor, Tensor]:
if isinstance(data, tuple):
assert len(data) == 1
data = data[0]
outputs = torch.transpose(data[0], 0, 1).contiguous()
bboxes, scores, kpts = outputs.split([4, 1, 51], 1)
scores, kpts = scores.squeeze(), kpts.squeeze()
idx = scores > conf_thres
bboxes, scores, kpts = bboxes[idx], scores[idx], kpts[idx]
xycenter, wh = bboxes.chunk(2, -1)
bboxes = torch.cat([xycenter - 0.5 * wh, xycenter + 0.5 * wh], -1)
idx = nms(bboxes, scores, iou_thres)
bboxes, scores, kpts = bboxes[idx], scores[idx], kpts[idx]
return bboxes, scores, kpts.reshape(idx.shape[0], -1, 3)
def det_postprocess(data: Tuple[Tensor, Tensor, Tensor, Tensor]):
assert len(data) == 4
num_dets, bboxes, scores, labels = (i[0] for i in data)

@ -128,3 +128,26 @@ def seg_postprocess(
masks = masks.transpose(2, 0, 1)
masks = np.ascontiguousarray((masks > 0.5)[..., None], dtype=np.float32)
return bboxes, scores, labels, masks
def pose_postprocess(
data: Union[Tuple, ndarray],
conf_thres: float = 0.25,
iou_thres: float = 0.65) \
-> Tuple[ndarray, ndarray, ndarray]:
if isinstance(data, tuple):
assert len(data) == 1
data = data[0]
outputs = np.transpose(data[0], (1, 0))
bboxes, scores, kpts = np.split(outputs, [4, 5], 1)
scores, kpts = scores.squeeze(), kpts.squeeze()
idx = scores > conf_thres
bboxes, scores, kpts = bboxes[idx], scores[idx], kpts[idx]
xycenter, wh = np.split(bboxes, [
2,
], -1)
cvbboxes = np.concatenate([xycenter - 0.5 * wh, wh], -1)
idx = cv2.dnn.NMSBoxes(cvbboxes, scores, conf_thres, iou_thres)
cvbboxes, scores, kpts = cvbboxes[idx], scores[idx], kpts[idx]
cvbboxes[:, 2:] += cvbboxes[:, :2]
return cvbboxes, scores, kpts.reshape(idx.shape[0], -1, 3)

Loading…
Cancel
Save