diff --git a/models/__init__.py b/models/__init__.py index 0acfe73..b22509a 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -7,4 +7,5 @@ warnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning) warnings.filterwarnings(action='ignore', category=torch.jit.ScriptWarning) warnings.filterwarnings(action='ignore', category=UserWarning) warnings.filterwarnings(action='ignore', category=FutureWarning) +warnings.filterwarnings(action='ignore', category=DeprecationWarning) __all__ = ['EngineBuilder', 'TRTModule', 'TRTProfilerV0', 'TRTProfilerV1'] diff --git a/models/common.py b/models/common.py index 1d563f8..3141900 100644 --- a/models/common.py +++ b/models/common.py @@ -104,7 +104,7 @@ class PostDetect(nn.Module): def forward(self, x): shape = x[0].shape - b, res = shape[0], [] + b, res, b_reg_num = shape[0], [], self.reg_max * 4 for i in range(self.nl): res.append(torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)) if self.dynamic or self.shape != shape: @@ -113,15 +113,14 @@ class PostDetect(nn.Module): self.shape = shape x = [i.view(b, self.no, -1) for i in res] y = torch.cat(x, 2) - box, cls = y[:, :self.reg_max * 4, ...], y[:, self.reg_max * 4:, - ...].sigmoid() - box = box.view(b, 4, self.reg_max, -1).permute(0, 1, 3, 2).contiguous() - box = box.softmax(-1) @ torch.arange(self.reg_max).to(box) - box0, box1 = -box[:, :2, ...], box[:, 2:, ...] - box = self.anchors.repeat(b, 2, 1) + torch.cat([box0, box1], 1) - box = box * self.strides - - return TRT_NMS.apply(box.transpose(1, 2), cls.transpose(1, 2), + boxes, scores = y[:, :b_reg_num, ...], y[:, b_reg_num:, ...].sigmoid() + boxes = boxes.view(b, 4, self.reg_max, -1).permute(0, 1, 3, 2) + boxes = boxes.softmax(-1) @ torch.arange(self.reg_max).to(boxes) + boxes0, boxes1 = -boxes[:, :2, ...], boxes[:, 2:, ...] + boxes = self.anchors.repeat(b, 2, 1) + torch.cat([boxes0, boxes1], 1) + boxes = boxes * self.strides + + return TRT_NMS.apply(boxes.transpose(1, 2), scores.transpose(1, 2), self.iou_thres, self.conf_thres, self.topk) @@ -139,30 +138,29 @@ class PostSeg(nn.Module): mc = torch.cat( [self.cv4[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2) # mask coefficients - box, score, cls = self.forward_det(x) - out = torch.cat([box, score, cls, mc.transpose(1, 2)], 2) + boxes, scores, labels = self.forward_det(x) + out = torch.cat([boxes, scores, labels.float(), mc.transpose(1, 2)], 2) return out, p.flatten(2) def forward_det(self, x): shape = x[0].shape - b, res = shape[0], [] + b, res, b_reg_num = shape[0], [], self.reg_max * 4 for i in range(self.nl): res.append(torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)) if self.dynamic or self.shape != shape: - self.anchors, self.strides = (x.transpose( - 0, 1) for x in make_anchors(x, self.stride, 0.5)) + self.anchors, self.strides = \ + (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5)) self.shape = shape x = [i.view(b, self.no, -1) for i in res] y = torch.cat(x, 2) - box, cls = y[:, :self.reg_max * 4, ...], y[:, self.reg_max * 4:, - ...].sigmoid() - box = box.view(b, 4, self.reg_max, -1).permute(0, 1, 3, 2).contiguous() - box = box.softmax(-1) @ torch.arange(self.reg_max).to(box) - box0, box1 = -box[:, :2, ...], box[:, 2:, ...] - box = self.anchors.repeat(b, 2, 1) + torch.cat([box0, box1], 1) - box = box * self.strides - score, cls = cls.transpose(1, 2).max(dim=-1, keepdim=True) - return box.transpose(1, 2), score, cls + boxes, scores = y[:, :, ...], y[:, b_reg_num:, ...].sigmoid() + boxes = boxes.view(b, 4, self.reg_max, -1).permute(0, 1, 3, 2) + boxes = boxes.softmax(-1) @ torch.arange(self.reg_max).to(boxes) + boxes0, boxes1 = -boxes[:, :2, ...], boxes[:, 2:, ...] + boxes = self.anchors.repeat(b, 2, 1) + torch.cat([boxes0, boxes1], 1) + boxes = boxes * self.strides + scores, labels = scores.transpose(1, 2).max(dim=-1, keepdim=True) + return boxes.transpose(1, 2), scores, labels def optim(module: nn.Module): diff --git a/models/engine.py b/models/engine.py index a558b0f..5fcfbe2 100644 --- a/models/engine.py +++ b/models/engine.py @@ -222,18 +222,19 @@ class TRTModule(torch.nn.Module): model = runtime.deserialize_cuda_engine(self.weight.read_bytes()) context = model.create_execution_context() + num_bindings = model.num_bindings + names = [model.get_binding_name(i) for i in range(num_bindings)] - names = [model.get_binding_name(i) for i in range(model.num_bindings)] - self.num_bindings = model.num_bindings - self.bindings: List[int] = [0] * self.num_bindings + self.bindings: List[int] = [0] * num_bindings num_inputs, num_outputs = 0, 0 - for i in range(model.num_bindings): + for i in range(num_bindings): if model.binding_is_input(i): num_inputs += 1 else: num_outputs += 1 + self.num_bindings = num_bindings self.num_inputs = num_inputs self.num_outputs = num_outputs self.model = model @@ -243,7 +244,7 @@ class TRTModule(torch.nn.Module): self.idx = list(range(self.num_outputs)) def __init_bindings(self) -> None: - dynamic = False + idynamic = odynamic = False Tensor = namedtuple('Tensor', ('name', 'dtype', 'shape')) inp_info = [] out_info = [] @@ -252,23 +253,26 @@ class TRTModule(torch.nn.Module): dtype = self.dtypeMapping[self.model.get_binding_dtype(i)] shape = tuple(self.model.get_binding_shape(i)) if -1 in shape: - dynamic = True + idynamic |= True inp_info.append(Tensor(name, dtype, shape)) for i, name in enumerate(self.output_names): i += self.num_inputs assert self.model.get_binding_name(i) == name dtype = self.dtypeMapping[self.model.get_binding_dtype(i)] shape = tuple(self.model.get_binding_shape(i)) + if -1 in shape: + odynamic |= True out_info.append(Tensor(name, dtype, shape)) - if not dynamic: + if not odynamic: self.output_tensor = [ torch.empty(info.shape, dtype=info.dtype, device=self.device) for info in out_info ] - self.is_dynamic = dynamic + self.idynamic = idynamic + self.odynamic = odynamic self.inp_info = inp_info - self.out_infp = out_info + self.out_info = out_info def set_profiler(self, profiler: Optional[trt.IProfiler]): self.context.profiler = profiler \ @@ -288,7 +292,7 @@ class TRTModule(torch.nn.Module): for i in range(self.num_inputs): self.bindings[i] = contiguous_inputs[i].data_ptr() - if self.is_dynamic: + if self.idynamic: self.context.set_binding_shape( i, tuple(contiguous_inputs[i].shape))