diff --git a/models/torch_utils.py b/models/torch_utils.py index 8ad2a2e..b25ede9 100644 --- a/models/torch_utils.py +++ b/models/torch_utils.py @@ -64,12 +64,12 @@ def det_postprocess(data: Tuple[Tensor, Tensor, Tensor, Tensor]): iou_thres: float = 0.65 num_dets, bboxes, scores, labels = data[0][0], data[1][0], data[2][ 0], data[3][0] - # check score negative - scores[scores < 0] = 1 + scores[scores < 0] nums = num_dets.item() if nums == 0: return bboxes.new_zeros((0, 4)), scores.new_zeros( (0, )), labels.new_zeros((0, )) + # check score negative + scores[scores < 0] = 1 + scores[scores < 0] # add nms idx = nms(bboxes, scores, iou_thres) bboxes, scores, labels = bboxes[idx], scores[idx], labels[idx] diff --git a/models/utils.py b/models/utils.py index db2da2d..5f22a3c 100644 --- a/models/utils.py +++ b/models/utils.py @@ -90,12 +90,12 @@ def det_postprocess(data: Tuple[ndarray, ndarray, ndarray, ndarray]): assert len(data) == 4 iou_thres: float = 0.65 num_dets, bboxes, scores, labels = (i[0] for i in data) - # check score negative - scores[scores < 0] = 1 + scores[scores < 0] nums = num_dets.item() if nums == 0: return np.empty((0, 4), dtype=np.float32), np.empty( (0, ), dtype=np.float32), np.empty((0, ), dtype=np.int32) + # check score negative + scores[scores < 0] = 1 + scores[scores < 0] # add nms idx = nms(bboxes, scores, iou_thres) bboxes, scores, labels = bboxes[idx], scores[idx], labels[idx]