|
|
|
@ -4,6 +4,7 @@ from typing import List, Tuple, Union |
|
|
|
|
import cv2 |
|
|
|
|
import numpy as np |
|
|
|
|
from numpy import ndarray |
|
|
|
|
from torchvision.ops import nms |
|
|
|
|
|
|
|
|
|
# image suffixs |
|
|
|
|
SUFFIXS = ('.bmp', '.dng', '.jpeg', '.jpg', '.mpo', '.png', '.tif', '.tiff', |
|
|
|
@ -87,11 +88,18 @@ def crop_mask(masks: ndarray, bboxes: ndarray) -> ndarray: |
|
|
|
|
|
|
|
|
|
def det_postprocess(data: Tuple[ndarray, ndarray, ndarray, ndarray]): |
|
|
|
|
assert len(data) == 4 |
|
|
|
|
iou_thres: float = 0.65 |
|
|
|
|
num_dets, bboxes, scores, labels = (i[0] for i in data) |
|
|
|
|
# check score negative |
|
|
|
|
scores[scores < 0] = 1 + scores[scores < 0] |
|
|
|
|
nums = num_dets.item() |
|
|
|
|
if nums == 0: |
|
|
|
|
return np.empty((0, 4), dtype=np.float32), np.empty( |
|
|
|
|
(0, ), dtype=np.float32), np.empty((0, ), dtype=np.int32) |
|
|
|
|
# add nms |
|
|
|
|
idx = nms(bboxes, scores, iou_thres) |
|
|
|
|
bboxes, scores, labels = bboxes[idx], scores[idx], labels[idx] |
|
|
|
|
|
|
|
|
|
bboxes = bboxes[:nums] |
|
|
|
|
scores = scores[:nums] |
|
|
|
|
labels = labels[:nums] |
|
|
|
|