|
|
@ -29,7 +29,7 @@ from tqdm import tqdm |
|
|
|
from ultralytics.cfg import get_cfg |
|
|
|
from ultralytics.cfg import get_cfg |
|
|
|
from ultralytics.data.utils import check_cls_dataset, check_det_dataset |
|
|
|
from ultralytics.data.utils import check_cls_dataset, check_det_dataset |
|
|
|
from ultralytics.nn.autobackend import AutoBackend |
|
|
|
from ultralytics.nn.autobackend import AutoBackend |
|
|
|
from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, colorstr, emojis |
|
|
|
from ultralytics.utils import LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, colorstr, emojis |
|
|
|
from ultralytics.utils.checks import check_imgsz |
|
|
|
from ultralytics.utils.checks import check_imgsz |
|
|
|
from ultralytics.utils.files import increment_path |
|
|
|
from ultralytics.utils.files import increment_path |
|
|
|
from ultralytics.utils.ops import Profile |
|
|
|
from ultralytics.utils.ops import Profile |
|
|
@ -43,9 +43,9 @@ class BaseValidator: |
|
|
|
A base class for creating validators. |
|
|
|
A base class for creating validators. |
|
|
|
|
|
|
|
|
|
|
|
Attributes: |
|
|
|
Attributes: |
|
|
|
|
|
|
|
args (SimpleNamespace): Configuration for the validator. |
|
|
|
dataloader (DataLoader): Dataloader to use for validation. |
|
|
|
dataloader (DataLoader): Dataloader to use for validation. |
|
|
|
pbar (tqdm): Progress bar to update during validation. |
|
|
|
pbar (tqdm): Progress bar to update during validation. |
|
|
|
args (SimpleNamespace): Configuration for the validator. |
|
|
|
|
|
|
|
model (nn.Module): Model to validate. |
|
|
|
model (nn.Module): Model to validate. |
|
|
|
data (dict): Data dictionary. |
|
|
|
data (dict): Data dictionary. |
|
|
|
device (torch.device): Device to use for validation. |
|
|
|
device (torch.device): Device to use for validation. |
|
|
@ -76,9 +76,9 @@ class BaseValidator: |
|
|
|
args (SimpleNamespace): Configuration for the validator. |
|
|
|
args (SimpleNamespace): Configuration for the validator. |
|
|
|
_callbacks (dict): Dictionary to store various callback functions. |
|
|
|
_callbacks (dict): Dictionary to store various callback functions. |
|
|
|
""" |
|
|
|
""" |
|
|
|
|
|
|
|
self.args = get_cfg(overrides=args) |
|
|
|
self.dataloader = dataloader |
|
|
|
self.dataloader = dataloader |
|
|
|
self.pbar = pbar |
|
|
|
self.pbar = pbar |
|
|
|
self.args = args or get_cfg(DEFAULT_CFG) |
|
|
|
|
|
|
|
self.model = None |
|
|
|
self.model = None |
|
|
|
self.data = None |
|
|
|
self.data = None |
|
|
|
self.device = None |
|
|
|
self.device = None |
|
|
@ -126,8 +126,7 @@ class BaseValidator: |
|
|
|
else: |
|
|
|
else: |
|
|
|
callbacks.add_integration_callbacks(self) |
|
|
|
callbacks.add_integration_callbacks(self) |
|
|
|
self.run_callbacks('on_val_start') |
|
|
|
self.run_callbacks('on_val_start') |
|
|
|
assert model is not None, 'Either trainer or model is needed for validation' |
|
|
|
model = AutoBackend(model or self.args.model, |
|
|
|
model = AutoBackend(model, |
|
|
|
|
|
|
|
device=select_device(self.args.device, self.args.batch), |
|
|
|
device=select_device(self.args.device, self.args.batch), |
|
|
|
dnn=self.args.dnn, |
|
|
|
dnn=self.args.dnn, |
|
|
|
data=self.args.data, |
|
|
|
data=self.args.data, |
|
|
|