from typing import List, Tuple, Union import torch import torch.nn.functional as F from torch import Tensor from torchvision.ops import batched_nms def seg_postprocess( data: Tuple[Tensor], shape: Union[Tuple, List], conf_thres: float = 0.25, iou_thres: float = 0.65) \ -> Tuple[Tensor, Tensor, Tensor, Tensor]: assert len(data) == 2 h, w = shape[0] // 4, shape[1] // 4 # 4x downsampling outputs, proto = (i[0] for i in data) bboxes, scores, labels, maskconf = outputs.split([4, 1, 1, 32], 1) scores, labels = scores.squeeze(), labels.squeeze() idx = scores > conf_thres bboxes, scores, labels, maskconf = \ bboxes[idx], scores[idx], labels[idx], maskconf[idx] idx = batched_nms(bboxes, scores, labels, iou_thres) bboxes, scores, labels, maskconf = \ bboxes[idx], scores[idx], labels[idx].int(), maskconf[idx] masks = (maskconf @ proto).view(-1, h, w) masks = crop_mask(masks, bboxes / 4.) masks = F.interpolate(masks[None], shape, mode='bilinear', align_corners=False)[0] masks = masks.gt_(0.5)[..., None] return bboxes, scores, labels, masks def det_postprocess(data: Tuple[Tensor, Tensor, Tensor, Tensor]): assert len(data) == 4 num_dets, bboxes, scores, labels = (i[0] for i in data) nums = num_dets.item() bboxes = bboxes[:nums] scores = scores[:nums] labels = labels[:nums] return bboxes, scores, labels def crop_mask(masks: Tensor, bboxes: Tensor) -> Tensor: n, h, w = masks.shape x1, y1, x2, y2 = torch.chunk(bboxes[:, :, None], 4, 1) # x1 shape(1,1,n) r = torch.arange(w, device=masks.device, dtype=x1.dtype)[None, None, :] # rows shape(1,w,1) c = torch.arange(h, device=masks.device, dtype=x1.dtype)[None, :, None] # cols shape(h,1,1) return masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2))