Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

explain_errors.py 2.0 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
  1. import numpy as np
  2. import pandas as pd
  3. from sklearn.base import BaseEstimator
  4. from sklearn.utils.validation import check_X_y
  5. import imodels
  6. def explain_classification_errors(X, predictions, y,
  7. feature_names: list = None,
  8. target_name: str = None,
  9. classifier: BaseEstimator = imodels.GreedyTreeClassifier(),
  10. target_one_hot_encode: bool = False,
  11. print_rules: bool = True):
  12. """Explains the classification errors of a model by fitting an interpretable model to them.
  13. Currently only supports binary classification.
  14. Parameters
  15. ----------
  16. X: array_like
  17. (n, n_features)
  18. predictions: array_like
  19. (n, 1) predictions
  20. y
  21. (n, 1) targets with integer values representing class
  22. feature_names
  23. n_features
  24. Returns
  25. -------
  26. model: BaseEstimator
  27. """
  28. # deal with names
  29. if feature_names is None:
  30. if isinstance(X, pd.DataFrame):
  31. feature_names = X.columns.tolist()
  32. else:
  33. feature_names = [f'X{i + 1}' for i in range(X.shape[1])]
  34. if target_name is None:
  35. if isinstance(y, pd.DataFrame):
  36. target_name = y.columns[0]
  37. elif isinstance(y, pd.Series):
  38. target_name = y.name
  39. else:
  40. target_name = 'target'
  41. if isinstance(predictions, pd.Series) or isinstance(predictions, pd.DataFrame):
  42. predictions = predictions.values
  43. X, y = check_X_y(X, y) # converts to np
  44. if len(y.shape) == 1:
  45. y = y.reshape(-1, 1)
  46. if len(predictions.shape) == 1:
  47. predictions = predictions.reshape(-1, 1)
  48. errors = np.array(predictions != y).astype(int)
  49. features = pd.DataFrame(np.hstack((X, y)))
  50. features.columns = [*feature_names, target_name]
  51. classifier.fit(features, errors.flatten())
  52. if print_rules:
  53. print(classifier)
  54. return classifier, features.columns
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...