Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#970 Update YoloNASQuickstart.md

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:bugfix/SG-000_fix_readme_yolonas_snippets
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
  1. import os
  2. import torch
  3. import numpy as np
  4. from super_gradients.training.utils.checkpoint_utils import read_ckpt_state_dict
  5. from super_gradients.training.utils.utils import move_state_dict_to_device
  6. class ModelWeightAveraging:
  7. """
  8. Utils class for managing the averaging of the best several snapshots into a single model.
  9. A snapshot dictionary file and the average model will be saved / updated at every epoch and evaluated only when
  10. training is completed. The snapshot file will only be deleted upon completing the training.
  11. The snapshot dict will be managed on cpu.
  12. """
  13. def __init__(
  14. self,
  15. ckpt_dir,
  16. greater_is_better,
  17. metric_to_watch="acc",
  18. metric_idx=1,
  19. load_checkpoint=False,
  20. number_of_models_to_average=10,
  21. ):
  22. """
  23. Init the ModelWeightAveraging
  24. :param ckpt_dir: the directory where the checkpoints are saved
  25. :param metric_to_watch: monitoring loss or acc, will be identical to that which determines best_model
  26. :param metric_idx:
  27. :param load_checkpoint: whether to load pre-existing snapshot dict.
  28. :param number_of_models_to_average: number of models to average
  29. """
  30. self.averaging_snapshots_file = os.path.join(ckpt_dir, "averaging_snapshots.pkl")
  31. self.number_of_models_to_average = number_of_models_to_average
  32. self.metric_to_watch = metric_to_watch
  33. self.metric_idx = metric_idx
  34. self.greater_is_better = greater_is_better
  35. # if continuing training, copy previous snapshot dict if exist
  36. if load_checkpoint and ckpt_dir is not None and os.path.isfile(self.averaging_snapshots_file):
  37. averaging_snapshots_dict = read_ckpt_state_dict(self.averaging_snapshots_file)
  38. else:
  39. averaging_snapshots_dict = {"snapshot" + str(i): None for i in range(self.number_of_models_to_average)}
  40. # if metric to watch is acc, hold a zero array, if loss hold inf array
  41. if self.greater_is_better:
  42. averaging_snapshots_dict["snapshots_metric"] = -1 * np.inf * np.ones(self.number_of_models_to_average)
  43. else:
  44. averaging_snapshots_dict["snapshots_metric"] = np.inf * np.ones(self.number_of_models_to_average)
  45. torch.save(averaging_snapshots_dict, self.averaging_snapshots_file)
  46. def update_snapshots_dict(self, model, validation_results_tuple):
  47. """
  48. Update the snapshot dict and returns the updated average model for saving
  49. :param model: the latest model
  50. :param validation_results_tuple: performance of the latest model
  51. """
  52. averaging_snapshots_dict = self._get_averaging_snapshots_dict()
  53. # IF CURRENT MODEL IS BETTER, TAKING HIS PLACE IN ACC LIST AND OVERWRITE THE NEW AVERAGE
  54. require_update, update_ind = self._is_better(averaging_snapshots_dict, validation_results_tuple)
  55. if require_update:
  56. # moving state dict to cpu
  57. new_sd = model.state_dict()
  58. new_sd = move_state_dict_to_device(new_sd, "cpu")
  59. averaging_snapshots_dict["snapshot" + str(update_ind)] = new_sd
  60. averaging_snapshots_dict["snapshots_metric"][update_ind] = validation_results_tuple[self.metric_idx]
  61. return averaging_snapshots_dict
  62. def get_average_model(self, model, validation_results_tuple=None):
  63. """
  64. Returns the averaged model
  65. :param model: will be used to determine arch
  66. :param validation_results_tuple: if provided, will update the average model before returning
  67. :param target_device: if provided, return sd on target device
  68. """
  69. # If validation tuple is provided, update the average model
  70. if validation_results_tuple is not None:
  71. averaging_snapshots_dict = self.update_snapshots_dict(model, validation_results_tuple)
  72. else:
  73. averaging_snapshots_dict = self._get_averaging_snapshots_dict()
  74. torch.save(averaging_snapshots_dict, self.averaging_snapshots_file)
  75. average_model_sd = averaging_snapshots_dict["snapshot0"]
  76. for n_model in range(1, self.number_of_models_to_average):
  77. if averaging_snapshots_dict["snapshot" + str(n_model)] is not None:
  78. net_sd = averaging_snapshots_dict["snapshot" + str(n_model)]
  79. # USING MOVING AVERAGE
  80. for key in average_model_sd:
  81. average_model_sd[key] = torch.true_divide(average_model_sd[key] * n_model + net_sd[key], (n_model + 1))
  82. return average_model_sd
  83. def cleanup(self):
  84. """
  85. Delete snapshot file when reaching the last epoch
  86. """
  87. os.remove(self.averaging_snapshots_file)
  88. def _is_better(self, averaging_snapshots_dict, validation_results_tuple):
  89. """
  90. Determines if the new model is better according to the specified metrics
  91. :param averaging_snapshots_dict: snapshot dict
  92. :param validation_results_tuple: latest model performance
  93. """
  94. snapshot_metric_array = averaging_snapshots_dict["snapshots_metric"]
  95. val = validation_results_tuple[self.metric_idx]
  96. if self.greater_is_better:
  97. update_ind = np.argmin(snapshot_metric_array)
  98. else:
  99. update_ind = np.argmax(snapshot_metric_array)
  100. if (self.greater_is_better and val > snapshot_metric_array[update_ind]) or (not self.greater_is_better and val < snapshot_metric_array[update_ind]):
  101. return True, update_ind
  102. return False, None
  103. def _get_averaging_snapshots_dict(self):
  104. return torch.load(self.averaging_snapshots_file)
Discard
Tip!

Press p or to see the previous file or, n or to see the next file