1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
|
- import gym
- import torch.optim as optim
- import dqn_model
- from double_dqn_learn import OptimizerSpec, double_dqn_learning
- from dqn_learn import dqn_learing
- from dqn_learn_ex import dqn_learn_ex
- from utils.gym import get_env, get_wrapper_by_name
- from utils.schedule import LinearSchedule
- from utils.cmd_kwargs import get_cmd_kwargs
- default_kwargs = {
- "batch_size": 32,
- "gamma": 0.99,
- "replay_buffer_size": 1000000,
- "learning_starts": 50000,
- "learning_freq": 4,
- "frame_history_len": 4,
- "target_update_freq": 10000,
- "save_path": None,
- "log_every_n_steps": 3000,
- "exploration_steps": 1000000,
- "exploration_min": 0.1,
- "gym_task_index": 3,
- "model": "DQN_SEPARABLE_DEEP",
- "optimizer": "Adam",
- "learning_func": "double_dqn",
- }
- def optimizer_spec(args):
- rmsprop_args = dict(
- lr=args.get('lr', 0.00025),
- alpha=args.get('alpha', 0.95),
- eps=args.get('eps', 0.01),
- weight_decay=args.get('weight_decay', 0),
- )
- adam_args = dict(
- lr=args.get('lr', 1e-3),
- eps=args.get('eps', 1e-8),
- weight_decay=args.get('weight_decay', 0),
- )
- specs = {
- "RMSProp": OptimizerSpec(
- constructor=optim.RMSprop,
- kwargs=rmsprop_args,
- ),
- "Adam": OptimizerSpec(
- constructor=optim.Adam,
- kwargs=adam_args
- )
- }
- return specs[args['optimizer']]
- learning_funcs = {
- 'dqn': dqn_learing,
- 'dqn_ex': dqn_learn_ex,
- 'double_dqn': double_dqn_learning,
- }
- if __name__ == '__main__':
- # Get Atari games.
- benchmark = gym.benchmark_spec('Atari40M')
- modified_kwargs = dict(default_kwargs, **get_cmd_kwargs())
- print("Starting with args:")
- print(modified_kwargs)
- # Change the index to select a different game.
- task = benchmark.tasks[modified_kwargs["gym_task_index"]]
- # Run training
- seed = 0 # Use a seed of zero (you may want to randomize the seed!)
- env = get_env(task, seed)
- print(task)
- print(task.max_timesteps)
- def stopping_criterion(env):
- # notice that here t is the number of steps of the wrapped env,
- # which is different from the number of steps in the underlying env
- return get_wrapper_by_name(env, "Monitor").get_total_steps() >= task.max_timesteps
- exploration_schedule = LinearSchedule(modified_kwargs["exploration_steps"], modified_kwargs["exploration_min"])
- model = getattr(dqn_model, modified_kwargs['model'])
- optimizer = optimizer_spec(modified_kwargs)
- learning_func = modified_kwargs['learning_func']
- print(str(learning_func) + ' starting with:')
- print(env)
- print(model)
- print(optimizer)
- print(exploration_schedule)
- print(stopping_criterion)
- learning_funcs[learning_func](
- env=env,
- q_func=model,
- optimizer_spec=optimizer,
- exploration=exploration_schedule,
- stopping_criterion=stopping_criterion,
- **modified_kwargs
- )
|