Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

train_model.py 1.7 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
  1. import sys
  2. import os
  3. import pandas as pd
  4. from scipy.sparse import load_npz
  5. import pickle
  6. from sklearn.neighbors import KNeighborsClassifier
  7. from sklearn.model_selection import cross_val_score
  8. from dagshub import DAGsHubLogger
  9. from tqdm import tqdm
  10. def train(processed_data_path, model_folder_path):
  11. X_train = load_npz(processed_data_path + "X_train.npz")
  12. y_train = pd.read_csv(processed_data_path + "y_train.csv")
  13. logger = DAGsHubLogger(
  14. metrics_path="reports/training_metrics.csv",
  15. hparams_path="src/models/training_params.yml",
  16. )
  17. val_error_rate = []
  18. neighbors_range = range(1, 500, 5)
  19. for i in tqdm(neighbors_range):
  20. knn = KNeighborsClassifier(n_neighbors=i)
  21. val_error = (
  22. 1 - cross_val_score(knn, X_train, y_train.values.ravel(), cv=2).mean()
  23. )
  24. logger.log_metrics({"val_error": val_error}, step_num=i)
  25. val_error_rate.append(val_error)
  26. best_k = neighbors_range[val_error_rate.index(min(val_error_rate))]
  27. logger.log_hyperparams(best_k=best_k)
  28. knn = KNeighborsClassifier(n_neighbors=best_k)
  29. knn.fit(X_train, y_train.values.ravel())
  30. # source, destination
  31. pickle.dump(knn, open(model_folder_path + "model.pkl", "wb"))
  32. logger.save()
  33. logger.close()
  34. if __name__ == "__main__":
  35. if not (1 <= len(sys.argv) <= 3):
  36. print(
  37. "usage: %s <processed_data_folder (optional)> <out_folder (optional)>"
  38. % sys.argv[0],
  39. file=sys.stderr,
  40. )
  41. sys.exit(0)
  42. in_folder = sys.argv[1] if len(sys.argv) >= 2 else "data/processed/"
  43. out_folder = sys.argv[2] if len(sys.argv) == 3 else "models/"
  44. if not os.path.exists(out_folder):
  45. os.makedirs(out_folder)
  46. train(in_folder, out_folder)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...