Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

sample_residuals.py 963 B

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
  1. #!/usr/bin/env python3
  2. from model import GPT, GPTConfig
  3. import torch
  4. device = 'cuda'
  5. block_size = 256
  6. n_layer = 6
  7. n_head = 6
  8. n_embd = 384
  9. dropout = 0.2
  10. residuals = []
  11. def hook(module, i, o):
  12. residuals.append(i[0].view(-1, n_embd).detach())
  13. if __name__ == '__main__':
  14. gptconf = GPTConfig(block_size=block_size,
  15. vocab_size=3,
  16. n_layer=n_layer,
  17. n_head=n_head,
  18. n_embd=n_embd,
  19. dropout=dropout,
  20. bias=False)
  21. model = GPT(gptconf)
  22. model = torch.compile(model).to(device)
  23. model.load_state_dict(torch.load('out-mess3/ckpt.pt', map_location=device)['model'])
  24. model.lm_head.register_forward_hook(hook)
  25. tokens = torch.load('mess3-tokens.pt')
  26. tokens = tokens.view(-1, tokens.size(2))[:, :-1].to(device)
  27. for token_subset in tokens.chunk(10):
  28. with torch.no_grad():
  29. model(token_subset)
  30. torch.save(torch.vstack(residuals), 'mess3-residuals.pt')
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...