Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

default_train_params.yaml 5.7 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
  1. resume: False # whether to continue training from ckpt with the same experiment name.
  2. resume_path: # Explicit checkpoint path (.pth file) to use to resume training.
  3. ckpt_name: ckpt_latest.pth # The checkpoint (.pth file) filename in CKPT_ROOT_DIR/EXPERIMENT_NAME/ to use when resume=True and resume_path=None
  4. lr_mode: # Learning rate scheduling policy, one of ['step','poly','cosine','function']
  5. lr_schedule_function: # Learning rate scheduling function to be used when `lr_mode` is 'function'.
  6. lr_warmup_epochs: 0 # number of epochs for learning rate warm up - see https://arxiv.org/pdf/1706.02677.pdf (Section 2.2).
  7. lr_cooldown_epochs: 0 # epochs to cooldown LR (i.e the last epoch from scheduling view point=max_epochs-cooldown)
  8. warmup_initial_lr: # Initial lr for linear_step. When none is given, initial_lr/(warmup_epochs+1) will be used.
  9. step_lr_update_freq: # (float) update frequency in epoch units for computing lr_updates when lr_mode=`step`.
  10. cosine_final_lr_ratio: 0.01 # final learning rate ratio (only relevant when `lr_mode`='cosine')
  11. warmup_mode: linear_step # learning rate warmup scheme, currently only 'linear_step' is supported
  12. lr_updates:
  13. _target_: super_gradients.training.utils.utils.empty_list # This is a workaround to instantiate a list using _target_. If we would instantiate as "lr_updates: []",
  14. # we would get an error every time we would want to overwrite lr_updates with a numpy array.
  15. pre_prediction_callback: # callback modifying images and targets right before forward pass.
  16. optimizer: SGD # Optimization algorithm. One of ['Adam','SGD','RMSProp'] corresponding to the torch.optim optimizers
  17. optimizer_params: {} # when `optimizer` is one of ['Adam','SGD','RMSProp'], it will be initialized with optimizer_params.
  18. load_opt_params: True # Whether to load the optimizers parameters as well when loading a model's checkpoint
  19. zero_weight_decay_on_bias_and_bn: False # whether to apply weight decay on batch normalization parameters or not
  20. loss: # Loss function for training (str as one of SuperGradient's built in options, or torch.nn.module)
  21. criterion_params: {} # when `loss` is one of SuperGradient's built in options, it will be initialized with criterion_params.
  22. ema: False # whether to use Model Exponential Moving Average
  23. ema_params: # parameters for the ema model.
  24. decay: 0.9999
  25. beta: 15
  26. exp_activation: True
  27. train_metrics_list: [] # Metrics to log during training. For more information on torchmetrics see https://torchmetrics.rtfd.io/en/latest/.
  28. valid_metrics_list: [] # Metrics to log during validation. For more information on torchmetrics see https://torchmetrics.rtfd.io/en/latest/
  29. metric_to_watch: Accuracy # will be the metric which the model checkpoint will be saved according to
  30. greater_metric_to_watch_is_better: True # When choosing a model's checkpoint to be saved, the best achieved model is the one that maximizes the metric_to_watch when this parameter is set to True
  31. launch_tensorboard: False # Whether to launch a TensorBoard process.
  32. tensorboard_port: # port for tensorboard process
  33. tb_files_user_prompt: False # Asks User for Tensorboard Deletion Prompt
  34. save_tensorboard_to_s3: False # whether to save tb to s3
  35. precise_bn: False # Whether to use precise_bn calculation during the training.
  36. precise_bn_batch_size: # the effective batch size we want to calculate the batchnorm on.
  37. silent_mode: False # Silents the Print outs
  38. mixed_precision: False # Whether to use mixed precision or not.
  39. save_ckpt_epoch_list: [] # indices where the ckpt will save automatically
  40. average_best_models: True # If set, a snapshot dictionary file and the average model will be saved
  41. dataset_statistics: False # add a dataset statistical analysis and sample images to tensorboard
  42. batch_accumulate: 1 # number of batches to accumulate before every backward pass
  43. run_validation_freq: 1 # The frequency in which validation is performed during training
  44. save_model: True # Whether to save the model checkpoints
  45. seed: 42 # seed for reproducibility
  46. phase_callbacks: [] # list of callbacks to be applied at specific phases.
  47. log_installed_packages: True # when set, the list of all installed packages (and their versions) will be written to the tensorboard
  48. save_full_train_log: False # When set, a full log (of all super_gradients modules, including uncaught exceptions from any other module) of training will be saved
  49. clip_grad_norm : # Defines a maximal L2 norm of the gradients. Values which exceed the given value will be clipped
  50. ckpt_best_name: ckpt_best.pth
  51. enable_qat: False # enables quantization aware training
  52. qat_params:
  53. start_epoch: 0 # int, first epoch to start QAT. Must be lower than `max_epochs`.
  54. quant_modules_calib_method: percentile # str, One of [percentile, mse, entropy, max]. Statistics method for amax computation of the quantized modules.
  55. per_channel_quant_modules: False # bool, whether quant modules should be per channel.
  56. calibrate: True # bool, whether to perfrom calibration.
  57. calibrated_model_path: # str, path to a calibrated checkpoint (default=None).
  58. calib_data_loader: # torch.utils.data.DataLoader, data loader of the calibration dataset. When None, context.train_loader will be used (default=None).
  59. num_calib_batches: 2 # int, number of batches to collect the statistics from.
  60. percentile: 99.99 # float, percentile value to use when quant_modules_calib_method='percentile'. Discarded when other methods are used (Default=99.99).
  61. sg_logger: base_sg_logger
  62. sg_logger_params:
  63. tb_files_user_prompt: False # Asks User for Tensorboard Deletion Prompt
  64. launch_tensorboard: False
  65. tensorboard_port:
  66. save_checkpoints_remote: False # upload checkpoint files to s3
  67. save_tensorboard_remote: False # upload tensorboard files to s3
  68. save_logs_remote: False # upload log files to s3
  69. _convert_: all
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...