Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

schedulers.py 1.0 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
  1. """
  2. @name schedulers.py
  3. @info This module contains implementations of learning rate schedulers.
  4. @organization: University Laval
  5. @author Gabriel Ramos
  6. @email gabriel.ramos.1@ulaval.ca
  7. """
  8. from typing import Callable
  9. import ml_collections
  10. import optax
  11. def create_learning_rate_fn(
  12. config: ml_collections.ConfigDict,
  13. base_learning_rate: float,
  14. steps_per_epoch: int
  15. ) -> Callable:
  16. """Create learning rate schedule."""
  17. warmup_fn = optax.linear_schedule(
  18. init_value=0.,
  19. end_value=base_learning_rate,
  20. transition_steps=config.warmup_epochs * steps_per_epoch
  21. )
  22. cosine_epochs = max(
  23. config.num_epochs - config.warmup_epochs, 1
  24. )
  25. cosine_fn = optax.cosine_decay_schedule(
  26. init_value=base_learning_rate,
  27. decay_steps=cosine_epochs * steps_per_epoch
  28. )
  29. schedule_fn = optax.join_schedules(
  30. schedules=[warmup_fn, cosine_fn],
  31. boundaries=[config.warmup_epochs * steps_per_epoch]
  32. )
  33. return schedule_fn
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...