Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

schedule.py 3.3 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
  1. """
  2. This file is copied/apdated from https://github.com/berkeleydeeprlcourse/homework/tree/master/hw3
  3. """
  4. class Schedule(object):
  5. def value(self, t):
  6. """Value of the schedule at time t"""
  7. raise NotImplementedError()
  8. class ConstantSchedule(object):
  9. def __init__(self, value):
  10. """Value remains constant over time.
  11. Parameters
  12. ----------
  13. value: float
  14. Constant value of the schedule
  15. """
  16. self._v = value
  17. def value(self, t):
  18. """See Schedule.value"""
  19. return self._v
  20. def linear_interpolation(l, r, alpha):
  21. return l + alpha * (r - l)
  22. class PiecewiseSchedule(object):
  23. def __init__(self, endpoints, interpolation=linear_interpolation, outside_value=None):
  24. """Piecewise schedule.
  25. endpoints: [(int, int)]
  26. list of pairs `(time, value)` meanining that schedule should output
  27. `value` when `t==time`. All the values for time must be sorted in
  28. an increasing order. When t is between two times, e.g. `(time_a, value_a)`
  29. and `(time_b, value_b)`, such that `time_a <= t < time_b` then value outputs
  30. `interpolation(value_a, value_b, alpha)` where alpha is a fraction of
  31. time passed between `time_a` and `time_b` for time `t`.
  32. interpolation: lambda float, float, float: float
  33. a function that takes value to the left and to the right of t according
  34. to the `endpoints`. Alpha is the fraction of distance from left endpoint to
  35. right endpoint that t has covered. See linear_interpolation for example.
  36. outside_value: float
  37. if the value is requested outside of all the intervals sepecified in
  38. `endpoints` this value is returned. If None then AssertionError is
  39. raised when outside value is requested.
  40. """
  41. idxes = [e[0] for e in endpoints]
  42. assert idxes == sorted(idxes)
  43. self._interpolation = interpolation
  44. self._outside_value = outside_value
  45. self._endpoints = endpoints
  46. def value(self, t):
  47. """See Schedule.value"""
  48. for (l_t, l), (r_t, r) in zip(self._endpoints[:-1], self._endpoints[1:]):
  49. if l_t <= t and t < r_t:
  50. alpha = float(t - l_t) / (r_t - l_t)
  51. return self._interpolation(l, r, alpha)
  52. # t does not belong to any of the pieces, so doom.
  53. assert self._outside_value is not None
  54. return self._outside_value
  55. class LinearSchedule(object):
  56. def __init__(self, schedule_timesteps, final_p, initial_p=1.0):
  57. """Linear interpolation between initial_p and final_p over
  58. schedule_timesteps. After this many timesteps pass final_p is
  59. returned.
  60. Parameters
  61. ----------
  62. schedule_timesteps: int
  63. Number of timesteps for which to linearly anneal initial_p
  64. to final_p
  65. initial_p: float
  66. initial output value
  67. final_p: float
  68. final output value
  69. """
  70. self.schedule_timesteps = schedule_timesteps
  71. self.final_p = final_p
  72. self.initial_p = initial_p
  73. def value(self, t):
  74. """See Schedule.value"""
  75. fraction = min(float(t) / self.schedule_timesteps, 1.0)
  76. return self.initial_p + fraction * (self.final_p - self.initial_p)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...