|
|
|
from typing import Tuple
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
from torch import Graph, Tensor, Value
|
|
|
|
|
|
|
|
|
|
|
|
def make_anchors(feats: Tensor,
|
|
|
|
strides: Tensor,
|
|
|
|
grid_cell_offset: float = 0.5) -> Tuple[Tensor, Tensor]:
|
|
|
|
anchor_points, stride_tensor = [], []
|
|
|
|
assert feats is not None
|
|
|
|
dtype, device = feats[0].dtype, feats[0].device
|
|
|
|
for i, stride in enumerate(strides):
|
|
|
|
_, _, h, w = feats[i].shape
|
|
|
|
sx = torch.arange(end=w, device=device,
|
|
|
|
dtype=dtype) + grid_cell_offset # shift x
|
|
|
|
sy = torch.arange(end=h, device=device,
|
|
|
|
dtype=dtype) + grid_cell_offset # shift y
|
|
|
|
sy, sx = torch.meshgrid(sy, sx)
|
|
|
|
anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2))
|
|
|
|
stride_tensor.append(
|
|
|
|
torch.full((h * w, 1), stride, dtype=dtype, device=device))
|
|
|
|
return torch.cat(anchor_points), torch.cat(stride_tensor)
|
|
|
|
|
|
|
|
|
|
|
|
class TRT_NMS(torch.autograd.Function):
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def forward(
|
|
|
|
ctx: Graph,
|
|
|
|
boxes: Tensor,
|
|
|
|
scores: Tensor,
|
|
|
|
iou_threshold: float = 0.65,
|
|
|
|
score_threshold: float = 0.25,
|
|
|
|
max_output_boxes: int = 100,
|
|
|
|
background_class: int = -1,
|
|
|
|
box_coding: int = 0,
|
|
|
|
plugin_version: str = '1',
|
|
|
|
score_activation: int = 0
|
|
|
|
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
|
|
|
|
batch_size, num_boxes, num_classes = scores.shape
|
|
|
|
num_dets = torch.randint(0,
|
|
|
|
max_output_boxes, (batch_size, 1),
|
|
|
|
dtype=torch.int32)
|
|
|
|
boxes = torch.randn(batch_size, max_output_boxes, 4)
|
|
|
|
scores = torch.randn(batch_size, max_output_boxes)
|
|
|
|
labels = torch.randint(0,
|
|
|
|
num_classes, (batch_size, max_output_boxes),
|
|
|
|
dtype=torch.int32)
|
|
|
|
|
|
|
|
return num_dets, boxes, scores, labels
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def symbolic(
|
|
|
|
g,
|
|
|
|
boxes: Value,
|
|
|
|
scores: Value,
|
|
|
|
iou_threshold: float = 0.45,
|
|
|
|
score_threshold: float = 0.25,
|
|
|
|
max_output_boxes: int = 100,
|
|
|
|
background_class: int = -1,
|
|
|
|
box_coding: int = 0,
|
|
|
|
score_activation: int = 0,
|
|
|
|
plugin_version: str = '1') -> Tuple[Value, Value, Value, Value]:
|
|
|
|
out = g.op('TRT::EfficientNMS_TRT',
|
|
|
|
boxes,
|
|
|
|
scores,
|
|
|
|
iou_threshold_f=iou_threshold,
|
|
|
|
score_threshold_f=score_threshold,
|
|
|
|
max_output_boxes_i=max_output_boxes,
|
|
|
|
background_class_i=background_class,
|
|
|
|
box_coding_i=box_coding,
|
|
|
|
plugin_version_s=plugin_version,
|
|
|
|
score_activation_i=score_activation,
|
|
|
|
outputs=4)
|
|
|
|
nums_dets, boxes, scores, classes = out
|
|
|
|
return nums_dets, boxes, scores, classes
|
|
|
|
|
|
|
|
|
|
|
|
class C2f(nn.Module):
|
|
|
|
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
x = self.cv1(x)
|
|
|
|
x = [x, x[:, self.c:, ...]]
|
|
|
|
x.extend(m(x[-1]) for m in self.m)
|
|
|
|
x.pop(1)
|
|
|
|
return self.cv2(torch.cat(x, 1))
|
|
|
|
|
|
|
|
|
|
|
|
class PostDetect(nn.Module):
|
|
|
|
export = True
|
|
|
|
shape = None
|
|
|
|
dynamic = False
|
|
|
|
iou_thres = 0.65
|
|
|
|
conf_thres = 0.25
|
|
|
|
topk = 100
|
|
|
|
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
shape = x[0].shape
|
|
|
|
b, res, b_reg_num = shape[0], [], self.reg_max * 4
|
|
|
|
for i in range(self.nl):
|
|
|
|
res.append(torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1))
|
|
|
|
if self.dynamic or self.shape != shape:
|
|
|
|
self.anchors, self.strides = (x.transpose(
|
|
|
|
0, 1) for x in make_anchors(x, self.stride, 0.5))
|
|
|
|
self.shape = shape
|
|
|
|
x = [i.view(b, self.no, -1) for i in res]
|
|
|
|
y = torch.cat(x, 2)
|
|
|
|
boxes, scores = y[:, :b_reg_num, ...], y[:, b_reg_num:, ...].sigmoid()
|
|
|
|
boxes = boxes.view(b, 4, self.reg_max, -1).permute(0, 1, 3, 2)
|
|
|
|
boxes = boxes.softmax(-1) @ torch.arange(self.reg_max).to(boxes)
|
|
|
|
boxes0, boxes1 = -boxes[:, :2, ...], boxes[:, 2:, ...]
|
|
|
|
boxes = self.anchors.repeat(b, 2, 1) + torch.cat([boxes0, boxes1], 1)
|
|
|
|
boxes = boxes * self.strides
|
|
|
|
|
|
|
|
return TRT_NMS.apply(boxes.transpose(1, 2), scores.transpose(1, 2),
|
|
|
|
self.iou_thres, self.conf_thres, self.topk)
|
|
|
|
|
|
|
|
|
|
|
|
class PostSeg(nn.Module):
|
|
|
|
export = True
|
|
|
|
shape = None
|
|
|
|
dynamic = False
|
|
|
|
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
p = self.proto(x[0]) # mask protos
|
|
|
|
bs = p.shape[0] # batch size
|
|
|
|
mc = torch.cat(
|
|
|
|
[self.cv4[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)],
|
|
|
|
2) # mask coefficients
|
|
|
|
boxes, scores, labels = self.forward_det(x)
|
|
|
|
out = torch.cat([boxes, scores, labels.float(), mc.transpose(1, 2)], 2)
|
|
|
|
return out, p.flatten(2)
|
|
|
|
|
|
|
|
def forward_det(self, x):
|
|
|
|
shape = x[0].shape
|
|
|
|
b, res, b_reg_num = shape[0], [], self.reg_max * 4
|
|
|
|
for i in range(self.nl):
|
|
|
|
res.append(torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1))
|
|
|
|
if self.dynamic or self.shape != shape:
|
|
|
|
self.anchors, self.strides = \
|
|
|
|
(x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
|
|
|
|
self.shape = shape
|
|
|
|
x = [i.view(b, self.no, -1) for i in res]
|
|
|
|
y = torch.cat(x, 2)
|
|
|
|
boxes, scores = y[:, :b_reg_num, ...], y[:, b_reg_num:, ...].sigmoid()
|
|
|
|
boxes = boxes.view(b, 4, self.reg_max, -1).permute(0, 1, 3, 2)
|
|
|
|
boxes = boxes.softmax(-1) @ torch.arange(self.reg_max).to(boxes)
|
|
|
|
boxes0, boxes1 = -boxes[:, :2, ...], boxes[:, 2:, ...]
|
|
|
|
boxes = self.anchors.repeat(b, 2, 1) + torch.cat([boxes0, boxes1], 1)
|
|
|
|
boxes = boxes * self.strides
|
|
|
|
scores, labels = scores.transpose(1, 2).max(dim=-1, keepdim=True)
|
|
|
|
return boxes.transpose(1, 2), scores, labels
|
|
|
|
|
|
|
|
|
|
|
|
def optim(module: nn.Module):
|
|
|
|
s = str(type(module))[6:-2].split('.')[-1]
|
|
|
|
if s == 'Detect':
|
|
|
|
setattr(module, '__class__', PostDetect)
|
|
|
|
elif s == 'Segment':
|
|
|
|
setattr(module, '__class__', PostSeg)
|
|
|
|
elif s == 'C2f':
|
|
|
|
setattr(module, '__class__', C2f)
|