|
|
|
@ -54,16 +54,15 @@ class WorldTrainerFromScratch(WorldTrainer): |
|
|
|
|
batch (int, optional): Size of batches, this is for `rect`. Defaults to None. |
|
|
|
|
""" |
|
|
|
|
gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32) |
|
|
|
|
if mode == "train": |
|
|
|
|
dataset = [ |
|
|
|
|
build_yolo_dataset(self.args, im_path, batch, self.data, stride=gs, multi_modal=True) |
|
|
|
|
if isinstance(im_path, str) |
|
|
|
|
else build_grounding(self.args, im_path["img_path"], im_path["json_file"], batch, stride=gs) |
|
|
|
|
for im_path in img_path |
|
|
|
|
] |
|
|
|
|
return YOLOConcatDataset(dataset) if len(dataset) > 1 else dataset[0] |
|
|
|
|
else: |
|
|
|
|
if mode != "train": |
|
|
|
|
return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs) |
|
|
|
|
dataset = [ |
|
|
|
|
build_yolo_dataset(self.args, im_path, batch, self.data, stride=gs, multi_modal=True) |
|
|
|
|
if isinstance(im_path, str) |
|
|
|
|
else build_grounding(self.args, im_path["img_path"], im_path["json_file"], batch, stride=gs) |
|
|
|
|
for im_path in img_path |
|
|
|
|
] |
|
|
|
|
return YOLOConcatDataset(dataset) if len(dataset) > 1 else dataset[0] |
|
|
|
|
|
|
|
|
|
def get_dataset(self): |
|
|
|
|
""" |
|
|
|
@ -71,7 +70,7 @@ class WorldTrainerFromScratch(WorldTrainer): |
|
|
|
|
|
|
|
|
|
Returns None if data format is not recognized. |
|
|
|
|
""" |
|
|
|
|
final_data = dict() |
|
|
|
|
final_data = {} |
|
|
|
|
data_yaml = self.args.data |
|
|
|
|
assert data_yaml.get("train", False) # object365.yaml |
|
|
|
|
assert data_yaml.get("val", False) # lvis.yaml |
|
|
|
@ -88,7 +87,7 @@ class WorldTrainerFromScratch(WorldTrainer): |
|
|
|
|
grounding_data = data_yaml[s].get("grounding_data") |
|
|
|
|
if grounding_data is None: |
|
|
|
|
continue |
|
|
|
|
grounding_data = [grounding_data] if not isinstance(grounding_data, list) else grounding_data |
|
|
|
|
grounding_data = grounding_data if isinstance(grounding_data, list) else [grounding_data] |
|
|
|
|
for g in grounding_data: |
|
|
|
|
assert isinstance(g, dict), f"Grounding data should be provided in dict format, but got {type(g)}" |
|
|
|
|
final_data[s] += grounding_data |
|
|
|
|