Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

segment.py 3.3 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
  1. import pickle
  2. from typing import Tuple
  3. import matplotlib.pyplot as plt
  4. import numpy as np
  5. import pandas as pd
  6. from omegaconf import DictConfig
  7. from sklearn.cluster import KMeans
  8. from sklearn.decomposition import PCA
  9. from yellowbrick.cluster import KElbowVisualizer
  10. from logger import BaseLogger
  11. import mlflow
  12. import hydra
  13. def get_pca_model(data: pd.DataFrame) -> PCA:
  14. pca = PCA(n_components=3)
  15. pca.fit(data)
  16. return pca
  17. def reduce_dimension(df: pd.DataFrame, pca: PCA) -> pd.DataFrame:
  18. return pd.DataFrame(pca.transform(df), columns=["col1", "col2", "col3"])
  19. def get_3d_projection(pca_df: pd.DataFrame) -> dict:
  20. """A 3D Projection Of Data In The Reduced Dimensionality Space"""
  21. return {"x": pca_df["col1"], "y": pca_df["col2"], "z": pca_df["col3"]}
  22. def get_best_k_cluster(
  23. pca_df: pd.DataFrame, image_path: str, logger: BaseLogger
  24. ) -> pd.DataFrame:
  25. fig = plt.figure(figsize=(10, 8))
  26. fig.add_subplot(111)
  27. elbow = KElbowVisualizer(KMeans(), metric="distortion")
  28. elbow.fit(pca_df)
  29. elbow.fig.savefig(image_path)
  30. k_best = elbow.elbow_value_
  31. # Log
  32. logger.log_metrics(
  33. {
  34. "k_best": k_best,
  35. "score_best": elbow.elbow_score_,
  36. }
  37. )
  38. return k_best
  39. def get_clusters_model(
  40. pca_df: pd.DataFrame, k: int, logger: BaseLogger
  41. ) -> Tuple[pd.DataFrame, pd.DataFrame]:
  42. model = KMeans(n_clusters=k)
  43. # Log model
  44. logger.log_params({"model_class": type(model).__name__})
  45. logger.log_params({"model": model.get_params()})
  46. # Fit model
  47. return model.fit(pca_df)
  48. def predict(model, pca_df: pd.DataFrame):
  49. return model.predict(pca_df)
  50. def insert_clusters_to_df(
  51. df: pd.DataFrame, clusters: np.ndarray
  52. ) -> pd.DataFrame:
  53. return df.assign(clusters=clusters)
  54. def plot_clusters(
  55. pca_df: pd.DataFrame, preds: np.ndarray, projections: dict, image_path: str
  56. ) -> None:
  57. pca_df["clusters"] = preds
  58. plt.figure(figsize=(10, 8))
  59. ax = plt.subplot(111, projection="3d")
  60. ax.scatter(
  61. projections["x"],
  62. projections["y"],
  63. projections["z"],
  64. s=40,
  65. c=pca_df["clusters"],
  66. marker="o",
  67. cmap="Accent",
  68. )
  69. ax.set_title("The Plot Of The Clusters")
  70. plt.savefig(image_path)
  71. @hydra.main(
  72. config_path="../config",
  73. config_name="main",
  74. )
  75. def segment(config: DictConfig) -> None:
  76. # initialize logger
  77. mlflow.set_tracking_uri(
  78. "https://dagshub.com/khuyentran1401/dagshub-demo.mlflow"
  79. )
  80. with mlflow.start_run():
  81. logger = BaseLogger()
  82. logger.log_params(dict(config.process))
  83. logger.log_params({"num_columns": len(config.process.keep_columns)})
  84. data = pd.read_csv(config.intermediate.path)
  85. pca = get_pca_model(data)
  86. pca_df = reduce_dimension(data, pca)
  87. projections = get_3d_projection(pca_df)
  88. k_best = get_best_k_cluster(pca_df, config.image.kmeans, logger)
  89. model = get_clusters_model(pca_df, k_best, logger)
  90. preds = predict(model, pca_df)
  91. data = insert_clusters_to_df(data, preds)
  92. plot_clusters(
  93. pca_df,
  94. preds,
  95. projections,
  96. config.image.clusters,
  97. )
  98. data.to_csv(config.final.path, index=False)
  99. pickle.dump(model, open(config.model.path, "wb"))
  100. if __name__ == "__main__":
  101. segment()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...