Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

mnist-keras.py 2.1 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
  1. from mlutils import git_version, mlversion, timestamp
  2. import git
  3. from datetime import datetime
  4. from keras.utils import to_categorical
  5. from keras import layers
  6. from keras import models
  7. from keras.datasets import mnist
  8. from matplotlib import pyplot as plt
  9. import tensorflow as tf
  10. import keras
  11. import platform
  12. import sys
  13. import subprocess
  14. import os
  15. from tensorflow.python.keras.utils.vis_utils import plot_model
  16. import matplotlib.image as mpimg
  17. import logging
  18. import coloredlogs
  19. coloredlogs.install(level='DEBUG')
  20. log = logging.getLogger('mnist-keras')
  21. MODEL_VERSION = timestamp() + git_version() + "-V1.0"
  22. log.warning(MODEL_VERSION)
  23. log.error(MODEL_VERSION)
  24. mlversion()
  25. # sys.exit()
  26. (train_images, train_labels), (test_images, test_labels) = mnist.load_data(
  27. os.getcwd() + "/datasets/mnist.npz")
  28. train_images.shape
  29. train_labels.shape
  30. test_images.shape
  31. test_labels.shape
  32. img = train_images[100]
  33. plt.imshow(img, cmap=plt.cm.binary)
  34. def build_mode():
  35. model = keras.Sequential(name=MODEL_VERSION)
  36. model.add(layers.Dense(512, activation='relu', input_shape=(28 * 28,)))
  37. model.add(layers.Dense(10, activation='softmax'))
  38. model.compile(optimizer='rmsprop',
  39. loss='categorical_crossentropy', metrics=['accuracy'])
  40. plot_model(model, to_file='models1.png',
  41. show_shapes=True, show_layer_names=True)
  42. # plt.imshow(mpimg.imread("models1.png"))
  43. return model
  44. train_images = train_images.reshape(60000, 28*28)
  45. train_images = train_images.astype('float32')/255
  46. train_labels = to_categorical(train_labels)
  47. test_images = test_images.reshape(10000, 28*28)
  48. test_images = test_images.astype('float32')/255
  49. test_labels = to_categorical(test_labels)
  50. with tf.distribute.MirroredStrategy().scope():
  51. model = build_mode()
  52. model.fit(train_images, train_labels, epochs=2, batch_size=128)
  53. model_path = os.path.join("./models/")
  54. if not os.path.exists(model_path):
  55. os.makedirs(model_path)
  56. model_file = os.path.join(model_path, MODEL_VERSION + ".h5")
  57. model.save(model_file)
  58. model.summary()
  59. test_loss, test_acc = model.evaluate(test_images, test_labels)
  60. print("test_loss : ", test_loss)
  61. print("test_acc : ", test_acc)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...