Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

train_model.py 1.9 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
  1. import pandas as pd
  2. from . import shared
  3. import sklearn.linear_model
  4. import numpy as np
  5. class RidgeClassifierPredictProba(sklearn.linear_model.RidgeClassifier):
  6. def __init(self, **kwargs):
  7. super().__init__(self, **kwargs)
  8. def predict_proba(self, X):
  9. # The linear model outputs values in the [-1,+1] range
  10. raw_output = self.decision_function(X)
  11. pos_prob = (raw_output + 1) / 2
  12. negative_prob = 1 - pos_prob
  13. return np.stack([negative_prob, pos_prob], axis=1)
  14. def fit_model(params: dict):
  15. print("Loading training data")
  16. X, y = shared.load_data_and_labels(shared.train_processed)
  17. print("Done")
  18. from sklearn.ensemble import AdaBoostClassifier as Classifier
  19. clf = Classifier(**params)
  20. print("Training model ", clf)
  21. # Required for efficient training, so that sklearn doesn't inflate the pandas sparse DF to a dense matrix.
  22. # sklearn only supports scipy sparse matrices.
  23. X_sparse = X.sparse.to_coo()
  24. clf.fit(X_sparse, y)
  25. print("Done")
  26. return X, y, clf
  27. def eval_on_train_data(clf, X, y):
  28. return shared.compute_metrics(clf, X, y, "train")
  29. def main(params: dict):
  30. X, y, clf = fit_model(params)
  31. shared.save(clf, shared.classifier_pkl)
  32. metrics = eval_on_train_data(clf, X, y)
  33. from dagshub import dagshub_logger
  34. with dagshub_logger() as logger:
  35. logger.log_hyperparams(clf.get_params(), classifier_type=type(clf).__name__)
  36. logger.log_metrics(metrics)
  37. # For possible interactive use
  38. return X, y, clf, metrics
  39. if __name__ == "__main__":
  40. def param_key(arg):
  41. prefix = '--param-'
  42. if arg.startswith(prefix):
  43. arg = arg[len(prefix):]
  44. arg = arg.replace('-','_')
  45. return arg
  46. import sys
  47. params = {param_key(k):v for k,v in zip(sys.argv[1::2],sys.argv[2::2])}
  48. print('Running training with params: ', params)
  49. main(params)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...