|
|
|
@ -26,6 +26,7 @@ class DetectionValidator(BaseValidator): |
|
|
|
|
self.metrics = DetMetrics(save_dir=self.save_dir, on_plot=self.on_plot) |
|
|
|
|
self.iouv = torch.linspace(0.5, 0.95, 10) # iou vector for mAP@0.5:0.95 |
|
|
|
|
self.niou = self.iouv.numel() |
|
|
|
|
self.lb = [] # for autolabelling |
|
|
|
|
|
|
|
|
|
def preprocess(self, batch): |
|
|
|
|
"""Preprocesses batch of images for YOLO training.""" |
|
|
|
@ -34,8 +35,12 @@ class DetectionValidator(BaseValidator): |
|
|
|
|
for k in ['batch_idx', 'cls', 'bboxes']: |
|
|
|
|
batch[k] = batch[k].to(self.device) |
|
|
|
|
|
|
|
|
|
if self.args.save_hybrid: |
|
|
|
|
height, width = batch['img'].shape[2:] |
|
|
|
|
nb = len(batch['img']) |
|
|
|
|
self.lb = [torch.cat([batch['cls'], batch['bboxes']], dim=-1)[batch['batch_idx'] == i] |
|
|
|
|
bboxes = batch['bboxes'] * torch.tensor((width, height, width, height), device=self.device) |
|
|
|
|
self.lb = [ |
|
|
|
|
torch.cat([batch['cls'][batch['batch_idx'] == i], bboxes[batch['batch_idx'] == i]], dim=-1) |
|
|
|
|
for i in range(nb)] if self.args.save_hybrid else [] # for autolabelling |
|
|
|
|
|
|
|
|
|
return batch |
|
|
|
|