Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#284 Fix training prints

Merged
GitHub User merged 1 commits into Deci-AI:master from deci-ai:hotfix/SG-000-fix_training_prints
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
  1. from super_gradients.training.utils import HpmStruct
  2. DEFAULT_TRAINING_PARAMS = {"lr_warmup_epochs": 0,
  3. "lr_cooldown_epochs": 0,
  4. "warmup_initial_lr": None,
  5. "cosine_final_lr_ratio": 0.01,
  6. "optimizer": "SGD",
  7. "criterion_params": {},
  8. "ema": False,
  9. "batch_accumulate": 1, # number of batches to accumulate before every backward pass
  10. "ema_params": {},
  11. "zero_weight_decay_on_bias_and_bn": False,
  12. "load_opt_params": True,
  13. "run_validation_freq": 1,
  14. "save_model": True,
  15. "metric_to_watch": "Accuracy",
  16. "launch_tensorboard": False,
  17. "tb_files_user_prompt": False, # Asks User for Tensorboard Deletion Prompt
  18. "silent_mode": False, # Silents the Print outs
  19. "mixed_precision": False,
  20. "tensorboard_port": None,
  21. "save_ckpt_epoch_list": [], # indices where the ckpt will save automatically
  22. "average_best_models": True,
  23. "dataset_statistics": False, # add a dataset statistical analysis and sample images to tensorboard
  24. "save_tensorboard_to_s3": False,
  25. "lr_schedule_function": None,
  26. "train_metrics_list": [],
  27. "valid_metrics_list": [],
  28. "loss_logging_items_names": ["Loss"],
  29. "greater_metric_to_watch_is_better": True,
  30. "precise_bn": False,
  31. "precise_bn_batch_size": None,
  32. "seed": 42,
  33. "lr_mode": None,
  34. "phase_callbacks": None,
  35. "log_installed_packages": True,
  36. "save_full_train_log": False,
  37. "sg_logger": "base_sg_logger",
  38. "sg_logger_params":
  39. {"tb_files_user_prompt": False, # Asks User for Tensorboard Deletion Prompt
  40. "project_name": "",
  41. "launch_tensorboard": False,
  42. "tensorboard_port": None,
  43. "save_checkpoints_remote": False, # upload checkpoint files to s3
  44. "save_tensorboard_remote": False, # upload tensorboard files to s3
  45. "save_logs_remote": False}, # upload log files to s3
  46. "warmup_mode": "linear_step",
  47. "step_lr_update_freq": None,
  48. "lr_updates": [],
  49. 'clip_grad_norm': None,
  50. 'pre_prediction_callback': None,
  51. 'ckpt_best_name': 'ckpt_best.pth',
  52. 'enable_qat': False,
  53. 'qat_params': {
  54. "start_epoch": 0,
  55. "quant_modules_calib_method": "percentile",
  56. "per_channel_quant_modules": False,
  57. "calibrate": True,
  58. "calibrated_model_path": None,
  59. "calib_data_loader": None,
  60. "num_calib_batches": 2,
  61. "percentile": 99.99
  62. }
  63. }
  64. DEFAULT_OPTIMIZER_PARAMS_SGD = {"weight_decay": 1e-4, "momentum": 0.9}
  65. DEFAULT_OPTIMIZER_PARAMS_ADAM = {"weight_decay": 1e-4}
  66. DEFAULT_OPTIMIZER_PARAMS_RMSPROP = {"weight_decay": 1e-4, "momentum": 0.9}
  67. DEFAULT_OPTIMIZER_PARAMS_RMSPROPTF = {"weight_decay": 1e-4, "momentum": 0.9}
  68. TRAINING_PARAM_SCHEMA = {"type": "object",
  69. "properties": {
  70. "max_epochs": {"type": "number", "minimum": 1},
  71. # FIXME: CHECK THE IMPORTANCE OF THE COMMENTED SCHEMA- AS IT CAUSES HYDRA USE TO CRASH
  72. # "lr_updates": {"type": "array", "minItems": 1},
  73. "lr_decay_factor": {"type": "number", "minimum": 0, "maximum": 1},
  74. "lr_warmup_epochs": {"type": "number", "minimum": 0, "maximum": 10},
  75. "initial_lr": {"type": "number", "exclusiveMinimum": 0, "maximum": 10}
  76. },
  77. "if": {
  78. "properties": {"lr_mode": {"const": "step"}}
  79. },
  80. "then": {
  81. "required": ["lr_updates", "lr_decay_factor"]
  82. },
  83. "required": ["max_epochs", "lr_mode", "initial_lr", "loss"]
  84. }
  85. class TrainingParams(HpmStruct):
  86. def __init__(self, **entries):
  87. # WE initialize by the default training params, overridden by the provided params
  88. super().__init__(**DEFAULT_TRAINING_PARAMS)
  89. self.set_schema(TRAINING_PARAM_SCHEMA)
  90. if len(entries) > 0:
  91. self.override(**entries)
  92. def override(self, **entries):
  93. super().override(**entries)
  94. self.validate()
Discard
Tip!

Press p or to see the previous file or, n or to see the next file