Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

test_auto_augment.py 1.4 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
  1. import unittest
  2. import torchvision.transforms as transforms
  3. from super_gradients.training.datasets.auto_augment import RandAugment
  4. from super_gradients.training.datasets.datasets_utils import get_color_augmentation
  5. import numpy as np
  6. from PIL import Image
  7. class TestAutoAugment(unittest.TestCase):
  8. def setUp(self):
  9. self.dataset_params = {"batch_size": 1, "color_jitter": 0.1, "rand_augment_config_string": "m9-mstd0.5"}
  10. def test_autoaugment_call(self):
  11. """
  12. tests a simple call to auto augment and other augmentations and verifies image size
  13. """
  14. image_size = 224
  15. color_augmentation = get_color_augmentation("m9-mstd0.5", color_jitter=None, crop_size=image_size)
  16. self.assertTrue(isinstance(color_augmentation, RandAugment))
  17. img = Image.fromarray(np.ones((image_size, image_size, 3)).astype("uint8"))
  18. augmented_image = color_augmentation(img)
  19. self.assertTrue(augmented_image.size == (image_size, image_size))
  20. color_augmentation = get_color_augmentation(None, color_jitter=(0.7, 0.7, 0.7), crop_size=image_size)
  21. self.assertTrue(isinstance(color_augmentation, transforms.ColorJitter))
  22. img = Image.fromarray(np.random.randn(image_size, image_size, 3).astype("uint8"))
  23. augmented_image = color_augmentation(img)
  24. self.assertTrue(augmented_image.size == (image_size, image_size))
  25. if __name__ == "__main__":
  26. unittest.main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...