Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

pytorch_geometric_networks.py 7.3 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
  1. """
  2. this code is adapted from
  3. unsupervised GraphSAGE example
  4. https://github.com/pyg-team/pytorch_geometric/blob/master/examples/graph_sage_unsup.py
  5. """
  6. import os
  7. from typing import Union
  8. import copy
  9. import numpy as np
  10. import pandas as pd
  11. import torch
  12. from torch import Tensor
  13. import torch.nn as nn
  14. import torch.nn.functional as F
  15. import tqdm
  16. from torch_cluster import random_walk
  17. from torch_geometric.loader import NeighborSampler as RawNeighborSampler
  18. from torch_geometric.nn import SAGEConv
  19. from torch_geometric.typing import Adj, OptPairTensor, Size
  20. from torch_geometric.utils import to_undirected
  21. class ResidualSAGEConv(SAGEConv):
  22. def __init__(self, **kwargs):
  23. super(ResidualSAGEConv, self).__init__(**kwargs)
  24. def forward(
  25. self, x: Union[Tensor, OptPairTensor], edge_index: Adj, size: Size = None
  26. ) -> Tensor:
  27. """ """
  28. if isinstance(x, Tensor):
  29. x: OptPairTensor = (x, x)
  30. out = self.propagate(edge_index, x=x, size=size)
  31. out = self.lin_l(out)
  32. x_r = x[1]
  33. if self.root_weight and x_r is not None:
  34. out += self.lin_r(x_r)
  35. out += x_r
  36. if self.normalize:
  37. out = F.normalize(out, p=2.0, dim=-1)
  38. return out
  39. class SAGENeighborSampler(RawNeighborSampler):
  40. def __init__(self, edge_index, sizes, directed=False, **kwargs):
  41. self.directed = directed
  42. super(SAGENeighborSampler, self).__init__(
  43. edge_index=edge_index, sizes=sizes, **kwargs
  44. )
  45. if not self.directed:
  46. new_adj_t = self.adj_t.t() + self.adj_t
  47. self.adj_t = new_adj_t
  48. def sample(self, batch):
  49. batch = torch.tensor(batch)
  50. row, col, _ = self.adj_t.coo()
  51. pos_batch = random_walk(row, col, batch, walk_length=1, coalesced=False)[:, 1]
  52. n_edges = self.adj_t.size(1)
  53. neg_batch = torch.randint(0, n_edges, (batch.numel(),), dtype=torch.long)
  54. batch = torch.cat([batch, pos_batch, neg_batch], dim=0)
  55. return super(SAGENeighborSampler, self).sample(batch)
  56. class SAGE(nn.Module):
  57. def __init__(
  58. self,
  59. in_channels,
  60. hidden_channels,
  61. num_layers,
  62. use_self_connection,
  63. sage_layer_cls=SAGEConv,
  64. ):
  65. super(SAGE, self).__init__()
  66. self.num_layers = num_layers
  67. self.convs = nn.ModuleList()
  68. for i in range(num_layers):
  69. in_channels = in_channels if i == 0 else hidden_channels
  70. self.convs.append(
  71. sage_layer_cls(
  72. in_channels=in_channels,
  73. out_channels=hidden_channels,
  74. root_weight=use_self_connection,
  75. )
  76. )
  77. def forward(self, x, adjs):
  78. """
  79. x - embeddings of nodes
  80. adjs -
  81. list of ((edge_index, edge_data, size))
  82. data for edges (edge_index contains data in COO format)
  83. 2 lists
  84. - one for neighboring vertices (positive samples)
  85. - one for random vertices (negative samples)
  86. """
  87. for i, (edge_index, _, size) in enumerate(adjs):
  88. x_target = x[: size[1]]
  89. x = self.convs[i]((x, x_target), edge_index)
  90. if i != self.num_layers - 1:
  91. x = x.relu()
  92. x = F.dropout(x, p=0.5, training=self.training)
  93. out, pos_out, neg_out = x.split(x.size(0) // 3, dim=0)
  94. return out, pos_out, neg_out
  95. def full_forward(self, x, edge_index):
  96. for i, conv in enumerate(self.convs):
  97. x = conv(x, edge_index)
  98. if i != self.num_layers - 1:
  99. x = x.relu()
  100. x = F.dropout(x, p=0.5, training=self.training)
  101. return x
  102. def loss(self, out, pos_out, neg_out):
  103. pos_loss = F.logsigmoid((out * pos_out).sum(-1)).mean()
  104. neg_loss = F.logsigmoid(-(out * neg_out).sum(-1)).mean()
  105. loss = -pos_loss - neg_loss
  106. return loss, pos_loss, neg_loss
  107. class Encoder(nn.Module):
  108. def __init__(self, in_channels, hidden_channels):
  109. super(Encoder, self).__init__()
  110. self.convs = torch.nn.ModuleList(
  111. [
  112. SAGEConv(in_channels, hidden_channels, root_weight=use_self_connection),
  113. SAGEConv(
  114. hidden_channels, hidden_channels, root_weight=use_self_connection
  115. ),
  116. SAGEConv(
  117. hidden_channels, hidden_channels, root_weight=use_self_connection
  118. ),
  119. ]
  120. )
  121. self.activations = torch.nn.ModuleList()
  122. self.activations.extend(
  123. [
  124. nn.PReLU(hidden_channels),
  125. nn.PReLU(hidden_channels),
  126. nn.PReLU(hidden_channels),
  127. ]
  128. )
  129. def forward(self, x, adjs):
  130. for i, (edge_index, _, size) in enumerate(adjs):
  131. x_target = x[: size[1]] # Target nodes are always placed first.
  132. x = self.convs[i]((x, x_target), edge_index)
  133. x = self.activations[i](x)
  134. return x
  135. def full_forward(self, x, edge_index):
  136. for i, conv in enumerate(self.convs):
  137. x_target = x # Target nodes are always placed first.
  138. x = self.convs[i]((x, x_target), edge_index)
  139. x = self.activations[i](x)
  140. return x
  141. def graph_infomax_corruption(x, edge_index):
  142. return x[torch.randperm(x.size(0))], edge_index
  143. def graph_infomax_summary(z, *args, **kwargs):
  144. return torch.sigmoid(z.mean(dim=0))
  145. def get_src_dst_embeddings(model, batch):
  146. h = model(batch.x, batch.edge_index)
  147. h_src = h[batch.edge_label_index[0]]
  148. h_dst = h[batch.edge_label_index[1]]
  149. return h_src, h_dst, h_src.size(0)
  150. def get_neg_pos_loss(batch, h_src, h_dst):
  151. pred = (h_src * h_dst).sum(dim=-1)
  152. return F.binary_cross_entropy_with_logits(pred, batch.edge_label)
  153. def get_loss(model, batch):
  154. h_src, h_dst, n_src = get_src_dst_embeddings(model, batch)
  155. return get_neg_pos_loss(batch, h_src, h_dst), n_src
  156. def get_model_grads_dict(model):
  157. return {
  158. name: torch.norm(p.grad.detach()).item()
  159. for (name, p) in model.named_parameters()
  160. }
  161. def train_epoch(
  162. model,
  163. data,
  164. train_loader,
  165. optimizer,
  166. scheduler,
  167. device,
  168. logging_callback,
  169. model_callback,
  170. ):
  171. model.train()
  172. with torch.cuda.amp.autocast():
  173. total_loss = 0
  174. for i, batch in enumerate(tqdm.tqdm(train_loader)):
  175. batch = batch.to(device)
  176. optimizer.zero_grad()
  177. loss, n_src = get_loss(model, batch)
  178. loss.backward()
  179. optimizer.step()
  180. scheduler.step(i)
  181. total_loss += float(loss) * n_src
  182. lr = optimizer.state_dict()["param_groups"][0]["lr"]
  183. logging_callback(
  184. {
  185. "loss": loss,
  186. "size": n_src,
  187. "lr": lr,
  188. **get_model_grads_dict(model)
  189. }
  190. )
  191. model_callback(model, i)
  192. loss = total_loss / data.num_nodes
  193. return loss
  194. @torch.no_grad()
  195. def get_val_loss(model, val_loader):
  196. model.eval()
  197. for batch_size, n_id, adjs in val_loader:
  198. adjs = [adj.to(device) for adj in adjs]
  199. x = data.x[n_id].to(device)
  200. outs = model(x, adjs)
  201. loss = unsupervised_graphsage_loss(model, x[n_id], adjs)
  202. total_loss += float(loss) * out.size(0)
  203. return total_loss
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...