Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

train.py 1.4 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
  1. import tensorflow as tf
  2. import datetime
  3. import yaml
  4. import json
  5. import time
  6. params = yaml.safe_load(open('params.yaml'))
  7. epochs = params['epochs']
  8. log_file = params['log_file']
  9. mnist = tf.keras.datasets.mnist
  10. (x_train, y_train),(x_test, y_test) = mnist.load_data()
  11. x_train, x_test = x_train / 255.0, x_test / 255.0
  12. def create_model():
  13. return tf.keras.models.Sequential([
  14. tf.keras.layers.Flatten(input_shape=(28, 28)),
  15. tf.keras.layers.Dense(10, activation='softmax')
  16. ])
  17. model = create_model()
  18. model.compile(optimizer='RMSprop',
  19. loss='sparse_categorical_crossentropy',
  20. metrics=['accuracy'])
  21. log_dir = "logs/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
  22. tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
  23. csv_logger = tf.keras.callbacks.CSVLogger(log_file)
  24. start_real = time.time()
  25. start_process = time.process_time()
  26. history = model.fit(x=x_train,
  27. y=y_train,
  28. epochs=epochs,
  29. validation_data=(x_test, y_test),
  30. callbacks=[csv_logger, tensorboard_callback])
  31. end_real = time.time()
  32. end_process = time.process_time()
  33. with open("summary.json", "w") as fd:
  34. json.dump({
  35. "accuracy": float(history.history["accuracy"][-1]),
  36. "loss": float(history.history["loss"][-1]),
  37. "time_real" : end_real - start_real,
  38. "time_process": end_process - start_process
  39. }, fd)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...