|
|
|
@ -107,7 +107,7 @@ class BaseTrainer: |
|
|
|
|
self.save_dir = get_save_dir(self.args) |
|
|
|
|
self.args.name = self.save_dir.name # update name for loggers |
|
|
|
|
self.wdir = self.save_dir / "weights" # weights dir |
|
|
|
|
if RANK in (-1, 0): |
|
|
|
|
if RANK in {-1, 0}: |
|
|
|
|
self.wdir.mkdir(parents=True, exist_ok=True) # make dir |
|
|
|
|
self.args.save_dir = str(self.save_dir) |
|
|
|
|
yaml_save(self.save_dir / "args.yaml", vars(self.args)) # save run args |
|
|
|
@ -121,7 +121,7 @@ class BaseTrainer: |
|
|
|
|
print_args(vars(self.args)) |
|
|
|
|
|
|
|
|
|
# Device |
|
|
|
|
if self.device.type in ("cpu", "mps"): |
|
|
|
|
if self.device.type in {"cpu", "mps"}: |
|
|
|
|
self.args.workers = 0 # faster CPU training as time dominated by inference, not dataloading |
|
|
|
|
|
|
|
|
|
# Model and Dataset |
|
|
|
@ -144,7 +144,7 @@ class BaseTrainer: |
|
|
|
|
|
|
|
|
|
# Callbacks |
|
|
|
|
self.callbacks = _callbacks or callbacks.get_default_callbacks() |
|
|
|
|
if RANK in (-1, 0): |
|
|
|
|
if RANK in {-1, 0}: |
|
|
|
|
callbacks.add_integration_callbacks(self) |
|
|
|
|
|
|
|
|
|
def add_callback(self, event: str, callback): |
|
|
|
@ -251,7 +251,7 @@ class BaseTrainer: |
|
|
|
|
|
|
|
|
|
# Check AMP |
|
|
|
|
self.amp = torch.tensor(self.args.amp).to(self.device) # True or False |
|
|
|
|
if self.amp and RANK in (-1, 0): # Single-GPU and DDP |
|
|
|
|
if self.amp and RANK in {-1, 0}: # Single-GPU and DDP |
|
|
|
|
callbacks_backup = callbacks.default_callbacks.copy() # backup callbacks as check_amp() resets them |
|
|
|
|
self.amp = torch.tensor(check_amp(self.model), device=self.device) |
|
|
|
|
callbacks.default_callbacks = callbacks_backup # restore callbacks |
|
|
|
@ -274,7 +274,7 @@ class BaseTrainer: |
|
|
|
|
# Dataloaders |
|
|
|
|
batch_size = self.batch_size // max(world_size, 1) |
|
|
|
|
self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=RANK, mode="train") |
|
|
|
|
if RANK in (-1, 0): |
|
|
|
|
if RANK in {-1, 0}: |
|
|
|
|
# Note: When training DOTA dataset, double batch size could get OOM on images with >2000 objects. |
|
|
|
|
self.test_loader = self.get_dataloader( |
|
|
|
|
self.testset, batch_size=batch_size if self.args.task == "obb" else batch_size * 2, rank=-1, mode="val" |
|
|
|
@ -340,7 +340,7 @@ class BaseTrainer: |
|
|
|
|
self._close_dataloader_mosaic() |
|
|
|
|
self.train_loader.reset() |
|
|
|
|
|
|
|
|
|
if RANK in (-1, 0): |
|
|
|
|
if RANK in {-1, 0}: |
|
|
|
|
LOGGER.info(self.progress_string()) |
|
|
|
|
pbar = TQDM(enumerate(self.train_loader), total=nb) |
|
|
|
|
self.tloss = None |
|
|
|
@ -392,7 +392,7 @@ class BaseTrainer: |
|
|
|
|
mem = f"{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G" # (GB) |
|
|
|
|
loss_len = self.tloss.shape[0] if len(self.tloss.shape) else 1 |
|
|
|
|
losses = self.tloss if loss_len > 1 else torch.unsqueeze(self.tloss, 0) |
|
|
|
|
if RANK in (-1, 0): |
|
|
|
|
if RANK in {-1, 0}: |
|
|
|
|
pbar.set_description( |
|
|
|
|
("%11s" * 2 + "%11.4g" * (2 + loss_len)) |
|
|
|
|
% (f"{epoch + 1}/{self.epochs}", mem, *losses, batch["cls"].shape[0], batch["img"].shape[-1]) |
|
|
|
@ -405,7 +405,7 @@ class BaseTrainer: |
|
|
|
|
|
|
|
|
|
self.lr = {f"lr/pg{ir}": x["lr"] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers |
|
|
|
|
self.run_callbacks("on_train_epoch_end") |
|
|
|
|
if RANK in (-1, 0): |
|
|
|
|
if RANK in {-1, 0}: |
|
|
|
|
final_epoch = epoch + 1 >= self.epochs |
|
|
|
|
self.ema.update_attr(self.model, include=["yaml", "nc", "args", "names", "stride", "class_weights"]) |
|
|
|
|
|
|
|
|
@ -447,7 +447,7 @@ class BaseTrainer: |
|
|
|
|
break # must break all DDP ranks |
|
|
|
|
epoch += 1 |
|
|
|
|
|
|
|
|
|
if RANK in (-1, 0): |
|
|
|
|
if RANK in {-1, 0}: |
|
|
|
|
# Do final val with best.pt |
|
|
|
|
LOGGER.info( |
|
|
|
|
f"\n{epoch - self.start_epoch + 1} epochs completed in " |
|
|
|
@ -503,12 +503,12 @@ class BaseTrainer: |
|
|
|
|
try: |
|
|
|
|
if self.args.task == "classify": |
|
|
|
|
data = check_cls_dataset(self.args.data) |
|
|
|
|
elif self.args.data.split(".")[-1] in ("yaml", "yml") or self.args.task in ( |
|
|
|
|
elif self.args.data.split(".")[-1] in {"yaml", "yml"} or self.args.task in { |
|
|
|
|
"detect", |
|
|
|
|
"segment", |
|
|
|
|
"pose", |
|
|
|
|
"obb", |
|
|
|
|
): |
|
|
|
|
}: |
|
|
|
|
data = check_det_dataset(self.args.data) |
|
|
|
|
if "yaml_file" in data: |
|
|
|
|
self.args.data = data["yaml_file"] # for validating 'yolo train data=url.zip' usage |
|
|
|
@ -740,7 +740,7 @@ class BaseTrainer: |
|
|
|
|
else: # weight (with decay) |
|
|
|
|
g[0].append(param) |
|
|
|
|
|
|
|
|
|
if name in ("Adam", "Adamax", "AdamW", "NAdam", "RAdam"): |
|
|
|
|
if name in {"Adam", "Adamax", "AdamW", "NAdam", "RAdam"}: |
|
|
|
|
optimizer = getattr(optim, name, optim.Adam)(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0) |
|
|
|
|
elif name == "RMSProp": |
|
|
|
|
optimizer = optim.RMSprop(g[2], lr=lr, momentum=momentum) |
|
|
|
|