Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

keras.py 3.5 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
  1. import operator
  2. import os
  3. from wandb import history
  4. from wandb import summary
  5. # Fully implemented here so we don't have to pull in keras as a dependency.
  6. # However, if the user is using this, they necessarily have Keras installed. So we
  7. # could probably selectively build this class only when the user requests it,
  8. # knowing that keras is available.
  9. #
  10. # Or have a separate lib "wandb-keras", then we could use the appropriate Keras
  11. # pieces
  12. class WandBKerasCallback(object):
  13. """WandB Keras Callback.
  14. Automatically saves wandb-history.csv and wandb-summary.csv, both tracking
  15. keras metrics.
  16. """
  17. def __init__(self, out_dir='.', monitor='val_loss', verbose=0, mode='auto',
  18. save_weights_only=False):
  19. """Constructor.
  20. Args:
  21. out_dir: Directory to save history/summary files in.
  22. See keras.ModelCheckpoint for other definitions of other
  23. arguments.
  24. """
  25. self.validation_data = None
  26. self.out_dir = out_dir
  27. self.history = None
  28. self.summary = None
  29. self.monitor = monitor
  30. self.verbose = verbose
  31. self.save_weights_only = save_weights_only
  32. self.filepath = os.path.join(out_dir, 'model-best.h5')
  33. # From Keras
  34. if mode not in ['auto', 'min', 'max']:
  35. print('WandBKerasCallback mode %s is unknown, '
  36. 'fallback to auto mode.' % (mode))
  37. mode = 'auto'
  38. if mode == 'min':
  39. self.monitor_op = operator.lt
  40. self.best = float('inf')
  41. elif mode == 'max':
  42. self.monitor_op = operator.gt
  43. self.best = float('-inf')
  44. else:
  45. if 'acc' in self.monitor or self.monitor.startswith('fmeasure'):
  46. self.monitor_op = operator.gt
  47. self.best = float('-inf')
  48. else:
  49. self.monitor_op = operator.lt
  50. self.best = float('inf')
  51. def set_params(self, params):
  52. self.params = params
  53. def set_model(self, model):
  54. self.model = model
  55. def on_epoch_begin(self, epoch, logs=None):
  56. pass
  57. def on_epoch_end(self, epoch, logs=None):
  58. # history
  59. if self.history is None:
  60. self.history = history.History(
  61. ['epoch'] + sorted(logs.keys()),
  62. out_dir=self.out_dir)
  63. row = {'epoch': epoch}
  64. row.update(logs)
  65. self.history.add(row)
  66. # summary
  67. current = logs.get(self.monitor)
  68. if current is None:
  69. print('Can save best model only with %s available, '
  70. 'skipping.' % (self.monitor))
  71. if self.monitor_op(current, self.best):
  72. row.pop('epoch')
  73. self.summary.update(row)
  74. if self.verbose > 0:
  75. print('Epoch %05d: %s improved from %0.5f to %0.5f,'
  76. ' saving model to %s'
  77. % (epoch, self.monitor, self.best,
  78. current, self.filepath))
  79. self.best = current
  80. if self.save_weights_only:
  81. self.model.save_weights(self.filepath, overwrite=True)
  82. else:
  83. self.model.save(self.filepath, overwrite=True)
  84. def on_batch_begin(self, batch, logs=None):
  85. pass
  86. def on_batch_end(self, batch, logs=None):
  87. pass
  88. def on_train_begin(self, logs=None):
  89. self.summary = summary.Summary(self.out_dir)
  90. def on_train_end(self, logs=None):
  91. if self.history is not None:
  92. self.history.close()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...