Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

config.py 6.1 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
  1. import yaml
  2. import os
  3. import sys
  4. import logging
  5. from .api import Error
  6. from wandb import __stage_dir__
  7. def boolify(s):
  8. if s.lower() == 'none':
  9. return None
  10. if s.lower() == 'true':
  11. return True
  12. if s.lower() == 'false':
  13. return False
  14. raise ValueError("Not a boolean")
  15. class Config(dict):
  16. """Creates a W&B config object.
  17. The object is an enhanced `dict`. You can access keys via instance methods or
  18. as you would a regular `dict`. The object first looks for a `config-defaults.yaml` file
  19. in the current directory. It then looks for environment variables pre-pended
  20. with "WANDB_". Lastly it overrides any key found in command line arguments.
  21. Using the config objects enables W&B to track all configuration parameters used
  22. in your training runs.
  23. Args:
  24. config(:obj:`dict`, optional): Key value pairs from your existing code that
  25. you would like to track. You can also pass in objects that respond to `__dict__`
  26. such as from argparse.
  27. """
  28. def __init__(self, config={}):
  29. if not isinstance(config, dict):
  30. try:
  31. # for tensorflow flags
  32. if "__flags" in dir(config):
  33. if not config.__parsed:
  34. config._parse_flags()
  35. config = config.__flags
  36. else:
  37. config = vars(config)
  38. except TypeError:
  39. raise TypeError(
  40. "config must be a dict or have a __dict__ attribute.")
  41. dict.__init__(self, {})
  42. self._descriptions = {}
  43. # we only persist when _external is True
  44. self._external = False
  45. self.load_defaults()
  46. self.load_env()
  47. self.load_overrides()
  48. for key in config:
  49. self[key] = config[key]
  50. self._external = True
  51. self.persist(overrides=True)
  52. @property
  53. def config_dir(self):
  54. """The config directory holding the latest configuration"""
  55. return os.path.join(os.getcwd(), __stage_dir__)
  56. @property
  57. def defaults_path(self):
  58. """Where to find the default configuration"""
  59. return os.getcwd() + "/config-defaults.yaml"
  60. @property
  61. def keys(self):
  62. """All keys in the current configuration"""
  63. return [key for key in self if not key.startswith("_")]
  64. def desc(self, key):
  65. """The description of a given key"""
  66. return self._descriptions.get(key)
  67. def convert(self, ob):
  68. """Type casting for Boolean, None, Int and Float"""
  69. # TODO: deeper type casting
  70. if isinstance(ob, dict) or isinstance(ob, list):
  71. return ob
  72. for fn in (boolify, int, float):
  73. try:
  74. return fn(str(ob))
  75. except ValueError:
  76. pass
  77. return str(ob)
  78. def load_json(self, json):
  79. """Loads existing config from JSON"""
  80. for key in json:
  81. self[key] = json[key].get('value')
  82. self._descriptions[key] = json[key].get('desc')
  83. def load_defaults(self):
  84. """Load defaults from YAML"""
  85. if os.path.exists(self.defaults_path):
  86. try:
  87. defaults = yaml.load(open(self.defaults_path))
  88. except yaml.parser.ParserError:
  89. raise Error("Invalid YAML in config-defaults.yaml")
  90. if defaults:
  91. for key in defaults:
  92. if key == "wandb_version":
  93. continue
  94. self[key] = defaults[key].get('value')
  95. self._descriptions[key] = defaults[key].get('desc')
  96. else:
  97. logging.info(
  98. "Couldn't load default config, run `wandb config init` in this directory")
  99. def load_overrides(self):
  100. """Load overrides from command line arguments"""
  101. for arg in sys.argv:
  102. key_value = arg.split("=")
  103. if len(key_value) == 2:
  104. key = key_value[0].replace("--", "")
  105. if self.get(key):
  106. self[key] = self.convert(key_value[1])
  107. def load_env(self):
  108. """Load overrides from the environment"""
  109. for key in [key for key in os.environ if key.startswith("WANDB")]:
  110. value = os.environ[key]
  111. self[key.replace("WANDB_", "").lower()] = self.convert(value)
  112. def persist(self, overrides=False):
  113. """Stores the current configuration for pushing to W&B"""
  114. if overrides:
  115. path = "{dir}/latest.yaml".format(dir=self.config_dir)
  116. else:
  117. path = self.defaults_path
  118. try:
  119. with open(path, "w") as defaults_file:
  120. defaults_file.write(str(self))
  121. return True
  122. except IOError:
  123. logging.warn(
  124. "Unable to persist config, no wandb directory exists. Run `wandb config init` in this directory.")
  125. return False
  126. def __getitem__(self, name):
  127. return super(Config, self).__getitem__(name)
  128. def __setitem__(self, key, value):
  129. # TODO: this feels gross
  130. if key.endswith("_desc"):
  131. parts = key.split("_")
  132. parts.pop()
  133. self._descriptions["_".join(parts)] = str(value)
  134. else:
  135. # TODO: maybe don't convert, but otherwise python3 dumps unicode
  136. super(Config, self).__setitem__(key, self.convert(value))
  137. if not key.startswith("_") and self._external:
  138. self.persist(overrides=True)
  139. return value
  140. def __getattr__(self, name):
  141. return self.get(name)
  142. __setattr__ = __setitem__
  143. @property
  144. def __dict__(self):
  145. defaults = {}
  146. for key in self.keys:
  147. defaults[key] = {'value': self[key],
  148. 'desc': self._descriptions.get(key)}
  149. return defaults
  150. def __repr__(self):
  151. rep = "\n".join([str({
  152. 'key': key,
  153. 'desc': self._descriptions.get(key),
  154. 'value': self[key]
  155. }) for key in self.keys])
  156. return rep
  157. def __str__(self):
  158. s = "wandb_version: 1\n\n"
  159. if self.__dict__:
  160. s += yaml.dump(self.__dict__, default_flow_style=False)
  161. return s
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...