1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
|
- #!/usr/bin/env python
- import click
- from pathlib import Path
- import pickle
- import numpy as np
- import matplotlib.pyplot as plt
- from matplotlib.patches import Rectangle
- CAT_COLORS = (
- 'gray',
- 'white',
- 'green',
- 'yellow',
- 'red',
- )
- def plot_forecast(site_info, datasets, models, site, date):
- date = np.datetime64(date, 'D')
- # Get site index
- site_names = list(site_info['names'])
- if site not in site_names:
- raise ValueError(f'Site {site} not in list')
- site_idx = site_names.index(site)
- site_dataset = datasets[site_idx]['test']
- site_classifier = models['classifiers'][site_idx]
- categories = site_dataset.density_categories
- date_idx = None
- for i in range(len(site_dataset)):
- dx, dy = site_dataset.get_dates(i)
- if dy[0] == date:
- date_idx = i
- break
- if date_idx is None:
- raise ValueError(f'Date {date} not found in test dataset')
- X, y = site_dataset[date_idx]
- dX, dy = site_dataset.get_dates(date_idx)
- ypred = site_classifier.predict_proba(X.reshape((1, -1)))
- ypred = np.vstack(ypred)
- fig, ax = plt.subplots(figsize=(10, 4))
- for c in range(len(categories)):
- fc = CAT_COLORS[c]
- for d in range(len(ypred)):
- prob = float(ypred[d][c])
- rect = Rectangle(
- (d, c), 1, 1,
- fc=fc,
- ec='none',
- alpha=prob
- )
- ax.add_patch(rect)
- ax.text(
- d + 0.5, c + 0.5, '%0.01f%%' % (100 * prob),
- ha='center', va='center', fontsize=14
- )
- if y[d] == c:
- box = Rectangle(
- (d, c), 1, 1,
- fc='none',
- ec='black',
- lw=3,
- )
- ax.add_patch(box)
- for d in range(len(ypred) + 1):
- ax.plot([d, d], [0, len(categories)], 'k-', lw=2)
- ax.set_ylim(0, len(categories))
- ax.set_xlim(0, len(ypred))
- ax.set_yticks(np.arange(0.5, len(categories)))
- ax.set_yticklabels(categories, fontsize=12)
- ax.set_xticks(np.arange(0.5, len(ypred)))
- ax.set_xticklabels(dy, fontsize=12)
- ax.tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)
- ax.set_ylabel('NOAA HMS Smoke Level', fontsize=16)
- return fig
- @click.command()
- @click.argument('datafile', type=click.Path(
- path_type=Path, exists=True
- ))
- @click.argument('modelfile', type=click.Path(
- path_type=Path, exists=True
- ))
- @click.argument('outputfile', type=click.Path(
- path_type=Path
- ))
- @click.option('--site', '-s', default='San Francisco', type=str)
- @click.option('--date', '-d', default='2018-07-01', type=str)
- def main(datafile, modelfile, outputfile, site, date):
- from train_classifiers import load_datasets
- print('Loading models...')
- with open(modelfile, 'rb') as f:
- models = pickle.load(f)
- print('Loading datasets...')
- config = models['configuration']
- site_info, datasets = load_datasets(datafile, config)
- print('Generating plot...')
- fig = plot_forecast(site_info, datasets, models, site, date)
- fig.savefig(outputfile)
- if __name__ == '__main__':
- main()
|