1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
|
- import os
- DAGSHUB_TOKEN = os.environ.get('DAGSHUB_TOKEN', None)
- DAGSHUB_USER_NAME = os.environ.get('DAGSHUB_USER_NAME', None)
- DAGSHUB_REPO_NAME = os.environ.get('DAGSHUB_REPO_NAME', None)
- if DAGSHUB_TOKEN is None:
- raise EnvironmentError("Environment variable 'DAGSHUB_TOKEN' must be set with valid token")
- if DAGSHUB_USER_NAME is None:
- raise EnvironmentError("Environment variable 'DAGSHUB_USER_NAME' must be set")
- if DAGSHUB_REPO_NAME is None:
- raise EnvironmentError("Environment variable 'DAGSHUB_REPO_NAME' must be set")
- # =============
- # Setup DagsHub
- # =============
- import dagshub
- dagshub.auth.add_app_token(DAGSHUB_TOKEN)
- from dagshub.streaming import install_hooks
- install_hooks(project_root='.', repo_url='https://dagshub.com/DagsHub-Datasets/LAION-Aesthetics-V2-6.5plus', branch='main')
- # ============
- # Setup MLflow
- # ============
- os.environ['MLFLOW_TRACKING_URI']=f"https://dagshub.com/{DAGSHUB_USER_NAME}/{DAGSHUB_REPO_NAME}.mlflow"
- os.environ['MLFLOW_TRACKING_USERNAME'] = DAGSHUB_USER_NAME
- os.environ['MLFLOW_TRACKING_PASSWORD'] = DAGSHUB_TOKEN
- import mlflow
- import logging
- from datetime import datetime
- import torch.utils.data
- from torch import optim
- from minimagen.Imagen import Imagen
- from minimagen.Unet import Unet, Base, Super, BaseTest, SuperTest
- from minimagen.generate import load_minimagen, load_params
- from minimagen.training import get_minimagen_parser, get_minimagen_dl_opts, \
- create_directory, get_model_params, get_model_size, save_training_info, get_default_args, MinimagenTrain, \
- load_restart_training_parameters, load_testing_parameters
- from data import train_valid_split
- from loss import calculate_loss_per_unet
- from helper import get_experiment_id
- logging.basicConfig(format='%(asctime)s.%(msecs)03d - %(levelname)-8s %(message)s', level=logging.INFO, datefmt='%Y-%m-%d %H:%M:%S')
- # Get device
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
- # Command line argument parser. See `training.get_minimagen_parser()`.
- parser = get_minimagen_parser()
- # Add argument for when using `main.py`
- parser.add_argument("-ts", "--TIMESTAMP", dest="timestamp", help="Timestamp for training directory", type=str,
- default=None)
- parser.add_argument("-id", "--INPUT_DIRECTORY", dest="INPUT_DIRECTORY", help="Input directory with images and labels.tsv file", type=str, default='data')
- parser.add_argument("-od", "--OUTPUT_DIRECTORY", dest="OUTPUT_DIRECTORY", help="Output directory with training info and model", type=str, default='/working/')
- args = parser.parse_args()
- timestamp = args.timestamp
- # Get training timestamp for when running train.py as main rather than via main.py
- if timestamp is None:
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
- # Create training directory
- dir_path = os.path.join(args.OUTPUT_DIRECTORY, f"./training_{timestamp}")
- training_dir = create_directory(dir_path)
- # If loading from a parameters/training directory
- if args.RESTART_DIRECTORY is not None:
- args = load_restart_training_parameters(args)
- elif args.PARAMETERS is not None:
- args = load_restart_training_parameters(args, justparams=True)
- # If testing, lower parameter values to lower computational load and also to lower amount of data being used.
- if args.TESTING:
- args = load_testing_parameters(args)
- train_dataset, valid_dataset = train_valid_split(args, smalldata=True)
- else:
- train_dataset, valid_dataset = train_valid_split(args, smalldata=False)
- logging.info(f'Training size: {len(train_dataset)}')
- logging.info(f'Validation size: {len(valid_dataset)}')
- # Create dataloaders
- dl_opts = {**get_minimagen_dl_opts(device), 'batch_size': args.BATCH_SIZE, 'num_workers': args.NUM_WORKERS}
- train_dataloader = torch.utils.data.DataLoader(train_dataset, **dl_opts)
- valid_dataloader = torch.utils.data.DataLoader(valid_dataset, **dl_opts)
- # Create Unets
- if args.RESTART_DIRECTORY is None:
- imagen_params = dict(
- image_sizes=(int(args.IMG_SIDE_LEN / 2), args.IMG_SIDE_LEN),
- timesteps=args.TIMESTEPS,
- cond_drop_prob=0.15,
- text_encoder_name=args.T5_NAME
- )
- # If not loading a training from a checkpoint
- if args.TESTING:
- # If testing, use tiny MinImagen for low computational load
- unets_params = [get_default_args(BaseTest), get_default_args(SuperTest)]
- # Else if not loading Unet/Imagen settings from a config (parameters) folder, use defaults
- elif not args.PARAMETERS:
- # If no parameters provided, use params from minimagen.Imagen.Base and minimagen.Imagen.Super built-in classes
- unets_params = [get_default_args(Base), get_default_args(Super)]
- # Else load unet/Imagen configs from config (parameters) folder (override imagen+params)
- else:
- # If parameters are provided, load them
- unets_params, imagen_params = get_model_params(args.PARAMETERS)
- # Create Unets accoridng to unets_params
- unets = [Unet(**unet_params).to(device) for unet_params in unets_params]
- # Create Imagen from UNets with specified imagen parameters
- imagen = Imagen(unets=unets, **imagen_params).to(device)
- else:
- # If training is being resumed from a previous one, load all relevant models/info (load config AND state dicts)
- orig_train_dir = os.path.join(os.getcwd(), args.RESTART_DIRECTORY)
- unets_params, imagen_params = load_params(orig_train_dir)
- imagen = load_minimagen(orig_train_dir).to(device)
- unets = imagen.unets
- # Fill in unspecified arguments with defaults for complete config (parameters) file
- unets_params = [{**get_default_args(Unet), **i} for i in unets_params]
- imagen_params = {**get_default_args(Imagen), **imagen_params}
- # Get the size of the Imagen model in megabytes
- model_size_MB = get_model_size(imagen)
- # Save all training info (config files, model size, etc.)
- save_training_info(args, timestamp, unets_params, imagen_params, model_size_MB, training_dir)
- # Create optimizer
- optimizer = optim.Adam(imagen.parameters(), lr=args.OPTIM_LR)
- # Train the MinImagen instance
- MinimagenTrain(timestamp, args, unets, imagen, train_dataloader, valid_dataloader, training_dir, optimizer, timeout=30)
- avg_losses = calculate_loss_per_unet(unets, imagen, valid_dataloader)
- experiment_id = get_experiment_id('minimagen')
- with mlflow.start_run(experiment_id=experiment_id):
- mlflow.log_params(args.__dict__)
- mlflow.log_params({
- "unet0_params": unets_params[0],
- "unet1_params": unets_params[1],
- "imagen_params": imagen_params
- })
- mlflow.log_param("model_size_MB", model_size_MB)
- mlflow.log_metrics({
- "unet0_avg_loss": avg_losses[0],
- "unet1_avg_loss": avg_losses[1],
- })
- mlflow.log_artifact(dir_path)
|