Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

distributed_training_imagenet.py 3.7 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
  1. #!/usr/bin/env python
  2. """ Single node distributed training.
  3. The program will dispatch distributed training on all available GPUs residing in a single node.
  4. Usage:
  5. python -m torch.distributed.launch --nproc_per_node=n distributed_training_imagenet.py
  6. where n is the number of GPUs required, e.g., n=8
  7. Important note: (1) in distributed training it is customary to specify learning rates and batch sizes per GPU.
  8. Whatever learning rate and schedule you specify will be applied to the each GPU individually.
  9. Since gradients are passed and summed (reduced) from all to all GPUs, the effective batch size is the
  10. batch you specify times the number of GPUs. In the literature there are several "best practices" to set
  11. learning rates and schedules for large batch sizes.
  12. Should be checked with. (2) The training protocol specified in this file for 8 GPUs are far from optimal.
  13. The best protocol should use cosine schedule.
  14. In the example below: for ImageNet training using Resnet50, when applied with n=8 should compute an Eopch in about
  15. 5min20sec with 8 V100 GPUs.
  16. Todo: (1) the code is more or less ready for multiple nodes, but I have not experimented with it at all.
  17. (2) detection and segmentation codes were not modified and should not work properly.
  18. Specifically, the analogue changes done in sg_classification_model should be done also in
  19. deci_segmentation_model and deci_detection_model
  20. """
  21. import super_gradients
  22. import torch.distributed
  23. from super_gradients.training.sg_model import MultiGPUMode
  24. from super_gradients.training import SgModel
  25. from super_gradients.training.datasets.dataset_interfaces import ImageNetDatasetInterface
  26. from super_gradients.common.aws_connection.aws_secrets_manager_connector import AWSSecretsManagerConnector
  27. from super_gradients.training.metrics.classification_metrics import Accuracy, Top5
  28. torch.backends.cudnn.benchmark = True
  29. super_gradients.init_trainer()
  30. # TODO - VALIDATE THE HYPER PARAMETERS WITH RAN TO FIX THIS EXAMPLE CODE
  31. train_params = {"max_epochs": 110,
  32. "lr_updates": [30, 60, 90],
  33. "lr_decay_factor": 0.1,
  34. "initial_lr": 0.6,
  35. "loss": "cross_entropy",
  36. "lr_mode": "step",
  37. # "initial_lr": 0.05 * 2,
  38. "lr_warmup_epochs": 5,
  39. # "criterion_params":{"smooth_eps":0.1}}
  40. "mixed_precision": True,
  41. # "mixed_precision_opt_level": "O3",
  42. "optimizer_params": {"weight_decay": 0.000, "momentum": 0.9},
  43. # "optimizer_params": {"weight_decay": 0.0001, "momentum": 0.9}
  44. "train_metrics_list": [Accuracy(), Top5()], "valid_metrics_list": [Accuracy(), Top5()],
  45. "loss_logging_items_names": ["Loss"], "metric_to_watch": "Accuracy",
  46. "greater_metric_to_watch_is_better": True}
  47. dataset_params = {"batch_size": 128}
  48. model_repo_bucket_name = AWSSecretsManagerConnector.get_secret_value_for_secret_key(aws_env='research',
  49. secret_name='training_secrets',
  50. secret_key='S3.MODEL_REPOSITORY_BUCKET_NAME')
  51. model = SgModel("test_checkpoints_resnet_8_gpus",
  52. model_checkpoints_location='s3://' + model_repo_bucket_name,
  53. multi_gpu=MultiGPUMode.DISTRIBUTED_DATA_PARALLEL
  54. )
  55. # FOR AWS
  56. dataset = ImageNetDatasetInterface(data_dir="/data/Imagenet", dataset_params=dataset_params)
  57. model.connect_dataset_interface(dataset, data_loader_num_workers=8)
  58. model.build_model("resnet50", load_checkpoint=False)
  59. model.train(training_params=train_params)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...