Merge pull request #25 from triple-Mu/dev

Rename variable name
pull/28/head
tripleMu 2 years ago committed by GitHub
commit 0770f5fa60
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 1
      models/__init__.py
  2. 46
      models/common.py
  3. 24
      models/engine.py

@ -7,4 +7,5 @@ warnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning)
warnings.filterwarnings(action='ignore', category=torch.jit.ScriptWarning) warnings.filterwarnings(action='ignore', category=torch.jit.ScriptWarning)
warnings.filterwarnings(action='ignore', category=UserWarning) warnings.filterwarnings(action='ignore', category=UserWarning)
warnings.filterwarnings(action='ignore', category=FutureWarning) warnings.filterwarnings(action='ignore', category=FutureWarning)
warnings.filterwarnings(action='ignore', category=DeprecationWarning)
__all__ = ['EngineBuilder', 'TRTModule', 'TRTProfilerV0', 'TRTProfilerV1'] __all__ = ['EngineBuilder', 'TRTModule', 'TRTProfilerV0', 'TRTProfilerV1']

@ -104,7 +104,7 @@ class PostDetect(nn.Module):
def forward(self, x): def forward(self, x):
shape = x[0].shape shape = x[0].shape
b, res = shape[0], [] b, res, b_reg_num = shape[0], [], self.reg_max * 4
for i in range(self.nl): for i in range(self.nl):
res.append(torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)) res.append(torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1))
if self.dynamic or self.shape != shape: if self.dynamic or self.shape != shape:
@ -113,15 +113,14 @@ class PostDetect(nn.Module):
self.shape = shape self.shape = shape
x = [i.view(b, self.no, -1) for i in res] x = [i.view(b, self.no, -1) for i in res]
y = torch.cat(x, 2) y = torch.cat(x, 2)
box, cls = y[:, :self.reg_max * 4, ...], y[:, self.reg_max * 4:, boxes, scores = y[:, :b_reg_num, ...], y[:, b_reg_num:, ...].sigmoid()
...].sigmoid() boxes = boxes.view(b, 4, self.reg_max, -1).permute(0, 1, 3, 2)
box = box.view(b, 4, self.reg_max, -1).permute(0, 1, 3, 2).contiguous() boxes = boxes.softmax(-1) @ torch.arange(self.reg_max).to(boxes)
box = box.softmax(-1) @ torch.arange(self.reg_max).to(box) boxes0, boxes1 = -boxes[:, :2, ...], boxes[:, 2:, ...]
box0, box1 = -box[:, :2, ...], box[:, 2:, ...] boxes = self.anchors.repeat(b, 2, 1) + torch.cat([boxes0, boxes1], 1)
box = self.anchors.repeat(b, 2, 1) + torch.cat([box0, box1], 1) boxes = boxes * self.strides
box = box * self.strides
return TRT_NMS.apply(boxes.transpose(1, 2), scores.transpose(1, 2),
return TRT_NMS.apply(box.transpose(1, 2), cls.transpose(1, 2),
self.iou_thres, self.conf_thres, self.topk) self.iou_thres, self.conf_thres, self.topk)
@ -139,30 +138,29 @@ class PostSeg(nn.Module):
mc = torch.cat( mc = torch.cat(
[self.cv4[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], [self.cv4[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)],
2) # mask coefficients 2) # mask coefficients
box, score, cls = self.forward_det(x) boxes, scores, labels = self.forward_det(x)
out = torch.cat([box, score, cls, mc.transpose(1, 2)], 2) out = torch.cat([boxes, scores, labels.float(), mc.transpose(1, 2)], 2)
return out, p.flatten(2) return out, p.flatten(2)
def forward_det(self, x): def forward_det(self, x):
shape = x[0].shape shape = x[0].shape
b, res = shape[0], [] b, res, b_reg_num = shape[0], [], self.reg_max * 4
for i in range(self.nl): for i in range(self.nl):
res.append(torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)) res.append(torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1))
if self.dynamic or self.shape != shape: if self.dynamic or self.shape != shape:
self.anchors, self.strides = (x.transpose( self.anchors, self.strides = \
0, 1) for x in make_anchors(x, self.stride, 0.5)) (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
self.shape = shape self.shape = shape
x = [i.view(b, self.no, -1) for i in res] x = [i.view(b, self.no, -1) for i in res]
y = torch.cat(x, 2) y = torch.cat(x, 2)
box, cls = y[:, :self.reg_max * 4, ...], y[:, self.reg_max * 4:, boxes, scores = y[:, :, ...], y[:, b_reg_num:, ...].sigmoid()
...].sigmoid() boxes = boxes.view(b, 4, self.reg_max, -1).permute(0, 1, 3, 2)
box = box.view(b, 4, self.reg_max, -1).permute(0, 1, 3, 2).contiguous() boxes = boxes.softmax(-1) @ torch.arange(self.reg_max).to(boxes)
box = box.softmax(-1) @ torch.arange(self.reg_max).to(box) boxes0, boxes1 = -boxes[:, :2, ...], boxes[:, 2:, ...]
box0, box1 = -box[:, :2, ...], box[:, 2:, ...] boxes = self.anchors.repeat(b, 2, 1) + torch.cat([boxes0, boxes1], 1)
box = self.anchors.repeat(b, 2, 1) + torch.cat([box0, box1], 1) boxes = boxes * self.strides
box = box * self.strides scores, labels = scores.transpose(1, 2).max(dim=-1, keepdim=True)
score, cls = cls.transpose(1, 2).max(dim=-1, keepdim=True) return boxes.transpose(1, 2), scores, labels
return box.transpose(1, 2), score, cls
def optim(module: nn.Module): def optim(module: nn.Module):

@ -222,18 +222,19 @@ class TRTModule(torch.nn.Module):
model = runtime.deserialize_cuda_engine(self.weight.read_bytes()) model = runtime.deserialize_cuda_engine(self.weight.read_bytes())
context = model.create_execution_context() context = model.create_execution_context()
num_bindings = model.num_bindings
names = [model.get_binding_name(i) for i in range(num_bindings)]
names = [model.get_binding_name(i) for i in range(model.num_bindings)] self.bindings: List[int] = [0] * num_bindings
self.num_bindings = model.num_bindings
self.bindings: List[int] = [0] * self.num_bindings
num_inputs, num_outputs = 0, 0 num_inputs, num_outputs = 0, 0
for i in range(model.num_bindings): for i in range(num_bindings):
if model.binding_is_input(i): if model.binding_is_input(i):
num_inputs += 1 num_inputs += 1
else: else:
num_outputs += 1 num_outputs += 1
self.num_bindings = num_bindings
self.num_inputs = num_inputs self.num_inputs = num_inputs
self.num_outputs = num_outputs self.num_outputs = num_outputs
self.model = model self.model = model
@ -243,7 +244,7 @@ class TRTModule(torch.nn.Module):
self.idx = list(range(self.num_outputs)) self.idx = list(range(self.num_outputs))
def __init_bindings(self) -> None: def __init_bindings(self) -> None:
dynamic = False idynamic = odynamic = False
Tensor = namedtuple('Tensor', ('name', 'dtype', 'shape')) Tensor = namedtuple('Tensor', ('name', 'dtype', 'shape'))
inp_info = [] inp_info = []
out_info = [] out_info = []
@ -252,23 +253,26 @@ class TRTModule(torch.nn.Module):
dtype = self.dtypeMapping[self.model.get_binding_dtype(i)] dtype = self.dtypeMapping[self.model.get_binding_dtype(i)]
shape = tuple(self.model.get_binding_shape(i)) shape = tuple(self.model.get_binding_shape(i))
if -1 in shape: if -1 in shape:
dynamic = True idynamic |= True
inp_info.append(Tensor(name, dtype, shape)) inp_info.append(Tensor(name, dtype, shape))
for i, name in enumerate(self.output_names): for i, name in enumerate(self.output_names):
i += self.num_inputs i += self.num_inputs
assert self.model.get_binding_name(i) == name assert self.model.get_binding_name(i) == name
dtype = self.dtypeMapping[self.model.get_binding_dtype(i)] dtype = self.dtypeMapping[self.model.get_binding_dtype(i)]
shape = tuple(self.model.get_binding_shape(i)) shape = tuple(self.model.get_binding_shape(i))
if -1 in shape:
odynamic |= True
out_info.append(Tensor(name, dtype, shape)) out_info.append(Tensor(name, dtype, shape))
if not dynamic: if not odynamic:
self.output_tensor = [ self.output_tensor = [
torch.empty(info.shape, dtype=info.dtype, device=self.device) torch.empty(info.shape, dtype=info.dtype, device=self.device)
for info in out_info for info in out_info
] ]
self.is_dynamic = dynamic self.idynamic = idynamic
self.odynamic = odynamic
self.inp_info = inp_info self.inp_info = inp_info
self.out_infp = out_info self.out_info = out_info
def set_profiler(self, profiler: Optional[trt.IProfiler]): def set_profiler(self, profiler: Optional[trt.IProfiler]):
self.context.profiler = profiler \ self.context.profiler = profiler \
@ -288,7 +292,7 @@ class TRTModule(torch.nn.Module):
for i in range(self.num_inputs): for i in range(self.num_inputs):
self.bindings[i] = contiguous_inputs[i].data_ptr() self.bindings[i] = contiguous_inputs[i].data_ptr()
if self.is_dynamic: if self.idynamic:
self.context.set_binding_shape( self.context.set_binding_shape(
i, tuple(contiguous_inputs[i].shape)) i, tuple(contiguous_inputs[i].shape))

Loading…
Cancel
Save