Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

eval_metric_calculation.py 2.9 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
  1. import numpy as np
  2. from PIL import Image
  3. from tqdm import tqdm
  4. def compute_errors(target, prediction):
  5. thresh = np.maximum((target / prediction), (prediction / target))
  6. a1 = (thresh < 1.25).mean()
  7. a2 = (thresh < 1.25 ** 2).mean()
  8. a3 = (thresh < 1.25 ** 3).mean()
  9. abs_rel = np.mean(np.abs(target - prediction) / target)
  10. sq_rel = np.mean(((target - prediction) ** 2) / target)
  11. rmse = (target - prediction) ** 2
  12. rmse = np.sqrt(rmse.mean())
  13. rmse_log = (np.log(target) - np.log(prediction)) ** 2
  14. rmse_log = np.sqrt(rmse_log.mean())
  15. err = np.log(prediction) - np.log(target)
  16. silog = np.sqrt(np.mean(err ** 2) - np.mean(err) ** 2) * 100
  17. log_10 = (np.abs(np.log10(target) - np.log10(prediction))).mean()
  18. return a1, a2, a3, abs_rel, sq_rel, rmse, rmse_log, silog, log_10
  19. def compute_eval_metrics(test_files):
  20. min_depth_eval = 1e-3
  21. max_depth_eval = 10
  22. num_samples = len(test_files)
  23. a1 = np.zeros(num_samples, np.float32)
  24. a2 = np.zeros(num_samples, np.float32)
  25. a3 = np.zeros(num_samples, np.float32)
  26. abs_rel = np.zeros(num_samples, np.float32)
  27. sq_rel = np.zeros(num_samples, np.float32)
  28. rmse = np.zeros(num_samples, np.float32)
  29. rmse_log = np.zeros(num_samples, np.float32)
  30. silog = np.zeros(num_samples, np.float32)
  31. log10 = np.zeros(num_samples, np.float32)
  32. for i in tqdm(range(num_samples), desc="Calculating metrics for test data", total=num_samples):
  33. sample_path = test_files[i]
  34. target_path = str(sample_path.parent/(sample_path.stem + "_depth.png"))
  35. pred_path = "src/eval/" + str(sample_path.stem) + "_pred.png"
  36. target_image = Image.open(target_path)
  37. pred_image = Image.open(pred_path)
  38. target = np.asarray(target_image)
  39. pred = np.asarray(pred_image)
  40. target = target / 25.0
  41. pred = pred / 25.0
  42. pred[pred < min_depth_eval] = min_depth_eval
  43. pred[pred > max_depth_eval] = max_depth_eval
  44. pred[np.isinf(pred)] = max_depth_eval
  45. target[np.isinf(target)] = 0
  46. target[np.isnan(target)] = 0
  47. valid_mask = np.logical_and(target > min_depth_eval, target < max_depth_eval)
  48. a1[i], a2[i], a3[i], abs_rel[i], sq_rel[i], rmse[i], rmse_log[i], silog[i], log10[i] = \
  49. compute_errors(target[valid_mask], pred[valid_mask])
  50. print("{:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}".format(
  51. 'd1', 'd2', 'd3', 'AbsRel', 'SqRel', 'RMSE', 'RMSElog', 'SILog', 'log10'))
  52. print("{:7.3f}, {:7.3f}, {:7.3f}, {:7.3f}, {:7.3f}, {:7.3f}, {:7.3f}, {:7.3f}, {:7.3f}".format(
  53. a1.mean(), a2.mean(), a3.mean(),
  54. abs_rel.mean(), sq_rel.mean(), rmse.mean(), rmse_log.mean(), silog.mean(), log10.mean()))
  55. return dict(a1=a1.mean(), a2=a2.mean(), a3=a3.mean(),
  56. abs_rel=abs_rel.mean(), sq_rel=sq_rel.mean(),
  57. rmse=rmse.mean(), rmse_log=rmse_log.mean(),
  58. log10=log10.mean(), silog=silog.mean())
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...