You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

54 lines
2.0 KiB

from typing import List, Tuple, Union
import torch
import torch.nn.functional as F
from torch import Tensor
from torchvision.ops import batched_nms
def seg_postprocess(
data: Tuple[Tensor],
shape: Union[Tuple, List],
conf_thres: float = 0.25,
iou_thres: float = 0.65) \
-> Tuple[Tensor, Tensor, Tensor, Tensor]:
assert len(data) == 2
h, w = shape[0] // 4, shape[1] // 4 # 4x downsampling
outputs, proto = (i[0] for i in data)
bboxes, scores, labels, maskconf = outputs.split([4, 1, 1, 32], 1)
scores, labels = scores.squeeze(), labels.squeeze()
idx = scores > conf_thres
bboxes, scores, labels, maskconf = \
bboxes[idx], scores[idx], labels[idx], maskconf[idx]
idx = batched_nms(bboxes, scores, labels, iou_thres)
bboxes, scores, labels, maskconf = \
bboxes[idx], scores[idx], labels[idx].int(), maskconf[idx]
masks = (maskconf @ proto).view(-1, h, w)
masks = crop_mask(masks, bboxes / 4.)
masks = F.interpolate(masks[None],
shape,
mode='bilinear',
align_corners=False)[0]
masks = masks.gt_(0.5)[..., None]
return bboxes, scores, labels, masks
def det_postprocess(data: Tuple[Tensor, Tensor, Tensor, Tensor]):
assert len(data) == 4
num_dets, bboxes, scores, labels = (i[0] for i in data)
nums = num_dets.item()
bboxes = bboxes[:nums]
scores = scores[:nums]
labels = labels[:nums]
return bboxes, scores, labels
def crop_mask(masks: Tensor, bboxes: Tensor) -> Tensor:
n, h, w = masks.shape
x1, y1, x2, y2 = torch.chunk(bboxes[:, :, None], 4, 1) # x1 shape(1,1,n)
r = torch.arange(w, device=masks.device,
dtype=x1.dtype)[None, None, :] # rows shape(1,w,1)
c = torch.arange(h, device=masks.device,
dtype=x1.dtype)[None, :, None] # cols shape(h,1,1)
return masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2))