Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

model_factory.py 6.0 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
  1. from typing import Optional
  2. import hydra
  3. from super_gradients.common import StrictLoad
  4. from super_gradients.common.plugins.deci_client import DeciClient
  5. from super_gradients.training import utils as core_utils
  6. from super_gradients.training.models import SgModule
  7. from super_gradients.training.models.all_architectures import ARCHITECTURES
  8. from super_gradients.training.pretrained_models import PRETRAINED_NUM_CLASSES
  9. from super_gradients.training.utils import HpmStruct
  10. from super_gradients.training.utils.checkpoint_utils import (
  11. load_checkpoint_to_model,
  12. load_pretrained_weights,
  13. read_ckpt_state_dict,
  14. load_pretrained_weights_local,
  15. )
  16. from super_gradients.common.abstractions.abstract_logger import get_logger
  17. logger = get_logger(__name__)
  18. def instantiate_model(name: str, arch_params: dict, pretrained_weights: str = None) -> SgModule:
  19. """
  20. Instantiates nn.Module according to architecture and arch_params, and handles pretrained weights and the required
  21. module manipulation (i.e head replacement).
  22. :param name: Defines the model's architecture from models/ALL_ARCHITECTURES
  23. :param arch_params: Architecture's parameters passed to models c'tor.
  24. :param pretrained_weights: string describing the dataset of the pretrained weights (for example "imagenent")
  25. :return: instantiated model i.e torch.nn.Module, architecture_class (will be none when architecture is not str)
  26. """
  27. if pretrained_weights is not None:
  28. if hasattr(arch_params, "num_classes"):
  29. num_classes_new_head = arch_params.num_classes
  30. else:
  31. num_classes_new_head = PRETRAINED_NUM_CLASSES[pretrained_weights]
  32. arch_params.num_classes = PRETRAINED_NUM_CLASSES[pretrained_weights]
  33. remote_model = False
  34. if isinstance(name, str) and name in ARCHITECTURES.keys():
  35. architecture_cls = ARCHITECTURES[name]
  36. net = architecture_cls(arch_params=arch_params)
  37. elif isinstance(name, str):
  38. logger.info(f'Required model {name} not found in local SuperGradients. Trying to load a model from remote deci lab')
  39. deci_client = DeciClient()
  40. _arch_params = deci_client.get_model_arch_params(name)
  41. if _arch_params is not None:
  42. _arch_params = hydra.utils.instantiate(_arch_params)
  43. base_name = _arch_params["model_name"]
  44. _arch_params = HpmStruct(**_arch_params)
  45. architecture_cls = ARCHITECTURES[base_name]
  46. _arch_params.override(**arch_params.to_dict())
  47. net = architecture_cls(arch_params=_arch_params)
  48. remote_model = True
  49. else:
  50. raise ValueError("Unsupported model name " + str(name) + ", see docs or all_architectures.py for supported nets.")
  51. else:
  52. raise ValueError("Unsupported model model_name " + str(name) + ", see docs or all_architectures.py for supported nets.")
  53. if pretrained_weights:
  54. if remote_model:
  55. weights_path = deci_client.get_model_weights(name)
  56. load_pretrained_weights_local(net, name, weights_path)
  57. else:
  58. load_pretrained_weights(net, name, pretrained_weights)
  59. if num_classes_new_head != arch_params.num_classes:
  60. net.replace_head(new_num_classes=num_classes_new_head)
  61. arch_params.num_classes = num_classes_new_head
  62. return net
  63. def get(model_name: str, arch_params: Optional[dict] = None, num_classes: int = None,
  64. strict_load: StrictLoad = StrictLoad.NO_KEY_MATCHING, checkpoint_path: str = None,
  65. pretrained_weights: str = None, load_backbone: bool = False) -> SgModule:
  66. """
  67. :param model_name: Defines the model's architecture from models/ALL_ARCHITECTURES
  68. :param num_classes: Number of classes (defines the net's structure). If None is given, will try to derrive from
  69. pretrained_weight's corresponding dataset.
  70. :param arch_params: Architecture hyper parameters. e.g.: block, num_blocks, etc.
  71. :param strict_load: See super_gradients.common.data_types.enum.strict_load.StrictLoad class documentation for details
  72. (default=NO_KEY_MATCHING to suport SG trained checkpoints)
  73. :param load_backbone: loads the provided checkpoint to model.backbone instead of model.
  74. :param checkpoint_path: The path to the external checkpoint to be loaded. Can be absolute or relative
  75. (ie: path/to/checkpoint.pth). If provided, will automatically attempt to
  76. load the checkpoint.
  77. :param pretrained_weights: a string describing the dataset of the pretrained weights (for example "imagenent").
  78. NOTE: Passing pretrained_weights and checkpoint_path is ill-defined and will raise an error.
  79. """
  80. if arch_params is None:
  81. arch_params = {}
  82. if arch_params.get("num_classes") is not None:
  83. logger.warning("Passing num_classes through arch_params is dperecated and will be removed in the next version. "
  84. "Pass num_classes explicitly to models.get")
  85. num_classes = num_classes or arch_params.get("num_classes")
  86. if pretrained_weights is None and num_classes is None:
  87. raise ValueError("num_classes or pretrained_weights must be passed to determine net's structure.")
  88. if num_classes is not None:
  89. arch_params["num_classes"] = num_classes
  90. arch_params = core_utils.HpmStruct(**arch_params)
  91. net = instantiate_model(model_name, arch_params, pretrained_weights)
  92. if checkpoint_path:
  93. load_ema_as_net = 'ema_net' in read_ckpt_state_dict(ckpt_path=checkpoint_path).keys()
  94. _ = load_checkpoint_to_model(ckpt_local_path=checkpoint_path,
  95. load_backbone=load_backbone,
  96. net=net,
  97. strict=strict_load.value if hasattr(strict_load, "value") else strict_load,
  98. load_weights_only=True,
  99. load_ema_as_net=load_ema_as_net)
  100. return net
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...