Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

train_model.py 1.2 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
  1. import pandas as pd
  2. import numpy as np
  3. from catboost import CatBoostClassifier, Pool
  4. from dagshub import dagshub_logger
  5. from catboost import CatBoostClassifier
  6. from sklearn.metrics import fbeta_score, precision_recall_curve
  7. from catboost.utils import get_roc_curve, select_threshold
  8. import pickle
  9. def eval_model(model):
  10. with dagshub_logger() as logger:
  11. logger.log_hyperparams(model_class=type(model).__name__)
  12. logger.log_hyperparams({'model': model.get_params()})
  13. model.fit(X_train, y_train)
  14. train_preds = model.predict(X_train)
  15. test_preds = model.predict(X_val)
  16. logger.log_metrics(
  17. {'train__f2': fbeta_score(y_train, train_preds, beta=2), 'val_f2': fbeta_score(y_val, test_preds, beta=2)})
  18. X_train = np.load('processed_data\\X_train_processed.npy', allow_pickle=True)
  19. X_val = np.load('processed_data\\X_val_processed.npy', allow_pickle=True)
  20. y_train = np.load('processed_data\\y_train_processed.npy', allow_pickle=True)
  21. y_val = np.load('processed_data\\y_val_processed.npy', allow_pickle=True)
  22. model = CatBoostClassifier(iterations=200, cat_features=[0,1,2,3,4,5], random_state=0, scale_pos_weight=13.791)
  23. eval_model(model)
  24. with open('trained_models\\catboost.pk1', 'wb') as f:
  25. pickle.dump(model,f)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...