Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

keras_checkpoint_saver_callback.py 5.4 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
  1. import time
  2. import datetime
  3. import logging
  4. from typing import Optional, Dict
  5. from collections import defaultdict
  6. import tensorflow as tf
  7. from tensorflow.python import keras
  8. from tensorflow.python.keras.callbacks import Callback
  9. from config import Config
  10. class ModelTrainingStatus:
  11. def __init__(self):
  12. self.nr_epochs_trained: int = 0
  13. self.trained_full_last_epoch: bool = False
  14. class ModelTrainingStatusTrackerCallback(Callback):
  15. def __init__(self, training_status: ModelTrainingStatus):
  16. self.training_status: ModelTrainingStatus = training_status
  17. super(ModelTrainingStatusTrackerCallback, self).__init__()
  18. def on_epoch_begin(self, epoch, logs=None):
  19. self.training_status.trained_full_last_epoch = False
  20. def on_epoch_end(self, epoch, logs=None):
  21. assert self.training_status.nr_epochs_trained == epoch
  22. self.training_status.nr_epochs_trained += 1
  23. self.training_status.trained_full_last_epoch = True
  24. class ModelCheckpointSaverCallback(Callback):
  25. """
  26. @model_wrapper should have a `.save()` method.
  27. """
  28. def __init__(self, model_wrapper, nr_epochs_to_save: int = 1,
  29. logger: logging.Logger = None):
  30. self.model_wrapper = model_wrapper
  31. self.nr_epochs_to_save: int = nr_epochs_to_save
  32. self.logger = logger if logger is not None else logging.getLogger()
  33. self.last_saved_epoch: Optional[int] = None
  34. super(ModelCheckpointSaverCallback, self).__init__()
  35. def on_epoch_begin(self, epoch, logs=None):
  36. if self.last_saved_epoch is None:
  37. self.last_saved_epoch = (epoch + 1) - 1
  38. def on_epoch_end(self, epoch, logs=None):
  39. nr_epochs_trained = epoch + 1
  40. nr_non_saved_epochs = nr_epochs_trained - self.last_saved_epoch
  41. if nr_non_saved_epochs >= self.nr_epochs_to_save:
  42. self.logger.info('Saving model after {} epochs.'.format(nr_epochs_trained))
  43. self.model_wrapper.save()
  44. self.logger.info('Done saving model.')
  45. self.last_saved_epoch = nr_epochs_trained
  46. class MultiBatchCallback(Callback):
  47. def __init__(self, multi_batch_size: int, average_logs: bool = False):
  48. self.multi_batch_size = multi_batch_size
  49. self.average_logs = average_logs
  50. self._multi_batch_start_time: int = 0
  51. self._multi_batch_logs_sum: Dict[str, float] = defaultdict(float)
  52. super(MultiBatchCallback, self).__init__()
  53. def on_batch_begin(self, batch, logs=None):
  54. if self.multi_batch_size == 1 or (batch + 1) % self.multi_batch_size == 1:
  55. self._multi_batch_start_time = time.time()
  56. if self.average_logs:
  57. self._multi_batch_logs_sum = defaultdict(float)
  58. def on_batch_end(self, batch, logs=None):
  59. if self.average_logs:
  60. assert isinstance(logs, dict)
  61. for log_key, log_value in logs.items():
  62. self._multi_batch_logs_sum[log_key] += log_value
  63. if self.multi_batch_size == 1 or (batch + 1) % self.multi_batch_size == 0:
  64. multi_batch_elapsed = time.time() - self._multi_batch_start_time
  65. if self.average_logs:
  66. multi_batch_logs = {log_key: log_value / self.multi_batch_size
  67. for log_key, log_value in self._multi_batch_logs_sum.items()}
  68. else:
  69. multi_batch_logs = logs
  70. self.on_multi_batch_end(batch, multi_batch_logs, multi_batch_elapsed)
  71. def on_multi_batch_end(self, batch, logs, multi_batch_elapsed):
  72. pass
  73. class ModelTrainingProgressLoggerCallback(MultiBatchCallback):
  74. def __init__(self, config: Config, training_status: ModelTrainingStatus):
  75. self.config = config
  76. self.training_status = training_status
  77. self.avg_throughput: Optional[float] = None
  78. super(ModelTrainingProgressLoggerCallback, self).__init__(
  79. self.config.NUM_BATCHES_TO_LOG_PROGRESS, average_logs=True)
  80. def on_train_begin(self, logs=None):
  81. self.config.log('Starting training...')
  82. def on_epoch_end(self, epoch, logs=None):
  83. self.config.log('Completed epoch #{}: {}'.format(epoch + 1, logs))
  84. def on_multi_batch_end(self, batch, logs, multi_batch_elapsed):
  85. nr_samples_in_multi_batch = self.config.TRAIN_BATCH_SIZE * \
  86. self.config.NUM_BATCHES_TO_LOG_PROGRESS
  87. throughput = nr_samples_in_multi_batch / multi_batch_elapsed
  88. if self.avg_throughput is None:
  89. self.avg_throughput = throughput
  90. else:
  91. self.avg_throughput = 0.5 * throughput + 0.5 * self.avg_throughput
  92. remained_batches = self.config.train_steps_per_epoch - (batch + 1)
  93. remained_samples = remained_batches * self.config.TRAIN_BATCH_SIZE
  94. remained_time_sec = remained_samples / self.avg_throughput
  95. self.config.log(
  96. 'Train: during epoch #{epoch} batch {batch}/{tot_batches} ({batch_precision}%) -- '
  97. 'throughput (#samples/sec): {throughput} -- epoch ETA: {epoch_ETA} -- loss: {loss:.4f}'.format(
  98. epoch=self.training_status.nr_epochs_trained + 1,
  99. batch=batch + 1,
  100. batch_precision=int(((batch + 1) / self.config.train_steps_per_epoch) * 100),
  101. tot_batches=self.config.train_steps_per_epoch,
  102. throughput=int(throughput),
  103. epoch_ETA=str(datetime.timedelta(seconds=int(remained_time_sec))),
  104. loss=logs['loss']))
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...