Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

critic_network.py 965 B

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
  1. from torch import nn
  2. from nets.graph_encoder import GraphAttentionEncoder
  3. class CriticNetwork(nn.Module):
  4. def __init__(
  5. self,
  6. input_dim,
  7. embedding_dim,
  8. hidden_dim,
  9. n_layers,
  10. encoder_normalization
  11. ):
  12. super(CriticNetwork, self).__init__()
  13. self.hidden_dim = hidden_dim
  14. self.encoder = GraphAttentionEncoder(
  15. node_dim=input_dim,
  16. n_heads=8,
  17. embed_dim=embedding_dim,
  18. n_layers=n_layers,
  19. normalization=encoder_normalization
  20. )
  21. self.value_head = nn.Sequential(
  22. nn.Linear(embedding_dim, hidden_dim),
  23. nn.ReLU(),
  24. nn.Linear(hidden_dim, 1)
  25. )
  26. def forward(self, inputs):
  27. """
  28. :param inputs: (batch_size, graph_size, input_dim)
  29. :return:
  30. """
  31. _, graph_embeddings = self.encoder(inputs)
  32. return self.value_head(graph_embeddings)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...