Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#676 apply black formatter to all of "common" directory

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-000_blsck_format_common
@@ -1,4 +1,4 @@
 # PACKAGE IMPORTS FOR EXTERNAL USAGE
 # PACKAGE IMPORTS FOR EXTERNAL USAGE
 from super_gradients.common.aws_connection.aws_connector import AWSConnector
 from super_gradients.common.aws_connection.aws_connector import AWSConnector
 
 
-__all__ = ['AWSConnector']
+__all__ = ["AWSConnector"]
Discard
@@ -24,25 +24,23 @@ class AWSConnector:
             try:
             try:
                 if profile_name and boto3.session.Session(profile_name=profile_name).get_credentials():
                 if profile_name and boto3.session.Session(profile_name=profile_name).get_credentials():
                     # TRY USING A SPECIFIC PROFILE_NAME (USING A CREDENTIALS FILE)
                     # TRY USING A SPECIFIC PROFILE_NAME (USING A CREDENTIALS FILE)
-                    logger.info('Trying to connect to AWS using Credentials File with profile_name: ' + profile_name)
+                    logger.info("Trying to connect to AWS using Credentials File with profile_name: " + profile_name)
 
 
                     session = boto3.Session(profile_name=profile_name)
                     session = boto3.Session(profile_name=profile_name)
                     return session
                     return session
 
 
             except ProfileNotFound as profileNotFoundException:
             except ProfileNotFound as profileNotFoundException:
                 logger.debug(
                 logger.debug(
-                    '[' + current_class_name + '] - Could not find profile name - Trying using Default Profile/IAM Role' + str(
-                        profileNotFoundException))
+                    "[" + current_class_name + "] - Could not find profile name - Trying using Default Profile/IAM Role" + str(profileNotFoundException)
+                )
 
 
             # TRY USING AN IAM ROLE (OR *DEFAULT* CREDENTIALS - USING A CREDENTIALS FILE)
             # TRY USING AN IAM ROLE (OR *DEFAULT* CREDENTIALS - USING A CREDENTIALS FILE)
-            logger.info('Trying to connect to AWS using IAM role or Default Credentials')
+            logger.info("Trying to connect to AWS using IAM role or Default Credentials")
             session = boto3.Session()
             session = boto3.Session()
             return session
             return session
 
 
         except Exception as ex:
         except Exception as ex:
-            logger.critical(
-                '[' + current_class_name + '] - Caught Exception while trying to connect to AWS Credentials Manager ' + str(
-                    ex))
+            logger.critical("[" + current_class_name + "] - Caught Exception while trying to connect to AWS Credentials Manager " + str(ex))
             return None
             return None
 
 
     @staticmethod
     @staticmethod
@@ -57,7 +55,7 @@ class AWSConnector:
 
 
         aws_session = AWSConnector.__create_boto_3_session(profile_name=profile_name)
         aws_session = AWSConnector.__create_boto_3_session(profile_name=profile_name)
         if aws_session is None:
         if aws_session is None:
-            logger.error('Failed to initiate an AWS Session')
+            logger.error("Failed to initiate an AWS Session")
 
 
         return aws_session
         return aws_session
 
 
@@ -74,7 +72,7 @@ class AWSConnector:
 
 
         aws_session = AWSConnector.__create_boto_3_session(profile_name=profile_name)
         aws_session = AWSConnector.__create_boto_3_session(profile_name=profile_name)
         if aws_session is None:
         if aws_session is None:
-            logger.error('Failed to connect to AWS client: ' + str(service_name))
+            logger.error("Failed to connect to AWS client: " + str(service_name))
 
 
         return aws_session.client(service_name=service_name)
         return aws_session.client(service_name=service_name)
 
 
@@ -91,7 +89,7 @@ class AWSConnector:
 
 
         aws_session = AWSConnector.__create_boto_3_session(profile_name=profile_name)
         aws_session = AWSConnector.__create_boto_3_session(profile_name=profile_name)
         if aws_session is None:
         if aws_session is None:
-            logger.error('Failed to connect to AWS client: ' + str(service_name))
+            logger.error("Failed to connect to AWS client: " + str(service_name))
 
 
         return aws_session.resource(service_name=service_name)
         return aws_session.resource(service_name=service_name)
 
 
Discard
@@ -1,4 +1,4 @@
 # PACKAGE IMPORTS FOR EXTERNAL USAGE
 # PACKAGE IMPORTS FOR EXTERNAL USAGE
 from super_gradients.common.data_connection.s3_connector import S3Connector
 from super_gradients.common.data_connection.s3_connector import S3Connector
 
 
-__all__ = ['S3Connector']
+__all__ = ["S3Connector"]
Discard
@@ -2,4 +2,4 @@
 from super_gradients.common.data_interface.dataset_data_interface import DatasetDataInterface
 from super_gradients.common.data_interface.dataset_data_interface import DatasetDataInterface
 from super_gradients.common.data_interface.adnn_model_repository_data_interface import ADNNModelRepositoryDataInterfaces
 from super_gradients.common.data_interface.adnn_model_repository_data_interface import ADNNModelRepositoryDataInterfaces
 
 
