Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

verify_min_samples_ddp.py 1.0 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
  1. import sys
  2. import numpy as np
  3. import torch
  4. from torch.utils.data import TensorDataset
  5. from super_gradients import setup_device
  6. from super_gradients.training import dataloaders
  7. def get_dataset(dataset_size, image_size):
  8. images = torch.Tensor(np.zeros((dataset_size, 3, image_size, image_size)))
  9. ground_truth = torch.LongTensor(np.zeros((dataset_size)))
  10. dataset = TensorDataset(images, ground_truth)
  11. return dataset
  12. if __name__ == "__main__":
  13. setup_device(
  14. device="cuda",
  15. multi_gpu="DDP",
  16. num_gpus=4,
  17. )
  18. dataset = get_dataset(dataset_size=64, image_size=32)
  19. dataloader = dataloaders.get(dataset=dataset, dataloader_params={"batch_size": 4, "min_samples": 80, "drop_last": True})
  20. if len(dataloader) == 5:
  21. torch.distributed.destroy_process_group()
  22. sys.exit(0)
  23. else:
  24. print(f"wrong DataLoader length, expected min_samples/(world_size*batch_size)=80/(4*4=5), got {len(dataloader)}")
  25. torch.distributed.destroy_process_group()
  26. sys.exit(1)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...