Fix seg inference with torch and numpy

pull/32/head
triple-Mu 2 years ago
parent 133d1625d9
commit 7902783618
  1. 11
      infer-seg-without-torch.py
  2. 12
      infer-seg.py
  3. 2
      models/engine.py
  4. 2
      models/utils.py
  5. 2
      requirements.txt

@ -4,7 +4,7 @@ from pathlib import Path
import cv2 import cv2
import numpy as np import numpy as np
from config import CLASSES, COLORS from config import ALPHA, CLASSES, COLORS, MASK_COLORS
from models.utils import blob, letterbox, path_to_list, seg_postprocess from models.utils import blob, letterbox, path_to_list, seg_postprocess
@ -41,9 +41,12 @@ def main(args: argparse.Namespace) -> None:
seg_img = seg_img[dh:H - dh, dw:W - dw, [2, 1, 0]] seg_img = seg_img[dh:H - dh, dw:W - dw, [2, 1, 0]]
bboxes, scores, labels, masks = seg_postprocess( bboxes, scores, labels, masks = seg_postprocess(
data, bgr.shape[:2], args.conf_thres, args.iou_thres) data, bgr.shape[:2], args.conf_thres, args.iou_thres)
mask, mask_color = [m[:, dh:H - dh, dw:W - dw, :] for m in masks] masks = masks[:, dh:H - dh, dw:W - dw, :]
inv_alph_masks = (1 - mask * 0.5).cumprod(0) mask_colors = MASK_COLORS[labels % len(MASK_COLORS)]
mcs = (mask_color * inv_alph_masks).sum(0) * 2 mask_colors = mask_colors.reshape(-1, 1, 1, 3) * ALPHA
mask_colors = masks @ mask_colors
inv_alph_masks = (1 - masks * 0.5).cumprod(0)
mcs = (mask_colors * inv_alph_masks).sum(0) * 2
seg_img = (seg_img * inv_alph_masks[-1] + mcs) * 255 seg_img = (seg_img * inv_alph_masks[-1] + mcs) * 255
draw = cv2.resize(seg_img.astype(np.uint8), draw.shape[:2][::-1]) draw = cv2.resize(seg_img.astype(np.uint8), draw.shape[:2][::-1])

@ -6,7 +6,7 @@ import cv2
import numpy as np import numpy as np
import torch import torch
from config import CLASSES, COLORS from config import ALPHA, CLASSES, COLORS, MASK_COLORS
from models.torch_utils import seg_postprocess from models.torch_utils import seg_postprocess
from models.utils import blob, letterbox, path_to_list from models.utils import blob, letterbox, path_to_list
@ -41,9 +41,13 @@ def main(args: argparse.Namespace) -> None:
device=device) device=device)
bboxes, scores, labels, masks = seg_postprocess( bboxes, scores, labels, masks = seg_postprocess(
data, bgr.shape[:2], args.conf_thres, args.iou_thres) data, bgr.shape[:2], args.conf_thres, args.iou_thres)
mask, mask_color = [m[:, dh:H - dh, dw:W - dw, :] for m in masks] masks = masks[:, dh:H - dh, dw:W - dw, :]
inv_alph_masks = (1 - mask * 0.5).cumprod(0) indices = (labels % len(MASK_COLORS)).long()
mcs = (mask_color * inv_alph_masks).sum(0) * 2 mask_colors = torch.asarray(MASK_COLORS, device=device)[indices]
mask_colors = mask_colors.view(-1, 1, 1, 3) * ALPHA
mask_colors = masks @ mask_colors
inv_alph_masks = (1 - masks * 0.5).cumprod(0)
mcs = (mask_colors * inv_alph_masks).sum(0) * 2
seg_img = (seg_img * inv_alph_masks[-1] + mcs) * 255 seg_img = (seg_img * inv_alph_masks[-1] + mcs) * 255
draw = cv2.resize(seg_img.cpu().numpy().astype(np.uint8), draw = cv2.resize(seg_img.cpu().numpy().astype(np.uint8),
draw.shape[:2][::-1]) draw.shape[:2][::-1])

@ -303,7 +303,7 @@ class TRTModule(torch.nn.Module):
for i in range(self.num_outputs): for i in range(self.num_outputs):
j = i + self.num_inputs j = i + self.num_inputs
if self.is_dynamic: if self.odynamic:
shape = tuple(self.context.get_binding_shape(j)) shape = tuple(self.context.get_binding_shape(j))
output = torch.empty(size=shape, output = torch.empty(size=shape,
dtype=self.out_info[i].dtype, dtype=self.out_info[i].dtype,

@ -121,5 +121,5 @@ def seg_postprocess(
masks = cv2.resize(masks.transpose([1, 2, 0]), masks = cv2.resize(masks.transpose([1, 2, 0]),
shape, shape,
interpolation=cv2.INTER_LINEAR).transpose(2, 0, 1) interpolation=cv2.INTER_LINEAR).transpose(2, 0, 1)
masks = np.ascontiguousarray((masks > 0.5)[..., None]) masks = np.ascontiguousarray((masks > 0.5)[..., None], dtype=np.float32)
return bboxes, scores, labels, masks return bboxes, scores, labels, masks

@ -1,4 +1,4 @@
numpy numpy<=1.23.5
onnx onnx
onnxsim onnxsim
opencv-python opencv-python

Loading…
Cancel
Save