Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#291 Fix flake8 errors in different folders

Merged
GitHub User merged 1 commits into Deci-AI:master from deci-ai:hotfix/SG-000-fix_flake8_errors_misc_folders
@@ -1 +1,3 @@
 from super_gradients.common.auto_logging.auto_logger import AutoLoggerConfig
+
+__all__ = ['AutoLoggerConfig']
Discard
@@ -1,2 +1,4 @@
 # PACKAGE IMPORTS FOR EXTERNAL USAGE
 from super_gradients.common.aws_connection.aws_connector import AWSConnector
+
+__all__ = ['AWSConnector']
Discard
@@ -4,8 +4,6 @@ import boto3
 import logging
 from botocore.exceptions import ClientError, ProfileNotFound
 
-from super_gradients.common import explicit_params_validation
-
 
 class AWSConnector:
     """
Discard
@@ -60,19 +60,21 @@ class AWSSecretsManagerConnector:
             for secret_key_property in db_properties_set:
                 secret_key_to_retrieve = '.'.join([env.upper(), secret_key, secret_key_property])
                 if secret_key_to_retrieve not in aws_secrets_dict:
-                    error = f'[{current_class_name}] - Error retrieving data from AWS Secrets Manager for Secret Key "{secret_name}": The secret property "{secret_key_property}" Does Not Exist'
+                    error = f'[{current_class_name}] - Error retrieving data from AWS Secrets Manager for Secret Key "{secret_name}": ' \
+                            f'The secret property "{secret_key_property}" Does Not Exist'
                     logger.error(error)
                     raise EnvironmentError(error)
                 else:
                     env_stripped_key_name = secret_key_to_retrieve.lstrip(env.upper()).lstrip('.')
                     aws_env_safe_secrets[env_stripped_key_name] = aws_secrets_dict[secret_key_to_retrieve]
         else:
-            # "db_properties_set" is not specified - validating and returning all the secret keys and values for the secret name.
+            # "db_properties_set" is not specified - validating and returning all the secret keys and values for
+            # the secret name.
             for secret_key_name, secret_value in aws_secrets_dict.items():
                 secret_key_to_retrieve = '.'.join([env.upper(), secret_key])
                 assert secret_key_name.startswith(
-                    env.upper()), f'The secret key property "{secret_key_name}", found in secret named {secret_name},' \
-                                  f' is not following the convention of environment prefix. please add the environment prefix "{env.upper()}" to property "{secret_key_name}"'
+                    env.upper()), f'The secret key property "{secret_key_name}", found in secret named {secret_name}, is not following the convention of ' \
+                                  f'environment prefix. please add the environment prefix "{env.upper()}" to property "{secret_key_name}"'
                 if secret_key_name.startswith(secret_key_to_retrieve):
                     env_stripped_key_name = secret_key_name.lstrip(env.upper()).lstrip('.')
                     aws_env_safe_secrets[env_stripped_key_name] = secret_value
Discard
@@ -1,2 +1,4 @@
 # PACKAGE IMPORTS FOR EXTERNAL USAGE
 from super_gradients.common.data_connection.s3_connector import S3Connector
+
+__all__ = ['S3Connector']
Discard
@@ -1,3 +1,5 @@
 # PACKAGE IMPORTS FOR EXTERNAL USAGE
 from super_gradients.common.decorators.explicit_params_validator import explicit_params_validation
 from super_gradients.common.decorators.singleton import singleton
+
+__all__ = ['explicit_params_validation', 'singleton']
Discard
@@ -13,7 +13,8 @@ if AWS_ENV_NAME not in AWS_ENVIRONMENTS:
             )
         else:
             print(
-                f'Bad AWS environment name: {AWS_ENV_NAME}. Please set an environment variable named ENVIRONMENT_NAME with one of the values: {",".join(AWS_ENVIRONMENTS)}'
+                f'Bad AWS environment name: {AWS_ENV_NAME}. Please set an environment variable named ENVIRONMENT_NAME '
+                f'with one of the values: {",".join(AWS_ENVIRONMENTS)}'
             )
 
 # Controlling the default logging level via environment variable
Discard
@@ -1,12 +1,8 @@
 from super_gradients.common.factories.base_factory import BaseFactory
-from super_gradients.training.datasets.dataset_interfaces import TestDatasetInterface, \
-    LibraryDatasetInterface, \
-    ClassificationDatasetInterface, Cifar10DatasetInterface, Cifar100DatasetInterface, \
-    ImageNetDatasetInterface, TinyImageNetDatasetInterface, \
-    CoCoDetectionDatasetInterface, CoCoSegmentationDatasetInterface, CoCo2014DetectionDatasetInterface, \
-    PascalAUG2012SegmentationDataSetInterface, PascalVOC2012SegmentationDataSetInterface
-from super_gradients.training.datasets.dataset_interfaces.dataset_interface import \
-    PascalVOCUnifiedDetectionDataSetInterface, \
+from super_gradients.training.datasets.dataset_interfaces import LibraryDatasetInterface, ClassificationDatasetInterface, Cifar10DatasetInterface,\
+    Cifar100DatasetInterface, ImageNetDatasetInterface, TinyImageNetDatasetInterface, CoCoDetectionDatasetInterface, CoCoSegmentationDatasetInterface, \
+    CoCo2014DetectionDatasetInterface, PascalAUG2012SegmentationDataSetInterface, PascalVOC2012SegmentationDataSetInterface
+from super_gradients.training.datasets.dataset_interfaces.dataset_interface import PascalVOCUnifiedDetectionDataSetInterface, \
     ClassificationTestDatasetInterface, CityscapesDatasetInterface, CocoDetectionDatasetInterfaceV2
 
 
Discard
@@ -36,4 +36,3 @@ class TransformsFactory(BaseFactory):
         if isinstance(conf, Mapping) and 'Compose' in conf:
             conf['Compose']['transforms'] = ListFactory(TransformsFactory()).get(conf['Compose']['transforms'])
         return super().get(conf)
-
Discard
@@ -3,7 +3,6 @@ import os
 from super_gradients.common.abstractions.abstract_logger import get_logger
 from super_gradients.common.sg_loggers.base_sg_logger import BaseSGLogger
 from super_gradients.common.environment.env_helpers import multi_process_safe
-from super_gradients.training.params import TrainingParams
 
 logger = get_logger(__name__)
 
Discard
@@ -22,25 +22,26 @@ except (ModuleNotFoundError, ImportError, NameError):
 WANDB_ID_PREFIX = 'wandb_id.'
 WANDB_INCLUDE_FILE_NAME = '.wandbinclude'
 
+
 class WandBSGLogger(BaseSGLogger):
 
-    def __init__(self, project_name: str, experiment_name: str, storage_location: str, resumed: bool, training_params: dict, checkpoints_dir_path: str, tb_files_user_prompt: bool = False,
-                 launch_tensorboard: bool = False, tensorboard_port: int = None, save_checkpoints_remote: bool = True, save_tensorboard_remote: bool = True,
-                 save_logs_remote: bool = True, entity: Optional[str] = None, api_server: Optional[str] = None, save_code: bool = False, **kwargs):
+    def __init__(self, project_name: str, experiment_name: str, storage_location: str, resumed: bool, training_params: dict, checkpoints_dir_path: str,
+                 tb_files_user_prompt: bool = False, launch_tensorboard: bool = False, tensorboard_port: int = None, save_checkpoints_remote: bool = True,
+                 save_tensorboard_remote: bool = True, save_logs_remote: bool = True, entity: Optional[str] = None, api_server: Optional[str] = None,
+                 save_code: bool = False, **kwargs):
         """
 
