Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#620 Black on factories and data_interface

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:hotfix/SG-000-black_on_some_common
@@ -13,7 +13,7 @@ class ADNNModelRepositoryDataInterfaces(ILogger):
     ResearchModelRepositoryDataInterface
     ResearchModelRepositoryDataInterface
     """
     """
 
 
-    def __init__(self, data_connection_location: str = 'local', data_connection_credentials: str = None):
+    def __init__(self, data_connection_location: str = "local", data_connection_credentials: str = None):
         """
         """
         ModelCheckpointsDataInterface
         ModelCheckpointsDataInterface
             :param data_connection_location: 'local' or s3 bucket 's3://my-bucket-name'
             :param data_connection_location: 'local' or s3 bucket 's3://my-bucket-name'
@@ -22,22 +22,22 @@ class ADNNModelRepositoryDataInterfaces(ILogger):
                     AWS_PROFILE if left empty
                     AWS_PROFILE if left empty
         """
         """
         super().__init__()
         super().__init__()
-        self.tb_events_file_prefix = 'events.out.tfevents'
-        self.log_file_prefix = 'log_'
-        self.latest_checkpoint_filename = 'ckpt_latest.pth'
-        self.best_checkpoint_filename = 'ckpt_best.pth'
+        self.tb_events_file_prefix = "events.out.tfevents"
+        self.log_file_prefix = "log_"
+        self.latest_checkpoint_filename = "ckpt_latest.pth"
+        self.best_checkpoint_filename = "ckpt_best.pth"
 
 
-        if data_connection_location.startswith('s3'):
-            assert data_connection_location.index('s3://') >= 0, 'S3 path must be formatted s3://bucket-name'
-            self.model_repo_bucket_name = data_connection_location.split('://')[1]
-            self.data_connection_source = 's3'
+        if data_connection_location.startswith("s3"):
+            assert data_connection_location.index("s3://") >= 0, "S3 path must be formatted s3://bucket-name"
+            self.model_repo_bucket_name = data_connection_location.split("://")[1]
+            self.data_connection_source = "s3"
 
 
             if data_connection_credentials is None:
             if data_connection_credentials is None:
-                data_connection_credentials = os.getenv('AWS_PROFILE')
+                data_connection_credentials = os.getenv("AWS_PROFILE")
 
 
             self.s3_connector = S3Connector(data_connection_credentials, self.model_repo_bucket_name)
             self.s3_connector = S3Connector(data_connection_credentials, self.model_repo_bucket_name)
 
 
-    @explicit_params_validation(validation_type='None')
+    @explicit_params_validation(validation_type="None")
     def load_all_remote_log_files(self, model_name: str, model_checkpoint_local_dir: str):
     def load_all_remote_log_files(self, model_name: str, model_checkpoint_local_dir: str):
         """
         """
         load_all_remote_checkpoint_files
         load_all_remote_checkpoint_files
@@ -45,12 +45,10 @@ class ADNNModelRepositoryDataInterfaces(ILogger):
             :param model_checkpoint_local_dir:
             :param model_checkpoint_local_dir:
             :return:
             :return:
         """
         """
-        self.load_remote_logging_files(model_name=model_name, model_checkpoint_dir_name=model_checkpoint_local_dir,
-                                       logging_type='tensorboard')
-        self.load_remote_logging_files(model_name=model_name, model_checkpoint_dir_name=model_checkpoint_local_dir,
-                                       logging_type='text')
+        self.load_remote_logging_files(model_name=model_name, model_checkpoint_dir_name=model_checkpoint_local_dir, logging_type="tensorboard")
+        self.load_remote_logging_files(model_name=model_name, model_checkpoint_dir_name=model_checkpoint_local_dir, logging_type="text")
 
 
-    @explicit_params_validation(validation_type='None')
+    @explicit_params_validation(validation_type="None")
     def save_all_remote_checkpoint_files(self, model_name: str, model_checkpoint_local_dir: str, log_file_name: str):
     def save_all_remote_checkpoint_files(self, model_name: str, model_checkpoint_local_dir: str, log_file_name: str):
         """
         """
         save_all_remote_checkpoint_files - Saves all of the local Checkpoint data into Remote Repo
         save_all_remote_checkpoint_files - Saves all of the local Checkpoint data into Remote Repo
