1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
|
- #!/usr/bin/env python3
- import plotly.express as px
- from tqdm import tqdm
- import random
- import torch
- class Mess3:
- def __init__(self, batch_size, block_size, device, load_from_cache=True):
- self.batch_size = batch_size
- self.block_size = block_size
- self.device = device
- self.transition_matrices = torch.tensor([
- [[0.765, 0.00375, 0.00375],
- [0.0425, 0.0675, 0.00375],
- [0.0425, 0.00375, 0.0675]],
- [[0.0675, 0.0425, 0.00375],
- [0.00375, 0.765, 0.00375],
- [0.00375, 0.0425, 0.0675]],
- [[0.0675, 0.00375, 0.0425],
- [0.00375, 0.0675, 0.0425],
- [0.00375, 0.00375, 0.76]]])
- self.load_from_cache = load_from_cache
- if self.load_from_cache:
- self.tokens = torch.load('mess3-tokens.pt', map_location=device, weights_only=True)
- self.belief_states = torch.load('mess3-belief_states.pt', map_location=device, weights_only=True)
- self.stationary_belief_state = torch.linalg.eig(self.transition_matrices.sum(0)).eigenvectors[:, 0].real
- def sample(self, n, start_state_idx=0):
- emitted = []
- belief_states = [self.stationary_belief_state.clone(),]
- current_index = start_state_idx
- for _ in range(n):
- emitted.append(self.transition_matrices[current_index].sum(1).multinomial(1).item())
- current_index = self.transition_matrices[current_index].sum(0).multinomial(1).item()
- belief_state = belief_states[-1] @ self.transition_matrices[emitted[-1]]
- belief_state /= belief_state.sum()
- belief_states.append(belief_state)
- return torch.tensor(emitted, dtype=torch.int64), torch.vstack(belief_states[1:])[None,]
- def get_batch(self, split, split_XY=True, return_belief_states=False): # retrofitting train.py, argument ignored
- if self.load_from_cache:
- idx = random.randint(0, self.tokens.size(0) - 1)
- tokens = self.tokens[idx]
- belief_states = self.belief_states[idx]
- else:
- start_indices = self.stationary_belief_state.multinomial(self.batch_size, replacement=True)
- tokens = []
- belief_states = []
- for start_idx in start_indices:
- block, belief_state = self.sample(self.block_size + 1, start_idx)
- tokens.append(block)
- belief_states.append(belief_state)
- tokens = torch.vstack(tokens).to(self.device)
- belief_states = torch.vstack(belief_states).to(self.device)
- if not return_belief_states: return ((tokens[:, :-1], tokens[:, 1:].contiguous()) if split_XY else tokens)
- else: return ((tokens[:, :-1], tokens[:, 1:].contiguous()) if split_XY else tokens), ((belief_states[:, :-1], belief_states[:, 1:].contiguous()) if split_XY else belief_states)
- def save_as_cache(self, n_batches):
- tokens = []
- belief_states = []
- for batch in tqdm(range(n_batches)):
- block, belief_state = self.get_batch('', split_XY=False, return_belief_states=True)
- tokens.append(block[None,])
- belief_states.append(belief_state[None,])
- torch.save(torch.vstack(tokens), 'mess3-tokens.pt')
- torch.save(torch.vstack(belief_states), 'mess3-belief_states.pt')
- if __name__ == '__main__':
- mess3 = Mess3(64, 256, device='cpu', load_from_cache=False)
- mess3.save_as_cache(1000)
- fig = px.scatter_ternary(torch.load('mess3-belief_states.pt', map_location='cpu').view(-1, 3)[:10000], a=0, b=1, c=2)
- fig.show()
|