Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

utils.py 1.7 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
  1. import os
  2. import matplotlib.pyplot as plt
  3. import pandas as pd
  4. import yaml
  5. from dagshub import dagshub_logger
  6. from joblib import dump, load
  7. from sklearn.metrics import plot_confusion_matrix
  8. from yaml import CLoader as Loader
  9. def log_experiment(out_path, params: dict, metrics: dict):
  10. with dagshub_logger(metrics_path=f'{out_path}metrics.csv', hparams_path=f'{out_path}params.yml') as logger:
  11. logger.log_hyperparams(params=params)
  12. logger.log_metrics(metrics=metrics)
  13. def print_results(accuracy, c_matrix, model_name=''):
  14. print(f'Finished Training {model_name}:\nStats:')
  15. print(f'\tConfusion Matrix:\n{c_matrix}')
  16. print(f'\tModel Accuracy: {accuracy}')
  17. def evaluate_model(model, X_test, y_test):
  18. cmd = plot_confusion_matrix(model, X_test, y_test, cmap=plt.cm.Reds)
  19. c_matrix = cmd.confusion_matrix
  20. accuracy = model.score(X_test, y_test)
  21. return accuracy, c_matrix, cmd.figure_
  22. def save_results(out_path, model, fig):
  23. if not os.path.isdir(out_path):
  24. os.makedirs(out_path)
  25. dump(model, f'{out_path}model.gz')
  26. if fig:
  27. fig.savefig(f'{out_path}confusion_matrix.svg', format='svg')
  28. def read_data(data_path: str) -> (pd.DataFrame, pd.DataFrame, pd.Series, pd.Series):
  29. train = pd.read_csv(f'{data_path}train.csv')
  30. test = pd.read_csv(f'{data_path}test.csv')
  31. X_train, y_train = train.drop(columns=['class']), train['class']
  32. X_test, y_test = test.drop(columns=['class']), test['class']
  33. return X_train, X_test, y_train, y_test
  34. def load_model(path):
  35. return load(f'{path}/model.gz')
  36. def read_params(file='params.yaml', model='pca'):
  37. with open(file, 'r') as fp:
  38. params = yaml.load(fp, Loader)
  39. return params[model]
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...