|
@@ -1,10 +1,18 @@
|
|
|
|
+import hydra
|
|
|
|
+
|
|
from super_gradients.common import StrictLoad
|
|
from super_gradients.common import StrictLoad
|
|
|
|
+from super_gradients.common.plugins.deci_client import DeciClient
|
|
from super_gradients.training import utils as core_utils
|
|
from super_gradients.training import utils as core_utils
|
|
from super_gradients.training.models import SgModule
|
|
from super_gradients.training.models import SgModule
|
|
from super_gradients.training.models.all_architectures import ARCHITECTURES
|
|
from super_gradients.training.models.all_architectures import ARCHITECTURES
|
|
from super_gradients.training.pretrained_models import PRETRAINED_NUM_CLASSES
|
|
from super_gradients.training.pretrained_models import PRETRAINED_NUM_CLASSES
|
|
-from super_gradients.training.utils.checkpoint_utils import load_checkpoint_to_model, load_pretrained_weights, \
|
|
|
|
- read_ckpt_state_dict
|
|
|
|
|
|
+from super_gradients.training.utils import HpmStruct
|
|
|
|
+from super_gradients.training.utils.checkpoint_utils import (
|
|
|
|
+ load_checkpoint_to_model,
|
|
|
|
+ load_pretrained_weights,
|
|
|
|
+ read_ckpt_state_dict,
|
|
|
|
+ load_pretrained_weights_local,
|
|
|
|
+)
|
|
from super_gradients.common.abstractions.abstract_logger import get_logger
|
|
from super_gradients.common.abstractions.abstract_logger import get_logger
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
logger = get_logger(__name__)
|
|
@@ -31,14 +39,34 @@ def instantiate_model(name: str, arch_params: dict, pretrained_weights: str = No
|
|
|
|
|
|
arch_params.num_classes = PRETRAINED_NUM_CLASSES[pretrained_weights]
|
|
arch_params.num_classes = PRETRAINED_NUM_CLASSES[pretrained_weights]
|
|
|
|
|
|
|
|
+ remote_model = False
|
|
if isinstance(name, str) and name in ARCHITECTURES.keys():
|
|
if isinstance(name, str) and name in ARCHITECTURES.keys():
|
|
architecture_cls = ARCHITECTURES[name]
|
|
architecture_cls = ARCHITECTURES[name]
|
|
net = architecture_cls(arch_params=arch_params)
|
|
net = architecture_cls(arch_params=arch_params)
|
|
|
|
+ elif isinstance(name, str):
|
|
|
|
+ logger.info(f'Required model {name} not found in local SuperGradients. Trying to load a model from remote deci lab')
|
|
|
|
+ deci_client = DeciClient()
|
|
|
|
+ _arch_params = deci_client.get_model_arch_params(name)
|
|
|
|
+
|
|
|
|
+ if _arch_params is not None:
|
|
|
|
+ _arch_params = hydra.utils.instantiate(_arch_params)
|
|
|
|
+ base_name = _arch_params["model_name"]
|
|
|
|
+ _arch_params = HpmStruct(**_arch_params)
|
|
|
|
+ architecture_cls = ARCHITECTURES[base_name]
|
|
|
|
+ _arch_params.override(**arch_params.to_dict())
|
|
|
|
+
|
|
|
|
+ net = architecture_cls(arch_params=_arch_params)
|
|
|
|
+ remote_model = True
|
|
|
|
+ else:
|
|
|
|
+ raise ValueError("Unsupported model name " + str(name) + ", see docs or all_architectures.py for supported nets.")
|
|
else:
|
|
else:
|
|
- raise ValueError(
|
|
|
|
- "Unsupported model model_name " + str(name) + ", see docs or all_architectures.py for supported nets.")
|
|
|
|
|
|
+ raise ValueError("Unsupported model model_name " + str(name) + ", see docs or all_architectures.py for supported nets.")
|
|
if pretrained_weights:
|
|
if pretrained_weights:
|
|
- load_pretrained_weights(net, name, pretrained_weights)
|
|
|
|
|
|
+ if remote_model:
|
|
|
|
+ weights_path = deci_client.get_model_weights(name)
|
|
|
|
+ load_pretrained_weights_local(net, name, weights_path)
|
|
|
|
+ else:
|
|
|
|
+ load_pretrained_weights(net, name, pretrained_weights)
|
|
if num_classes_new_head != arch_params.num_classes:
|
|
if num_classes_new_head != arch_params.num_classes:
|
|
net.replace_head(new_num_classes=num_classes_new_head)
|
|
net.replace_head(new_num_classes=num_classes_new_head)
|
|
arch_params.num_classes = num_classes_new_head
|
|
arch_params.num_classes = num_classes_new_head
|