Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

train.py 6.6 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
  1. import os
  2. DAGSHUB_TOKEN = os.environ.get('DAGSHUB_TOKEN', None)
  3. DAGSHUB_USER_NAME = os.environ.get('DAGSHUB_USER_NAME', None)
  4. DAGSHUB_REPO_NAME = os.environ.get('DAGSHUB_REPO_NAME', None)
  5. if DAGSHUB_TOKEN is None:
  6. raise EnvironmentError("Environment variable 'DAGSHUB_TOKEN' must be set with valid token")
  7. if DAGSHUB_USER_NAME is None:
  8. raise EnvironmentError("Environment variable 'DAGSHUB_USER_NAME' must be set")
  9. if DAGSHUB_REPO_NAME is None:
  10. raise EnvironmentError("Environment variable 'DAGSHUB_REPO_NAME' must be set")
  11. # =============
  12. # Setup DagsHub
  13. # =============
  14. import dagshub
  15. dagshub.auth.add_app_token(DAGSHUB_TOKEN)
  16. from dagshub.streaming import install_hooks
  17. install_hooks(project_root='.', repo_url='https://dagshub.com/DagsHub-Datasets/LAION-Aesthetics-V2-6.5plus', branch='main')
  18. # ============
  19. # Setup MLflow
  20. # ============
  21. os.environ['MLFLOW_TRACKING_URI']=f"https://dagshub.com/{DAGSHUB_USER_NAME}/{DAGSHUB_REPO_NAME}.mlflow"
  22. os.environ['MLFLOW_TRACKING_USERNAME'] = DAGSHUB_USER_NAME
  23. os.environ['MLFLOW_TRACKING_PASSWORD'] = DAGSHUB_TOKEN
  24. import mlflow
  25. import logging
  26. from datetime import datetime
  27. import torch.utils.data
  28. from torch import optim
  29. from minimagen.Imagen import Imagen
  30. from minimagen.Unet import Unet, Base, Super, BaseTest, SuperTest
  31. from minimagen.generate import load_minimagen, load_params
  32. from minimagen.training import get_minimagen_parser, get_minimagen_dl_opts, \
  33. create_directory, get_model_params, get_model_size, save_training_info, get_default_args, MinimagenTrain, \
  34. load_restart_training_parameters, load_testing_parameters
  35. from data import train_valid_split
  36. from loss import calculate_loss_per_unet
  37. from helper import get_experiment_id
  38. logging.basicConfig(format='%(asctime)s.%(msecs)03d - %(levelname)-8s %(message)s', level=logging.INFO, datefmt='%Y-%m-%d %H:%M:%S')
  39. # Get device
  40. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  41. # Command line argument parser. See `training.get_minimagen_parser()`.
  42. parser = get_minimagen_parser()
  43. # Add argument for when using `main.py`
  44. parser.add_argument("-ts", "--TIMESTAMP", dest="timestamp", help="Timestamp for training directory", type=str,
  45. default=None)
  46. parser.add_argument("-id", "--INPUT_DIRECTORY", dest="INPUT_DIRECTORY", help="Input directory with images and labels.tsv file", type=str, default='data')
  47. parser.add_argument("-od", "--OUTPUT_DIRECTORY", dest="OUTPUT_DIRECTORY", help="Output directory with training info and model", type=str, default='/working/')
  48. args = parser.parse_args()
  49. timestamp = args.timestamp
  50. # Get training timestamp for when running train.py as main rather than via main.py
  51. if timestamp is None:
  52. timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
  53. # Create training directory
  54. dir_path = os.path.join(args.OUTPUT_DIRECTORY, f"./training_{timestamp}")
  55. training_dir = create_directory(dir_path)
  56. # If loading from a parameters/training directory
  57. if args.RESTART_DIRECTORY is not None:
  58. args = load_restart_training_parameters(args)
  59. elif args.PARAMETERS is not None:
  60. args = load_restart_training_parameters(args, justparams=True)
  61. # If testing, lower parameter values to lower computational load and also to lower amount of data being used.
  62. if args.TESTING:
  63. args = load_testing_parameters(args)
  64. train_dataset, valid_dataset = train_valid_split(args, smalldata=True)
  65. else:
  66. train_dataset, valid_dataset = train_valid_split(args, smalldata=False)
  67. logging.info(f'Training size: {len(train_dataset)}')
  68. logging.info(f'Validation size: {len(valid_dataset)}')
  69. # Create dataloaders
  70. dl_opts = {**get_minimagen_dl_opts(device), 'batch_size': args.BATCH_SIZE, 'num_workers': args.NUM_WORKERS}
  71. train_dataloader = torch.utils.data.DataLoader(train_dataset, **dl_opts)
  72. valid_dataloader = torch.utils.data.DataLoader(valid_dataset, **dl_opts)
  73. # Create Unets
  74. if args.RESTART_DIRECTORY is None:
  75. imagen_params = dict(
  76. image_sizes=(int(args.IMG_SIDE_LEN / 2), args.IMG_SIDE_LEN),
  77. timesteps=args.TIMESTEPS,
  78. cond_drop_prob=0.15,
  79. text_encoder_name=args.T5_NAME
  80. )
  81. # If not loading a training from a checkpoint
  82. if args.TESTING:
  83. # If testing, use tiny MinImagen for low computational load
  84. unets_params = [get_default_args(BaseTest), get_default_args(SuperTest)]
  85. # Else if not loading Unet/Imagen settings from a config (parameters) folder, use defaults
  86. elif not args.PARAMETERS:
  87. # If no parameters provided, use params from minimagen.Imagen.Base and minimagen.Imagen.Super built-in classes
  88. unets_params = [get_default_args(Base), get_default_args(Super)]
  89. # Else load unet/Imagen configs from config (parameters) folder (override imagen+params)
  90. else:
  91. # If parameters are provided, load them
  92. unets_params, imagen_params = get_model_params(args.PARAMETERS)
  93. # Create Unets accoridng to unets_params
  94. unets = [Unet(**unet_params).to(device) for unet_params in unets_params]
  95. # Create Imagen from UNets with specified imagen parameters
  96. imagen = Imagen(unets=unets, **imagen_params).to(device)
  97. else:
  98. # If training is being resumed from a previous one, load all relevant models/info (load config AND state dicts)
  99. orig_train_dir = os.path.join(os.getcwd(), args.RESTART_DIRECTORY)
  100. unets_params, imagen_params = load_params(orig_train_dir)
  101. imagen = load_minimagen(orig_train_dir).to(device)
  102. unets = imagen.unets
  103. # Fill in unspecified arguments with defaults for complete config (parameters) file
  104. unets_params = [{**get_default_args(Unet), **i} for i in unets_params]
  105. imagen_params = {**get_default_args(Imagen), **imagen_params}
  106. # Get the size of the Imagen model in megabytes
  107. model_size_MB = get_model_size(imagen)
  108. # Save all training info (config files, model size, etc.)
  109. save_training_info(args, timestamp, unets_params, imagen_params, model_size_MB, training_dir)
  110. # Create optimizer
  111. optimizer = optim.Adam(imagen.parameters(), lr=args.OPTIM_LR)
  112. # Train the MinImagen instance
  113. MinimagenTrain(timestamp, args, unets, imagen, train_dataloader, valid_dataloader, training_dir, optimizer, timeout=30)
  114. avg_losses = calculate_loss_per_unet(unets, imagen, valid_dataloader)
  115. experiment_id = get_experiment_id('minimagen')
  116. with mlflow.start_run(experiment_id=experiment_id):
  117. mlflow.log_params(args.__dict__)
  118. mlflow.log_params({
  119. "unet0_params": unets_params[0],
  120. "unet1_params": unets_params[1],
  121. "imagen_params": imagen_params
  122. })
  123. mlflow.log_param("model_size_MB", model_size_MB)
  124. mlflow.log_metrics({
  125. "unet0_avg_loss": avg_losses[0],
  126. "unet1_avg_loss": avg_losses[1],
  127. })
  128. mlflow.log_artifact(dir_path)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...