Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

01_fit_dnn_best.py 2.2 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
  1. import os
  2. from os.path import join as oj
  3. import sys
  4. sys.path.append('../src')
  5. import numpy as np
  6. import torch
  7. import scipy
  8. from matplotlib import pyplot as plt
  9. from sklearn import metrics
  10. import data
  11. from config import *
  12. from tqdm import tqdm
  13. import pickle as pkl
  14. import train_reg
  15. from copy import deepcopy
  16. import config
  17. import models
  18. import pandas as pd
  19. import features
  20. import outcomes
  21. import neural_networks
  22. from sklearn.model_selection import KFold
  23. from torch import nn, optim
  24. from torch.nn import functional as F
  25. from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
  26. from sklearn.linear_model import LinearRegression, RidgeCV
  27. from sklearn.svm import SVR
  28. from collections import defaultdict
  29. import pickle as pkl
  30. from sklearn.model_selection import train_test_split
  31. if __name__ == '__main__':
  32. print("loading data")
  33. ############ get data ######################
  34. df_train, df_test, _ = data.get_snf_mt_vs_wt()
  35. # splits = ['train', 'test']
  36. length = 40
  37. # padding = 'end'
  38. # feat_name = 'X_same_length_extended' # include buffer X_same_length_normalized
  39. feat_name = 'X_same_length' # include buffer X_same_length_normalized
  40. # outcome = 'Y_sig_mean_normalized'
  41. outcome = 'mt'
  42. epoch = 100
  43. df_full = df_train
  44. print('before dropping', df_full.shape)
  45. df_full = df_full[[feat_name, outcome]].dropna()
  46. print('after dropping', df_full.shape)
  47. print('vals', df_full['mt'].value_counts()) # 1791 class 0, 653 class 1
  48. ############ finish getting data data ######################
  49. np.random.seed(42)
  50. # checkpoint_fname = f'../models/dnn_vps_fit_extended_lifetimes>{lifetime_threshold}.pkl'
  51. checkpoint_fname = f'../models/vps_distingish_mt_vs_wt_epoch={epoch}.pkl'
  52. dnn = neural_networks.neural_net_sklearn(
  53. D_in=length, H=20, p=0, arch='lstm', lr=0.0001,
  54. epochs=epoch, track_name=feat_name
  55. )
  56. print('track_name', vars(dnn))
  57. dnn.fit(df_full[[feat_name]],
  58. df_full[outcome].values,
  59. verbose=True,
  60. checkpoint_fname=checkpoint_fname, device='cpu')
  61. pkl.dump({'model_state_dict': dnn.model.cpu().state_dict()}, open(checkpoint_fname, 'wb'))
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...