1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
|
- from einops import rearrange
- import numpy as np
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from torchcluster.zoo.spectrum import SpectrumClustering
- class SegmentMerger(nn.Module):
- def __init__(self, K, H):
- super(SegmentMerger, self).__init__()
- self.K = K # the number of grouping stages
- self.H = H # the number of segments at each stage
- self.theta = -0.5
- """
- * x is a tensor of segment feature tensors, each of shape [B, Hk, D]
- * l is a list of indices of the grounding results for each entity
- * y is a tensor of entity embeddings, of shape [B, N, D]
- """
- def forward(self, x: torch.Tensor, y: torch.Tensor, text_mask):
- loss = 0
- x = F.normalize(x, dim=-1)
- y = F.normalize(y, dim=-1)
- sim_k = torch.cosine_similarity(x.unsqueeze(2), x.unsqueeze(1), dim=-1)
- sim_k = (sim_k + 1) / 2
-
- entity_grounding_results = torch.cosine_similarity(x.unsqueeze(2), y.unsqueeze(1), dim=-1)
- entity_grounding_results = torch.where(rearrange(text_mask, 'b c -> b 1 c'), entity_grounding_results, torch.full_like(entity_grounding_results, fill_value=-1e9))
- l = entity_grounding_results.argmax(dim=-1)
- sim_target_k = (l.unsqueeze(2) == l.unsqueeze(1))
- mask_k = (entity_grounding_results.max(dim=-1).values > self.theta).float() # [Hk]
- # compute the cosine similarity between each segment feature x_k_i and its corresponding entity embedding y_l^{k}_i, and filter out those below a threshold theta
- sim_target_k = sim_target_k * mask_k.unsqueeze(2) * mask_k.unsqueeze(1) # filter out low similarity segments
-
- loss_k = nn.MSELoss()(sim_k, sim_target_k)
- loss += loss_k
- return loss, sim_k
- """
- * segment_features is a list of segment feature tensors, each of shape [Hk, D]
- * k_inf is the index of the grouping stage at which to perform inference
- """
- def inference(self, segment_features, masks):
- similarity_matrix = torch.cosine_similarity(segment_features.unsqueeze(0), segment_features.unsqueeze(1), dim=-1)
- cluster = SpectrumClustering(8)
- labels = cluster(similarity_matrix)
- print(labels)
-
|