Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

cifar10_model.py 880 B

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
  1. import tensorflow as tf
  2. from tensorflow.keras import datasets, layers, models, Model
  3. class Cifar10Model(Model):
  4. # model from https://www.tensorflow.org/tutorials/images/cnn
  5. def __init__(self, n_classes=10):
  6. super().__init__()
  7. # TODO move hyperparemters of model out to training script
  8. model = tf.keras.Sequential()
  9. model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)))
  10. model.add(layers.MaxPooling2D((2, 2)))
  11. model.add(layers.Conv2D(64, (3, 3), activation='relu'))
  12. model.add(layers.MaxPooling2D((2, 2)))
  13. model.add(layers.Conv2D(64, (3, 3), activation='relu'))
  14. model.add(layers.Flatten())
  15. model.add(layers.Dense(64, activation='relu'))
  16. model.add(layers.Dense(n_classes))
  17. self.model = model
  18. def call(self, images):
  19. return self.model(images)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...