Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

train_model.py 1.2 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
  1. import pandas as pd
  2. from . import shared
  3. def fit_model(params: dict):
  4. tfidf = shared.load(shared.train_tfidf)
  5. y = shared.load_labels(shared.train_data)
  6. from sklearn.dummy import DummyClassifier as Classifier
  7. clf = Classifier(strategy='stratified', **params)
  8. clf.fit(tfidf, y)
  9. return y, clf.predict(tfidf), clf
  10. def eval_on_train_data(y, preds):
  11. return shared.compute_metrics(y, preds, "train")
  12. def main(params: dict):
  13. y, preds, clf = fit_model(params)
  14. metrics = eval_on_train_data(y, preds)
  15. from dagshub import dagshub_logger
  16. with dagshub_logger() as logger:
  17. logger.log_hyperparams(clf.get_params(), classifier_type=type(clf).__name__)
  18. logger.log_metrics(metrics)
  19. shared.save(clf, shared.classifier_pkl)
  20. # For possible interactive use
  21. return y, preds, clf, metrics
  22. if __name__ == "__main__":
  23. def param_key(arg):
  24. prefix = '--param-'
  25. if arg.startswith(prefix):
  26. arg = arg[len(prefix):]
  27. arg = arg.replace('-','_')
  28. return arg
  29. import sys
  30. params = {param_key(k):v for k,v in zip(sys.argv[1::2],sys.argv[2::2])}
  31. print('Running training with params: ', params)
  32. main(params)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...