Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

prepare.py 2.2 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
  1. import os
  2. import random
  3. import re
  4. import sys
  5. import xml.etree.ElementTree
  6. import yaml
  7. def process_posts(input_lines, fd_out_train, fd_out_test, target_tag, split):
  8. """
  9. Process the input lines and write the output to the output files.
  10. Args:
  11. input_lines (list): List of input lines.
  12. fd_out_train (file): Output file for the training data set.
  13. fd_out_test (file): Output file for the test data set.
  14. target_tag (str): Target tag.
  15. split (float): Test data set split ratio.
  16. """
  17. num = 1
  18. for line in input_lines:
  19. try:
  20. fd_out = fd_out_train if random.random() > split else fd_out_test
  21. attr = xml.etree.ElementTree.fromstring(line).attrib
  22. pid = attr.get("Id", "")
  23. label = 1 if target_tag in attr.get("Tags", "") else 0
  24. title = re.sub(r"\s+", " ", attr.get("Title", "")).strip()
  25. body = re.sub(r"\s+", " ", attr.get("Body", "")).strip()
  26. text = title + " " + body
  27. fd_out.write("{}\t{}\t{}\n".format(pid, label, text))
  28. num += 1
  29. except Exception as ex:
  30. sys.stderr.write(f"Skipping the broken line {num}: {ex}\n")
  31. def main():
  32. params = yaml.safe_load(open("params.yaml"))["prepare"]
  33. if len(sys.argv) != 2:
  34. sys.stderr.write("Arguments error. Usage:\n")
  35. sys.stderr.write("\tpython prepare.py data-file\n")
  36. sys.exit(1)
  37. # Test data set split ratio
  38. split = params["split"]
  39. random.seed(params["seed"])
  40. input = sys.argv[1]
  41. output_train = os.path.join("data", "prepared", "train.tsv")
  42. output_test = os.path.join("data", "prepared", "test.tsv")
  43. os.makedirs(os.path.join("data", "prepared"), exist_ok=True)
  44. input_lines = []
  45. with open(input) as fd_in:
  46. input_lines = fd_in.readlines()
  47. fd_out_train = open(output_train, "w", encoding="utf-8")
  48. fd_out_test = open(output_test, "w", encoding="utf-8")
  49. process_posts(
  50. input_lines=input_lines,
  51. fd_out_train=fd_out_train,
  52. fd_out_test=fd_out_test,
  53. target_tag="<r>",
  54. split=split,
  55. )
  56. fd_out_train.close()
  57. fd_out_test.close()
  58. if __name__ == "__main__":
  59. main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...