Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

gandlf_collectStats 6.1 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. import os
  4. import argparse
  5. import pandas as pd
  6. import seaborn as sns
  7. import matplotlib.pyplot as plt
  8. from pathlib import Path
  9. from GANDLF.cli import copyrightMessage
  10. def plot_all(df_training, df_validation, df_testing, output_plot_dir):
  11. """
  12. Plots training, validation, and testing data for loss and other metrics.
  13. TODO: this function needs to be moved under utils and then called after every training epoch.
  14. Args:
  15. df_training (pd.DataFrame): DataFrame containing training data.
  16. df_validation (pd.DataFrame): DataFrame containing validation data.
  17. df_testing (pd.DataFrame): DataFrame containing testing data.
  18. output_plot_dir (str): Directory to save the plots.
  19. Returns:
  20. tuple: Tuple containing the modified training, validation, and testing DataFrames.
  21. """
  22. # Drop any columns that might have "_" in the values of their rows
  23. banned_cols = [
  24. col
  25. for col in df_training.columns
  26. if any("_" in str(val) for val in df_training[col].values)
  27. ]
  28. # Determine metrics from the column names by removing the "train_" prefix
  29. metrics = [
  30. col.replace("train_", "")
  31. for col in df_training.columns
  32. if "train_" in col and col not in banned_cols
  33. ]
  34. # Split the values of the banned columns into multiple columns
  35. # for df in [df_training, df_validation, df_testing]:
  36. # for col in banned_cols:
  37. # if df[col].dtype == "object":
  38. # split_cols = (
  39. # df[col]
  40. # .str.split("_", expand=True)
  41. # .apply(pd.to_numeric, errors="coerce")
  42. # )
  43. # split_cols.columns = [f"{col}_{i}" for i in range(split_cols.shape[1])]
  44. # df.drop(columns=col, inplace=True)
  45. # df = pd.concat([df, split_cols], axis=1)
  46. # Check if any of the metrics is present in the column names of the dataframe
  47. assert any(
  48. any(metric in col for col in df_training.columns) for metric in metrics
  49. ), "None of the specified metrics is in the dataframe."
  50. required_cols = ["epoch_no", "train_loss"]
  51. # Check if the required columns are in the dataframe
  52. assert all(
  53. col in df_training.columns for col in required_cols
  54. ), "Not all required columns are in the dataframe."
  55. epochs = len(df_training)
  56. # Plot for loss
  57. plt.figure(figsize=(12, 6))
  58. if "train_loss" in df_training.columns:
  59. sns.lineplot(data=df_training, x="epoch_no", y="train_loss", label="Training")
  60. if "valid_loss" in df_validation.columns:
  61. sns.lineplot(
  62. data=df_validation, x="epoch_no", y="valid_loss", label="Validation"
  63. )
  64. if df_testing is not None and "test_loss" in df_testing.columns:
  65. sns.lineplot(data=df_testing, x="epoch_no", y="test_loss", label="Testing")
  66. plt.xlim(0, epochs - 1)
  67. plt.xlabel("Epoch")
  68. plt.ylabel("Loss")
  69. plt.title("Loss Plot")
  70. plt.legend()
  71. Path(output_plot_dir).mkdir(parents=True, exist_ok=True)
  72. plt.savefig(os.path.join(output_plot_dir, "loss_plot.png"), dpi=300)
  73. plt.close()
  74. # Plot for other metrics
  75. for metric in metrics:
  76. metric_cols = [col for col in df_training.columns if metric in col]
  77. for metric_col in metric_cols:
  78. plt.figure(figsize=(12, 6))
  79. if metric_col in df_training.columns:
  80. sns.lineplot(
  81. data=df_training,
  82. x="epoch_no",
  83. y=metric_col,
  84. label=f"Training {metric_col}",
  85. )
  86. if metric_col.replace("train", "valid") in df_validation.columns:
  87. sns.lineplot(
  88. data=df_validation,
  89. x="epoch_no",
  90. y=metric_col.replace("train", "valid"),
  91. label=f"Validation {metric_col}",
  92. )
  93. if (
  94. df_testing is not None
  95. and metric_col.replace("train", "test") in df_testing.columns
  96. ):
  97. sns.lineplot(
  98. data=df_testing,
  99. x="epoch_no",
  100. y=metric_col.replace("train", "test"),
  101. label=f"Testing {metric_col}",
  102. )
  103. plt.xlim(0, epochs - 1)
  104. plt.xlabel("Epoch")
  105. plt.ylabel(metric.capitalize())
  106. plt.title(f"{metric.capitalize()} Plot")
  107. plt.legend()
  108. plt.savefig(os.path.join(output_plot_dir, f"{metric}_plot.png"), dpi=300)
  109. plt.close()
  110. print("Plots saved successfully.")
  111. return df_training, df_validation, df_testing
  112. if __name__ == "__main__":
  113. parser = argparse.ArgumentParser(
  114. prog="GANDLF_CollectStats",
  115. formatter_class=argparse.RawTextHelpFormatter,
  116. description="Collect statistics from different testing/validation combinations from output directory.\n\n"
  117. + copyrightMessage,
  118. )
  119. parser.add_argument(
  120. "-m",
  121. "--modeldir",
  122. metavar="",
  123. type=str,
  124. help="Input directory which contains testing and validation models log files",
  125. )
  126. parser.add_argument(
  127. "-o",
  128. "--outputdir",
  129. metavar="",
  130. type=str,
  131. help="Output directory to save stats and plot",
  132. )
  133. args = parser.parse_args()
  134. inputDir = os.path.normpath(args.modeldir)
  135. outputDir = os.path.normpath(args.outputdir)
  136. Path(outputDir).mkdir(parents=True, exist_ok=True)
  137. outputFile = os.path.join(outputDir, "data.csv") # data file name
  138. outputPlot = os.path.join(outputDir, "plot.png") # plot file
  139. trainingLogs = os.path.join(inputDir, "logs_training.csv")
  140. validationLogs = os.path.join(inputDir, "logs_validation.csv")
  141. testingLogs = os.path.join(inputDir, "logs_testing.csv")
  142. # Read all the files
  143. df_training = pd.read_csv(trainingLogs)
  144. df_validation = pd.read_csv(validationLogs)
  145. df_testing = pd.read_csv(testingLogs) if os.path.isfile(testingLogs) else None
  146. # Check for metrics in columns and do tight plots
  147. plot_all(df_training, df_validation, df_testing, outputPlot)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...