1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
|
- import argparse
- import math
- from pathlib import Path
- from typing import List
- import numpy as np
- import rioxarray
- import torch
- from deadtrees.data.deadtreedata import val_transform
- from deadtrees.deployment.inference import PyTorchEnsembleInference, PyTorchInference
- from deadtrees.deployment.tiler import Tiler
- from PIL import Image
- from tqdm import tqdm
- def main():
- parser = argparse.ArgumentParser()
- parser.add_argument("infile", type=Path)
- parser.add_argument(
- "-m",
- "--model",
- dest="model",
- action="append",
- type=Path,
- default=[],
- help="model artefact",
- )
- parser.add_argument(
- "-o",
- dest="outpath",
- type=Path,
- default=Path("."),
- help="output directory",
- )
- parser.add_argument(
- "--all",
- action="store_true",
- dest="all",
- default=False,
- help="process complete directory",
- )
- parser.add_argument(
- "--nopreview",
- action="store_false",
- dest="preview",
- default=True,
- help="produce preview images",
- )
- args = parser.parse_args()
- if len(args.model) == 0:
- args.model = [Path("checkpoints/bestmodel.ckpt")]
- bs = 64
- INFILE = args.infile
- def is_valid_tile(infile):
- with rioxarray.open_rasterio(infile).sel(band=1) as t:
- return False if np.isin(t, [0, 255]).all() else True
- # inference = ONNXInference("checkpoints/bestmodel.onnx")
- if len(args.model) == 1:
- print("Default inference: single model")
- inference = PyTorchInference(args.model[0])
- else:
- print(f"Ensemble inference: {len(args.model)} models")
- inference = PyTorchEnsembleInference(*args.model)
- if args.all:
- INFILES = sorted(INFILE.glob("ortho*.tif"))
- else:
- INFILES = [INFILE]
- for INFILE in INFILES:
- if not is_valid_tile(INFILE):
- continue
- tiler = Tiler()
- tiler.load_file(INFILE)
- batches = tiler.get_batches()
- batches = np.array_split(batches, math.ceil(len(batches) / bs), axis=0)
- out_batches = []
- for b, batch in enumerate(tqdm(batches, desc=INFILE.name)):
- batch_tensor = torch.stack(
- [val_transform(image=i.transpose(1, 2, 0))["image"] for i in batch]
- )
- # pytorch
- out_batch = (
- inference.run(batch_tensor.detach().to("cuda"), device="cuda")
- .cpu()
- .numpy()
- )
- out_batches.append(out_batch)
- OUTFILE = args.outpath / INFILE.name
- OUTFILE_PREVIEW = Path(str(args.outpath) + "_preview") / INFILE.name
- tiler.put_batches(np.concatenate(out_batches, axis=0))
- tiler.write_file(OUTFILE)
- if args.preview:
- image = Image.fromarray(np.uint8(tiler._target.values * 255), "L")
- image.save(OUTFILE_PREVIEW)
- if __name__ == "__main__":
- main()
|