Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

replay_buffer.py 7.9 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
  1. """
  2. This file is copied/apdated from https://github.com/berkeleydeeprlcourse/homework/tree/master/hw3
  3. """
  4. import numpy as np
  5. import random
  6. def sample_n_unique(sampling_f, n):
  7. """Helper function. Given a function `sampling_f` that returns
  8. comparable objects, sample n such unique objects.
  9. """
  10. res = []
  11. while len(res) < n:
  12. candidate = sampling_f()
  13. if candidate not in res:
  14. res.append(candidate)
  15. return res
  16. class ReplayBuffer(object):
  17. def __init__(self, size, frame_history_len):
  18. """This is a memory efficient implementation of the replay buffer.
  19. The sepecific memory optimizations use here are:
  20. - only store each frame once rather than k times
  21. even if every observation normally consists of k last frames
  22. - store frames as np.uint8 (actually it is most time-performance
  23. to cast them back to float32 on GPU to minimize memory transfer
  24. time)
  25. - store frame_t and frame_(t+1) in the same buffer.
  26. For the typical use case in Atari Deep RL buffer with 1M frames the total
  27. memory footprint of this buffer is 10^6 * 84 * 84 bytes ~= 7 gigabytes
  28. Warning! Assumes that returning frame of zeros at the beginning
  29. of the episode, when there is less frames than `frame_history_len`,
  30. is acceptable.
  31. Parameters
  32. ----------
  33. size: int
  34. Max number of transitions to store in the buffer. When the buffer
  35. overflows the old memories are dropped.
  36. frame_history_len: int
  37. Number of memories to be retried for each observation.
  38. """
  39. self.size = size
  40. self.frame_history_len = frame_history_len
  41. self.next_idx = 0
  42. self.num_in_buffer = 0
  43. self.obs = None
  44. self.action = None
  45. self.reward = None
  46. self.done = None
  47. def can_sample(self, batch_size):
  48. """Returns true if `batch_size` different transitions can be sampled from the buffer."""
  49. return batch_size + 1 <= self.num_in_buffer
  50. def _encode_sample(self, idxes):
  51. obs_batch = np.concatenate([self._encode_observation(idx)[np.newaxis, :] for idx in idxes], 0)
  52. act_batch = self.action[idxes]
  53. rew_batch = self.reward[idxes]
  54. next_obs_batch = np.concatenate([self._encode_observation(idx + 1)[np.newaxis, :] for idx in idxes], 0)
  55. done_mask = np.array([1.0 if self.done[idx] else 0.0 for idx in idxes], dtype=np.float32)
  56. return obs_batch, act_batch, rew_batch, next_obs_batch, done_mask
  57. def sample(self, batch_size):
  58. """Sample `batch_size` different transitions.
  59. i-th sample transition is the following:
  60. when observing `obs_batch[i]`, action `act_batch[i]` was taken,
  61. after which reward `rew_batch[i]` was received and subsequent
  62. observation next_obs_batch[i] was observed, unless the epsiode
  63. was done which is represented by `done_mask[i]` which is equal
  64. to 1 if episode has ended as a result of that action.
  65. Parameters
  66. ----------
  67. batch_size: int
  68. How many transitions to sample.
  69. Returns
  70. -------
  71. obs_batch: np.array
  72. Array of shape
  73. (batch_size, img_c * frame_history_len, img_h, img_w)
  74. and dtype np.uint8
  75. act_batch: np.array
  76. Array of shape (batch_size,) and dtype np.int32
  77. rew_batch: np.array
  78. Array of shape (batch_size,) and dtype np.float32
  79. next_obs_batch: np.array
  80. Array of shape
  81. (batch_size, img_c * frame_history_len, img_h, img_w)
  82. and dtype np.uint8
  83. done_mask: np.array
  84. Array of shape (batch_size,) and dtype np.float32
  85. """
  86. assert self.can_sample(batch_size)
  87. idxes = sample_n_unique(lambda: random.randint(0, self.num_in_buffer - 2), batch_size)
  88. return self._encode_sample(idxes)
  89. def encode_recent_observation(self):
  90. """Return the most recent `frame_history_len` frames.
  91. Returns
  92. -------
  93. observation: np.array
  94. Array of shape (img_h, img_w, img_c * frame_history_len)
  95. and dtype np.uint8, where observation[:, :, i*img_c:(i+1)*img_c]
  96. encodes frame at time `t - frame_history_len + i`
  97. """
  98. assert self.num_in_buffer > 0
  99. return self._encode_observation((self.next_idx - 1) % self.size)
  100. def _encode_observation(self, idx):
  101. end_idx = idx + 1 # make noninclusive
  102. start_idx = end_idx - self.frame_history_len
  103. # this checks if we are using low-dimensional observations, such as RAM
  104. # state, in which case we just directly return the latest RAM.
  105. if len(self.obs.shape) == 2:
  106. return self.obs[end_idx-1]
  107. # if there weren't enough frames ever in the buffer for context
  108. if start_idx < 0 and self.num_in_buffer != self.size:
  109. start_idx = 0
  110. for idx in range(start_idx, end_idx - 1):
  111. if self.done[idx % self.size]:
  112. start_idx = idx + 1
  113. missing_context = self.frame_history_len - (end_idx - start_idx)
  114. # if zero padding is needed for missing context
  115. # or we are on the boundry of the buffer
  116. if start_idx < 0 or missing_context > 0:
  117. frames = [np.zeros_like(self.obs[0]) for _ in range(missing_context)]
  118. for idx in range(start_idx, end_idx):
  119. frames.append(self.obs[idx % self.size])
  120. return np.concatenate(frames, 0)
  121. else:
  122. # this optimization has potential to saves about 30% compute time \o/
  123. img_h, img_w = self.obs.shape[2], self.obs.shape[3]
  124. return self.obs[start_idx:end_idx].reshape(-1, img_h, img_w)
  125. def store_frame(self, frame):
  126. """Store a single frame in the buffer at the next available index, overwriting
  127. old frames if necessary.
  128. Parameters
  129. ----------
  130. frame: np.array
  131. Array of shape (img_h, img_w, img_c) and dtype np.uint8
  132. and the frame will transpose to shape (img_h, img_w, img_c) to be stored
  133. Returns
  134. -------
  135. idx: int
  136. Index at which the frame is stored. To be used for `store_effect` later.
  137. """
  138. # make sure we are not using low-dimensional observations, such as RAM
  139. if len(frame.shape) > 1:
  140. # transpose image frame into (img_c, img_h, img_w)
  141. frame = frame.transpose(2, 0, 1)
  142. if self.obs is None:
  143. self.obs = np.empty([self.size] + list(frame.shape), dtype=np.uint8)
  144. self.action = np.empty([self.size], dtype=np.int32)
  145. self.reward = np.empty([self.size], dtype=np.float32)
  146. self.done = np.empty([self.size], dtype=np.bool)
  147. self.obs[self.next_idx] = frame
  148. ret = self.next_idx
  149. self.next_idx = (self.next_idx + 1) % self.size
  150. self.num_in_buffer = min(self.size, self.num_in_buffer + 1)
  151. return ret
  152. def store_effect(self, idx, action, reward, done):
  153. """Store effects of action taken after obeserving frame stored
  154. at index idx. The reason `store_frame` and `store_effect` is broken
  155. up into two functions is so that one can call `encode_recent_observation`
  156. in between.
  157. Paramters
  158. ---------
  159. idx: int
  160. Index in buffer of recently observed frame (returned by `store_frame`).
  161. action: int
  162. Action that was performed upon observing this frame.
  163. reward: float
  164. Reward that was received when the actions was performed.
  165. done: bool
  166. True if episode was finished after performing that action.
  167. """
  168. self.action[idx] = action
  169. self.reward[idx] = reward
  170. self.done[idx] = done
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...