Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

optimization_test.py 1.7 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
  1. # coding=utf-8
  2. # Copyright 2018 The Google AI Language Team Authors.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from __future__ import absolute_import
  16. from __future__ import division
  17. from __future__ import print_function
  18. import optimization
  19. import tensorflow as tf
  20. class OptimizationTest(tf.test.TestCase):
  21. def test_adam(self):
  22. with self.test_session() as sess:
  23. w = tf.get_variable(
  24. "w",
  25. shape=[3],
  26. initializer=tf.constant_initializer([0.1, -0.2, -0.1]))
  27. x = tf.constant([0.4, 0.2, -0.5])
  28. loss = tf.reduce_mean(tf.square(x - w))
  29. tvars = tf.trainable_variables()
  30. grads = tf.gradients(loss, tvars)
  31. global_step = tf.train.get_or_create_global_step()
  32. optimizer = optimization.AdamWeightDecayOptimizer(learning_rate=0.2)
  33. train_op = optimizer.apply_gradients(zip(grads, tvars), global_step)
  34. init_op = tf.group(tf.global_variables_initializer(),
  35. tf.local_variables_initializer())
  36. sess.run(init_op)
  37. for _ in range(100):
  38. sess.run(train_op)
  39. w_np = sess.run(w)
  40. self.assertAllClose(w_np.flat, [0.4, 0.2, -0.5], rtol=1e-2, atol=1e-2)
  41. if __name__ == "__main__":
  42. tf.test.main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...