Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

train_model.py 3.1 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
  1. """
  2. Train classification model for MNIST
  3. """
  4. import json
  5. import pickle
  6. import numpy as np
  7. import time
  8. import yaml
  9. import os
  10. # New imports
  11. import torch
  12. import torch.utils.data
  13. import torch.nn.functional as F
  14. import torch.optim as optim
  15. from my_torch_model import Net
  16. from dagshub import dagshub_logger
  17. def relpath(path):
  18. return os.path.join(os.path.dirname(__file__), path)
  19. # New function
  20. def train(model, device, train_loader, optimizer, epoch, logger):
  21. log_interval = 100
  22. steps_so_far = (epoch - 1) * len(train_loader)
  23. model.train()
  24. for batch_idx, (data, target) in enumerate(train_loader):
  25. data, target = data.to(device), target.to(device)
  26. optimizer.zero_grad()
  27. output = model(data)
  28. loss = F.nll_loss(output, target)
  29. loss.backward()
  30. logger.log_metrics(loss=loss.item(), step_num=batch_idx + steps_so_far)
  31. optimizer.step()
  32. if batch_idx % log_interval == 0:
  33. print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
  34. epoch, batch_idx * len(data), len(train_loader.dataset),
  35. 100. * batch_idx / len(train_loader), loss.item()))
  36. def train_model(params: dict):
  37. # Setting up network
  38. print("Setting up Params...")
  39. print(params)
  40. device = torch.device("cpu")
  41. batch_size = params['batch_size']
  42. epochs = params['epochs']
  43. learning_rate = params['learning_rate']
  44. momentum = params['momentum']
  45. print("done.")
  46. # Load training data
  47. print("Load training data...")
  48. train_data = np.load('./data/processed_train_data.npy')
  49. # Divide loaded data-set into data and labels
  50. labels = torch.Tensor(train_data[:, 0]).long()
  51. data = torch.Tensor(train_data[:, 1:].reshape([train_data.shape[0], 1, 28, 28]))
  52. torch_train_data = torch.utils.data.TensorDataset(data, labels)
  53. train_loader = torch.utils.data.DataLoader(torch_train_data,
  54. batch_size=batch_size,
  55. shuffle=True)
  56. print("done.")
  57. # Define SVM classifier and train model
  58. print("Training model...")
  59. model = Net(**params).to(device)
  60. optimizer = optim.SGD(model.parameters(),
  61. lr=learning_rate,
  62. momentum=momentum)
  63. with dagshub_logger(relpath('../metrics/train_metrics.csv')) as logger:
  64. # Measure training time
  65. start_time = time.time()
  66. for epoch in range(1, epochs + 1):
  67. logger.log_metrics(epoch=epoch, step_num=((epoch - 1) * len(train_loader)))
  68. train(model, device, train_loader, optimizer, epoch, logger)
  69. print("done.")
  70. print("Save model and training time metric...")
  71. # End training time measurement
  72. end_time = time.time()
  73. logger.log_metrics(training_time=end_time - start_time)
  74. # Save model as pkl
  75. with open("./data/model.pkl", 'wb') as f:
  76. pickle.dump(model, f)
  77. print("done.")
  78. if __name__ == '__main__':
  79. with open(relpath('params.yml')) as f:
  80. params = yaml.safe_load(f)
  81. train_model(params)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...