pull/25/head
triple-Mu 2 years ago
parent 76de08073d
commit 953eb1971a
  1. 1
      models/__init__.py
  2. 46
      models/common.py
  3. 24
      models/engine.py

@ -7,4 +7,5 @@ warnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning)
warnings.filterwarnings(action='ignore', category=torch.jit.ScriptWarning)
warnings.filterwarnings(action='ignore', category=UserWarning)
warnings.filterwarnings(action='ignore', category=FutureWarning)
warnings.filterwarnings(action='ignore', category=DeprecationWarning)
__all__ = ['EngineBuilder', 'TRTModule', 'TRTProfilerV0', 'TRTProfilerV1']

@ -104,7 +104,7 @@ class PostDetect(nn.Module):
def forward(self, x):
shape = x[0].shape
b, res = shape[0], []
b, res, b_reg_num = shape[0], [], self.reg_max * 4
for i in range(self.nl):
res.append(torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1))
if self.dynamic or self.shape != shape:
@ -113,15 +113,14 @@ class PostDetect(nn.Module):
self.shape = shape
x = [i.view(b, self.no, -1) for i in res]
y = torch.cat(x, 2)
box, cls = y[:, :self.reg_max * 4, ...], y[:, self.reg_max * 4:,
...].sigmoid()
box = box.view(b, 4, self.reg_max, -1).permute(0, 1, 3, 2).contiguous()
box = box.softmax(-1) @ torch.arange(self.reg_max).to(box)
box0, box1 = -box[:, :2, ...], box[:, 2:, ...]
box = self.anchors.repeat(b, 2, 1) + torch.cat([box0, box1], 1)
box = box * self.strides
return TRT_NMS.apply(box.transpose(1, 2), cls.transpose(1, 2),
boxes, scores = y[:, :b_reg_num, ...], y[:, b_reg_num:, ...].sigmoid()
boxes = boxes.view(b, 4, self.reg_max, -1).permute(0, 1, 3, 2)
boxes = boxes.softmax(-1) @ torch.arange(self.reg_max).to(boxes)
boxes0, boxes1 = -boxes[:, :2, ...], boxes[:, 2:, ...]
boxes = self.anchors.repeat(b, 2, 1) + torch.cat([boxes0, boxes1], 1)
boxes = boxes * self.strides
return TRT_NMS.apply(boxes.transpose(1, 2), scores.transpose(1, 2),
self.iou_thres, self.conf_thres, self.topk)
@ -139,30 +138,29 @@ class PostSeg(nn.Module):
mc = torch.cat(
[self.cv4[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)],
2) # mask coefficients
box, score, cls = self.forward_det(x)
out = torch.cat([box, score, cls, mc.transpose(1, 2)], 2)
boxes, scores, labels = self.forward_det(x)
out = torch.cat([boxes, scores, labels.float(), mc.transpose(1, 2)], 2)
return out, p.flatten(2)
def forward_det(self, x):
shape = x[0].shape
b, res = shape[0], []
b, res, b_reg_num = shape[0], [], self.reg_max * 4
for i in range(self.nl):
res.append(torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1))
if self.dynamic or self.shape != shape:
self.anchors, self.strides = (x.transpose(
0, 1) for x in make_anchors(x, self.stride, 0.5))
self.anchors, self.strides = \
(x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
self.shape = shape
x = [i.view(b, self.no, -1) for i in res]
y = torch.cat(x, 2)
box, cls = y[:, :self.reg_max * 4, ...], y[:, self.reg_max * 4:,
...].sigmoid()
box = box.view(b, 4, self.reg_max, -1).permute(0, 1, 3, 2).contiguous()
box = box.softmax(-1) @ torch.arange(self.reg_max).to(box)
box0, box1 = -box[:, :2, ...], box[:, 2:, ...]
box = self.anchors.repeat(b, 2, 1) + torch.cat([box0, box1], 1)
box = box * self.strides
score, cls = cls.transpose(1, 2).max(dim=-1, keepdim=True)
return box.transpose(1, 2), score, cls
boxes, scores = y[:, :, ...], y[:, b_reg_num:, ...].sigmoid()
boxes = boxes.view(b, 4, self.reg_max, -1).permute(0, 1, 3, 2)
boxes = boxes.softmax(-1) @ torch.arange(self.reg_max).to(boxes)
boxes0, boxes1 = -boxes[:, :2, ...], boxes[:, 2:, ...]
boxes = self.anchors.repeat(b, 2, 1) + torch.cat([boxes0, boxes1], 1)
boxes = boxes * self.strides
scores, labels = scores.transpose(1, 2).max(dim=-1, keepdim=True)
return boxes.transpose(1, 2), scores, labels
def optim(module: nn.Module):

@ -222,18 +222,19 @@ class TRTModule(torch.nn.Module):
model = runtime.deserialize_cuda_engine(self.weight.read_bytes())
context = model.create_execution_context()
num_bindings = model.num_bindings
names = [model.get_binding_name(i) for i in range(num_bindings)]
names = [model.get_binding_name(i) for i in range(model.num_bindings)]
self.num_bindings = model.num_bindings
self.bindings: List[int] = [0] * self.num_bindings
self.bindings: List[int] = [0] * num_bindings
num_inputs, num_outputs = 0, 0
for i in range(model.num_bindings):
for i in range(num_bindings):
if model.binding_is_input(i):
num_inputs += 1
else:
num_outputs += 1
self.num_bindings = num_bindings
self.num_inputs = num_inputs
self.num_outputs = num_outputs
self.model = model
@ -243,7 +244,7 @@ class TRTModule(torch.nn.Module):
self.idx = list(range(self.num_outputs))
def __init_bindings(self) -> None:
dynamic = False
idynamic = odynamic = False
Tensor = namedtuple('Tensor', ('name', 'dtype', 'shape'))
inp_info = []
out_info = []
@ -252,23 +253,26 @@ class TRTModule(torch.nn.Module):
dtype = self.dtypeMapping[self.model.get_binding_dtype(i)]
shape = tuple(self.model.get_binding_shape(i))
if -1 in shape:
dynamic = True
idynamic |= True
inp_info.append(Tensor(name, dtype, shape))
for i, name in enumerate(self.output_names):
i += self.num_inputs
assert self.model.get_binding_name(i) == name
dtype = self.dtypeMapping[self.model.get_binding_dtype(i)]
shape = tuple(self.model.get_binding_shape(i))
if -1 in shape:
odynamic |= True
out_info.append(Tensor(name, dtype, shape))
if not dynamic:
if not odynamic:
self.output_tensor = [
torch.empty(info.shape, dtype=info.dtype, device=self.device)
for info in out_info
]
self.is_dynamic = dynamic
self.idynamic = idynamic
self.odynamic = odynamic
self.inp_info = inp_info
self.out_infp = out_info
self.out_info = out_info
def set_profiler(self, profiler: Optional[trt.IProfiler]):
self.context.profiler = profiler \
@ -288,7 +292,7 @@ class TRTModule(torch.nn.Module):
for i in range(self.num_inputs):
self.bindings[i] = contiguous_inputs[i].data_ptr()
if self.is_dynamic:
if self.idynamic:
self.context.set_binding_shape(
i, tuple(contiguous_inputs[i].shape))

Loading…
Cancel
Save