Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

create_in_situ_prediction_dataset.py 3.0 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
  1. #!/usr/bin/env python
  2. import os
  3. import json
  4. import click
  5. import numpy as np
  6. import xarray as xr
  7. from glob import glob
  8. from datetime import datetime
  9. from collections import defaultdict
  10. from dask.distributed import Client
  11. def get_date(path, file_exprs):
  12. base = os.path.basename(path)
  13. dt = None
  14. for ex in file_exprs:
  15. try:
  16. dt = datetime.strptime(base, ex)
  17. except ValueError:
  18. continue
  19. if dt is None:
  20. raise ValueError(f'Could not parse datetime from {base}')
  21. return dt
  22. def filter_files(filelist, file_exprs, tstart, tend):
  23. tstart = datetime.strptime(tstart, '%Y-%m-%d')
  24. tend = datetime.strptime(tend, '%Y-%m-%d')
  25. return [
  26. file for file in filelist
  27. if tstart <= get_date(file, file_exprs) <= tend
  28. ]
  29. def load_variables(basedir, suffix, bounds, variables, file_exprs, use_collections):
  30. if len(variables) == 0:
  31. return []
  32. if use_collections:
  33. if not all('collection' in v for v in variables):
  34. raise ValueError('All variables must specify collection')
  35. subdirs = set([
  36. os.path.join(basedir, v['collection'])
  37. for v in variables
  38. ])
  39. else:
  40. subdirs = set([basedir])
  41. datasets = []
  42. for subdir in subdirs:
  43. filelist = glob(os.path.join(subdir, f'*.{suffix}'))
  44. filelist = filter_files(filelist, file_exprs, *bounds['time'])
  45. print(f'Loading {len(filelist)} from {subdir}...')
  46. ds = xr.open_mfdataset(filelist, join='override', parallel=True)
  47. print('...done')
  48. ds = ds.sel(**{
  49. k: slice(*v) for k, v in bounds.items()
  50. }).transpose('time', 'lat', 'lon')
  51. datasets.append(ds)
  52. ds = xr.merge(datasets, join='inner')
  53. data = [
  54. sum(
  55. vi['weight'] * ds[vi['name']]
  56. for vi in v['variables']
  57. ).rename(v['name']).assign_attrs(**v['attributes'])
  58. for v in variables
  59. ]
  60. return data
  61. @click.command()
  62. @click.argument('configfile')
  63. @click.argument('geosfpdir')
  64. @click.argument('epadir')
  65. @click.argument('outputfile')
  66. def main(configfile, geosfpdir, epadir, outputfile):
  67. client = Client()
  68. with open(configfile, 'r') as f:
  69. config = json.load(f)
  70. bounds = config['bounds']
  71. epavars = config['epa_variables']
  72. geosvars = config['geos_fp_variables']
  73. file_exprs = config['file_exprs']
  74. all_variables = sum([
  75. load_variables(geosfpdir, 'nc4', bounds, geosvars, file_exprs, True),
  76. load_variables(epadir, 'nc', bounds, epavars, file_exprs, False)
  77. ], [])
  78. merged = xr.merge(all_variables, combine_attrs='drop_conflicts')
  79. merged = merged.chunk(
  80. {
  81. 'time': 8,
  82. 'lat': len(merged.lat),
  83. 'lon': len(merged.lon),
  84. }
  85. )
  86. print(merged)
  87. write_job = merged.to_zarr(
  88. outputfile, mode='w', compute=False, consolidated=True
  89. )
  90. print(f'Writing data, view progress: {client.dashboard_link}')
  91. write_job.compute()
  92. print('Done.')
  93. if __name__ == '__main__':
  94. main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...