Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

utils_embedding.py 5.2 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
  1. import logging
  2. from dataclasses import dataclass
  3. from typing import List, Dict, Optional
  4. from transformers import PreTrainedTokenizer, is_torch_available
  5. import pdb
  6. from tqdm import tqdm
  7. logger = logging.getLogger(__name__)
  8. @dataclass
  9. class InputFeatures:
  10. """
  11. A single set of features of data.
  12. Property names are the same names as the corresponding inputs to a model.
  13. """
  14. input_ids: List[int]
  15. attention_mask: List[int]
  16. token_type_ids: Optional[List[int]] = None
  17. metadata: Optional[dict] = None
  18. def read_texts_from_file(file_path) -> List[str]:
  19. with open(file_path, encoding="utf-8") as f:
  20. texts = f.readlines()
  21. return texts
  22. def convert_texts_to_features(
  23. texts: List[str],
  24. max_seq_length: int,
  25. tokenizer: PreTrainedTokenizer,
  26. cls_token="[CLS]",
  27. sep_token="[SEP]",
  28. ) -> dict:
  29. """ Convert text in .txt file into input features
  30. """
  31. features = []
  32. for t_idx, text in tqdm(enumerate(texts), total=len(texts)):
  33. tokens = tokenizer.tokenize(text)
  34. # Account for [CLS] and [SEP] with "- 2" and with "- 3" for RoBERTa.
  35. special_tokens_count = tokenizer.num_special_tokens_to_add()
  36. if len(tokens) > max_seq_length - special_tokens_count:
  37. tokens = tokens[: (max_seq_length - special_tokens_count)]
  38. # The convention in BERT is:
  39. # (a) For sequence pairs:
  40. # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
  41. # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
  42. # (b) For single sequences:
  43. # tokens: [CLS] the dog is hairy . [SEP]
  44. # type_ids: 0 0 0 0 0 0 0
  45. #
  46. # Where "type_ids" are used to indicate whether this is the first
  47. # sequence or the second sequence. The embedding vectors for `type=0` and
  48. # `type=1` were learned during pre-training and are added to the wordpiece
  49. # embedding vector (and position vector). This is not *strictly* necessary
  50. # since the [SEP] token unambiguously separates the sequences, but it makes
  51. # it easier for the model to learn the concept of sequences.
  52. #
  53. # For classification tasks, the first vector (corresponding to [CLS]) is
  54. # used as as the "sentence vector". Note that this only makes sense because
  55. # the entire model is fine-tuned.
  56. tokens += [sep_token]
  57. segment_ids = [0] * len(tokens)
  58. # CLS token
  59. tokens = [cls_token] + tokens
  60. segment_ids = [0] + segment_ids
  61. input_ids = tokenizer.convert_tokens_to_ids(tokens)
  62. # The mask has 1 for real tokens and 0 for padding tokens. Only real
  63. # tokens are attended to.
  64. input_mask = [1] * len(input_ids)
  65. # Zero-pad up to the sequence length.
  66. padding_length = max_seq_length - len(input_ids)
  67. input_ids += [0] * padding_length
  68. input_mask += [0] * padding_length
  69. segment_ids += [0] * padding_length
  70. assert len(input_ids) == max_seq_length
  71. assert len(input_mask) == max_seq_length
  72. assert len(segment_ids) == max_seq_length
  73. if t_idx<5:
  74. logger.info("*** Example ***")
  75. logger.info("tokens: %s", " ".join([str(x) for x in tokens]))
  76. logger.info("input_ids: %s", " ".join([str(x) for x in input_ids]))
  77. logger.info("input_mask: %s", " ".join([str(x) for x in input_mask]))
  78. logger.info("segment_ids: %s", " ".join([str(x) for x in segment_ids]))
  79. features.append(
  80. InputFeatures(
  81. input_ids=input_ids,
  82. attention_mask=input_mask,
  83. token_type_ids=segment_ids,
  84. metadata={
  85. 'text':text.strip(),
  86. 'text_id':str(t_idx),
  87. 'tokens': tokens
  88. }
  89. )
  90. )
  91. return features
  92. if is_torch_available():
  93. import torch
  94. from torch import nn
  95. from torch.utils.data.dataset import Dataset
  96. class EmbeddingDataset(Dataset):
  97. """
  98. This will be superseded by a framework-agnostic approach
  99. soon.
  100. """
  101. features: List[InputFeatures]
  102. def __init__(
  103. self,
  104. data_path: str,
  105. tokenizer: PreTrainedTokenizer,
  106. max_seq_length,
  107. ):
  108. logger.info(f"Creating features from dataset file at {data_path}")
  109. texts = read_texts_from_file(data_path)
  110. self.features = convert_texts_to_features(
  111. texts = texts,
  112. max_seq_length = max_seq_length,
  113. tokenizer = tokenizer,
  114. )
  115. def __len__(self):
  116. return len(self.features)
  117. def __getitem__(self, i) -> InputFeatures:
  118. return self.features[i]
  119. def data_collator(features: List[InputFeatures]) -> Dict[str, torch.Tensor]:
  120. first = features[0]
  121. batch = {}
  122. for k, v in first.__dict__.items():
  123. if k == 'metadata':
  124. batch[k] = [f.__dict__[k] for f in features]
  125. else:
  126. batch[k] = torch.tensor([f.__dict__[k] for f in features], dtype=torch.long)
  127. return batch
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...