Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

evaluate_on_coco.py 12 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
  1. """
  2. A script to evaluate the model's performance using pre-trained weights using COCO API.
  3. Example usage: python evaluate_on_coco.py -dir D:\cocoDataset\val2017\val2017 -gta D:\cocoDataset\annotatio
  4. ns_trainval2017\annotations\instances_val2017.json -c cfg/yolov4-smaller-input.cfg -g 0
  5. Explanation: set where your images can be found using -dir, then use -gta to point to the ground truth annotations file
  6. and finally -c to point to the config file you want to use to load the network using.
  7. """
  8. import argparse
  9. import datetime
  10. import json
  11. import logging
  12. import os
  13. import sys
  14. import time
  15. from collections import defaultdict
  16. import numpy as np
  17. import torch
  18. from PIL import Image, ImageDraw
  19. from easydict import EasyDict as edict
  20. from pycocotools.coco import COCO
  21. from pycocotools.cocoeval import COCOeval
  22. from cfg import Cfg
  23. from tool.darknet2pytorch import Darknet
  24. from tool.utils import load_class_names
  25. from tool.torch_utils import do_detect
  26. def get_class_name(cat):
  27. class_names = load_class_names("./data/coco.names")
  28. if cat >= 1 and cat <= 11:
  29. cat = cat - 1
  30. elif cat >= 13 and cat <= 25:
  31. cat = cat - 2
  32. elif cat >= 27 and cat <= 28:
  33. cat = cat - 3
  34. elif cat >= 31 and cat <= 44:
  35. cat = cat - 5
  36. elif cat >= 46 and cat <= 65:
  37. cat = cat - 6
  38. elif cat == 67:
  39. cat = cat - 7
  40. elif cat == 70:
  41. cat = cat - 9
  42. elif cat >= 72 and cat <= 82:
  43. cat = cat - 10
  44. elif cat >= 84 and cat <= 90:
  45. cat = cat - 11
  46. return class_names[cat]
  47. def convert_cat_id_and_reorientate_bbox(single_annotation):
  48. cat = single_annotation['category_id']
  49. bbox = single_annotation['bbox']
  50. x, y, w, h = bbox
  51. x1, y1, x2, y2 = x - w / 2, y - h / 2, x + w / 2, y + h / 2
  52. if 0 <= cat <= 10:
  53. cat = cat + 1
  54. elif 11 <= cat <= 23:
  55. cat = cat + 2
  56. elif 24 <= cat <= 25:
  57. cat = cat + 3
  58. elif 26 <= cat <= 39:
  59. cat = cat + 5
  60. elif 40 <= cat <= 59:
  61. cat = cat + 6
  62. elif cat == 60:
  63. cat = cat + 7
  64. elif cat == 61:
  65. cat = cat + 9
  66. elif 62 <= cat <= 72:
  67. cat = cat + 10
  68. elif 73 <= cat <= 79:
  69. cat = cat + 11
  70. single_annotation['category_id'] = cat
  71. single_annotation['bbox'] = [x1, y1, w, h]
  72. return single_annotation
  73. def myconverter(obj):
  74. if isinstance(obj, np.integer):
  75. return int(obj)
  76. elif isinstance(obj, np.floating):
  77. return float(obj)
  78. elif isinstance(obj, np.ndarray):
  79. return obj.tolist()
  80. elif isinstance(obj, datetime.datetime):
  81. return obj.__str__()
  82. else:
  83. return obj
  84. def evaluate_on_coco(cfg, resFile):
  85. annType = "bbox" # specify type here
  86. with open(resFile, 'r') as f:
  87. unsorted_annotations = json.load(f)
  88. sorted_annotations = list(sorted(unsorted_annotations, key=lambda single_annotation: single_annotation["image_id"]))
  89. sorted_annotations = list(map(convert_cat_id_and_reorientate_bbox, sorted_annotations))
  90. reshaped_annotations = defaultdict(list)
  91. for annotation in sorted_annotations:
  92. reshaped_annotations[annotation['image_id']].append(annotation)
  93. with open('temp.json', 'w') as f:
  94. json.dump(sorted_annotations, f)
  95. cocoGt = COCO(cfg.gt_annotations_path)
  96. cocoDt = cocoGt.loadRes('temp.json')
  97. with open(cfg.gt_annotations_path, 'r') as f:
  98. gt_annotation_raw = json.load(f)
  99. gt_annotation_raw_images = gt_annotation_raw["images"]
  100. gt_annotation_raw_labels = gt_annotation_raw["annotations"]
  101. rgb_label = (255, 0, 0)
  102. rgb_pred = (0, 255, 0)
  103. for i, image_id in enumerate(reshaped_annotations):
  104. image_annotations = reshaped_annotations[image_id]
  105. gt_annotation_image_raw = list(filter(
  106. lambda image_json: image_json['id'] == image_id, gt_annotation_raw_images
  107. ))
  108. gt_annotation_labels_raw = list(filter(
  109. lambda label_json: label_json['image_id'] == image_id, gt_annotation_raw_labels
  110. ))
  111. if len(gt_annotation_image_raw) == 1:
  112. image_path = os.path.join(cfg.dataset_dir, gt_annotation_image_raw[0]["file_name"])
  113. actual_image = Image.open(image_path).convert('RGB')
  114. draw = ImageDraw.Draw(actual_image)
  115. for annotation in image_annotations:
  116. x1_pred, y1_pred, w, h = annotation['bbox']
  117. x2_pred, y2_pred = x1_pred + w, y1_pred + h
  118. cls_id = annotation['category_id']
  119. label = get_class_name(cls_id)
  120. draw.text((x1_pred, y1_pred), label, fill=rgb_pred)
  121. draw.rectangle([x1_pred, y1_pred, x2_pred, y2_pred], outline=rgb_pred)
  122. for annotation in gt_annotation_labels_raw:
  123. x1_truth, y1_truth, w, h = annotation['bbox']
  124. x2_truth, y2_truth = x1_truth + w, y1_truth + h
  125. cls_id = annotation['category_id']
  126. label = get_class_name(cls_id)
  127. draw.text((x1_truth, y1_truth), label, fill=rgb_label)
  128. draw.rectangle([x1_truth, y1_truth, x2_truth, y2_truth], outline=rgb_label)
  129. actual_image.save("./data/outcome/predictions_{}".format(gt_annotation_image_raw[0]["file_name"]))
  130. else:
  131. print('please check')
  132. break
  133. if (i + 1) % 100 == 0: # just see first 100
  134. break
  135. imgIds = sorted(cocoGt.getImgIds())
  136. cocoEval = COCOeval(cocoGt, cocoDt, annType)
  137. cocoEval.params.imgIds = imgIds
  138. cocoEval.evaluate()
  139. cocoEval.accumulate()
  140. cocoEval.summarize()
  141. def test(model, annotations, cfg):
  142. if not annotations["images"]:
  143. print("Annotations do not have 'images' key")
  144. return
  145. images = annotations["images"]
  146. # images = images[:10]
  147. resFile = 'data/coco_val_outputs.json'
  148. if torch.cuda.is_available():
  149. use_cuda = 1
  150. else:
  151. use_cuda = 0
  152. # do one forward pass first to circumvent cold start
  153. throwaway_image = Image.open('data/dog.jpg').convert('RGB').resize((model.width, model.height))
  154. do_detect(model, throwaway_image, 0.5, 80, 0.4, use_cuda)
  155. boxes_json = []
  156. for i, image_annotation in enumerate(images):
  157. logging.info("currently on image: {}/{}".format(i + 1, len(images)))
  158. image_file_name = image_annotation["file_name"]
  159. image_id = image_annotation["id"]
  160. image_height = image_annotation["height"]
  161. image_width = image_annotation["width"]
  162. # open and resize each image first
  163. img = Image.open(os.path.join(cfg.dataset_dir, image_file_name)).convert('RGB')
  164. sized = img.resize((model.width, model.height))
  165. if use_cuda:
  166. model.cuda()
  167. start = time.time()
  168. boxes = do_detect(model, sized, 0.0, 80, 0.4, use_cuda)
  169. finish = time.time()
  170. if type(boxes) == list:
  171. for box in boxes:
  172. box_json = {}
  173. category_id = box[-1]
  174. score = box[-2]
  175. bbox_normalized = box[:4]
  176. box_json["category_id"] = int(category_id)
  177. box_json["image_id"] = int(image_id)
  178. bbox = []
  179. for i, bbox_coord in enumerate(bbox_normalized):
  180. modified_bbox_coord = float(bbox_coord)
  181. if i % 2:
  182. modified_bbox_coord *= image_height
  183. else:
  184. modified_bbox_coord *= image_width
  185. modified_bbox_coord = round(modified_bbox_coord, 2)
  186. bbox.append(modified_bbox_coord)
  187. box_json["bbox_normalized"] = list(map(lambda x: round(float(x), 2), bbox_normalized))
  188. box_json["bbox"] = bbox
  189. box_json["score"] = round(float(score), 2)
  190. box_json["timing"] = float(finish - start)
  191. boxes_json.append(box_json)
  192. # print("see box_json: ", box_json)
  193. with open(resFile, 'w') as outfile:
  194. json.dump(boxes_json, outfile, default=myconverter)
  195. else:
  196. print("warning: output from model after postprocessing is not a list, ignoring")
  197. return
  198. # namesfile = 'data/coco.names'
  199. # class_names = load_class_names(namesfile)
  200. # plot_boxes(img, boxes, 'data/outcome/predictions_{}.jpg'.format(image_id), class_names)
  201. with open(resFile, 'w') as outfile:
  202. json.dump(boxes_json, outfile, default=myconverter)
  203. evaluate_on_coco(cfg, resFile)
  204. def get_args(**kwargs):
  205. cfg = kwargs
  206. parser = argparse.ArgumentParser(description='Test model on test dataset',
  207. formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  208. parser.add_argument('-f', '--load', dest='load', type=str, default=None,
  209. help='Load model from a .pth file')
  210. parser.add_argument('-g', '--gpu', metavar='G', type=str, default='-1',
  211. help='GPU', dest='gpu')
  212. parser.add_argument('-dir', '--data-dir', type=str, default=None,
  213. help='dataset dir', dest='dataset_dir')
  214. parser.add_argument('-gta', '--ground_truth_annotations', type=str, default='instances_val2017.json',
  215. help='ground truth annotations file', dest='gt_annotations_path')
  216. parser.add_argument('-w', '--weights_file', type=str, default='weights/yolov4.weights',
  217. help='weights file to load', dest='weights_file')
  218. parser.add_argument('-c', '--model_config', type=str, default='cfg/yolov4.cfg',
  219. help='model config file to load', dest='model_config')
  220. args = vars(parser.parse_args())
  221. for k in args.keys():
  222. cfg[k] = args.get(k)
  223. return edict(cfg)
  224. def init_logger(log_file=None, log_dir=None, log_level=logging.INFO, mode='w', stdout=True):
  225. """
  226. log_dir: 日志文件的文件夹路径
  227. mode: 'a', append; 'w', 覆盖原文件写入.
  228. """
  229. import datetime
  230. def get_date_str():
  231. now = datetime.datetime.now()
  232. return now.strftime('%Y-%m-%d_%H-%M-%S')
  233. fmt = '%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s: %(message)s'
  234. if log_dir is None:
  235. log_dir = '~/temp/log/'
  236. if log_file is None:
  237. log_file = 'log_' + get_date_str() + '.txt'
  238. if not os.path.exists(log_dir):
  239. os.makedirs(log_dir)
  240. log_file = os.path.join(log_dir, log_file)
  241. # 此处不能使用logging输出
  242. print('log file path:' + log_file)
  243. logging.basicConfig(level=logging.DEBUG,
  244. format=fmt,
  245. filename=log_file,
  246. filemode=mode)
  247. if stdout:
  248. console = logging.StreamHandler(stream=sys.stdout)
  249. console.setLevel(log_level)
  250. formatter = logging.Formatter(fmt)
  251. console.setFormatter(formatter)
  252. logging.getLogger('').addHandler(console)
  253. return logging
  254. if __name__ == "__main__":
  255. logging = init_logger(log_dir='log')
  256. cfg = get_args(**Cfg)
  257. os.environ["CUDA_VISIBLE_DEVICES"] = cfg.gpu
  258. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  259. logging.info(f'Using device {device}')
  260. model = Darknet(cfg.model_config)
  261. model.print_network()
  262. model.load_weights(cfg.weights_file)
  263. model.eval() # set model away from training
  264. if torch.cuda.device_count() > 1:
  265. model = torch.nn.DataParallel(model)
  266. model.to(device=device)
  267. annotations_file_path = cfg.gt_annotations_path
  268. with open(annotations_file_path) as annotations_file:
  269. try:
  270. annotations = json.load(annotations_file)
  271. except:
  272. print("annotations file not a json")
  273. exit()
  274. test(model=model,
  275. annotations=annotations,
  276. cfg=cfg, )
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...