Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

dqn_model.py 4.7 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
  1. import torch.nn as nn
  2. import torch.nn.functional as F
  3. class DQN(nn.Module):
  4. def __init__(self, in_channels=4, num_actions=18):
  5. """
  6. Initialize a deep Q-learning network as described in
  7. https://storage.googleapis.com/deepmind-data/assets/papers/DeepMindNature14236Paper.pdf
  8. Arguments:
  9. in_channels: number of channel of input.
  10. i.e The number of most recent frames stacked together as describe in the paper
  11. num_actions: number of action-value to output, one-to-one correspondence to action in game.
  12. """
  13. super(DQN, self).__init__()
  14. self.conv1 = nn.Conv2d(in_channels, 32, kernel_size=8, stride=4)
  15. self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
  16. self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
  17. self.fc4 = nn.Linear(7 * 7 * 64, 512)
  18. self.fc5 = nn.Linear(512, num_actions)
  19. def forward(self, x):
  20. x = F.relu(self.conv1(x))
  21. x = F.relu(self.conv2(x))
  22. x = F.relu(self.conv3(x))
  23. x = F.relu(self.fc4(x.view(x.size(0), -1)))
  24. return self.fc5(x)
  25. class DQN_RAM(nn.Module):
  26. def __init__(self, in_features=4, num_actions=18):
  27. """
  28. Initialize a deep Q-learning network for testing algorithm
  29. in_features: number of features of input.
  30. num_actions: number of action-value to output, one-to-one correspondence to action in game.
  31. """
  32. super(DQN_RAM, self).__init__()
  33. self.fc1 = nn.Linear(in_features, 256)
  34. self.fc2 = nn.Linear(256, 128)
  35. self.fc3 = nn.Linear(128, 64)
  36. self.fc4 = nn.Linear(64, num_actions)
  37. def forward(self, x):
  38. x = F.relu(self.fc1(x))
  39. x = F.relu(self.fc2(x))
  40. x = F.relu(self.fc3(x))
  41. return self.fc4(x)
  42. class DQN_SEPARABLE(nn.Module):
  43. def __init__(self, in_channels=4, num_actions=18):
  44. """
  45. Similar architecture to DQN above, but the classic conv2d is replaced with depthwise separable convolutions:
  46. https://arxiv.org/abs/1704.04861
  47. This should yield much more efficient models, with an order of magnitude less trainable parameters and therefore
  48. much more efficient forward and backprop.
  49. Also, replaced the ReLUs with leaky ReLUs, to prevent "dead" neurons.
  50. """
  51. super(DQN_SEPARABLE, self).__init__()
  52. self.conv1_depth = nn.Conv2d(in_channels, in_channels, kernel_size=8, stride=4, groups=in_channels)
  53. self.conv1_point = nn.Conv2d(in_channels, 32, kernel_size=1)
  54. self.conv2_depth = nn.Conv2d(32, 32, kernel_size=4, stride=2, groups=32)
  55. self.conv2_point = nn.Conv2d(32, 64, kernel_size=1)
  56. self.conv3_depth = nn.Conv2d(64, 64, kernel_size=3, stride=1, groups=64)
  57. self.conv3_point = nn.Conv2d(64, 64, kernel_size=1)
  58. self.fc4 = nn.Linear(7 * 7 * 64, 512)
  59. self.fc5 = nn.Linear(512, num_actions)
  60. def forward(self, x):
  61. x = F.leaky_relu(self.conv1_point(self.conv1_depth(x)))
  62. x = F.leaky_relu(self.conv2_point(self.conv2_depth(x)))
  63. x = F.leaky_relu(self.conv3_point(self.conv3_depth(x)))
  64. x = F.leaky_relu(self.fc4(x.view(x.size(0), -1)))
  65. return self.fc5(x)
  66. class DQN_SEPARABLE_DEEP(nn.Module):
  67. def __init__(self, in_channels=4, num_actions=18, num_layers=18, features=32, in_height=84, in_width=84):
  68. """
  69. Similar to DQN_SEPARABLE, but (almost) arbitrarily deep, and actually DOES have an order of magnitude less
  70. weights than DQN (since the final linear layer was the actual source of most of the weights, and it should be smaller here)
  71. """
  72. super(DQN_SEPARABLE_DEEP, self).__init__()
  73. self.num_layers = num_layers
  74. self.conv0_depth = nn.Conv2d(in_channels, in_channels, kernel_size=5, stride=2, groups=in_channels)
  75. self.conv0_point = nn.Conv2d(in_channels, features, kernel_size=1)
  76. for i in range(1, num_layers):
  77. setattr(self, 'conv{}_depth'.format(i), nn.Conv2d(features, features, kernel_size=3, groups=features))
  78. setattr(self, 'conv{}_point'.format(i), nn.Conv2d(features, features, kernel_size=1))
  79. def out_size(in_size):
  80. return ((in_size - 4) / 2) - (2 * (num_layers - 1))
  81. out_height = out_size(in_height)
  82. out_width = out_size(in_width)
  83. self.fc1 = nn.Linear(int(out_height * out_width * features), 512)
  84. self.fc2 = nn.Linear(512, num_actions)
  85. def forward(self, x):
  86. for i in range(self.num_layers):
  87. conv_depth = getattr(self, 'conv{}_depth'.format(i))
  88. conv_point = getattr(self, 'conv{}_point'.format(i))
  89. x = conv_depth(x)
  90. x = conv_point(x)
  91. x = F.leaky_relu(x)
  92. x = F.leaky_relu(self.fc1(x.view(x.size(0), -1)))
  93. return self.fc2(x)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...