|
|
|
@ -26,6 +26,7 @@ from ultralytics.data.utils import check_cls_dataset, check_det_dataset |
|
|
|
|
from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights |
|
|
|
|
from ultralytics.utils import ( |
|
|
|
|
DEFAULT_CFG, |
|
|
|
|
LOCAL_RANK, |
|
|
|
|
LOGGER, |
|
|
|
|
RANK, |
|
|
|
|
TQDM, |
|
|
|
@ -129,7 +130,7 @@ class BaseTrainer: |
|
|
|
|
|
|
|
|
|
# Model and Dataset |
|
|
|
|
self.model = check_model_file_from_stem(self.args.model) # add suffix, i.e. yolov8n -> yolov8n.pt |
|
|
|
|
with torch_distributed_zero_first(RANK): # avoid auto-downloading dataset multiple times |
|
|
|
|
with torch_distributed_zero_first(LOCAL_RANK): # avoid auto-downloading dataset multiple times |
|
|
|
|
self.trainset, self.testset = self.get_dataset() |
|
|
|
|
self.ema = None |
|
|
|
|
|
|
|
|
@ -285,7 +286,7 @@ class BaseTrainer: |
|
|
|
|
|
|
|
|
|
# Dataloaders |
|
|
|
|
batch_size = self.batch_size // max(world_size, 1) |
|
|
|
|
self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=RANK, mode="train") |
|
|
|
|
self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=LOCAL_RANK, mode="train") |
|
|
|
|
if RANK in {-1, 0}: |
|
|
|
|
# Note: When training DOTA dataset, double batch size could get OOM on images with >2000 objects. |
|
|
|
|
self.test_loader = self.get_dataloader( |
|
|
|
|