Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

utils.py 3.0 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
  1. import os
  2. from collections.abc import Iterable
  3. from pathlib import Path
  4. from typing import Union
  5. import matplotlib.pyplot as plt
  6. import pandas as pd
  7. import yaml
  8. from joblib import dump
  9. from sklearn.metrics import plot_confusion_matrix
  10. from yaml import Loader
  11. def read_yaml(
  12. file: Union[str, Path], key: str = None, default: Union[str, dict] = None
  13. ) -> dict:
  14. """
  15. Read yaml file and return `dict`.
  16. Args:
  17. file: `str` or `Path`. Yaml file path.
  18. key: `str`. Yaml key you want to read.
  19. default: `str` or `dict`. Yaml key or default dict to use as default values.
  20. Returns:
  21. Yaml file content as `dict` object.
  22. """
  23. with open(file, "r") as fp:
  24. params = yaml.load(fp, Loader)
  25. default = (
  26. default
  27. if isinstance(default, dict)
  28. else (params[default] if isinstance(default, str) else dict())
  29. )
  30. result = params[key] if key else params
  31. return {**default, **result}
  32. def dump_yaml(
  33. obj: dict, file_path: Union[str, Path], key: str = None, norm: bool = True
  34. ) -> Path:
  35. """
  36. Write yaml file and return `Path`.
  37. Args:
  38. obj: `dict` to write to yaml file.
  39. file: `str` or `Path`. Yaml file path.
  40. key: `str`. dict key you want to write.
  41. norm: `bool`. flag to normalize float values or not.
  42. Returns:
  43. `Path` of yaml file after writing.
  44. """
  45. obj = obj[key] if key else obj
  46. if norm:
  47. obj = normalize(obj)
  48. with open(file_path, "w+") as file:
  49. yaml.dump(obj, file)
  50. return Path(file_path)
  51. def normalize(obj: dict, ndigits: int = 4) -> dict:
  52. """Normalizes float values to `ndigits` decimal places"""
  53. if isinstance(obj, (float,)):
  54. return round(obj, ndigits)
  55. if isinstance(obj, (str,)):
  56. return obj
  57. if isinstance(obj, dict):
  58. for key, value in obj.items():
  59. obj[key] = normalize(value, ndigits)
  60. return obj
  61. if isinstance(obj, Iterable):
  62. return [normalize(x, ndigits) for x in obj]
  63. return obj
  64. def print_results(accuracy, c_matrix, model_name=""):
  65. print(f"Finished Training {model_name}:\nStats:")
  66. print(f"\tConfusion Matrix:\n{c_matrix}")
  67. print(f"\tModel Accuracy: {accuracy}")
  68. def evaluate_model(model, X_test, y_test):
  69. cmd = plot_confusion_matrix(model, X_test, y_test, cmap=plt.cm.Reds)
  70. c_matrix = cmd.confusion_matrix.tolist()
  71. accuracy = model.score(X_test, y_test)
  72. return float(accuracy), c_matrix, cmd.figure_
  73. def save_results(out_path, model, fig):
  74. if not os.path.isdir(out_path):
  75. os.makedirs(out_path)
  76. dump(model, f"{out_path}model.gz")
  77. if fig:
  78. fig.savefig(f"{out_path}confusion_matrix.svg", format="svg")
  79. def read_data(data_path: str) -> (pd.DataFrame, pd.DataFrame, pd.Series, pd.Series):
  80. train = pd.read_csv(f"{data_path}train.csv")
  81. test = pd.read_csv(f"{data_path}test.csv")
  82. X_train, y_train = train.drop(columns=["class"]), train["class"]
  83. X_test, y_test = test.drop(columns=["class"]), test["class"]
  84. return X_train, X_test, y_train, y_test
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...