Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

train.py 5.1 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
  1. # script for training a CNN classifier
  2. from config import PROCESSED_IMAGES_DIR, MODELS_DIR
  3. from scrt import *
  4. import os
  5. import tensorflow.keras
  6. from tensorflow.keras.models import Sequential
  7. from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout
  8. from tensorflow.keras.preprocessing.image import ImageDataGenerator
  9. import mlflow
  10. from dagshub import dagshub_logger
  11. mlflow.set_tracking_uri("https://dagshub.com/eryk.lewinson/mario_vs_wario_v2.mlflow")
  12. os.environ['MLFLOW_TRACKING_USERNAME'] = USER_NAME
  13. os.environ['MLFLOW_TRACKING_PASSWORD'] = PASSWORD
  14. def get_datasets(validation_ratio=0.2, target_img_size=64, batch_size=32):
  15. """
  16. Train/valid/test split based on this SO answer:
  17. https://stackoverflow.com/questions/42443936/keras-split-train-test-set-when-using-imagedatagenerator
  18. """
  19. train_datagen = ImageDataGenerator(rescale = 1./255,
  20. zoom_range=[0.5, 1.5],
  21. validation_split=validation_ratio)
  22. valid_datagen = ImageDataGenerator(rescale=1./255,
  23. validation_split=validation_ratio)
  24. test_datagen = ImageDataGenerator(rescale = 1./255)
  25. training_set = train_datagen.flow_from_directory(f"{PROCESSED_IMAGES_DIR}/train",
  26. target_size = (target_img_size, target_img_size),
  27. color_mode="grayscale",
  28. batch_size = batch_size,
  29. class_mode = "binary",
  30. shuffle=True,
  31. subset="training")
  32. valid_set = valid_datagen.flow_from_directory(f"{PROCESSED_IMAGES_DIR}/train",
  33. target_size = (target_img_size, target_img_size),
  34. color_mode="grayscale",
  35. batch_size = batch_size,
  36. class_mode = "binary",
  37. shuffle=False,
  38. subset="validation")
  39. test_set = test_datagen.flow_from_directory(f"{PROCESSED_IMAGES_DIR}/test",
  40. target_size = (target_img_size, target_img_size),
  41. color_mode="grayscale",
  42. batch_size = batch_size,
  43. class_mode = "binary")
  44. return training_set, valid_set, test_set
  45. def get_model(input_img_size, lr):
  46. """
  47. Returns a compiled model.
  48. Architecture is fixed, inputs change the image size and the learning rate.
  49. """
  50. # Initializing
  51. model = Sequential()
  52. # 1st conv. layer
  53. model.add(Conv2D(32, (3, 3), input_shape = (input_img_size, input_img_size, 1), activation = "relu"))
  54. model.add(MaxPooling2D(pool_size = (2, 2)))
  55. # 2nd conv. layer
  56. model.add(Conv2D(32, (3, 3), activation = "relu"))
  57. model.add(MaxPooling2D(pool_size = (2, 2)))
  58. # 3nd conv. layer
  59. model.add(Conv2D(64, (3, 3), activation = "relu"))
  60. model.add(MaxPooling2D(pool_size = (2, 2)))
  61. # Flattening
  62. model.add(Flatten())
  63. # Full connection
  64. model.add(Dense(units = 64, activation = "relu"))
  65. model.add(Dropout(0.5))
  66. model.add(Dense(units = 1, activation = "sigmoid"))
  67. model.compile(optimizer = tensorflow.keras.optimizers.Adam(learning_rate=lr),
  68. loss = "binary_crossentropy",
  69. metrics = ["accuracy"])
  70. return model
  71. if __name__ == "__main__":
  72. mlflow.tensorflow.autolog()
  73. IMG_SIZE = 128
  74. LR = 0.001
  75. EPOCHS = 10
  76. with mlflow.start_run():
  77. training_set, valid_set, test_set = get_datasets(validation_ratio=0.2,
  78. target_img_size=IMG_SIZE,
  79. batch_size=32)
  80. model = get_model(IMG_SIZE, LR)
  81. print("Training the model...")
  82. model.fit(training_set,
  83. validation_data=valid_set,
  84. epochs = EPOCHS)
  85. print("Training completed.")
  86. print("Evaluating the model...")
  87. test_loss, test_accuracy = model.evaluate(test_set)
  88. print("Evaluating completed.")
  89. with dagshub_logger() as logger:
  90. logger.log_metrics(loss=test_loss, accuracy=test_accuracy)
  91. logger.log_hyperparams({
  92. "img_size": IMG_SIZE,
  93. "learning_rate": LR,
  94. "epochs": EPOCHS
  95. })
  96. mlflow.log_params({
  97. "img_size": IMG_SIZE,
  98. "learning_rate": LR,
  99. "epochs": EPOCHS
  100. })
  101. mlflow.log_metrics(
  102. {
  103. "test_set_loss": test_loss,
  104. "test_set_accuracy": test_accuracy,
  105. }
  106. )
  107. print("Saving the model...")
  108. model.save(MODELS_DIR)
  109. print("done.")
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...