Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

sync.py 9.0 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
  1. import psutil
  2. import os
  3. import stat
  4. import sys
  5. import time
  6. import traceback
  7. from tempfile import NamedTemporaryFile
  8. from watchdog.observers import Observer
  9. from watchdog.events import PatternMatchingEventHandler
  10. from shortuuid import ShortUUID
  11. import atexit
  12. from .config import Config
  13. import logging
  14. import threading
  15. from six.moves import queue
  16. import socket
  17. import click
  18. from wandb import __stage_dir__, Error
  19. from wandb import streaming_log
  20. from wandb import util
  21. from .run import Run
  22. logger = logging.getLogger(__name__)
  23. def editor(content='', marker='# Before we start this run, enter a brief description. (to skip, direct stdin to dev/null: `python train.py < /dev/null`)\n'):
  24. message = click.edit(content + '\n\n' + marker)
  25. if message is None:
  26. return None
  27. return message.split(marker, 1)[0].rstrip('\n')
  28. class OutStreamTee(object):
  29. """Tees a writable filelike object.
  30. writes/flushes to the passed in stream will go to the stream
  31. and a second stream.
  32. """
  33. def __init__(self, stream, second_stream):
  34. """Constructor.
  35. Args:
  36. stream: stream to tee.
  37. second_stream: stream to duplicate writes to.
  38. """
  39. self._orig_stream = stream
  40. self._second_stream = second_stream
  41. self._queue = queue.Queue()
  42. self._thread = threading.Thread(target=self._thread_body)
  43. self._thread.daemon = True
  44. self._thread.start()
  45. def _thread_body(self):
  46. while True:
  47. item = self._queue.get()
  48. if item is None:
  49. break
  50. self._second_stream.write(item)
  51. def write(self, message):
  52. #print('writing to orig: ', self._orig_stream)
  53. self._orig_stream.write(message)
  54. # print('queueing')
  55. self._queue.put(message)
  56. def flush(self):
  57. self._orig_stream.flush()
  58. def close(self):
  59. self._queue.put(None)
  60. class ExitHooks(object):
  61. def __init__(self):
  62. self.exit_code = None
  63. self.exception = None
  64. def hook(self):
  65. self._orig_exit = sys.exit
  66. self._orig_excepthook = sys.excepthook
  67. sys.exit = self.exit
  68. sys.excepthook = self.excepthook
  69. def exit(self, code=0):
  70. self.exit_code = code
  71. self._orig_exit(code)
  72. def excepthook(self, exc_type, exc, *args):
  73. self.exception = exc
  74. self._orig_excepthook(exc_type, exc, *args)
  75. class Sync(object):
  76. """Watches for files to change and automatically pushes them
  77. """
  78. def __init__(self, api, run=None, project=None, tags=[], datasets=[], config={}, description=None, dir=None):
  79. # 1.6 million 6 character combinations
  80. runGen = ShortUUID(alphabet=list(
  81. "0123456789abcdefghijklmnopqrstuvwxyz"))
  82. self.run_id = run or runGen.random(6)
  83. self._project = project or api.settings("project")
  84. self._entity = api.settings("entity")
  85. logger.debug("Initialized sync for %s/%s", self._project, self.run_id)
  86. self._dpath = os.path.join(__stage_dir__, 'description.md')
  87. self._description = description or (os.path.exists(self._dpath) and open(
  88. self._dpath).read()) or os.getenv('WANDB_DESCRIPTION')
  89. try:
  90. self.tty = sys.stdin.isatty() and os.getpgrp() == os.tcgetpgrp(sys.stdout.fileno())
  91. except OSError:
  92. self.tty = False
  93. if not os.getenv('DEBUG') and not self._description and self.tty:
  94. self._description = editor()
  95. if self._description is None:
  96. sys.stderr.write('No description provided, aborting run.\n')
  97. sys.exit(1)
  98. self._config = Config(config)
  99. self._proc = psutil.Process(os.getpid())
  100. self._api = api
  101. self._tags = tags
  102. self._handler = PatternMatchingEventHandler()
  103. self._handler.on_created = self.add
  104. self._handler.on_modified = self.push
  105. base_url = api.settings('base_url')
  106. if base_url.endswith('.dev'):
  107. base_url = 'http://app.dev'
  108. self.url = "{base}/{entity}/{project}/runs/{run}".format(
  109. project=self._project,
  110. entity=self._entity,
  111. run=self.run_id,
  112. base=base_url
  113. )
  114. self._hooks = ExitHooks()
  115. self._hooks.hook()
  116. self._observer = Observer()
  117. if dir is None:
  118. self._watch_dir = os.path.join(
  119. __stage_dir__, 'run-%s' % self.run_id)
  120. util.mkdir_exists_ok(self._watch_dir)
  121. else:
  122. self._watch_dir = os.path.abspath(dir)
  123. self._observer.schedule(self._handler, self._watch_dir, recursive=True)
  124. self.run = Run(self.run_id, self._watch_dir, self._config)
  125. self._api.set_current_run(self.run_id)
  126. def watch(self, files):
  127. try:
  128. # TODO: better failure handling
  129. self._api.upsert_run(name=self.run_id, project=self._project, entity=self._entity,
  130. config=self._config.__dict__, description=self._description, host=socket.gethostname())
  131. self._handler._patterns = [
  132. os.path.join(self._watch_dir, os.path.normpath(f)) for f in files]
  133. # Ignore hidden files/folders
  134. self._handler._ignore_patterns = ['*/.*']
  135. if os.path.exists(__stage_dir__ + "diff.patch"):
  136. self._api.push("{project}/{run}".format(
  137. project=self._project,
  138. run=self.run_id
  139. ), {"diff.patch": open(__stage_dir__ + "diff.patch", "rb")})
  140. self._observer.start()
  141. print("Syncing %s" % self.url)
  142. # Tee stdout/stderr into our TextOutputStream, which will push lines to the cloud.
  143. self._stdout_stream = streaming_log.TextStreamPusher(
  144. self._api.get_file_stream_api(), 'output.log', prepend_timestamp=True)
  145. sys.stdout = OutStreamTee(sys.stdout, self._stdout_stream)
  146. self._stderr_stream = streaming_log.TextStreamPusher(
  147. self._api.get_file_stream_api(), 'output.log', line_prepend='ERROR',
  148. prepend_timestamp=True)
  149. sys.stderr = OutStreamTee(sys.stderr, self._stderr_stream)
  150. self._stdout_stream.write(" ".join(psutil.Process(
  151. os.getpid()).cmdline()) + "\n\n")
  152. logger.debug("Swapped stdout/stderr")
  153. atexit.register(self.stop)
  154. except KeyboardInterrupt:
  155. self.stop()
  156. except Error:
  157. exc_type, exc_value, exc_traceback = sys.exc_info()
  158. print("!!! Fatal W&B Error: %s" % exc_value)
  159. lines = traceback.format_exception(
  160. exc_type, exc_value, exc_traceback)
  161. logger.error('\n'.join(lines))
  162. def stop(self):
  163. # This is a a heuristic delay to catch files that were written just before
  164. # the end of the script. This is unverified, but theoretically the file
  165. # change notification process used by watchdog (maybe inotify?) is
  166. # asynchronous. It's possible we could miss files if 10s isn't long enough.
  167. # TODO: Guarantee that all files will be saved.
  168. print("Script ended, waiting for final file modifications.")
  169. time.sleep(10.0)
  170. # self.log.tempfile.flush()
  171. print("Pushing log")
  172. slug = "{project}/{run}".format(
  173. project=self._project,
  174. run=self.run_id
  175. )
  176. # self._api.push(
  177. # slug, {"training.log": open(self.log.tempfile.name, "rb")})
  178. os.path.exists(self._dpath) and os.remove(self._dpath)
  179. print("Synced %s" % self.url)
  180. self._stdout_stream.close()
  181. self._stderr_stream.close()
  182. self._api.get_file_stream_api().finish(self._hooks.exception)
  183. try:
  184. self._observer.stop()
  185. self._observer.join()
  186. # TODO: py2 TypeError: PyCObject_AsVoidPtr called with null pointer
  187. except TypeError:
  188. pass
  189. # TODO: py3 SystemError: <built-in function stop> returned a result with an error set
  190. except SystemError:
  191. pass
  192. # TODO: limit / throttle the number of adds / pushes
  193. def add(self, event):
  194. self.push(event)
  195. # TODO: is this blocking the main thread?
  196. def push(self, event):
  197. if os.stat(event.src_path).st_size == 0 or os.path.isdir(event.src_path):
  198. return None
  199. file_name = os.path.relpath(event.src_path, self._watch_dir)
  200. if logger.parent.handlers[0]:
  201. debugLog = logger.parent.handlers[0].stream
  202. else:
  203. debugLog = None
  204. print("Pushing %s" % file_name)
  205. with open(event.src_path, 'rb') as f:
  206. self._api.push(self._project, {file_name: f}, run=self.run_id,
  207. description=self._description, progress=debugLog)
  208. @property
  209. def source_proc(self):
  210. mode = os.fstat(0).st_mode
  211. if not stat.S_ISFIFO(mode):
  212. # stdin is not a pipe
  213. return None
  214. else:
  215. source = self._proc.parent().children()[0]
  216. return None if source == self._proc else source
  217. def echo(self):
  218. print(sys.stdin.read())
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...