Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

config.py 11 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
  1. from math import ceil
  2. from typing import Optional
  3. import logging
  4. from argparse import ArgumentParser
  5. import sys
  6. import os
  7. class Config:
  8. @classmethod
  9. def arguments_parser(cls) -> ArgumentParser:
  10. parser = ArgumentParser()
  11. parser.add_argument("-d", "--data", dest="data_path",
  12. help="path to preprocessed dataset", required=False)
  13. parser.add_argument("-te", "--test", dest="test_path",
  14. help="path to test file", metavar="FILE", required=False, default='')
  15. parser.add_argument("-s", "--save", dest="save_path",
  16. help="path to save the model file", metavar="FILE", required=False)
  17. parser.add_argument("-w2v", "--save_word2v", dest="save_w2v",
  18. help="path to save the tokens embeddings file", metavar="FILE", required=False)
  19. parser.add_argument("-t2v", "--save_target2v", dest="save_t2v",
  20. help="path to save the targets embeddings file", metavar="FILE", required=False)
  21. parser.add_argument("-l", "--load", dest="load_path",
  22. help="path to load the model from", metavar="FILE", required=False)
  23. parser.add_argument('--save_w2v', dest='save_w2v', required=False,
  24. help="save word (token) vectors in word2vec format")
  25. parser.add_argument('--save_t2v', dest='save_t2v', required=False,
  26. help="save target vectors in word2vec format")
  27. parser.add_argument('--export_code_vectors', action='store_true', required=False,
  28. help="export code vectors for the given examples")
  29. parser.add_argument('--release', action='store_true',
  30. help='if specified and loading a trained model, release the loaded model for a lower model '
  31. 'size.')
  32. parser.add_argument('--predict', action='store_true',
  33. help='execute the interactive prediction shell')
  34. parser.add_argument("-fw", "--framework", dest="dl_framework", choices=['keras', 'tensorflow'],
  35. default='tensorflow', help="deep learning framework to use.")
  36. parser.add_argument("-v", "--verbose", dest="verbose_mode", type=int, required=False, default=1,
  37. help="verbose mode (should be in {0,1,2}).")
  38. parser.add_argument("-lp", "--logs-path", dest="logs_path", metavar="FILE", required=False,
  39. help="path to store logs into. if not given logs are not saved to file.")
  40. parser.add_argument('-tb', '--tensorboard', dest='use_tensorboard', action='store_true',
  41. help='use tensorboard during training')
  42. return parser
  43. def set_defaults(self):
  44. self.NUM_TRAIN_EPOCHS = 50
  45. self.SAVE_EVERY_EPOCHS = 1
  46. self.TRAIN_BATCH_SIZE = 512
  47. self.TEST_BATCH_SIZE = self.TRAIN_BATCH_SIZE
  48. self.TOP_K_WORDS_CONSIDERED_DURING_PREDICTION = 3
  49. self.NUM_BATCHES_TO_LOG_PROGRESS = 100
  50. self.NUM_TRAIN_BATCHES_TO_EVALUATE = 100
  51. self.READER_NUM_PARALLEL_BATCHES = 6 # cpu cores [for tf.contrib.data.map_and_batch() in the reader]
  52. self.SHUFFLE_BUFFER_SIZE = 10000
  53. self.CSV_BUFFER_SIZE = 150 * 1024 * 1024 # 100 MB
  54. self.MAX_TO_KEEP = 10
  55. # model hyper-params
  56. self.MAX_CONTEXTS = 200
  57. self.MAX_TOKEN_VOCAB_SIZE = 50000
  58. self.MAX_TARGET_VOCAB_SIZE = 3
  59. self.MAX_PATH_VOCAB_SIZE = 50000
  60. self.DEFAULT_EMBEDDINGS_SIZE = 128
  61. self.TOKEN_EMBEDDINGS_SIZE = self.DEFAULT_EMBEDDINGS_SIZE
  62. self.PATH_EMBEDDINGS_SIZE = self.DEFAULT_EMBEDDINGS_SIZE
  63. self.CODE_VECTOR_SIZE = self.context_vector_size
  64. self.TARGET_EMBEDDINGS_SIZE = self.CODE_VECTOR_SIZE
  65. self.DROPOUT_KEEP_RATE = 0.5
  66. self.SEPARATE_OOV_AND_PAD = False
  67. def load_from_args(self):
  68. args = self.arguments_parser().parse_args()
  69. # Automatically filled, do not edit:
  70. self.PREDICT = args.predict
  71. self.MODEL_SAVE_PATH = args.save_path
  72. self.MODEL_LOAD_PATH = args.load_path
  73. self.TRAIN_DATA_PATH_PREFIX = args.data_path
  74. self.TEST_DATA_PATH = args.test_path
  75. self.RELEASE = args.release
  76. self.EXPORT_CODE_VECTORS = args.export_code_vectors
  77. self.SAVE_W2V = args.save_w2v
  78. self.SAVE_T2V = args.save_t2v
  79. self.VERBOSE_MODE = args.verbose_mode
  80. self.LOGS_PATH = args.logs_path
  81. self.DL_FRAMEWORK = 'tensorflow' if not args.dl_framework else args.dl_framework
  82. self.USE_TENSORBOARD = args.use_tensorboard
  83. def __init__(self, set_defaults: bool = False, load_from_args: bool = False, verify: bool = False):
  84. self.NUM_TRAIN_EPOCHS: int = 0
  85. self.SAVE_EVERY_EPOCHS: int = 0
  86. self.TRAIN_BATCH_SIZE: int = 0
  87. self.TEST_BATCH_SIZE: int = 0
  88. self.TOP_K_WORDS_CONSIDERED_DURING_PREDICTION: int = 0
  89. self.NUM_BATCHES_TO_LOG_PROGRESS: int = 0
  90. self.NUM_TRAIN_BATCHES_TO_EVALUATE: int = 0
  91. self.READER_NUM_PARALLEL_BATCHES: int = 0
  92. self.SHUFFLE_BUFFER_SIZE: int = 0
  93. self.CSV_BUFFER_SIZE: int = 0
  94. self.MAX_TO_KEEP: int = 0
  95. # model hyper-params
  96. self.MAX_CONTEXTS: int = 0
  97. self.MAX_TOKEN_VOCAB_SIZE: int = 0
  98. self.MAX_TARGET_VOCAB_SIZE: int = 0
  99. self.MAX_PATH_VOCAB_SIZE: int = 0
  100. self.DEFAULT_EMBEDDINGS_SIZE: int = 0
  101. self.TOKEN_EMBEDDINGS_SIZE: int = 0
  102. self.PATH_EMBEDDINGS_SIZE: int = 0
  103. self.CODE_VECTOR_SIZE: int = 0
  104. self.TARGET_EMBEDDINGS_SIZE: int = 0
  105. self.DROPOUT_KEEP_RATE: float = 0
  106. self.SEPARATE_OOV_AND_PAD: bool = False
  107. # Automatically filled by `args`.
  108. self.PREDICT: bool = False # TODO: update README;
  109. self.MODEL_SAVE_PATH: Optional[str] = None
  110. self.MODEL_LOAD_PATH: Optional[str] = None
  111. self.TRAIN_DATA_PATH_PREFIX: Optional[str] = None
  112. self.TEST_DATA_PATH: Optional[str] = ''
  113. self.RELEASE: bool = False
  114. self.EXPORT_CODE_VECTORS: bool = False
  115. self.SAVE_W2V: Optional[str] = None # TODO: update README;
  116. self.SAVE_T2V: Optional[str] = None # TODO: update README;
  117. self.VERBOSE_MODE: int = 0
  118. self.LOGS_PATH: Optional[str] = None
  119. self.DL_FRAMEWORK: str = '' # in {'keras', 'tensorflow'}
  120. self.USE_TENSORBOARD: bool = False
  121. # Automatically filled by `Code2VecModelBase._init_num_of_examples()`.
  122. self.NUM_TRAIN_EXAMPLES: int = 0
  123. self.NUM_TEST_EXAMPLES: int = 0
  124. self.__logger: Optional[logging.Logger] = None
  125. if set_defaults:
  126. self.set_defaults()
  127. if load_from_args:
  128. self.load_from_args()
  129. if verify:
  130. self.verify()
  131. @property
  132. def context_vector_size(self) -> int:
  133. # The context vector is actually a concatenation of the embedded
  134. # source & target vectors and the embedded path vector.
  135. return self.PATH_EMBEDDINGS_SIZE + 2 * self.TOKEN_EMBEDDINGS_SIZE
  136. @property
  137. def is_training(self) -> bool:
  138. return bool(self.TRAIN_DATA_PATH_PREFIX)
  139. @property
  140. def is_loading(self) -> bool:
  141. return bool(self.MODEL_LOAD_PATH)
  142. @property
  143. def is_saving(self) -> bool:
  144. return bool(self.MODEL_SAVE_PATH)
  145. @property
  146. def is_testing(self) -> bool:
  147. return bool(self.TEST_DATA_PATH)
  148. @property
  149. def train_steps_per_epoch(self) -> int:
  150. return ceil(self.NUM_TRAIN_EXAMPLES / self.TRAIN_BATCH_SIZE) if self.TRAIN_BATCH_SIZE else 0
  151. @property
  152. def test_steps(self) -> int:
  153. return ceil(self.NUM_TEST_EXAMPLES / self.TEST_BATCH_SIZE) if self.TEST_BATCH_SIZE else 0
  154. def data_path(self, is_evaluating: bool = False):
  155. return self.TEST_DATA_PATH if is_evaluating else self.train_data_path
  156. def batch_size(self, is_evaluating: bool = False):
  157. return self.TEST_BATCH_SIZE if is_evaluating else self.TRAIN_BATCH_SIZE # take min with NUM_TRAIN_EXAMPLES?
  158. @property
  159. def train_data_path(self) -> Optional[str]:
  160. if not self.is_training:
  161. return None
  162. return '{}.train.c2v'.format(self.TRAIN_DATA_PATH_PREFIX)
  163. @property
  164. def word_freq_dict_path(self) -> Optional[str]:
  165. if not self.is_training:
  166. return None
  167. return '{}.dict.c2v'.format(self.TRAIN_DATA_PATH_PREFIX)
  168. @classmethod
  169. def get_vocabularies_path_from_model_path(cls, model_file_path: str) -> str:
  170. vocabularies_save_file_name = "dictionaries.bin"
  171. return '/'.join(model_file_path.split('/')[:-1] + [vocabularies_save_file_name])
  172. @classmethod
  173. def get_entire_model_path(cls, model_path: str) -> str:
  174. return model_path + '__entire-model'
  175. @classmethod
  176. def get_model_weights_path(cls, model_path: str) -> str:
  177. return model_path + '__only-weights'
  178. @property
  179. def model_load_dir(self):
  180. return '/'.join(self.MODEL_LOAD_PATH.split('/')[:-1])
  181. @property
  182. def entire_model_load_path(self) -> Optional[str]:
  183. if not self.is_loading:
  184. return None
  185. return self.get_entire_model_path(self.MODEL_LOAD_PATH)
  186. @property
  187. def model_weights_load_path(self) -> Optional[str]:
  188. if not self.is_loading:
  189. return None
  190. return self.get_model_weights_path(self.MODEL_LOAD_PATH)
  191. @property
  192. def entire_model_save_path(self) -> Optional[str]:
  193. if not self.is_saving:
  194. return None
  195. return self.get_entire_model_path(self.MODEL_SAVE_PATH)
  196. @property
  197. def model_weights_save_path(self) -> Optional[str]:
  198. if not self.is_saving:
  199. return None
  200. return self.get_model_weights_path(self.MODEL_SAVE_PATH)
  201. def verify(self):
  202. if not self.is_training and not self.is_loading:
  203. raise ValueError("Must train or load a model.")
  204. if self.is_loading and not os.path.isdir(self.model_load_dir):
  205. raise ValueError("Model load dir `{model_load_dir}` does not exist.".format(
  206. model_load_dir=self.model_load_dir))
  207. if self.DL_FRAMEWORK not in {'tensorflow', 'keras'}:
  208. raise ValueError("config.DL_FRAMEWORK must be in {'tensorflow', 'keras'}.")
  209. def __iter__(self):
  210. for attr_name in dir(self):
  211. if attr_name.startswith("__"):
  212. continue
  213. try:
  214. attr_value = getattr(self, attr_name, None)
  215. except:
  216. attr_value = None
  217. if callable(attr_value):
  218. continue
  219. yield attr_name, attr_value
  220. def get_logger(self) -> logging.Logger:
  221. if self.__logger is None:
  222. self.__logger = logging.getLogger('code2vec')
  223. self.__logger.setLevel(logging.INFO)
  224. self.__logger.handlers = []
  225. self.__logger.propagate = 0
  226. formatter = logging.Formatter('%(asctime)s %(levelname)-8s %(message)s')
  227. if self.VERBOSE_MODE >= 1:
  228. ch = logging.StreamHandler(sys.stdout)
  229. ch.setLevel(logging.INFO)
  230. ch.setFormatter(formatter)
  231. self.__logger.addHandler(ch)
  232. if self.LOGS_PATH:
  233. fh = logging.FileHandler(self.LOGS_PATH)
  234. fh.setLevel(logging.INFO)
  235. fh.setFormatter(formatter)
  236. self.__logger.addHandler(fh)
  237. return self.__logger
  238. def log(self, msg):
  239. self.get_logger().info(msg)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...