Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#546 Features/sg 409 check all params used

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:features/SG-409-check-all-params-used
@@ -66,7 +66,7 @@ class HpmStruct:
             `jsonschema.exceptions.SchemaError` if the schema itselfis invalid
             `jsonschema.exceptions.SchemaError` if the schema itselfis invalid
         """
         """
         if self.schema is None:
         if self.schema is None:
-            raise AttributeError('schema was not set')
+            raise AttributeError("schema was not set")
         else:
         else:
             validate(self.__dict__, self.schema)
             validate(self.__dict__, self.schema)
 
 
@@ -89,7 +89,7 @@ class Timer:
         :param device: str
         :param device: str
             'cpu'\'cuda'
             'cpu'\'cuda'
         """
         """
-        self.on_gpu = (device == 'cuda')
+        self.on_gpu = device == "cuda"
         # On GPU time is measured using cuda.events
         # On GPU time is measured using cuda.events
         if self.on_gpu:
         if self.on_gpu:
             self.starter = torch.cuda.Event(enable_timing=True)
             self.starter = torch.cuda.Event(enable_timing=True)
@@ -141,8 +141,7 @@ class AverageMeter:
     def average(self):
     def average(self):
         if self._sum is None:
         if self._sum is None:
             return 0
             return 0
-        return ((self._sum / self._count).__float__()) if self._sum.dim() < 1 else tuple(
-            (self._sum / self._count).cpu().numpy())
+        return ((self._sum / self._count).__float__()) if self._sum.dim() < 1 else tuple((self._sum / self._count).cpu().numpy())
 
 
         # return (self._sum / self._count).__float__() if self._sum.dim() < 1 or len(self._sum) == 1 \
         # return (self._sum / self._count).__float__() if self._sum.dim() < 1 or len(self._sum) == 1 \
         #     else tuple((self._sum / self._count).cpu().numpy())
         #     else tuple((self._sum / self._count).cpu().numpy())
@@ -186,16 +185,16 @@ def get_param(params, name, default_val=None):
     :param default_val: assumed to be the same type as the value searched in the params
     :param default_val: assumed to be the same type as the value searched in the params
     :return:            the found value, or default if not found
     :return:            the found value, or default if not found
     """
     """
-    if isinstance(params, dict):
+    if isinstance(params, Mapping):
         if name in params:
         if name in params:
-            if isinstance(default_val, dict):
+            if isinstance(default_val, Mapping):
                 return {**default_val, **params[name]}
                 return {**default_val, **params[name]}
             else:
             else:
                 return params[name]
                 return params[name]
         else:
         else:
             return default_val
             return default_val
     elif hasattr(params, name):
     elif hasattr(params, name):
-        if isinstance(default_val, dict):
+        if isinstance(default_val, Mapping):
             return {**default_val, **getattr(params, name)}
             return {**default_val, **getattr(params, name)}
         else:
         else:
             return getattr(params, name)
             return getattr(params, name)
@@ -239,7 +238,7 @@ def random_seed(is_ddp, device, seed):
     :param device: 'cuda','cpu', 'cuda:<device_number>'
     :param device: 'cuda','cpu', 'cuda:<device_number>'
     :param seed: int, random seed to be set
     :param seed: int, random seed to be set
     """
     """
-    rank = 0 if not is_ddp else int(device.split(':')[1])
+    rank = 0 if not is_ddp else int(device.split(":")[1])
     torch.manual_seed(seed + rank)
     torch.manual_seed(seed + rank)
     np.random.seed(seed + rank)
     np.random.seed(seed + rank)
     random.seed(seed + rank)
     random.seed(seed + rank)
@@ -266,22 +265,21 @@ def get_filename_suffix_by_framework(framework: str):
     @param framework: (str)
     @param framework: (str)
     @return: (str) the suffix for the specific framework
     @return: (str) the suffix for the specific framework
     """
     """
