Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

data_loaders.py 2.1 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
  1. import yaml
  2. import torch
  3. from functools import partial
  4. from torch.nn.utils.rnn import pad_sequence
  5. from torch.utils.data import DataLoader
  6. from utils.preprocess import Preprocessor
  7. with open('params.yaml', 'r') as f:
  8. PARAMS = yaml.safe_load(f)
  9. if torch.cuda.is_available():
  10. DEVICE = torch.device('cuda', PARAMS.get('gpu', 0))
  11. else:
  12. DEVICE = torch.device('cpu')
  13. class DataFrameDataLoader(DataLoader):
  14. def __init__(self, df, use_bag=True, use_eos=True, max_len=None, *args, **kwargs):
  15. # order is text, label
  16. self._preprocessor = None
  17. self.init()
  18. self._data_iter = list(zip(df['review'], df['sentiment']))
  19. collate_batch = partial(self.collate_batch, use_bag=use_bag, use_eos=use_eos, max_len=max_len)
  20. super(DataFrameDataLoader, self).__init__(self._data_iter, collate_fn=collate_batch, *args, **kwargs)
  21. def init(self):
  22. self._preprocessor = Preprocessor(torch.load('outputs/vocab.plk'))
  23. def collate_batch(self, batch, use_bag, use_eos, max_len):
  24. label_list, text_list, offsets = [], [], [0]
  25. for (_text, _label) in batch:
  26. label_list.append(_label)
  27. processed_text = self._preprocessor.text_pipeline(_text)
  28. if use_eos:
  29. processed_text += [self.vocab[PARAMS['eos_token']]]
  30. if max_len:
  31. processed_text = processed_text[:max_len]
  32. processed_text = torch.tensor(processed_text, dtype=torch.int64)
  33. text_list.append(processed_text)
  34. offsets.append(processed_text.size(0))
  35. label_list = torch.tensor(label_list, dtype=torch.float32)
  36. if use_bag:
  37. offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
  38. text_list = torch.cat(text_list)
  39. else:
  40. offsets = torch.tensor(offsets[1:], dtype=torch.int64)
  41. text_list = pad_sequence(text_list, padding_value=self.vocab[PARAMS['pad_token']])
  42. return label_list.to(DEVICE), text_list.to(DEVICE), offsets.to(DEVICE)
  43. @property
  44. def vocab(self):
  45. return self._preprocessor.vocab
  46. @property
  47. def vocab_size(self):
  48. return len(self._preprocessor)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...