Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#468 Bug/sg 399 external checkpoints fix

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:bug/SG-399_external_checkpoints_fix
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
  1. import torch.optim as optim
  2. import torch.nn as nn
  3. from torch.nn.modules.batchnorm import _BatchNorm
  4. from torch.nn.modules.conv import _ConvNd
  5. from super_gradients.common.abstractions.abstract_logger import get_logger
  6. from super_gradients.common.factories.optimizers_type_factory import OptimizersTypeFactory
  7. from super_gradients.training.params import (
  8. DEFAULT_OPTIMIZER_PARAMS_SGD,
  9. DEFAULT_OPTIMIZER_PARAMS_ADAM,
  10. DEFAULT_OPTIMIZER_PARAMS_RMSPROP,
  11. DEFAULT_OPTIMIZER_PARAMS_RMSPROPTF,
  12. )
  13. from super_gradients.training.utils import get_param
  14. from super_gradients.training.utils.optimizers.rmsprop_tf import RMSpropTF
  15. logger = get_logger(__name__)
  16. OPTIMIZERS_DEFAULT_PARAMS = {
  17. optim.SGD: DEFAULT_OPTIMIZER_PARAMS_SGD,
  18. optim.Adam: DEFAULT_OPTIMIZER_PARAMS_ADAM,
  19. optim.RMSprop: DEFAULT_OPTIMIZER_PARAMS_RMSPROP,
  20. RMSpropTF: DEFAULT_OPTIMIZER_PARAMS_RMSPROPTF,
  21. }
  22. def separate_zero_wd_params_groups_for_optimizer(module: nn.Module, net_named_params, weight_decay: float):
  23. """
  24. separate param groups for batchnorm and biases and others with weight decay. return list of param groups in format
  25. required by torch Optimizer classes.
  26. bias + BN with weight decay=0 and the rest with the given weight decay
  27. :param module: train net module.
  28. :param net_named_params: list of params groups, output of SgModule.initialize_param_groups
  29. :param weight_decay: value to set for the non BN and bias parameters
  30. """
  31. # FIXME - replace usage of ids addresses to find batchnorm and biases params.
  32. # This solution iterate 2 times over module parameters, find a way to iterate only one time.
  33. no_decay_ids = _get_no_decay_param_ids(module)
  34. # split param groups for optimizer
  35. optimizer_param_groups = []
  36. for param_group in net_named_params:
  37. no_decay_params = []
  38. decay_params = []
  39. for name, param in param_group["named_params"]:
  40. if id(param) in no_decay_ids:
  41. no_decay_params.append(param)
  42. else:
  43. decay_params.append(param)
  44. # append two param groups from the original param group, with and without weight decay.
  45. extra_optim_params = {key: param_group[key] for key in param_group if key not in ["named_params", "weight_decay"]}
  46. optimizer_param_groups.append({"params": no_decay_params, "weight_decay": 0.0, **extra_optim_params})
  47. optimizer_param_groups.append({"params": decay_params, "weight_decay": weight_decay, **extra_optim_params})
  48. return optimizer_param_groups
  49. def _get_no_decay_param_ids(module: nn.Module):
  50. # FIXME - replace usage of ids addresses to find batchnorm and biases params.
  51. # Use other common way to identify torch parameters other than id or layer names
  52. """
  53. Iterate over module.modules() and returns params id addresses of batch-norm and biases params.
  54. NOTE - ALL MODULES WITH ATTRIBUTES NAMED BIAS AND ARE INSTANCE OF nn.Parameter WILL BE CONSIDERED A BIAS PARAM FOR
  55. ZERO WEIGHT DECAY.
  56. """
  57. batchnorm_types = (_BatchNorm,)
  58. torch_weight_with_bias_types = (_ConvNd, nn.Linear)
  59. no_decay_ids = []
  60. for name, m in module.named_modules():
  61. if isinstance(m, batchnorm_types):
  62. no_decay_ids.append(id(m.weight))
  63. no_decay_ids.append(id(m.bias))
  64. elif hasattr(m, "bias") and isinstance(m.bias, nn.Parameter):
  65. if not isinstance(m, torch_weight_with_bias_types):
  66. logger.warning(
  67. f"Module class: {m.__class__}, have a `bias` parameter attribute but is not instance of"
  68. f" torch primitive modules, this bias parameter will be part of param group with zero"
  69. f" weight decay."
  70. )
  71. no_decay_ids.append(id(m.bias))
  72. return no_decay_ids
  73. def build_optimizer(net, lr, training_params):
  74. """
  75. Wrapper function for initializing the optimizer
  76. :param net: the nn_module to build the optimizer for
  77. :param lr: initial learning rate
  78. :param training_params: training_parameters
  79. """
  80. if isinstance(training_params.optimizer, str):
  81. optimizer_cls = OptimizersTypeFactory().get(training_params.optimizer)
  82. else:
  83. optimizer_cls = training_params.optimizer
  84. optimizer_params = OPTIMIZERS_DEFAULT_PARAMS[optimizer_cls].copy() if optimizer_cls in OPTIMIZERS_DEFAULT_PARAMS.keys() else dict()
  85. optimizer_params.update(**training_params.optimizer_params)
  86. training_params.optimizer_params = optimizer_params
  87. weight_decay = get_param(training_params.optimizer_params, "weight_decay", 0.0)
  88. # OPTIMIZER PARAM GROUPS ARE SET USING DEFAULT OR MODEL SPECIFIC INIT
  89. if hasattr(net.module, "initialize_param_groups"):
  90. # INITIALIZE_PARAM_GROUPS MUST RETURN A LIST OF DICTS WITH 'named_params' AND OPTIMIZER's ATTRIBUTES PER GROUP
  91. net_named_params = net.module.initialize_param_groups(lr, training_params)
  92. else:
  93. net_named_params = [{"named_params": net.named_parameters()}]
  94. if training_params.zero_weight_decay_on_bias_and_bn:
  95. optimizer_training_params = separate_zero_wd_params_groups_for_optimizer(net.module, net_named_params, weight_decay)
  96. else:
  97. # Overwrite groups to include params instead of named params
  98. for ind_group, param_group in enumerate(net_named_params):
  99. param_group["params"] = [param[1] for param in list(param_group["named_params"])]
  100. del param_group["named_params"]
  101. net_named_params[ind_group] = param_group
  102. optimizer_training_params = net_named_params
  103. # CREATE AN OPTIMIZER OBJECT AND INITIALIZE IT
  104. optimizer = optimizer_cls(optimizer_training_params, lr=lr, **training_params.optimizer_params)
  105. return optimizer
Discard
Tip!

Press p or to see the previous file or, n or to see the next file