Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

mess3.py 3.6 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
  1. #!/usr/bin/env python3
  2. import plotly.express as px
  3. from tqdm import tqdm
  4. import random
  5. import torch
  6. class Mess3:
  7. def __init__(self, batch_size, block_size, device, load_from_cache=True):
  8. self.batch_size = batch_size
  9. self.block_size = block_size
  10. self.device = device
  11. self.transition_matrices = torch.tensor([
  12. [[0.765, 0.00375, 0.00375],
  13. [0.0425, 0.0675, 0.00375],
  14. [0.0425, 0.00375, 0.0675]],
  15. [[0.0675, 0.0425, 0.00375],
  16. [0.00375, 0.765, 0.00375],
  17. [0.00375, 0.0425, 0.0675]],
  18. [[0.0675, 0.00375, 0.0425],
  19. [0.00375, 0.0675, 0.0425],
  20. [0.00375, 0.00375, 0.76]]])
  21. self.load_from_cache = load_from_cache
  22. if self.load_from_cache:
  23. self.tokens = torch.load('mess3-tokens.pt', map_location=device, weights_only=True)
  24. self.belief_states = torch.load('mess3-belief_states.pt', map_location=device, weights_only=True)
  25. self.stationary_belief_state = torch.linalg.eig(self.transition_matrices.sum(0)).eigenvectors[:, 0].real
  26. def sample(self, n, start_state_idx=0):
  27. emitted = []
  28. belief_states = [self.stationary_belief_state.clone(),]
  29. current_index = start_state_idx
  30. for _ in range(n):
  31. emitted.append(self.transition_matrices[current_index].sum(1).multinomial(1).item())
  32. current_index = self.transition_matrices[current_index].sum(0).multinomial(1).item()
  33. belief_state = belief_states[-1] @ self.transition_matrices[emitted[-1]]
  34. belief_state /= belief_state.sum()
  35. belief_states.append(belief_state)
  36. return torch.tensor(emitted, dtype=torch.int64), torch.vstack(belief_states[1:])[None,]
  37. def get_batch(self, split, split_XY=True, return_belief_states=False): # retrofitting train.py, argument ignored
  38. if self.load_from_cache:
  39. idx = random.randint(0, self.tokens.size(0) - 1)
  40. tokens = self.tokens[idx]
  41. belief_states = self.belief_states[idx]
  42. else:
  43. start_indices = self.stationary_belief_state.multinomial(self.batch_size, replacement=True)
  44. tokens = []
  45. belief_states = []
  46. for start_idx in start_indices:
  47. block, belief_state = self.sample(self.block_size + 1, start_idx)
  48. tokens.append(block)
  49. belief_states.append(belief_state)
  50. tokens = torch.vstack(tokens).to(self.device)
  51. belief_states = torch.vstack(belief_states).to(self.device)
  52. if not return_belief_states: return ((tokens[:, :-1], tokens[:, 1:].contiguous()) if split_XY else tokens)
  53. else: return ((tokens[:, :-1], tokens[:, 1:].contiguous()) if split_XY else tokens), ((belief_states[:, :-1], belief_states[:, 1:].contiguous()) if split_XY else belief_states)
  54. def save_as_cache(self, n_batches):
  55. tokens = []
  56. belief_states = []
  57. for batch in tqdm(range(n_batches)):
  58. block, belief_state = self.get_batch('', split_XY=False, return_belief_states=True)
  59. tokens.append(block[None,])
  60. belief_states.append(belief_state[None,])
  61. torch.save(torch.vstack(tokens), 'mess3-tokens.pt')
  62. torch.save(torch.vstack(belief_states), 'mess3-belief_states.pt')
  63. if __name__ == '__main__':
  64. mess3 = Mess3(64, 256, device='cpu', load_from_cache=False)
  65. mess3.save_as_cache(1000)
  66. fig = px.scatter_ternary(torch.load('mess3-belief_states.pt', map_location='cpu').view(-1, 3)[:10000], a=0, b=1, c=2)
  67. fig.show()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...