Refactor code for detection and segment

pull/28/head
triple-Mu 2 years ago
parent 0770f5fa60
commit 13ed0033db
  1. 10
      .gitignore
  2. 14
      README.md
  3. 3
      build.py
  4. 40
      config.py
  5. 35
      docs/Segment.md
  6. 0
      export-det.py
  7. 0
      export-seg.py
  8. 91
      infer-det-without-torch.py
  9. 82
      infer-det.py
  10. 254
      infer-no-torch.py
  11. 100
      infer-seg-without-torch.py
  12. 101
      infer-seg.py
  13. 267
      infer.py
  14. 3
      models/engine.py
  15. 54
      models/torch_utils.py
  16. 125
      models/utils.py
  17. 29
      profile.py

10
.gitignore vendored

@ -1,3 +1,13 @@
# Personal
.idea
*.engine
*.pt
*.pth
*.onnx
*.jpg
*.png
*.bmp
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]

@ -54,7 +54,7 @@ You can export your onnx model by `ultralytics` API
and add postprocess into model at the same time.
``` shell
python3 export.py \
python3 export-det.py \
--weights yolov8s.pt \
--iou-thres 0.65 \
--conf-thres 0.25 \
@ -148,12 +148,12 @@ Please see more information in [`API-Build.md`](docs/API-Build.md)
## 1. Infer with python script
You can infer images with the engine by [`infer.py`](infer.py) .
You can infer images with the engine by [`infer-det.py`](infer-det.py) .
Usage:
``` shell
python3 infer.py \
python3 infer-det.py \
--engine yolov8s.engine \
--imgs data \
--show \
@ -215,13 +215,13 @@ If you want to profile the TensorRT engine:
Usage:
``` shell
python3 infer.py --engine yolov8s.engine --profile
python3 profile.py --engine yolov8s.engine --device cuda:0
```
# Refuse To Use PyTorch for model inference !!!
# Refuse To Use PyTorch for Model Inference !!!
If you need to break away from pytorch and use tensorrt inference,
you can get more information in [`infer-no-torch.py`](infer-no-torch.py),
you can get more information in [`infer-det-without-torch.py`](infer-det-without-torch.py),
the usage is the same as the pytorch version, but its performance is much worse.
You can use `cuda-python` or `pycuda` for inference.
@ -236,7 +236,7 @@ pip install pycuda
Usage:
``` shell
python3 infer-no-torch.py \
python3 infer-det-without-torch.py \
--engine yolov8s.engine \
--imgs data \
--show \

@ -1,10 +1,7 @@
import argparse
import os
from models import EngineBuilder
os.environ['CUDA_MODULE_LOADING'] = 'LAZY'
def parse_args():
parser = argparse.ArgumentParser()

@ -0,0 +1,40 @@
import random
import numpy as np
random.seed(0)
# detection model classes
CLASSES = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe',
'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat',
'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot',
'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop',
'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven',
'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
'scissors', 'teddy bear', 'hair drier', 'toothbrush')
# colors for per classes
COLORS = {
cls: [random.randint(0, 255) for _ in range(3)]
for i, cls in enumerate(CLASSES)
}
# colors for segment masks
MASK_COLORS = np.array([(255, 56, 56), (255, 157, 151), (255, 112, 31),
(255, 178, 29), (207, 210, 49), (72, 249, 10),
(146, 204, 23), (61, 219, 134), (26, 147, 52),
(0, 212, 187), (44, 153, 168), (0, 194, 255),
(52, 69, 147), (100, 115, 255), (0, 24, 236),
(132, 56, 255), (82, 0, 133), (203, 56, 255),
(255, 149, 200), (255, 55, 199)],
dtype=np.float32) / 255.
# alpha for segment masks
ALPHA = 0.5

@ -10,7 +10,7 @@ The yolov8-seg model conversion route is :
You can export your onnx model by `ultralytics` API and the onnx is also modify by this repo.
``` shell
python3 export_seg.py \
python3 export-seg.py \
--weights yolov8s-seg.pt \
--opset 11 \
--sim \
@ -68,18 +68,17 @@ Usage:
## Infer with python script
You can infer images with the engine by [`infer.py`](../infer.py) .
You can infer images with the engine by [`infer-seg.py`](../infer-seg.py) .
Usage:
``` shell
python3 infer.py \
python3 infer-seg.py \
--engine yolov8s-seg.engine \
--imgs data \
--show \
--out-dir outputs \
--device cuda:0 \
--seg
--device cuda:0
```
#### Description of all arguments
@ -90,7 +89,6 @@ python3 infer.py \
- `--out-dir` : Where to save detection results images. It will not work when use `--show` flag.
- `--device` : The CUDA deivce you use.
- `--profile` : Profile the TensorRT engine.
- `--seg` : Infer with seg model.
## Infer with C++
@ -207,3 +205,28 @@ Usage:
# infer video
./yolov8-seg yolov8s-seg.engine data/test.mp4 # the video path
```
# Refuse To Use PyTorch for segment Model Inference !!!
It is the same as detection model.
you can get more information in [`infer-seg-without-torch.py`](../infer-seg-without-torch.py),
Usage:
``` shell
python3 infer-seg-without-torch.py \
--engine yolov8s-seg.engine \
--imgs data \
--show \
--out-dir outputs \
--method cudart
```
#### Description of all arguments
- `--engine` : The Engine you export.
- `--imgs` : The images path you want to detect.
- `--show` : Whether to show detection results.
- `--out-dir` : Where to save detection results images. It will not work when use `--show` flag.
- `--method` : Choose `cudart` or `pycuda`, default is `cudart`.
- `--profile` : Profile the TensorRT engine.

@ -0,0 +1,91 @@
import argparse
from pathlib import Path
import cv2
import numpy as np
from config import CLASSES, COLORS
from models.utils import blob, det_postprocess, letterbox, path_to_list
def main(args: argparse.Namespace) -> None:
if args.method == 'cudart':
from models.cudart_api import TRTEngine
elif args.method == 'pycuda':
from models.pycuda_api import TRTEngine
else:
raise NotImplementedError
Engine = TRTEngine(args.engine)
H, W = Engine.inp_info[0].shape[-2:]
images = path_to_list(args.imgs)
save_path = Path(args.out_dir)
if not args.show and not save_path.exists():
save_path.mkdir(parents=True, exist_ok=True)
for image in images:
save_image = save_path / image.name
bgr = cv2.imread(str(image))
draw = bgr.copy()
bgr, ratio, dwdh = letterbox(bgr, (W, H))
rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
tensor = blob(rgb, return_seg=False)
dwdh = np.array(dwdh * 2, dtype=np.float32)
tensor = np.ascontiguousarray(tensor)
# inference
data = Engine(tensor)
bboxes, scores, labels = det_postprocess(data)
bboxes -= dwdh
bboxes /= ratio
for (bbox, score, label) in zip(bboxes, scores, labels):
bbox = bbox.round().astype(np.int32).tolist()
cls_id = int(label)
cls = CLASSES[cls_id]
color = COLORS[cls]
cv2.rectangle(draw, bbox[:2], bbox[2:], color, 2)
cv2.putText(draw,
f'{cls}:{score:.3f}', (bbox[0], bbox[1] - 2),
cv2.FONT_HERSHEY_SIMPLEX,
0.75, [225, 255, 255],
thickness=2)
if args.show:
cv2.imshow('result', draw)
cv2.waitKey(0)
else:
cv2.imwrite(str(save_image), draw)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--engine', type=str, help='Engine file')
parser.add_argument('--imgs', type=str, help='Images file')
parser.add_argument('--show',
action='store_true',
help='Show the detection results')
parser.add_argument('--out-dir',
type=str,
default='./output',
help='Path to output file')
parser.add_argument('--conf-thres',
type=float,
default=0.25,
help='Confidence threshold')
parser.add_argument('--iou-thres',
type=float,
default=0.65,
help='Confidence threshold')
parser.add_argument('--method',
type=str,
default='cudart',
help='CUDART pipeline')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
main(args)

@ -0,0 +1,82 @@
from models import TRTModule # isort:skip
import argparse
from pathlib import Path
import cv2
import torch
from config import CLASSES, COLORS
from models.torch_utils import det_postprocess
from models.utils import blob, letterbox, path_to_list
def main(args: argparse.Namespace) -> None:
device = torch.device(args.device)
Engine = TRTModule(args.engine, device)
H, W = Engine.inp_info[0].shape[-2:]
# set desired output names order
Engine.set_desired(['num_dets', 'bboxes', 'scores', 'labels'])
images = path_to_list(args.imgs)
save_path = Path(args.out_dir)
if not args.show and not save_path.exists():
save_path.mkdir(parents=True, exist_ok=True)
for image in images:
save_image = save_path / image.name
bgr = cv2.imread(str(image))
draw = bgr.copy()
bgr, ratio, dwdh = letterbox(bgr, (W, H))
rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
tensor = blob(rgb, return_seg=False)
dwdh = torch.asarray(dwdh * 2, dtype=torch.float32, device=device)
tensor = torch.asarray(tensor, device=device)
# inference
data = Engine(tensor)
bboxes, scores, labels = det_postprocess(data)
bboxes -= dwdh
bboxes /= ratio
for (bbox, score, label) in zip(bboxes, scores, labels):
bbox = bbox.round().int().tolist()
cls_id = int(label)
cls = CLASSES[cls_id]
color = COLORS[cls]
cv2.rectangle(draw, bbox[:2], bbox[2:], color, 2)
cv2.putText(draw,
f'{cls}:{score:.3f}', (bbox[0], bbox[1] - 2),
cv2.FONT_HERSHEY_SIMPLEX,
0.75, [225, 255, 255],
thickness=2)
if args.show:
cv2.imshow('result', draw)
cv2.waitKey(0)
else:
cv2.imwrite(str(save_image), draw)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument('--engine', type=str, help='Engine file')
parser.add_argument('--imgs', type=str, help='Images file')
parser.add_argument('--show',
action='store_true',
help='Show the detection results')
parser.add_argument('--out-dir',
type=str,
default='./output',
help='Path to output file')
parser.add_argument('--device',
type=str,
default='cuda:0',
help='TensorRT infer device')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
main(args)

@ -1,254 +0,0 @@
import argparse
import os
import random
from pathlib import Path
from typing import List, Tuple, Union
import cv2
import numpy as np
from numpy import ndarray
os.environ['CUDA_MODULE_LOADING'] = 'LAZY'
random.seed(0)
SUFFIXS = ('.bmp', '.dng', '.jpeg', '.jpg', '.mpo', '.png', '.tif', '.tiff',
'.webp', '.pfm')
CLASSES = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe',
'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat',
'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot',
'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop',
'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven',
'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
'scissors', 'teddy bear', 'hair drier', 'toothbrush')
COLORS = {
cls: [random.randint(0, 255) for _ in range(3)]
for i, cls in enumerate(CLASSES)
}
# the same as yolov8
MASK_COLORS = np.array([(255, 56, 56), (255, 157, 151), (255, 112, 31),
(255, 178, 29), (207, 210, 49), (72, 249, 10),
(146, 204, 23), (61, 219, 134), (26, 147, 52),
(0, 212, 187), (44, 153, 168), (0, 194, 255),
(52, 69, 147), (100, 115, 255), (0, 24, 236),
(132, 56, 255), (82, 0, 133), (203, 56, 255),
(255, 149, 200), (255, 55, 199)],
dtype=np.float32) / 255.
ALPHA = 0.5
def letterbox(
im: ndarray,
new_shape: Union[Tuple, List] = (640, 640),
color: Union[Tuple, List] = (114, 114, 114)
) -> Tuple[ndarray, float, Tuple[float, float]]:
# Resize and pad image while meeting stride-multiple constraints
shape = im.shape[:2] # current shape [height, width]
if isinstance(new_shape, int):
new_shape = (new_shape, new_shape)
# Scale ratio (new / old)
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
# Compute padding
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[
1] # wh padding
dw /= 2 # divide padding into 2 sides
dh /= 2
if shape[::-1] != new_unpad: # resize
im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
im = cv2.copyMakeBorder(im,
top,
bottom,
left,
right,
cv2.BORDER_CONSTANT,
value=color) # add border
return im, r, (dw, dh)
def blob(im: ndarray) -> Tuple[ndarray, ndarray]:
seg = im.astype(np.float32) / 255
im = im.transpose([2, 0, 1])
im = im[np.newaxis, ...]
im = np.ascontiguousarray(im).astype(np.float32) / 255
return im, seg
def main(args):
if args.method == 'cudart':
from models.cudart_api import TRTEngine
elif args.method == 'pycuda':
from models.pycuda_api import TRTEngine
else:
raise NotImplementedError
Engine = TRTEngine(args.engine)
H, W = Engine.inp_info[0].shape[-2:]
images_path = Path(args.imgs)
assert images_path.exists()
save_path = Path(args.out_dir)
if images_path.is_dir():
images = [
i.absolute() for i in images_path.iterdir() if i.suffix in SUFFIXS
]
else:
assert images_path.suffix in SUFFIXS
images = [images_path.absolute()]
if not args.show and not save_path.exists():
save_path.mkdir(parents=True, exist_ok=True)
for image in images:
save_image = save_path / image.name
bgr = cv2.imread(str(image))
draw = bgr.copy()
bgr, ratio, dwdh = letterbox(bgr, (W, H))
dw, dh = int(dwdh[0]), int(dwdh[1])
rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
tensor, seg_img = blob(rgb)
dwdh = np.array(dwdh * 2, dtype=np.float32)
tensor = np.ascontiguousarray(tensor)
data = Engine(tensor)
if args.seg:
seg_img = seg_img[dh:H - dh, dw:W - dw, [2, 1, 0]]
bboxes, scores, labels, masks = seg_postprocess(
data, bgr.shape[:2], args.conf_thres, args.iou_thres)
mask, mask_color = [m[:, dh:H - dh, dw:W - dw, :] for m in masks]
inv_alph_masks = (1 - mask * 0.5).cumprod(0)
mcs = (mask_color * inv_alph_masks).sum(0) * 2
seg_img = (seg_img * inv_alph_masks[-1] + mcs) * 255
draw = cv2.resize(seg_img.astype(np.uint8), draw.shape[:2][::-1])
else:
bboxes, scores, labels = det_postprocess(data)
bboxes -= dwdh
bboxes /= ratio
for (bbox, score, label) in zip(bboxes, scores, labels):
bbox = bbox.round().astype(np.int32).tolist()
cls_id = int(label)
cls = CLASSES[cls_id]
color = COLORS[cls]
cv2.rectangle(draw, bbox[:2], bbox[2:], color, 2)
cv2.putText(draw,
f'{cls}:{score:.3f}', (bbox[0], bbox[1] - 2),
cv2.FONT_HERSHEY_SIMPLEX,
0.75, [225, 255, 255],
thickness=2)
if args.show:
cv2.imshow('result', draw)
cv2.waitKey(0)
else:
cv2.imwrite(str(save_image), draw)
def crop_mask(masks: ndarray, bboxes: ndarray) -> ndarray:
n, h, w = masks.shape
x1, y1, x2, y2 = np.split(bboxes[:, :, None], [1, 2, 3],
1) # x1 shape(1,1,n)
r = np.arange(w, dtype=x1.dtype)[None, None, :] # rows shape(1,w,1)
c = np.arange(h, dtype=x1.dtype)[None, :, None] # cols shape(h,1,1)
return masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2))
def seg_postprocess(
data: Tuple[ndarray],
shape: Union[Tuple, List],
conf_thres: float = 0.25,
iou_thres: float = 0.65) -> Tuple[ndarray, ndarray, ndarray, List]:
assert len(data) == 2
h, w = shape[0] // 4, shape[1] // 4 # 4x downsampling
outputs, proto = (i[0] for i in data)
bboxes, scores, labels, maskconf = np.split(outputs, [4, 5, 6], 1)
scores, labels = scores.squeeze(), labels.squeeze()
select = scores > conf_thres
bboxes, scores, labels, maskconf = bboxes[select], scores[select], labels[
select], maskconf[select]
cvbboxes = np.concatenate([bboxes[:, :2], bboxes[:, 2:] - bboxes[:, :2]],
1)
labels = labels.astype(np.int32)
v0, v1 = map(int, (cv2.__version__).split('.')[:2])
assert v0 == 4, 'OpenCV version is wrong'
if v1 > 6:
idx = cv2.dnn.NMSBoxesBatched(cvbboxes, scores, labels, conf_thres,
iou_thres)
else:
idx = cv2.dnn.NMSBoxes(cvbboxes, scores, conf_thres, iou_thres)
bboxes, scores, labels, maskconf = bboxes[idx], scores[idx], labels[
idx], maskconf[idx]
masks = (maskconf @ proto).reshape(-1, h, w)
masks = crop_mask(masks, bboxes / 4.)
masks = cv2.resize(masks.transpose([1, 2, 0]),
shape,
interpolation=cv2.INTER_LINEAR).transpose(2, 0, 1)
masks = np.ascontiguousarray((masks > 0.5)[..., None])
cidx = labels % len(MASK_COLORS)
mask_color = MASK_COLORS[cidx].reshape(-1, 1, 1, 3) * ALPHA
out = [masks, masks @ mask_color]
return bboxes, scores, labels, out
def det_postprocess(data: Tuple[ndarray, ndarray, ndarray]):
assert len(data) == 4
num_dets, bboxes, scores, labels = (i[0] for i in data)
nums = num_dets.item()
bboxes = bboxes[:nums]
scores = scores[:nums]
labels = labels[:nums]
return bboxes, scores, labels
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--engine', type=str, help='Engine file')
parser.add_argument('--imgs', type=str, help='Images file')
parser.add_argument('--show',
action='store_true',
help='Show the detection results')
parser.add_argument('--seg', action='store_true', help='Seg inference')
parser.add_argument('--out-dir',
type=str,
default='./output',
help='Path to output file')
parser.add_argument('--conf-thres',
type=float,
default=0.25,
help='Confidence threshold')
parser.add_argument('--iou-thres',
type=float,
default=0.65,
help='Confidence threshold')
parser.add_argument('--method',
type=str,
default='cudart',
help='CUDART pipeline')
parser.add_argument('--profile',
action='store_true',
help='Profile TensorRT engine')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
main(args)

@ -0,0 +1,100 @@
import argparse
from pathlib import Path
import cv2
import numpy as np
from config import CLASSES, COLORS
from models.utils import blob, letterbox, path_to_list, seg_postprocess
def main(args: argparse.Namespace) -> None:
if args.method == 'cudart':
from models.cudart_api import TRTEngine
elif args.method == 'pycuda':
from models.pycuda_api import TRTEngine
else:
raise NotImplementedError
Engine = TRTEngine(args.engine)
H, W = Engine.inp_info[0].shape[-2:]
images = path_to_list(args.imgs)
save_path = Path(args.out_dir)
if not args.show and not save_path.exists():
save_path.mkdir(parents=True, exist_ok=True)
for image in images:
save_image = save_path / image.name
bgr = cv2.imread(str(image))
draw = bgr.copy()
bgr, ratio, dwdh = letterbox(bgr, (W, H))
dw, dh = int(dwdh[0]), int(dwdh[1])
rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
tensor, seg_img = blob(rgb, return_seg=True)
dwdh = np.array(dwdh * 2, dtype=np.float32)
tensor = np.ascontiguousarray(tensor)
# inference
data = Engine(tensor)
seg_img = seg_img[dh:H - dh, dw:W - dw, [2, 1, 0]]
bboxes, scores, labels, masks = seg_postprocess(
data, bgr.shape[:2], args.conf_thres, args.iou_thres)
mask, mask_color = [m[:, dh:H - dh, dw:W - dw, :] for m in masks]
inv_alph_masks = (1 - mask * 0.5).cumprod(0)
mcs = (mask_color * inv_alph_masks).sum(0) * 2
seg_img = (seg_img * inv_alph_masks[-1] + mcs) * 255
draw = cv2.resize(seg_img.astype(np.uint8), draw.shape[:2][::-1])
bboxes -= dwdh
bboxes /= ratio
for (bbox, score, label) in zip(bboxes, scores, labels):
bbox = bbox.round().astype(np.int32).tolist()
cls_id = int(label)
cls = CLASSES[cls_id]
color = COLORS[cls]
cv2.rectangle(draw, bbox[:2], bbox[2:], color, 2)
cv2.putText(draw,
f'{cls}:{score:.3f}', (bbox[0], bbox[1] - 2),
cv2.FONT_HERSHEY_SIMPLEX,
0.75, [225, 255, 255],
thickness=2)
if args.show:
cv2.imshow('result', draw)
cv2.waitKey(0)
else:
cv2.imwrite(str(save_image), draw)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--engine', type=str, help='Engine file')
parser.add_argument('--imgs', type=str, help='Images file')
parser.add_argument('--show',
action='store_true',
help='Show the detection results')
parser.add_argument('--out-dir',
type=str,
default='./output',
help='Path to output file')
parser.add_argument('--conf-thres',
type=float,
default=0.25,
help='Confidence threshold')
parser.add_argument('--iou-thres',
type=float,
default=0.65,
help='Confidence threshold')
parser.add_argument('--method',
type=str,
default='cudart',
help='CUDART pipeline')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
main(args)

@ -0,0 +1,101 @@
from models import TRTModule # isort:skip
import argparse
from pathlib import Path
import cv2
import numpy as np
import torch
from config import CLASSES, COLORS
from models.torch_utils import seg_postprocess
from models.utils import blob, letterbox, path_to_list
def main(args: argparse.Namespace) -> None:
device = torch.device(args.device)
Engine = TRTModule(args.engine, device)
H, W = Engine.inp_info[0].shape[-2:]
# set desired output names order
Engine.set_desired(['outputs', 'proto'])
images = path_to_list(args.imgs)
save_path = Path(args.out_dir)
if not args.show and not save_path.exists():
save_path.mkdir(parents=True, exist_ok=True)
for image in images:
save_image = save_path / image.name
bgr = cv2.imread(str(image))
draw = bgr.copy()
bgr, ratio, dwdh = letterbox(bgr, (W, H))
dw, dh = int(dwdh[0]), int(dwdh[1])
rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
tensor, seg_img = blob(rgb, return_seg=True)
dwdh = torch.asarray(dwdh * 2, dtype=torch.float32, device=device)
tensor = torch.asarray(tensor, device=device)
# inference
data = Engine(tensor)
seg_img = torch.asarray(seg_img[dh:H - dh, dw:W - dw, [2, 1, 0]],
device=device)
bboxes, scores, labels, masks = seg_postprocess(
data, bgr.shape[:2], args.conf_thres, args.iou_thres)
mask, mask_color = [m[:, dh:H - dh, dw:W - dw, :] for m in masks]
inv_alph_masks = (1 - mask * 0.5).cumprod(0)
mcs = (mask_color * inv_alph_masks).sum(0) * 2
seg_img = (seg_img * inv_alph_masks[-1] + mcs) * 255
draw = cv2.resize(seg_img.cpu().numpy().astype(np.uint8),
draw.shape[:2][::-1])
bboxes -= dwdh
bboxes /= ratio
for (bbox, score, label) in zip(bboxes, scores, labels):
bbox = bbox.round().int().tolist()
cls_id = int(label)
cls = CLASSES[cls_id]
color = COLORS[cls]
cv2.rectangle(draw, bbox[:2], bbox[2:], color, 2)
cv2.putText(draw,
f'{cls}:{score:.3f}', (bbox[0], bbox[1] - 2),
cv2.FONT_HERSHEY_SIMPLEX,
0.75, [225, 255, 255],
thickness=2)
if args.show:
cv2.imshow('result', draw)
cv2.waitKey(0)
else:
cv2.imwrite(str(save_image), draw)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument('--engine', type=str, help='Engine file')
parser.add_argument('--imgs', type=str, help='Images file')
parser.add_argument('--show',
action='store_true',
help='Show the detection results')
parser.add_argument('--out-dir',
type=str,
default='./output',
help='Path to output file')
parser.add_argument('--conf-thres',
type=float,
default=0.25,
help='Confidence threshold')
parser.add_argument('--iou-thres',
type=float,
default=0.65,
help='Confidence threshold')
parser.add_argument('--device',
type=str,
default='cuda:0',
help='TensorRT infer device')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
main(args)

@ -1,267 +0,0 @@
from models import TRTModule, TRTProfilerV0 # isort:skip
import argparse
import os
import random
from pathlib import Path
from typing import List, Tuple, Union
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from numpy import ndarray
from torch import Tensor
from torchvision.ops import batched_nms
os.environ['CUDA_MODULE_LOADING'] = 'LAZY'
random.seed(0)
SUFFIXS = ('.bmp', '.dng', '.jpeg', '.jpg', '.mpo', '.png', '.tif', '.tiff',
'.webp', '.pfm')
CLASSES = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe',
'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat',
'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot',
'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop',
'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven',
'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
'scissors', 'teddy bear', 'hair drier', 'toothbrush')
COLORS = {
cls: [random.randint(0, 255) for _ in range(3)]
for i, cls in enumerate(CLASSES)
}
# the same as yolov8
MASK_COLORS = np.array([(255, 56, 56), (255, 157, 151), (255, 112, 31),
(255, 178, 29), (207, 210, 49), (72, 249, 10),
(146, 204, 23), (61, 219, 134), (26, 147, 52),
(0, 212, 187), (44, 153, 168), (0, 194, 255),
(52, 69, 147), (100, 115, 255), (0, 24, 236),
(132, 56, 255), (82, 0, 133), (203, 56, 255),
(255, 149, 200), (255, 55, 199)],
dtype=np.float32) / 255.
ALPHA = 0.5
def letterbox(
im: ndarray,
new_shape: Union[Tuple, List] = (640, 640),
color: Union[Tuple, List] = (114, 114, 114)
) -> Tuple[ndarray, float, Tuple[float, float]]:
# Resize and pad image while meeting stride-multiple constraints
shape = im.shape[:2] # current shape [height, width]
if isinstance(new_shape, int):
new_shape = (new_shape, new_shape)
# Scale ratio (new / old)
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
# Compute padding
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[
1] # wh padding
dw /= 2 # divide padding into 2 sides
dh /= 2
if shape[::-1] != new_unpad: # resize
im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
im = cv2.copyMakeBorder(im,
top,
bottom,
left,
right,
cv2.BORDER_CONSTANT,
value=color) # add border
return im, r, (dw, dh)
def blob(im: ndarray) -> Tuple[ndarray, ndarray]:
seg = im.astype(np.float32) / 255
im = im.transpose([2, 0, 1])
im = im[np.newaxis, ...]
im = np.ascontiguousarray(im).astype(np.float32) / 255
return im, seg
def main(args):
device = torch.device(args.device)
Engine = TRTModule(args.engine, device)
H, W = Engine.inp_info[0].shape[-2:]
# set desired output names order
if args.seg:
Engine.set_desired(['outputs', 'proto'])
else:
Engine.set_desired(['num_dets', 'bboxes', 'scores', 'labels'])
images_path = Path(args.imgs)
assert images_path.exists()
save_path = Path(args.out_dir)
if images_path.is_dir():
images = [
i.absolute() for i in images_path.iterdir() if i.suffix in SUFFIXS
]
else:
assert images_path.suffix in SUFFIXS
images = [images_path.absolute()]
if not args.show and not save_path.exists():
save_path.mkdir(parents=True, exist_ok=True)
for image in images:
save_image = save_path / image.name
bgr = cv2.imread(str(image))
draw = bgr.copy()
bgr, ratio, dwdh = letterbox(bgr, (W, H))
dw, dh = int(dwdh[0]), int(dwdh[1])
rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
tensor, seg_img = blob(rgb)
dwdh = torch.asarray(dwdh * 2, dtype=torch.float32, device=device)
tensor = torch.asarray(tensor, device=device)
data = Engine(tensor)
if args.seg:
seg_img = torch.asarray(seg_img[dh:H - dh, dw:W - dw, [2, 1, 0]],
device=device)
bboxes, scores, labels, masks = seg_postprocess(
data, bgr.shape[:2], args.conf_thres, args.iou_thres)
mask, mask_color = [m[:, dh:H - dh, dw:W - dw, :] for m in masks]
inv_alph_masks = (1 - mask * 0.5).cumprod(0)
mcs = (mask_color * inv_alph_masks).sum(0) * 2
seg_img = (seg_img * inv_alph_masks[-1] + mcs) * 255
draw = cv2.resize(seg_img.cpu().numpy().astype(np.uint8),
draw.shape[:2][::-1])
else:
bboxes, scores, labels = det_postprocess(data)
bboxes -= dwdh
bboxes /= ratio
for (bbox, score, label) in zip(bboxes, scores, labels):
bbox = bbox.round().int().tolist()
cls_id = int(label)
cls = CLASSES[cls_id]
color = COLORS[cls]
cv2.rectangle(draw, bbox[:2], bbox[2:], color, 2)
cv2.putText(draw,
f'{cls}:{score:.3f}', (bbox[0], bbox[1] - 2),
cv2.FONT_HERSHEY_SIMPLEX,
0.75, [225, 255, 255],
thickness=2)
if args.show:
cv2.imshow('result', draw)
cv2.waitKey(0)
else:
cv2.imwrite(str(save_image), draw)
def crop_mask(masks: Tensor, bboxes: Tensor) -> Tensor:
n, h, w = masks.shape
x1, y1, x2, y2 = torch.chunk(bboxes[:, :, None], 4, 1) # x1 shape(1,1,n)
r = torch.arange(w, device=masks.device,
dtype=x1.dtype)[None, None, :] # rows shape(1,w,1)
c = torch.arange(h, device=masks.device,
dtype=x1.dtype)[None, :, None] # cols shape(h,1,1)
return masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2))
def seg_postprocess(
data: Tuple[Tensor],
shape: Union[Tuple, List],
conf_thres: float = 0.25,
iou_thres: float = 0.65) -> Tuple[Tensor, Tensor, Tensor, List]:
assert len(data) == 2
h, w = shape[0] // 4, shape[1] // 4 # 4x downsampling
outputs, proto = (i[0] for i in data)
bboxes, scores, labels, maskconf = outputs.split([4, 1, 1, 32], 1)
scores, labels = scores.squeeze(), labels.squeeze()
select = scores > conf_thres
bboxes, scores, labels, maskconf = bboxes[select], scores[select], labels[
select], maskconf[select]
idx = batched_nms(bboxes, scores, labels, iou_thres)
bboxes, scores, labels, maskconf = bboxes[idx], scores[idx], labels[
idx].int(), maskconf[idx]
masks = (maskconf @ proto).view(-1, h, w)
masks = crop_mask(masks, bboxes / 4.)
masks = F.interpolate(masks[None],
shape,
mode='bilinear',
align_corners=False)[0]
masks = masks.gt_(0.5)[..., None]
cidx = (labels % len(MASK_COLORS)).cpu().numpy()
mask_color = torch.tensor(MASK_COLORS[cidx].reshape(-1, 1, 1,
3)).to(bboxes) * ALPHA
out = [masks, masks @ mask_color]
return bboxes, scores, labels, out
def det_postprocess(data: Tuple[Tensor, Tensor, Tensor, Tensor]):
assert len(data) == 4
num_dets, bboxes, scores, labels = (i[0] for i in data)
nums = num_dets.item()
bboxes = bboxes[:nums]
scores = scores[:nums]
labels = labels[:nums]
return bboxes, scores, labels
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--engine', type=str, help='Engine file')
parser.add_argument('--imgs', type=str, help='Images file')
parser.add_argument('--show',
action='store_true',
help='Show the detection results')
parser.add_argument('--seg', action='store_true', help='Seg inference')
parser.add_argument('--out-dir',
type=str,
default='./output',
help='Path to output file')
parser.add_argument('--conf-thres',
type=float,
default=0.25,
help='Confidence threshold')
parser.add_argument('--iou-thres',
type=float,
default=0.65,
help='Confidence threshold')
parser.add_argument('--device',
type=str,
default='cuda:0',
help='TensorRT infer device')
parser.add_argument('--profile',
action='store_true',
help='Profile TensorRT engine')
args = parser.parse_args()
return args
def profile(args):
device = torch.device(args.device)
Engine = TRTModule(args.engine, device)
profiler = TRTProfilerV0()
Engine.set_profiler(profiler)
random_input = torch.randn(Engine.inp_info[0].shape, device=device)
_ = Engine(random_input)
if __name__ == '__main__':
args = parse_args()
if args.profile:
profile(args)
else:
main(args)

@ -1,3 +1,4 @@
import os
import pickle
from collections import defaultdict, namedtuple
from pathlib import Path
@ -7,6 +8,8 @@ import onnx
import tensorrt as trt
import torch
os.environ['CUDA_MODULE_LOADING'] = 'LAZY'
class EngineBuilder:
seg = False

@ -0,0 +1,54 @@
from typing import List, Tuple, Union
import torch
import torch.nn.functional as F
from torch import Tensor
from torchvision.ops import batched_nms
def seg_postprocess(
data: Tuple[Tensor],
shape: Union[Tuple, List],
conf_thres: float = 0.25,
iou_thres: float = 0.65) \
-> Tuple[Tensor, Tensor, Tensor, Tensor]:
assert len(data) == 2
h, w = shape[0] // 4, shape[1] // 4 # 4x downsampling
outputs, proto = (i[0] for i in data)
bboxes, scores, labels, maskconf = outputs.split([4, 1, 1, 32], 1)
scores, labels = scores.squeeze(), labels.squeeze()
idx = scores > conf_thres
bboxes, scores, labels, maskconf = \
bboxes[idx], scores[idx], labels[idx], maskconf[idx]
idx = batched_nms(bboxes, scores, labels, iou_thres)
bboxes, scores, labels, maskconf = \
bboxes[idx], scores[idx], labels[idx].int(), maskconf[idx]
masks = (maskconf @ proto).view(-1, h, w)
masks = crop_mask(masks, bboxes / 4.)
masks = F.interpolate(masks[None],
shape,
mode='bilinear',
align_corners=False)[0]
masks = masks.gt_(0.5)[..., None]
return bboxes, scores, labels, masks
def det_postprocess(data: Tuple[Tensor, Tensor, Tensor, Tensor]):
assert len(data) == 4
num_dets, bboxes, scores, labels = (i[0] for i in data)
nums = num_dets.item()
bboxes = bboxes[:nums]
scores = scores[:nums]
labels = labels[:nums]
return bboxes, scores, labels
def crop_mask(masks: Tensor, bboxes: Tensor) -> Tensor:
n, h, w = masks.shape
x1, y1, x2, y2 = torch.chunk(bboxes[:, :, None], 4, 1) # x1 shape(1,1,n)
r = torch.arange(w, device=masks.device,
dtype=x1.dtype)[None, None, :] # rows shape(1,w,1)
c = torch.arange(h, device=masks.device,
dtype=x1.dtype)[None, :, None] # cols shape(h,1,1)
return masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2))

@ -0,0 +1,125 @@
from pathlib import Path
from typing import List, Tuple, Union
import cv2
import numpy as np
from numpy import ndarray
# image suffixs
SUFFIXS = ('.bmp', '.dng', '.jpeg', '.jpg', '.mpo', '.png', '.tif', '.tiff',
'.webp', '.pfm')
def letterbox(im: ndarray,
new_shape: Union[Tuple, List] = (640, 640),
color: Union[Tuple, List] = (114, 114, 114)) \
-> Tuple[ndarray, float, Tuple[float, float]]:
# Resize and pad image while meeting stride-multiple constraints
shape = im.shape[:2] # current shape [height, width]
if isinstance(new_shape, int):
new_shape = (new_shape, new_shape)
# Scale ratio (new / old)
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
# Compute padding
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[
1] # wh padding
dw /= 2 # divide padding into 2 sides
dh /= 2
if shape[::-1] != new_unpad: # resize
im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
im = cv2.copyMakeBorder(im,
top,
bottom,
left,
right,
cv2.BORDER_CONSTANT,
value=color) # add border
return im, r, (dw, dh)
def blob(im: ndarray, return_seg: bool = False) -> Union[ndarray, Tuple]:
if return_seg:
seg = im.astype(np.float32) / 255
im = im.transpose([2, 0, 1])
im = im[np.newaxis, ...]
im = np.ascontiguousarray(im).astype(np.float32) / 255
if return_seg:
return im, seg
else:
return im
def path_to_list(images_path: Union[str, Path]) -> List:
if isinstance(images_path, str):
images_path = Path(images_path)
assert images_path.exists()
if images_path.is_dir():
images = [
i.absolute() for i in images_path.iterdir() if i.suffix in SUFFIXS
]
else:
assert images_path.suffix in SUFFIXS
images = [images_path.absolute()]
return images
def crop_mask(masks: ndarray, bboxes: ndarray) -> ndarray:
n, h, w = masks.shape
x1, y1, x2, y2 = np.split(bboxes[:, :, None], [1, 2, 3],
1) # x1 shape(1,1,n)
r = np.arange(w, dtype=x1.dtype)[None, None, :] # rows shape(1,w,1)
c = np.arange(h, dtype=x1.dtype)[None, :, None] # cols shape(h,1,1)
return masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2))
def det_postprocess(data: Tuple[ndarray, ndarray, ndarray, ndarray]):
assert len(data) == 4
num_dets, bboxes, scores, labels = (i[0] for i in data)
nums = num_dets.item()
bboxes = bboxes[:nums]
scores = scores[:nums]
labels = labels[:nums]
return bboxes, scores, labels
def seg_postprocess(
data: Tuple[ndarray],
shape: Union[Tuple, List],
conf_thres: float = 0.25,
iou_thres: float = 0.65) \
-> Tuple[ndarray, ndarray, ndarray, ndarray]:
assert len(data) == 2
h, w = shape[0] // 4, shape[1] // 4 # 4x downsampling
outputs, proto = (i[0] for i in data)
bboxes, scores, labels, maskconf = np.split(outputs, [4, 5, 6], 1)
scores, labels = scores.squeeze(), labels.squeeze()
idx = scores > conf_thres
bboxes, scores, labels, maskconf = \
bboxes[idx], scores[idx], labels[idx], maskconf[idx]
cvbboxes = np.concatenate([bboxes[:, :2], bboxes[:, 2:] - bboxes[:, :2]],
1)
labels = labels.astype(np.int32)
v0, v1 = map(int, (cv2.__version__).split('.')[:2])
assert v0 == 4, 'OpenCV version is wrong'
if v1 > 6:
idx = cv2.dnn.NMSBoxesBatched(cvbboxes, scores, labels, conf_thres,
iou_thres)
else:
idx = cv2.dnn.NMSBoxes(cvbboxes, scores, conf_thres, iou_thres)
bboxes, scores, labels, maskconf = \
bboxes[idx], scores[idx], labels[idx], maskconf[idx]
masks = (maskconf @ proto).reshape(-1, h, w)
masks = crop_mask(masks, bboxes / 4.)
masks = cv2.resize(masks.transpose([1, 2, 0]),
shape,
interpolation=cv2.INTER_LINEAR).transpose(2, 0, 1)
masks = np.ascontiguousarray((masks > 0.5)[..., None])
return bboxes, scores, labels, masks

@ -0,0 +1,29 @@
from models import TRTModule, TRTProfilerV0 # isort:skip
import argparse
import torch
def profile(args):
device = torch.device(args.device)
Engine = TRTModule(args.engine, device)
profiler = TRTProfilerV0()
Engine.set_profiler(profiler)
random_input = torch.randn(Engine.inp_info[0].shape, device=device)
_ = Engine(random_input)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--engine', type=str, help='Engine file')
parser.add_argument('--device',
type=str,
default='cuda:0',
help='TensorRT infer device')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
profile(args)
Loading…
Cancel
Save