Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

DNN_IMDB.py 3.5 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
  1. import matplotlib.pyplot as plt
  2. import matplotlib as mpl
  3. import numpy as np
  4. import pandas as pd
  5. import tensorflow as tf
  6. from tensorflow import keras
  7. for module in mpl,np,pd,tf,keras:
  8. print(module.__name__,module.__version__)
  9. #取出词频为前10000
  10. vocab_size = 10000
  11. # <3的id都是特殊字符
  12. index_from = 3
  13. # 这里用keras里面的imdb数据集
  14. imdb = keras.datasets.imdb
  15. (train_data,train_labels),(test_data,test_labels) = imdb.load_data(num_words = vocab_size,index_from = index_from)
  16. print(train_data.shape)
  17. print(train_labels.shape)
  18. print(train_data[0],train_labels[0])
  19. print(type(train_data))
  20. print(type(train_labels))
  21. print(np.unique(train_labels))
  22. print(test_data.shape)
  23. print(test_labels.shape)
  24. word_index=imdb.get_word_index()
  25. print(len(word_index))
  26. print(type(word_index))
  27. print(word_index.get("footwork"))
  28. # 取出的词表索引从1开始,id都偏移3
  29. word_index = {k:(v+3) for k,v in word_index.items()}
  30. # 自定义索引0-3
  31. word_index["<PAD>"] = 0 # 填充字符
  32. word_index["<START>"] = 1 # 起始字符
  33. word_index["<UNK>"] = 2 # 找不到就返回UNK
  34. word_index["<END>"] = 3 # 结束字符
  35. # 转成习惯的方式,索引为key,单词为value
  36. reverse_word_index = {v:k for k,v in word_index.items()}
  37. # 查看解码效果
  38. print(reverse_word_index)
  39. # {34710: 'fawn', 52015: 'tsukino', 52016: 'nunnery', 16825: 'sonja', 63960: 'vani', 1417: 'woods', ……}
  40. # debug
  41. print(reverse_word_index[34707])
  42. # footwork
  43. print(word_index.get("footwork"))
  44. # 34707
  45. # 随机看一下样本长度,可见长度不一
  46. print(len(train_data[0]),len(train_data[1]),len(train_data[100]))
  47. # 218 189 158
  48. # 设置输入词汇表的长度,长度<500会被补全,>500会被截断
  49. max_length = 500
  50. # 填充padding
  51. # value 用什么值填充
  52. # padding 选择填充的顺序,2中pre,post
  53. train_data = keras.preprocessing.sequence.pad_sequences(train_data,value = word_index["<PAD>"],padding="pre",maxlen = max_length)
  54. # 使测试集要和训练集结构相同
  55. test_data = keras.preprocessing.sequence.pad_sequences(test_data,value = word_index["<PAD>"],padding = "pre",maxlen = max_length)
  56. # 一个单词的维度是16维
  57. embedding_dim = 16
  58. batch_size = 128
  59. # 定义模型
  60. # 定义矩阵 [vocab_size,embedding_dim]
  61. # GlobalAveragePooling1D 全局平均值池化-在max_length这个维度上做平均,就是1x16了,在哪个维度上做Global,该维度就会消失
  62. # 二分类问题,最后的激活函数用sigmoid
  63. model = keras.models.Sequential([
  64. keras.layers.Embedding(vocab_size,embedding_dim,input_length = max_length),
  65. keras.layers.GlobalAveragePooling1D(),
  66. keras.layers.Dense(64,activation = "relu"),
  67. keras.layers.Dense(1,activation = "sigmoid"),
  68. ])
  69. model.summary()
  70. model.compile(optimizer = "adam",loss = "binary_crossentropy",metrics = ["accuracy"])
  71. # 数据集中只有训练集、测试集,没有验证集,就用validation_split-拿20%的训练集数据当作验证集数据
  72. history=model.fit(train_data,train_labels,epochs=30,batch_size=batch_size,validation_split=0.2)
  73. # 绘制学习曲线
  74. def plot_(history,label):
  75. plt.plot(history.history[label])
  76. plt.plot(history.history["val_" + label])
  77. plt.title("model " + label)
  78. plt.ylabel(label)
  79. plt.xlabel("epoch")
  80. plt.legend(["train","validation"],loc = "upper left")
  81. plt.show()
  82. plot_(history,"acc")
  83. plot_(history,"loss")
  84. score = model.evaluate(
  85. test_data,test_labels,
  86. batch_size=batch_size,
  87. verbose=1
  88. )
  89. print("Test loss:", score[0])
  90. print("Test accuracy:", score[1])
  91. '''
  92. Test loss: 0.6628
  93. Test accuracy: 0.8588
  94. '''
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...