From 5b09fff3cd7621626b8a76a8b0604c5ce6e40e3b Mon Sep 17 00:00:00 2001 From: triple-Mu Date: Tue, 11 Jul 2023 16:21:31 +0800 Subject: [PATCH] Support yolov8 original pose model. --- config.py | 15 +++++ infer-det-without-torch.py | 4 ++ infer-det.py | 4 ++ infer-pose-without-torch.py | 116 ++++++++++++++++++++++++++++++++++++ infer-pose.py | 112 ++++++++++++++++++++++++++++++++++ infer-seg-without-torch.py | 4 ++ infer-seg.py | 7 +-- models/torch_utils.py | 38 ++++++++++-- models/utils.py | 36 +++++++++++ 9 files changed, 327 insertions(+), 9 deletions(-) create mode 100644 infer-pose-without-torch.py create mode 100644 infer-pose.py diff --git a/config.py b/config.py index e5b4e7b..6a0e315 100644 --- a/config.py +++ b/config.py @@ -36,5 +36,20 @@ MASK_COLORS = np.array([(255, 56, 56), (255, 157, 151), (255, 112, 31), (255, 149, 200), (255, 55, 199)], dtype=np.float32) / 255. +KPS_COLORS = [[0, 255, 0], [0, 255, 0], [0, 255, 0], [0, 255, 0], [0, 255, 0], + [255, 128, 0], [255, 128, 0], [255, 128, 0], [255, 128, 0], + [255, 128, 0], [255, 128, 0], [51, 153, 255], [51, 153, 255], + [51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255]] + +SKELETON = [[16, 14], [14, 12], [17, 15], [15, 13], [12, 13], [6, 12], [7, 13], + [6, 7], [6, 8], [7, 9], [8, 10], [9, 11], [2, 3], [1, 2], [1, 3], + [2, 4], [3, 5], [4, 6], [5, 7]] + +LIMB_COLORS = [[51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255], + [255, 51, 255], [255, 51, 255], [255, 51, 255], [255, 128, 0], + [255, 128, 0], [255, 128, 0], [255, 128, 0], [255, 128, 0], + [0, 255, 0], [0, 255, 0], [0, 255, 0], [0, 255, 0], [0, 255, 0], + [0, 255, 0], [0, 255, 0]] + # alpha for segment masks ALPHA = 0.5 diff --git a/infer-det-without-torch.py b/infer-det-without-torch.py index a38eecd..81055ce 100644 --- a/infer-det-without-torch.py +++ b/infer-det-without-torch.py @@ -38,6 +38,10 @@ def main(args: argparse.Namespace) -> None: data = Engine(tensor) bboxes, scores, labels = det_postprocess(data) + if bboxes.size == 0: + # if no bounding box + print(f'{image}: no object!') + continue bboxes -= dwdh bboxes /= ratio diff --git a/infer-det.py b/infer-det.py index 55c50e9..ab9586d 100644 --- a/infer-det.py +++ b/infer-det.py @@ -37,6 +37,10 @@ def main(args: argparse.Namespace) -> None: data = Engine(tensor) bboxes, scores, labels = det_postprocess(data) + if bboxes.numel() == 0: + # if no bounding box + print(f'{image}: no object!') + continue bboxes -= dwdh bboxes /= ratio diff --git a/infer-pose-without-torch.py b/infer-pose-without-torch.py new file mode 100644 index 0000000..7b89a84 --- /dev/null +++ b/infer-pose-without-torch.py @@ -0,0 +1,116 @@ +import argparse +from pathlib import Path + +import cv2 +import numpy as np + +from config import COLORS, KPS_COLORS, LIMB_COLORS, SKELETON +from models.utils import blob, letterbox, path_to_list, pose_postprocess + + +def main(args: argparse.Namespace) -> None: + if args.method == 'cudart': + from models.cudart_api import TRTEngine + elif args.method == 'pycuda': + from models.pycuda_api import TRTEngine + else: + raise NotImplementedError + + Engine = TRTEngine(args.engine) + H, W = Engine.inp_info[0].shape[-2:] + + images = path_to_list(args.imgs) + save_path = Path(args.out_dir) + + if not args.show and not save_path.exists(): + save_path.mkdir(parents=True, exist_ok=True) + + for image in images: + save_image = save_path / image.name + bgr = cv2.imread(str(image)) + draw = bgr.copy() + bgr, ratio, dwdh = letterbox(bgr, (W, H)) + dw, dh = int(dwdh[0]), int(dwdh[1]) + rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB) + tensor = blob(rgb, return_seg=False) + dwdh = np.array(dwdh * 2, dtype=np.float32) + tensor = np.ascontiguousarray(tensor) + # inference + data = Engine(tensor) + + bboxes, scores, kpts = pose_postprocess(data, args.conf_thres, + args.iou_thres) + if bboxes.size == 0: + # if no bounding box + print(f'{image}: no object!') + continue + bboxes -= dwdh + bboxes /= ratio + + for (bbox, score, kpt) in zip(bboxes, scores, kpts): + bbox = bbox.round().astype(np.int32).tolist() + color = COLORS['person'] + cv2.rectangle(draw, bbox[:2], bbox[2:], color, 2) + cv2.putText(draw, + f'person:{score:.3f}', (bbox[0], bbox[1] - 2), + cv2.FONT_HERSHEY_SIMPLEX, + 0.75, [225, 255, 255], + thickness=2) + for i in range(19): + if i < 17: + px, py, ps = kpt[i] + if ps > 0.5: + kcolor = KPS_COLORS[i] + px = round(float(px - dw) / ratio) + py = round(float(py - dh) / ratio) + cv2.circle(draw, (px, py), 5, kcolor, -1) + xi, yi = SKELETON[i] + pos1_s = kpt[xi - 1][2] + pos2_s = kpt[yi - 1][2] + if pos1_s > 0.5 and pos2_s > 0.5: + limb_color = LIMB_COLORS[i] + pos1_x = round(float(kpt[xi - 1][0] - dw) / ratio) + pos1_y = round(float(kpt[xi - 1][1] - dh) / ratio) + + pos2_x = round(float(kpt[yi - 1][0] - dw) / ratio) + pos2_y = round(float(kpt[yi - 1][1] - dh) / ratio) + + cv2.line(draw, (pos1_x, pos1_y), (pos2_x, pos2_y), + limb_color, 2) + if args.show: + cv2.imshow('result', draw) + cv2.waitKey(0) + else: + cv2.imwrite(str(save_image), draw) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--engine', type=str, help='Engine file') + parser.add_argument('--imgs', type=str, help='Images file') + parser.add_argument('--show', + action='store_true', + help='Show the detection results') + parser.add_argument('--out-dir', + type=str, + default='./output', + help='Path to output file') + parser.add_argument('--conf-thres', + type=float, + default=0.25, + help='Confidence threshold') + parser.add_argument('--iou-thres', + type=float, + default=0.65, + help='Confidence threshold') + parser.add_argument('--method', + type=str, + default='cudart', + help='CUDART pipeline') + args = parser.parse_args() + return args + + +if __name__ == '__main__': + args = parse_args() + main(args) diff --git a/infer-pose.py b/infer-pose.py new file mode 100644 index 0000000..0ba0679 --- /dev/null +++ b/infer-pose.py @@ -0,0 +1,112 @@ +from models import TRTModule # isort:skip +import argparse +from pathlib import Path + +import cv2 +import torch + +from config import COLORS, KPS_COLORS, LIMB_COLORS, SKELETON +from models.torch_utils import pose_postprocess +from models.utils import blob, letterbox, path_to_list + + +def main(args: argparse.Namespace) -> None: + device = torch.device(args.device) + Engine = TRTModule(args.engine, device) + H, W = Engine.inp_info[0].shape[-2:] + + images = path_to_list(args.imgs) + save_path = Path(args.out_dir) + + if not args.show and not save_path.exists(): + save_path.mkdir(parents=True, exist_ok=True) + + for image in images: + save_image = save_path / image.name + bgr = cv2.imread(str(image)) + draw = bgr.copy() + bgr, ratio, dwdh = letterbox(bgr, (W, H)) + dw, dh = int(dwdh[0]), int(dwdh[1]) + rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB) + tensor = blob(rgb, return_seg=False) + dwdh = torch.asarray(dwdh * 2, dtype=torch.float32, device=device) + tensor = torch.asarray(tensor, device=device) + # inference + data = Engine(tensor) + + bboxes, scores, kpts = pose_postprocess(data, args.conf_thres, + args.iou_thres) + if bboxes.numel() == 0: + # if no bounding box + print(f'{image}: no object!') + continue + bboxes -= dwdh + bboxes /= ratio + + for (bbox, score, kpt) in zip(bboxes, scores, kpts): + bbox = bbox.round().int().tolist() + color = COLORS['person'] + cv2.rectangle(draw, bbox[:2], bbox[2:], color, 2) + cv2.putText(draw, + f'person:{score:.3f}', (bbox[0], bbox[1] - 2), + cv2.FONT_HERSHEY_SIMPLEX, + 0.75, [225, 255, 255], + thickness=2) + for i in range(19): + if i < 17: + px, py, ps = kpt[i] + if ps > 0.5: + kcolor = KPS_COLORS[i] + px = round(float(px - dw) / ratio) + py = round(float(py - dh) / ratio) + cv2.circle(draw, (px, py), 5, kcolor, -1) + xi, yi = SKELETON[i] + pos1_s = kpt[xi - 1][2] + pos2_s = kpt[yi - 1][2] + if pos1_s > 0.5 and pos2_s > 0.5: + limb_color = LIMB_COLORS[i] + pos1_x = round(float(kpt[xi - 1][0] - dw) / ratio) + pos1_y = round(float(kpt[xi - 1][1] - dh) / ratio) + + pos2_x = round(float(kpt[yi - 1][0] - dw) / ratio) + pos2_y = round(float(kpt[yi - 1][1] - dh) / ratio) + + cv2.line(draw, (pos1_x, pos1_y), (pos2_x, pos2_y), + limb_color, 2) + if args.show: + cv2.imshow('result', draw) + cv2.waitKey(0) + else: + cv2.imwrite(str(save_image), draw) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser() + parser.add_argument('--engine', type=str, help='Engine file') + parser.add_argument('--imgs', type=str, help='Images file') + parser.add_argument('--show', + action='store_true', + help='Show the detection results') + parser.add_argument('--out-dir', + type=str, + default='./output', + help='Path to output file') + parser.add_argument('--conf-thres', + type=float, + default=0.25, + help='Confidence threshold') + parser.add_argument('--iou-thres', + type=float, + default=0.65, + help='Confidence threshold') + parser.add_argument('--device', + type=str, + default='cuda:0', + help='TensorRT infer device') + args = parser.parse_args() + return args + + +if __name__ == '__main__': + args = parse_args() + main(args) diff --git a/infer-seg-without-torch.py b/infer-seg-without-torch.py index f1923e3..fa08b62 100644 --- a/infer-seg-without-torch.py +++ b/infer-seg-without-torch.py @@ -41,6 +41,10 @@ def main(args: argparse.Namespace) -> None: seg_img = seg_img[dh:H - dh, dw:W - dw, [2, 1, 0]] bboxes, scores, labels, masks = seg_postprocess( data, bgr.shape[:2], args.conf_thres, args.iou_thres) + if bboxes.size == 0: + # if no bounding box + print(f'{image}: no object!') + continue masks = masks[:, dh:H - dh, dw:W - dw, :] mask_colors = MASK_COLORS[labels % len(MASK_COLORS)] mask_colors = mask_colors.reshape(-1, 1, 1, 3) * ALPHA diff --git a/infer-seg.py b/infer-seg.py index de77316..0284244 100644 --- a/infer-seg.py +++ b/infer-seg.py @@ -42,10 +42,9 @@ def main(args: argparse.Namespace) -> None: device=device) bboxes, scores, labels, masks = seg_postprocess( data, bgr.shape[:2], args.conf_thres, args.iou_thres) - if bboxes is None: - # if no bounding box or others save original image - if not args.show: - cv2.imwrite(str(save_image), draw) + if bboxes.numel() == 0: + # if no bounding box + print(f'{image}: no object!') continue masks = masks[:, dh:H - dh, dw:W - dw, :] indices = (labels % len(MASK_COLORS)).long() diff --git a/models/torch_utils.py b/models/torch_utils.py index 1bf7d39..36d0255 100644 --- a/models/torch_utils.py +++ b/models/torch_utils.py @@ -3,7 +3,7 @@ from typing import List, Tuple, Union import torch import torch.nn.functional as F from torch import Tensor -from torchvision.ops import batched_nms +from torchvision.ops import batched_nms, nms def seg_postprocess( @@ -14,12 +14,13 @@ def seg_postprocess( -> Tuple[Tensor, Tensor, Tensor, Tensor]: assert len(data) == 2 h, w = shape[0] // 4, shape[1] // 4 # 4x downsampling - outputs, proto = (i[0] for i in data) + outputs, proto = data[0][0], data[1][0] bboxes, scores, labels, maskconf = outputs.split([4, 1, 1, 32], 1) scores, labels = scores.squeeze(), labels.squeeze() idx = scores > conf_thres - if idx.sum() == 0: # no bounding boxes or seg were created - return None, None, None, None + if not idx.any(): # no bounding boxes or seg were created + return bboxes.new_zeros((0, 4)), scores.new_zeros( + (0, )), labels.new_zeros((0, )), bboxes.new_zeros((0, 0, 0, 0)) bboxes, scores, labels, maskconf = \ bboxes[idx], scores[idx], labels[idx], maskconf[idx] idx = batched_nms(bboxes, scores, labels, iou_thres) @@ -35,10 +36,37 @@ def seg_postprocess( return bboxes, scores, labels, masks +def pose_postprocess( + data: Union[Tuple, Tensor], + conf_thres: float = 0.25, + iou_thres: float = 0.65) \ + -> Tuple[Tensor, Tensor, Tensor]: + if isinstance(data, tuple): + assert len(data) == 1 + data = data[0] + outputs = torch.transpose(data[0], 0, 1).contiguous() + bboxes, scores, kpts = outputs.split([4, 1, 51], 1) + scores, kpts = scores.squeeze(), kpts.squeeze() + idx = scores > conf_thres + if not idx.any(): # no bounding boxes or seg were created + return bboxes.new_zeros((0, 4)), scores.new_zeros( + (0, )), bboxes.new_zeros((0, 0, 0)) + bboxes, scores, kpts = bboxes[idx], scores[idx], kpts[idx] + xycenter, wh = bboxes.chunk(2, -1) + bboxes = torch.cat([xycenter - 0.5 * wh, xycenter + 0.5 * wh], -1) + idx = nms(bboxes, scores, iou_thres) + bboxes, scores, kpts = bboxes[idx], scores[idx], kpts[idx] + return bboxes, scores, kpts.reshape(idx.shape[0], -1, 3) + + def det_postprocess(data: Tuple[Tensor, Tensor, Tensor, Tensor]): assert len(data) == 4 - num_dets, bboxes, scores, labels = (i[0] for i in data) + num_dets, bboxes, scores, labels = data[0][0], data[1][0], data[2][ + 0], data[3][0] nums = num_dets.item() + if nums == 0: + return bboxes.new_zeros((0, 4)), scores.new_zeros( + (0, )), labels.new_zeros((0, )) bboxes = bboxes[:nums] scores = scores[:nums] labels = labels[:nums] diff --git a/models/utils.py b/models/utils.py index 01c5034..c0d78ae 100644 --- a/models/utils.py +++ b/models/utils.py @@ -45,6 +45,7 @@ def letterbox(im: ndarray, def blob(im: ndarray, return_seg: bool = False) -> Union[ndarray, Tuple]: + seg = None if return_seg: seg = im.astype(np.float32) / 255 im = im.transpose([2, 0, 1]) @@ -88,6 +89,9 @@ def det_postprocess(data: Tuple[ndarray, ndarray, ndarray, ndarray]): assert len(data) == 4 num_dets, bboxes, scores, labels = (i[0] for i in data) nums = num_dets.item() + if nums == 0: + return np.empty((0, 4), dtype=np.float32), np.empty( + (0, ), dtype=np.float32), np.empty((0, ), dtype=np.int32) bboxes = bboxes[:nums] scores = scores[:nums] labels = labels[:nums] @@ -106,6 +110,12 @@ def seg_postprocess( bboxes, scores, labels, maskconf = np.split(outputs, [4, 5, 6], 1) scores, labels = scores.squeeze(), labels.squeeze() idx = scores > conf_thres + if not idx.any(): # no bounding boxes or seg were created + return np.empty((0, 4), dtype=np.float32), \ + np.empty((0,), dtype=np.float32), \ + np.empty((0,), dtype=np.int32), \ + np.empty((0, 0, 0, 0), dtype=np.int32) + bboxes, scores, labels, maskconf = \ bboxes[idx], scores[idx], labels[idx], maskconf[idx] cvbboxes = np.concatenate([bboxes[:, :2], bboxes[:, 2:] - bboxes[:, :2]], @@ -128,3 +138,29 @@ def seg_postprocess( masks = masks.transpose(2, 0, 1) masks = np.ascontiguousarray((masks > 0.5)[..., None], dtype=np.float32) return bboxes, scores, labels, masks + + +def pose_postprocess( + data: Union[Tuple, ndarray], + conf_thres: float = 0.25, + iou_thres: float = 0.65) \ + -> Tuple[ndarray, ndarray, ndarray]: + if isinstance(data, tuple): + assert len(data) == 1 + data = data[0] + outputs = np.transpose(data[0], (1, 0)) + bboxes, scores, kpts = np.split(outputs, [4, 5], 1) + scores, kpts = scores.squeeze(), kpts.squeeze() + idx = scores > conf_thres + if not idx.any(): # no bounding boxes or seg were created + return np.empty((0, 4), dtype=np.float32), np.empty( + (0, ), dtype=np.float32), np.empty((0, 0, 0), dtype=np.float32) + bboxes, scores, kpts = bboxes[idx], scores[idx], kpts[idx] + xycenter, wh = np.split(bboxes, [ + 2, + ], -1) + cvbboxes = np.concatenate([xycenter - 0.5 * wh, wh], -1) + idx = cv2.dnn.NMSBoxes(cvbboxes, scores, conf_thres, iou_thres) + cvbboxes, scores, kpts = cvbboxes[idx], scores[idx], kpts[idx] + cvbboxes[:, 2:] += cvbboxes[:, :2] + return cvbboxes, scores, kpts.reshape(idx.shape[0], -1, 3)