diff --git a/models/common.py b/models/common.py index 6821023..f8780ab 100644 --- a/models/common.py +++ b/models/common.py @@ -40,17 +40,16 @@ class TRT_NMS(torch.autograd.Function): score_activation: int = 0 ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: batch_size, num_boxes, num_classes = scores.shape - num_det = torch.randint(0, - max_output_boxes, (batch_size, 1), - dtype=torch.int32) - det_boxes = torch.randn(batch_size, max_output_boxes, 4) - det_scores = torch.randn(batch_size, max_output_boxes) - det_classes = torch.randint(0, - num_classes, - (batch_size, max_output_boxes), - dtype=torch.int32) - - return num_det, det_boxes, det_scores, det_classes + num_dets = torch.randint(0, + max_output_boxes, (batch_size, 1), + dtype=torch.int32) + boxes = torch.randn(batch_size, max_output_boxes, 4) + scores = torch.randn(batch_size, max_output_boxes) + labels = torch.randint(0, + num_classes, (batch_size, max_output_boxes), + dtype=torch.int32) + + return num_dets, boxes, scores, labels @staticmethod def symbolic( @@ -81,13 +80,7 @@ class TRT_NMS(torch.autograd.Function): class C2f(nn.Module): - def __init__(self, - c1, - c2, - n=1, - shortcut=False, - g=1, - e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion + def __init__(self, *args, **kwargs): super().__init__() def forward(self, x): @@ -107,7 +100,7 @@ class PostDetect(nn.Module): conf_thres = 0.25 topk = 100 - def __init__(self, nc=80, ch=()): + def __init__(self, *args, **kwargs): super().__init__() def forward(self, x):