|
|
|
@ -49,9 +49,12 @@ class BaseTrainer: |
|
|
|
|
# dirs |
|
|
|
|
project = self.args.project or f"runs/{self.args.task}" |
|
|
|
|
name = self.args.name or f"{self.args.mode}" |
|
|
|
|
self.save_dir = increment_path(Path(project) / name, exist_ok=self.args.exist_ok if RANK == -1 else True) |
|
|
|
|
self.save_dir = increment_path(Path(project) / name, exist_ok=self.args.exist_ok if RANK in {-1, 0} else True) |
|
|
|
|
self.wdir = self.save_dir / 'weights' # weights dir |
|
|
|
|
self.wdir.mkdir(parents=True, exist_ok=True) # make dir |
|
|
|
|
if RANK in {-1, 0}: |
|
|
|
|
self.wdir.mkdir(parents=True, exist_ok=True) # make dir |
|
|
|
|
# Save run settings |
|
|
|
|
save_yaml(self.save_dir / 'args.yaml', OmegaConf.to_container(self.args, resolve=True)) |
|
|
|
|
self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt' # checkpoint paths |
|
|
|
|
|
|
|
|
|
self.batch_size = self.args.batch_size |
|
|
|
@ -60,9 +63,6 @@ class BaseTrainer: |
|
|
|
|
if RANK == -1: |
|
|
|
|
print_args(dict(self.args)) |
|
|
|
|
|
|
|
|
|
# Save run settings |
|
|
|
|
save_yaml(self.save_dir / 'args.yaml', OmegaConf.to_container(self.args, resolve=True)) |
|
|
|
|
|
|
|
|
|
# device |
|
|
|
|
self.device = utils.torch_utils.select_device(self.args.device, self.batch_size) |
|
|
|
|
self.amp = self.device.type != 'cpu' |
|
|
|
|