Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

eval_seq2seq.py 2.1 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
  1. #!/usr/bin/env python
  2. import os
  3. import json
  4. import click
  5. import numpy as np
  6. import xarray as xr
  7. from tqdm import tqdm
  8. import torch
  9. import matplotlib.pyplot as plt
  10. from epa_seq2seq import (
  11. EpaDataset, EpaSeq2Seq,
  12. eval_model, mse_nan_loss,
  13. )
  14. @click.command()
  15. @click.argument('modelfile')
  16. @click.argument('datafile')
  17. @click.argument('outputfile')
  18. def main(modelfile, datafile, outputfile):
  19. device = torch.device('cpu')
  20. checkpoint = torch.load(modelfile, map_location=device)
  21. config = checkpoint['config']
  22. window_size = config['sequence_length']
  23. in_channels = config['in_channels']
  24. out_channels = config['out_channels']
  25. model_params = config['model_params']
  26. model_params['device'] = device
  27. dataset_names = ('test',)
  28. datasets = {}
  29. with xr.open_zarr(datafile) as ds:
  30. for s in dataset_names:
  31. ds_sub = ds.sel(time=slice(*config[f'{s}_date_range']))
  32. datasets[s] = EpaDataset(ds_sub, window_size, in_channels, out_channels)
  33. model = EpaSeq2Seq(
  34. in_channels=len(in_channels),
  35. out_channels=len(out_channels),
  36. frame_size=datasets['test'].frame_size,
  37. **model_params
  38. )
  39. model.load_state_dict(checkpoint['model_state_dict'])
  40. model.to(device)
  41. model.eval()
  42. test = datasets['test']
  43. all_times = []
  44. predictions = []
  45. for i in tqdm(list(range(len(test)))):
  46. X, y = test[i]
  47. times = test.get_time(i)
  48. X = torch.tensor(X[np.newaxis, ...])
  49. y = torch.tensor(y[np.newaxis, ...])
  50. ypred = np.squeeze(model(X).detach().numpy())
  51. predictions.append(ypred[-1])
  52. all_times.append(times[-1])
  53. predictions = np.stack(predictions, axis=0)
  54. dataset = xr.Dataset(
  55. data_vars={
  56. 'PM25': (
  57. ['time', 'lat', 'lon'],
  58. predictions,
  59. test.y['PM25'].attrs
  60. )
  61. },
  62. coords={
  63. 'time': xr.Variable('time', np.array(all_times)),
  64. 'lat': test.dataset.lat,
  65. 'lon': test.dataset.lon,
  66. }
  67. ).chunk(time=8)
  68. dataset.to_zarr(outputfile)
  69. if __name__ == '__main__':
  70. main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...