Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

figs_demo.py 3.9 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
  1. # # Setup
  2. # +
  3. # %load_ext autoreload
  4. # %autoreload 2
  5. import numpy as np
  6. import matplotlib.pyplot as plt
  7. import pandas as pd
  8. from sklearn.model_selection import train_test_split
  9. from sklearn.tree import plot_tree, DecisionTreeClassifier
  10. from sklearn import metrics
  11. # TODo remove when package is updated
  12. import sys,os
  13. sys.path.append(os.path.expanduser('~/imodels'))
  14. # installable with: `pip install imodels`
  15. from imodels import FIGSClassifier
  16. import demo_helper
  17. np.random.seed(13)
  18. # -
  19. # Let's start by loading some data in...
  20. # Note, we need to still load the reg dataset first to get the same splits as in `imodels_demo.ipynb` due to the call to random
  21. # +
  22. # ames housing dataset: https://www.openml.org/search?type=data&status=active&id=43926
  23. X_train_reg, X_test_reg, y_train_reg, y_test_reg, feat_names_reg = demo_helper.get_ames_data()
  24. # diabetes dataset: https://www.openml.org/search?type=data&sort=runs&id=37&status=active
  25. X_train, X_test, y_train, y_test, feat_names = demo_helper.get_diabetes_data()
  26. # feat_names meanings:
  27. # ["#Pregnant", "Glucose concentration test", "Blood pressure(mmHg)",
  28. # "Triceps skin fold thickness(mm)",
  29. # "2-Hour serum insulin (mu U/ml)", "Body mass index", "Diabetes pedigree function", "Age (years)"]
  30. # load some data
  31. # print('Regression data training', X_train_reg.shape, 'Classification data training', X_train.shape)
  32. # -
  33. # ***
  34. # # FIGS
  35. model_figs = FIGSClassifier(max_rules=7, max_trees=3)
  36. model_figs.fit(X_train, y_train, feature_names=feat_names);
  37. print(model_figs)
  38. print(model_figs.print_tree(X_train, y_train))
  39. model_figs.plot(fig_size=7)
  40. # ## Gini Importance
  41. dfp_importance = pd.DataFrame({'feat_names': feat_names})
  42. dfp_importance['feature'] = dfp_importance.index
  43. dfp_importance_gini = pd.DataFrame({'importance_gini': model_figs.feature_importances_})
  44. dfp_importance_gini['feature'] = dfp_importance_gini.index
  45. dfp_importance_gini['importance_gini_pct'] = dfp_importance_gini['importance_gini'].rank(pct=True)
  46. dfp_importance = pd.merge(dfp_importance, dfp_importance_gini, on='feature', how='left')
  47. dfp_importance = dfp_importance.sort_values(by=['importance_gini', 'feature'], ascending=[False, True]).reset_index(drop=True)
  48. display(dfp_importance)
  49. # ***
  50. # # `dtreeviz` Integration
  51. # One tree at a time only, showing tree 0 here
  52. # +
  53. import dtreeviz
  54. from imodels.tree.viz_utils import extract_sklearn_tree_from_figs
  55. dt = extract_sklearn_tree_from_figs(model_figs, tree_num=0, n_classes=2)
  56. viz_model = dtreeviz.model(dt, X_train=X_train, y_train=y_train, feature_names=feat_names, target_name='y', class_names=[0, 1])
  57. # -
  58. color_params = {'classes': dtreeviz.colors.mpl_colors, 'hist_bar': 'C0', 'legend_edge': None}
  59. for _ in ['axis_label', 'title', 'legend_title', 'text', 'arrow', 'node_label', 'tick_label', 'leaf_label', 'wedge', 'text_wedge']:
  60. color_params[_] = 'black'
  61. dtv_params_gen = {'colors': color_params, 'fontname': 'Arial', 'figsize': (4, 3)}
  62. dtv_params = {'leaftype': 'barh',
  63. 'label_fontsize': 10,
  64. 'colors': dtv_params_gen['colors'],
  65. 'fontname': dtv_params_gen['fontname']
  66. }
  67. viz_model.view(**dtv_params)
  68. x_example = X_train[13]
  69. display(pd.DataFrame([{col: value for col,value in zip(feat_names, x_example)}]))
  70. print(viz_model.explain_prediction_path(x=x_example))
  71. viz_model.view(**dtv_params, x=x_example)
  72. viz_model.view(**dtv_params, show_node_labels=True, fancy=False)
  73. viz_model.ctree_leaf_distributions(**dtv_params_gen)
  74. viz_model.leaf_purity(display_type='plot', **dtv_params_gen)
  75. # ***
  76. # # `SKompiler` Integration
  77. # One tree at a time only, showing tree 0 here
  78. # +
  79. from skompiler import skompile
  80. from imodels.tree.viz_utils import extract_sklearn_tree_from_figs
  81. dt = extract_sklearn_tree_from_figs(model_figs, tree_num=0, n_classes=2)
  82. expr = skompile(dt.predict_proba, feat_names)
  83. # +
  84. # Currently broken, see https://github.com/konstantint/SKompiler/issues/16
  85. # print(expr.to('sqlalchemy/sqlite', component=1, assign_to='tree_0'))
  86. # -
  87. print(expr.to('python/code'))
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...