-__all__ = ['DatasetDataInterface', 'ADNNModelRepositoryDataInterfaces']
+__all__ = ["DatasetDataInterface", "ADNNModelRepositoryDataInterfaces"]
Discard
@@ -1,4 +1,3 @@
-from super_gradients.common.data_types.enum import StrictLoad, DeepLearningTask, MultiGPUMode, EvaluationType,\
-    UpsampleMode
+from super_gradients.common.data_types.enum import StrictLoad, DeepLearningTask, MultiGPUMode, EvaluationType, UpsampleMode
 
 
-__all__ = ['StrictLoad', 'DeepLearningTask', 'EvaluationType', 'MultiGPUMode', 'UpsampleMode']
+__all__ = ["StrictLoad", "DeepLearningTask", "EvaluationType", "MultiGPUMode", "UpsampleMode"]
Discard
@@ -5,4 +5,4 @@ from super_gradients.common.data_types.enum.multi_gpu_mode import MultiGPUMode
 from super_gradients.common.data_types.enum.upsample_mode import UpsampleMode
 from super_gradients.common.data_types.enum.upsample_mode import UpsampleMode
 
 
 
 
-__all__ = ['StrictLoad', 'DeepLearningTask', 'EvaluationType', 'MultiGPUMode', 'UpsampleMode']
+__all__ = ["StrictLoad", "DeepLearningTask", "EvaluationType", "MultiGPUMode", "UpsampleMode"]
Discard
@@ -2,10 +2,10 @@ from enum import Enum
 
 
 
 
 class DeepLearningTask(str, Enum):
 class DeepLearningTask(str, Enum):
-    CLASSIFICATION = 'classification'
-    SEMANTIC_SEGMENTATION = 'semantic_segmentation'
-    OBJECT_DETECTION = 'object_detection'
-    DEPTH_ESTIMATION = 'depth_estimation'
-    POSE_ESTIMATION = 'pose_estimation'
-    NLP = 'nlp'
-    OTHER = 'other'
+    CLASSIFICATION = "classification"
+    SEMANTIC_SEGMENTATION = "semantic_segmentation"
+    OBJECT_DETECTION = "object_detection"
+    DEPTH_ESTIMATION = "depth_estimation"
+    POSE_ESTIMATION = "pose_estimation"
+    NLP = "nlp"
+    OTHER = "other"
Discard
@@ -12,5 +12,6 @@ class EvaluationType(str, Enum):
             VALIDATION
             VALIDATION
 
 
     """
     """
-    TEST = 'TEST'
-    VALIDATION = 'VALIDATION'
+
+    TEST = "TEST"
+    VALIDATION = "VALIDATION"
Discard
@@ -2,4 +2,4 @@
 from super_gradients.common.decorators.explicit_params_validator import explicit_params_validation
 from super_gradients.common.decorators.explicit_params_validator import explicit_params_validation
 from super_gradients.common.decorators.singleton import singleton
 from super_gradients.common.decorators.singleton import singleton
 
 
-__all__ = ['explicit_params_validation', 'singleton']
+__all__ = ["explicit_params_validation", "singleton"]
Discard
@@ -1,4 +1,4 @@
-def deci_func_logger(_func=None, *, name: str = 'abstract_decorator'):
+def deci_func_logger(_func=None, *, name: str = "abstract_decorator"):
     """
     """
     This decorator is used to wrap our functions with logs.
     This decorator is used to wrap our functions with logs.
     It will log every enter and exit of the functon with the equivalent parameters as extras.
     It will log every enter and exit of the functon with the equivalent parameters as extras.
Discard
@@ -4,7 +4,7 @@ from typing import Callable
 
 
 
 
 class _ExplicitParamsValidator:
 class _ExplicitParamsValidator:
-    def __init__(self, function: Callable, validation_type: str = 'None'):
+    def __init__(self, function: Callable, validation_type: str = "None"):
         """
         """
         ExplicitParamsValidator
         ExplicitParamsValidator
             :param function:
             :param function:
@@ -23,7 +23,7 @@ class _ExplicitParamsValidator:
             :param kwargs:
             :param kwargs:
             :return:
             :return:
         """
         """
-        if not hasattr(self, 'func'):
+        if not hasattr(self, "func"):
             self.func = args[0]
             self.func = args[0]
             return self
             return self
 
 
@@ -39,12 +39,12 @@ class _ExplicitParamsValidator:
         """
         """
         var_names = inspect.getfullargspec(self.func)[0]
         var_names = inspect.getfullargspec(self.func)[0]
 
 
