Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

main.py 5.7 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
  1. import argparse
  2. import pandas as pd
  3. from sklearn.feature_extraction.text import TfidfVectorizer
  4. from sklearn.linear_model import SGDClassifier
  5. from transformers import DistilBertTokenizerFast
  6. from transformers import DistilBertForSequenceClassification, Trainer, TrainingArguments
  7. from sklearn.metrics import roc_auc_score, average_precision_score, accuracy_score, precision_score, recall_score, \
  8. f1_score
  9. from sklearn.model_selection import train_test_split
  10. import joblib
  11. import dagshub
  12. import torch
  13. # Consts
  14. CLASS_LABEL = 'label'
  15. train_df_path = 'data/train.csv.zip'
  16. test_df_path = 'data/test.csv.zip'
  17. class HSDataset(torch.utils.data.Dataset):
  18. def __init__(self, encodings, labels):
  19. self.encodings = encodings
  20. self.labels = labels
  21. def __getitem__(self, idx):
  22. item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
  23. item['label'] = torch.tensor(self.labels[idx])
  24. return item
  25. def __len__(self):
  26. return len(self.labels)
  27. def feature_engineering(raw_df):
  28. df = raw_df.copy()
  29. df['len'] = df.comment.str.len()
  30. df['comment'] = df['comment'].fillna('')
  31. df = df.drop(columns=['isHate'])
  32. return df
  33. def fit_tfidf(train_df, test_df):
  34. tfidf = TfidfVectorizer(max_features=25000)
  35. tfidf.fit(train_df['comment'])
  36. train_tfidf = tfidf.transform(train_df['comment'])
  37. test_tfidf = tfidf.transform(test_df['comment'])
  38. return train_tfidf, test_tfidf, tfidf
  39. def fit_tokenizer(train_df, test_df):
  40. tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
  41. train_tokens = tokenizer(train_df['comment'].values.tolist(), truncation=True, padding=True)
  42. test_tokens = tokenizer(test_df['comment'].values.tolist(), truncation=True, padding=True)
  43. return train_tokens, test_tokens
  44. def fit_model(train_ds, test_ds, training_args):
  45. model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased")
  46. trainer = Trainer(
  47. model=model, # the instantiated 🤗 Transformers model to be trained
  48. args=training_args, # training arguments, defined above
  49. train_dataset=train_ds, # training dataset
  50. eval_dataset=test_ds # evaluation dataset
  51. )
  52. trainer.train()
  53. return trainer
  54. def eval_model(clf, ds):
  55. # y_proba = clf.predict_proba(X)[:, 1]
  56. y = ds.labels
  57. y_pred = clf.predict(ds)
  58. return {
  59. # 'roc_auc': roc_auc_score(y, y_proba),
  60. # 'average_precision': average_precision_score(y, y_proba),
  61. 'accuracy': accuracy_score(y, y_pred.label_ids),
  62. 'precision': precision_score(y, y_pred.label_ids),
  63. 'recall': recall_score(y, y_pred.label_ids),
  64. 'f1': f1_score(y, y_pred.label_ids),
  65. }
  66. def split(random_state=42):
  67. print('Loading data...')
  68. df = pd.read_csv('data/Ethos_Dataset_Binary.csv', delimiter=';')
  69. df[CLASS_LABEL] = df.isHate.apply(lambda x: float(x>=0.5))
  70. # df[CLASS_LABEL] = df['Tags'].str.contains('machine-learning').fillna(False)
  71. train_df, test_df = train_test_split(df, random_state=random_state, stratify=df[CLASS_LABEL])
  72. print('Saving split data...')
  73. train_df.to_csv(train_df_path)
  74. test_df.to_csv(test_df_path)
  75. def train():
  76. print('Loading data...')
  77. train_df = pd.read_csv(train_df_path)
  78. test_df = pd.read_csv(test_df_path)
  79. # print('Engineering features...')
  80. # train_df = feature_engineering(train_df)
  81. # test_df = feature_engineering(test_df)
  82. with dagshub.dagshub_logger() as logger:
  83. print('Fitting Tokenizer..')
  84. train_tokens, test_tokens = fit_tokenizer(train_df, test_df)
  85. # print('Saving TFIDF object...')
  86. # joblib.dump(tfidf, 'outputs/tfidf.joblib')
  87. # logger.log_hyperparams({'tfidf': tfidf.get_params()})
  88. print('Training model...')
  89. train_y = train_df[CLASS_LABEL]
  90. test_y = test_df[CLASS_LABEL]
  91. train_dataset = HSDataset(train_tokens, train_y.astype(int))
  92. test_dataset = HSDataset(test_tokens, test_y.astype(int))
  93. training_args = TrainingArguments(
  94. output_dir='./results', # output directory
  95. num_train_epochs=3, # total number of training epochs
  96. per_device_train_batch_size=16, # batch size per device during training
  97. per_device_eval_batch_size=64, # batch size for evaluation
  98. warmup_steps=500, # number of warmup steps for learning rate scheduler
  99. weight_decay=0.01, # strength of weight decay
  100. logging_dir='./logs', # directory for storing logs
  101. logging_steps=10,
  102. )
  103. model = fit_model(train_dataset, test_dataset, training_args)
  104. # print('Saving trained model...')
  105. # joblib.dump(model, 'outputs/model.joblib')
  106. logger.log_hyperparams(model_class=type(model).__name__)
  107. logger.log_hyperparams({'model': model.args})
  108. print('Evaluating model...')
  109. train_metrics = eval_model(model, train_dataset)
  110. print('Train metrics:')
  111. print(train_metrics)
  112. logger.log_metrics({f'train__{k}': v for k,v in train_metrics.items()})
  113. test_metrics = eval_model(model, test_dataset)
  114. print('Test metrics:')
  115. print(test_metrics)
  116. logger.log_metrics({f'test__{k}': v for k,v in test_metrics.items()})
  117. if __name__ == '__main__':
  118. parser = argparse.ArgumentParser()
  119. subparsers = parser.add_subparsers(title='Split or Train step:', dest='step')
  120. subparsers.required = True
  121. split_parser = subparsers.add_parser('split')
  122. split_parser.set_defaults(func=split)
  123. train_parser = subparsers.add_parser('train')
  124. train_parser.set_defaults(func=train)
  125. parser.parse_args().func()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...