Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

test_callbacks.py 4.8 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
  1. import wandb
  2. from wandb import wandb_run
  3. from wandb.keras import WandbCallback
  4. import pytest
  5. from click.testing import CliRunner
  6. import os
  7. import json
  8. from .utils import git_repo
  9. from keras.layers import Dense, Flatten, Reshape
  10. from keras.models import Sequential
  11. import sys
  12. import glob
  13. @pytest.fixture
  14. def dummy_model(request):
  15. multi = request.node.get_marker('multiclass')
  16. image_output = request.node.get_marker('image_output')
  17. if multi:
  18. nodes = 10
  19. loss = 'categorical_crossentropy'
  20. else:
  21. nodes = 1
  22. loss = 'binary_crossentropy'
  23. nodes = 1 if not multi else 10
  24. if image_output:
  25. nodes = 300
  26. model = Sequential()
  27. model.add(Flatten(input_shape=(10, 10, 3)))
  28. model.add(Dense(nodes, activation='sigmoid'))
  29. if image_output:
  30. model.add(Dense(nodes, activation="relu"))
  31. model.add(Reshape((10, 10, 3)))
  32. model.compile(optimizer='adam',
  33. loss=loss,
  34. metrics=['accuracy'])
  35. return model
  36. @pytest.fixture
  37. def dummy_data(request):
  38. multi = request.node.get_marker('multiclass')
  39. image_output = request.node.get_marker('image_output')
  40. cats = 10 if multi else 1
  41. import numpy as np
  42. data = np.random.randint(255, size=(100, 10, 10, 3))
  43. labels = np.random.randint(2, size=(100, cats))
  44. if image_output:
  45. labels = data
  46. return (data, labels)
  47. @pytest.fixture
  48. def run():
  49. return wandb_run.Run.from_environment_or_defaults()
  50. def test_basic_keras(dummy_model, dummy_data, git_repo, run):
  51. wandb.run = run
  52. dummy_model.fit(*dummy_data, epochs=2, batch_size=36,
  53. callbacks=[WandbCallback()])
  54. wandb.run.summary.load()
  55. assert run.history.rows[0]["epoch"] == 0
  56. assert run.summary["acc"] > 0
  57. assert len(run.summary["graph"]["nodes"]) == 2
  58. def test_keras_image_bad_data(dummy_model, dummy_data, git_repo, run):
  59. wandb.run = run
  60. error = False
  61. data, labels = dummy_data
  62. try:
  63. dummy_model.fit(*dummy_data, epochs=2, batch_size=36, validation_data=(data.reshape(10), labels),
  64. callbacks=[WandbCallback(data_type="image")])
  65. except ValueError:
  66. error = True
  67. assert error
  68. def test_keras_image_binary(dummy_model, dummy_data, git_repo, run):
  69. wandb.run = run
  70. dummy_model.fit(*dummy_data, epochs=2, batch_size=36, validation_data=dummy_data,
  71. callbacks=[WandbCallback(data_type="image")])
  72. assert len(run.history.rows[0]["examples"]['captions']) == 36
  73. def test_keras_image_binary_captions(dummy_model, dummy_data, git_repo, run):
  74. wandb.run = run
  75. dummy_model.fit(*dummy_data, epochs=2, batch_size=36, validation_data=dummy_data,
  76. callbacks=[WandbCallback(data_type="image", predictions=10, labels=["Rad", "Nice"])])
  77. print(run.history.rows[0])
  78. assert run.history.rows[0]["examples"]['captions'][0] in ["Rad", "Nice"]
  79. @pytest.mark.multiclass
  80. def test_keras_image_multiclass(dummy_model, dummy_data, git_repo, run):
  81. wandb.run = run
  82. dummy_model.fit(*dummy_data, epochs=2, batch_size=36, validation_data=dummy_data,
  83. callbacks=[WandbCallback(data_type="image", predictions=10)])
  84. print(run.history.rows[0])
  85. assert len(run.history.rows[0]["examples"]['captions']) == 10
  86. @pytest.mark.multiclass
  87. def test_keras_image_multiclass_captions(dummy_model, dummy_data, git_repo, run):
  88. wandb.run = run
  89. dummy_model.fit(*dummy_data, epochs=2, batch_size=36, validation_data=dummy_data,
  90. callbacks=[WandbCallback(data_type="image", predictions=10, labels=["Rad", "Nice", "Fun", "Rad", "Nice", "Fun", "Rad", "Nice", "Fun", "Rad"])])
  91. print(run.history.rows[0])
  92. assert run.history.rows[0]["examples"]['captions'][0] in [
  93. "Rad", "Nice", "Fun"]
  94. @pytest.mark.image_output
  95. def test_keras_image_output(dummy_model, dummy_data, git_repo, run):
  96. wandb.run = run
  97. dummy_model.fit(*dummy_data, epochs=2, batch_size=36, validation_data=dummy_data,
  98. callbacks=[WandbCallback(data_type="image", predictions=10)])
  99. assert run.history.rows[0]["examples"]['count'] == 30
  100. assert run.history.rows[0]["examples"]['grouping'] == 3
  101. def test_keras_log_weights(dummy_model, dummy_data, git_repo, run):
  102. wandb.run = run
  103. dummy_model.fit(*dummy_data, epochs=2, batch_size=36, validation_data=dummy_data,
  104. callbacks=[WandbCallback(data_type="image", log_weights=True)])
  105. print("WHOA", run.history.rows[0].keys())
  106. assert run.history.rows[0]['dense_9.weights']['_type'] == "histogram"
  107. def test_keras_save_model(dummy_model, dummy_data, git_repo, run):
  108. wandb.run = run
  109. dummy_model.fit(*dummy_data, epochs=2, batch_size=36, validation_data=dummy_data,
  110. callbacks=[WandbCallback(data_type="image", save_model=True)])
  111. assert len(glob.glob(run.dir + "/model-best.h5")) == 1
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...