Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

train.py 8.1 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
  1. '''
  2. This contains the code to Train a Model to Predict the Origin of Blood Clot
  3. '''
  4. import tensorflow as tf
  5. from tensorflow import keras
  6. import cv2
  7. from tensorflow.keras.preprocessing.image import random_rotation,random_shift,random_brightness
  8. from tensorflow.keras import layers
  9. from tensorflow.keras.models import Model
  10. from tensorflow.keras.models import Sequential
  11. from tensorflow.keras.layers import Conv2D, MaxPooling2D, Dropout, Flatten, Dense, GlobalMaxPooling2D, GlobalAveragePooling2D, BatchNormalization
  12. from tensorflow.keras.callbacks import LearningRateScheduler, EarlyStopping, Callback
  13. from tensorflow.keras.applications import EfficientNetB5
  14. import math
  15. from tensorflow.keras.optimizers import Adam
  16. from config import *
  17. from utils import *
  18. from dagshub.upload import Repo
  19. import glob
  20. print("Tensorflow Version is ")
  21. print(tf.__version__)
  22. tf.debugging.set_log_device_placement(True) ## This allows us to see which device is being used for building the model - wither CPU or GPU
  23. print("List of Tensorflow Devices ")
  24. gpus = tf.config.experimental.list_physical_devices("GPU")
  25. os.environ["AZUREML_ARTIFACTS_DEFAULT_TIMEOUT"] = "3000"
  26. print("AZUREML_ARTIFACTS_DEFAULT_TIMEOUT Timeout Set is ",os.environ["AZUREML_ARTIFACTS_DEFAULT_TIMEOUT"])
  27. import mlflow
  28. import pathlib
  29. mlflow.autolog()
  30. def step_decay(epoch):
  31. initial_lrate = 0.001
  32. drop = 0.5
  33. epochs_drop = 10.0
  34. lrate = initial_lrate * math.pow(drop, math.floor((epoch)/epochs_drop))
  35. return lrate
  36. def model_EfficentNetB5(efficient_net_weights, lr = 0.001, dr_rate = 0.15):
  37. model = EfficientNetB5(include_top=False, weights=efficient_net_weights)
  38. model.trainable = False
  39. # Rebuild top
  40. x = GlobalAveragePooling2D()(model.output)
  41. x = BatchNormalization()(x)
  42. x = Dropout(dr_rate)(x)
  43. dense_1 = Dense(64, activation="relu")(x)
  44. dense_2 = Dense(32, activation="relu")(dense_1)
  45. outputs = Dense(1, activation="sigmoid")(dense_2)
  46. # Compile
  47. model = Model(model.inputs, outputs, name="EfficientNet")
  48. optimizer = Adam(learning_rate=lr)
  49. model.compile(
  50. optimizer=optimizer, loss="binary_crossentropy", metrics=["binary_accuracy"]
  51. )
  52. return model
  53. '''
  54. This data generator reads the data from a given path and generates batches of data .
  55. '''
  56. class DataGenerator(keras.utils.Sequence):
  57. 'Generates data for Keras'
  58. def __init__(self,dataframe,data_directory,dimensions=(512,512),batch_size: int=16,shuffle=True,num_channels=3,mode="train",rotation_range=None,width_shift_range=None,height_shift_range=None,brightness_range=None,horizontal_flip=False):
  59. '''
  60. Initialise the data .
  61. '''
  62. #self.df=data.copy()
  63. self.batch_size=batch_size
  64. self.dim=dimensions
  65. self.data_directory=data_directory
  66. self.shuffle=shuffle
  67. self.rotation_range=rotation_range
  68. self.horizontal_flip=horizontal_flip
  69. self.brightness_range=brightness_range
  70. self.height_shift_range=height_shift_range
  71. self.width_shift_range=width_shift_range
  72. self.fs=create_streaming_client()
  73. if mode=="train":
  74. dat=dataframe[dataframe['is_train']=="train"]
  75. dat=dat.reset_index(drop=True)
  76. else:
  77. dat=dataframe[dataframe['is_train']=="val"]
  78. dat=dat.reset_index(drop=True)
  79. self.images=dat['image_id'].tolist()
  80. self.labels=dat['int_labels'].tolist()
  81. unique_labels=set(self.labels)
  82. self.n_channels=num_channels
  83. self.n_classes = len(unique_labels)
  84. #print("Number of Channels ",self.n_channels)
  85. #print("Number of Labels ",self.n_classes)
  86. print("Number of Images for "+mode+" is "+str(len(self.images)))
  87. self.on_epoch_end()
  88. def __len__(self):
  89. 'Denotes the number of batches per epoch'
  90. return int(np.floor(len(self.images) / self.batch_size))
  91. def __getitem__(self, index):
  92. 'Generate one batch of data'
  93. # Generate indexes of the batch
  94. indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
  95. # Find list of IDs
  96. #list_IDs_temp = [self.list_IDs[k] for k in indexes]
  97. # Generate data
  98. X, y = self.__data_generation(indexes)
  99. return X, y
  100. def on_epoch_end(self):
  101. 'Updates indexes after each epoch. If shuffle is true, this will shuffle the dataset'
  102. print("In On_EPOCH_END")
  103. self.indexes = np.arange(len(self.images))
  104. if self.shuffle == True:
  105. np.random.shuffle(self.indexes)
  106. def __data_generation(self, indexes):
  107. 'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels)
  108. # Initialization
  109. X = np.empty((self.batch_size, *self.dim, self.n_channels))
  110. y = np.empty((self.batch_size), dtype=int)
  111. ### Generate data based on the indexes.
  112. for i,idx in enumerate(indexes):
  113. image_id=self.images[idx]
  114. img_path=os.path.join(self.data_directory,image_id+".png")
  115. ### Read the Images using Streaming Client
  116. img=read_images(self.fs,img_path)
  117. #img=cv2.imread(img_path)
  118. img=img/255
  119. img=cv2.resize(img,self.dim)
  120. if self.rotation_range!=None:
  121. img=random_rotation(img,self.rotation_range)
  122. if self.width_shift_range!=None:
  123. img=random_shift(img,wrg=self.width_shift_range,hrg=0)
  124. if self.height_shift_range!=None:
  125. img=random_shift(img,hrg=self.height_shift_range,wrg=0)
  126. if self.brightness_range!=None:
  127. img=random_brightness(img,brightness_range=self.brightness_range)
  128. if self.horizontal_flip==True:
  129. img=cv2.flip(img, 1)
  130. X[i,]=img
  131. ### Store the labels
  132. y[i]=self.labels[idx]
  133. print("Data Shape Printing")
  134. print(X.shape)
  135. print(y.shape)
  136. return X,y
  137. ### Now let us clone the Git Repo first
  138. print("Cloning the Repo")
  139. gitclone()
  140. print("Done Cloning")
  141. fs=create_streaming_client()
  142. data=get_train_dataframe(fs,"train.csv")
  143. data=train_split(data) ## Splitting the data into train and val
  144. ### Create Training and Validation Generator
  145. train_data_generator=DataGenerator(dataframe=data,data_directory=TRAIN_DATA_PATH,mode="train",rotation_range=10, width_shift_range=0.2, height_shift_range=0.2,horizontal_flip=True,brightness_range=[0.2, 1.2],batch_size=32)
  146. validation_data_generator=DataGenerator(dataframe=data,data_directory=TRAIN_DATA_PATH,mode="validation",rotation_range=10, width_shift_range=0.2, height_shift_range=0.2,horizontal_flip=True,brightness_range=[0.2, 1.2],batch_size=32)
  147. lrate = LearningRateScheduler(step_decay)
  148. earstop = EarlyStopping(monitor = 'val_loss', min_delta = 0, patience = 3)
  149. ### Load the EfficientNet Model
  150. efficientWeight=download_EfficientNet(fs,'efficientnet-b5_tf24_imagenet_1000_notop.h5')
  151. efficentB5 = model_EfficentNetB5(efficientWeight)
  152. ###
  153. history_0 = efficentB5.fit(
  154. train_data_generator,
  155. epochs = 4,
  156. validation_data = validation_data_generator,
  157. verbose = 1,
  158. callbacks = [lrate, earstop]
  159. )
  160. #pathlib.Path("Blood_Clot_Prediction_Models").mkdir(parents=True,exist_ok=True)
  161. print("saving model")
  162. efficentB5.save('outputs/efficientNet_Model')
  163. print("Going to Upload Files to Dagshub")
  164. ### Let us then upload the files from the EfficientNet_Model to Dagshub Repo
  165. ## STep 1: Connect to the Repo
  166. repo = Repo("aiswaryasrinivas",DAGSHUB_REPO_NAME, username=DAGSHUB_USERNAME ,password=DAGSHUB_TOKEN)
  167. ### Uploading all the files into the model folder.
  168. for __file__ in glob.glob("outputs/efficientNet_Model/*.pb"):
  169. filename=os.path.basename(__file__)
  170. repo.upload(file=__file__, path="model/efficientNet_Model/"+filename, commit_message = "file added"+filename,versioning="dvc")
  171. for __file__ in glob.glob("outputs/efficientNet_Model/variables/*"):
  172. print(__file__)
  173. filename=os.path.basename(__file__)
  174. repo.upload(file=__file__, path="model/efficientNet_Model/variables/"+filename, commit_message = "file added "+filename,versioning="dvc")
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...