Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

train_model.py 1.6 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
  1. """
  2. Train classification model for MNIST
  3. """
  4. import json
  5. import pickle
  6. import numpy as np
  7. from sklearn.svm import SVC
  8. from sklearn.multiclass import OneVsRestClassifier
  9. import time
  10. def train_model():
  11. # Measure training time
  12. start_time = time.time()
  13. # Load training data
  14. print("Load training data...")
  15. train_data = np.load('./data/processed_train_data.npy')
  16. # Choose a random sample of images from the training data.
  17. # This is important since SVM training time increases quadratically with the number of training samples.
  18. print("Choosing smaller sample to shorten training time...")
  19. # Set a random seed so that we get the same "random" choices when we try to recreate the experiment.
  20. np.random.seed(42)
  21. num_samples = 5000
  22. choice = np.random.choice(train_data.shape[0], num_samples, replace=False)
  23. train_data = train_data[choice, :]
  24. # Divide loaded data-set into data and labels
  25. labels = train_data[:, 0]
  26. data = train_data[:, 1:]
  27. print("done.")
  28. # Define SVM classifier and train model
  29. print("Training model...")
  30. model = OneVsRestClassifier(SVC(kernel='linear'), n_jobs=6)
  31. model.fit(data, labels)
  32. print("done.")
  33. # Save model as pkl
  34. print("Save model and training time metric...")
  35. with open("./data/model.pkl", 'wb') as f:
  36. pickle.dump(model, f)
  37. # End training time measurement
  38. end_time = time.time()
  39. # Create metric for model training time
  40. with open('./metrics/train_metric.json', 'w') as f:
  41. json.dump({'training_time': end_time - start_time}, f)
  42. print("done.")
  43. if __name__ == '__main__':
  44. train_model()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...