Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#381 Feature/sg 000 connect to lab

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/sg-000_connect_to_lab
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
  1. import unittest
  2. import torchvision.transforms as transforms
  3. from super_gradients.training.datasets.auto_augment import RandAugment
  4. from super_gradients.training.datasets.datasets_utils import get_color_augmentation
  5. import numpy as np
  6. from PIL import Image
  7. class TestAutoAugment(unittest.TestCase):
  8. def setUp(self):
  9. self.dataset_params = {"batch_size": 1,
  10. "color_jitter": 0.1,
  11. 'rand_augment_config_string': "m9-mstd0.5"}
  12. def test_autoaugment_call(self):
  13. """
  14. tests a simple call to auto augment and other augmentations and verifies image size
  15. """
  16. image_size = 224
  17. color_augmentation = get_color_augmentation("m9-mstd0.5", color_jitter=None, crop_size=image_size)
  18. self.assertTrue(isinstance(color_augmentation, RandAugment))
  19. img = Image.fromarray(np.ones((image_size, image_size, 3)).astype('uint8'))
  20. augmented_image = color_augmentation(img)
  21. self.assertTrue(augmented_image.size == (image_size, image_size))
  22. color_augmentation = get_color_augmentation(None, color_jitter=(0.7, 0.7, 0.7), crop_size=image_size)
  23. self.assertTrue(isinstance(color_augmentation, transforms.ColorJitter))
  24. img = Image.fromarray(np.random.randn(image_size, image_size, 3).astype('uint8'))
  25. augmented_image = color_augmentation(img)
  26. self.assertTrue(augmented_image.size == (image_size, image_size))
  27. if __name__ == '__main__':
  28. unittest.main()
Discard
Tip!

Press p or to see the previous file or, n or to see the next file