Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

inference.py 2.9 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
  1. import argparse
  2. import math
  3. from pathlib import Path
  4. from typing import List
  5. import numpy as np
  6. import rioxarray
  7. import torch
  8. from deadtrees.data.deadtreedata import val_transform
  9. from deadtrees.deployment.inference import PyTorchEnsembleInference, PyTorchInference
  10. from deadtrees.deployment.tiler import Tiler
  11. from PIL import Image
  12. from tqdm import tqdm
  13. def main():
  14. parser = argparse.ArgumentParser()
  15. parser.add_argument("infile", type=Path)
  16. parser.add_argument(
  17. "-m",
  18. "--model",
  19. dest="model",
  20. action="append",
  21. type=Path,
  22. default=[],
  23. help="model artefact",
  24. )
  25. parser.add_argument(
  26. "-o",
  27. dest="outpath",
  28. type=Path,
  29. default=Path("."),
  30. help="output directory",
  31. )
  32. parser.add_argument(
  33. "--all",
  34. action="store_true",
  35. dest="all",
  36. default=False,
  37. help="process complete directory",
  38. )
  39. parser.add_argument(
  40. "--nopreview",
  41. action="store_false",
  42. dest="preview",
  43. default=True,
  44. help="produce preview images",
  45. )
  46. args = parser.parse_args()
  47. if len(args.model) == 0:
  48. args.model = [Path("checkpoints/bestmodel.ckpt")]
  49. bs = 64
  50. INFILE = args.infile
  51. def is_valid_tile(infile):
  52. with rioxarray.open_rasterio(infile).sel(band=1) as t:
  53. return False if np.isin(t, [0, 255]).all() else True
  54. # inference = ONNXInference("checkpoints/bestmodel.onnx")
  55. if len(args.model) == 1:
  56. print("Default inference: single model")
  57. inference = PyTorchInference(args.model[0])
  58. else:
  59. print(f"Ensemble inference: {len(args.model)} models")
  60. inference = PyTorchEnsembleInference(*args.model)
  61. if args.all:
  62. INFILES = sorted(INFILE.glob("ortho*.tif"))
  63. else:
  64. INFILES = [INFILE]
  65. for INFILE in INFILES:
  66. if not is_valid_tile(INFILE):
  67. continue
  68. tiler = Tiler()
  69. tiler.load_file(INFILE)
  70. batches = tiler.get_batches()
  71. batches = np.array_split(batches, math.ceil(len(batches) / bs), axis=0)
  72. out_batches = []
  73. for b, batch in enumerate(tqdm(batches, desc=INFILE.name)):
  74. batch_tensor = torch.stack(
  75. [val_transform(image=i.transpose(1, 2, 0))["image"] for i in batch]
  76. )
  77. # pytorch
  78. out_batch = (
  79. inference.run(batch_tensor.detach().to("cuda"), device="cuda")
  80. .cpu()
  81. .numpy()
  82. )
  83. out_batches.append(out_batch)
  84. OUTFILE = args.outpath / INFILE.name
  85. OUTFILE_PREVIEW = Path(str(args.outpath) + "_preview") / INFILE.name
  86. tiler.put_batches(np.concatenate(out_batches, axis=0))
  87. tiler.write_file(OUTFILE)
  88. if args.preview:
  89. image = Image.fromarray(np.uint8(tiler._target.values * 255), "L")
  90. image.save(OUTFILE_PREVIEW)
  91. if __name__ == "__main__":
  92. main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...