Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

create_onnx.py 1.4 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
  1. import argparse
  2. from pathlib import Path
  3. from deadtrees.data.deadtreedata import DeadtreesDataModule
  4. from deadtrees.network.segmodel import SemSegment
  5. from torch.utils import data
  6. def main():
  7. parser = argparse.ArgumentParser()
  8. parser.add_argument("modelfile", type=Path)
  9. args = parser.parse_args()
  10. datamodule = DeadtreesDataModule(
  11. "../data/dataset/train_balanced_short/",
  12. pattern="train-balanced-short-000*.tar",
  13. train_dataloader_conf={"batch_size": 8, "num_workers": 4},
  14. val_dataloader_conf={"batch_size": 8, "num_workers": 2},
  15. test_dataloader_conf={"batch_size": 1, "num_workers": 1},
  16. )
  17. datamodule.setup()
  18. model = SemSegment.load_from_checkpoint(args.modelfile)
  19. model.eval()
  20. input_sample = next(iter(datamodule.train_dataloader()))[0]
  21. print(input_sample)
  22. filepath = args.modelfile
  23. model.to_onnx(
  24. filepath.with_suffix(".onnx"),
  25. input_sample,
  26. export_params=True,
  27. opset_version=11,
  28. do_constant_folding=True, # whether to execute constant folding for optimization
  29. input_names=["input"], # the model's input names
  30. output_names=["output"], # the model's output names
  31. dynamic_axes={
  32. "input": {0: "batch_size"}, # variable length axes
  33. "output": {0: "batch_size"},
  34. },
  35. )
  36. if __name__ == "__main__":
  37. main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...