Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

MNIST-AnomalyDetection.py 4.5 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
  1. import sys
  2. sys.path.append("../..")
  3. import matplotlib.pyplot as plt
  4. import numpy as np
  5. import torch
  6. import torch.nn as nn
  7. import torch.optim as optim
  8. import torchvision.transforms as transforms
  9. # from torch_lr_finder import LRFinder
  10. from sklearn.manifold import TSNE
  11. # from TorchUtils.Analyzer.LayerAnalyzer import AnalyzedLinear
  12. from TorchUtils.Analyzer.ManifoldAnalyzer import (LDAAnalyzer, PCAAnalyzer,
  13. SVCAnalyzer,
  14. TruncatedSVDAnalyzer,
  15. TSNEAnalyzer)
  16. from TorchUtils.Core.Arguments import parse_args
  17. from TorchUtils.Core.EnvironmentChecker import get_device_type, convert_device
  18. from TorchUtils.Core.LayerCalculator import calculate_listed_layer
  19. from TorchUtils.Core.ShapeChecker import check_shape
  20. from TorchUtils.DatasetGenerator.FromPublicDatasets import get_custom_MNIST
  21. from TorchUtils.ModelGenerator.MLP import MLP
  22. from TorchUtils.PipeLine.PipeLine import PipeLine
  23. from TorchUtils.Trainer.CNNTrainer import CNNClassificationTrainer
  24. from TorchUtils.Layers.ConvLayers import FireModule
  25. import warnings
  26. warnings.simplefilter("ignore")
  27. class Model(nn.Module):
  28. def __init__(self):
  29. super(Model, self).__init__()
  30. self.conv1 = nn.Conv2d(1, 16, kernel_size=3)
  31. self.pool1 = nn.MaxPool2d(kernel_size=2)
  32. self.activation1 = nn.ReLU(True)
  33. self.features = nn.ModuleList([self.conv1, self.pool1, self.activation1])
  34. self.linear1 = nn.Linear(13*13*16, 512)
  35. self.activation2 = nn.ReLU(True)
  36. self.linear2 = nn.Linear(512, 128)
  37. self.activation3 = nn.ReLU(True)
  38. self.pre_classifier = nn.ModuleList([self.linear1, self.activation2, self.linear2, self.activation3])
  39. # self.linear3 = AnalyzedLinear(128, 10)
  40. self.linear3 = nn.Linear(128, 10)
  41. self.activation4 = nn.Softmax()
  42. self.classifier = nn.ModuleList([self.linear3, self.activation4])
  43. def forward(self, x):
  44. x = calculate_listed_layer(self.features, x)
  45. x = x.view(x.size(0), -1)
  46. x = calculate_listed_layer(self.pre_classifier, x)
  47. return calculate_listed_layer(self.classifier, x)
  48. def predict(model, test_loader, reshape_size=None):
  49. total_outputs = None
  50. total_labels = None
  51. for images, labels in test_loader:
  52. images, labels = convert_device(images, labels)
  53. outputs = calculate_listed_layer(model.features, images)
  54. outputs = outputs.view(outputs.size(0), -1)
  55. outputs = calculate_listed_layer(model.pre_classifier, outputs)
  56. if total_outputs is None:
  57. total_outputs = outputs.detach().numpy()
  58. total_labels = labels.detach().numpy()
  59. total_labels = total_labels.reshape(total_labels.shape[0], 1)
  60. else:
  61. total_outputs = np.vstack((total_outputs, outputs.detach().numpy()))
  62. labels_temp = labels.detach().numpy()
  63. total_labels = np.vstack((total_labels, labels_temp.reshape(labels_temp.shape[0], 1)))
  64. return total_outputs, total_labels
  65. def train_model(args):
  66. transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
  67. train_loader, val_loader, test_loader = get_custom_MNIST(train_labels=[0, 2, 4, 6, 8], transform=transform,
  68. val_rate=0.2)
  69. model = Model().to(get_device_type())
  70. optimizer = optim.SGD(model.parameters(), lr=0.05, weight_decay=1e-5)
  71. criterion = nn.CrossEntropyLoss()
  72. trainer = CNNClassificationTrainer(model, criterion, optimizer)
  73. trainer.fit(train_loader, epochs=args.epochs, validation_loader=val_loader)
  74. trainer.evaluate(test_loader)
  75. # outputs, labels = predict(model, test_loader)
  76. # return outputs, labels
  77. def train_svd(data):
  78. svd = TruncatedSVDAnalyzer()
  79. svd.fit(data[0])
  80. outputs = svd.predict(data[0])
  81. return (outputs, data[1])
  82. def train_svm(data):
  83. svc_visualizer = SVCAnalyzer()
  84. svc_visualizer.fit(data[0], data[1])
  85. svc_visualizer.evaluate(data[0], data[1])
  86. svc_visualizer.plot(data[0], data[1])
  87. def train_lda(data):
  88. lda_visualizer = LDAAnalyzer()
  89. lda_visualizer.fit(data[0], data[1])
  90. lda_visualizer.plot(data[0], data[1])
  91. if __name__ == "__main__":
  92. args = parse_args()
  93. pipeline = PipeLine()
  94. pipeline.add_function(train_model, False, args)
  95. # pipeline.add_function(train_svd)
  96. # pipeline.add_function(train_lda)
  97. # pipeline.add_function(train_svm)
  98. pipeline.execute()
  99. pipeline.save_pipeline()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...