Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

utils.py 1.5 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
  1. import os
  2. import matplotlib.pyplot as plt
  3. import pandas as pd
  4. from dagshub import dagshub_logger
  5. from joblib import dump, load
  6. from sklearn.metrics import plot_confusion_matrix
  7. def log_experiment(out_path, params: dict, metrics: dict):
  8. with dagshub_logger(metrics_path=f'{out_path}metrics.csv', hparams_path=f'{out_path}params.yml') as logger:
  9. logger.log_hyperparams(params=params)
  10. logger.log_metrics(metrics=metrics)
  11. def print_results(accuracy, c_matrix, model_name=''):
  12. print(f'Finished Training {model_name}:\nStats:')
  13. print(f'\tConfusion Matrix:\n{c_matrix}')
  14. print(f'\tModel Accuracy: {accuracy}')
  15. def evaluate_model(model, X_test, y_test):
  16. cmd = plot_confusion_matrix(model, X_test, y_test, cmap=plt.cm.Reds)
  17. c_matrix = cmd.confusion_matrix
  18. accuracy = model.score(X_test, y_test)
  19. return accuracy, c_matrix, cmd.figure_
  20. def save_results(out_path, model, fig):
  21. if not os.path.isdir(out_path):
  22. os.makedirs(out_path)
  23. dump(model, f'{out_path}model.gz')
  24. fig.savefig(f'{out_path}confusion_matrix.svg', format='svg')
  25. def read_data(data_path: str) -> (pd.DataFrame, pd.DataFrame, pd.Series, pd.Series):
  26. train = pd.read_csv(f'{data_path}train.csv')
  27. test = pd.read_csv(f'{data_path}test.csv')
  28. X_train, y_train = train.drop(columns=['class']), train['class']
  29. X_test, y_test = test.drop(columns=['class']), test['class']
  30. return X_train, X_test, y_train, y_test
  31. def load_model(path):
  32. return load(f'{path}/model.gz')
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...