Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

train_tfidf_logistic.py 1.2 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
  1. import json
  2. import joblib
  3. import numpy as np
  4. from scipy.sparse import load_npz
  5. from sklearn.pipeline import Pipeline
  6. from news_cat.config import get_app_settings
  7. from news_cat.ml.classifiers import tfidf_logistic_classifier
  8. from news_cat.ml.config import MLConfig
  9. def train_tfidf_logistic():
  10. print("Loading data...")
  11. cfg = get_app_settings()
  12. trainX = load_npz(cfg.data_dir.joinpath(MLConfig.embedding.train_tfidf))
  13. testX = load_npz(cfg.data_dir.joinpath(MLConfig.embedding.test_tfidf))
  14. trainY = np.load(cfg.data_dir.joinpath(MLConfig.embedding.trainY))
  15. testY = np.load(cfg.data_dir.joinpath(MLConfig.embedding.testY))
  16. print("Training the Logistic classifier...")
  17. clf, eval_metrics = tfidf_logistic_classifier(trainX, trainY, testX, testY)
  18. print("Saving the model and associated metrics...")
  19. tfidf = joblib.load(cfg.artifact_dir.joinpath(MLConfig.embedding.tfidf_vectorizer))
  20. joblib.dump(
  21. Pipeline([("tfidf", tfidf), ("logistic_clf", clf)]),
  22. cfg.artifact_dir.joinpath("model_logistic.jlib"),
  23. )
  24. with cfg.metrics_dir.joinpath("metrics_logistic.json").open("w") as f:
  25. json.dump(eval_metrics, f)
  26. print("Done")
  27. if __name__ == "__main__":
  28. train_tfidf_logistic()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...