Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

style_predict.py 3.2 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
  1. import os
  2. import json
  3. import collections
  4. import re
  5. from keras.models import model_from_json
  6. from keras.preprocessing.sequence import pad_sequences
  7. from python.style_extract import tokenizer
  8. from tqdm import tqdm
  9. def load_model(folder):
  10. """Loads model files, and returns keras model and parameters"""
  11. with open(folder + '/model_arch.json', 'r', encoding="utf-8", errors="ignore") as f:
  12. model = model_from_json(f.read())
  13. model.load_weights(folder + '/model_weights.h5')
  14. model._make_predict_function()
  15. with open(folder + '/model_params.json', 'r', encoding="utf-8", errors="ignore") as f:
  16. data = json.load(f)
  17. word2ind = collections.defaultdict(lambda: 1, data["word2ind"])
  18. ind2word = {i: l for l, i in word2ind.items()}
  19. label2ind = data["label2ind"]
  20. ind2label = {i: l for l, i in label2ind.items()}
  21. ind2label[0] = 'n'
  22. maxlen = data["max_length"]
  23. return model, {'word2ind': word2ind, 'ind2word': ind2word, 'label2ind': label2ind, 'ind2label': ind2label,
  24. 'maxlen': maxlen}
  25. def predict_on_token_array(X, model, params):
  26. X_enc = [[params['word2ind'][x] for x in X]]
  27. X_enc = pad_sequences(X_enc, maxlen=params['maxlen'])
  28. y_enc = model.predict(X_enc).argmax(2)
  29. y_enc = list(y_enc)[0][-len(X):]
  30. return [params["ind2label"][y] for y in y_enc]
  31. def predict_on_test_file(filename, model, params):
  32. ret = []
  33. with open(filename, 'r', encoding='utf-8', errors='ignore') as f:
  34. for line in f.readlines():
  35. line_prediction = ' '.join(predict_on_token_array(line.split(), model, params))
  36. ret.append(line_prediction)
  37. return ret
  38. def predict_on_test_dir(dirname, model, params):
  39. ret= {}
  40. for fname in tqdm(os.listdir(dirname)):
  41. if fname.endswith('.txt'):
  42. ret[fname.replace('.txt', '')] = predict_on_test_file(dirname+'/'+fname, model, params)
  43. return ret
  44. def autotag(text, model, params):
  45. """Gets text, model and params, and outputs formatter HTML"""
  46. # Covert line to X_enc vector, and predict y_enc
  47. X = [tokenizer(line.strip(), lower=False, enum=False, numeric=False) for line in text.split('\n')]
  48. X_enc = [[params['word2ind'][tokenizer(c, split=False, enum=True, numeric=True)] for c in x] for x in X]
  49. X_enc = pad_sequences(X_enc, maxlen=params['maxlen'])
  50. y_enc = model.predict(X_enc).argmax(2)
  51. # Turn prediction to HTML
  52. lines = []
  53. for row in zip(X, y_enc):
  54. lines.append([])
  55. for word, label in zip(reversed(row[0]), reversed(row[1])):
  56. tag = params['ind2label'][label]
  57. lines[-1].insert(0, "<{t}>{w}</{t}>".format(t=tag, w=word) if tag != 'n' else word)
  58. html = "<br>".join([' '.join(line) for line in lines])
  59. # Rejoin words together, to get a cleaner view
  60. for tag in params['ind2label'].values():
  61. html = html.replace("</{t}> <{t}>".format(t=tag), " ")
  62. html = re.compile('" (\w+) "').sub('"\\1"', html)
  63. html = re.compile(' ([\\.,:;]) ').sub('\\1 ', html)
  64. return html
  65. if __name__ == "__main__":
  66. model, params = load_model("../model")
  67. X = "governing law . the parties shall obide to".split()
  68. y = predict_on_token_array(X, model, params)
  69. print(list(zip(X, y)))
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...