Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

bench.py 4.7 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
  1. """
  2. A much shorter version of train.py for benchmarking
  3. """
  4. import os
  5. from contextlib import nullcontext
  6. import numpy as np
  7. import time
  8. import torch
  9. from model import GPTConfig, GPT
  10. # -----------------------------------------------------------------------------
  11. batch_size = 12
  12. block_size = 1024
  13. bias = False
  14. real_data = True
  15. seed = 1337
  16. device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc.
  17. dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16'
  18. compile = True # use PyTorch 2.0 to compile the model to be faster
  19. profile = False # use pytorch profiler, or just simple benchmarking?
  20. exec(open('configurator.py').read()) # overrides from command line or config file
  21. # -----------------------------------------------------------------------------
  22. torch.manual_seed(seed)
  23. torch.cuda.manual_seed(seed)
  24. torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
  25. torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
  26. device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
  27. ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
  28. ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
  29. # data loading init
  30. if real_data:
  31. dataset = 'openwebtext'
  32. data_dir = os.path.join('data', dataset)
  33. train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
  34. def get_batch(split):
  35. data = train_data # note ignore split in benchmarking script
  36. ix = torch.randint(len(data) - block_size, (batch_size,))
  37. x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
  38. y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
  39. x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
  40. return x, y
  41. else:
  42. # alternatively, if fixed data is desired to not care about data loading
  43. x = torch.randint(50304, (batch_size, block_size), device=device)
  44. y = torch.randint(50304, (batch_size, block_size), device=device)
  45. get_batch = lambda split: (x, y)
  46. # model init
  47. gptconf = GPTConfig(
  48. block_size = block_size, # how far back does the model look? i.e. context size
  49. n_layer = 12, n_head = 12, n_embd = 768, # size of the model
  50. dropout = 0, # for determinism
  51. bias = bias,
  52. )
  53. model = GPT(gptconf)
  54. model.to(device)
  55. optimizer = model.configure_optimizers(weight_decay=1e-2, learning_rate=1e-4, betas=(0.9, 0.95), device_type=device_type)
  56. if compile:
  57. print("Compiling model...")
  58. model = torch.compile(model) # pytorch 2.0
  59. if profile:
  60. # useful docs on pytorch profiler:
  61. # - tutorial https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html
  62. # - api https://pytorch.org/docs/stable/profiler.html#torch.profiler.profile
  63. wait, warmup, active = 5, 5, 5
  64. num_steps = wait + warmup + active
  65. with torch.profiler.profile(
  66. activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
  67. schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=1),
  68. on_trace_ready=torch.profiler.tensorboard_trace_handler('./bench_log'),
  69. record_shapes=False,
  70. profile_memory=False,
  71. with_stack=False, # incurs an additional overhead, disable if not needed
  72. with_flops=True,
  73. with_modules=False, # only for torchscript models atm
  74. ) as prof:
  75. X, Y = get_batch('train')
  76. for k in range(num_steps):
  77. with ctx:
  78. logits, loss = model(X, Y)
  79. X, Y = get_batch('train')
  80. optimizer.zero_grad(set_to_none=True)
  81. loss.backward()
  82. optimizer.step()
  83. lossf = loss.item()
  84. print(f"{k}/{num_steps} loss: {lossf:.4f}")
  85. prof.step() # notify the profiler at end of each step
  86. else:
  87. # simple benchmarking
  88. torch.cuda.synchronize()
  89. for stage, num_steps in enumerate([10, 20]): # burnin, then benchmark
  90. t0 = time.time()
  91. X, Y = get_batch('train')
  92. for k in range(num_steps):
  93. with ctx:
  94. logits, loss = model(X, Y)
  95. X, Y = get_batch('train')
  96. optimizer.zero_grad(set_to_none=True)
  97. loss.backward()
  98. optimizer.step()
  99. lossf = loss.item()
  100. print(f"{k}/{num_steps} loss: {lossf:.4f}")
  101. torch.cuda.synchronize()
  102. t1 = time.time()
  103. dt = t1-t0
  104. mfu = model.estimate_mfu(batch_size * 1 * num_steps, dt)
  105. if stage == 1:
  106. print(f"time per iteration: {dt/num_steps*1000:.4f}ms, MFU: {mfu*100:.2f}%")
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...