Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

distillation.py 2.1 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
  1. from sklearn.base import RegressorMixin, BaseEstimator, is_regressor
  2. class DistilledRegressor(BaseEstimator, RegressorMixin):
  3. """
  4. Class to implement distillation. Currently only supports regression.
  5. Params
  6. ------
  7. teacher: initial model to be trained
  8. must be a regressor or a binary classifier
  9. student: model to be distilled from teacher's predictions
  10. must be a regressor
  11. """
  12. def __init__(self, teacher: BaseEstimator, student: BaseEstimator,
  13. n_iters_teacher: int=1):
  14. self.teacher = teacher
  15. self.student = student
  16. self.n_iters_teacher = n_iters_teacher
  17. self._validate_student()
  18. self._check_teacher_type()
  19. def _validate_student(self):
  20. if is_regressor(self.student):
  21. pass
  22. else:
  23. if not hasattr(self.student, "prediction_task"):
  24. raise ValueError("Student must be either a scikit-learn or imodels regressor")
  25. elif self.student.prediction_task == "classification":
  26. raise ValueError("Student must be a regressor")
  27. def _check_teacher_type(self):
  28. if hasattr(self.teacher, "prediction_task"):
  29. self.teacher_type = self.teacher.prediction_task
  30. elif hasattr(self.teacher, "_estimator_type"):
  31. if is_regressor(self.teacher):
  32. self.teacher_type = "regression"
  33. else:
  34. self.teacher_type = "classification"
  35. def set_teacher_params(self, **params):
  36. self.teacher.set_params(**params)
  37. def set_student_params(self, **params):
  38. self.student.set_params(**params)
  39. def fit(self, X, y, **kwargs):
  40. # fit teacher
  41. for iter_teacher in range(self.n_iters_teacher):
  42. self.teacher.fit(X, y, **kwargs)
  43. if self.teacher_type == "regression":
  44. y = self.teacher.predict(X)
  45. else:
  46. y = self.teacher.predict_proba(X)[:, 1] # assumes binary classifier
  47. # fit student
  48. self.student.fit(X, y)
  49. def predict(self, X):
  50. return self.student.predict(X)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...