Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

transfer-learning-resnet.py 7.7 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
  1. #%%
  2. import torch
  3. from torchvision import datasets, models, transforms
  4. #%%
  5. mean = [0.485, 0.456, 0.406] # images fed to pre-trained models have to be normalized using these parameters https://pytorch.org/docs/stable/torchvision/models.html#id3
  6. std = [0.229, 0.224, 0.225]
  7. #%%
  8. train_transform = transforms.Compose([transforms.Resize(256), transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean, std)]) # perform arbitrary transforms and normalize the input images to be fed into the pretrained model
  9. test_transform = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean, std)]) # no random flipping or cropping
  10. #%%
  11. import zipfile
  12. zip = zipfile.ZipFile('C:/Users/pcuci/Downloads/pytorch-building-deep-learning-models/datasets/flowers_.zip')
  13. zip.extractall('datasets')
  14. data_dir = 'datasets/flowers_'
  15. image_datasets = {}
  16. image_datasets['train'] = datasets.ImageFolder(data_dir + '/train', train_transform) # applies a series of tranformations to an image folder path
  17. image_datasets['test'] = datasets.ImageFolder(data_dir + '/test', test_transform)
  18. print("Training data size: %d" % len(image_datasets['train']))
  19. print("Test data size: %d" % len(image_datasets['test']))
  20. #%%
  21. class_names = image_datasets['train'].classes
  22. class_names # 5 types of flowers
  23. #%%
  24. image_datasets # a dictionary with two keys: train, test
  25. #%%
  26. dataloaders = {} # used to iterate over the datasets
  27. dataloaders['train'] = torch.utils.data.DataLoader(image_datasets['train'], batch_size=8, shuffle=True, num_workers=4)
  28. dataloaders['test'] = torch.utils.data.DataLoader(image_datasets['test'], batch_size=8, shuffle=True, num_workers=4)
  29. dataloaders
  30. #%%
  31. # input images to pre-trained models should be in the format [batch_size, num_channels, height, width]
  32. inputs, labels = next(iter(dataloaders['train']))
  33. inputs.shape # a 4D tensor
  34. #%%
  35. labels # numeric values of 0 to 4 corresponding to the 5 categories of flowers
  36. #%%
  37. import torchvision
  38. inp = torchvision.utils.make_grid(inputs)
  39. inp.shape # stacked all images side by side
  40. #%%
  41. inp.max() # however, Matplotlib requires floating point RGB values to be in the 0-1 range
  42. #%%
  43. import numpy as np
  44. np.clip(inp, 0, 1).max()
  45. #%%
  46. inp.numpy().transpose((1, 2, 0)).shape # matplotlib expects channels in the last dimension
  47. #%%
  48. import matplotlib.pyplot as plt
  49. plt.ion()
  50. def img_show(inp, title=None):
  51. inp = inp.numpy().transpose((1, 2, 0))
  52. inp = std * inp + mean # denormalize the image
  53. inp = np.clip(inp, 0, 1)
  54. plt.figure(figsize=(16, 4))
  55. plt.axis('off')
  56. plt.imshow(inp)
  57. if title is not None:
  58. plt.title(title)
  59. #%%
  60. img_show(inp, title=[class_names[x] for x in labels])
  61. #%%
  62. model = models.resnet18(pretrained=True)
  63. num_ftrs = model.fc.in_features
  64. num_ftrs
  65. #%%
  66. import torch.nn as nn
  67. model.fc = nn.Linear(num_ftrs, 5) # 512 features as input to classify into 5 categories; this replaces the existing linear layer in the model
  68. import torch.optim as optim
  69. optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) # use momentum to accelerate model convergence
  70. #%%
  71. from torch.optim import lr_scheduler # learning rate scheduler
  72. exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1) # the learning rate scheduler which decays the learning rate as we get close to convergence, reduce the learning rate by 0.1 every 7 epochs
  73. #%%
  74. def calculate_accuracy(phase, running_loss, running_corrects):
  75. epoch_loss = running_loss / len(image_datasets[phase])
  76. epoch_acc = running_corrects.double() / len(image_datasets[phase])
  77. print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
  78. return (epoch_loss, epoch_acc)
  79. #%%
  80. def phase_train(model, criterion, optimizer, scheduler): # the training phase
  81. scheduler.step()
  82. model.train()
  83. running_loss = 0.0
  84. running_corrects = 0
  85. for inputs, labels in dataloaders['train']:
  86. optimizer.zero_grad()
  87. with torch.set_grad_enabled(True):
  88. outputs = model(inputs)
  89. _, preds = torch.max(outputs, 1)
  90. loss = criterion(outputs, labels) # calculate the cross entropy loss
  91. loss.backward() # calculate gradients
  92. optimizer.step() # update model parameters
  93. running_loss += loss.item() * inputs.size(0)
  94. running_corrects += torch.sum(preds == labels.data)
  95. calculate_accuracy('train', running_loss, running_corrects)
  96. #%%
  97. import copy
  98. criterion = nn.CrossEntropyLoss()
  99. best_acc = 0.0 # save only the best model parameters on test data
  100. def phase_test(model, criterion, optimizer):
  101. model.eval() # to run the model in the test phase
  102. running_loss = 0.0
  103. running_corrects = 0
  104. global best_acc # keep track of the model weights which produce the best accuracy on the test data
  105. for inputs, labels in dataloaders['test']:
  106. optimizer.zero_grad()
  107. with torch.no_grad(): # don't calculate gradients in the test phase
  108. outputs = model(inputs)
  109. _, preds = torch.max(outputs, 1)
  110. loss = criterion(outputs, labels)
  111. running_loss += loss.item() * inputs.size(0)
  112. running_corrects += torch.sum(preds == labels.data)
  113. epoch_loss, epoch_acc = calculate_accuracy('test', running_loss, running_corrects)
  114. if epoch_acc > best_acc:
  115. best_acc = epoch_acc
  116. best_model_wts = copy.deepcopy(model.state_dict())
  117. return best_model_wts
  118. #%%
  119. def build_model(model, criterion, optimizer, scheduler, num_epochs=10): # train the model with the flowers dataset
  120. best_model_wts = copy.deepcopy(model.state_dict())
  121. for epoch in range(num_epochs):
  122. print('Epoch {}/{}'.format(epoch, num_epochs -1))
  123. print('-' * 10)
  124. phase_train(model, criterion, optimizer, scheduler)
  125. best_model_wts = phase_test(model, criterion, optimizer)
  126. print()
  127. print('Best test Acc: {:4f}'.format(best_acc))
  128. model.load_state_dict(best_model_wts) # at the end of the training, load the model that has the best accuracy in test
  129. return model
  130. #%%
  131. model = build_model(model, criterion, optimizer, exp_lr_scheduler, num_epochs=1)
  132. #%%
  133. # run the model for predictions
  134. with torch.no_grad():
  135. # retrieve one batch of test images
  136. inputs, labels = iter(dataloaders['test']).next()
  137. inp = torchvision.utils.make_grid(inputs) # turn them into a grid
  138. outputs = model(inputs)
  139. _, preds = torch.max(outputs, 1)
  140. for j in range(len(inputs)): # display the predicted label for each image
  141. inp = inputs.data[j]
  142. img_show(inp, 'predicted:' + class_names[preds[j]])
  143. #%%
  144. # no need to train on all the layers (more typical usecase)
  145. frozen_model = models.resnet18(pretrained=True)
  146. for param in frozen_model.parameters():
  147. param.requires_grad = False # freezes model weights so they don't get updated during training
  148. frozen_model.fc = nn.Linear(num_ftrs, 5) # replace the last layer, which also
  149. optimizer = optim.SGD(frozen_model.fc.parameters(), lr=0.001, momentum=0.9)
  150. exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
  151. criterion = nn.CrossEntropyLoss()
  152. best_acc = 0
  153. #%%
  154. frozen_model = build_model(frozen_model, criterion, optimizer, exp_lr_scheduler, num_epochs=1) # the accuracy is less because of the frozen layers
  155. #%%
  156. with torch.no_grad():
  157. inputs, labels = iter(dataloaders['test']).next()
  158. inp = torchvision.utils.make_grid(inputs)
  159. outputs = frozen_model(inputs)
  160. _, preds = torch.max(outputs, 1)
  161. for j in range(len(inputs)):
  162. inp = inputs.data[j]
  163. img_show(inp, 'predicted:' + class_names[preds[j]])
  164. #%%
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...