Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

文本生成.py 7.8 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
  1. # -*- coding: utf-8 -*-
  2. """文本生成.ipynb
  3. Automatically generated by Colaboratory.
  4. Original file is located at
  5. https://colab.research.google.com/drive/1JFyRrdUS6eRqL3h2pLPCCUP3_iBh7hkV
  6. """
  7. import matplotlib as mpl
  8. import matplotlib.pyplot as plt
  9. import numpy as np
  10. import pandas as pd
  11. import tensorflow as tf
  12. from tensorflow import keras
  13. import sklearn
  14. import os
  15. import sys
  16. import time
  17. print(tf.__version__)
  18. print(sys.version_info)
  19. for module in mpl,np,pd,sklearn,tf,keras:
  20. print(module.__name__,module.__version__)
  21. # !wget https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt
  22. # 莎士比亚文集路径
  23. input_filepath="/content/shakespeare.txt"
  24. # 读取文集
  25. text=open(input_filepath,'r').read()
  26. print(len(text))
  27. print(type(text))
  28. print(text[0:100])
  29. """100W+个数据集对于NN训练是很小的"""
  30. # 对文集去重并排序
  31. vocab=sorted(set(text))
  32. print(len(vocab))
  33. print(type(vocab))
  34. print(vocab)
  35. """char-level:区分字母大小写,各种符号"""
  36. # 给字符编号char:idx
  37. char2idx={char:idx for idx,char in enumerate(vocab)}
  38. print(type(char2idx))
  39. print(char2idx)
  40. # idx:char-vocab的下标就是天然的索引
  41. # np的array速度比python的list快
  42. idx2char=np.array(vocab)
  43. print(type(idx2char))
  44. print(idx2char)
  45. # 用{字符:编号}把文集中的字符都转成数字编号
  46. text_as_int=np.array([char2idx[c] for c in text])
  47. print(len(text_as_int))
  48. print(type(text_as_int))
  49. print(text[0:10])
  50. print(text_as_int[0:10])
  51. """文本First Citi对应的字符索引是[18 47 56 57 58 1 15 47 58 47]
  52. 莎士比亚文集:char-level
  53. """
  54. # 对文本划分好输入、输出
  55. # e.g.对文本abcde划分得到输入abcd,输出bcde
  56. def split_input_target(id_text):
  57. return id_text[0:-1],id_text[1:]
  58. # 把id text作为数据集
  59. char_dataset=tf.data.Dataset.from_tensor_slices(text_as_int)
  60. # 给100个序列,则划分成100个输入,100个输出
  61. seq_length=100
  62. # 分批处理,得到多组输入、输出
  63. seq_dataset=char_dataset.batch(seq_length+1,drop_remainder=True)
  64. # 拿出来两个字符,看看是什么具体字符
  65. for ch_id in char_dataset.take(2):
  66. print(ch_id,idx2char[ch_id.numpy()])
  67. # 拿出来两个句子,看看是什么具体句子
  68. for seq_id in seq_dataset.take(2):
  69. print(seq_id)
  70. print(repr(''.join(idx2char[seq_id.numpy()])))
  71. print(repr(' '.join(idx2char[seq_id.numpy()])))
  72. """seq_id是用char-level表示的
  73. 字符可以预测出空格,就作为单词的分隔了
  74. """
  75. # 划分单词的输入、输出
  76. # map实现调用函数
  77. seq_dataset=seq_dataset.map(split_input_target)
  78. # 拿出来两个序列,划分输入输出
  79. for item_input,item_output in seq_dataset.take(1):
  80. print(item_input.numpy())
  81. print(item_output.numpy())
  82. print(seq_dataset)
  83. """看例子,18输出47,47输出56,56输出57..."""
  84. # batch处理
  85. batch_size=64
  86. # 洗牌
  87. buffer_size=10000
  88. # shuffle()-随机重新排列此数据集的元素-理解:如果数据集包含10,000个元素但buffer_size设置为1,000个,则shuffle最初将仅从缓冲区中的前1,000个元素中选择一个随机元素。选择一个元素后,其缓冲区中的空间将被下一个(即1,001个)元素替换,并保留1,000个元素缓冲区。
  89. seq_dataset=seq_dataset.shuffle(buffer_size).batch(batch_size,drop_remainder=True)
  90. print(seq_dataset)
  91. vocab_size=len(vocab)
  92. # 资料小,就升维
  93. embedding_dim=256
  94. rnn_units=1024
  95. # 搭建模型写成一个函数
  96. def build_model(vocab_size,embedding_dim,rnn_units,batch_size):
  97. model=keras.models.Sequential([
  98. keras.layers.Embedding(vocab_size,embedding_dim,batch_input_shape=[batch_size,None]),
  99. # stateful=True 是否要把最后返回的状态添加到输出
  100. # recurrent_initializer='glorot_uniform'-RNN的权值初始值,均值为0,以0为中心的对称区间均匀分布的随机数
  101. keras.layers.SimpleRNN(units=rnn_units,stateful=True,recurrent_initializer='glorot_uniform',return_sequences=True),
  102. # 让输出是65个vocab_size的一个
  103. keras.layers.Dense(vocab_size),
  104. ])
  105. return model
  106. model=build_model(vocab_size=vocab_size,embedding_dim=embedding_dim,rnn_units=rnn_units,batch_size=batch_size)
  107. model.summary()
  108. model.variables
  109. # 从序列中拿出一个take(1)看一看模型对输入处理完的输出,简单验证能否做预测
  110. for input_example_batch,target_example_batch in seq_dataset.take(1):
  111. # 把model当成函数来用了,其实是调用了call方法-使实例(对象)像函数一样被调用
  112. example_batch_predictions=model(input_example_batch)
  113. print(example_batch_predictions.shape)
  114. print(example_batch_predictions)
  115. """浮点数理解-softmax输出就是概率值,再对应到具体的一类中,这里就是字符类"""
  116. # 输入是100个char,输出也是100个char
  117. # categorical()从分类分布中抽取样本
  118. # logits:2-D Tensor with shape [batch_size, num_classes],这里是[100,1]的2维张量
  119. sample_indices=tf.random.categorical(logits=example_batch_predictions[0],num_samples=1)
  120. print(sample_indices)
  121. # sqeeze:转为100的向量
  122. sample_indices=tf.squeeze(sample_indices,axis=-1)
  123. print(sample_indices)
  124. # 看一组输入、输出
  125. print("Input:",repr("".join(idx2char[input_example_batch[0]])))
  126. print()
  127. print("Output:",repr("".join(idx2char[target_example_batch[0]])))
  128. print()
  129. print("Predictions:",repr("".join(idx2char[sample_indices])))
  130. """预测效果不咋地"""
  131. def loss(labels,logits):
  132. return keras.losses.sparse_categorical_crossentropy(labels,logits,from_logits=True)
  133. model.compile(optimizer='adam',loss=loss)
  134. # 真实标签和预测标签的区别
  135. example_loss=loss(target_example_batch,example_batch_predictions)
  136. print(example_loss.shape)
  137. print(example_loss.numpy().mean())
  138. # 保存模型-经过训练之后的模型效果会比上边那个没有经过训练的效果要好
  139. # 定义一个文件夹来保存模型
  140. output_dir="/content/text_generation_checkpoints"
  141. if not os.path.exists(output_dir):
  142. os.mkdir(output_dir)
  143. # 保存最后一次迭代的模型
  144. checkpoint_prefix=os.path.join(output_dir,'ckpt_{epoch}')
  145. checkpoint_callback=keras.callbacks.ModelCheckpoint(
  146. filepath=checkpoint_prefix,
  147. # 只保存权重值
  148. save_weights_only=True
  149. )
  150. epochs=100
  151. history=model.fit(seq_dataset,epochs=epochs,callbacks=[checkpoint_callback])
  152. # 看看最好的模型
  153. tf.train.latest_checkpoint(output_dir)
  154. # 再创建一个新模型,去加载已经保存过的权重值
  155. output_dir="/content/text_generation_checkpoints"
  156. model2=build_model(vocab_size,embedding_dim,rnn_units,batch_size=1)
  157. model2.load_weights(tf.train.latest_checkpoint(output_dir))
  158. # 1个样本,None表示变长序列
  159. model2.build(tf.TensorShape([1,None]))
  160. model2.summary()
  161. # 实现文本生成
  162. def generate_text(model,start_string,num_generate=1000):
  163. # 输入
  164. input_eval=[char2idx[ch] for ch in start_string]
  165. print(input_eval)
  166. # 升维-在axis=0方向上
  167. input_eval=tf.expand_dims(input_eval,0)
  168. print(input_eval)
  169. text_generated=[]
  170. # 连续调用模型用reset_states()
  171. model.reset_states()
  172. # 逐个预测输出字符
  173. for _ in range(num_generate):
  174. predictions=model(input_eval)
  175. # squeeze降维:消掉batch_size维度,变成predictions:[input_eval_len,vocab_size]
  176. predictions=tf.squeeze(predictions,0)
  177. print(predictions)
  178. # 倒序、抽取样本、降维成1维
  179. predicted_id=tf.random.categorical(predictions,num_samples=1)[-1,0].numpy()
  180. print(predicted_id)
  181. # 得到一个预测字符,append到生成文本中
  182. text_generated.append(idx2char[predicted_id])
  183. # 得到的预测再去作为下一次的输入字符
  184. input_eval=tf.expand_dims([predicted_id],0)
  185. return start_string+''.join(text_generated)
  186. new_text=generate_text(model2,"amazing:")
  187. print(new_text)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...