Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#595 Feature/sg 492 fuzzy logic for get param and factories

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-492_fuzzy_logic_for_get_param_and_factories
@@ -7,10 +7,11 @@ import torch
 from super_gradients.common import StrictLoad
 from super_gradients.common import StrictLoad
 from super_gradients.common.plugins.deci_client import DeciClient, client_enabled
 from super_gradients.common.plugins.deci_client import DeciClient, client_enabled
 from super_gradients.training import utils as core_utils
 from super_gradients.training import utils as core_utils
+from super_gradients.common.exceptions.factory_exceptions import UnknownTypeException
 from super_gradients.training.models import SgModule
 from super_gradients.training.models import SgModule
 from super_gradients.training.models.all_architectures import ARCHITECTURES
 from super_gradients.training.models.all_architectures import ARCHITECTURES
 from super_gradients.training.pretrained_models import PRETRAINED_NUM_CLASSES
 from super_gradients.training.pretrained_models import PRETRAINED_NUM_CLASSES
-from super_gradients.training.utils import HpmStruct
+from super_gradients.training.utils import HpmStruct, get_param
 from super_gradients.training.utils.checkpoint_utils import (
 from super_gradients.training.utils.checkpoint_utils import (
     load_checkpoint_to_model,
     load_checkpoint_to_model,
     load_pretrained_weights,
     load_pretrained_weights,
@@ -42,7 +43,9 @@ def get_architecture(model_name: str, arch_params: HpmStruct, download_required_
     is_remote = False
     is_remote = False
     if not isinstance(model_name, str):
     if not isinstance(model_name, str):
         raise ValueError("Parameter model_name is expected to be a string.")
         raise ValueError("Parameter model_name is expected to be a string.")
-    elif model_name not in ARCHITECTURES.keys():
+
+    architecture = get_param(ARCHITECTURES, model_name)
+    if model_name not in ARCHITECTURES.keys() and architecture is None:
         if client_enabled:
         if client_enabled:
             logger.info(f'The required model, "{model_name}", was not found in SuperGradients. Trying to load a model from remote deci-lab')
             logger.info(f'The required model, "{model_name}", was not found in SuperGradients. Trying to load a model from remote deci-lab')
             deci_client = DeciClient()
             deci_client = DeciClient()
@@ -63,11 +66,13 @@ def get_architecture(model_name: str, arch_params: HpmStruct, download_required_
             _arch_params.override(**arch_params.to_dict())
             _arch_params.override(**arch_params.to_dict())
             arch_params, is_remote = _arch_params, True
             arch_params, is_remote = _arch_params, True
         else:
         else:
-            raise ValueError(
-                f'The required model, "{model_name}", was not found in SuperGradients. See docs or all_architectures.py for supported model names.'
+            raise UnknownTypeException(
+                message=f'The required model, "{model_name}", was not found in SuperGradients. See docs or all_architectures.py for supported model names.',
+                unknown_type=model_name,
+                choices=list(ARCHITECTURES.keys()),
             )
             )
 
 
-    return ARCHITECTURES[model_name], arch_params, pretrained_weights_path, is_remote
+    return get_param(ARCHITECTURES, model_name), arch_params, pretrained_weights_path, is_remote
 
 
 
 
 def instantiate_model(
 def instantiate_model(
Discard
Tip!

Press p or to see the previous file or, n or to see the next file