Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

wrappers.py 4.4 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
  1. """
  2. Original source code: https://www.tensorflow.org/hub/tutorials/biggan_generation_with_tf_hub
  3. """
  4. import tensorflow as tf
  5. import numpy as np
  6. from scipy.stats import truncnorm
  7. import tensorflow_hub as hub
  8. class BigGAN:
  9. def __init__(self):
  10. # Load a BigGAN generator module
  11. self.module = hub.Module('https://tfhub.dev/deepmind/biggan-deep-512/1')
  12. self.sess = tf.Session()
  13. self.inputs = {k: tf.placeholder(v.dtype, v.get_shape().as_list(), k)
  14. for k, v in self.module.get_input_info_dict().items()}
  15. self.output = self.module(self.inputs)
  16. # Define some functions for sampling and displaying BigGAN images
  17. self.input_z = self.inputs['z']
  18. self.input_y = self.inputs['y']
  19. self.input_trunc = self.inputs['truncation']
  20. self.dim_z = self.input_z.shape.as_list()[1]
  21. self.vocab_size = self.input_y.shape.as_list()[1]
  22. # Create a TensorFlow session and initialize variables
  23. initializer = tf.global_variables_initializer()
  24. self.sess.run(initializer)
  25. @property
  26. def get_y(self):
  27. return self.input_y
  28. @property
  29. def get_z(self):
  30. return self.input_z
  31. @property
  32. def get_trunc(self):
  33. return self.input_trunc
  34. def one_hot(self, index, vocab_size):
  35. index = np.asarray(index)
  36. if len(index.shape) == 0:
  37. index = np.asarray([index])
  38. assert len(index.shape) == 1
  39. num = index.shape[0]
  40. output = np.zeros((num, vocab_size), dtype=np.float32)
  41. output[np.arange(num), index] = 1
  42. return output
  43. def one_hot_if_needed(self, label, vocab_size):
  44. label = np.asarray(label)
  45. if len(label.shape) <= 1:
  46. label = self.one_hot(label, vocab_size)
  47. assert len(label.shape) == 2
  48. return label
  49. def sample(self, noise, label, truncation=1., batch_size=1):
  50. # batch_size=8 was used by default
  51. noise = np.asarray(noise)
  52. label = np.asarray(label)
  53. num = noise.shape[0]
  54. if len(label.shape) == 0:
  55. label = np.asarray([label] * num)
  56. if label.shape[0] != num:
  57. raise ValueError('Got # noise samples ({}) != # label samples ({})'
  58. .format(noise.shape[0], label.shape[0]))
  59. label = self.one_hot_if_needed(label, self.vocab_size)
  60. ims = []
  61. for batch_start in range(0, num, batch_size):
  62. s = slice(batch_start, min(num, batch_start + batch_size))
  63. feed_dict = {self.input_z: noise[s], self.input_y: label[s], self.input_trunc: truncation}
  64. ims.append(self.sess.run(self.output, feed_dict=feed_dict))
  65. ims = np.concatenate(ims, axis=0)
  66. assert ims.shape[0] == num
  67. ims = np.clip(((ims + 1) / 2.0) * 256, 0, 255)
  68. ims = np.uint8(ims)
  69. return ims
  70. def sample_latent(self, seed, truncation, batch_size=1):
  71. state = None if seed is None else np.random.RandomState(seed)
  72. values = truncnorm.rvs(-2, 2, size=(batch_size, self.dim_z), random_state=state)
  73. return truncation * values
  74. def truncated_z_sample(self, batch_size, truncation=1., seed=None):
  75. state = None if seed is None else np.random.RandomState(seed)
  76. values = truncnorm.rvs(-2, 2, size=(batch_size, self.dim_z), random_state=state)
  77. return truncation * values
  78. def get_latent_dims(self):
  79. return self.dim_z
  80. def partial_forward(self, z, y, truncation):
  81. # TODO: Ideally this should work with batch > 1. However it seems to throw an Invalid Input error.
  82. # seed = tf.get_default_graph().get_tensor_by_name('module/Generator_1/GenZ/G_linear/add_8:0')
  83. # feed_dict = {self.input_z: np.asarray(z), self.input_y: self.one_hot_if_needed(np.asarray(y), self.vocab_size), self.input_trunc: truncation}
  84. # return seed.eval(feed_dict=feed_dict, session=self.sess)
  85. seed = tf.get_default_graph().get_tensor_by_name('module_apply_default/Generator_1/GenZ/G_linear/add_8:0')
  86. feed_dict = {self.input_z: np.asarray(z), self.input_y: self.one_hot_if_needed(np.asarray(y), self.vocab_size), self.input_trunc: truncation}
  87. return seed.eval(feed_dict=feed_dict, session=self.sess)
  88. def write_layers(self):
  89. with open('biggan_layers.txt', 'w') as f:
  90. for item in tf.get_default_graph().get_operations(): # 147646
  91. f.write("%s\n" % str(item.values()))
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...