Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#1000 Cityscapes AutoLabelling dataset

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/ALG-1373_cityscapes_auto_label
1 changed files with 35 additions and 8 deletions
  1. 35
    8
      tests/unit_tests/cityscapes_dataset_test.py
@@ -1,8 +1,9 @@
 import unittest
 import unittest
+from typing import Type
 
 
 import pkg_resources
 import pkg_resources
 import yaml
 import yaml
-from torch.utils.data import DataLoader
+from torch.utils.data import DataLoader, Dataset
 
 
 from super_gradients.training.dataloaders.dataloaders import (
 from super_gradients.training.dataloaders.dataloaders import (
     cityscapes_train,
     cityscapes_train,
@@ -15,30 +16,46 @@ from super_gradients.training.dataloaders.dataloaders import (
     cityscapes_regseg48_train,
     cityscapes_regseg48_train,
     cityscapes_ddrnet_val,
     cityscapes_ddrnet_val,
     cityscapes_stdc_seg75_train,
     cityscapes_stdc_seg75_train,
+    get,
 )
 )
-from super_gradients.training.datasets.segmentation_datasets.cityscape_segmentation import CityscapesDataset
+from super_gradients.training.datasets.segmentation_datasets.cityscape_segmentation import CityscapesDataset, CityscapesConcatDataset
 
 
 
 
 class CityscapesDatasetTest(unittest.TestCase):
 class CityscapesDatasetTest(unittest.TestCase):
-    def setUp(self) -> None:
+    def _cityscapes_dataset_params(self):
         default_config_path = pkg_resources.resource_filename("super_gradients.recipes", "dataset_params/cityscapes_dataset_params.yaml")
         default_config_path = pkg_resources.resource_filename("super_gradients.recipes", "dataset_params/cityscapes_dataset_params.yaml")
         with open(default_config_path, "r") as file:
         with open(default_config_path, "r") as file:
-            self.recipe = yaml.safe_load(file)
+            dataset_params = yaml.safe_load(file)
+        return dataset_params
 
 
-    def dataloader_tester(self, dl: DataLoader):
+    def _cityscapes_al_dataset_params(self):
+        default_config_path = pkg_resources.resource_filename("super_gradients.recipes", "dataset_params/cityscapes_al_dataset_params.yaml")
+        with open(default_config_path, "r") as file:
+            dataset_params = yaml.safe_load(file)
+        return dataset_params
+
+    def dataloader_tester(self, dl: DataLoader, dataset_cls: Type[Dataset] = CityscapesDataset):
         self.assertTrue(isinstance(dl, DataLoader))
         self.assertTrue(isinstance(dl, DataLoader))
-        self.assertTrue(isinstance(dl.dataset, CityscapesDataset))
+        self.assertTrue(isinstance(dl.dataset, dataset_cls))
         it = iter(dl)
         it = iter(dl)
         for _ in range(10):
         for _ in range(10):
             next(it)
             next(it)
 
 
     def test_train_dataset_creation(self):
     def test_train_dataset_creation(self):
-        train_dataset = CityscapesDataset(**self.recipe["train_dataset_params"])
+        dataset_params = self._cityscapes_dataset_params()
+        train_dataset = CityscapesDataset(**dataset_params["train_dataset_params"])
+        for i in range(10):
+            image, mask = train_dataset[i]
+
+    def test_al_train_dataset_creation(self):
+        dataset_params = self._cityscapes_al_dataset_params()
+        train_dataset = CityscapesConcatDataset(**dataset_params["train_dataset_params"])
         for i in range(10):
         for i in range(10):
             image, mask = train_dataset[i]
             image, mask = train_dataset[i]
 
 
     def test_val_dataset_creation(self):
     def test_val_dataset_creation(self):
-        val_dataset = CityscapesDataset(**self.recipe["val_dataset_params"])
+        dataset_params = self._cityscapes_dataset_params()
+        val_dataset = CityscapesDataset(**dataset_params["val_dataset_params"])
         for i in range(10):
         for i in range(10):
             image, mask = val_dataset[i]
             image, mask = val_dataset[i]
 
 
@@ -46,6 +63,16 @@ class CityscapesDatasetTest(unittest.TestCase):
         dl_train = cityscapes_train()
         dl_train = cityscapes_train()
         self.dataloader_tester(dl_train)
         self.dataloader_tester(dl_train)
 
 
+    def test_cityscapes_al_train_dataloader(self):
+        dataset_params = self._cityscapes_al_dataset_params()
+        # Same dataloader creation as in `train_from_recipe`
+        dl_train = get(
+            name=None,
+            dataset_params=dataset_params["train_dataset_params"],
+            dataloader_params=dataset_params["train_dataloader_params"],
+        )
+        self.dataloader_tester(dl_train, dataset_cls=CityscapesConcatDataset)
+
     def test_cityscapes_val_dataloader(self):
     def test_cityscapes_val_dataloader(self):
         dl_val = cityscapes_val()
         dl_val = cityscapes_val()
         self.dataloader_tester(dl_val)
         self.dataloader_tester(dl_val)
Discard