Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

test.py 1.5 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
  1. import unittest
  2. from collections import defaultdict
  3. import torch
  4. import torchvision.transforms as transforms
  5. from sampler import PKSampler
  6. from torch.utils.data import DataLoader
  7. from torchvision.datasets import FakeData
  8. class Tester(unittest.TestCase):
  9. def test_pksampler(self):
  10. p, k = 16, 4
  11. # Ensure sampler does not allow p to be greater than num_classes
  12. dataset = FakeData(size=100, num_classes=10, image_size=(3, 1, 1))
  13. targets = [target.item() for _, target in dataset]
  14. self.assertRaises(AssertionError, PKSampler, targets, p, k)
  15. # Ensure p, k constraints on batch
  16. trans = transforms.Compose(
  17. [
  18. transforms.PILToTensor(),
  19. transforms.ConvertImageDtype(torch.float),
  20. ]
  21. )
  22. dataset = FakeData(size=1000, num_classes=100, image_size=(3, 1, 1), transform=trans)
  23. targets = [target.item() for _, target in dataset]
  24. sampler = PKSampler(targets, p, k)
  25. loader = DataLoader(dataset, batch_size=p * k, sampler=sampler)
  26. for _, labels in loader:
  27. bins = defaultdict(int)
  28. for label in labels.tolist():
  29. bins[label] += 1
  30. # Ensure that each batch has samples from exactly p classes
  31. self.assertEqual(len(bins), p)
  32. # Ensure that there are k samples from each class
  33. for b in bins:
  34. self.assertEqual(bins[b], k)
  35. if __name__ == "__main__":
  36. unittest.main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...