Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

model_manager.py 3.0 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
  1. """
  2. f.learning model manager that trains and evaluates model
  3. """
  4. import numpy as np
  5. import tensorflow as tf
  6. from src.utils.utils import (
  7. batch_sample_indices,
  8. )
  9. from typing import Callable
  10. class ClassifierModelManager():
  11. def __init__(self,
  12. get_model_fn: Callable,
  13. n_train_epochs: int,
  14. batch_size: int=32,
  15. is_debug: bool=False,
  16. ):
  17. self.get_model_fn = get_model_fn
  18. self.n_train_epochs = n_train_epochs
  19. self.batch_size= batch_size
  20. self.is_debug = is_debug,
  21. self.optimizer = self._get_optimizer()
  22. self.loss_fn = self._get_loss()
  23. self.model = get_model_fn()
  24. def reset_model(self, clear_backend=False):
  25. if clear_backend:
  26. tf.keras.backend.clear_session()
  27. del self.model
  28. self.model = self.get_model_fn()
  29. def _get_optimizer(self):
  30. # TODO add optimizer args
  31. return tf.keras.optimizers.Adam()
  32. def _get_loss(self):
  33. # TODO weight a class loss
  34. return tf.keras.losses.CategoricalCrossentropy(
  35. from_logits=True,
  36. reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE)
  37. def train_model(
  38. self,
  39. train_x,
  40. train_y,
  41. n_train_epochs=None,
  42. batch_size=None) -> None:
  43. """
  44. trains model for n epochs
  45. """
  46. model = self.model
  47. if n_train_epochs is None:
  48. n_train_epochs = self.n_train_epochs
  49. if batch_size is None:
  50. batch_size = self.batch_size
  51. optimizer = self.optimizer
  52. loss_fn = self.loss_fn
  53. data_size = train_x.shape[0]
  54. for i in range(n_train_epochs):
  55. for idx, batch in enumerate(
  56. batch_sample_indices(data_size, batch_size=batch_size)):
  57. batch_x, batch_y = train_x[batch], train_y[batch]
  58. with tf.GradientTape() as tape:
  59. prediction = model(batch_x)
  60. loss = loss_fn(batch_y, prediction)
  61. grads = tape.gradient(loss, model.trainable_variables)
  62. optimizer.apply_gradients(zip(grads, model.trainable_variables))
  63. # TODO debug logging
  64. def evaluate_model(
  65. self,
  66. test_x,
  67. test_y,
  68. batch_size=None) -> dict:
  69. """
  70. generator for evaluating model and input, prediction, and true label
  71. """
  72. model = self.model
  73. if batch_size is None:
  74. batch_size = self.batch_size
  75. data_size = test_x.shape[0]
  76. for idx, test_batch in enumerate(
  77. batch_sample_indices(data_size, batch_size=batch_size)):
  78. batch_x, batch_y = test_x[test_batch], test_y[test_batch]
  79. raw_prediction = model(batch_x, training=False)
  80. batch_loss = self.loss_fn(batch_y, raw_prediction)
  81. yield batch_x, batch_y, raw_prediction, batch_loss
  82. def save_model(self, model_fpath: str):
  83. self.model.save_weights(model_fpath)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...