Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

cnn.py 2.3 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
  1. # -*- coding: utf-8 -*-
  2. # Copyright (c) 2021. Jeffrey Nirschl. All rights reserved.
  3. #
  4. # Licensed under the MIT license. See the LICENSE file in the project
  5. # root directory for license information.
  6. #
  7. # Time-stamp: <>
  8. # ======================================================================
  9. import os
  10. import tensorflow as tf
  11. from keras.layers import Dense, Activation, Flatten, Dropout, BatchNormalization, Conv2D, MaxPooling2D
  12. def conv_relu_bn(filters, kernel_size=(3,3), strides=(1, 1), padding="valid"):
  13. return [tf.keras.layers.Conv2D(filters=filters, kernel_size=kernel_size,
  14. strides=strides, padding=padding,
  15. activation=tf.nn.relu),
  16. tf.keras.layers.BatchNormalization(),
  17. ]
  18. def simple_mnist(base_filter=32, fc_width=512, dropout_rate=0.5,
  19. image_size=(28,28,1), n_class=10, learn_rate=0.01,
  20. optimizer="adam"):
  21. """Simple CNN implementation for MNIST"""
  22. assert (base_filter > 0), ValueError
  23. assert (fc_width > 0), ValueError
  24. assert 0 <= dropout_rate < 1, ValueError
  25. model = tf.keras.Sequential([
  26. tf.keras.Input(shape=image_size, name="Input"),
  27. *conv_relu_bn(filters=base_filter, kernel_size=(3, 3), strides=(1, 1)),
  28. *conv_relu_bn(filters=2*base_filter, kernel_size=(3, 3), strides=(1, 1)),
  29. MaxPooling2D(pool_size=(2, 2)),
  30. *conv_relu_bn(filters=4 * base_filter, kernel_size=(3, 3), strides=(1, 1)),
  31. *conv_relu_bn(filters=4 * base_filter, kernel_size=(3, 3), strides=(1, 1)),
  32. MaxPooling2D(pool_size=(2, 2)),
  33. BatchNormalization(),
  34. * conv_relu_bn(filters=4 * base_filter, kernel_size=(3, 3), strides=(1, 1)),
  35. MaxPooling2D(pool_size=(2, 2)),
  36. Flatten(),
  37. BatchNormalization(),
  38. Dense(fc_width, activation="relu"),
  39. tf.keras.layers.Dropout(dropout_rate),
  40. Dense(n_class, activation="softmax")
  41. ])
  42. if optimizer.lower() == "adam":
  43. opt = tf.keras.optimizers.Adam(learning_rate=learn_rate)
  44. model.compile(loss="categorical_crossentropy",
  45. optimizer=opt, metrics=["accuracy"])
  46. else:
  47. model.compile(loss="categorical_crossentropy",
  48. optimizer="adam", metrics=["accuracy"])
  49. return model
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...