Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#609 Ci fix

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:bugfix/infra-000_ci
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
  1. from pathlib import Path
  2. import hydra
  3. import torch
  4. from omegaconf import DictConfig
  5. import numpy as np
  6. from torch.nn import Identity
  7. from super_gradients.common.abstractions.abstract_logger import get_logger
  8. from super_gradients.common.decorators.factory_decorator import resolve_param
  9. from super_gradients.common.factories.transforms_factory import TransformsFactory
  10. from super_gradients.training import models
  11. from super_gradients.training.utils.checkpoint_utils import get_checkpoints_dir_path
  12. from super_gradients.training.utils.hydra_utils import load_experiment_cfg
  13. from super_gradients.training.utils.sg_trainer_utils import parse_args
  14. import os
  15. import pathlib
  16. logger = get_logger(__name__)
  17. class ConvertableCompletePipelineModel(torch.nn.Module):
  18. """
  19. Exportable nn.Module that wraps the model, preprocessing and postprocessing.
  20. Args:
  21. model: torch.nn.Module, the main model. takes input from pre_process' output, and feeds pre_process.
  22. pre_process: torch.nn.Module, preprocessing module, its output will be model's input. When none (default), set to Identity().
  23. pre_process: torch.nn.Module, postprocessing module, its output is the final output. When none (default), set to Identity().
  24. **prep_model_for_conversion_kwargs: for SgModules- args to be passed to model.prep_model_for_conversion
  25. prior to torch.onnx.export call.
  26. """
  27. def __init__(self, model: torch.nn.Module, pre_process: torch.nn.Module = None, post_process: torch.nn.Module = None, **prep_model_for_conversion_kwargs):
  28. super(ConvertableCompletePipelineModel, self).__init__()
  29. model.eval()
  30. pre_process = pre_process or Identity()
  31. post_process = post_process or Identity()
  32. if hasattr(model, "prep_model_for_conversion"):
  33. model.prep_model_for_conversion(**prep_model_for_conversion_kwargs)
  34. self.model = model
  35. self.pre_process = pre_process
  36. self.post_process = post_process
  37. def forward(self, x):
  38. return self.post_process(self.model(self.pre_process(x)))
  39. @resolve_param("pre_process", TransformsFactory())
  40. @resolve_param("post_process", TransformsFactory())
  41. def convert_to_onnx(
  42. model: torch.nn.Module,
  43. out_path: str,
  44. input_shape: tuple,
  45. pre_process: torch.nn.Module = None,
  46. post_process: torch.nn.Module = None,
  47. prep_model_for_conversion_kwargs=None,
  48. torch_onnx_export_kwargs=None,
  49. ):
  50. """
  51. Exports model to ONNX.
  52. :param model: torch.nn.Module, model to export to ONNX.
  53. :param out_path: str, destination path for the .onnx file.
  54. :param input_shape: tuple, input shape, excluding batch_size (i.e (3, 224, 224)).
  55. :param pre_process: torch.nn.Module, preprocessing pipeline, will be resolved by TransformsFactory()
  56. :param post_process: torch.nn.Module, postprocessing pipeline, will be resolved by TransformsFactory()
  57. :param prep_model_for_conversion_kwargs: dict, for SgModules- args to be passed to model.prep_model_for_conversion
  58. prior to torch.onnx.export call.
  59. :param torch_onnx_export_kwargs: kwargs (EXCLUDING: FIRST 3 KWARGS- MODEL, F, ARGS). to be unpacked in torch.onnx.export call
  60. :return: out_path
  61. """
  62. if not os.path.isdir(pathlib.Path(out_path).parent.resolve()):
  63. raise FileNotFoundError(f"Could not find destination directory {out_path} for the ONNX file.")
  64. torch_onnx_export_kwargs = torch_onnx_export_kwargs or dict()
  65. prep_model_for_conversion_kwargs = prep_model_for_conversion_kwargs or dict()
  66. onnx_input = torch.Tensor(np.zeros([1, *input_shape]))
  67. if not out_path.endswith(".onnx"):
  68. out_path = out_path + ".onnx"
  69. complete_model = ConvertableCompletePipelineModel(model, pre_process, post_process, **prep_model_for_conversion_kwargs)
  70. torch.onnx.export(model=complete_model, args=onnx_input, f=out_path, **torch_onnx_export_kwargs)
  71. return out_path
  72. def prepare_conversion_cfgs(cfg: DictConfig):
  73. """
  74. Builds the cfg (i.e conversion_params) and experiment_cfg (i.e recipe config according to cfg.experiment_name)
  75. to be used by convert_recipe_example
  76. :param cfg: DictConfig, converion_params config
  77. :return: cfg, experiment_cfg
  78. """
  79. cfg = hydra.utils.instantiate(cfg)
  80. # CREATE THE EXPERIMENT CFG
  81. experiment_cfg = load_experiment_cfg(cfg.experiment_name, cfg.ckpt_root_dir)
  82. hydra.utils.instantiate(experiment_cfg)
  83. if cfg.checkpoint_path is None:
  84. logger.info(
  85. "checkpoint_params.checkpoint_path was not provided, so the model will be converted using weights from "
  86. "checkpoints_dir/training_hyperparams.ckpt_name "
  87. )
  88. checkpoints_dir = Path(get_checkpoints_dir_path(experiment_name=cfg.experiment_name, ckpt_root_dir=cfg.ckpt_root_dir))
  89. cfg.checkpoint_path = str(checkpoints_dir / cfg.ckpt_name)
  90. cfg.out_path = cfg.out_path or cfg.checkpoint_path.replace(".ckpt", ".onnx")
  91. logger.info(f"Exporting checkpoint: {cfg.checkpoint_path} to ONNX.")
  92. return cfg, experiment_cfg
  93. def convert_from_config(cfg: DictConfig) -> str:
  94. """
  95. Exports model according to cfg.
  96. See:
  97. super_gradients/recipes/conversion_params/default_conversion_params.yaml for the full cfg content documentation,
  98. and super_gradients/examples/convert_recipe_example/convert_recipe_example.py for usage.
  99. :param cfg:
  100. :return: out_path, the path of the saved .onnx file.
  101. """
  102. cfg, experiment_cfg = prepare_conversion_cfgs(cfg)
  103. model = models.get(
  104. model_name=experiment_cfg.architecture,
  105. num_classes=experiment_cfg.arch_params.num_classes,
  106. arch_params=experiment_cfg.arch_params,
  107. strict_load=cfg.strict_load,
  108. checkpoint_path=cfg.checkpoint_path,
  109. )
  110. cfg = parse_args(cfg, models.convert_to_onnx)
  111. out_path = models.convert_to_onnx(model=model, **cfg)
  112. return out_path
Discard
Tip!

Press p or to see the previous file or, n or to see the next file