Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

train_gpt2.py 681 B

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
  1. # config for training GPT-2 (124M) down to very nice loss of ~2.85 on 1 node of 8X A100 40GB
  2. # launch as the following (e.g. in a screen session) and wait ~5 days:
  3. # $ torchrun --standalone --nproc_per_node=8 train.py config/train_gpt2.py
  4. wandb_log = True
  5. wandb_project = 'owt'
  6. wandb_run_name='gpt2-124M'
  7. # these make the total batch size be ~0.5M
  8. # 12 batch size * 1024 block size * 5 gradaccum * 8 GPUs = 491,520
  9. batch_size = 12
  10. block_size = 1024
  11. gradient_accumulation_steps = 5 * 8
  12. # this makes total number of tokens be 300B
  13. max_iters = 600000
  14. lr_decay_iters = 600000
  15. # eval stuff
  16. eval_interval = 1000
  17. eval_iters = 200
  18. log_interval = 10
  19. # weight decay
  20. weight_decay = 1e-1
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...