Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

lit_autoencoder.py 2.3 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
  1. from argparse import ArgumentParser
  2. import torch
  3. from torch import nn
  4. import torch.nn.functional as F
  5. from torch.utils.data import DataLoader
  6. import pytorch_lightning as pl
  7. from torch.utils.data import random_split
  8. from torchvision.datasets.mnist import MNIST
  9. from torchvision import transforms
  10. class LitAutoEncoder(pl.LightningModule):
  11. def __init__(self):
  12. super().__init__()
  13. self.encoder = nn.Sequential(
  14. nn.Linear(28 * 28, 64),
  15. nn.ReLU(),
  16. nn.Linear(64, 3)
  17. )
  18. self.decoder = nn.Sequential(
  19. nn.Linear(3, 64),
  20. nn.ReLU(),
  21. nn.Linear(64, 28 * 28)
  22. )
  23. def forward(self, x):
  24. # in lightning, forward defines the prediction/inference actions
  25. embedding = self.encoder(x)
  26. return embedding
  27. def training_step(self, batch, batch_idx):
  28. x, y = batch
  29. x = x.view(x.size(0), -1)
  30. z = self.encoder(x)
  31. x_hat = self.decoder(z)
  32. loss = F.mse_loss(x_hat, x)
  33. return loss
  34. def configure_optimizers(self):
  35. optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
  36. return optimizer
  37. def cli_main():
  38. pl.seed_everything(1234)
  39. # ------------
  40. # args
  41. # ------------
  42. parser = ArgumentParser()
  43. parser.add_argument('--batch_size', default=32, type=int)
  44. parser.add_argument('--hidden_dim', type=int, default=128)
  45. parser = pl.Trainer.add_argparse_args(parser)
  46. args = parser.parse_args()
  47. # ------------
  48. # data
  49. # ------------
  50. dataset = MNIST('', train=True, download=True, transform=transforms.ToTensor())
  51. mnist_test = MNIST('', train=False, download=True, transform=transforms.ToTensor())
  52. mnist_train, mnist_val = random_split(dataset, [55000, 5000])
  53. train_loader = DataLoader(mnist_train, batch_size=args.batch_size)
  54. val_loader = DataLoader(mnist_val, batch_size=args.batch_size)
  55. test_loader = DataLoader(mnist_test, batch_size=args.batch_size)
  56. # ------------
  57. # model
  58. # ------------
  59. model = LitAutoEncoder()
  60. # ------------
  61. # training
  62. # ------------
  63. trainer = pl.Trainer.from_argparse_args(args)
  64. trainer.fit(model, train_loader, val_loader)
  65. # ------------
  66. # testing
  67. # ------------
  68. result = trainer.test(test_dataloaders=test_loader)
  69. print(result)
  70. if __name__ == '__main__':
  71. cli_main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...