Hierarchical shrinkage: improving the accuracy and interpretability of tree-based methods

Abhineet Agarwal*, Yan Shuo Tan*, Omer Ronen, Chandan Singh, Bin Yu


📄 Paper (ICML 2022), 🗂 Doc, 📌 Citation

Hierarchical shrinkage is an extremely fast post-hoc regularization method which works on any decision tree (or tree-based ensemble, such as Random Forest). It does not modify the tree structure, and instead regularizes the tree by shrinking the prediction over each node towards the sample means of its ancestors (using a single regularization parameter). Experiments over a wide variety of datasets show that hierarchical shrinkage substantially increases the predictive performance of individual decision trees and decision-tree ensembles.

How does Hierarchical shrinkage work?

Fig 1. HS applies post-hoc regularization to any decision tree by shrinking each node towards its parent. This is done after a tree has been trained. The amount of shrinkage can be varied using a regularization param (this works best if the parameter is chosen via cross-validation).

An example using HS

HS can be used in the same way as standard scikit-learn models: simply import a classifier or regressor and use the fit and predict methods. Here's a full example of using it on a sample clinical dataset.
    
from imodels import HSTreeClassifierCV, get_clean_dataset
from sklearn.model_selection import train_test_split
from sklearn.tree import plot_tree

# prepare data (in this a sample clinical dataset)
X, y, feat_names = get_clean_dataset('csi_pecarn_pred')
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.33, random_state=42)

# fit the model
model = HSTreeClassifierCV(max_leaf_nodes=7)  # initialize a model
model.fit(X_train, y_train)   # fit model
preds = model.predict(X_test) # discrete predictions: shape is (n_test, 1)
preds_proba = model.predict_proba(X_test) # predicted probabilities: shape is (n_test, n_classes)

# visualize the model
plot_tree(model.estimator_, feature_names=feat_names)
        


Fig 2. Simple model learned by HS for predicting risk of cervical spinal injury.

Examples with HS on synthetic data

See some examples of how hierarchical shrinkage works on one-dimensional functions which are fitted with a CART decision tree.

Fig 3. Step function.

Fig 4. Linear function.