Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

MLP-MNIST-Test.py 1.9 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
  1. import sys
  2. sys.path.append("../..")
  3. from TorchUtils.DatasetGenerator.FromPublicDatasets import load_public_dataset
  4. from TorchUtils.Trainer.MLPTrainer import MLPClassificationTrainer, MLPClassificationMPTrainer
  5. from TorchUtils.ModelGenerator.MLP import MLP
  6. from TorchUtils.Core.ShapeChecker import check_shape
  7. import torch
  8. import torch.nn as nn
  9. import torch.optim as optim
  10. import torchvision.transforms as transforms
  11. from torch_lr_finder import LRFinder
  12. import warnings
  13. warnings.simplefilter("ignore")
  14. class Model(nn.Module):
  15. def __init__(self):
  16. super(Model, self).__init__()
  17. self.classifier = nn.Sequential(
  18. nn.Linear(28*28, 512),
  19. nn.BatchNorm1d(512),
  20. nn.ReLU(True),
  21. nn.Linear(512, 256),
  22. nn.BatchNorm1d(256),
  23. nn.ReLU(True),
  24. nn.Linear(256, 128),
  25. nn.BatchNorm1d(128),
  26. nn.ReLU(True),
  27. nn.Linear(128, 64),
  28. nn.ReLU(True),
  29. nn.Linear(64, 10),
  30. nn.Softmax()
  31. )
  32. def forward(self, x):
  33. x = x.view(x.size(0), -1)
  34. return self.classifier(x)
  35. if __name__ == "__main__":
  36. transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
  37. train_loader, test_loader = load_public_dataset("MNIST", transform=transform, batch_size=256)
  38. model = Model()
  39. optimizer = optim.Adam(model.parameters(), lr=1e-2)
  40. criterion = nn.CrossEntropyLoss()
  41. # lr_finder = LRFinder(model, optimizer, criterion, device="cpu")
  42. # lr_finder.range_test(train_loader, end_lr=100, num_iter=100, accumulation_steps=1)
  43. # lr_finder.plot()
  44. # lr_finder.reset()
  45. trainer = MLPClassificationTrainer(model, criterion, optimizer)
  46. trainer.fit(train_loader, epochs=2, reshape_size=(-1, 28**2), validation_loader=test_loader)
  47. # trainer.fit(train_loader, epochs=10, reshape_size=(-1, 28**2))
  48. trainer.plot_result()
  49. trainer.save(is_parameter_only=False)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...