Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

train.py 3.8 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
  1. import tensorflow as tf
  2. import json
  3. from .utiles.functions import print_data, load_dataset
  4. from .const.train_const import *
  5. from .const.general_const import *
  6. def preprocess_data_layers():
  7. data_augmentation_layer = tf.keras.Sequential(
  8. [
  9. tf.keras.layers.experimental.preprocessing.RandomFlip("horizontal"),
  10. tf.keras.layers.experimental.preprocessing.RandomRotation(0.05),
  11. tf.keras.layers.experimental.preprocessing.RandomZoom(0.05)
  12. ]
  13. )
  14. rescale_layer = tf.keras.layers.experimental.preprocessing.Rescaling(1. / 255)
  15. return data_augmentation_layer, rescale_layer
  16. def build_model(data_augmentation, rescale, img_shape=IMG_SHAPE,
  17. learning_rate=LEARNING_RATE, num_class=2):
  18. base_model = tf.keras.applications.VGG19(include_top=False,
  19. weights='imagenet',
  20. input_shape=img_shape)
  21. image_batch, label_batch = next(iter(train_dataset))
  22. feature_batch = base_model(image_batch)
  23. # print(feature_batch.shape)
  24. base_model.trainable = False
  25. global_average_layer = tf.keras.layers.GlobalAveragePooling2D()
  26. feature_batch_average = global_average_layer(feature_batch)
  27. # print(feature_batch_average.shape)
  28. prediction_layer = tf.keras.layers.Dense(num_class, activation='softmax')
  29. prediction_batch = prediction_layer(feature_batch_average)
  30. # print(prediction_batch.shape)
  31. inputs = tf.keras.Input(shape=img_shape)
  32. x = data_augmentation(inputs)
  33. x = rescale(x)
  34. x = base_model(x, training=False)
  35. x = global_average_layer(x)
  36. x = tf.keras.layers.Dropout(0.2)(x)
  37. outputs = prediction_layer(x)
  38. model = tf.keras.Model(inputs, outputs)
  39. model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
  40. loss='categorical_crossentropy',
  41. metrics=['accuracy'])
  42. return model
  43. def get_callbacks(csv_logger_path, checkpoint_filepath):
  44. csv_logger = tf.keras.callbacks.CSVLogger(csv_logger_path)
  45. early_stopping = tf.keras.callbacks.EarlyStopping(patience=4, monitor='val_accuracy')
  46. model_checkpoint = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_filepath,
  47. monitor='val_accuracy',
  48. mode='max',
  49. save_best_only=True)
  50. return csv_logger, early_stopping, model_checkpoint
  51. if __name__ == '__main__':
  52. train_dataset = load_dataset(RAW_TRAIN_PATH, BATCH_SIZE, IMG_SIZE, LABEL_MODE)
  53. validation_dataset = load_dataset(RAW_TEST_PATH, BATCH_SIZE, IMG_SIZE, LABEL_MODE)
  54. class_names = train_dataset.class_names
  55. with open(CLASS_NAME_PATH, "w") as textfile:
  56. textfile.write(",".join(class_names))
  57. if NOTEBOOK:
  58. print_data(validation_dataset, class_names, notebook=NOTEBOOK, process=False, save=False, predict=False)
  59. data_augmentation_layer, rescale_layer = preprocess_data_layers()
  60. print_data(validation_dataset, class_names, notebook=NOTEBOOK,
  61. process=True, save=True, predict=False,
  62. save_path='eval/processed_data', rescale_layer=rescale_layer,
  63. data_augmentation_layer=data_augmentation_layer)
  64. model = build_model(data_augmentation_layer, rescale_layer, IMG_SHAPE, LEARNING_RATE)
  65. csv_logger, early_stopping, model_checkpoint = get_callbacks(CSV_LOG_PATH, CHECKPOINT_PATH)
  66. history = model.fit(train_dataset,
  67. epochs=INIT_EPOCHS,
  68. validation_data=validation_dataset,
  69. callbacks=[csv_logger, early_stopping, model_checkpoint])
  70. json.dump(history.params, open("model/TF-Model-Checkpoint/history_params.json", 'w'))
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...