Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#595 Feature/sg 492 fuzzy logic for get param and factories

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-492_fuzzy_logic_for_get_param_and_factories
@@ -33,3 +33,4 @@ wheel>=0.38.0
 pygments>=2.7.4
 pygments>=2.7.4
 stringcase>=1.2.0
 stringcase>=1.2.0
 numpy<=1.23
 numpy<=1.23
+rapidfuzz
Discard
    Discard
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    1. from typing import List
    2. from rapidfuzz import process, fuzz
    3. class UnknownTypeException(Exception):
    4. """Type error with message, followed by type suggestion, chosen by fuzzy matching
    5. (out of 'choices' arg passed in __init__).
    6. Attributes:
    7. message -- explanation of the error
    8. """
    9. def __init__(self, unknown_type: str, choices: List, message: str = None):
    10. message = message or f"Unknown object type: {unknown_type} in configuration. valid types are: {choices}"
    11. if isinstance(unknown_type, str):
    12. choice, score, _ = process.extractOne(unknown_type, choices, scorer=fuzz.WRatio)
    13. if score > 70:
    14. err_msg_tip = f"\n Did you mean: {choice}?"
    15. else:
    16. err_msg_tip = ""
    17. self.message = message + err_msg_tip
    18. super().__init__(self.message)
    Discard
    @@ -1,5 +1,8 @@
     from typing import Union, Mapping, Dict
     from typing import Union, Mapping, Dict
     
     
    +from super_gradients.common.exceptions.factory_exceptions import UnknownTypeException
    +from super_gradients.training.utils.utils import fuzzy_str, fuzzy_keys, get_fuzzy_mapping_param
    +
     
     
     class AbstractFactory:
     class AbstractFactory:
         """
         """
    @@ -23,6 +26,7 @@ class BaseFactory(AbstractFactory):
         """
         """
         The basic factory fo a *single* object generation.
         The basic factory fo a *single* object generation.
         """
         """
    +
         def __init__(self, type_dict: Dict[str, type]):
         def __init__(self, type_dict: Dict[str, type]):
             """
             """
             :param type_dict: a dictionary mapping a name to a type
             :param type_dict: a dictionary mapping a name to a type
    @@ -31,29 +35,35 @@ class BaseFactory(AbstractFactory):
     
     
         def get(self, conf: Union[str, dict]):
         def get(self, conf: Union[str, dict]):
             """
             """
    -         Get an instantiated object.
    -            :param conf: a configuration
    -            if string - assumed to be a type name (not the real name, but a name defined in the Factory)
    -            if dictionary - assumed to be {type_name(str): {parameters...}} (single item in dict)
    +        Get an instantiated object.
    +           :param conf: a configuration
    +           if string - assumed to be a type name (not the real name, but a name defined in the Factory)
    +           if dictionary - assumed to be {type_name(str): {parameters...}} (single item in dict)
     
     
    -            If provided value is not one of the three above, the value will be returned as is
    +           If provided value is not one of the three above, the value will be returned as is
             """
             """
             if isinstance(conf, str):
             if isinstance(conf, str):
                 if conf in self.type_dict:
                 if conf in self.type_dict:
                     return self.type_dict[conf]()
                     return self.type_dict[conf]()
    +            elif fuzzy_str(conf) in fuzzy_keys(self.type_dict):
    +                return get_fuzzy_mapping_param(conf, self.type_dict)()
                 else:
                 else:
    -                raise RuntimeError(f"Unknown object type: {conf} in configuration. valid types are: {self.type_dict.keys()}")
    +                raise UnknownTypeException(conf, list(self.type_dict.keys()))
             elif isinstance(conf, Mapping):
             elif isinstance(conf, Mapping):
                 if len(conf.keys()) > 1:
                 if len(conf.keys()) > 1:
    -                raise RuntimeError("Malformed object definition in configuration. Expecting either a string of object type or a single entry dictionary"
    -                                   "{type_name(str): {parameters...}}."
    -                                   f"received: {conf}")
    +                raise RuntimeError(
    +                    "Malformed object definition in configuration. Expecting either a string of object type or a single entry dictionary"
    +                    "{type_name(str): {parameters...}}."
    +                    f"received: {conf}"
    +                )
     
     
                 _type = list(conf.keys())[0]  # THE TYPE NAME
                 _type = list(conf.keys())[0]  # THE TYPE NAME
                 _params = list(conf.values())[0]  # A DICT CONTAINING THE PARAMETERS FOR INIT
                 _params = list(conf.values())[0]  # A DICT CONTAINING THE PARAMETERS FOR INIT
                 if _type in self.type_dict:
                 if _type in self.type_dict:
                     return self.type_dict[_type](**_params)
                     return self.type_dict[_type](**_params)
    +            elif fuzzy_str(_type) in fuzzy_keys(self.type_dict):
    +                return get_fuzzy_mapping_param(_type, self.type_dict)(**_params)
                 else:
                 else:
    -                raise RuntimeError(f"Unknown object type: {_type} in configuration. valid types are: {self.type_dict.keys()}")
    +                raise UnknownTypeException(_type, list(self.type_dict.keys()))
             else:
             else:
                 return conf
                 return conf
    Discard
    @@ -2,7 +2,9 @@ from typing import Dict, Union, Type
     from enum import Enum
     from enum import Enum
     import importlib
     import importlib
     
     
    +from super_gradients.common.exceptions.factory_exceptions import UnknownTypeException
     from super_gradients.common.factories.base_factory import AbstractFactory
     from super_gradients.common.factories.base_factory import AbstractFactory
    +from super_gradients.training.utils import get_param
     
     
     
     
     class TypeFactory(AbstractFactory):
     class TypeFactory(AbstractFactory):
    @@ -32,6 +34,8 @@ class TypeFactory(AbstractFactory):
             if isinstance(conf, str) or isinstance(conf, bool):
             if isinstance(conf, str) or isinstance(conf, bool):
                 if conf in self.type_dict:
                 if conf in self.type_dict:
                     return self.type_dict[conf]
                     return self.type_dict[conf]
    +            elif isinstance(conf, str) and get_param(self.type_dict, conf) is not None:
    +                return get_param(self.type_dict, conf)
                 else:
                 else:
                     try:
                     try:
                         lib = ".".join(conf.split(".")[:-1])
                         lib = ".".join(conf.split(".")[:-1])
    @@ -40,9 +44,12 @@ class TypeFactory(AbstractFactory):
                         class_type = lib.__dict__[module]
                         class_type = lib.__dict__[module]
                         return class_type
                         return class_type
                     except RuntimeError:
                     except RuntimeError:
    -                    raise RuntimeError(
    -                        f"Unknown object type: {conf} in configuration. valid types are: {self.type_dict.keys()} or a class "
    -                        "type available in the env (or the form 'package_name.sub_package.MyClass'"
    +                    raise UnknownTypeException(
    +                        unknown_type=conf,
    +                        choices=list(self.type_dict.keys()),
    +                        message=f"Unknown object type: {conf} in configuration. valid types are: {self.type_dict.keys()} or a class "
    +                        "type available in the env (or the form "
    +                        "'package_name.sub_package.MyClass' ",
                         )
                         )
             else:
             else:
                 return conf
                 return conf
    Discard
    @@ -7,10 +7,11 @@ import torch
     from super_gradients.common import StrictLoad
     from super_gradients.common import StrictLoad
     from super_gradients.common.plugins.deci_client import DeciClient, client_enabled
     from super_gradients.common.plugins.deci_client import DeciClient, client_enabled
     from super_gradients.training import utils as core_utils
     from super_gradients.training import utils as core_utils
    +from super_gradients.common.exceptions.factory_exceptions import UnknownTypeException
     from super_gradients.training.models import SgModule
     from super_gradients.training.models import SgModule
     from super_gradients.training.models.all_architectures import ARCHITECTURES
     from super_gradients.training.models.all_architectures import ARCHITECTURES
     from super_gradients.training.pretrained_models import PRETRAINED_NUM_CLASSES
     from super_gradients.training.pretrained_models import PRETRAINED_NUM_CLASSES
    -from super_gradients.training.utils import HpmStruct
    +from super_gradients.training.utils import HpmStruct, get_param
     from super_gradients.training.utils.checkpoint_utils import (
     from super_gradients.training.utils.checkpoint_utils import (
         load_checkpoint_to_model,
         load_checkpoint_to_model,
         load_pretrained_weights,
         load_pretrained_weights,
    @@ -42,7 +43,9 @@ def get_architecture(model_name: str, arch_params: HpmStruct, download_required_
         is_remote = False
         is_remote = False
         if not isinstance(model_name, str):
         if not isinstance(model_name, str):
             raise ValueError("Parameter model_name is expected to be a string.")
             raise ValueError("Parameter model_name is expected to be a string.")
    -    elif model_name not in ARCHITECTURES.keys():
    +
    +    architecture = get_param(ARCHITECTURES, model_name)
    +    if model_name not in ARCHITECTURES.keys() and architecture is None:
             if client_enabled:
             if client_enabled:
                 logger.info(f'The required model, "{model_name}", was not found in SuperGradients. Trying to load a model from remote deci-lab')
                 logger.info(f'The required model, "{model_name}", was not found in SuperGradients. Trying to load a model from remote deci-lab')
                 deci_client = DeciClient()
                 deci_client = DeciClient()
    @@ -63,11 +66,13 @@ def get_architecture(model_name: str, arch_params: HpmStruct, download_required_
                 _arch_params.override(**arch_params.to_dict())
                 _arch_params.override(**arch_params.to_dict())
                 arch_params, is_remote = _arch_params, True
                 arch_params, is_remote = _arch_params, True
             else:
             else:
    -            raise ValueError(
    -                f'The required model, "{model_name}", was not found in SuperGradients. See docs or all_architectures.py for supported model names.'
    +            raise UnknownTypeException(
    +                message=f'The required model, "{model_name}", was not found in SuperGradients. See docs or all_architectures.py for supported model names.',
    +                unknown_type=model_name,
    +                choices=list(ARCHITECTURES.keys()),
                 )
                 )
     
     
    -    return ARCHITECTURES[model_name], arch_params, pretrained_weights_path, is_remote
    +    return get_param(ARCHITECTURES, model_name), arch_params, pretrained_weights_path, is_remote
     
     
     
     
     def instantiate_model(
     def instantiate_model(
    Discard
    @@ -38,7 +38,6 @@ from super_gradients.training.pretrained_models import PRETRAINED_NUM_CLASSES
     from super_gradients.training.utils import sg_trainer_utils, get_param
     from super_gradients.training.utils import sg_trainer_utils, get_param
     from super_gradients.training.utils.sg_trainer_utils import MonitoredValue, parse_args, log_main_training_params
     from super_gradients.training.utils.sg_trainer_utils import MonitoredValue, parse_args, log_main_training_params
     from super_gradients.training.exceptions.sg_trainer_exceptions import UnsupportedOptimizerFormat, GPUModeNotSetupError
     from super_gradients.training.exceptions.sg_trainer_exceptions import UnsupportedOptimizerFormat, GPUModeNotSetupError
    -from super_gradients.training.losses import LOSSES
     from super_gradients.training.metrics.metric_utils import (
     from super_gradients.training.metrics.metric_utils import (
         get_metrics_titles,
         get_metrics_titles,
         get_metrics_results_tuple,
         get_metrics_results_tuple,
    @@ -60,6 +59,7 @@ from super_gradients.training.utils.distributed_training_utils import (
     )
     )
     from super_gradients.training.utils.ema import ModelEMA
     from super_gradients.training.utils.ema import ModelEMA
     from super_gradients.training.utils.optimizer_utils import build_optimizer
     from super_gradients.training.utils.optimizer_utils import build_optimizer
    +from super_gradients.training.utils.utils import fuzzy_idx_in_list
     from super_gradients.training.utils.weight_averaging_utils import ModelWeightAveraging
     from super_gradients.training.utils.weight_averaging_utils import ModelWeightAveraging
     from super_gradients.training.metrics import Accuracy, Top5
     from super_gradients.training.metrics import Accuracy, Top5
     from super_gradients.training.utils import random_seed
     from super_gradients.training.utils import random_seed
    @@ -465,7 +465,7 @@ class Trainer:
             return loss, loss_logging_items
             return loss, loss_logging_items
     
     
         def _init_monitored_items(self):
         def _init_monitored_items(self):
    -        self.metric_idx_in_results_tuple = (self.loss_logging_items_names + get_metrics_titles(self.valid_metrics)).index(self.metric_to_watch)
    +        self.metric_idx_in_results_tuple = fuzzy_idx_in_list(self.metric_to_watch, self.loss_logging_items_names + get_metrics_titles(self.valid_metrics))
             # Instantiate the values to monitor (loss/metric)
             # Instantiate the values to monitor (loss/metric)
             for loss_name in self.loss_logging_items_names:
             for loss_name in self.loss_logging_items_names:
                 self.train_monitored_values[loss_name] = MonitoredValue(name=loss_name, greater_is_better=False)
                 self.train_monitored_values[loss_name] = MonitoredValue(name=loss_name, greater_is_better=False)
    @@ -997,8 +997,7 @@ class Trainer:
     
     
             # Allowing loading instantiated loss or string
             # Allowing loading instantiated loss or string
             if isinstance(self.training_params.loss, str):
             if isinstance(self.training_params.loss, str):
    -            criterion_cls = LOSSES[self.training_params.loss]
    -            self.criterion = criterion_cls(**self.training_params.criterion_params)
    +            self.criterion = LossesFactory().get({self.training_params.loss: self.training_params.criterion_params})
     
     
             elif isinstance(self.training_params.loss, Mapping):
             elif isinstance(self.training_params.loss, Mapping):
                 self.criterion = LossesFactory().get(self.training_params.loss)
                 self.criterion = LossesFactory().get(self.training_params.loss)
    Discard
    @@ -2,13 +2,13 @@ import math
     import time
     import time
     from functools import lru_cache
     from functools import lru_cache
     from pathlib import Path
     from pathlib import Path
    -from typing import Mapping, Optional, Tuple, Union, List, Dict
    +from typing import Mapping, Optional, Tuple, Union, List, Dict, Any
     from zipfile import ZipFile
     from zipfile import ZipFile
     import os
     import os
     from jsonschema import validate
     from jsonschema import validate
     import tarfile
     import tarfile
     from PIL import Image, ExifTags
     from PIL import Image, ExifTags
    -
    +import re
     import torch
     import torch
     import torch.nn as nn
     import torch.nn as nn
     
     
    @@ -168,18 +168,77 @@ def tensor_container_to_device(obj: Union[torch.Tensor, tuple, list, dict], devi
             return obj
             return obj
     
     
     
     
    +def fuzzy_keys(params: Mapping) -> List[str]:
    +    """
    +    Returns params.key() removing leading and trailing white space, lower-casing and dropping symbols.
    +    :param params: Mapping, the mapping containing the keys to be returned.
    +    :return: List[str], list of keys as discussed above.
    +    """
    +    return [fuzzy_str(s) for s in params.keys()]
    +
    +
    +def fuzzy_str(s: str):
    +    """
    +    Returns s removing leading and trailing white space, lower-casing and drops
    +    :param s: str, string to apply the manipulation discussed above.
    +    :return: str, s after the manipulation discussed above.
    +    """
    +    return re.sub(r"[^\w]", "", s).replace("_", "").lower()
    +
    +
    +def _get_fuzzy_attr_map(params):
    +    return {fuzzy_str(a): a for a in params.__dir__()}
    +
    +
    +def _has_fuzzy_attr(params, name):
    +    return fuzzy_str(name) in _get_fuzzy_attr_map(params)
    +
    +
    +def get_fuzzy_mapping_param(name: str, params: Mapping):
    +    """
    +    Returns parameter value, with key=name with no sensitivity to lowercase, uppercase and symbols.
    +    :param name: str, the key in params which is fuzzy-matched and retruned.
    +    :param params: Mapping, the mapping containing param.
    +    :return:
    +    """
    +    fuzzy_params = {fuzzy_str(key): params[key] for key in params.keys()}
    +    return fuzzy_params[fuzzy_str(name)]
    +
    +
    +def get_fuzzy_attr(params: Any, name: str):
    +    """
    +    Returns attribute (same functionality as getattr), but non sensitive to symbols, uppercase and lowercase.
    +    :param params: Any, any object which wed looking for the attribute name in.
    +    :param name: str, the attribute of param to be returned.
    +    :return: Any, the attribute value or None when not fuzzy matching of the attribute is found
    +    """
    +    return getattr(params, _get_fuzzy_attr_map(params)[fuzzy_str(name)])
    +
    +
    +def fuzzy_idx_in_list(name: str, lst: List[str]) -> int:
    +    """
    +    Returns the index of name in lst, with non sensitivity to symbols, uppercase and lowercase.
    +    :param name: str, the name to be searched in lst.
    +    :param lst: List[str], the list as described above.
    +    :return: int, index of name in lst in the matter discussed above.
    +    """
    +    return [fuzzy_str(x) for x in lst].index(fuzzy_str(name))
    +
    +
     def get_param(params, name, default_val=None):
     def get_param(params, name, default_val=None):
         """
         """
    -    Retrieves a param from a parameter object/dict. If the parameter does not exist, will return default_val.
    +    Retrieves a param from a parameter object/dict . If the parameter does not exist, will return default_val.
         In case the default_val is of type dictionary, and a value is found in the params - the function
         In case the default_val is of type dictionary, and a value is found in the params - the function
         will return the default value dictionary with internal values overridden by the found value
         will return the default value dictionary with internal values overridden by the found value
    +    IMPORTANT: Not sensitive to lowercase, uppercase and symbols.
     
     
         i.e.
         i.e.
         default_opt_params = {'lr':0.1, 'momentum':0.99, 'alpha':0.001}
         default_opt_params = {'lr':0.1, 'momentum':0.99, 'alpha':0.001}
         training_params = {'optimizer_params': {'lr':0.0001}, 'batch': 32 .... }
         training_params = {'optimizer_params': {'lr':0.0001}, 'batch': 32 .... }
    -    get_param(training_params, name='optimizer_params', default_val=default_opt_params)
    +    get_param(training_params, name='OptimizerParams', default_val=default_opt_params)
         will return {'lr':0.0001, 'momentum':0.99, 'alpha':0.001}
         will return {'lr':0.0001, 'momentum':0.99, 'alpha':0.001}
     
     
    +
         :param params:      an object (typically HpmStruct) or a dict holding the params
         :param params:      an object (typically HpmStruct) or a dict holding the params
         :param name:        name of the searched parameter
         :param name:        name of the searched parameter
         :param default_val: assumed to be the same type as the value searched in the params
         :param default_val: assumed to be the same type as the value searched in the params
    @@ -187,19 +246,24 @@ def get_param(params, name, default_val=None):
         """
         """
         if isinstance(params, Mapping):
         if isinstance(params, Mapping):
             if name in params:
             if name in params:
    -            if isinstance(default_val, Mapping):
    -                return {**default_val, **params[name]}
    -            else:
    -                return params[name]
    +            param_val = params[name]
    +
    +        elif fuzzy_str(name) in fuzzy_keys(params):
    +            param_val = get_fuzzy_mapping_param(name, params)
    +
             else:
             else:
    -            return default_val
    +            param_val = default_val
         elif hasattr(params, name):
         elif hasattr(params, name):
    -        if isinstance(default_val, Mapping):
    -            return {**default_val, **getattr(params, name)}
    -        else:
    -            return getattr(params, name)
    +        param_val = getattr(params, name)
    +    elif _has_fuzzy_attr(params, name):
    +        param_val = get_fuzzy_attr(params, name)
    +    else:
    +        param_val = default_val
    +
    +    if isinstance(default_val, Mapping):
    +        return {**default_val, **param_val}
         else:
         else:
    -        return default_val
    +        return param_val
     
     
     
     
     def static_vars(**kwargs):
     def static_vars(**kwargs):
    Discard
    @@ -7,6 +7,7 @@ from super_gradients.common.decorators.factory_decorator import resolve_param
     from super_gradients.common.factories.activations_type_factory import ActivationsTypeFactory
     from super_gradients.common.factories.activations_type_factory import ActivationsTypeFactory
     from super_gradients.training import models
     from super_gradients.training import models
     from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
     from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
    +from super_gradients.training.losses import LabelSmoothingCrossEntropyLoss
     from super_gradients.training.metrics import Accuracy, Top5
     from super_gradients.training.metrics import Accuracy, Top5
     from torch import nn
     from torch import nn
     
     
    @@ -38,6 +39,32 @@ class FactoriesTest(unittest.TestCase):
             self.assertIsInstance(trainer.valid_metrics.Top5, Top5)
             self.assertIsInstance(trainer.valid_metrics.Top5, Top5)
             self.assertIsInstance(trainer.optimizer, torch.optim.ASGD)
             self.assertIsInstance(trainer.optimizer, torch.optim.ASGD)
     
     
    +    def test_training_with_factories_with_typos(self):
    +        trainer = Trainer("test_train_with_factories_with_typos")
    +        net = models.get("Resnet___18", num_classes=5)
    +        train_params = {
    +            "max_epochs": 2,
    +            "lr_updates": [1],
    +            "lr_decay_factor": 0.1,
    +            "lr_mode": "step",
    +            "lr_warmup_epochs": 0,
    +            "initial_lr": 0.1,
    +            "loss": "crossEnt_ropy",
    +            "optimizer": "AdAm_",  # use an optimizer by factory
    +            "criterion_params": {},
    +            "train_metrics_list": ["accur_acy", "Top_5"],  # use a metric by factory
    +            "valid_metrics_list": ["aCCuracy", "Top5"],  # use a metric by factory
    +            "metric_to_watch": "Accurac_Y",
    +            "greater_metric_to_watch_is_better": True,
    +        }
    +
    +        trainer.train(model=net, training_params=train_params, train_loader=classification_test_dataloader(), valid_loader=classification_test_dataloader())
    +
    +        self.assertIsInstance(trainer.train_metrics.Accuracy, Accuracy)
    +        self.assertIsInstance(trainer.valid_metrics.Top5, Top5)
    +        self.assertIsInstance(trainer.optimizer, torch.optim.Adam)
    +        self.assertIsInstance(trainer.criterion, LabelSmoothingCrossEntropyLoss)
    +
         def test_activations_factory(self):
         def test_activations_factory(self):
             class DummyModel(nn.Module):
             class DummyModel(nn.Module):
                 @resolve_param("activation_in_head", ActivationsTypeFactory())
                 @resolve_param("activation_in_head", ActivationsTypeFactory())
    Discard