Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#378 Feature/sg 281 add kd notebook

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-281-add_kd_notebook
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
  1. import hydra
  2. from super_gradients.common import StrictLoad
  3. from super_gradients.common.plugins.deci_client import DeciClient
  4. from super_gradients.training import utils as core_utils
  5. from super_gradients.training.models import SgModule
  6. from super_gradients.training.models.all_architectures import ARCHITECTURES
  7. from super_gradients.training.pretrained_models import PRETRAINED_NUM_CLASSES
  8. from super_gradients.training.utils import HpmStruct
  9. from super_gradients.training.utils.checkpoint_utils import (
  10. load_checkpoint_to_model,
  11. load_pretrained_weights,
  12. read_ckpt_state_dict,
  13. load_pretrained_weights_local,
  14. )
  15. from super_gradients.common.abstractions.abstract_logger import get_logger
  16. logger = get_logger(__name__)
  17. def instantiate_model(name: str, arch_params: dict, pretrained_weights: str = None) -> SgModule:
  18. """
  19. Instantiates nn.Module according to architecture and arch_params, and handles pretrained weights and the required
  20. module manipulation (i.e head replacement).
  21. :param name: Defines the model's architecture from models/ALL_ARCHITECTURES
  22. :param arch_params: Architecture's parameters passed to models c'tor.
  23. :param pretrained_weights: string describing the dataset of the pretrained weights (for example "imagenent")
  24. :return: instantiated model i.e torch.nn.Module, architecture_class (will be none when architecture is not str)
  25. """
  26. if pretrained_weights is not None:
  27. if hasattr(arch_params, "num_classes"):
  28. num_classes_new_head = arch_params.num_classes
  29. else:
  30. num_classes_new_head = PRETRAINED_NUM_CLASSES[pretrained_weights]
  31. arch_params.num_classes = PRETRAINED_NUM_CLASSES[pretrained_weights]
  32. remote_model = False
  33. if isinstance(name, str) and name in ARCHITECTURES.keys():
  34. architecture_cls = ARCHITECTURES[name]
  35. net = architecture_cls(arch_params=arch_params)
  36. elif isinstance(name, str):
  37. logger.info(f'Required model {name} not found in local SuperGradients. Trying to load a model from remote deci lab')
  38. deci_client = DeciClient()
  39. _arch_params = deci_client.get_model_arch_params(name)
  40. if _arch_params is not None:
  41. _arch_params = hydra.utils.instantiate(_arch_params)
  42. base_name = _arch_params["model_name"]
  43. _arch_params = HpmStruct(**_arch_params)
  44. architecture_cls = ARCHITECTURES[base_name]
  45. _arch_params.override(**arch_params.to_dict())
  46. net = architecture_cls(arch_params=_arch_params)
  47. remote_model = True
  48. else:
  49. raise ValueError("Unsupported model name " + str(name) + ", see docs or all_architectures.py for supported nets.")
  50. else:
  51. raise ValueError("Unsupported model model_name " + str(name) + ", see docs or all_architectures.py for supported nets.")
  52. if pretrained_weights:
  53. if remote_model:
  54. weights_path = deci_client.get_model_weights(name)
  55. load_pretrained_weights_local(net, name, weights_path)
  56. else:
  57. load_pretrained_weights(net, name, pretrained_weights)
  58. if num_classes_new_head != arch_params.num_classes:
  59. net.replace_head(new_num_classes=num_classes_new_head)
  60. arch_params.num_classes = num_classes_new_head
  61. return net
  62. def get(model_name: str, arch_params: dict = {}, num_classes: int = None,
  63. strict_load: StrictLoad = StrictLoad.NO_KEY_MATCHING, checkpoint_path: str = None,
  64. pretrained_weights: str = None, load_backbone: bool = False) -> SgModule:
  65. """
  66. :param model_name: Defines the model's architecture from models/ALL_ARCHITECTURES
  67. :param num_classes: Number of classes (defines the net's structure). If None is given, will try to derrive from
  68. pretrained_weight's corresponding dataset.
  69. :param arch_params: Architecture hyper parameters. e.g.: block, num_blocks, etc.
  70. :param strict_load: See super_gradients.common.data_types.enum.strict_load.StrictLoad class documentation for details
  71. (default=NO_KEY_MATCHING to suport SG trained checkpoints)
  72. :param load_backbone: loads the provided checkpoint to model.backbone instead of model.
  73. :param checkpoint_path: The path to the external checkpoint to be loaded. Can be absolute or relative
  74. (ie: path/to/checkpoint.pth). If provided, will automatically attempt to
  75. load the checkpoint.
  76. :param pretrained_weights: a string describing the dataset of the pretrained weights (for example "imagenent").
  77. NOTE: Passing pretrained_weights and checkpoint_path is ill-defined and will raise an error.
  78. """
  79. if arch_params.get("num_classes") is not None:
  80. logger.warning("Passing num_classes through arch_params is dperecated and will be removed in the next version. "
  81. "Pass num_classes explicitly to models.get")
  82. num_classes = num_classes or arch_params.get("num_classes")
  83. if pretrained_weights is None and num_classes is None:
  84. raise ValueError("num_classes or pretrained_weights must be passed to determine net's structure.")
  85. if num_classes is not None:
  86. arch_params["num_classes"] = num_classes
  87. arch_params = core_utils.HpmStruct(**arch_params)
  88. net = instantiate_model(model_name, arch_params, pretrained_weights)
  89. if checkpoint_path:
  90. load_ema_as_net = 'ema_net' in read_ckpt_state_dict(ckpt_path=checkpoint_path).keys()
  91. _ = load_checkpoint_to_model(ckpt_local_path=checkpoint_path,
  92. load_backbone=load_backbone,
  93. net=net,
  94. strict=strict_load.value if hasattr(strict_load, "value") else strict_load,
  95. load_weights_only=True,
  96. load_ema_as_net=load_ema_as_net)
  97. return net
Discard
Tip!

Press p or to see the previous file or, n or to see the next file