Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#609 Ci fix

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:bugfix/infra-000_ci
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
  1. from super_gradients.training.utils import HpmStruct
  2. from copy import deepcopy
  3. DEFAULT_TRAINING_PARAMS = {
  4. "lr_warmup_epochs": 0,
  5. "lr_cooldown_epochs": 0,
  6. "warmup_initial_lr": None,
  7. "cosine_final_lr_ratio": 0.01,
  8. "optimizer": "SGD",
  9. "optimizer_params": {},
  10. "criterion_params": {},
  11. "ema": False,
  12. "batch_accumulate": 1, # number of batches to accumulate before every backward pass
  13. "ema_params": {},
  14. "zero_weight_decay_on_bias_and_bn": False,
  15. "load_opt_params": True,
  16. "run_validation_freq": 1,
  17. "save_model": True,
  18. "metric_to_watch": "Accuracy",
  19. "launch_tensorboard": False,
  20. "tb_files_user_prompt": False, # Asks User for Tensorboard Deletion Prompt
  21. "silent_mode": False, # Silents the Print outs
  22. "mixed_precision": False,
  23. "tensorboard_port": None,
  24. "save_ckpt_epoch_list": [], # indices where the ckpt will save automatically
  25. "average_best_models": True,
  26. "dataset_statistics": False, # add a dataset statistical analysis and sample images to tensorboard
  27. "save_tensorboard_to_s3": False,
  28. "lr_schedule_function": None,
  29. "train_metrics_list": [],
  30. "valid_metrics_list": [],
  31. "greater_metric_to_watch_is_better": True,
  32. "precise_bn": False,
  33. "precise_bn_batch_size": None,
  34. "seed": 42,
  35. "lr_mode": None,
  36. "phase_callbacks": None,
  37. "log_installed_packages": True,
  38. "sg_logger": "base_sg_logger",
  39. "sg_logger_params": {
  40. "tb_files_user_prompt": False, # Asks User for Tensorboard Deletion Prompt
  41. "project_name": "",
  42. "launch_tensorboard": False,
  43. "tensorboard_port": None,
  44. "save_checkpoints_remote": False, # upload checkpoint files to s3
  45. "save_tensorboard_remote": False, # upload tensorboard files to s3
  46. "save_logs_remote": False,
  47. }, # upload log files to s3
  48. "warmup_mode": "linear_step",
  49. "step_lr_update_freq": None,
  50. "lr_updates": [],
  51. "clip_grad_norm": None,
  52. "pre_prediction_callback": None,
  53. "ckpt_best_name": "ckpt_best.pth",
  54. "enable_qat": False,
  55. "qat_params": {
  56. "start_epoch": 0,
  57. "quant_modules_calib_method": "percentile",
  58. "per_channel_quant_modules": False,
  59. "calibrate": True,
  60. "calibrated_model_path": None,
  61. "calib_data_loader": None,
  62. "num_calib_batches": 2,
  63. "percentile": 99.99,
  64. },
  65. "resume": False,
  66. "resume_path": None,
  67. "ckpt_name": "ckpt_latest.pth",
  68. "resume_strict_load": False,
  69. "sync_bn": False,
  70. }
  71. DEFAULT_OPTIMIZER_PARAMS_SGD = {"weight_decay": 1e-4, "momentum": 0.9}
  72. DEFAULT_OPTIMIZER_PARAMS_ADAM = {"weight_decay": 1e-4}
  73. DEFAULT_OPTIMIZER_PARAMS_RMSPROP = {"weight_decay": 1e-4, "momentum": 0.9}
  74. DEFAULT_OPTIMIZER_PARAMS_RMSPROPTF = {"weight_decay": 1e-4, "momentum": 0.9}
  75. TRAINING_PARAM_SCHEMA = {
  76. "type": "object",
  77. "properties": {
  78. "max_epochs": {"type": "number", "minimum": 1},
  79. # FIXME: CHECK THE IMPORTANCE OF THE COMMENTED SCHEMA- AS IT CAUSES HYDRA USE TO CRASH
  80. # "lr_updates": {"type": "array", "minItems": 1},
  81. "lr_decay_factor": {"type": "number", "minimum": 0, "maximum": 1},
  82. "lr_warmup_epochs": {"type": "number", "minimum": 0, "maximum": 10},
  83. "initial_lr": {"type": "number", "exclusiveMinimum": 0, "maximum": 10},
  84. },
  85. "if": {"properties": {"lr_mode": {"const": "step"}}},
  86. "then": {"required": ["lr_updates", "lr_decay_factor"]},
  87. "required": ["max_epochs", "lr_mode", "initial_lr", "loss"],
  88. }
  89. class TrainingParams(HpmStruct):
  90. def __init__(self, **entries):
  91. # WE initialize by the default training params, overridden by the provided params
  92. default_training_params = deepcopy(DEFAULT_TRAINING_PARAMS)
  93. super().__init__(**default_training_params)
  94. self.set_schema(TRAINING_PARAM_SCHEMA)
  95. if len(entries) > 0:
  96. self.override(**entries)
  97. def override(self, **entries):
  98. super().override(**entries)
  99. self.validate()
Discard
Tip!

Press p or to see the previous file or, n or to see the next file