Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

detection_dataset_test.py 4.6 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
  1. import tempfile
  2. import unittest
  3. import super_gradients
  4. from super_gradients.training.datasets import PascalVOCDetectionDataset, COCODetectionDataset
  5. from super_gradients.training.transforms import DetectionMosaic, DetectionPaddedRescale, DetectionTargetsFormatTransform
  6. from super_gradients.training.utils.detection_utils import DetectionTargetsFormat
  7. from super_gradients.training.exceptions.dataset_exceptions import EmptyDatasetException
  8. class DatasetIntegrationTest(unittest.TestCase):
  9. def setUp(self) -> None:
  10. super_gradients.init_trainer()
  11. self.batch_size = 64
  12. self.max_samples_per_plot = 16
  13. self.n_plot = 1
  14. transforms = [
  15. DetectionMosaic(input_dim=(640, 640), prob=0.8),
  16. DetectionPaddedRescale(input_dim=(640, 640), max_targets=120),
  17. DetectionTargetsFormatTransform(output_format=DetectionTargetsFormat.XYXY_LABEL),
  18. ]
  19. self.test_dir = tempfile.TemporaryDirectory().name
  20. PascalVOCDetectionDataset.download(self.test_dir)
  21. self.pascal_class_inclusion_lists = [["aeroplane", "bicycle"], ["bird", "boat", "bottle", "bus"], ["pottedplant"], ["person"]]
  22. self.pascal_base_config = dict(data_dir=self.test_dir, images_sub_directory="images/train2012/", input_dim=(640, 640), transforms=transforms)
  23. self.coco_class_inclusion_lists = [["airplane", "bicycle"], ["bird", "boat", "bottle", "bus"], ["potted plant"], ["person"]]
  24. self.dataset_coco_base_config = dict(
  25. data_dir="/data/coco",
  26. subdir="images/val2017",
  27. json_file="instances_val2017.json",
  28. input_dim=(640, 640),
  29. transforms=transforms,
  30. )
  31. def test_multiple_pascal_dataset_subclass_before_transforms(self):
  32. """Run test_pascal_dataset_subclass on multiple inclusion lists"""
  33. for class_inclusion_list in self.pascal_class_inclusion_lists:
  34. dataset = PascalVOCDetectionDataset(
  35. class_inclusion_list=class_inclusion_list, max_num_samples=self.max_samples_per_plot * self.n_plot, **self.pascal_base_config
  36. )
  37. dataset.plot(max_samples_per_plot=self.max_samples_per_plot, n_plots=self.n_plot, plot_transformed_data=False)
  38. def test_multiple_pascal_dataset_subclass_after_transforms(self):
  39. """Run test_pascal_dataset_subclass on multiple inclusion lists"""
  40. for class_inclusion_list in self.pascal_class_inclusion_lists:
  41. dataset = PascalVOCDetectionDataset(
  42. class_inclusion_list=class_inclusion_list, max_num_samples=self.max_samples_per_plot * self.n_plot, **self.pascal_base_config
  43. )
  44. dataset.plot(max_samples_per_plot=self.max_samples_per_plot, n_plots=self.n_plot, plot_transformed_data=True)
  45. def test_multiple_coco_dataset_subclass_before_transforms(self):
  46. """Check subclass on multiple inclusions before transform"""
  47. for class_inclusion_list in self.coco_class_inclusion_lists:
  48. dataset = COCODetectionDataset(
  49. class_inclusion_list=class_inclusion_list, max_num_samples=self.max_samples_per_plot * self.n_plot, **self.dataset_coco_base_config
  50. )
  51. dataset.plot(max_samples_per_plot=self.max_samples_per_plot, n_plots=self.n_plot, plot_transformed_data=False)
  52. def test_multiple_coco_dataset_subclass_after_transforms(self):
  53. """Check subclass on multiple inclusions after transform"""
  54. for class_inclusion_list in self.coco_class_inclusion_lists:
  55. dataset = COCODetectionDataset(
  56. class_inclusion_list=class_inclusion_list, max_num_samples=self.max_samples_per_plot * self.n_plot, **self.dataset_coco_base_config
  57. )
  58. dataset.plot(max_samples_per_plot=self.max_samples_per_plot, n_plots=self.n_plot, plot_transformed_data=True)
  59. def test_subclass_non_existing_class(self):
  60. """Check that EmptyDatasetException is raised when unknown label."""
  61. with self.assertRaises(ValueError):
  62. PascalVOCDetectionDataset(class_inclusion_list=["new_class"], **self.pascal_base_config)
  63. def test_sub_sampling_dataset(self):
  64. """Check that sub sampling works."""
  65. full_dataset = PascalVOCDetectionDataset(**self.pascal_base_config)
  66. with self.assertRaises(EmptyDatasetException):
  67. PascalVOCDetectionDataset(max_num_samples=0, **self.pascal_base_config)
  68. for max_num_samples in [1, 10, 1000, 1_000_000]:
  69. sampled_dataset = PascalVOCDetectionDataset(max_num_samples=max_num_samples, **self.pascal_base_config)
  70. self.assertEqual(len(sampled_dataset), min(max_num_samples, len(full_dataset)))
  71. if __name__ == "__main__":
  72. unittest.main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...