Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

visualize_classifications.py 3.2 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
  1. #!/usr/bin/env python
  2. import click
  3. from pathlib import Path
  4. import pickle
  5. import numpy as np
  6. import matplotlib.pyplot as plt
  7. from matplotlib.patches import Rectangle
  8. CAT_COLORS = (
  9. 'gray',
  10. 'white',
  11. 'green',
  12. 'yellow',
  13. 'red',
  14. )
  15. def plot_forecast(site_info, datasets, models, site, date):
  16. date = np.datetime64(date, 'D')
  17. # Get site index
  18. site_names = list(site_info['names'])
  19. if site not in site_names:
  20. raise ValueError(f'Site {site} not in list')
  21. site_idx = site_names.index(site)
  22. site_dataset = datasets[site_idx]['test']
  23. site_classifier = models['classifiers'][site_idx]
  24. categories = site_dataset.density_categories
  25. date_idx = None
  26. for i in range(len(site_dataset)):
  27. dx, dy = site_dataset.get_dates(i)
  28. if dy[0] == date:
  29. date_idx = i
  30. break
  31. if date_idx is None:
  32. raise ValueError(f'Date {date} not found in test dataset')
  33. X, y = site_dataset[date_idx]
  34. dX, dy = site_dataset.get_dates(date_idx)
  35. ypred = site_classifier.predict_proba(X.reshape((1, -1)))
  36. ypred = np.vstack(ypred)
  37. fig, ax = plt.subplots(figsize=(10, 4))
  38. for c in range(len(categories)):
  39. fc = CAT_COLORS[c]
  40. for d in range(len(ypred)):
  41. prob = float(ypred[d][c])
  42. rect = Rectangle(
  43. (d, c), 1, 1,
  44. fc=fc,
  45. ec='none',
  46. alpha=prob
  47. )
  48. ax.add_patch(rect)
  49. ax.text(
  50. d + 0.5, c + 0.5, '%0.01f%%' % (100 * prob),
  51. ha='center', va='center', fontsize=14
  52. )
  53. if y[d] == c:
  54. box = Rectangle(
  55. (d, c), 1, 1,
  56. fc='none',
  57. ec='black',
  58. lw=3,
  59. )
  60. ax.add_patch(box)
  61. for d in range(len(ypred) + 1):
  62. ax.plot([d, d], [0, len(categories)], 'k-', lw=2)
  63. ax.set_ylim(0, len(categories))
  64. ax.set_xlim(0, len(ypred))
  65. ax.set_yticks(np.arange(0.5, len(categories)))
  66. ax.set_yticklabels(categories, fontsize=12)
  67. ax.set_xticks(np.arange(0.5, len(ypred)))
  68. ax.set_xticklabels(dy, fontsize=12)
  69. ax.tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)
  70. ax.set_ylabel('NOAA HMS Smoke Level', fontsize=16)
  71. return fig
  72. @click.command()
  73. @click.argument('datafile', type=click.Path(
  74. path_type=Path, exists=True
  75. ))
  76. @click.argument('modelfile', type=click.Path(
  77. path_type=Path, exists=True
  78. ))
  79. @click.argument('outputfile', type=click.Path(
  80. path_type=Path
  81. ))
  82. @click.option('--site', '-s', default='San Francisco', type=str)
  83. @click.option('--date', '-d', default='2018-07-01', type=str)
  84. def main(datafile, modelfile, outputfile, site, date):
  85. from train_classifiers import load_datasets
  86. print('Loading models...')
  87. with open(modelfile, 'rb') as f:
  88. models = pickle.load(f)
  89. print('Loading datasets...')
  90. config = models['configuration']
  91. site_info, datasets = load_datasets(datafile, config)
  92. print('Generating plot...')
  93. fig = plot_forecast(site_info, datasets, models, site, date)
  94. fig.savefig(outputfile)
  95. if __name__ == '__main__':
  96. main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...