Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

al_session.py 7.6 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
  1. """
  2. al run managers a single AL session. It handles logging
  3. """
  4. import csv
  5. import numpy as np
  6. import os
  7. import time
  8. import tensorflow as tf
  9. from attr import attrs, attrib
  10. from datetime import datetime
  11. from src.environment import ClassiferALEnvironmentT
  12. from src.al_agent import ClassifierALAgentT
  13. from src.utils.log_utils import (
  14. set_up_experiment_logging,
  15. time_display,
  16. )
  17. from sklearn.metrics import f1_score, confusion_matrix
  18. @attrs
  19. class ClassiferALSessionManager:
  20. al_agent: ClassifierALAgentT = attrib()
  21. al_env: ClassiferALEnvironmentT = attrib()
  22. al_manager = attrib()
  23. session_dir: str = attrib()
  24. al_epochs: int = attrib()
  25. al_step_percentage: float = attrib()
  26. warm_start_percentage: float = attrib(default=0)
  27. retrain_model: bool = attrib(default=False)
  28. save_model_interval:int = attrib(default=10)
  29. stdout: bool = attrib(default=False)
  30. # only made available when init_session is called
  31. start_time:int = None
  32. run_dir:str = None
  33. logger = None
  34. tf_summary_writer = None
  35. model_snapshot_dir: str = None
  36. def reset_session(self):
  37. timestamp_str = datetime.now().strftime("%Y-%m-%d_%H:%M:%S")
  38. self.run_dir = os.path.join(self.session_dir, timestamp_str)
  39. if not os.path.exists(self.run_dir):
  40. os.makedirs(self.run_dir)
  41. # setting logs, tf sumary writer and some
  42. self.logger, self.tf_summary_writer, self.model_snapshot_dir = (
  43. set_up_experiment_logging(
  44. self.run_dir,
  45. log_fpath=os.path.join(self.run_dir, "session.log"),
  46. model_snapshot_dir=os.path.join(self.run_dir, "model_snapshots"),
  47. metrics_dir=os.path.join(self.run_dir, "metrics"),
  48. stdout=self.stdout,
  49. clear_old_data=True,
  50. )
  51. )
  52. self.al_env.reset()
  53. self.start_time = None
  54. def run_session(self):
  55. pool_size = self.al_manager.pool_size
  56. n_points_to_label = int(pool_size*self.al_step_percentage)
  57. self.start_time = time.monotonic()
  58. self.logger.info(f"Starting session in f{self.run_dir}")
  59. # warm start
  60. if self.warm_start_percentage > 0:
  61. warm_start_count = int(self.warm_start_percentage * pool_size)
  62. self.logger.info(f"Warm start of {warm_start_count} labels")
  63. self.al_env.warm_start(warm_start_count)
  64. self.al_env.train_step()
  65. for al_epoch in range(0, self.al_epochs):
  66. n_step = self.al_env.n_step
  67. self.logger.info("-" * 118)
  68. self.logger.info(
  69. f"AL Epoch: {al_epoch+1}/{self.al_epochs}"
  70. f"\tTrain Data Labeled: {n_step}/{pool_size}"
  71. f"\tElapsed Time: {time_display(time.monotonic()-self.start_time)}")
  72. # label step
  73. selection = self.al_agent.select_data_to_label(n_points_to_label)
  74. # TODO add metrics around selection
  75. self.al_env.label_step(selection)
  76. self.al_env.train_step(retrain=self.retrain_model)
  77. self.log_metrics(n_step, "train")
  78. self.log_metrics(n_step, "test")
  79. self.log_metrics(n_step, "validation")
  80. # save model
  81. if (self.save_model_interval > 0 and
  82. ((al_epoch+1) % self.save_model_interval == 0)):
  83. model_fpath = os.path.join(
  84. self.model_snapshot_dir,
  85. f"model_AL_epoch_{al_epoch}_{self.al_epochs}.ckpt")
  86. self.al_env.model_manager.save_model(model_fpath)
  87. def log_metrics(
  88. self,
  89. step: int, # training step (number data point labeled)
  90. data_type: str, # test, train, validation
  91. ):
  92. """
  93. evaluates model and input, prediction, and true label
  94. """
  95. x, y = self.al_manager.get_dataset(data_type)
  96. data_size = x.shape[0]
  97. # TODO rely on model compiling here?
  98. loss_metric = tf.keras.metrics.Mean(name="loss")
  99. micro_f1_metric = tf.keras.metrics.Mean(name="micro_f1_metric")
  100. macro_f1_metric = tf.keras.metrics.Mean(name="macro_f1_metric")
  101. loss_metric.reset_states()
  102. micro_f1_metric.reset_states()
  103. macro_f1_metric.reset_states()
  104. model_num_classes = None
  105. total_prediction_count = None
  106. total_true_label_count = None
  107. cm = None
  108. for batch_x, batch_y, raw_prediction, batch_loss in \
  109. self.al_env.model_manager.evaluate_model(x, y):
  110. loss_metric.update_state(batch_loss)
  111. # dynamically getting number of class
  112. model_num_classes = raw_prediction.shape[-1]
  113. if total_prediction_count is None:
  114. total_prediction_count = np.zeros(model_num_classes)
  115. total_true_label_count = np.zeros(model_num_classes)
  116. prediction = np.argmax(raw_prediction, axis=1)
  117. unique, counts = np.unique(prediction, return_counts=True)
  118. for i, count in zip(unique, counts):
  119. total_prediction_count[i] += count
  120. total_true_label_count += np.sum(batch_y, axis=0)
  121. batch_y = np.argmax(batch_y, axis=1) # 1 hot to class
  122. if cm is None:
  123. cm = confusion_matrix(batch_y, prediction, labels=np.arange(model_num_classes))
  124. else:
  125. cm += confusion_matrix(batch_y, prediction, labels=np.arange(model_num_classes))
  126. micro_f1_metric.update_state(
  127. f1_score(batch_y, prediction, average="micro", labels=np.arange(model_num_classes)))
  128. macro_f1_metric.update_state(
  129. f1_score(batch_y, prediction, average="macro", labels=np.arange(model_num_classes)))
  130. # min class ratio
  131. min_class_prediction_ratio = np.min(total_prediction_count)/data_size
  132. min_class_true_ratio = np.min(total_true_label_count)/data_size
  133. # max class ratio
  134. max_class_prediction_ratio = np.max(total_prediction_count)/data_size
  135. max_class_true_ratio = np.max(total_true_label_count)/data_size
  136. metrics = {
  137. "loss": loss_metric.result(),
  138. "micro_f1_metric": micro_f1_metric.result(),
  139. "macro_f1_metric": macro_f1_metric.result(),
  140. "min_class_prediction_ratio": min_class_prediction_ratio,
  141. "min_class_true_ratio": min_class_true_ratio,
  142. "max_class_prediction_ratio": max_class_prediction_ratio,
  143. "max_class_true_ratio": max_class_true_ratio,
  144. }
  145. # tensorflow metric output
  146. with self.tf_summary_writer.as_default():
  147. for metric_key, metric_value in metrics.items():
  148. tf.summary.scalar(f"{data_type} {metric_key}", metric_value, step=step)
  149. # log output
  150. for metric_key, metric_value in metrics.items():
  151. self.logger.info(f"{data_type} {metric_key}: {metric_value}")
  152. # log output, tensorboard output, save to a master csv
  153. # TODO keep this open for faster run
  154. metrics["step"] = step
  155. # we cast to float before storing to csv
  156. for metric_key, metric_value in metrics.items():
  157. metrics[metric_key] = float(metric_value)
  158. csv_file = os.path.join(self.run_dir, f"{data_type}_results.csv")
  159. with open(csv_file, 'a') as f:
  160. writer = csv.DictWriter(f, fieldnames=list(metrics.keys()))
  161. if f.tell() == 0:
  162. writer.writeheader()
  163. writer.writerow(metrics)
  164. # dumping confusion_matrix
  165. data = np.expand_dims(np.append([step], cm.flatten()), 0)
  166. cm_file = os.path.join(self.run_dir, f"{data_type}_confusion_matrix.csv")
  167. with open(cm_file, 'a') as f:
  168. np.savetxt(f, data, delimiter=",")
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...