Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

loss1.py 1.6 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
  1. import torch
  2. import numpy as np
  3. import torch.nn.functional as F
  4. import tqdm
  5. torch.manual_seed(4)
  6. # very inefficient but just to check values
  7. b = 64
  8. k = 64
  9. l = 20
  10. n = 196
  11. theta = -0.5
  12. tau = 0.1
  13. x = torch.randn(b, k, n, device='cuda')
  14. y = torch.randn(b, l, n, device='cuda')
  15. x = F.normalize(x, dim=-1)
  16. y = F.normalize(y, dim=-1)
  17. lp = 0
  18. lq = 0
  19. loss = 0
  20. def exp(x):
  21. return torch.exp(x)
  22. def p_k(x, y):
  23. p = 0
  24. count_p = 0
  25. for k_i in range(k):
  26. p_i = 0
  27. for j in range(l):
  28. cos_sim = F.cosine_similarity(x[k_i], y[j], dim=-1)
  29. p_i = max(p_i, cos_sim)
  30. if p_i > theta:
  31. p += p_i
  32. count_p += 1
  33. return p / (count_p + 1e-5)
  34. for b_i in tqdm.tqdm(range(b)):
  35. p_ii = exp(p_k(x[b_i], y[b_i]) / tau)
  36. p_ij = 0
  37. for b_j in range(b):
  38. p_j = exp(p_k(x[b_i], y[b_j]) / tau)
  39. p_ij += p_j
  40. print(torch.log(p_ii / p_ij))
  41. lp += torch.log(p_ii / p_ij)
  42. lp /= b
  43. lp *= -1
  44. print(f"lossq: {lp}")
  45. quit()
  46. def q_k(y, x):
  47. q = 0
  48. count_q = 0
  49. for l_j in range(l):
  50. q_i = 0
  51. for i in range(k):
  52. cos_sim = F.cosine_similarity(x[i].unsqueeze(0), y[l_j].unsqueeze(0), dim=-1)
  53. # print(cos_sim)
  54. q_i = max(q_i, cos_sim)
  55. if q_i > theta:
  56. q += q_i
  57. count_q += 1
  58. return q / (count_q + 1e-5)
  59. q_ijs = []
  60. for b_i in range(b):
  61. q_ii = exp(q_k(y[b_i], x[b_i]) / tau)
  62. q_ij = 0
  63. for b_j in range(b):
  64. q_j = exp(q_k(y[b_i], x[b_j]) / tau)
  65. q_ij += q_j
  66. q_ijs.append(q_ii)
  67. lq += q_ii / q_ij
  68. lq /= b
  69. lq *= -1
  70. loss = 0.5 * (lp + lq)
  71. print(lp)
  72. print(lq)
  73. print(loss)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...