-        explicit_args_var_names = list(var_names[:len(args)])
+        explicit_args_var_names = list(var_names[: len(args)])
 
 
         # FOR CLASS METHOD REMOVE THE EXPLICIT DEMAND FOR self PARAMETER
         # FOR CLASS METHOD REMOVE THE EXPLICIT DEMAND FOR self PARAMETER
         for params_list in [explicit_args_var_names, list(kwargs.keys())]:
         for params_list in [explicit_args_var_names, list(kwargs.keys())]:
-            if 'self' in params_list:
-                params_list.remove('self')
+            if "self" in params_list:
+                params_list.remove("self")
 
 
         # FIRST OF ALL HANDLE ALL OF THE KEYWORD ARGUMENTS
         # FIRST OF ALL HANDLE ALL OF THE KEYWORD ARGUMENTS
         for kwarg, value in kwargs.items():
         for kwarg, value in kwargs.items():
@@ -63,19 +63,20 @@ class _ExplicitParamsValidator:
         :param value:
         :param value:
         :return:
         :return:
         """
         """
-        if self.validation_type == 'NoneOrEmpty':
+        if self.validation_type == "NoneOrEmpty":
             if not value:
             if not value:
-                raise ValueError('Input param: ' + str(input_param) + ' is Empty')
+                raise ValueError("Input param: " + str(input_param) + " is Empty")
 
 
         if value is None:
         if value is None:
-            raise ValueError('Input param: ' + str(input_param) + ' is None')
+            raise ValueError("Input param: " + str(input_param) + " is None")
 
 
 
 
 # WRAPS THE RETRY DECORATOR CLASS TO ENABLE CALLING WITHOUT PARAMS
 # WRAPS THE RETRY DECORATOR CLASS TO ENABLE CALLING WITHOUT PARAMS
-def explicit_params_validation(function: Callable = None, validation_type: str = 'None'):
+def explicit_params_validation(function: Callable = None, validation_type: str = "None"):
     if function is not None:
     if function is not None:
         return _ExplicitParamsValidator(function=function)
         return _ExplicitParamsValidator(function=function)
     else:
     else:
+
         def wrapper(function):
         def wrapper(function):
             return _ExplicitParamsValidator(function=function, validation_type=validation_type)
             return _ExplicitParamsValidator(function=function, validation_type=validation_type)
 
 
Discard
@@ -34,5 +34,7 @@ def resolve_param(param_name: str, factory: AbstractFactory):
                         new_value = factory.get(args[index])
                         new_value = factory.get(args[index])
                         args = _assign_tuple(args, index, new_value)
                         args = _assign_tuple(args, index, new_value)
             return func(*args, **kwargs)
             return func(*args, **kwargs)
+
         return wrapper
         return wrapper
+
     return inner
     return inner
Discard
@@ -63,7 +63,7 @@ class AbstractSGLogger(ABC):
         raise NotImplementedError
         raise NotImplementedError
 
 
     @abstractmethod
     @abstractmethod
-    def add_image(self, tag: str, image: Union[torch.Tensor, np.array, Image.Image], data_format: str = 'CHW', global_step: int = None):
+    def add_image(self, tag: str, image: Union[torch.Tensor, np.array, Image.Image], data_format: str = "CHW", global_step: int = None):
         """
         """
         Add a single image to SGLogger.
         Add a single image to SGLogger.
         Typically, this function will add an image to tensorboard, save it to disk or add it to experiment management framework.
         Typically, this function will add an image to tensorboard, save it to disk or add it to experiment management framework.
@@ -76,7 +76,7 @@ class AbstractSGLogger(ABC):
         raise NotImplementedError
         raise NotImplementedError
 
 
     @abstractmethod
     @abstractmethod
-    def add_images(self, tag: str, images: Union[torch.Tensor, np.array], data_format='NCHW', global_step: int = None):
+    def add_images(self, tag: str, images: Union[torch.Tensor, np.array], data_format="NCHW", global_step: int = None):
         """
         """
         Add multiple images to SGLogger.
         Add multiple images to SGLogger.
         Typically, this function will add images to tensorboard, save them to disk or add them to experiment management framework.
         Typically, this function will add images to tensorboard, save them to disk or add them to experiment management framework.
@@ -89,7 +89,7 @@ class AbstractSGLogger(ABC):
         raise NotImplementedError
         raise NotImplementedError
 
 
     @abstractmethod
     @abstractmethod
-    def add_histogram(self, tag: str, values: Union[torch.Tensor, np.array], bins: Union[str, np.array, list, int] = 'auto', global_step: int = None):
+    def add_histogram(self, tag: str, values: Union[torch.Tensor, np.array], bins: Union[str, np.array, list, int] = "auto", global_step: int = None):
         """
         """
         Add a histogram to SGLogger.
         Add a histogram to SGLogger.
         Typically, this function will add a histogram to tensorboard or add it to experiment management framework.
         Typically, this function will add a histogram to tensorboard or add it to experiment management framework.
Discard