|
@@ -139,6 +139,7 @@ def get(
|
|
pretrained_weights: str = None,
|
|
pretrained_weights: str = None,
|
|
load_backbone: bool = False,
|
|
load_backbone: bool = False,
|
|
download_required_code: bool = True,
|
|
download_required_code: bool = True,
|
|
|
|
+ checkpoint_num_classes: int = None,
|
|
) -> SgModule:
|
|
) -> SgModule:
|
|
"""
|
|
"""
|
|
:param model_name: Defines the model's architecture from models/ALL_ARCHITECTURES
|
|
:param model_name: Defines the model's architecture from models/ALL_ARCHITECTURES
|
|
@@ -153,11 +154,20 @@ def get(
|
|
:param load_backbone: Load the provided checkpoint to model.backbone instead of model.
|
|
:param load_backbone: Load the provided checkpoint to model.backbone instead of model.
|
|
:param download_required_code: if model is not found in SG and is downloaded from a remote client, overriding this parameter with False
|
|
:param download_required_code: if model is not found in SG and is downloaded from a remote client, overriding this parameter with False
|
|
will prevent additional code from being downloaded. This affects only models from remote client.
|
|
will prevent additional code from being downloaded. This affects only models from remote client.
|
|
|
|
+ :param checkpoint_num_classes: num_classes of checkpoint_path/ pretrained_weights, when checkpoint_path is not None.
|
|
|
|
+ Used when num_classes != checkpoint_num_class. In this case, the module will be initialized with checkpoint_num_class, then weights will be loaded. Finaly
|
|
|
|
+ replace_head(new_num_classes=num_classes) is called (useful when wanting to perform transfer learning, from a checkpoint outside of
|
|
|
|
+ then ones offered in SG model zoo).
|
|
|
|
+
|
|
|
|
|
|
NOTE: Passing pretrained_weights and checkpoint_path is ill-defined and will raise an error.
|
|
NOTE: Passing pretrained_weights and checkpoint_path is ill-defined and will raise an error.
|
|
"""
|
|
"""
|
|
|
|
+ checkpoint_num_classes = checkpoint_num_classes or num_classes
|
|
|
|
|
|
- net = instantiate_model(model_name, arch_params, num_classes, pretrained_weights, download_required_code)
|
|
|
|
|
|
+ if checkpoint_num_classes:
|
|
|
|
+ net = instantiate_model(model_name, arch_params, checkpoint_num_classes, pretrained_weights, download_required_code)
|
|
|
|
+ else:
|
|
|
|
+ net = instantiate_model(model_name, arch_params, num_classes, pretrained_weights, download_required_code)
|
|
|
|
|
|
if load_backbone and not checkpoint_path:
|
|
if load_backbone and not checkpoint_path:
|
|
raise ValueError("Please set checkpoint_path when load_backbone=True")
|
|
raise ValueError("Please set checkpoint_path when load_backbone=True")
|
|
@@ -172,4 +182,7 @@ def get(
|
|
load_weights_only=True,
|
|
load_weights_only=True,
|
|
load_ema_as_net=load_ema_as_net,
|
|
load_ema_as_net=load_ema_as_net,
|
|
)
|
|
)
|
|
|
|
+ if checkpoint_num_classes != num_classes:
|
|
|
|
+ net.replace_head(new_num_classes=num_classes)
|
|
|
|
+
|
|
return net
|
|
return net
|