Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

verify_distributed_sampler_wrapper.py 3.4 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
  1. import collections
  2. import sys
  3. from itertools import chain
  4. import torch
  5. from torch.utils.data import Dataset, DataLoader
  6. from super_gradients import setup_device
  7. from super_gradients.training.datasets.samplers.distributed_sampler_wrapper import DistributedSamplerWrapper
  8. class DummyDataset(Dataset):
  9. def __init__(self, length=42):
  10. super().__init__()
  11. self.length = length
  12. def __getitem__(self, index):
  13. return -index
  14. def __len__(self):
  15. return self.length
  16. class RepeatSampler(torch.utils.data.Sampler):
  17. def __init__(self, data_source, repeat_times):
  18. self.data_source = data_source
  19. self.repeat_times = repeat_times
  20. self.num_samples = repeat_times * len(data_source)
  21. def __iter__(self):
  22. indices = list(range(len(self.data_source)))
  23. return iter(indices * self.repeat_times)
  24. def __len__(self):
  25. return self.num_samples
  26. def aggregate_epoch(data_loader):
  27. results = list()
  28. for batch in data_loader:
  29. for element in batch:
  30. results.append(element.item())
  31. return results
  32. def compare_counts(x, y):
  33. return collections.Counter(x) == collections.Counter(y)
  34. if __name__ == "__main__":
  35. n_gpus = 2
  36. sampler_n_repeats = 3
  37. bs = 4
  38. data_size = 10 * n_gpus * bs
  39. setup_device(
  40. device="cuda",
  41. multi_gpu="DDP",
  42. num_gpus=n_gpus,
  43. )
  44. dataset = DummyDataset(length=data_size)
  45. sampler = RepeatSampler(dataset, repeat_times=sampler_n_repeats)
  46. dataloader = DataLoader(dataset, batch_size=bs, sampler=sampler)
  47. whole_epoch_data = list(chain.from_iterable([[-i] * sampler_n_repeats for i in range(data_size)]))
  48. # Test *non-distributed* sampler *in DDP mode*
  49. # THIS IS BAD EXAMPLE BECAUSE YOU EXPECT A DISTRIBUTED SAMPLER TO BE USED IN DDP MODE
  50. # The expected `len(dataloader)` when implemented correctly should ALSO be divided by `n_gpus`
  51. if len(dataloader) != (data_size * sampler_n_repeats) / bs:
  52. print(f"Wrong DataLoader length. Expected: {((data_size * sampler_n_repeats) / bs)=}, got {len(dataloader)}")
  53. torch.distributed.destroy_process_group()
  54. sys.exit(1)
  55. epoch_data_per_rank = aggregate_epoch(dataloader)
  56. if not compare_counts(epoch_data_per_rank, whole_epoch_data): # NOTE THAT EACH GPU SEES ALL DATA -- NOT WHAT WE WANT!
  57. torch.distributed.destroy_process_group()
  58. sys.exit(1)
  59. dist_sampler = DistributedSamplerWrapper(sampler)
  60. dataloader = DataLoader(dataset, batch_size=bs, sampler=dist_sampler)
  61. if len(dataloader) != (data_size * sampler_n_repeats) / (bs * n_gpus):
  62. print(f"Wrong DataLoader length. Expected: {((data_size * sampler_n_repeats) / (bs*n_gpus))=}, got {len(dataloader)}")
  63. torch.distributed.destroy_process_group()
  64. sys.exit(1)
  65. # We have dataset split across `n_gpus` processes. Let's aggregate and make sure we get the same results.
  66. per_rank_aggregated = torch.tensor(aggregate_epoch(dataloader)).cuda()
  67. all_gathered_placeholder = torch.zeros(len(per_rank_aggregated) * n_gpus, dtype=torch.int64).cuda()
  68. torch.distributed.all_gather_into_tensor(all_gathered_placeholder, per_rank_aggregated)
  69. if not compare_counts(all_gathered_placeholder.cpu().tolist(), whole_epoch_data):
  70. torch.distributed.destroy_process_group()
  71. sys.exit(1)
  72. torch.distributed.destroy_process_group()
  73. sys.exit(0)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...