Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

conll_trainer.py 2.8 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
  1. """
  2. Step 1: train a simple model.
  3. Step 2: train a simple model using active learning.
  4. Step 3: understand something.
  5. Step 4: feel productive.
  6. Step 5: ...
  7. Step 6: profit.
  8. """
  9. import os
  10. import matplotlib.pyplot as plt
  11. from torchnlp.word_to_vector import GloVe
  12. import torch
  13. import torch.nn.functional as F
  14. import tqdm
  15. from torch import nn
  16. from torch.utils.data import DataLoader
  17. from sklearn.metrics import f1_score
  18. from bald import data_dir,vectors_dir,load_ner_dataset
  19. from bald.dataset import ConllDataset
  20. from bald.simple_model import ConllModel
  21. from bald.utils import epoch_run
  22. vectors = GloVe(cache=vectors_dir)
  23. train_path = os.path.join(data_dir,"raw","CoNLL2003","eng.train")
  24. train_ds = ConllDataset(data_path=train_path,vectors=vectors,emb_dim=300)
  25. test_path = os.path.join(data_dir,"raw","CoNLL2003","eng.testa")
  26. test_ds = ConllDataset(data_path=test_path,vectors=vectors,emb_dim=300)
  27. max_seq_len = max(train_ds.max_seq_len,test_ds.max_seq_len)
  28. train_ds.set_max_seq_len(max_seq_len)
  29. test_ds.set_max_seq_len(max_seq_len)
  30. train_dl = DataLoader(train_ds, batch_size=32, shuffle=True)
  31. test_dl = DataLoader(test_ds, batch_size=32, shuffle=False)
  32. model = ConllModel(
  33. max_seq_len = max_seq_len,
  34. num_labels = train_ds.num_labels,
  35. emb_dim = train_ds.emb_dim
  36. )
  37. def loss_fun(input,target):
  38. batch_len,seq_len = target.size()
  39. target = target.view(batch_len*seq_len)
  40. return F.cross_entropy(input=input,target=target,ignore_index=0)
  41. def score_fun(input,target):
  42. dims,labels = input.size()
  43. y_pred = F.softmax(input,dim=1)
  44. y_pred = torch.argmax(y_pred,dim=1)
  45. target = target.view(dims)
  46. return f1_score(
  47. y_true = target.cpu().data.numpy(),
  48. y_pred = y_pred.cpu().data.numpy(),
  49. labels = list(range(1,6)),
  50. average = "weighted",
  51. )
  52. # optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
  53. optimizer = torch.optim.Adam(model.parameters())
  54. # set number of epochs from command line
  55. # num_epochs = int(input("Enter number of epochs: "))
  56. num_epochs = 30
  57. train_losses = []
  58. test_losses = []
  59. for epoch in range(num_epochs):
  60. print(f"\nEpoch {epoch+1}.")
  61. print("Training.")
  62. run_d = epoch_run(
  63. model = model,
  64. data_loader = train_dl,
  65. criterion = loss_fun,
  66. score_fun = score_fun,
  67. trainer_mode = True,
  68. optimizer = optimizer,
  69. )
  70. train_losses.append(run_d["loss"])
  71. print(f"Train f1 score is {run_d['score']}.")
  72. print("Evaluating.")
  73. run_d = epoch_run(
  74. model = model,
  75. data_loader = train_dl,
  76. criterion = loss_fun,
  77. score_fun = score_fun,
  78. trainer_mode = False,
  79. )
  80. test_losses.append(run_d["loss"])
  81. print(f"Test f1 score is {run_d['score']}.")
  82. plt.plot(train_losses, label="train")
  83. plt.plot(test_losses, label="test")
  84. plt.legend()
  85. plt.show()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...