Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

checks.py 2.2 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
  1. from inspect import isclass
  2. def check_is_fitted(estimator, attributes=None, *, msg=None, all_or_any=all):
  3. """Perform is_fitted validation for estimator.
  4. Checks if the estimator is fitted by verifying the presence of
  5. fitted attributes (ending with a trailing underscore) and otherwise
  6. raises a NotFittedError with the given message.
  7. If an estimator does not set any attributes with a trailing underscore, it
  8. can define a ``__sklearn_is_fitted__`` method returning a boolean to specify if the
  9. estimator is fitted or not.
  10. Parameters
  11. ----------
  12. estimator : estimator instance
  13. estimator instance for which the check is performed.
  14. attributes : str, list or tuple of str, default=None
  15. Attribute name(s) given as string or a list/tuple of strings
  16. Eg.: ``["coef_", "estimator_", ...], "coef_"``
  17. If `None`, `estimator` is considered fitted if there exist an
  18. attribute that ends with a underscore and does not start with double
  19. underscore.
  20. msg : str, default=None
  21. The default error message is, "This %(name)s instance is not fitted
  22. yet. Call 'fit' with appropriate arguments before using this
  23. estimator."
  24. For custom messages if "%(name)s" is present in the message string,
  25. it is substituted for the estimator name.
  26. Eg. : "Estimator, %(name)s, must be fitted before sparsifying".
  27. all_or_any : callable, {all, any}, default=all
  28. Specify whether all or any of the given attributes must exist.
  29. Returns
  30. -------
  31. fitted: bool
  32. """
  33. if isclass(estimator):
  34. raise TypeError("{} is a class, not an instance.".format(estimator))
  35. if not hasattr(estimator, "fit"):
  36. raise TypeError("%s is not an estimator instance." % (estimator))
  37. if attributes is not None:
  38. if not isinstance(attributes, (list, tuple)):
  39. attributes = [attributes]
  40. return all_or_any([hasattr(estimator, attr) for attr in attributes])
  41. elif hasattr(estimator, "__sklearn_is_fitted__"):
  42. return estimator.__sklearn_is_fitted__()
  43. else:
  44. return len([
  45. v for v in vars(estimator) if v.endswith("_") and not v.startswith("__")
  46. ]) > 0
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...