Fix seg model sigmoid

pull/49/head
triple-Mu 2 years ago
parent a246b5b5b0
commit d59533ba12
  1. 2
      models/torch_utils.py
  2. 20
      models/utils.py

@ -23,7 +23,7 @@ def seg_postprocess(
idx = batched_nms(bboxes, scores, labels, iou_thres)
bboxes, scores, labels, maskconf = \
bboxes[idx], scores[idx], labels[idx].int(), maskconf[idx]
masks = (maskconf @ proto).view(-1, h, w)
masks = (maskconf @ proto).sigmoid().view(-1, h, w)
masks = crop_mask(masks, bboxes / 4.)
masks = F.interpolate(masks[None],
shape,

@ -18,12 +18,14 @@ def letterbox(im: ndarray,
shape = im.shape[:2] # current shape [height, width]
if isinstance(new_shape, int):
new_shape = (new_shape, new_shape)
# new_shape: [width, height]
# Scale ratio (new / old)
r = min(new_shape[0] / shape[1], new_shape[1] / shape[0])
# Compute padding
# Compute padding [width, height]
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
dw, dh = new_shape[0] - new_unpad[0], new_shape[1] - new_unpad[1] # wh padding
dw, dh = new_shape[0] - new_unpad[0], new_shape[1] - new_unpad[
1] # wh padding
dw /= 2 # divide padding into 2 sides
dh /= 2
@ -54,6 +56,10 @@ def blob(im: ndarray, return_seg: bool = False) -> Union[ndarray, Tuple]:
return im
def sigmoid(x):
return 1. / (1. + np.exp(-x))
def path_to_list(images_path: Union[str, Path]) -> List:
if isinstance(images_path, str):
images_path = Path(images_path)
@ -114,13 +120,11 @@ def seg_postprocess(
idx = cv2.dnn.NMSBoxes(cvbboxes, scores, conf_thres, iou_thres)
bboxes, scores, labels, maskconf = \
bboxes[idx], scores[idx], labels[idx], maskconf[idx]
if bboxes==[]:return None,None,None,None
masks = (maskconf @ proto).reshape(-1, h, w)
masks = sigmoid(maskconf @ proto).reshape(-1, h, w)
masks = crop_mask(masks, bboxes / 4.)
masks=masks.transpose([1, 2, 0])
masks = cv2.resize(masks,
(shape[1],shape[0]),
masks = masks.transpose([1, 2, 0])
masks = cv2.resize(masks, (shape[1], shape[0]),
interpolation=cv2.INTER_LINEAR)
masks=masks.transpose(2, 0, 1)
masks = masks.transpose(2, 0, 1)
masks = np.ascontiguousarray((masks > 0.5)[..., None], dtype=np.float32)
return bboxes, scores, labels, masks

Loading…
Cancel
Save