Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

wandb_utils.py 18 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
  1. """Utilities and tools for tracking runs with Weights & Biases."""
  2. import logging
  3. import os
  4. import sys
  5. from contextlib import contextmanager
  6. from pathlib import Path
  7. import yaml
  8. from tqdm import tqdm
  9. sys.path.append(str(Path(__file__).parent.parent.parent)) # add utils/ to path
  10. from utils.datasets import LoadImagesAndLabels
  11. from utils.datasets import img2label_paths
  12. from utils.general import colorstr, check_dataset, check_file
  13. try:
  14. import wandb
  15. from wandb import init, finish
  16. except ImportError:
  17. wandb = None
  18. RANK = int(os.getenv('RANK', -1))
  19. WANDB_ARTIFACT_PREFIX = 'wandb-artifact://'
  20. def remove_prefix(from_string, prefix=WANDB_ARTIFACT_PREFIX):
  21. return from_string[len(prefix):]
  22. def check_wandb_config_file(data_config_file):
  23. wandb_config = '_wandb.'.join(data_config_file.rsplit('.', 1)) # updated data.yaml path
  24. if Path(wandb_config).is_file():
  25. return wandb_config
  26. return data_config_file
  27. def get_run_info(run_path):
  28. run_path = Path(remove_prefix(run_path, WANDB_ARTIFACT_PREFIX))
  29. run_id = run_path.stem
  30. project = run_path.parent.stem
  31. entity = run_path.parent.parent.stem
  32. model_artifact_name = 'run_' + run_id + '_model'
  33. return entity, project, run_id, model_artifact_name
  34. def check_wandb_resume(opt):
  35. process_wandb_config_ddp_mode(opt) if RANK not in [-1, 0] else None
  36. if isinstance(opt.resume, str):
  37. if opt.resume.startswith(WANDB_ARTIFACT_PREFIX):
  38. if RANK not in [-1, 0]: # For resuming DDP runs
  39. entity, project, run_id, model_artifact_name = get_run_info(opt.resume)
  40. api = wandb.Api()
  41. artifact = api.artifact(entity + '/' + project + '/' + model_artifact_name + ':latest')
  42. modeldir = artifact.download()
  43. opt.weights = str(Path(modeldir) / "last.pt")
  44. return True
  45. return None
  46. def process_wandb_config_ddp_mode(opt):
  47. with open(check_file(opt.data)) as f:
  48. data_dict = yaml.safe_load(f) # data dict
  49. train_dir, val_dir = None, None
  50. if isinstance(data_dict['train'], str) and data_dict['train'].startswith(WANDB_ARTIFACT_PREFIX):
  51. api = wandb.Api()
  52. train_artifact = api.artifact(remove_prefix(data_dict['train']) + ':' + opt.artifact_alias)
  53. train_dir = train_artifact.download()
  54. train_path = Path(train_dir) / 'data/images/'
  55. data_dict['train'] = str(train_path)
  56. if isinstance(data_dict['val'], str) and data_dict['val'].startswith(WANDB_ARTIFACT_PREFIX):
  57. api = wandb.Api()
  58. val_artifact = api.artifact(remove_prefix(data_dict['val']) + ':' + opt.artifact_alias)
  59. val_dir = val_artifact.download()
  60. val_path = Path(val_dir) / 'data/images/'
  61. data_dict['val'] = str(val_path)
  62. if train_dir or val_dir:
  63. ddp_data_path = str(Path(val_dir) / 'wandb_local_data.yaml')
  64. with open(ddp_data_path, 'w') as f:
  65. yaml.safe_dump(data_dict, f)
  66. opt.data = ddp_data_path
  67. class WandbLogger():
  68. """Log training runs, datasets, models, and predictions to Weights & Biases.
  69. This logger sends information to W&B at wandb.ai. By default, this information
  70. includes hyperparameters, system configuration and metrics, model metrics,
  71. and basic data metrics and analyses.
  72. By providing additional command line arguments to train.py, datasets,
  73. models and predictions can also be logged.
  74. For more on how this logger is used, see the Weights & Biases documentation:
  75. https://docs.wandb.com/guides/integrations/yolov5
  76. """
  77. def __init__(self, opt, name, run_id, data_dict, job_type='Training'):
  78. # Pre-training routine --
  79. self.job_type = job_type
  80. self.wandb, self.wandb_run, self.data_dict = wandb, None if not wandb else wandb.run, data_dict
  81. # It's more elegant to stick to 1 wandb.init call, but useful config data is overwritten in the WandbLogger's wandb.init call
  82. if isinstance(opt.resume, str): # checks resume from artifact
  83. if opt.resume.startswith(WANDB_ARTIFACT_PREFIX):
  84. entity, project, run_id, model_artifact_name = get_run_info(opt.resume)
  85. model_artifact_name = WANDB_ARTIFACT_PREFIX + model_artifact_name
  86. assert wandb, 'install wandb to resume wandb runs'
  87. # Resume wandb-artifact:// runs here| workaround for not overwriting wandb.config
  88. self.wandb_run = wandb.init(id=run_id,
  89. project=project,
  90. entity=entity,
  91. resume='allow',
  92. allow_val_change=True)
  93. opt.resume = model_artifact_name
  94. elif self.wandb:
  95. self.wandb_run = wandb.init(config=opt,
  96. resume="allow",
  97. project='YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem,
  98. entity=opt.entity,
  99. name=name,
  100. job_type=job_type,
  101. id=run_id,
  102. allow_val_change=True) if not wandb.run else wandb.run
  103. if self.wandb_run:
  104. if self.job_type == 'Training':
  105. if not opt.resume:
  106. wandb_data_dict = self.check_and_upload_dataset(opt) if opt.upload_dataset else data_dict
  107. # Info useful for resuming from artifacts
  108. self.wandb_run.config.update({'opt': vars(opt), 'data_dict': data_dict}, allow_val_change=True)
  109. self.data_dict = self.setup_training(opt, data_dict)
  110. if self.job_type == 'Dataset Creation':
  111. self.data_dict = self.check_and_upload_dataset(opt)
  112. else:
  113. prefix = colorstr('wandb: ')
  114. print(f"{prefix}Install Weights & Biases for YOLOv5 logging with 'pip install wandb' (recommended)")
  115. def check_and_upload_dataset(self, opt):
  116. assert wandb, 'Install wandb to upload dataset'
  117. config_path = self.log_dataset_artifact(check_file(opt.data),
  118. opt.single_cls,
  119. 'YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem)
  120. print("Created dataset config file ", config_path)
  121. with open(config_path) as f:
  122. wandb_data_dict = yaml.safe_load(f)
  123. return wandb_data_dict
  124. def setup_training(self, opt, data_dict):
  125. self.log_dict, self.current_epoch, self.log_imgs = {}, 0, 16 # Logging Constants
  126. self.bbox_interval = opt.bbox_interval
  127. if isinstance(opt.resume, str):
  128. modeldir, _ = self.download_model_artifact(opt)
  129. if modeldir:
  130. self.weights = Path(modeldir) / "last.pt"
  131. config = self.wandb_run.config
  132. opt.weights, opt.save_period, opt.batch_size, opt.bbox_interval, opt.epochs, opt.hyp = str(
  133. self.weights), config.save_period, config.total_batch_size, config.bbox_interval, config.epochs, \
  134. config.opt['hyp']
  135. data_dict = dict(self.wandb_run.config.data_dict) # eliminates the need for config file to resume
  136. if 'val_artifact' not in self.__dict__: # If --upload_dataset is set, use the existing artifact, don't download
  137. self.train_artifact_path, self.train_artifact = self.download_dataset_artifact(data_dict.get('train'),
  138. opt.artifact_alias)
  139. self.val_artifact_path, self.val_artifact = self.download_dataset_artifact(data_dict.get('val'),
  140. opt.artifact_alias)
  141. self.result_artifact, self.result_table, self.val_table, self.weights = None, None, None, None
  142. if self.train_artifact_path is not None:
  143. train_path = Path(self.train_artifact_path) / 'data/images/'
  144. data_dict['train'] = str(train_path)
  145. if self.val_artifact_path is not None:
  146. val_path = Path(self.val_artifact_path) / 'data/images/'
  147. data_dict['val'] = str(val_path)
  148. self.val_table = self.val_artifact.get("val")
  149. self.map_val_table_path()
  150. wandb.log({"validation dataset": self.val_table})
  151. if self.val_artifact is not None:
  152. self.result_artifact = wandb.Artifact("run_" + wandb.run.id + "_progress", "evaluation")
  153. self.result_table = wandb.Table(["epoch", "id", "ground truth", "prediction", "avg_confidence"])
  154. if opt.bbox_interval == -1:
  155. self.bbox_interval = opt.bbox_interval = (opt.epochs // 10) if opt.epochs > 10 else 1
  156. return data_dict
  157. def download_dataset_artifact(self, path, alias):
  158. if isinstance(path, str) and path.startswith(WANDB_ARTIFACT_PREFIX):
  159. artifact_path = Path(remove_prefix(path, WANDB_ARTIFACT_PREFIX) + ":" + alias)
  160. dataset_artifact = wandb.use_artifact(artifact_path.as_posix().replace("\\","/"))
  161. assert dataset_artifact is not None, "'Error: W&B dataset artifact doesn\'t exist'"
  162. datadir = dataset_artifact.download()
  163. return datadir, dataset_artifact
  164. return None, None
  165. def download_model_artifact(self, opt):
  166. if opt.resume.startswith(WANDB_ARTIFACT_PREFIX):
  167. model_artifact = wandb.use_artifact(remove_prefix(opt.resume, WANDB_ARTIFACT_PREFIX) + ":latest")
  168. assert model_artifact is not None, 'Error: W&B model artifact doesn\'t exist'
  169. modeldir = model_artifact.download()
  170. epochs_trained = model_artifact.metadata.get('epochs_trained')
  171. total_epochs = model_artifact.metadata.get('total_epochs')
  172. is_finished = total_epochs is None
  173. assert not is_finished, 'training is finished, can only resume incomplete runs.'
  174. return modeldir, model_artifact
  175. return None, None
  176. def log_model(self, path, opt, epoch, fitness_score, best_model=False):
  177. model_artifact = wandb.Artifact('run_' + wandb.run.id + '_model', type='model', metadata={
  178. 'original_url': str(path),
  179. 'epochs_trained': epoch + 1,
  180. 'save period': opt.save_period,
  181. 'project': opt.project,
  182. 'total_epochs': opt.epochs,
  183. 'fitness_score': fitness_score
  184. })
  185. model_artifact.add_file(str(path / 'last.pt'), name='last.pt')
  186. wandb.log_artifact(model_artifact,
  187. aliases=['latest', 'last', 'epoch ' + str(self.current_epoch), 'best' if best_model else ''])
  188. print("Saving model artifact on epoch ", epoch + 1)
  189. def log_dataset_artifact(self, data_file, single_cls, project, overwrite_config=False):
  190. with open(data_file) as f:
  191. data = yaml.safe_load(f) # data dict
  192. check_dataset(data)
  193. nc, names = (1, ['item']) if single_cls else (int(data['nc']), data['names'])
  194. names = {k: v for k, v in enumerate(names)} # to index dictionary
  195. self.train_artifact = self.create_dataset_table(LoadImagesAndLabels(
  196. data['train'], rect=True, batch_size=1), names, name='train') if data.get('train') else None
  197. self.val_artifact = self.create_dataset_table(LoadImagesAndLabels(
  198. data['val'], rect=True, batch_size=1), names, name='val') if data.get('val') else None
  199. if data.get('train'):
  200. data['train'] = WANDB_ARTIFACT_PREFIX + str(Path(project) / 'train')
  201. if data.get('val'):
  202. data['val'] = WANDB_ARTIFACT_PREFIX + str(Path(project) / 'val')
  203. path = data_file if overwrite_config else '_wandb.'.join(data_file.rsplit('.', 1)) # updated data.yaml path
  204. data.pop('download', None)
  205. data.pop('path', None)
  206. with open(path, 'w') as f:
  207. yaml.safe_dump(data, f)
  208. if self.job_type == 'Training': # builds correct artifact pipeline graph
  209. self.wandb_run.use_artifact(self.val_artifact)
  210. self.wandb_run.use_artifact(self.train_artifact)
  211. self.val_artifact.wait()
  212. self.val_table = self.val_artifact.get('val')
  213. self.map_val_table_path()
  214. else:
  215. self.wandb_run.log_artifact(self.train_artifact)
  216. self.wandb_run.log_artifact(self.val_artifact)
  217. return path
  218. def map_val_table_path(self):
  219. self.val_table_map = {}
  220. print("Mapping dataset")
  221. for i, data in enumerate(tqdm(self.val_table.data)):
  222. self.val_table_map[data[3]] = data[0]
  223. def create_dataset_table(self, dataset, class_to_id, name='dataset'):
  224. # TODO: Explore multiprocessing to slpit this loop parallely| This is essential for speeding up the the logging
  225. artifact = wandb.Artifact(name=name, type="dataset")
  226. img_files = tqdm([dataset.path]) if isinstance(dataset.path, str) and Path(dataset.path).is_dir() else None
  227. img_files = tqdm(dataset.img_files) if not img_files else img_files
  228. for img_file in img_files:
  229. if Path(img_file).is_dir():
  230. artifact.add_dir(img_file, name='data/images')
  231. labels_path = 'labels'.join(dataset.path.rsplit('images', 1))
  232. artifact.add_dir(labels_path, name='data/labels')
  233. else:
  234. artifact.add_file(img_file, name='data/images/' + Path(img_file).name)
  235. label_file = Path(img2label_paths([img_file])[0])
  236. artifact.add_file(str(label_file),
  237. name='data/labels/' + label_file.name) if label_file.exists() else None
  238. table = wandb.Table(columns=["id", "train_image", "Classes", "name"])
  239. class_set = wandb.Classes([{'id': id, 'name': name} for id, name in class_to_id.items()])
  240. for si, (img, labels, paths, shapes) in enumerate(tqdm(dataset)):
  241. box_data, img_classes = [], {}
  242. for cls, *xywh in labels[:, 1:].tolist():
  243. cls = int(cls)
  244. box_data.append({"position": {"middle": [xywh[0], xywh[1]], "width": xywh[2], "height": xywh[3]},
  245. "class_id": cls,
  246. "box_caption": "%s" % (class_to_id[cls])})
  247. img_classes[cls] = class_to_id[cls]
  248. boxes = {"ground_truth": {"box_data": box_data, "class_labels": class_to_id}} # inference-space
  249. table.add_data(si, wandb.Image(paths, classes=class_set, boxes=boxes), list(img_classes.values()),
  250. Path(paths).name)
  251. artifact.add(table, name)
  252. return artifact
  253. def log_training_progress(self, predn, path, names):
  254. if self.val_table and self.result_table:
  255. class_set = wandb.Classes([{'id': id, 'name': name} for id, name in names.items()])
  256. box_data = []
  257. total_conf = 0
  258. for *xyxy, conf, cls in predn.tolist():
  259. if conf >= 0.25:
  260. box_data.append(
  261. {"position": {"minX": xyxy[0], "minY": xyxy[1], "maxX": xyxy[2], "maxY": xyxy[3]},
  262. "class_id": int(cls),
  263. "box_caption": "%s %.3f" % (names[cls], conf),
  264. "scores": {"class_score": conf},
  265. "domain": "pixel"})
  266. total_conf = total_conf + conf
  267. boxes = {"predictions": {"box_data": box_data, "class_labels": names}} # inference-space
  268. id = self.val_table_map[Path(path).name]
  269. self.result_table.add_data(self.current_epoch,
  270. id,
  271. self.val_table.data[id][1],
  272. wandb.Image(self.val_table.data[id][1], boxes=boxes, classes=class_set),
  273. total_conf / max(1, len(box_data))
  274. )
  275. def log(self, log_dict):
  276. if self.wandb_run:
  277. for key, value in log_dict.items():
  278. self.log_dict[key] = value
  279. def end_epoch(self, best_result=False):
  280. if self.wandb_run:
  281. with all_logging_disabled():
  282. wandb.log(self.log_dict)
  283. self.log_dict = {}
  284. if self.result_artifact:
  285. self.result_artifact.add(self.result_table, 'result')
  286. wandb.log_artifact(self.result_artifact, aliases=['latest', 'last', 'epoch ' + str(self.current_epoch),
  287. ('best' if best_result else '')])
  288. wandb.log({"evaluation": self.result_table})
  289. self.result_table = wandb.Table(["epoch", "id", "ground truth", "prediction", "avg_confidence"])
  290. self.result_artifact = wandb.Artifact("run_" + wandb.run.id + "_progress", "evaluation")
  291. def finish_run(self):
  292. if self.wandb_run:
  293. if self.log_dict:
  294. with all_logging_disabled():
  295. wandb.log(self.log_dict)
  296. wandb.run.finish()
  297. @contextmanager
  298. def all_logging_disabled(highest_level=logging.CRITICAL):
  299. """ source - https://gist.github.com/simon-weber/7853144
  300. A context manager that will prevent any logging messages triggered during the body from being processed.
  301. :param highest_level: the maximum logging level in use.
  302. This would only need to be changed if a custom level greater than CRITICAL is defined.
  303. """
  304. previous_level = logging.root.manager.disable
  305. logging.disable(highest_level)
  306. try:
  307. yield
  308. finally:
  309. logging.disable(previous_level)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...