Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

dqn_learn.py 14 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
  1. """
  2. This file is copied/apdated from https://github.com/berkeleydeeprlcourse/homework/tree/master/hw3
  3. """
  4. import sys
  5. import pickle
  6. import numpy as np
  7. from collections import namedtuple
  8. from itertools import count
  9. import random
  10. import gym.spaces
  11. from utils.saved_state import SavedState
  12. import torch
  13. import torch.optim as optim
  14. import torch.nn as nn
  15. import torch.nn.functional as F
  16. import torch.autograd as autograd
  17. from utils.replay_buffer import ReplayBuffer
  18. from utils.gym import get_wrapper_by_name
  19. USE_CUDA = torch.cuda.is_available()
  20. print("USE_CUDA=", USE_CUDA)
  21. print("tomer11")
  22. dtype = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor
  23. longType = torch.cuda.LongTensor if USE_CUDA else torch.LongTensor
  24. class Variable(autograd.Variable):
  25. def __init__(self, data, *args, **kwargs):
  26. if USE_CUDA:
  27. data = data.cuda()
  28. super(Variable, self).__init__(data, *args, **kwargs)
  29. """
  30. OptimizerSpec containing following attributes
  31. constructor: The optimizer constructor ex: RMSprop
  32. kwargs: {Dict} arguments for constructing optimizer
  33. """
  34. OptimizerSpec = namedtuple("OptimizerSpec", ["constructor", "kwargs"])
  35. def dqn_learing(
  36. env,
  37. q_func,
  38. optimizer_spec,
  39. exploration,
  40. stopping_criterion=None,
  41. replay_buffer_size=1000000,
  42. batch_size=32,
  43. gamma=0.99,
  44. learning_starts=50000,
  45. learning_freq=4,
  46. frame_history_len=4,
  47. target_update_freq=10000,
  48. save_path=None,
  49. save_freq=100000,
  50. **kwargs
  51. ):
  52. """Run Deep Q-learning algorithm.
  53. You can specify your own convnet using q_func.
  54. All schedules are w.r.t. total number of steps taken in the environment.
  55. Parameters
  56. ----------
  57. env: gym.Env
  58. gym environment to train on.
  59. q_func: function
  60. Model to use for computing the q function. It should accept the
  61. following named arguments:
  62. input_channel: int
  63. number of channel of input.
  64. num_actions: int
  65. number of actions
  66. optimizer_spec: OptimizerSpec
  67. Specifying the constructor and kwargs, as well as learning rate schedule
  68. for the optimizer
  69. exploration: Schedule (defined in utils.schedule)
  70. schedule for probability of chosing random action.
  71. stopping_criterion: (env) -> bool
  72. should return true when it's ok for the RL algorithm to stop.
  73. takes in env and the number of steps executed so far.
  74. replay_buffer_size: int
  75. How many memories to store in the replay buffer.
  76. batch_size: int
  77. How many transitions to sample each time experience is replayed.
  78. gamma: float
  79. Discount Factor
  80. learning_starts: int
  81. After how many environment steps to start replaying experiences
  82. learning_freq: int
  83. How many steps of environment to take between every experience replay
  84. frame_history_len: int
  85. How many past frames to include as input to the model.
  86. target_update_freq: int
  87. How many experience replay rounds (not steps!) to perform between
  88. each update to the target Q network
  89. """
  90. assert type(env.observation_space) == gym.spaces.Box
  91. assert type(env.action_space) == gym.spaces.Discrete
  92. ###############
  93. # BUILD MODEL #
  94. ###############
  95. if len(env.observation_space.shape) == 1:
  96. # This means we are running on low-dimensional observations (e.g. RAM)
  97. input_arg = env.observation_space.shape[0]
  98. else:
  99. img_h, img_w, img_c = env.observation_space.shape
  100. input_arg = frame_history_len * img_c
  101. num_actions = env.action_space.n
  102. def to_pytorch(obs, type=dtype, normalize=True):
  103. t = torch.from_numpy(obs).type(type)
  104. if normalize:
  105. return t / 255.0
  106. else:
  107. return t
  108. def to_pytorch_var(x, grad=False, type=dtype, normalize=True):
  109. return Variable(to_pytorch(x, type=type, normalize=normalize), requires_grad=grad)
  110. # Construct an epilson greedy policy with given exploration schedule
  111. def select_epsilon_greedy_action(model, obs, t):
  112. sample = random.random()
  113. eps_threshold = exploration.value(t)
  114. if sample > eps_threshold:
  115. obs = to_pytorch(obs).unsqueeze(0)
  116. # Use volatile = True if variable is only used in inference mode, i.e. don’t save the history
  117. return model(Variable(obs, volatile=True)).data.max(1)[1].cpu()
  118. else:
  119. return torch.IntTensor([[random.randrange(num_actions)]])
  120. # Initialize target q function and q function, i.e. build the model.
  121. ######
  122. # YOUR CODE HERE
  123. print("Input and output size of network:")
  124. print(input_arg,num_actions)
  125. Q = q_func(input_arg,num_actions)
  126. Q_target = q_func(input_arg,num_actions)
  127. if USE_CUDA:
  128. Q = Q.cuda()
  129. Q_target = Q_target.cuda()
  130. def update_Q_target():
  131. print("Updating Q_target")
  132. Q_target.load_state_dict(Q.state_dict())
  133. update_Q_target()
  134. ######
  135. # Construct Q network optimizer function
  136. optimizer = optimizer_spec.constructor(Q.parameters(), **optimizer_spec.kwargs)
  137. #scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma_scheduler)
  138. replay_buffer = ReplayBuffer(replay_buffer_size, frame_history_len)
  139. start_step = 0
  140. Statistic = {
  141. "mean_episode_rewards": [],
  142. "best_mean_episode_rewards": []
  143. }
  144. mean_episode_reward = -float('nan')
  145. best_mean_episode_reward = -float('inf')
  146. if save_path is not None:
  147. try:
  148. print("Trying to load state from ", save_path)
  149. with open(save_path, 'rb') as f:
  150. saved_state = pickle.load(f)
  151. start_step = saved_state.timestep
  152. Q.load_state_dict(saved_state.state_dict)
  153. Q_target.load_state_dict(saved_state.state_dict)
  154. Statistic = saved_state.stats
  155. mean_episode_reward = Statistic["mean_episode_rewards"][-1][1]
  156. best_mean_episode_reward = Statistic["best_mean_episode_rewards"][-1][1]
  157. except:
  158. print("Saved state doesn't exist yet")
  159. def save_state(t):
  160. """
  161. Saves the current stable network weights, together with the current time step and statistics, for resuming later
  162. """
  163. if save_path is not None:
  164. print("Saving state")
  165. with open(save_path, 'wb') as f:
  166. pickle.dump(SavedState(Q_target.state_dict(),t,Statistic), f, pickle.HIGHEST_PROTOCOL)
  167. ###############
  168. # RUN ENV #
  169. ###############
  170. num_param_updates = 0
  171. last_obs = env.reset()
  172. LOG_EVERY_N_STEPS = 1000
  173. for t in count(start=start_step):
  174. ### 1. Check stopping criterion
  175. if stopping_criterion is not None and stopping_criterion(env):
  176. break
  177. ### 2. Step the env and store the transition
  178. # At this point, "last_obs" contains the latest observation that was
  179. # recorded from the simulator. Here, your code needs to store this
  180. # observation and its outcome (reward, next observation, etc.) into
  181. # the replay buffer while stepping the simulator forward one step.
  182. # At the end of this block of code, the simulator should have been
  183. # advanced one step, and the replay buffer should contain one more
  184. # transition.
  185. # Specifically, last_obs must point to the new latest observation.
  186. # Useful functions you'll need to call:
  187. # obs, reward, done, info = env.step(action)
  188. # this steps the environment forward one step
  189. # obs = env.reset()
  190. # this resets the environment if you reached an episode boundary.
  191. # Don't forget to call env.reset() to get a new observation if done
  192. # is true!!
  193. # Note that you cannot use "last_obs" directly as input
  194. # into your network, since it needs to be processed to include context
  195. # from previous frames. You should check out the replay buffer
  196. # implementation in dqn_utils.py to see what functionality the replay
  197. # buffer exposes. The replay buffer has a function called
  198. # encode_recent_observation that will take the latest observation
  199. # that you pushed into the buffer and compute the corresponding
  200. # input that should be given to a Q network by appending some
  201. # previous frames.
  202. # Don't forget to include epsilon greedy exploration!
  203. # And remember that the first time you enter this loop, the model
  204. # may not yet have been initialized (but of course, the first step
  205. # might as well be random, since you haven't trained your net...)
  206. #####
  207. # YOUR CODE HERE
  208. last_frame_idx = replay_buffer.store_frame(last_obs)
  209. enc_last_obs = replay_buffer.encode_recent_observation()
  210. action = select_epsilon_greedy_action(Q, enc_last_obs, t)
  211. new_frame, r, done, _ = env.step(action)
  212. replay_buffer.store_effect(last_frame_idx, action, r, done)
  213. if done:
  214. last_obs = env.reset()
  215. else:
  216. last_obs = new_frame
  217. #####
  218. # at this point, the environment should have been advanced one step (and
  219. # reset if done was true), and last_obs should point to the new latest
  220. # observation
  221. ### 3. Perform experience replay and train the network.
  222. # Note that this is only done if the replay buffer contains enough samples
  223. # for us to learn something useful -- until then, the model will not be
  224. # initialized and random actions should be taken
  225. if (t > learning_starts and
  226. t % learning_freq == 0 and
  227. replay_buffer.can_sample(batch_size)):
  228. # Here, you should perform training. Training consists of four steps:
  229. # 3.a: use the replay buffer to sample a batch of transitions (see the
  230. # replay buffer code for function definition, each batch that you sample
  231. # should consist of current observations, current actions, rewards,
  232. # next observations, and done indicator).
  233. # Note: Move the variables to the GPU if avialable
  234. # 3.b: fill in your own code to compute the Bellman error. This requires
  235. # evaluating the current and next Q-values and constructing the corresponding error.
  236. # Note: don't forget to clip the error between [-1,1], multiply is by -1 (since pytorch minimizes) and
  237. # maskout post terminal status Q-values (see ReplayBuffer code).
  238. # 3.c: train the model. To do this, use the bellman error you calculated perviously.
  239. # Pytorch will differentiate this error for you, to backward the error use the following API:
  240. # current.backward(d_error.data.unsqueeze(1))
  241. # Where "current" is the variable holding current Q Values and d_error is the clipped bellman error.
  242. # Your code should produce one scalar-valued tensor.
  243. # Note: don't forget to call optimizer.zero_grad() before the backward call and
  244. # optimizer.step() after the backward call.
  245. # 3.d: periodically update the target network by loading the current Q network weights into the
  246. # target_Q network. see state_dict() and load_state_dict() methods.
  247. # you should update every target_update_freq steps, and you may find the
  248. # variable num_param_updates useful for this (it was initialized to 0)
  249. #####
  250. # YOUR CODE HERE
  251. optimizer.zero_grad()
  252. # TODO: Add a learning rate scheduler step here
  253. obs_batch,act_batch,r_batch,next_obs_batch,done_mask = replay_buffer.sample(batch_size)
  254. Q_val_batch = Q(to_pytorch_var(obs_batch))
  255. Q_target_val_batch = Q_target(to_pytorch_var(next_obs_batch)).detach()
  256. # The following code will take only one cell from each vector of the Q_val_batch tensor.
  257. # Each vector corresponds to a single output of the Q net, and each cell corresponds to a single action.
  258. # This means we take only the cells of the actions that we actually took, since all others are irrelevant
  259. # when calculating the loss.
  260. act_batch_var = to_pytorch_var(act_batch, type=longType, normalize=False).unsqueeze(1)
  261. vals_of_actions_taken = Q_val_batch.gather(1, act_batch_var)
  262. Q_target_val_max, _ = Q_target_val_batch.max(1)
  263. Q_target_val_max = Q_target_val_max.unsqueeze(1)
  264. reverse_done_mask = 1 - to_pytorch_var(done_mask, normalize=False).unsqueeze(1)
  265. Q_target_masked_val_max = (reverse_done_mask * Q_target_val_max)
  266. Q_target_discounted = (gamma * Q_target_masked_val_max)
  267. r_batch_var = to_pytorch_var(r_batch, normalize=False).unsqueeze(1)
  268. Q_masked_target = r_batch_var + Q_target_discounted
  269. bellman_error = Q_masked_target - vals_of_actions_taken
  270. clipped_error = bellman_error.clamp(-1,1)
  271. vals_of_actions_taken.backward(-clipped_error)
  272. optimizer.step()
  273. #scheduler.step()
  274. num_param_updates += 1
  275. if num_param_updates % target_update_freq == 0:
  276. update_Q_target()
  277. #####
  278. ### 4. Log progress and keep track of statistics
  279. episode_rewards = get_wrapper_by_name(env, "Monitor").get_episode_rewards()
  280. if len(episode_rewards) > 0:
  281. mean_episode_reward = np.mean(episode_rewards[-100:])
  282. if len(episode_rewards) > 100:
  283. best_mean_episode_reward = max(best_mean_episode_reward, mean_episode_reward)
  284. Statistic["mean_episode_rewards"].append((t,mean_episode_reward))
  285. Statistic["best_mean_episode_rewards"].append((t,best_mean_episode_reward))
  286. if t % LOG_EVERY_N_STEPS == 0 and t > learning_starts and t > start_step:
  287. print("Timestep %d" % (t,))
  288. print("mean reward (100 episodes) %f" % mean_episode_reward)
  289. print("best mean reward %f" % best_mean_episode_reward)
  290. print("episodes %d" % len(episode_rewards))
  291. print("exploration %f" % exploration.value(t))
  292. sys.stdout.flush()
  293. if t % save_freq == 0 and t > start_step:
  294. save_state(t)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...