Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

segment_merger.py 2.2 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
  1. from einops import rearrange
  2. import numpy as np
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from torchcluster.zoo.spectrum import SpectrumClustering
  7. class SegmentMerger(nn.Module):
  8. def __init__(self, K, H):
  9. super(SegmentMerger, self).__init__()
  10. self.K = K # the number of grouping stages
  11. self.H = H # the number of segments at each stage
  12. self.theta = -0.5
  13. """
  14. * x is a tensor of segment feature tensors, each of shape [B, Hk, D]
  15. * l is a list of indices of the grounding results for each entity
  16. * y is a tensor of entity embeddings, of shape [B, N, D]
  17. """
  18. def forward(self, x: torch.Tensor, y: torch.Tensor, text_mask):
  19. loss = 0
  20. x = F.normalize(x, dim=-1)
  21. y = F.normalize(y, dim=-1)
  22. sim_k = torch.cosine_similarity(x.unsqueeze(2), x.unsqueeze(1), dim=-1)
  23. sim_k = (sim_k + 1) / 2
  24. entity_grounding_results = torch.cosine_similarity(x.unsqueeze(2), y.unsqueeze(1), dim=-1)
  25. entity_grounding_results = torch.where(rearrange(text_mask, 'b c -> b 1 c'), entity_grounding_results, torch.full_like(entity_grounding_results, fill_value=-1e9))
  26. l = entity_grounding_results.argmax(dim=-1)
  27. sim_target_k = (l.unsqueeze(2) == l.unsqueeze(1))
  28. mask_k = (entity_grounding_results.max(dim=-1).values > self.theta).float() # [Hk]
  29. # compute the cosine similarity between each segment feature x_k_i and its corresponding entity embedding y_l^{k}_i, and filter out those below a threshold theta
  30. sim_target_k = sim_target_k * mask_k.unsqueeze(2) * mask_k.unsqueeze(1) # filter out low similarity segments
  31. loss_k = nn.MSELoss()(sim_k, sim_target_k)
  32. loss += loss_k
  33. return loss, sim_k
  34. """
  35. * segment_features is a list of segment feature tensors, each of shape [Hk, D]
  36. * k_inf is the index of the grouping stage at which to perform inference
  37. """
  38. def inference(self, segment_features, masks):
  39. similarity_matrix = torch.cosine_similarity(segment_features.unsqueeze(0), segment_features.unsqueeze(1), dim=-1)
  40. cluster = SpectrumClustering(8)
  41. labels = cluster(similarity_matrix)
  42. print(labels)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...