Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

random_erase_test.py 606 B

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
  1. import torch
  2. import unittest
  3. from super_gradients.training.datasets.data_augmentation import RandomErase
  4. class RandomEraseTest(unittest.TestCase):
  5. def test_random_erase(self):
  6. dummy_input = torch.randn(1, 3, 32, 32)
  7. one_erase = RandomErase(probability=0, value='1.')
  8. self.assertEqual(one_erase.p, 0)
  9. self.assertEqual(one_erase.value, 1.)
  10. one_erase(dummy_input)
  11. rndm_erase = RandomErase(probability=0, value='random')
  12. self.assertEqual(rndm_erase.value, 'random')
  13. rndm_erase(dummy_input)
  14. if __name__ == '__main__':
  15. unittest.main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...