Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

coco_segmentation_dataset_test.py 1.6 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
  1. import unittest
  2. import hydra
  3. import pkg_resources
  4. import yaml
  5. from torch.utils.data import DataLoader
  6. from super_gradients.training.dataloaders.dataloaders import coco_segmentation_train, coco_segmentation_val
  7. from super_gradients.training.datasets.segmentation_datasets.coco_segmentation import CoCoSegmentationDataSet
  8. class CocoSegmentationDatasetTest(unittest.TestCase):
  9. def setUp(self) -> None:
  10. default_config_path = pkg_resources.resource_filename("super_gradients.recipes", "dataset_params/coco_segmentation_dataset_params.yaml")
  11. with open(default_config_path, "r") as file:
  12. self.recipe = yaml.safe_load(file)
  13. self.recipe = hydra.utils.instantiate(self.recipe)
  14. def dataloader_tester(self, dl: DataLoader):
  15. self.assertTrue(isinstance(dl, DataLoader))
  16. self.assertTrue(isinstance(dl.dataset, CoCoSegmentationDataSet))
  17. it = iter(dl)
  18. for _ in range(10):
  19. next(it)
  20. def test_train_dataset_creation(self):
  21. train_dataset = CoCoSegmentationDataSet(**self.recipe["train_dataset_params"])
  22. for i in range(10):
  23. image, mask = train_dataset[i]
  24. def test_val_dataset_creation(self):
  25. val_dataset = CoCoSegmentationDataSet(**self.recipe["val_dataset_params"])
  26. for i in range(10):
  27. image, mask = val_dataset[i]
  28. def test_coco_seg_train_dataloader(self):
  29. dl_train = coco_segmentation_train()
  30. self.dataloader_tester(dl_train)
  31. def test_coco_seg_val_dataloader(self):
  32. dl_val = coco_segmentation_val()
  33. self.dataloader_tester(dl_val)
  34. if __name__ == "__main__":
  35. unittest.main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...