|
|
@ -214,7 +214,7 @@ class v8DetectionLoss: |
|
|
|
targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1) |
|
|
|
targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1) |
|
|
|
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]]) |
|
|
|
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]]) |
|
|
|
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy |
|
|
|
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy |
|
|
|
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0) |
|
|
|
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0) |
|
|
|
|
|
|
|
|
|
|
|
# Pboxes |
|
|
|
# Pboxes |
|
|
|
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4) |
|
|
|
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4) |
|
|
@ -280,7 +280,7 @@ class v8SegmentationLoss(v8DetectionLoss): |
|
|
|
targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1) |
|
|
|
targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1) |
|
|
|
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]]) |
|
|
|
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]]) |
|
|
|
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy |
|
|
|
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy |
|
|
|
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0) |
|
|
|
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0) |
|
|
|
except RuntimeError as e: |
|
|
|
except RuntimeError as e: |
|
|
|
raise TypeError( |
|
|
|
raise TypeError( |
|
|
|
"ERROR ❌ segment dataset incorrectly formatted or not a segment dataset.\n" |
|
|
|
"ERROR ❌ segment dataset incorrectly formatted or not a segment dataset.\n" |
|
|
@ -467,7 +467,7 @@ class v8PoseLoss(v8DetectionLoss): |
|
|
|
targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1) |
|
|
|
targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1) |
|
|
|
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]]) |
|
|
|
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]]) |
|
|
|
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy |
|
|
|
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy |
|
|
|
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0) |
|
|
|
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0) |
|
|
|
|
|
|
|
|
|
|
|
# Pboxes |
|
|
|
# Pboxes |
|
|
|
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4) |
|
|
|
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4) |
|
|
@ -652,7 +652,7 @@ class v8OBBLoss(v8DetectionLoss): |
|
|
|
targets = targets[(rw >= 2) & (rh >= 2)] # filter rboxes of tiny size to stabilize training |
|
|
|
targets = targets[(rw >= 2) & (rh >= 2)] # filter rboxes of tiny size to stabilize training |
|
|
|
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]]) |
|
|
|
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]]) |
|
|
|
gt_labels, gt_bboxes = targets.split((1, 5), 2) # cls, xywhr |
|
|
|
gt_labels, gt_bboxes = targets.split((1, 5), 2) # cls, xywhr |
|
|
|
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0) |
|
|
|
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0) |
|
|
|
except RuntimeError as e: |
|
|
|
except RuntimeError as e: |
|
|
|
raise TypeError( |
|
|
|
raise TypeError( |
|
|
|
"ERROR ❌ OBB dataset incorrectly formatted or not a OBB dataset.\n" |
|
|
|
"ERROR ❌ OBB dataset incorrectly formatted or not a OBB dataset.\n" |
|
|
|