-        :param experiment_name: Used for logging and loading purposes
-        :param s3_path: If set to 's3' (i.e. s3://my-bucket) saves the Checkpoints in AWS S3 otherwise saves the Checkpoints Locally
-        :param checkpoint_loaded: if true, then old tensorboard files will *not* be deleted when tb_files_user_prompt=True
-        :param max_epochs: the number of epochs planned for this training
-        :param tb_files_user_prompt: Asks user for Tensorboard deletion prompt.
-        :param launch_tensorboard: Whether to launch a TensorBoard process.
-        :param tensorboard_port: Specific port number for the tensorboard to use when launched (when set to None, some free port
-                    number will be used
+        :param experiment_name:         Used for logging and loading purposes
+        :param s3_path:                 If set to 's3' (i.e. s3://my-bucket) saves the Checkpoints in AWS S3 otherwise saves the Checkpoints Locally
+        :param checkpoint_loaded:       If true, then old tensorboard files will *not* be deleted when tb_files_user_prompt=True
+        :param max_epochs:              Number of epochs planned for this training
+        :param tb_files_user_prompt:    Asks user for Tensorboard deletion prompt.
+        :param launch_tensorboard:      Whether to launch a TensorBoard process.
+        :param tensorboard_port:        Specific port number for the tensorboard to use when launched (when set to None, some free port number will be used)
         :param save_checkpoints_remote: Saves checkpoints in s3.
         :param save_tensorboard_remote: Saves tensorboard in s3.
-        :param save_logs_remote: Saves log files in s3.
-        :param save_code: save current code to wandb
+        :param save_logs_remote:        Saves log files in s3.
+        :param save_code:               Save current code to wandb
         """
         self.s3_location_available = storage_location.startswith('s3')
         super().__init__(project_name, experiment_name, storage_location, resumed, training_params,
@@ -103,7 +104,6 @@ class WandBSGLogger(BaseSGLogger):
         else:
             wandb.run.log_code(".", include_fn=include_fn)
 
-
     @multi_process_safe
     def add_config(self, tag: str, config: dict):
         super(WandBSGLogger, self).add_config(tag=tag, config=config)
@@ -206,7 +206,7 @@ class WandBSGLogger(BaseSGLogger):
     def _get_tensorboard_file_name(self):
         try:
             tb_file_path = self.tensorboard_writer.file_writer.event_writer._file_name
-        except RuntimeError as e:
+        except RuntimeError:
             logger.warning('tensorboard file could not be located for ')
             return None
 
@@ -272,7 +272,7 @@ class WandBSGLogger(BaseSGLogger):
                     return os.path.join(cur_dir, file_name)
                 else:
                     cur_dir = os.path.dirname(cur_dir)
-        except RuntimeError as e:
+        except RuntimeError:
             return None
 
         return None
Discard
@@ -20,8 +20,8 @@ from deci_lab_client.models import (
 def main(architecture_name: str):
     # Empty on purpose so that it can be fit to the trainer use case
     checkpoint_dir = ""
-    
-    auth_token = YOUR_API_TOKEN_HERE
+
+    auth_token = "YOUR_API_TOKEN_HERE"
 
     model = SgModel(
         f"lab_optimization_{architecture_name}_example",
Discard
@@ -1,3 +1,3 @@
 from super_gradients.sanity_check.env_sanity_check import env_sanity_check
 
-__all__ = ['env_sanity_check']
+__all__ = ['env_sanity_check']
Discard
@@ -2,7 +2,7 @@ import logging
 import os
 import sys
 from pip._internal.operations.freeze import freeze
-from typing import List, Dict, Union, Tuple
+from typing import List, Dict, Union
 from pathlib import Path
 from packaging.version import Version
 
@@ -82,7 +82,9 @@ def verify_installed_libraries() -> List[str]:
 
         is_constraint_respected = {
             ">=": installed_version >= required_version,
-            "~=": installed_version.major == required_version.major and installed_version.minor == required_version.minor and installed_version.micro >= required_version.micro,
+            "~=": (installed_version.major == required_version.major and
+                   installed_version.minor == required_version.minor and
+                   installed_version.micro >= required_version.micro),
             "==": installed_version == required_version
         }
         if not is_constraint_respected[constraint]:
@@ -133,8 +135,7 @@ def env_sanity_check():
         logger.log(stdout_log_level, '_' * 20)
 
     if sanity_check_errors:
-        logger.log(stdout_log_level, 
-            f'The current environment does not meet Deci\'s needs, errors found in: {", ".join(list(sanity_check_errors.keys()))}')
+        logger.log(stdout_log_level, f'The current environment does not meet Deci\'s needs, errors found in: {", ".join(list(sanity_check_errors.keys()))}')
     elif lib_check_is_impossible:
         logger.log(stdout_log_level, LIB_CHECK_IMPOSSIBLE_MSG)
     else:
@@ -142,10 +143,10 @@ def env_sanity_check():
 
     # The last message needs to be displayed independently of DISPLAY_SANITY_CHECK
     if display_sanity_check:
-        logger.info(f'** This check can be hidden by setting the env variable DISPLAY_SANITY_CHECK=False prior to import. **')
+        logger.info('** This check can be hidden by setting the env variable DISPLAY_SANITY_CHECK=False prior to import. **')
     else:
-        logger.info(f'** A sanity check is done when importing super_gradients for the first time. **\n'
-                    f'-> You can see the details by setting the env variable DISPLAY_SANITY_CHECK=True prior to import.')
+        logger.info('** A sanity check is done when importing super_gradients for the first time. **\n'
+                    '-> You can see the details by setting the env variable DISPLAY_SANITY_CHECK=True prior to import.')
 
 
 if __name__ == '__main__':
Discard