Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

experiments.py 2.9 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
  1. import gym
  2. import torch.optim as optim
  3. import dqn_model
  4. from double_dqn_learn import OptimizerSpec, double_dqn_learning
  5. from dqn_learn import dqn_learing
  6. from dqn_learn_ex import dqn_learn_ex
  7. from utils.gym import get_env, get_wrapper_by_name
  8. from utils.schedule import LinearSchedule
  9. from utils.cmd_kwargs import get_cmd_kwargs
  10. default_kwargs = {
  11. "batch_size": 32,
  12. "gamma": 0.99,
  13. "replay_buffer_size": 1000000,
  14. "learning_starts": 50000,
  15. "learning_freq": 4,
  16. "frame_history_len": 4,
  17. "target_update_freq": 10000,
  18. "save_path": None,
  19. "log_every_n_steps": 3000,
  20. "exploration_steps": 1000000,
  21. "exploration_min": 0.1,
  22. "gym_task_index": 3,
  23. "model": "DQN_SEPARABLE_DEEP",
  24. "optimizer": "Adam",
  25. "learning_func": "double_dqn",
  26. }
  27. def optimizer_spec(args):
  28. rmsprop_args = dict(
  29. lr=args.get('lr', 0.00025),
  30. alpha=args.get('alpha', 0.95),
  31. eps=args.get('eps', 0.01),
  32. weight_decay=args.get('weight_decay', 0),
  33. )
  34. adam_args = dict(
  35. lr=args.get('lr', 1e-3),
  36. eps=args.get('eps', 1e-8),
  37. weight_decay=args.get('weight_decay', 0),
  38. )
  39. specs = {
  40. "RMSProp": OptimizerSpec(
  41. constructor=optim.RMSprop,
  42. kwargs=rmsprop_args,
  43. ),
  44. "Adam": OptimizerSpec(
  45. constructor=optim.Adam,
  46. kwargs=adam_args
  47. )
  48. }
  49. return specs[args['optimizer']]
  50. learning_funcs = {
  51. 'dqn': dqn_learing,
  52. 'dqn_ex': dqn_learn_ex,
  53. 'double_dqn': double_dqn_learning,
  54. }
  55. if __name__ == '__main__':
  56. # Get Atari games.
  57. benchmark = gym.benchmark_spec('Atari40M')
  58. modified_kwargs = dict(default_kwargs, **get_cmd_kwargs())
  59. print("Starting with args:")
  60. print(modified_kwargs)
  61. # Change the index to select a different game.
  62. task = benchmark.tasks[modified_kwargs["gym_task_index"]]
  63. # Run training
  64. seed = 0 # Use a seed of zero (you may want to randomize the seed!)
  65. env = get_env(task, seed)
  66. print(task)
  67. print(task.max_timesteps)
  68. def stopping_criterion(env):
  69. # notice that here t is the number of steps of the wrapped env,
  70. # which is different from the number of steps in the underlying env
  71. return get_wrapper_by_name(env, "Monitor").get_total_steps() >= task.max_timesteps
  72. exploration_schedule = LinearSchedule(modified_kwargs["exploration_steps"], modified_kwargs["exploration_min"])
  73. model = getattr(dqn_model, modified_kwargs['model'])
  74. optimizer = optimizer_spec(modified_kwargs)
  75. learning_func = modified_kwargs['learning_func']
  76. print(str(learning_func) + ' starting with:')
  77. print(env)
  78. print(model)
  79. print(optimizer)
  80. print(exploration_schedule)
  81. print(stopping_criterion)
  82. learning_funcs[learning_func](
  83. env=env,
  84. q_func=model,
  85. optimizer_spec=optimizer,
  86. exploration=exploration_schedule,
  87. stopping_criterion=stopping_criterion,
  88. **modified_kwargs
  89. )
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...