Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

test_datasets_samplers.py 3.7 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
  1. import pytest
  2. import torch
  3. from common_utils import assert_equal, get_list_of_videos
  4. from torchvision import io
  5. from torchvision.datasets.samplers import DistributedSampler, RandomClipSampler, UniformClipSampler
  6. from torchvision.datasets.video_utils import VideoClips
  7. @pytest.mark.skipif(not io.video._av_available(), reason="this test requires av")
  8. class TestDatasetsSamplers:
  9. def test_random_clip_sampler(self, tmpdir):
  10. video_list = get_list_of_videos(tmpdir, num_videos=3, sizes=[25, 25, 25])
  11. video_clips = VideoClips(video_list, 5, 5)
  12. sampler = RandomClipSampler(video_clips, 3)
  13. assert len(sampler) == 3 * 3
  14. indices = torch.tensor(list(iter(sampler)))
  15. videos = torch.div(indices, 5, rounding_mode="floor")
  16. v_idxs, count = torch.unique(videos, return_counts=True)
  17. assert_equal(v_idxs, torch.tensor([0, 1, 2]))
  18. assert_equal(count, torch.tensor([3, 3, 3]))
  19. def test_random_clip_sampler_unequal(self, tmpdir):
  20. video_list = get_list_of_videos(tmpdir, num_videos=3, sizes=[10, 25, 25])
  21. video_clips = VideoClips(video_list, 5, 5)
  22. sampler = RandomClipSampler(video_clips, 3)
  23. assert len(sampler) == 2 + 3 + 3
  24. indices = list(iter(sampler))
  25. assert 0 in indices
  26. assert 1 in indices
  27. # remove elements of the first video, to simplify testing
  28. indices.remove(0)
  29. indices.remove(1)
  30. indices = torch.tensor(indices) - 2
  31. videos = torch.div(indices, 5, rounding_mode="floor")
  32. v_idxs, count = torch.unique(videos, return_counts=True)
  33. assert_equal(v_idxs, torch.tensor([0, 1]))
  34. assert_equal(count, torch.tensor([3, 3]))
  35. def test_uniform_clip_sampler(self, tmpdir):
  36. video_list = get_list_of_videos(tmpdir, num_videos=3, sizes=[25, 25, 25])
  37. video_clips = VideoClips(video_list, 5, 5)
  38. sampler = UniformClipSampler(video_clips, 3)
  39. assert len(sampler) == 3 * 3
  40. indices = torch.tensor(list(iter(sampler)))
  41. videos = torch.div(indices, 5, rounding_mode="floor")
  42. v_idxs, count = torch.unique(videos, return_counts=True)
  43. assert_equal(v_idxs, torch.tensor([0, 1, 2]))
  44. assert_equal(count, torch.tensor([3, 3, 3]))
  45. assert_equal(indices, torch.tensor([0, 2, 4, 5, 7, 9, 10, 12, 14]))
  46. def test_uniform_clip_sampler_insufficient_clips(self, tmpdir):
  47. video_list = get_list_of_videos(tmpdir, num_videos=3, sizes=[10, 25, 25])
  48. video_clips = VideoClips(video_list, 5, 5)
  49. sampler = UniformClipSampler(video_clips, 3)
  50. assert len(sampler) == 3 * 3
  51. indices = torch.tensor(list(iter(sampler)))
  52. assert_equal(indices, torch.tensor([0, 0, 1, 2, 4, 6, 7, 9, 11]))
  53. def test_distributed_sampler_and_uniform_clip_sampler(self, tmpdir):
  54. video_list = get_list_of_videos(tmpdir, num_videos=3, sizes=[25, 25, 25])
  55. video_clips = VideoClips(video_list, 5, 5)
  56. clip_sampler = UniformClipSampler(video_clips, 3)
  57. distributed_sampler_rank0 = DistributedSampler(
  58. clip_sampler,
  59. num_replicas=2,
  60. rank=0,
  61. group_size=3,
  62. )
  63. indices = torch.tensor(list(iter(distributed_sampler_rank0)))
  64. assert len(distributed_sampler_rank0) == 6
  65. assert_equal(indices, torch.tensor([0, 2, 4, 10, 12, 14]))
  66. distributed_sampler_rank1 = DistributedSampler(
  67. clip_sampler,
  68. num_replicas=2,
  69. rank=1,
  70. group_size=3,
  71. )
  72. indices = torch.tensor(list(iter(distributed_sampler_rank1)))
  73. assert len(distributed_sampler_rank1) == 6
  74. assert_equal(indices, torch.tensor([5, 7, 9, 0, 2, 4]))
  75. if __name__ == "__main__":
  76. pytest.main([__file__])
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...