@@ -64,9 +62,10 @@ class ADNNModelRepositoryDataInterfaces(ILogger):
         self.save_remote_checkpoints_file(model_name, model_checkpoint_local_dir, log_file_name)
         self.save_remote_checkpoints_file(model_name, model_checkpoint_local_dir, log_file_name)
         self.save_remote_tensorboard_event_files(model_name, model_checkpoint_local_dir)
         self.save_remote_tensorboard_event_files(model_name, model_checkpoint_local_dir)
 
 
-    @explicit_params_validation(validation_type='None')
-    def load_remote_checkpoints_file(self, ckpt_source_remote_dir: str, ckpt_destination_local_dir: str,
-                                     ckpt_file_name: str, overwrite_local_checkpoints_file: bool = False) -> str:
+    @explicit_params_validation(validation_type="None")
+    def load_remote_checkpoints_file(
+        self, ckpt_source_remote_dir: str, ckpt_destination_local_dir: str, ckpt_file_name: str, overwrite_local_checkpoints_file: bool = False
+    ) -> str:
         """
         """
         load_remote_checkpoints_file - Loads a model's checkpoint from local/cloud file
         load_remote_checkpoints_file - Loads a model's checkpoint from local/cloud file
             :param ckpt_source_remote_dir:               The source folder to download from
             :param ckpt_source_remote_dir:               The source folder to download from
@@ -76,27 +75,26 @@ class ADNNModelRepositoryDataInterfaces(ILogger):
                                                             is to overwrite a previous version of the same files
                                                             is to overwrite a previous version of the same files
             :return: Model Checkpoint File Path -> Depends on model architecture
             :return: Model Checkpoint File Path -> Depends on model architecture
         """
         """
-        ckpt_file_local_full_path = ckpt_destination_local_dir + '/' + ckpt_file_name
+        ckpt_file_local_full_path = ckpt_destination_local_dir + "/" + ckpt_file_name
 
 
-        if self.data_connection_source == 's3':
+        if self.data_connection_source == "s3":
             if overwrite_local_checkpoints_file:
             if overwrite_local_checkpoints_file:
                 # DELETE THE LOCAL VERSION ON THE MACHINE
                 # DELETE THE LOCAL VERSION ON THE MACHINE
                 if os.path.exists(ckpt_file_local_full_path):
                 if os.path.exists(ckpt_file_local_full_path):
                     os.remove(ckpt_file_local_full_path)
                     os.remove(ckpt_file_local_full_path)
 
 
-            key_to_download = ckpt_source_remote_dir + '/' + ckpt_file_name
-            download_success = self.s3_connector.download_key(target_path=ckpt_file_local_full_path,
-                                                              key_to_download=key_to_download)
+            key_to_download = ckpt_source_remote_dir + "/" + ckpt_file_name
+            download_success = self.s3_connector.download_key(target_path=ckpt_file_local_full_path, key_to_download=key_to_download)
 
 
             if not download_success:
             if not download_success:
-                failed_download_path = 's3://' + self.model_repo_bucket_name + '/' + key_to_download
-                error_msg = 'Failed to Download Model Checkpoint from ' + failed_download_path
+                failed_download_path = "s3://" + self.model_repo_bucket_name + "/" + key_to_download
+                error_msg = "Failed to Download Model Checkpoint from " + failed_download_path
                 self._logger.error(error_msg)
                 self._logger.error(error_msg)
                 raise ModelCheckpointNotFoundException(error_msg)
                 raise ModelCheckpointNotFoundException(error_msg)
 
 
         return ckpt_file_local_full_path
         return ckpt_file_local_full_path
 
 
-    @explicit_params_validation(validation_type='NoneOrEmpty')
+    @explicit_params_validation(validation_type="NoneOrEmpty")
     def load_remote_logging_files(self, model_name: str, model_checkpoint_dir_name: str, logging_type: str):
     def load_remote_logging_files(self, model_name: str, model_checkpoint_dir_name: str, logging_type: str):
         """
         """
         load_remote_tensorboard_event_files - Downloads all of the tb_events Files from remote repository
         load_remote_tensorboard_event_files - Downloads all of the tb_events Files from remote repository
@@ -106,24 +104,23 @@ class ADNNModelRepositoryDataInterfaces(ILogger):
             :return:
             :return:
         """
         """
         if not os.path.isdir(model_checkpoint_dir_name):
         if not os.path.isdir(model_checkpoint_dir_name):
