Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

test_model.py 1.6 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
  1. #!/usr/bin/env python
  2. #coding: utf-8
  3. import tensorflow as tf
  4. import numpy as np
  5. import PIL.Image as Image
  6. def recognize(png_path,pb_file_path):
  7. """
  8. Function:使用训练完的网络模型进行预测。
  9. Parameters
  10. ----------
  11. png_path:要预测的图片的路径。
  12. pb_file_path: pb文件的路径。
  13. """
  14. with tf.Graph().as_default():
  15. output_graph_def = tf.GraphDef()
  16. with open(pb_file_path, "rb") as f:
  17. output_graph_def.ParseFromString(f.read()) #rb
  18. _ = tf.import_graph_def(output_graph_def, name="")
  19. with tf.Session() as sess:
  20. tf.global_variables_initializer().run()
  21. input_x = sess.graph.get_tensor_by_name("input:0")
  22. print input_x
  23. out_softmax = sess.graph.get_tensor_by_name("out_softmax:0")
  24. print out_softmax
  25. keep_prob = sess.graph.get_tensor_by_name("keep_prob_placeholder:0")
  26. print keep_prob
  27. out_label = sess.graph.get_tensor_by_name("output:0")
  28. print out_label
  29. img_datas = np.array(Image.open(png_path).convert('L'))
  30. img_out_softmax = sess.run(out_softmax, feed_dict={
  31. input_x: img_datas,
  32. keep_prob: 1.0,
  33. })
  34. print "img_out_softmax:",img_out_softmax
  35. prediction_labels = np.argmax(img_out_softmax, axis=1)
  36. print "label:",prediction_labels
  37. #recognize("/home/tsiangleo/mnist_test_set/1.png","output/mnist-tf1.0.1.pb")
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...