You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

174 lines
6.3 KiB

2 years ago
from typing import Tuple
import torch
import torch.nn as nn
from torch import Graph, Tensor, Value
def make_anchors(feats: Tensor,
strides: Tensor,
grid_cell_offset: float = 0.5) -> Tuple[Tensor, Tensor]:
anchor_points, stride_tensor = [], []
assert feats is not None
dtype, device = feats[0].dtype, feats[0].device
for i, stride in enumerate(strides):
_, _, h, w = feats[i].shape
sx = torch.arange(end=w, device=device,
dtype=dtype) + grid_cell_offset # shift x
sy = torch.arange(end=h, device=device,
dtype=dtype) + grid_cell_offset # shift y
sy, sx = torch.meshgrid(sy, sx)
anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2))
stride_tensor.append(
torch.full((h * w, 1), stride, dtype=dtype, device=device))
return torch.cat(anchor_points), torch.cat(stride_tensor)
class TRT_NMS(torch.autograd.Function):
@staticmethod
def forward(
ctx: Graph,
boxes: Tensor,
scores: Tensor,
iou_threshold: float = 0.65,
score_threshold: float = 0.25,
max_output_boxes: int = 100,
background_class: int = -1,
box_coding: int = 0,
plugin_version: str = '1',
score_activation: int = 0
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
batch_size, num_boxes, num_classes = scores.shape
2 years ago
num_dets = torch.randint(0,
max_output_boxes, (batch_size, 1),
dtype=torch.int32)
boxes = torch.randn(batch_size, max_output_boxes, 4)
scores = torch.randn(batch_size, max_output_boxes)
labels = torch.randint(0,
num_classes, (batch_size, max_output_boxes),
dtype=torch.int32)
return num_dets, boxes, scores, labels
2 years ago
@staticmethod
def symbolic(
g,
boxes: Value,
scores: Value,
iou_threshold: float = 0.45,
score_threshold: float = 0.25,
max_output_boxes: int = 100,
background_class: int = -1,
box_coding: int = 0,
score_activation: int = 0,
plugin_version: str = '1') -> Tuple[Value, Value, Value, Value]:
out = g.op('TRT::EfficientNMS_TRT',
boxes,
scores,
iou_threshold_f=iou_threshold,
score_threshold_f=score_threshold,
max_output_boxes_i=max_output_boxes,
background_class_i=background_class,
box_coding_i=box_coding,
plugin_version_s=plugin_version,
score_activation_i=score_activation,
outputs=4)
nums_dets, boxes, scores, classes = out
return nums_dets, boxes, scores, classes
class C2f(nn.Module):
2 years ago
def __init__(self, *args, **kwargs):
2 years ago
super().__init__()
def forward(self, x):
x = self.cv1(x)
2 years ago
x = [x, x[:, self.c:, ...]]
x.extend(m(x[-1]) for m in self.m)
x.pop(1)
return self.cv2(torch.cat(x, 1))
2 years ago
class PostDetect(nn.Module):
export = True
shape = None
dynamic = False
iou_thres = 0.65
conf_thres = 0.25
topk = 100
2 years ago
def __init__(self, *args, **kwargs):
2 years ago
super().__init__()
def forward(self, x):
shape = x[0].shape
2 years ago
b, res, b_reg_num = shape[0], [], self.reg_max * 4
2 years ago
for i in range(self.nl):
res.append(torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1))
if self.dynamic or self.shape != shape:
self.anchors, self.strides = (x.transpose(
0, 1) for x in make_anchors(x, self.stride, 0.5))
self.shape = shape
x = [i.view(b, self.no, -1) for i in res]
y = torch.cat(x, 2)
2 years ago
boxes, scores = y[:, :b_reg_num, ...], y[:, b_reg_num:, ...].sigmoid()
boxes = boxes.view(b, 4, self.reg_max, -1).permute(0, 1, 3, 2)
boxes = boxes.softmax(-1) @ torch.arange(self.reg_max).to(boxes)
boxes0, boxes1 = -boxes[:, :2, ...], boxes[:, 2:, ...]
boxes = self.anchors.repeat(b, 2, 1) + torch.cat([boxes0, boxes1], 1)
boxes = boxes * self.strides
return TRT_NMS.apply(boxes.transpose(1, 2), scores.transpose(1, 2),
2 years ago
self.iou_thres, self.conf_thres, self.topk)
class PostSeg(nn.Module):
export = True
shape = None
dynamic = False
def __init__(self, *args, **kwargs):
super().__init__()
def forward(self, x):
p = self.proto(x[0]) # mask protos
bs = p.shape[0] # batch size
mc = torch.cat(
[self.cv4[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)],
2) # mask coefficients
2 years ago
boxes, scores, labels = self.forward_det(x)
out = torch.cat([boxes, scores, labels.float(), mc.transpose(1, 2)], 2)
2 years ago
return out, p.flatten(2)
def forward_det(self, x):
shape = x[0].shape
2 years ago
b, res, b_reg_num = shape[0], [], self.reg_max * 4
for i in range(self.nl):
res.append(torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1))
if self.dynamic or self.shape != shape:
2 years ago
self.anchors, self.strides = \
(x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
self.shape = shape
x = [i.view(b, self.no, -1) for i in res]
y = torch.cat(x, 2)
boxes, scores = y[:, :b_reg_num, ...], y[:, b_reg_num:, ...].sigmoid()
2 years ago
boxes = boxes.view(b, 4, self.reg_max, -1).permute(0, 1, 3, 2)
boxes = boxes.softmax(-1) @ torch.arange(self.reg_max).to(boxes)
boxes0, boxes1 = -boxes[:, :2, ...], boxes[:, 2:, ...]
boxes = self.anchors.repeat(b, 2, 1) + torch.cat([boxes0, boxes1], 1)
boxes = boxes * self.strides
scores, labels = scores.transpose(1, 2).max(dim=-1, keepdim=True)
return boxes.transpose(1, 2), scores, labels
2 years ago
def optim(module: nn.Module):
s = str(type(module))[6:-2].split('.')[-1]
if s == 'Detect':
setattr(module, '__class__', PostDetect)
elif s == 'Segment':
setattr(module, '__class__', PostSeg)
2 years ago
elif s == 'C2f':
setattr(module, '__class__', C2f)