Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

sgd.py 2.0 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
  1. """Code to classify the source of a signal using the given features,
  2. using SGD classifier."""
  3. # Figure out way to get rid of sklearn deprecation warning
  4. import pandas as pd
  5. import numpy as np
  6. import sklearn as sk
  7. import time
  8. import json
  9. import pickle
  10. def sgd():
  11. print("Reading in dataset and preparing...")
  12. s = pd.read_csv('./data/processed/stars_proc.csv',
  13. low_memory=False)
  14. # Drop Unnamed: 0, because it doesn't seem to drop for some reason
  15. s = s.drop(['Unnamed: 0'], axis=1)
  16. print("Splitting into labels and features...")
  17. y = s['Source']
  18. x = s[['Frequency(MHz)', 'Signal_to_noise_ratio', 'Drift_rate(Hz/sec)']]
  19. print("Scaling data...")
  20. # Using robust scaling due to high kurtosis; robust to outliers
  21. from sklearn.preprocessing import RobustScaler
  22. rb_sc = RobustScaler()
  23. x = rb_sc.fit_transform(x)
  24. print("Splitting into training and test sets...")
  25. from sklearn.model_selection import train_test_split
  26. x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=1)
  27. start_time = time.time()
  28. print("Training model...")
  29. from sklearn.linear_model import SGDClassifier
  30. sgd_class = SGDClassifier(n_jobs = -1, max_iter=10)
  31. sgd_class.fit(x_train, y_train)
  32. y_pred = sgd_class.predict(x_test)
  33. from sklearn.metrics import accuracy_score
  34. acc_sgd = accuracy_score(y_test, y_pred)
  35. end_time = time.time()
  36. print("Saving model and standard scaler...")
  37. with open("./models/sgd.pkl", 'wb') as f:
  38. pickle.dump(sgd_class, f)
  39. with open("./models/rb_sc_sgd.pkl", 'wb') as f:
  40. pickle.dump(rb_sc, f)
  41. print("Saving accuracy and training time...")
  42. with open("./metrics/sgd_acc.json", 'w') as f:
  43. json.dump({'SGD_accuracy': acc_sgd}, f)
  44. with open("./metrics/sgd_time.json", 'w') as f:
  45. json.dump({'SGD_training_time': end_time - start_time}, f)
  46. print("Accuracy: " + str(acc_sgd*100) + " %")
  47. print("Training time: " + str(end_time - start_time) + " seconds")
  48. if __name__ == '__main__':
  49. sgd()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...