Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

gam_test.py 1.3 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
  1. import os
  2. import random
  3. from functools import partial
  4. import numpy as np
  5. import pandas as pd
  6. from sklearn.tree import DecisionTreeRegressor
  7. from sklearn import metrics
  8. import imodels
  9. from imodels import TreeGAMClassifier
  10. from sklearn.model_selection import train_test_split
  11. # def test_gam_hyperparams():
  12. # X, y, feat_names = imodels.get_clean_dataset("heart")
  13. # X, _, y, _ = train_test_split(X, y, test_size=0.9, random_state=13)
  14. # roc = 0.5
  15. # for n_boosting_rounds in [1, 2, 3]:
  16. # m = TreeGAMClassifier(
  17. # n_boosting_rounds=n_boosting_rounds,
  18. # max_leaf_nodes=2,
  19. # random_state=42,
  20. # n_boosting_rounds_marginal=0,
  21. # )
  22. # m.fit(X, y, learning_rate=0.1)
  23. # roc_new = metrics.roc_auc_score(y, m.predict_proba(X)[:, 1])
  24. # assert roc_new >= roc
  25. # roc = roc_new
  26. # roc = 0.5
  27. # for n_boosting_rounds_marginal in [1, 2, 3]:
  28. # m = TreeGAMClassifier(
  29. # n_boosting_rounds=0,
  30. # random_state=42,
  31. # n_boosting_rounds_marginal=n_boosting_rounds_marginal,
  32. # )
  33. # m.fit(X, y, learning_rate=0.1)
  34. # roc_new = metrics.roc_auc_score(y, m.predict_proba(X)[:, 1])
  35. # assert roc_new >= roc
  36. # roc = roc_new
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...