1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
|
- # Import PyTorch modules
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- # Define the entity grounder class
- class EntityGrounder(nn.Module):
- def __init__(self, hidden_size, vocab_size, embed_size):
- super(EntityGrounder, self).__init__()
- # Initialize the Transformer encoder for token feature extraction
- self.tfm = nn.TransformerEncoderLayer(d_model=embed_size, nhead=8)
- # Initialize the RNN for entity feature encoding
- self.rnn = nn.GRU(input_size=embed_size, hidden_size=hidden_size, batch_first=True)
- # Initialize the MLPs for segment and entity feature projection
- self.projI = nn.Linear(in_features=embed_size, out_features=hidden_size)
- self.projT = nn.Linear(in_features=hidden_size, out_features=hidden_size)
- # Initialize the embedding layer for token embedding
- self.embed = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embed_size)
- def forward(self, image, caption, entity_mask):
- # image: a batch of image features, shape: (batch_size, num_segments, embed_size)
- # caption: a batch of token ids, shape: (batch_size, max_length)
- # entity_mask: a batch of entity masks, shape: (batch_size, max_length, num_entities)
- # entity_mask[i, j, k] = 1 if the j-th token in the i-th caption belongs to the k-th entity, 0 otherwise
- # Embed the tokens in the caption
- token_embed = self.embed(caption) # shape: (batch_size, max_length, embed_size)
- # Apply the Transformer encoder to propagate information between tokens
- token_feat = self.tfm(token_embed) # shape: (batch_size, max_length, embed_size)
- # Apply the RNN to encode the tokens into entity features
- entity_feat, _ = self.rnn(token_feat) # shape: (batch_size, max_length, hidden_size)
- # Apply the entity mask to merge the token features into entity features
- entity_feat = torch.bmm(entity_mask.transpose(1, 2), entity_feat) # shape: (batch_size, num_entities, hidden_size)
- # Project the segment features and the entity features into the same feature space
- segment_feat = self.projI(image) # shape: (batch_size, num_segments, hidden_size)
- entity_feat = self.projT(entity_feat) # shape: (batch_size, num_entities, hidden_size)
- # Compute the cosine similarity between the segment features and the entity features
- similarity = F.cosine_similarity(segment_feat.unsqueeze(2), entity_feat.unsqueeze(1), dim=-1) # shape: (batch_size, num_segments, num_entities)
- # Return the similarity matrix
- return similarity
- # Import networkx
- import networkx as nx
- # Import numpy
- import numpy as np
- # Define a function to create the text graph from the rule parser output
- def create_text_graph(rule_parser_output):
- # Create an empty graph
- text_graph = nx.Graph()
- # Loop through the rule parser output
- for e1, r, e2 in rule_parser_output:
- # Add the entities as nodes
- text_graph.add_node(e1)
- text_graph.add_node(e2)
- # Add the relation as an edge
- text_graph.add_edge(e1, e2, label=r)
- # Return the text graph[^1^][1]
- return text_graph
- # Define a function to assign each word to the entity node that it belongs to
- def assign_word_entity(coreNLP_data, annotations_data):
- # Initialize an empty list to store the word-entity assignments
- word_entity = []
- # Loop through the sentences in the coreNLP data
- for sentence in coreNLP_data["sentences"]:
- # Loop through the tokens in the sentence
- for token in sentence["tokens"]:
- # Get the word and the pos tag of the token
- word = token["word"]
- pos = token["pos"]
- # Initialize the entity node as None
- entity_node = None
- # Check if the word is a noun or a noun phrase
- if pos in ["NN", "NNP", "NNPS", "NNS"]:
- # Use the word as the entity node
- entity_node = word
- # Check if the word is a determiner, a preposition, or a modifier
- elif pos in ["DT", "IN", "TO", "JJ", "JJR", "JJS", "RB", "RBR", "RBS"]:
- # Use the previous word's entity node, if any
- if word_entity:
- entity_node = word_entity[-1][1]
- # Check if the word is part of an entity span
- for entity_span in annotations_data["entity_spans"]:
- # Get the start and end indices of the entity span
- start = entity_span["start"]
- end = entity_span["end"]
- # Get the segment id of the entity span
- segment_id = entity_span["segment_id"]
- # Check if the token index is within the entity span
- if start <= token["index"] - 1 < end:
- # Use the segment id as the entity node
- entity_node = segment_id
- # Append the word and the entity node to the word-entity list
- word_entity.append((word, entity_node))
- # Return the word-entity list
- return word_entity
- # Define a function to create the entity mask from the word-entity assignments
- def create_entity_mask(word_entity, text_graph):
- # Get the number of words and the number of entities
- num_words = len(word_entity)
- num_entities = len(text_graph.nodes)
- # Create an empty matrix of zeros
- entity_mask = np.zeros((num_words, num_entities))
- # Loop through the word-entity list
- for i, (word, entity_node) in enumerate(word_entity):
- # Check if the word belongs to an entity
- if entity_node is not None:
- # Get the index of the entity node
- j = list(text_graph.nodes).index(entity_node)
- # Set the matrix element to 1
- entity_mask[i, j] = 1
- # Return the entity mask
- return entity_mask
|