Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

plot_learning_curves.py 1.4 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
  1. #!/usr/bin/env python
  2. import os
  3. import click
  4. from glob import glob
  5. from tqdm import tqdm
  6. from pathlib import Path
  7. import numpy as np
  8. import torch
  9. import matplotlib.pyplot as plt
  10. @click.command()
  11. @click.argument('inputdir', type=click.Path(
  12. exists=True, dir_okay=True, file_okay=False, path_type=Path
  13. ))
  14. def main(inputdir):
  15. ptfiles = glob(os.path.join(inputdir, '*.pt'))
  16. train_loss = {}
  17. test_loss = {}
  18. for ptfile in tqdm(ptfiles):
  19. checkpoint = torch.load(ptfile, map_location=torch.device('cpu'))
  20. epoch = checkpoint['epoch']
  21. train_loss[epoch] = checkpoint['train_loss']
  22. test_loss[epoch] = (
  23. checkpoint['validation_loss'] if 'validation_loss' in checkpoint
  24. else checkpoint['test_loss']
  25. )
  26. epochs = np.array(sorted(train_loss.keys()))
  27. train_loss_arr = np.array([
  28. train_loss[e] for e in epochs
  29. ])
  30. test_loss_arr = np.array([
  31. test_loss[e] for e in epochs
  32. ])
  33. best_epoch = epochs[np.argmin(test_loss_arr)]
  34. print(f'Best epoch: {best_epoch}')
  35. fig, ax = plt.subplots()
  36. ax.plot(epochs, np.sqrt(train_loss_arr), 'ko-', label='Train Loss')
  37. ax.plot(epochs, np.sqrt(test_loss_arr), 'ro-', label='Validation Loss')
  38. ax.legend(loc='upper right', fontsize=16)
  39. ax.set_xlabel('Epoch', fontsize=16)
  40. ax.set_ylabel('RMSE Loss', fontsize=16)
  41. plt.show()
  42. if __name__ == '__main__':
  43. main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...