Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

train_model.py 1.4 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
  1. import pandas as pd
  2. from . import shared
  3. def fit_model(params: dict):
  4. print("Loading training data")
  5. X, y = shared.load_data_and_labels(shared.train_processed)
  6. print("Done")
  7. from sklearn.linear_model import LogisticRegression as Classifier
  8. clf = Classifier(**params)
  9. print("Training model ", clf)
  10. # Required for efficient training, so that sklearn doesn't inflate the pandas sparse DF to a dense matrix.
  11. # sklearn only supports scipy sparse matrices.
  12. X_sparse = X.sparse.to_coo()
  13. clf.fit(X_sparse, y)
  14. print("Done")
  15. return X, y, clf
  16. def eval_on_train_data(clf, X, y):
  17. return shared.compute_metrics(clf, X, y, "train")
  18. def main(params: dict):
  19. X, y, clf = fit_model(params)
  20. metrics = eval_on_train_data(clf, X, y)
  21. from dagshub import dagshub_logger
  22. with dagshub_logger() as logger:
  23. logger.log_hyperparams(clf.get_params(), classifier_type=type(clf).__name__)
  24. logger.log_metrics(metrics)
  25. shared.save(clf, shared.classifier_pkl)
  26. # For possible interactive use
  27. return X, y, clf, metrics
  28. if __name__ == "__main__":
  29. def param_key(arg):
  30. prefix = '--param-'
  31. if arg.startswith(prefix):
  32. arg = arg[len(prefix):]
  33. arg = arg.replace('-','_')
  34. return arg
  35. import sys
  36. params = {param_key(k):v for k,v in zip(sys.argv[1::2],sys.argv[2::2])}
  37. print('Running training with params: ', params)
  38. main(params)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...