1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
|
- from argparse import ArgumentParser
- import torch
- import pytorch_lightning as pl
- from torch.nn import functional as F
- from torch.utils.data import DataLoader, random_split
- from torchvision.datasets.mnist import MNIST
- from torchvision import transforms
- class Backbone(torch.nn.Module):
- def __init__(self, hidden_dim=128):
- super().__init__()
- self.l1 = torch.nn.Linear(28 * 28, hidden_dim)
- self.l2 = torch.nn.Linear(hidden_dim, 10)
- def forward(self, x):
- x = x.view(x.size(0), -1)
- x = torch.relu(self.l1(x))
- x = torch.relu(self.l2(x))
- return x
- class LitClassifier(pl.LightningModule):
- def __init__(self, backbone, learning_rate=1e-3):
- super().__init__()
- self.save_hyperparameters()
- self.backbone = backbone
- def forward(self, x):
- # use forward for inference/predictions
- embedding = self.backbone(x)
- return embedding
- def training_step(self, batch, batch_idx):
- x, y = batch
- y_hat = self.backbone(x)
- loss = F.cross_entropy(y_hat, y)
- self.log('train_loss', loss, on_epoch=True)
- return loss
- def validation_step(self, batch, batch_idx):
- x, y = batch
- y_hat = self.backbone(x)
- loss = F.cross_entropy(y_hat, y)
- self.log('valid_loss', loss, on_step=True)
- def test_step(self, batch, batch_idx):
- x, y = batch
- y_hat = self.backbone(x)
- loss = F.cross_entropy(y_hat, y)
- self.log('test_loss', loss)
- def configure_optimizers(self):
- # self.hparams available because we called self.save_hyperparameters()
- return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
- @staticmethod
- def add_model_specific_args(parent_parser):
- parser = ArgumentParser(parents=[parent_parser], add_help=False)
- parser.add_argument('--learning_rate', type=float, default=0.0001)
- return parser
- def cli_main():
- pl.seed_everything(1234)
- # ------------
- # args
- # ------------
- parser = ArgumentParser()
- parser.add_argument('--batch_size', default=32, type=int)
- parser.add_argument('--hidden_dim', type=int, default=128)
- parser = pl.Trainer.add_argparse_args(parser)
- parser = LitClassifier.add_model_specific_args(parser)
- args = parser.parse_args()
- # ------------
- # data
- # ------------
- dataset = MNIST('', train=True, download=True, transform=transforms.ToTensor())
- mnist_test = MNIST('', train=False, download=True, transform=transforms.ToTensor())
- mnist_train, mnist_val = random_split(dataset, [55000, 5000])
- train_loader = DataLoader(mnist_train, batch_size=args.batch_size)
- val_loader = DataLoader(mnist_val, batch_size=args.batch_size)
- test_loader = DataLoader(mnist_test, batch_size=args.batch_size)
- # ------------
- # model
- # ------------
- model = LitClassifier(Backbone(hidden_dim=args.hidden_dim), args.learning_rate)
- # ------------
- # training
- # ------------
- trainer = pl.Trainer.from_argparse_args(args)
- trainer.fit(model, train_loader, val_loader)
- # ------------
- # testing
- # ------------
- result = trainer.test(test_dataloaders=test_loader)
- print(result)
- if __name__ == '__main__':
- cli_main()
|