remove args

pull/6/head
triple-Mu 2 years ago
parent 7dd6f3684d
commit be3592b2e6
  1. 31
      models/common.py

@ -40,17 +40,16 @@ class TRT_NMS(torch.autograd.Function):
score_activation: int = 0
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
batch_size, num_boxes, num_classes = scores.shape
num_det = torch.randint(0,
max_output_boxes, (batch_size, 1),
dtype=torch.int32)
det_boxes = torch.randn(batch_size, max_output_boxes, 4)
det_scores = torch.randn(batch_size, max_output_boxes)
det_classes = torch.randint(0,
num_classes,
(batch_size, max_output_boxes),
dtype=torch.int32)
return num_det, det_boxes, det_scores, det_classes
num_dets = torch.randint(0,
max_output_boxes, (batch_size, 1),
dtype=torch.int32)
boxes = torch.randn(batch_size, max_output_boxes, 4)
scores = torch.randn(batch_size, max_output_boxes)
labels = torch.randint(0,
num_classes, (batch_size, max_output_boxes),
dtype=torch.int32)
return num_dets, boxes, scores, labels
@staticmethod
def symbolic(
@ -81,13 +80,7 @@ class TRT_NMS(torch.autograd.Function):
class C2f(nn.Module):
def __init__(self,
c1,
c2,
n=1,
shortcut=False,
g=1,
e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
def __init__(self, *args, **kwargs):
super().__init__()
def forward(self, x):
@ -107,7 +100,7 @@ class PostDetect(nn.Module):
conf_thres = 0.25
topk = 100
def __init__(self, nc=80, ch=()):
def __init__(self, *args, **kwargs):
super().__init__()
def forward(self, x):

Loading…
Cancel
Save