Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#595 Feature/sg 492 fuzzy logic for get param and factories

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-492_fuzzy_logic_for_get_param_and_factories
@@ -7,10 +7,11 @@ import torch
 from super_gradients.common import StrictLoad
 from super_gradients.common import StrictLoad
 from super_gradients.common.plugins.deci_client import DeciClient, client_enabled
 from super_gradients.common.plugins.deci_client import DeciClient, client_enabled
 from super_gradients.training import utils as core_utils
 from super_gradients.training import utils as core_utils
+from super_gradients.common.exceptions.factory_exceptions import UnknownTypeException
 from super_gradients.training.models import SgModule
 from super_gradients.training.models import SgModule
 from super_gradients.training.models.all_architectures import ARCHITECTURES
 from super_gradients.training.models.all_architectures import ARCHITECTURES
 from super_gradients.training.pretrained_models import PRETRAINED_NUM_CLASSES
 from super_gradients.training.pretrained_models import PRETRAINED_NUM_CLASSES
-from super_gradients.training.utils import HpmStruct
+from super_gradients.training.utils import HpmStruct, get_param
 from super_gradients.training.utils.checkpoint_utils import (
 from super_gradients.training.utils.checkpoint_utils import (
     load_checkpoint_to_model,
     load_checkpoint_to_model,
     load_pretrained_weights,
     load_pretrained_weights,
@@ -42,7 +43,9 @@ def get_architecture(model_name: str, arch_params: HpmStruct, download_required_
     is_remote = False
     is_remote = False
     if not isinstance(model_name, str):
     if not isinstance(model_name, str):
         raise ValueError("Parameter model_name is expected to be a string.")
         raise ValueError("Parameter model_name is expected to be a string.")
-    elif model_name not in ARCHITECTURES.keys():
+
+    architecture = get_param(ARCHITECTURES, model_name)
+    if model_name not in ARCHITECTURES.keys() and architecture is None:
         if client_enabled:
         if client_enabled:
             logger.info(f'The required model, "{model_name}", was not found in SuperGradients. Trying to load a model from remote deci-lab')
             logger.info(f'The required model, "{model_name}", was not found in SuperGradients. Trying to load a model from remote deci-lab')
             deci_client = DeciClient()
             deci_client = DeciClient()
@@ -63,11 +66,13 @@ def get_architecture(model_name: str, arch_params: HpmStruct, download_required_
             _arch_params.override(**arch_params.to_dict())
             _arch_params.override(**arch_params.to_dict())
             arch_params, is_remote = _arch_params, True
             arch_params, is_remote = _arch_params, True
         else:
         else:
-            raise ValueError(
-                f'The required model, "{model_name}", was not found in SuperGradients. See docs or all_architectures.py for supported model names.'
+            raise UnknownTypeException(
+                message=f'The required model, "{model_name}", was not found in SuperGradients. See docs or all_architectures.py for supported model names.',
+                unknown_type=model_name,
+                choices=list(ARCHITECTURES.keys()),
             )
             )
 
 
-    return ARCHITECTURES[model_name], arch_params, pretrained_weights_path, is_remote
+    return get_param(ARCHITECTURES, model_name), arch_params, pretrained_weights_path, is_remote
 
 
 
 
 def instantiate_model(
 def instantiate_model(
Discard
@@ -38,7 +38,6 @@ from super_gradients.training.pretrained_models import PRETRAINED_NUM_CLASSES
 from super_gradients.training.utils import sg_trainer_utils, get_param
 from super_gradients.training.utils import sg_trainer_utils, get_param
 from super_gradients.training.utils.sg_trainer_utils import MonitoredValue, parse_args, log_main_training_params
 from super_gradients.training.utils.sg_trainer_utils import MonitoredValue, parse_args, log_main_training_params
 from super_gradients.training.exceptions.sg_trainer_exceptions import UnsupportedOptimizerFormat, GPUModeNotSetupError
 from super_gradients.training.exceptions.sg_trainer_exceptions import UnsupportedOptimizerFormat, GPUModeNotSetupError
-from super_gradients.training.losses import LOSSES
 from super_gradients.training.metrics.metric_utils import (
 from super_gradients.training.metrics.metric_utils import (
     get_metrics_titles,
     get_metrics_titles,
     get_metrics_results_tuple,
     get_metrics_results_tuple,
@@ -60,6 +59,7 @@ from super_gradients.training.utils.distributed_training_utils import (
 )
 )
 from super_gradients.training.utils.ema import ModelEMA
 from super_gradients.training.utils.ema import ModelEMA
 from super_gradients.training.utils.optimizer_utils import build_optimizer
 from super_gradients.training.utils.optimizer_utils import build_optimizer
+from super_gradients.training.utils.utils import fuzzy_idx_in_list
 from super_gradients.training.utils.weight_averaging_utils import ModelWeightAveraging
 from super_gradients.training.utils.weight_averaging_utils import ModelWeightAveraging
 from super_gradients.training.metrics import Accuracy, Top5
 from super_gradients.training.metrics import Accuracy, Top5
 from super_gradients.training.utils import random_seed
 from super_gradients.training.utils import random_seed
@@ -465,7 +465,7 @@ class Trainer:
         return loss, loss_logging_items
         return loss, loss_logging_items
 
 
     def _init_monitored_items(self):
     def _init_monitored_items(self):
-        self.metric_idx_in_results_tuple = (self.loss_logging_items_names + get_metrics_titles(self.valid_metrics)).index(self.metric_to_watch)
+        self.metric_idx_in_results_tuple = fuzzy_idx_in_list(self.metric_to_watch, self.loss_logging_items_names + get_metrics_titles(self.valid_metrics))
         # Instantiate the values to monitor (loss/metric)
         # Instantiate the values to monitor (loss/metric)
         for loss_name in self.loss_logging_items_names:
         for loss_name in self.loss_logging_items_names:
             self.train_monitored_values[loss_name] = MonitoredValue(name=loss_name, greater_is_better=False)
             self.train_monitored_values[loss_name] = MonitoredValue(name=loss_name, greater_is_better=False)
@@ -997,8 +997,7 @@ class Trainer:
 
 
         # Allowing loading instantiated loss or string
         # Allowing loading instantiated loss or string
         if isinstance(self.training_params.loss, str):
         if isinstance(self.training_params.loss, str):
-            criterion_cls = LOSSES[self.training_params.loss]
-            self.criterion = criterion_cls(**self.training_params.criterion_params)
+            self.criterion = LossesFactory().get({self.training_params.loss: self.training_params.criterion_params})
 
 
         elif isinstance(self.training_params.loss, Mapping):
         elif isinstance(self.training_params.loss, Mapping):
             self.criterion = LossesFactory().get(self.training_params.loss)
             self.criterion = LossesFactory().get(self.training_params.loss)
Discard
@@ -2,13 +2,13 @@ import math
 import time
 import time
 from functools import lru_cache
 from functools import lru_cache
 from pathlib import Path
 from pathlib import Path
-from typing import Mapping, Optional, Tuple, Union, List, Dict
+from typing import Mapping, Optional, Tuple, Union, List, Dict, Any
 from zipfile import ZipFile
 from zipfile import ZipFile
 import os
 import os
 from jsonschema import validate
 from jsonschema import validate
 import tarfile
 import tarfile
 from PIL import Image, ExifTags
 from PIL import Image, ExifTags
-
+import re
 import torch
 import torch
 import torch.nn as nn
 import torch.nn as nn
 
 
@@ -168,18 +168,77 @@ def tensor_container_to_device(obj: Union[torch.Tensor, tuple, list, dict], devi
         return obj
         return obj
 
 
 
 
+def fuzzy_keys(params: Mapping) -> List[str]:
+    """
+    Returns params.key() removing leading and trailing white space, lower-casing and dropping symbols.
+    :param params: Mapping, the mapping containing the keys to be returned.
+    :return: List[str], list of keys as discussed above.
+    """
+    return [fuzzy_str(s) for s in params.keys()]
+
+
+def fuzzy_str(s: str):
+    """
+    Returns s removing leading and trailing white space, lower-casing and drops
+    :param s: str, string to apply the manipulation discussed above.
+    :return: str, s after the manipulation discussed above.
+    """
+    return re.sub(r"[^\w]", "", s).replace("_", "").lower()
+
+
+def _get_fuzzy_attr_map(params):
+    return {fuzzy_str(a): a for a in params.__dir__()}
+
+
+def _has_fuzzy_attr(params, name):
+    return fuzzy_str(name) in _get_fuzzy_attr_map(params)
+
+
+def get_fuzzy_mapping_param(name: str, params: Mapping):
+    """
+    Returns parameter value, with key=name with no sensitivity to lowercase, uppercase and symbols.
+    :param name: str, the key in params which is fuzzy-matched and retruned.
+    :param params: Mapping, the mapping containing param.
+    :return:
+    """
+    fuzzy_params = {fuzzy_str(key): params[key] for key in params.keys()}
+    return fuzzy_params[fuzzy_str(name)]
+
+
+def get_fuzzy_attr(params: Any, name: str):
+    """
+    Returns attribute (same functionality as getattr), but non sensitive to symbols, uppercase and lowercase.
+    :param params: Any, any object which wed looking for the attribute name in.
+    :param name: str, the attribute of param to be returned.
+    :return: Any, the attribute value or None when not fuzzy matching of the attribute is found
+    """
+    return getattr(params, _get_fuzzy_attr_map(params)[fuzzy_str(name)])
+
+
+def fuzzy_idx_in_list(name: str, lst: List[str]) -> int:
+    """
+    Returns the index of name in lst, with non sensitivity to symbols, uppercase and lowercase.
+    :param name: str, the name to be searched in lst.
+    :param lst: List[str], the list as described above.
+    :return: int, index of name in lst in the matter discussed above.
+    """
+    return [fuzzy_str(x) for x in lst].index(fuzzy_str(name))
+
+
 def get_param(params, name, default_val=None):
 def get_param(params, name, default_val=None):
     """
     """
-    Retrieves a param from a parameter object/dict. If the parameter does not exist, will return default_val.
+    Retrieves a param from a parameter object/dict . If the parameter does not exist, will return default_val.
     In case the default_val is of type dictionary, and a value is found in the params - the function
     In case the default_val is of type dictionary, and a value is found in the params - the function
     will return the default value dictionary with internal values overridden by the found value
     will return the default value dictionary with internal values overridden by the found value
+    IMPORTANT: Not sensitive to lowercase, uppercase and symbols.
 
 
     i.e.
     i.e.
     default_opt_params = {'lr':0.1, 'momentum':0.99, 'alpha':0.001}
     default_opt_params = {'lr':0.1, 'momentum':0.99, 'alpha':0.001}
     training_params = {'optimizer_params': {'lr':0.0001}, 'batch': 32 .... }
     training_params = {'optimizer_params': {'lr':0.0001}, 'batch': 32 .... }
-    get_param(training_params, name='optimizer_params', default_val=default_opt_params)
+    get_param(training_params, name='OptimizerParams', default_val=default_opt_params)
     will return {'lr':0.0001, 'momentum':0.99, 'alpha':0.001}
     will return {'lr':0.0001, 'momentum':0.99, 'alpha':0.001}
 
 
+
     :param params:      an object (typically HpmStruct) or a dict holding the params
     :param params:      an object (typically HpmStruct) or a dict holding the params
     :param name:        name of the searched parameter
     :param name:        name of the searched parameter
     :param default_val: assumed to be the same type as the value searched in the params
     :param default_val: assumed to be the same type as the value searched in the params
@@ -187,19 +246,24 @@ def get_param(params, name, default_val=None):
     """
     """
     if isinstance(params, Mapping):
     if isinstance(params, Mapping):
         if name in params:
         if name in params:
-            if isinstance(default_val, Mapping):
-                return {**default_val, **params[name]}
-            else:
-                return params[name]
+            param_val = params[name]
+
+        elif fuzzy_str(name) in fuzzy_keys(params):
+            param_val = get_fuzzy_mapping_param(name, params)
+
         else:
         else:
-            return default_val
+            param_val = default_val
     elif hasattr(params, name):
     elif hasattr(params, name):
-        if isinstance(default_val, Mapping):
-            return {**default_val, **getattr(params, name)}
-        else:
-            return getattr(params, name)
+        param_val = getattr(params, name)
+    elif _has_fuzzy_attr(params, name):
+        param_val = get_fuzzy_attr(params, name)
+    else:
+        param_val = default_val
+
+    if isinstance(default_val, Mapping):
+        return {**default_val, **param_val}
     else:
     else:
-        return default_val
+        return param_val
 
 
 
 
 def static_vars(**kwargs):
 def static_vars(**kwargs):
Discard