Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

ema_decay_schedules.py 1.9 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
  1. import math
  2. from abc import abstractmethod
  3. __all__ = ["IDecayFunction", "ConstantDecay", "ThresholdDecay", "ExpDecay", "EMA_DECAY_FUNCTIONS"]
  4. class IDecayFunction:
  5. """
  6. Interface for EMA decay schedule. The decay schedule is a function of the maximum decay value and training progress.
  7. Usually it gradually increase EMA from to the maximum value. The exact ramp-up schedule is defined by the concrete
  8. implementation.
  9. """
  10. @abstractmethod
  11. def __call__(self, decay: float, step: int, total_steps: int) -> float:
  12. """
  13. :param decay: The maximum decay value.
  14. :param step: Current training step. The unit-range training percentage can be obtained by `step / total_steps`.
  15. :param total_steps: Total number of training steps.
  16. :return: Computed decay value for a given step.
  17. """
  18. pass
  19. class ConstantDecay(IDecayFunction):
  20. """
  21. Constant decay schedule.
  22. """
  23. def __init__(self, **kwargs):
  24. pass
  25. def __call__(self, decay: float, step: int, total_steps: int) -> float:
  26. return decay
  27. class ThresholdDecay(IDecayFunction):
  28. """
  29. Gradually increase EMA decay from 0.1 to the maximum value using following formula: min(decay, (1 + step) / (10 + step))
  30. """
  31. def __init__(self, **kwargs):
  32. pass
  33. def __call__(self, decay: float, step, total_steps: int) -> float:
  34. return min(decay, (1 + step) / (10 + step))
  35. class ExpDecay(IDecayFunction):
  36. """
  37. Gradually increase EMA decay from 0.1 to the maximum value using following formula: decay * (1 - math.exp(-x * self.beta))
  38. """
  39. def __init__(self, beta: float, **kwargs):
  40. self.beta = beta
  41. def __call__(self, decay: float, step, total_steps: int) -> float:
  42. x = step / total_steps
  43. return decay * (1 - math.exp(-x * self.beta))
  44. EMA_DECAY_FUNCTIONS = {"constant": ConstantDecay, "threshold": ThresholdDecay, "exp": ExpDecay}
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...