1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
|
- import unittest
- import pkg_resources
- import yaml
- from torch.utils.data import DataLoader
- from super_gradients.training.dataloaders.dataloaders import (
- cityscapes_train,
- cityscapes_val,
- cityscapes_stdc_seg50_train,
- cityscapes_stdc_seg50_val,
- cityscapes_stdc_seg75_val,
- cityscapes_ddrnet_train,
- cityscapes_regseg48_val,
- cityscapes_regseg48_train,
- cityscapes_ddrnet_val,
- cityscapes_stdc_seg75_train,
- )
- from super_gradients.training.datasets.segmentation_datasets.cityscape_segmentation import CityscapesDataset
- class CityscapesDatasetTest(unittest.TestCase):
- def setUp(self) -> None:
- default_config_path = pkg_resources.resource_filename("super_gradients.recipes", "dataset_params/cityscapes_dataset_params.yaml")
- with open(default_config_path, "r") as file:
- self.recipe = yaml.safe_load(file)
- def dataloader_tester(self, dl: DataLoader):
- self.assertTrue(isinstance(dl, DataLoader))
- self.assertTrue(isinstance(dl.dataset, CityscapesDataset))
- it = iter(dl)
- for _ in range(10):
- next(it)
- def test_train_dataset_creation(self):
- train_dataset = CityscapesDataset(**self.recipe["train_dataset_params"])
- for i in range(10):
- image, mask = train_dataset[i]
- def test_val_dataset_creation(self):
- val_dataset = CityscapesDataset(**self.recipe["val_dataset_params"])
- for i in range(10):
- image, mask = val_dataset[i]
- def test_cityscapes_train_dataloader(self):
- dl_train = cityscapes_train()
- self.dataloader_tester(dl_train)
- def test_cityscapes_val_dataloader(self):
- dl_val = cityscapes_val()
- self.dataloader_tester(dl_val)
- def test_cityscapes_stdc_seg50_train_dataloader(self):
- dl_train = cityscapes_stdc_seg50_train()
- self.dataloader_tester(dl_train)
- def test_cityscapes_stdc_seg50_val_dataloader(self):
- dl_val = cityscapes_stdc_seg50_val()
- self.dataloader_tester(dl_val)
- def test_cityscapes_stdc_seg75_train_dataloader(self):
- dl_train = cityscapes_stdc_seg75_train()
- self.dataloader_tester(dl_train)
- def test_cityscapes_stdc_seg75_val_dataloader(self):
- dl_val = cityscapes_stdc_seg75_val()
- self.dataloader_tester(dl_val)
- def test_cityscapes_regseg48_train_dataloader(self):
- dl_train = cityscapes_regseg48_train()
- self.dataloader_tester(dl_train)
- def test_cityscapes_regseg48_val_dataloader(self):
- dl_val = cityscapes_regseg48_val()
- self.dataloader_tester(dl_val)
- def test_cityscapes_ddrnet_train_dataloader(self):
- dl_train = cityscapes_ddrnet_train()
- self.dataloader_tester(dl_train)
- def test_cityscapes_ddrnet_val_dataloader(self):
- dl_val = cityscapes_ddrnet_val()
- self.dataloader_tester(dl_val)
- if __name__ == "__main__":
- unittest.main()
|