Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

imagenet_resnet_example.py 1.3 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
  1. """
  2. ResNet50 Imagenet classification training:
  3. This example trains with batch_size = 64 * 4 GPUs, total 256.
  4. Training times:
  5. ResNet18: 36 hours with 4 X NVIDIA RTX A5000.
  6. ResNet34: 36 hours with 4 X NVIDIA RTX A5000.
  7. ResNet50: 46 hours with 4 X GeForce RTX 3090 Ti.
  8. Top1, Top5 results:
  9. ResNet18: Top1: 70.60 Top5: 89.64
  10. ResNet34: Top1: 74.13 Top5: 91.70
  11. ResNet50: Top1: 76.30 Top5: 93.03
  12. BE AWARE THAT THIS RECIPE USE DATA_PARALLEL, WHEN USING DDP FOR DISTRIBUTED TRAINING THIS RECIPE REACH ONLY 75.4 TOP1
  13. ACCURACY.
  14. """
  15. import super_gradients
  16. from omegaconf import DictConfig
  17. import hydra
  18. import pkg_resources
  19. @hydra.main(config_path=pkg_resources.resource_filename("conf", ""), config_name="imagenet_resnet50_conf")
  20. def train(cfg: DictConfig) -> None:
  21. # INSTANTIATE ALL OBJECTS IN CFG
  22. cfg = hydra.utils.instantiate(cfg)
  23. # CONNECT THE DATASET INTERFACE WITH DECI MODEL
  24. cfg.sg_model.connect_dataset_interface(cfg.dataset_interface, data_loader_num_workers=cfg.data_loader_num_workers)
  25. # BUILD NETWORK
  26. cfg.sg_model.build_model(cfg.architecture, arch_params=cfg.arch_params, load_checkpoint=cfg.load_checkpoint)
  27. # TRAIN
  28. cfg.sg_model.train(training_params=cfg.training_params)
  29. if __name__ == "__main__":
  30. super_gradients.init_trainer()
  31. train()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...