Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

dataset.py 1.4 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
  1. import torch
  2. from torch.utils.data import Dataset
  3. from bald.load_ner_dataset import load_ner_dataset
  4. class ConllDataset(Dataset):
  5. def __init__(self,data_path,vectors,emb_dim):
  6. self.data = load_ner_dataset(data_path)
  7. self.encoding = {
  8. 'B-PER':1,
  9. 'I-PER':1,
  10. 'B-ORG':2,
  11. 'I-ORG':2,
  12. 'B-LOC':3,
  13. 'I-LOC':3,
  14. 'B-MISC':4,
  15. 'I-MISC':4,
  16. 'O':5,
  17. }
  18. self.max_seq_len = self.compute_max_seq_len()
  19. self.num_labels = len(self.encoding)
  20. self.vectors = vectors
  21. self.emb_dim = emb_dim
  22. def compute_max_seq_len(self):
  23. return max(len(d["tag"]) for d in self.data)
  24. def set_max_seq_len(self,val: int):
  25. assert val > 0
  26. self.max_seq_len = val
  27. def __len__(self):
  28. return len(self.data)
  29. def __getitem__(self,i):
  30. sample = self.data[i]
  31. x_seq = sample["text"]
  32. y_seq = sample["tag"]
  33. x_seq = [self.vectors[tok] for tok in x_seq]
  34. rest = self.max_seq_len - len(x_seq)
  35. assert rest >= 0
  36. x_seq.extend([torch.zeros(self.emb_dim) for _ in range(rest)])
  37. x_seq = torch.stack(x_seq)
  38. y_seq = [self.encoding[tok] for tok in y_seq]
  39. y_seq.extend([0 for _ in range(rest)])
  40. assert len(y_seq) == self.max_seq_len
  41. y_seq = torch.tensor(y_seq)
  42. return x_seq,y_seq
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...