|
@@ -22,25 +22,26 @@ except (ModuleNotFoundError, ImportError, NameError):
|
|
|
WANDB_ID_PREFIX = 'wandb_id.'
|
|
|
WANDB_INCLUDE_FILE_NAME = '.wandbinclude'
|
|
|
|
|
|
+
|
|
|
class WandBSGLogger(BaseSGLogger):
|
|
|
|
|
|
- def __init__(self, project_name: str, experiment_name: str, storage_location: str, resumed: bool, training_params: dict, checkpoints_dir_path: str, tb_files_user_prompt: bool = False,
|
|
|
- launch_tensorboard: bool = False, tensorboard_port: int = None, save_checkpoints_remote: bool = True, save_tensorboard_remote: bool = True,
|
|
|
- save_logs_remote: bool = True, entity: Optional[str] = None, api_server: Optional[str] = None, save_code: bool = False, **kwargs):
|
|
|
+ def __init__(self, project_name: str, experiment_name: str, storage_location: str, resumed: bool, training_params: dict, checkpoints_dir_path: str,
|
|
|
+ tb_files_user_prompt: bool = False, launch_tensorboard: bool = False, tensorboard_port: int = None, save_checkpoints_remote: bool = True,
|
|
|
+ save_tensorboard_remote: bool = True, save_logs_remote: bool = True, entity: Optional[str] = None, api_server: Optional[str] = None,
|
|
|
+ save_code: bool = False, **kwargs):
|
|
|
"""
|
|
|
|
|
|
- :param experiment_name: Used for logging and loading purposes
|
|
|
- :param s3_path: If set to 's3' (i.e. s3://my-bucket) saves the Checkpoints in AWS S3 otherwise saves the Checkpoints Locally
|
|
|
- :param checkpoint_loaded: if true, then old tensorboard files will *not* be deleted when tb_files_user_prompt=True
|
|
|
- :param max_epochs: the number of epochs planned for this training
|
|
|
- :param tb_files_user_prompt: Asks user for Tensorboard deletion prompt.
|
|
|
- :param launch_tensorboard: Whether to launch a TensorBoard process.
|
|
|
- :param tensorboard_port: Specific port number for the tensorboard to use when launched (when set to None, some free port
|
|
|
- number will be used
|
|
|
+ :param experiment_name: Used for logging and loading purposes
|
|
|
+ :param s3_path: If set to 's3' (i.e. s3://my-bucket) saves the Checkpoints in AWS S3 otherwise saves the Checkpoints Locally
|
|
|
+ :param checkpoint_loaded: If true, then old tensorboard files will *not* be deleted when tb_files_user_prompt=True
|
|
|
+ :param max_epochs: Number of epochs planned for this training
|
|
|
+ :param tb_files_user_prompt: Asks user for Tensorboard deletion prompt.
|
|
|
+ :param launch_tensorboard: Whether to launch a TensorBoard process.
|
|
|
+ :param tensorboard_port: Specific port number for the tensorboard to use when launched (when set to None, some free port number will be used)
|
|
|
:param save_checkpoints_remote: Saves checkpoints in s3.
|
|
|
:param save_tensorboard_remote: Saves tensorboard in s3.
|
|
|
- :param save_logs_remote: Saves log files in s3.
|
|
|
- :param save_code: save current code to wandb
|
|
|
+ :param save_logs_remote: Saves log files in s3.
|
|
|
+ :param save_code: Save current code to wandb
|
|
|
"""
|
|
|
self.s3_location_available = storage_location.startswith('s3')
|
|
|
super().__init__(project_name, experiment_name, storage_location, resumed, training_params,
|
|
@@ -103,7 +104,6 @@ class WandBSGLogger(BaseSGLogger):
|
|
|
else:
|
|
|
wandb.run.log_code(".", include_fn=include_fn)
|
|
|
|
|
|
-
|
|
|
@multi_process_safe
|
|
|
def add_config(self, tag: str, config: dict):
|
|
|
super(WandBSGLogger, self).add_config(tag=tag, config=config)
|
|
@@ -206,7 +206,7 @@ class WandBSGLogger(BaseSGLogger):
|
|
|
def _get_tensorboard_file_name(self):
|
|
|
try:
|
|
|
tb_file_path = self.tensorboard_writer.file_writer.event_writer._file_name
|
|
|
- except RuntimeError as e:
|
|
|
+ except RuntimeError:
|
|
|
logger.warning('tensorboard file could not be located for ')
|
|
|
return None
|
|
|
|
|
@@ -272,7 +272,7 @@ class WandBSGLogger(BaseSGLogger):
|
|
|
return os.path.join(cur_dir, file_name)
|
|
|
else:
|
|
|
cur_dir = os.path.dirname(cur_dir)
|
|
|
- except RuntimeError as e:
|
|
|
+ except RuntimeError:
|
|
|
return None
|
|
|
|
|
|
return None
|