Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

detection_subclass_test.py 4.6 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
  1. import unittest
  2. import super_gradients
  3. import torch
  4. from super_gradients.training.datasets import COCODetectionDataSet
  5. from super_gradients.training.utils.detection_utils import base_detection_collate_fn, \
  6. plot_coco_datasaet_images_with_detections
  7. import os
  8. class DeciDataSetIntegrationTest(unittest.TestCase):
  9. def setUp(self) -> None:
  10. super_gradients.init_trainer()
  11. self.class_to_test = COCODetectionDataSet
  12. @classmethod
  13. def setUpClass(cls) -> None:
  14. cls.coco_dataset_params = {"batch_size": 1,
  15. "test_batch_size": 1,
  16. "dataset_dir": "/data/coco/",
  17. "s3_link": None,
  18. "image_size": 416,
  19. "degrees": 1.98, # image rotation (+/- deg)
  20. "translate": 0.05, # image translation (+/- fraction)
  21. "scale": 0.05, # image scale (+/- gain)
  22. "shear": 0.641,
  23. "hsv_h": 0.0138, # image HSV-Hue augmentation (fraction)
  24. "hsv_s": 0.678, # image HSV-Saturation augmentation (fraction)
  25. "hsv_v": 0.36, # image HSV-Value augmentation (fraction)
  26. }
  27. @classmethod
  28. def tearDownClass(cls) -> None:
  29. pass
  30. def test_coco_dataset_subclass_mosaic_loading_labels_cached(self, class_inclusion_list=['chair', 'dining table']):
  31. """
  32. Plots a single image with single bbox of an object from the sub class list, when in mosaic mode.
  33. @param class_inclusion_list: list(str) list of sub class names (from coco classes).
  34. @return:
  35. """
  36. test_batch_size = 64
  37. # TZAG COCO DATASET LOCATION
  38. coco_dataset = COCODetectionDataSet('/data/coco/', 'val2017.txt', batch_size=test_batch_size, img_size=640,
  39. dataset_hyper_params=self.coco_dataset_params,
  40. augment=True,
  41. cache_labels=True,
  42. cache_images=False,
  43. sample_loading_method='mosaic',
  44. class_inclusion_list=class_inclusion_list)
  45. self.assertTrue(len(coco_dataset) > 0)
  46. # LOAD DATA USING A DATA LOADER
  47. nw = min([os.cpu_count(), test_batch_size if test_batch_size > 1 else 0, 4]) # number of workers
  48. dataloader = torch.utils.data.DataLoader(coco_dataset,
  49. batch_size=test_batch_size,
  50. num_workers=nw,
  51. shuffle=True,
  52. # Shuffle=True unless rectangular training is used
  53. pin_memory=True,
  54. collate_fn=base_detection_collate_fn)
  55. plot_coco_datasaet_images_with_detections(dataloader, num_images_to_plot=1)
  56. def test_coco_dataset_subclass_integration_rectangular_loading_labels_cached(self, class_inclusion_list=['chair', 'dining table']):
  57. """
  58. Plots a single image with single bbox of an object from the sub class list, when in mosaic mode.
  59. @param class_inclusion_list: list(str) list of sub class names (from coco classes).
  60. @return:
  61. """
  62. test_batch_size = 64
  63. # TZAG COCO DATASET LOCATION
  64. coco_dataset = COCODetectionDataSet(
  65. '/data/coco/', 'val2017.txt', img_size=640, batch_size=test_batch_size,
  66. dataset_hyper_params=self.coco_dataset_params,
  67. cache_labels=True,
  68. cache_images=False,
  69. augment=False,
  70. sample_loading_method='rectangular',
  71. class_inclusion_list=class_inclusion_list
  72. )
  73. # LOAD DATA USING A DATA LOADER
  74. nw = min([os.cpu_count(), test_batch_size if test_batch_size > 1 else 0, 4]) # number of workers
  75. dataloader = torch.utils.data.DataLoader(coco_dataset,
  76. batch_size=test_batch_size,
  77. num_workers=nw,
  78. pin_memory=True,
  79. collate_fn=base_detection_collate_fn
  80. )
  81. plot_coco_datasaet_images_with_detections(dataloader, num_images_to_plot=1)
  82. if __name__ == '__main__':
  83. unittest.main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...