diff --git a/infer-seg.py b/infer-seg.py index 46e3251..de77316 100644 --- a/infer-seg.py +++ b/infer-seg.py @@ -42,6 +42,11 @@ def main(args: argparse.Namespace) -> None: device=device) bboxes, scores, labels, masks = seg_postprocess( data, bgr.shape[:2], args.conf_thres, args.iou_thres) + if bboxes is None: + # if no bounding box or others save original image + if not args.show: + cv2.imwrite(str(save_image), draw) + continue masks = masks[:, dh:H - dh, dw:W - dw, :] indices = (labels % len(MASK_COLORS)).long() mask_colors = torch.asarray(MASK_COLORS, device=device)[indices] diff --git a/models/torch_utils.py b/models/torch_utils.py index 6624ff1..1bf7d39 100644 --- a/models/torch_utils.py +++ b/models/torch_utils.py @@ -18,6 +18,8 @@ def seg_postprocess( bboxes, scores, labels, maskconf = outputs.split([4, 1, 1, 32], 1) scores, labels = scores.squeeze(), labels.squeeze() idx = scores > conf_thres + if idx.sum() == 0: # no bounding boxes or seg were created + return None, None, None, None bboxes, scores, labels, maskconf = \ bboxes[idx], scores[idx], labels[idx], maskconf[idx] idx = batched_nms(bboxes, scores, labels, iou_thres)