Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

wandb_utils.py 6.7 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
  1. import json
  2. import shutil
  3. import sys
  4. from datetime import datetime
  5. from pathlib import Path
  6. import torch
  7. sys.path.append(str(Path(__file__).parent.parent.parent)) # add utils/ to path
  8. from utils.general import colorstr, xywh2xyxy
  9. try:
  10. import wandb
  11. except ImportError:
  12. wandb = None
  13. print(f"{colorstr('wandb: ')}Install Weights & Biases for YOLOv5 logging with 'pip install wandb' (recommended)")
  14. WANDB_ARTIFACT_PREFIX = 'wandb-artifact://'
  15. def remove_prefix(from_string, prefix):
  16. return from_string[len(prefix):]
  17. class WandbLogger():
  18. def __init__(self, opt, name, run_id, data_dict, job_type='Training'):
  19. self.wandb = wandb
  20. self.wandb_run = wandb.init(config=opt, resume="allow",
  21. project='YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem,
  22. name=name,
  23. job_type=job_type,
  24. id=run_id) if self.wandb else None
  25. if job_type == 'Training':
  26. self.setup_training(opt, data_dict)
  27. if opt.bbox_interval == -1:
  28. opt.bbox_interval = (opt.epochs // 10) if opt.epochs > 10 else opt.epochs
  29. if opt.save_period == -1:
  30. opt.save_period = (opt.epochs // 10) if opt.epochs > 10 else opt.epochs
  31. def setup_training(self, opt, data_dict):
  32. self.log_dict = {}
  33. self.train_artifact_path, self.trainset_artifact = \
  34. self.download_dataset_artifact(data_dict['train'], opt.artifact_alias)
  35. self.test_artifact_path, self.testset_artifact = \
  36. self.download_dataset_artifact(data_dict['val'], opt.artifact_alias)
  37. self.result_artifact, self.result_table, self.weights = None, None, None
  38. if self.train_artifact_path is not None:
  39. train_path = Path(self.train_artifact_path) / 'data/images/'
  40. data_dict['train'] = str(train_path)
  41. if self.test_artifact_path is not None:
  42. test_path = Path(self.test_artifact_path) / 'data/images/'
  43. data_dict['val'] = str(test_path)
  44. self.result_artifact = wandb.Artifact("run_" + wandb.run.id + "_progress", "evaluation")
  45. self.result_table = wandb.Table(["epoch", "id", "prediction", "avg_confidence"])
  46. if opt.resume_from_artifact:
  47. modeldir, _ = self.download_model_artifact(opt.resume_from_artifact)
  48. if modeldir:
  49. self.weights = Path(modeldir) / "best.pt"
  50. opt.weights = self.weights
  51. def download_dataset_artifact(self, path, alias):
  52. if path.startswith(WANDB_ARTIFACT_PREFIX):
  53. dataset_artifact = wandb.use_artifact(remove_prefix(path, WANDB_ARTIFACT_PREFIX) + ":" + alias)
  54. assert dataset_artifact is not None, "'Error: W&B dataset artifact doesn\'t exist'"
  55. datadir = dataset_artifact.download()
  56. labels_zip = Path(datadir) / "data/labels.zip"
  57. shutil.unpack_archive(labels_zip, Path(datadir) / 'data/labels', 'zip')
  58. print("Downloaded dataset to : ", datadir)
  59. return datadir, dataset_artifact
  60. return None, None
  61. def download_model_artifact(self, name):
  62. model_artifact = wandb.use_artifact(name + ":latest")
  63. assert model_artifact is not None, 'Error: W&B model artifact doesn\'t exist'
  64. modeldir = model_artifact.download()
  65. print("Downloaded model to : ", modeldir)
  66. return modeldir, model_artifact
  67. def log_model(self, path, opt, epoch):
  68. datetime_suffix = datetime.today().strftime('%Y-%m-%d-%H-%M-%S')
  69. model_artifact = wandb.Artifact('run_' + wandb.run.id + '_model', type='model', metadata={
  70. 'original_url': str(path),
  71. 'epoch': epoch + 1,
  72. 'save period': opt.save_period,
  73. 'project': opt.project,
  74. 'datetime': datetime_suffix
  75. })
  76. model_artifact.add_file(str(path / 'last.pt'), name='last.pt')
  77. model_artifact.add_file(str(path / 'best.pt'), name='best.pt')
  78. wandb.log_artifact(model_artifact)
  79. print("Saving model artifact on epoch ", epoch + 1)
  80. def log_dataset_artifact(self, dataset, class_to_id, name='dataset'):
  81. artifact = wandb.Artifact(name=name, type="dataset")
  82. image_path = dataset.path
  83. artifact.add_dir(image_path, name='data/images')
  84. table = wandb.Table(columns=["id", "train_image", "Classes"])
  85. class_set = wandb.Classes([{'id': id, 'name': name} for id, name in class_to_id.items()])
  86. for si, (img, labels, paths, shapes) in enumerate(dataset):
  87. height, width = shapes[0]
  88. labels[:, 2:] = (xywh2xyxy(labels[:, 2:].view(-1, 4)))
  89. labels[:, 2:] *= torch.Tensor([width, height, width, height])
  90. box_data = []
  91. img_classes = {}
  92. for cls, *xyxy in labels[:, 1:].tolist():
  93. cls = int(cls)
  94. box_data.append({"position": {"minX": xyxy[0], "minY": xyxy[1], "maxX": xyxy[2], "maxY": xyxy[3]},
  95. "class_id": cls,
  96. "box_caption": "%s" % (class_to_id[cls]),
  97. "scores": {"acc": 1},
  98. "domain": "pixel"})
  99. img_classes[cls] = class_to_id[cls]
  100. boxes = {"ground_truth": {"box_data": box_data, "class_labels": class_to_id}} # inference-space
  101. table.add_data(si, wandb.Image(paths, classes=class_set, boxes=boxes), json.dumps(img_classes))
  102. artifact.add(table, name)
  103. labels_path = 'labels'.join(image_path.rsplit('images', 1))
  104. zip_path = Path(labels_path).parent / (name + '_labels.zip')
  105. if not zip_path.is_file(): # make_archive won't check if file exists
  106. shutil.make_archive(zip_path.with_suffix(''), 'zip', labels_path)
  107. artifact.add_file(str(zip_path), name='data/labels.zip')
  108. wandb.log_artifact(artifact)
  109. print("Saving data to W&B...")
  110. def log(self, log_dict):
  111. if self.wandb_run:
  112. for key, value in log_dict.items():
  113. self.log_dict[key] = value
  114. def end_epoch(self):
  115. if self.wandb_run and self.log_dict:
  116. wandb.log(self.log_dict)
  117. self.log_dict = {}
  118. def finish_run(self):
  119. if self.wandb_run:
  120. if self.result_artifact:
  121. print("Add Training Progress Artifact")
  122. self.result_artifact.add(self.result_table, 'result')
  123. train_results = wandb.JoinedTable(self.testset_artifact.get("val"), self.result_table, "id")
  124. self.result_artifact.add(train_results, 'joined_result')
  125. wandb.log_artifact(self.result_artifact)
  126. if self.log_dict:
  127. wandb.log(self.log_dict)
  128. wandb.run.finish()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...