Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

cityscapes_dataset_test.py 4.2 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
  1. import unittest
  2. from typing import Type
  3. import pkg_resources
  4. import yaml
  5. from torch.utils.data import DataLoader, Dataset
  6. from super_gradients.training.dataloaders.dataloaders import (
  7. cityscapes_train,
  8. cityscapes_val,
  9. cityscapes_stdc_seg50_train,
  10. cityscapes_stdc_seg50_val,
  11. cityscapes_stdc_seg75_val,
  12. cityscapes_ddrnet_train,
  13. cityscapes_regseg48_val,
  14. cityscapes_regseg48_train,
  15. cityscapes_ddrnet_val,
  16. cityscapes_stdc_seg75_train,
  17. get,
  18. )
  19. from super_gradients.training.datasets.segmentation_datasets.cityscape_segmentation import CityscapesDataset, CityscapesConcatDataset
  20. class CityscapesDatasetTest(unittest.TestCase):
  21. def _cityscapes_dataset_params(self):
  22. default_config_path = pkg_resources.resource_filename("super_gradients.recipes", "dataset_params/cityscapes_dataset_params.yaml")
  23. with open(default_config_path, "r") as file:
  24. dataset_params = yaml.safe_load(file)
  25. return dataset_params
  26. def _cityscapes_al_dataset_params(self):
  27. default_config_path = pkg_resources.resource_filename("super_gradients.recipes", "dataset_params/cityscapes_al_dataset_params.yaml")
  28. with open(default_config_path, "r") as file:
  29. dataset_params = yaml.safe_load(file)
  30. return dataset_params
  31. def dataloader_tester(self, dl: DataLoader, dataset_cls: Type[Dataset] = CityscapesDataset):
  32. self.assertTrue(isinstance(dl, DataLoader))
  33. self.assertTrue(isinstance(dl.dataset, dataset_cls))
  34. it = iter(dl)
  35. for _ in range(10):
  36. next(it)
  37. def test_train_dataset_creation(self):
  38. dataset_params = self._cityscapes_dataset_params()
  39. train_dataset = CityscapesDataset(**dataset_params["train_dataset_params"])
  40. for i in range(10):
  41. image, mask = train_dataset[i]
  42. def test_al_train_dataset_creation(self):
  43. dataset_params = self._cityscapes_al_dataset_params()
  44. train_dataset = CityscapesConcatDataset(**dataset_params["train_dataset_params"])
  45. for i in range(10):
  46. image, mask = train_dataset[i]
  47. def test_val_dataset_creation(self):
  48. dataset_params = self._cityscapes_dataset_params()
  49. val_dataset = CityscapesDataset(**dataset_params["val_dataset_params"])
  50. for i in range(10):
  51. image, mask = val_dataset[i]
  52. def test_cityscapes_train_dataloader(self):
  53. dl_train = cityscapes_train()
  54. self.dataloader_tester(dl_train)
  55. def test_cityscapes_al_train_dataloader(self):
  56. dataset_params = self._cityscapes_al_dataset_params()
  57. # Same dataloader creation as in `train_from_recipe`
  58. dl_train = get(
  59. name=None,
  60. dataset_params=dataset_params["train_dataset_params"],
  61. dataloader_params=dataset_params["train_dataloader_params"],
  62. )
  63. self.dataloader_tester(dl_train, dataset_cls=CityscapesConcatDataset)
  64. def test_cityscapes_val_dataloader(self):
  65. dl_val = cityscapes_val()
  66. self.dataloader_tester(dl_val)
  67. def test_cityscapes_stdc_seg50_train_dataloader(self):
  68. dl_train = cityscapes_stdc_seg50_train()
  69. self.dataloader_tester(dl_train)
  70. def test_cityscapes_stdc_seg50_val_dataloader(self):
  71. dl_val = cityscapes_stdc_seg50_val()
  72. self.dataloader_tester(dl_val)
  73. def test_cityscapes_stdc_seg75_train_dataloader(self):
  74. dl_train = cityscapes_stdc_seg75_train()
  75. self.dataloader_tester(dl_train)
  76. def test_cityscapes_stdc_seg75_val_dataloader(self):
  77. dl_val = cityscapes_stdc_seg75_val()
  78. self.dataloader_tester(dl_val)
  79. def test_cityscapes_regseg48_train_dataloader(self):
  80. dl_train = cityscapes_regseg48_train()
  81. self.dataloader_tester(dl_train)
  82. def test_cityscapes_regseg48_val_dataloader(self):
  83. dl_val = cityscapes_regseg48_val()
  84. self.dataloader_tester(dl_val)
  85. def test_cityscapes_ddrnet_train_dataloader(self):
  86. dl_train = cityscapes_ddrnet_train()
  87. self.dataloader_tester(dl_train)
  88. def test_cityscapes_ddrnet_val_dataloader(self):
  89. dl_val = cityscapes_ddrnet_val()
  90. self.dataloader_tester(dl_val)
  91. if __name__ == "__main__":
  92. unittest.main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...