|
@@ -1,9 +1,11 @@
|
|
# Ultralytics ๐ AGPL-3.0 License - https://ultralytics.com/license
|
|
# Ultralytics ๐ AGPL-3.0 License - https://ultralytics.com/license
|
|
|
|
|
|
|
|
+from pathlib import Path
|
|
|
|
+
|
|
from ultralytics.data import YOLOConcatDataset, build_grounding, build_yolo_dataset
|
|
from ultralytics.data import YOLOConcatDataset, build_grounding, build_yolo_dataset
|
|
from ultralytics.data.utils import check_det_dataset
|
|
from ultralytics.data.utils import check_det_dataset
|
|
from ultralytics.models.yolo.world import WorldTrainer
|
|
from ultralytics.models.yolo.world import WorldTrainer
|
|
-from ultralytics.utils import DEFAULT_CFG, LOGGER
|
|
|
|
|
|
+from ultralytics.utils import DATASETS_DIR, DEFAULT_CFG, LOGGER
|
|
from ultralytics.utils.torch_utils import de_parallel
|
|
from ultralytics.utils.torch_utils import de_parallel
|
|
|
|
|
|
|
|
|
|
@@ -136,7 +138,7 @@ class WorldTrainerFromScratch(WorldTrainer):
|
|
if d.get("minival") is None: # for lvis dataset
|
|
if d.get("minival") is None: # for lvis dataset
|
|
continue
|
|
continue
|
|
d["minival"] = str(d["path"] / d["minival"])
|
|
d["minival"] = str(d["path"] / d["minival"])
|
|
- for s in ["train", "val"]:
|
|
|
|
|
|
+ for s in {"train", "val"}:
|
|
final_data[s] = [d["train" if s == "train" else val_split] for d in data[s]]
|
|
final_data[s] = [d["train" if s == "train" else val_split] for d in data[s]]
|
|
# save grounding data if there's one
|
|
# save grounding data if there's one
|
|
grounding_data = data_yaml[s].get("grounding_data")
|
|
grounding_data = data_yaml[s].get("grounding_data")
|
|
@@ -145,6 +147,10 @@ class WorldTrainerFromScratch(WorldTrainer):
|
|
grounding_data = grounding_data if isinstance(grounding_data, list) else [grounding_data]
|
|
grounding_data = grounding_data if isinstance(grounding_data, list) else [grounding_data]
|
|
for g in grounding_data:
|
|
for g in grounding_data:
|
|
assert isinstance(g, dict), f"Grounding data should be provided in dict format, but got {type(g)}"
|
|
assert isinstance(g, dict), f"Grounding data should be provided in dict format, but got {type(g)}"
|
|
|
|
+ for k in {"img_path", "json_file"}:
|
|
|
|
+ path = Path(g[k])
|
|
|
|
+ if not path.exists() and not path.is_absolute():
|
|
|
|
+ g[k] = str((DATASETS_DIR / g[k]).resolve()) # path relative to DATASETS_DIR
|
|
final_data[s] += grounding_data
|
|
final_data[s] += grounding_data
|
|
# assign the first val dataset as currently only one validation set is supported
|
|
# assign the first val dataset as currently only one validation set is supported
|
|
data["val"] = data["val"][0]
|
|
data["val"] = data["val"][0]
|