Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

test_torch.py 3.9 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. import wandb
  5. from pprint import pprint
  6. from torchvision import models
  7. class Net(nn.Module):
  8. def __init__(self):
  9. super(Net, self).__init__()
  10. self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
  11. self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
  12. self.conv2_drop = nn.Dropout2d()
  13. self.fc1 = nn.Linear(320, 50)
  14. self.fc2 = nn.Linear(50, 10)
  15. def forward(self, x):
  16. x = F.relu(F.max_pool2d(self.conv1(x), 2))
  17. x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
  18. x = x.view(-1, 320)
  19. x = F.relu(self.fc1(x))
  20. x = F.dropout(x, training=self.training)
  21. x = self.fc2(x)
  22. return F.log_softmax(x, dim=1)
  23. class Sequence(nn.Module):
  24. def __init__(self):
  25. super(Sequence, self).__init__()
  26. self.lstm1 = nn.LSTMCell(1, 51)
  27. self.lstm2 = nn.LSTMCell(51, 51)
  28. self.linear = nn.Linear(51, 1)
  29. def forward(self, input, future=0):
  30. outputs = []
  31. h_t = torch.zeros(input.size(0), 51, dtype=torch.double)
  32. c_t = torch.zeros(input.size(0), 51, dtype=torch.double)
  33. h_t2 = torch.zeros(input.size(0), 51, dtype=torch.double)
  34. c_t2 = torch.zeros(input.size(0), 51, dtype=torch.double)
  35. for i, input_t in enumerate(input.chunk(input.size(1), dim=1)):
  36. h_t, c_t = self.lstm1(input_t, (h_t, c_t))
  37. h_t2, c_t2 = self.lstm2(h_t, (h_t2, c_t2))
  38. output = self.linear(h_t2)
  39. outputs += [output]
  40. for i in range(future): # if we should predict the future
  41. h_t, c_t = self.lstm1(output, (h_t, c_t))
  42. h_t2, c_t2 = self.lstm2(h_t, (h_t2, c_t2))
  43. output = self.linear(h_t2)
  44. outputs += [output]
  45. outputs = torch.stack(outputs, 1).squeeze(2)
  46. return outputs
  47. def test_no_requires_grad(history):
  48. # log_stats() used to fail on tensors that didn't have .require_grad = True
  49. history.torch.log_stats(torch.randn(3, 3))
  50. history.torch.log_stats(torch.autograd.Variable(torch.randn(3, 3)))
  51. def test_simple_net():
  52. net = Net()
  53. graph = wandb.Graph.hook_torch(net)
  54. output = net.forward(torch.ones((64, 1, 28, 28), requires_grad=True))
  55. grads = torch.ones(64, 10)
  56. output.backward(grads)
  57. graph = wandb.Graph.transform(graph)
  58. assert len(graph["nodes"]) == 5
  59. assert graph["nodes"][0]['class_name'] == "Conv2d(1, 10, kernel_size=(5, 5), stride=(1, 1))"
  60. assert graph["nodes"][0]['name'] == "conv1"
  61. def test_sequence_net():
  62. net = Sequence()
  63. net.double()
  64. graph = wandb.Graph.hook_torch(net)
  65. output = net.forward(torch.ones(
  66. (97, 999), requires_grad=True, dtype=torch.double))
  67. output.backward(torch.zeros((97, 999), dtype=torch.double))
  68. graph = wandb.Graph.transform(graph)
  69. pprint(graph)
  70. assert len(graph["nodes"]) == 3
  71. assert len(graph["nodes"][0]['parameters']) == 4
  72. assert graph["nodes"][0]['class_name'] == "LSTMCell(1, 51)"
  73. assert graph["nodes"][0]['name'] == "lstm1"
  74. def test_multi_net():
  75. net = Net()
  76. wandb.run = wandb.wandb_run.Run.from_environment_or_defaults()
  77. graphs = wandb.hook_torch((net, net))
  78. wandb.run = None
  79. output = net.forward(torch.ones((64, 1, 28, 28), requires_grad=True))
  80. grads = torch.ones(64, 10)
  81. output.backward(grads)
  82. graph1 = wandb.Graph.transform(graphs[0])
  83. graph2 = wandb.Graph.transform(graphs[1])
  84. assert len(graph1["nodes"]) == 5
  85. assert len(graph2["nodes"]) == 5
  86. def test_alex_net():
  87. alex = models.AlexNet()
  88. graph = wandb.Graph.hook_torch(alex)
  89. output = alex.forward(torch.ones((2, 3, 224, 224), requires_grad=True))
  90. grads = torch.ones(2, 1000)
  91. output.backward(grads)
  92. graph = wandb.Graph.transform(graph)
  93. assert len(graph["nodes"]) == 20
  94. assert graph["nodes"][0]['class_name'] == "Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))"
  95. assert graph["nodes"][0]['name'] == "features.0"
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...