Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

optimizers.py 2.4 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
  1. #coding:utf-8
  2. import os, sys
  3. import os.path as osp
  4. import numpy as np
  5. import torch
  6. from torch import nn
  7. from torch.optim import Optimizer
  8. from functools import reduce
  9. from torch.optim import AdamW
  10. class MultiOptimizer:
  11. def __init__(self, optimizers={}, schedulers={}):
  12. self.optimizers = optimizers
  13. self.schedulers = schedulers
  14. self.keys = list(optimizers.keys())
  15. self.param_groups = reduce(lambda x,y: x+y, [v.param_groups for v in self.optimizers.values()])
  16. def state_dict(self):
  17. state_dicts = [(key, self.optimizers[key].state_dict())\
  18. for key in self.keys]
  19. return state_dicts
  20. def load_state_dict(self, state_dict):
  21. for key, val in state_dict:
  22. try:
  23. self.optimizers[key].load_state_dict(val)
  24. except:
  25. print("Unloaded %s" % key)
  26. def step(self, key=None, scaler=None):
  27. keys = [key] if key is not None else self.keys
  28. _ = [self._step(key, scaler) for key in keys]
  29. def _step(self, key, scaler=None):
  30. if scaler is not None:
  31. scaler.step(self.optimizers[key])
  32. scaler.update()
  33. else:
  34. self.optimizers[key].step()
  35. def zero_grad(self, key=None):
  36. if key is not None:
  37. self.optimizers[key].zero_grad()
  38. else:
  39. _ = [self.optimizers[key].zero_grad() for key in self.keys]
  40. def scheduler(self, *args, key=None):
  41. if key is not None:
  42. self.schedulers[key].step(*args)
  43. else:
  44. _ = [self.schedulers[key].step(*args) for key in self.keys]
  45. def define_scheduler(optimizer, params):
  46. scheduler = torch.optim.lr_scheduler.OneCycleLR(
  47. optimizer,
  48. max_lr=params.get('max_lr', 2e-4),
  49. epochs=params.get('epochs', 200),
  50. steps_per_epoch=params.get('steps_per_epoch', 1000),
  51. pct_start=params.get('pct_start', 0.0),
  52. div_factor=1,
  53. final_div_factor=1)
  54. return scheduler
  55. def build_optimizer(parameters_dict, scheduler_params_dict, lr):
  56. optim = dict([(key, AdamW(params, lr=lr, weight_decay=1e-4, betas=(0.0, 0.99), eps=1e-9))
  57. for key, params in parameters_dict.items()])
  58. schedulers = dict([(key, define_scheduler(opt, scheduler_params_dict[key])) \
  59. for key, opt in optim.items()])
  60. multi_optim = MultiOptimizer(optim, schedulers)
  61. return multi_optim
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...