-    frameworks_dict = \
-        {
-            'TENSORFLOW1': '.pb',
-            'TENSORFLOW2': '.zip',
-            'PYTORCH': '.pth',
-            'ONNX': '.onnx',
-            'TENSORRT': '.pkl',
-            'OPENVINO': '.pkl',
-            'TORCHSCRIPT': '.pth',
-            'TVM': '',
-            'KERAS': '.h5',
-            'TFLITE': '.tflite'
-        }
+    frameworks_dict = {
+        "TENSORFLOW1": ".pb",
+        "TENSORFLOW2": ".zip",
+        "PYTORCH": ".pth",
+        "ONNX": ".onnx",
+        "TENSORRT": ".pkl",
+        "OPENVINO": ".pkl",
+        "TORCHSCRIPT": ".pth",
+        "TVM": "",
+        "KERAS": ".h5",
+        "TFLITE": ".tflite",
+    }
 
 
     if framework.upper() not in frameworks_dict.keys():
     if framework.upper() not in frameworks_dict.keys():
-        raise ValueError(f'Unsupported framework: {framework}')
+        raise ValueError(f"Unsupported framework: {framework}")
 
 
     return frameworks_dict[framework.upper()]
     return frameworks_dict[framework.upper()]
 
 
@@ -294,15 +292,15 @@ def check_models_have_same_weights(model_1: torch.nn.Module, model_2: torch.nn.M
     @param model_2: Net to be checked
     @param model_2: Net to be checked
     @return: True iff the two networks have the same weights
     @return: True iff the two networks have the same weights
     """
     """
-    model_1, model_2 = model_1.to('cpu'), model_2.to('cpu')
+    model_1, model_2 = model_1.to("cpu"), model_2.to("cpu")
     models_differ = 0
     models_differ = 0
     for key_item_1, key_item_2 in zip(model_1.state_dict().items(), model_2.state_dict().items()):
     for key_item_1, key_item_2 in zip(model_1.state_dict().items(), model_2.state_dict().items()):
         if torch.equal(key_item_1[1], key_item_2[1]):
         if torch.equal(key_item_1[1], key_item_2[1]):
             pass
             pass
         else:
         else:
             models_differ += 1
             models_differ += 1
-            if (key_item_1[0] == key_item_2[0]):
-                print(f'Layer names match but layers have different weights for layers: {key_item_1[0]}')
+            if key_item_1[0] == key_item_2[0]:
+                print(f"Layer names match but layers have different weights for layers: {key_item_1[0]}")
     if models_differ == 0:
     if models_differ == 0:
         return True
         return True
     else:
     else:
@@ -320,7 +318,7 @@ def recursive_override(base: dict, extension: dict):
             base[k] = extension[k]
             base[k] = extension[k]
 
 
 
 
-def download_and_unzip_from_url(url, dir='.', unzip=True, delete=True):
+def download_and_unzip_from_url(url, dir=".", unzip=True, delete=True):
     """
     """
     Downloads a zip file from url to dir, and unzips it.
     Downloads a zip file from url to dir, and unzips it.
 
 
@@ -341,14 +339,14 @@ def download_and_unzip_from_url(url, dir='.', unzip=True, delete=True):
         if Path(url).is_file():  # exists in current path
         if Path(url).is_file():  # exists in current path
             Path(url).rename(f)  # move to dir
             Path(url).rename(f)  # move to dir
         elif not f.exists():
         elif not f.exists():
-            print(f'Downloading {url} to {f}...')
+            print(f"Downloading {url} to {f}...")
             torch.hub.download_url_to_file(url, f, progress=True)  # torch download
             torch.hub.download_url_to_file(url, f, progress=True)  # torch download
-        if unzip and f.suffix in ('.zip', '.gz'):
-            print(f'Unzipping {f}...')
-            if f.suffix == '.zip':
+        if unzip and f.suffix in (".zip", ".gz"):
+            print(f"Unzipping {f}...")
+            if f.suffix == ".zip":
                 ZipFile(f).extractall(path=dir)  # unzip
                 ZipFile(f).extractall(path=dir)  # unzip
-            elif f.suffix == '.gz':
-                os.system(f'tar xfz {f} --directory {f.parent}')  # unzip
+            elif f.suffix == ".gz":
+                os.system(f"tar xfz {f} --directory {f.parent}")  # unzip
             if delete:
             if delete:
                 f.unlink()  # remove zip
                 f.unlink()  # remove zip
 
 
@@ -358,7 +356,7 @@ def download_and_unzip_from_url(url, dir='.', unzip=True, delete=True):
         download_one(u, dir)
         download_one(u, dir)
 
 
 
 
-def download_and_untar_from_url(urls: List[str], dir: Union[str, Path] = '.'):
+def download_and_untar_from_url(urls: List[str], dir: Union[str, Path] = "."):
     """
     """
     Download a file from url and untar.
     Download a file from url and untar.
 
 
@@ -375,13 +373,13 @@ def download_and_untar_from_url(urls: List[str], dir: Union[str, Path] = '.'):
         if url_path.is_file():
         if url_path.is_file():
             url_path.rename(filepath)
             url_path.rename(filepath)
         elif not filepath.exists():
         elif not filepath.exists():
-            logger.info(f'Downloading {url} to {filepath}...')
+            logger.info(f"Downloading {url} to {filepath}...")
             torch.hub.download_url_to_file(url, str(filepath), progress=True)
             torch.hub.download_url_to_file(url, str(filepath), progress=True)
 
 
         modes = {".tar.gz": "r:gz", ".tar": "r:"}
         modes = {".tar.gz": "r:gz", ".tar": "r:"}
         assert filepath.suffix in modes.keys(), f"{filepath} has {filepath.suffix} suffix which is not supported"
         assert filepath.suffix in modes.keys(), f"{filepath} has {filepath.suffix} suffix which is not supported"
 
 
-        logger.info(f'Extracting to {dir}...')
+        logger.info(f"Extracting to {dir}...")
         with tarfile.open(filepath, mode=modes[filepath.suffix]) as f:
         with tarfile.open(filepath, mode=modes[filepath.suffix]) as f:
             f.extractall(dir)
             f.extractall(dir)
         filepath.unlink()
         filepath.unlink()
@@ -418,7 +416,7 @@ def get_orientation_key() -> int:
     """Get the orientation key according to PIL, which is useful to get the image size for instance
     """Get the orientation key according to PIL, which is useful to get the image size for instance
     :return: Orientation key according to PIL"""
     :return: Orientation key according to PIL"""
     for key, value in ExifTags.TAGS.items():
     for key, value in ExifTags.TAGS.items():
-        if value == 'Orientation':
+        if value == "Orientation":
             return key
             return key
 
 
 
 
@@ -442,14 +440,14 @@ def exif_size(image: Image) -> Tuple[int, int]:
             elif rotation == 8:
             elif rotation == 8:
                 image_size = (image_size[1], image_size[0])
                 image_size = (image_size[1], image_size[0])
     except Exception as ex:
     except Exception as ex:
-        print('Caught Exception trying to rotate: ' + str(image) + str(ex))
+        print("Caught Exception trying to rotate: " + str(image) + str(ex))
     width, height = image_size
     width, height = image_size
     return height, width
     return height, width
 
 
 
 
 def get_image_size_from_path(img_path: str) -> Tuple[int, int]:
 def get_image_size_from_path(img_path: str) -> Tuple[int, int]:
     """Get the image size of an image at a specific path"""
     """Get the image size of an image at a specific path"""
-    with open(img_path, 'rb') as f:
+    with open(img_path, "rb") as f:
         return exif_size(Image.open(f))
         return exif_size(Image.open(f))
 
 
 
 
Discard
Tip!

Press p or to see the previous file or, n or to see the next file