Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#367 fix: Request correct hydra-core

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:hotfix/ALG-000_hydra-req
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
  1. import unittest
  2. import pkg_resources
  3. import yaml
  4. from torch.utils.data import DataLoader
  5. from super_gradients.training.dataloaders.dataloaders import cityscapes_train, cityscapes_val, \
  6. cityscapes_stdc_seg50_train, cityscapes_stdc_seg50_val, cityscapes_stdc_seg75_val, cityscapes_ddrnet_train, \
  7. cityscapes_regseg48_val, cityscapes_regseg48_train, cityscapes_ddrnet_val, cityscapes_stdc_seg75_train
  8. from super_gradients.training.datasets.segmentation_datasets.cityscape_segmentation import CityscapesDataset
  9. class CityscapesDatasetTest(unittest.TestCase):
  10. def setUp(self) -> None:
  11. default_config_path = pkg_resources.resource_filename("super_gradients.recipes",
  12. "dataset_params/cityscapes_dataset_params.yaml")
  13. with open(default_config_path, 'r') as file:
  14. self.recipe = yaml.safe_load(file)
  15. def dataloader_tester(self, dl: DataLoader):
  16. self.assertTrue(isinstance(dl, DataLoader))
  17. self.assertTrue(isinstance(dl.dataset, CityscapesDataset))
  18. it = iter(dl)
  19. for _ in range(10):
  20. next(it)
  21. def test_train_dataset_creation(self):
  22. train_dataset = CityscapesDataset(**self.recipe['train_dataset_params'])
  23. for i in range(10):
  24. image, mask = train_dataset[i]
  25. def test_val_dataset_creation(self):
  26. val_dataset = CityscapesDataset(**self.recipe['val_dataset_params'])
  27. for i in range(10):
  28. image, mask = val_dataset[i]
  29. def test_cityscapes_train_dataloader(self):
  30. dl_train = cityscapes_train()
  31. self.dataloader_tester(dl_train)
  32. def test_cityscapes_val_dataloader(self):
  33. dl_val = cityscapes_val()
  34. self.dataloader_tester(dl_val)
  35. def test_cityscapes_stdc_seg50_train_dataloader(self):
  36. dl_train = cityscapes_stdc_seg50_train()
  37. self.dataloader_tester(dl_train)
  38. def test_cityscapes_stdc_seg50_val_dataloader(self):
  39. dl_val = cityscapes_stdc_seg50_val()
  40. self.dataloader_tester(dl_val)
  41. def test_cityscapes_stdc_seg75_train_dataloader(self):
  42. dl_train = cityscapes_stdc_seg75_train()
  43. self.dataloader_tester(dl_train)
  44. def test_cityscapes_stdc_seg75_val_dataloader(self):
  45. dl_val = cityscapes_stdc_seg75_val()
  46. self.dataloader_tester(dl_val)
  47. def test_cityscapes_regseg48_train_dataloader(self):
  48. dl_train = cityscapes_regseg48_train()
  49. self.dataloader_tester(dl_train)
  50. def test_cityscapes_regseg48_val_dataloader(self):
  51. dl_val = cityscapes_regseg48_val()
  52. self.dataloader_tester(dl_val)
  53. def test_cityscapes_ddrnet_train_dataloader(self):
  54. dl_train = cityscapes_ddrnet_train()
  55. self.dataloader_tester(dl_train)
  56. def test_cityscapes_ddrnet_val_dataloader(self):
  57. dl_val = cityscapes_ddrnet_val()
  58. self.dataloader_tester(dl_val)
  59. if __name__ == '__main__':
  60. unittest.main()
Discard
Tip!

Press p or to see the previous file or, n or to see the next file