Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

train_model.py 1.6 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
  1. """
  2. Train classification model for MNIST
  3. """
  4. import json
  5. import pickle
  6. import numpy as np
  7. import time
  8. import sklearn
  9. from sklearn.linear_model import LogisticRegression
  10. def train_model():
  11. # Measure training time
  12. start_time = time.time()
  13. # Load training data
  14. print("Load training data...")
  15. train_data = np.load('./data/processed_train_data.npy')
  16. # Choose a random sample of images from the training data.
  17. # This is important since SVM training time increases quadratically with the number of training samples.
  18. print("Choosing smaller sample to shorten training time...")
  19. # Set a random seed so that we get the same "random" choices when we try to recreate the experiment.
  20. np.random.seed(42)
  21. num_samples = 5000
  22. choice = np.random.choice(train_data.shape[0], num_samples, replace=False)
  23. train_data = train_data[choice, :]
  24. # Divide loaded data-set into data and labels
  25. labels = train_data[:, 0]
  26. data = train_data[:, 1:]
  27. print("done.")
  28. # Define SVM classifier and train model
  29. print("Training model...")
  30. model = LogisticRegression(random_state=0)
  31. model.fit(data, labels)
  32. print("done.")
  33. # Save model as pkl
  34. print("Save model and training time metric...")
  35. with open("./data/model3.pkl", 'wb') as f:
  36. pickle.dump(model, f)
  37. # End training time measurement
  38. end_time = time.time()
  39. # Create metric for model training time
  40. with open('./metrics/train_metric.json', 'w') as f:
  41. json.dump({'training_time': end_time - start_time}, f)
  42. print("done.")
  43. if __name__ == '__main__':
  44. train_model()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...