Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

distrib2.py 2.3 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
  1. from fastai.basics import *
  2. from fastai.distributed import *
  3. from reformer_fastai.all import *
  4. from reformer_fastai.expscript import *
  5. import time
  6. @call_parse
  7. def run_exp(
  8. data_path: Param(help="Path to data folder", type=str, default='./data'),
  9. n_epochs: Param(help="Number of epochs", type=int, default=1),
  10. # lr:Param(help="Learning rate", type=float, default=1e-3),
  11. bs: Param(help="Batch size", type=int, default=4),
  12. sl: Param(help="Sequence length", type=int, default=512),
  13. max_seq_len:Param(help="Max sequence length for model embedding and dataloader", type=int, default=8192),
  14. axial_shape:Param(help="Axial Positional Encoding shape, passed as a string, e.g. '64,32''", type=str, default='128,64'),
  15. do_wandb_logging: Param(help="Use weights and biases logging", type=bool_arg, default=False),
  16. run_name: Param(help="Run name for wandb tracking and model filename", type=str, default=''),
  17. wandb_group:Param(help="wandb group", type=str, default='TEST'),
  18. wandb_notes:Param(help="wandb notes", type=str, default='My experiment notes'),
  19. wandb_tags: Param(help="wandb tags, add tags in a single string, space separated", type=str, default='test'),
  20. save_model: Param(help="Save model locally in /models", type=bool_arg, default=False),
  21. # grad_accum:Param(help="Gradient Accumulation, set greater than 1 to implement", type=int, default=1),
  22. ):
  23. print('Loading data...')
  24. path = rank0_first(download_enwik8_data, dest=data_path)
  25. print('Preparing dataloaders...')
  26. dls = rank0_first(get_enwik8_dataloader, data_path=data_path, bs=bs, sl=sl, n_workers=None)
  27. axial_shape = tuple(map(int, axial_shape.split(',')))
  28. config = TransformerLMConfigEnwik8(warn=False, verbose=True, max_seq_len=max_seq_len, axial_shape=axial_shape)
  29. model = TransformerLM.from_config(config)
  30. learn = get_lm_learner(dls, model)
  31. cbs = []
  32. if do_wandb_logging:
  33. wandb_run, cbs = init_wandb(cbs, wandb_name=run_name, wandb_group=wandb_group,
  34. wandb_notes=wandb_notes, wandb_tags=wandb_tags)
  35. print('Training...')
  36. with learn.distrib_ctx(): learn.fit(n_epochs, cbs=cbs)
  37. if save_model:
  38. now = time.strftime("_%d_%m_%Y_%H:%M", time.gmtime())
  39. learn.save(f'{run_name}_{now}')
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...