Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

computestats_inference.py 2.5 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
  1. import argparse
  2. from functools import partial, reduce
  3. from pathlib import Path
  4. from joblib import delayed, Parallel
  5. import numpy as np
  6. import pandas as pd
  7. import rioxarray
  8. from tqdm import tqdm
  9. classes = [0, 1, 2]
  10. WORKERS = 16
  11. def process_tile(tile, *, year):
  12. with rioxarray.open_rasterio(tile, chunks=(1, 512, 512)).squeeze(drop=True) as ds:
  13. # TODO: we only consider deadtrees (not potential classes of them)
  14. # ds = ds.clip(max=1)
  15. unique, counts = np.unique(ds.values, return_counts=True)
  16. row_data = dict(zip([f"cl_{int(x)}" for x in unique], counts))
  17. for c in classes:
  18. if f"cl_{c}" not in row_data:
  19. row_data[f"cl_{c}"] = 0
  20. row_data["total"] = int(ds.count().compute())
  21. row_data["tile"] = tile.stem.replace(f"ortho_ms_{year}_EPSG3044_", "")
  22. return row_data
  23. def main():
  24. parser = argparse.ArgumentParser()
  25. parser.add_argument("datapath", type=Path, nargs="+")
  26. args = parser.parse_args()
  27. dfs = []
  28. years = [2017, 2018, 2019, 2020]
  29. for year in years:
  30. inpath = None
  31. for dpath in args.datapath:
  32. if f"predicted.{year}" in str(dpath):
  33. inpath = dpath
  34. if not inpath:
  35. raise NotImplementedError
  36. print(f"Processing year: {year}...")
  37. tiles = sorted(inpath.glob("*.tif"))
  38. results = Parallel(n_jobs=WORKERS)(
  39. delayed(partial(process_tile, year=year))(x) for x in tqdm(tiles)
  40. )
  41. df = pd.DataFrame(results)
  42. df["deadarea_m2"] = (
  43. (df["cl_1"] + df["cl_2"]) * 0.200022269188281 * 0.200022454940277
  44. ).round(1)
  45. dfs.append(df)
  46. # add suffixes to each df (but remove if again form the join column <tile>)
  47. dfs = [df.add_suffix(f"_{s}") for df, s in zip(dfs, years)]
  48. dfs = [df.rename(columns={f"tile_{s}": "tile"}) for df, s in zip(dfs, years)]
  49. dfall = reduce(lambda x, y: pd.merge(x, y, on=["tile"], how="outer"), dfs)
  50. # cleanup, drop all but one columns total and move tile, total cols to front
  51. dfall = dfall.rename(columns={f"total_{years[0]}": "total"})
  52. dfall = dfall[dfall.columns.drop(list(dfall.filter(regex="total_")))]
  53. colnames = list(dfall)
  54. colnames.insert(0, colnames.pop(colnames.index("total")))
  55. colnames.insert(0, colnames.pop(colnames.index("tile")))
  56. dfall = dfall.loc[:, colnames].convert_dtypes()
  57. dfall.to_csv(args.datapath[0].parent / "predicted.stats.csv", index=False)
  58. # merge
  59. if __name__ == "__main__":
  60. main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...