-            raise ValueError('[' + sys._getframe().f_code.co_name + '] - Provided directory does not exist')
+            raise ValueError("[" + sys._getframe().f_code.co_name + "] - Provided directory does not exist")
 
 
         # LOADS THE DATA FROM THE REMOTE REPOSITORY
         # LOADS THE DATA FROM THE REMOTE REPOSITORY
         s3_bucket_path_prefix = model_name
         s3_bucket_path_prefix = model_name
-        if logging_type == 'tensorboard':
-            if self.data_connection_source == 's3':
-                self.s3_connector.download_keys_by_prefix(s3_bucket_path_prefix=s3_bucket_path_prefix,
-                                                          local_download_dir=model_checkpoint_dir_name,
-                                                          s3_file_path_prefix=self.tb_events_file_prefix)
-        elif logging_type == 'text':
-            if self.data_connection_source == 's3':
-                self.s3_connector.download_keys_by_prefix(s3_bucket_path_prefix=s3_bucket_path_prefix,
-                                                          local_download_dir=model_checkpoint_dir_name,
-                                                          s3_file_path_prefix=self.log_file_prefix)
-
-    @explicit_params_validation(validation_type='NoneOrEmpty')
-    def save_remote_checkpoints_file(self, model_name: str, model_checkpoint_local_dir: str,
-                                     checkpoints_file_name: str) -> bool:
+        if logging_type == "tensorboard":
+            if self.data_connection_source == "s3":
+                self.s3_connector.download_keys_by_prefix(
+                    s3_bucket_path_prefix=s3_bucket_path_prefix, local_download_dir=model_checkpoint_dir_name, s3_file_path_prefix=self.tb_events_file_prefix
+                )
+        elif logging_type == "text":
+            if self.data_connection_source == "s3":
+                self.s3_connector.download_keys_by_prefix(
+                    s3_bucket_path_prefix=s3_bucket_path_prefix, local_download_dir=model_checkpoint_dir_name, s3_file_path_prefix=self.log_file_prefix
+                )
+
+    @explicit_params_validation(validation_type="NoneOrEmpty")
+    def save_remote_checkpoints_file(self, model_name: str, model_checkpoint_local_dir: str, checkpoints_file_name: str) -> bool:
         """
         """
         save_remote_checkpoints_file - Saves a Checkpoints file in the Remote Repo
         save_remote_checkpoints_file - Saves a Checkpoints file in the Remote Repo
             :param model_name:                      The Model Name for S3 Prefix
             :param model_name:                      The Model Name for S3 Prefix
@@ -132,14 +129,14 @@ class ADNNModelRepositoryDataInterfaces(ILogger):
             :return: True/False for Operation Success/Failure
             :return: True/False for Operation Success/Failure
         """
         """
         # LOAD THE LOCAL VERSION
         # LOAD THE LOCAL VERSION
-        model_checkpoint_file_full_path = model_checkpoint_local_dir + '/' + checkpoints_file_name
+        model_checkpoint_file_full_path = model_checkpoint_local_dir + "/" + checkpoints_file_name
 
 
         # SAVE ON THE REMOTE S3 REPOSITORY
         # SAVE ON THE REMOTE S3 REPOSITORY
-        if self.data_connection_source == 's3':
-            model_checkpoint_s3_in_bucket_path = model_name + '/' + checkpoints_file_name
+        if self.data_connection_source == "s3":
+            model_checkpoint_s3_in_bucket_path = model_name + "/" + checkpoints_file_name
             return self.__update_or_upload_s3_key(model_checkpoint_file_full_path, model_checkpoint_s3_in_bucket_path)
             return self.__update_or_upload_s3_key(model_checkpoint_file_full_path, model_checkpoint_s3_in_bucket_path)
 
 
-    @explicit_params_validation(validation_type='NoneOrEmpty')
+    @explicit_params_validation(validation_type="NoneOrEmpty")
     def save_remote_tensorboard_event_files(self, model_name: str, model_checkpoint_dir_name: str):
     def save_remote_tensorboard_event_files(self, model_name: str, model_checkpoint_dir_name: str):
         """
         """
         save_remote_tensorboard_event_files - Saves all of the tensorboard files remotely
         save_remote_tensorboard_event_files - Saves all of the tensorboard files remotely
@@ -147,18 +144,18 @@ class ADNNModelRepositoryDataInterfaces(ILogger):
             :param model_checkpoint_dir_name: The directory where the files are stored in
             :param model_checkpoint_dir_name: The directory where the files are stored in
         """
         """
         if not os.path.isdir(model_checkpoint_dir_name):
         if not os.path.isdir(model_checkpoint_dir_name):
