Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

test.py 5.8 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
  1. # Import PyTorch modules
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. # Define the entity grounder class
  6. class EntityGrounder(nn.Module):
  7. def __init__(self, hidden_size, vocab_size, embed_size):
  8. super(EntityGrounder, self).__init__()
  9. # Initialize the Transformer encoder for token feature extraction
  10. self.tfm = nn.TransformerEncoderLayer(d_model=embed_size, nhead=8)
  11. # Initialize the RNN for entity feature encoding
  12. self.rnn = nn.GRU(input_size=embed_size, hidden_size=hidden_size, batch_first=True)
  13. # Initialize the MLPs for segment and entity feature projection
  14. self.projI = nn.Linear(in_features=embed_size, out_features=hidden_size)
  15. self.projT = nn.Linear(in_features=hidden_size, out_features=hidden_size)
  16. # Initialize the embedding layer for token embedding
  17. self.embed = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embed_size)
  18. def forward(self, image, caption, entity_mask):
  19. # image: a batch of image features, shape: (batch_size, num_segments, embed_size)
  20. # caption: a batch of token ids, shape: (batch_size, max_length)
  21. # entity_mask: a batch of entity masks, shape: (batch_size, max_length, num_entities)
  22. # entity_mask[i, j, k] = 1 if the j-th token in the i-th caption belongs to the k-th entity, 0 otherwise
  23. # Embed the tokens in the caption
  24. token_embed = self.embed(caption) # shape: (batch_size, max_length, embed_size)
  25. # Apply the Transformer encoder to propagate information between tokens
  26. token_feat = self.tfm(token_embed) # shape: (batch_size, max_length, embed_size)
  27. # Apply the RNN to encode the tokens into entity features
  28. entity_feat, _ = self.rnn(token_feat) # shape: (batch_size, max_length, hidden_size)
  29. # Apply the entity mask to merge the token features into entity features
  30. entity_feat = torch.bmm(entity_mask.transpose(1, 2), entity_feat) # shape: (batch_size, num_entities, hidden_size)
  31. # Project the segment features and the entity features into the same feature space
  32. segment_feat = self.projI(image) # shape: (batch_size, num_segments, hidden_size)
  33. entity_feat = self.projT(entity_feat) # shape: (batch_size, num_entities, hidden_size)
  34. # Compute the cosine similarity between the segment features and the entity features
  35. similarity = F.cosine_similarity(segment_feat.unsqueeze(2), entity_feat.unsqueeze(1), dim=-1) # shape: (batch_size, num_segments, num_entities)
  36. # Return the similarity matrix
  37. return similarity
  38. # Import networkx
  39. import networkx as nx
  40. # Import numpy
  41. import numpy as np
  42. # Define a function to create the text graph from the rule parser output
  43. def create_text_graph(rule_parser_output):
  44. # Create an empty graph
  45. text_graph = nx.Graph()
  46. # Loop through the rule parser output
  47. for e1, r, e2 in rule_parser_output:
  48. # Add the entities as nodes
  49. text_graph.add_node(e1)
  50. text_graph.add_node(e2)
  51. # Add the relation as an edge
  52. text_graph.add_edge(e1, e2, label=r)
  53. # Return the text graph[^1^][1]
  54. return text_graph
  55. # Define a function to assign each word to the entity node that it belongs to
  56. def assign_word_entity(coreNLP_data, annotations_data):
  57. # Initialize an empty list to store the word-entity assignments
  58. word_entity = []
  59. # Loop through the sentences in the coreNLP data
  60. for sentence in coreNLP_data["sentences"]:
  61. # Loop through the tokens in the sentence
  62. for token in sentence["tokens"]:
  63. # Get the word and the pos tag of the token
  64. word = token["word"]
  65. pos = token["pos"]
  66. # Initialize the entity node as None
  67. entity_node = None
  68. # Check if the word is a noun or a noun phrase
  69. if pos in ["NN", "NNP", "NNPS", "NNS"]:
  70. # Use the word as the entity node
  71. entity_node = word
  72. # Check if the word is a determiner, a preposition, or a modifier
  73. elif pos in ["DT", "IN", "TO", "JJ", "JJR", "JJS", "RB", "RBR", "RBS"]:
  74. # Use the previous word's entity node, if any
  75. if word_entity:
  76. entity_node = word_entity[-1][1]
  77. # Check if the word is part of an entity span
  78. for entity_span in annotations_data["entity_spans"]:
  79. # Get the start and end indices of the entity span
  80. start = entity_span["start"]
  81. end = entity_span["end"]
  82. # Get the segment id of the entity span
  83. segment_id = entity_span["segment_id"]
  84. # Check if the token index is within the entity span
  85. if start <= token["index"] - 1 < end:
  86. # Use the segment id as the entity node
  87. entity_node = segment_id
  88. # Append the word and the entity node to the word-entity list
  89. word_entity.append((word, entity_node))
  90. # Return the word-entity list
  91. return word_entity
  92. # Define a function to create the entity mask from the word-entity assignments
  93. def create_entity_mask(word_entity, text_graph):
  94. # Get the number of words and the number of entities
  95. num_words = len(word_entity)
  96. num_entities = len(text_graph.nodes)
  97. # Create an empty matrix of zeros
  98. entity_mask = np.zeros((num_words, num_entities))
  99. # Loop through the word-entity list
  100. for i, (word, entity_node) in enumerate(word_entity):
  101. # Check if the word belongs to an entity
  102. if entity_node is not None:
  103. # Get the index of the entity node
  104. j = list(text_graph.nodes).index(entity_node)
  105. # Set the matrix element to 1
  106. entity_mask[i, j] = 1
  107. # Return the entity mask
  108. return entity_mask
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...