-            raise ValueError('[' + sys._getframe().f_code.co_name + '] - Provided directory does not exist')
+            raise ValueError("[" + sys._getframe().f_code.co_name + "] - Provided directory does not exist")
 
 
         for tb_events_file_name in os.listdir(model_checkpoint_dir_name):
         for tb_events_file_name in os.listdir(model_checkpoint_dir_name):
             if tb_events_file_name.startswith(self.tb_events_file_prefix):
             if tb_events_file_name.startswith(self.tb_events_file_prefix):
-                upload_success = self.save_remote_checkpoints_file(model_name=model_name,
-                                                                   model_checkpoint_local_dir=model_checkpoint_dir_name,
-                                                                   checkpoints_file_name=tb_events_file_name)
+                upload_success = self.save_remote_checkpoints_file(
+                    model_name=model_name, model_checkpoint_local_dir=model_checkpoint_dir_name, checkpoints_file_name=tb_events_file_name
+                )
 
 
                 if not upload_success:
                 if not upload_success:
-                    self._logger.error('Failed to upload tb_events_file: ' + tb_events_file_name)
+                    self._logger.error("Failed to upload tb_events_file: " + tb_events_file_name)
 
 
-    @explicit_params_validation(validation_type='NoneOrEmpty')
+    @explicit_params_validation(validation_type="NoneOrEmpty")
     def __update_or_upload_s3_key(self, local_file_path: str, s3_key_path: str):
     def __update_or_upload_s3_key(self, local_file_path: str, s3_key_path: str):
         """
         """
         __update_or_upload_s3_key - Uploads/Updates an S3 Key based on a local file path
         __update_or_upload_s3_key - Uploads/Updates an S3 Key based on a local file path
@@ -169,10 +166,10 @@ class ADNNModelRepositoryDataInterfaces(ILogger):
             # DELETE KEY TO UPDATE THE FILE IN S3
             # DELETE KEY TO UPDATE THE FILE IN S3
             delete_response = self.s3_connector.delete_key(s3_key_path)
             delete_response = self.s3_connector.delete_key(s3_key_path)
             if delete_response:
             if delete_response:
-                self._logger.info('Removed previous checkpoint from S3')
+                self._logger.info("Removed previous checkpoint from S3")
 
 
         upload_success = self.s3_connector.upload_file(local_file_path, s3_key_path)
         upload_success = self.s3_connector.upload_file(local_file_path, s3_key_path)
         if not upload_success:
         if not upload_success:
-            self._logger.error('Failed to upload model checkpoint')
+            self._logger.error("Failed to upload model checkpoint")
 
 
         return upload_success
         return upload_success
Discard
@@ -3,6 +3,5 @@ from super_gradients.training.utils.callbacks import CALLBACKS
 
 
 
 
 class CallbacksFactory(BaseFactory):
 class CallbacksFactory(BaseFactory):
-
     def __init__(self):
     def __init__(self):
         super().__init__(CALLBACKS)
         super().__init__(CALLBACKS)
Discard
@@ -4,7 +4,6 @@ from super_gradients.common.factories.base_factory import AbstractFactory
 
 
 
 
 class ListFactory(AbstractFactory):
 class ListFactory(AbstractFactory):
-
     def __init__(self, factry: AbstractFactory):
     def __init__(self, factry: AbstractFactory):
         self.factry = factry
         self.factry = factry
 
 
Discard
@@ -3,6 +3,5 @@ from super_gradients.training.losses import LOSSES
 
 
 
 
 class LossesFactory(BaseFactory):
 class LossesFactory(BaseFactory):
-
     def __init__(self):
     def __init__(self):
         super().__init__(LOSSES)
         super().__init__(LOSSES)
Discard
@@ -3,6 +3,5 @@ from super_gradients.training.metrics import METRICS
 
 
 
 
 class MetricsFactory(BaseFactory):
 class MetricsFactory(BaseFactory):
-
     def __init__(self):
     def __init__(self):
         super().__init__(METRICS)
         super().__init__(METRICS)
Discard
@@ -3,6 +3,5 @@ from super_gradients.training.datasets.samplers import SAMPLERS
 
 
 
 
 class SamplersFactory(BaseFactory):
 class SamplersFactory(BaseFactory):
-
     def __init__(self):
     def __init__(self):
         super().__init__(SAMPLERS)
         super().__init__(SAMPLERS)
Discard