1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
|
- import yaml
- import os
- import sys
- import logging
- from .api import Error
- from wandb import __stage_dir__
- def boolify(s):
- if s.lower() == 'none':
- return None
- if s.lower() == 'true':
- return True
- if s.lower() == 'false':
- return False
- raise ValueError("Not a boolean")
- class Config(dict):
- """Creates a W&B config object.
- The object is an enhanced `dict`. You can access keys via instance methods or
- as you would a regular `dict`. The object first looks for a `config-defaults.yaml` file
- in the current directory. It then looks for environment variables pre-pended
- with "WANDB_". Lastly it overrides any key found in command line arguments.
- Using the config objects enables W&B to track all configuration parameters used
- in your training runs.
- Args:
- config(:obj:`dict`, optional): Key value pairs from your existing code that
- you would like to track. You can also pass in objects that respond to `__dict__`
- such as from argparse.
- """
- def __init__(self, config={}):
- if not isinstance(config, dict):
- try:
- # for tensorflow flags
- if "__flags" in dir(config):
- if not config.__parsed:
- config._parse_flags()
- config = config.__flags
- else:
- config = vars(config)
- except TypeError:
- raise TypeError(
- "config must be a dict or have a __dict__ attribute.")
- dict.__init__(self, {})
- self._descriptions = {}
- # we only persist when _external is True
- self._external = False
- self.load_defaults()
- self.load_env()
- self.load_overrides()
- for key in config:
- self[key] = config[key]
- self._external = True
- self.persist(overrides=True)
- @property
- def config_dir(self):
- """The config directory holding the latest configuration"""
- return os.path.join(os.getcwd(), __stage_dir__)
- @property
- def defaults_path(self):
- """Where to find the default configuration"""
- return os.getcwd() + "/config-defaults.yaml"
- @property
- def keys(self):
- """All keys in the current configuration"""
- return [key for key in self if not key.startswith("_")]
- def desc(self, key):
- """The description of a given key"""
- return self._descriptions.get(key)
- def convert(self, ob):
- """Type casting for Boolean, None, Int and Float"""
- # TODO: deeper type casting
- if isinstance(ob, dict) or isinstance(ob, list):
- return ob
- for fn in (boolify, int, float):
- try:
- return fn(str(ob))
- except ValueError:
- pass
- return str(ob)
- def load_json(self, json):
- """Loads existing config from JSON"""
- for key in json:
- self[key] = json[key].get('value')
- self._descriptions[key] = json[key].get('desc')
- def load_defaults(self):
- """Load defaults from YAML"""
- if os.path.exists(self.defaults_path):
- try:
- defaults = yaml.load(open(self.defaults_path))
- except yaml.parser.ParserError:
- raise Error("Invalid YAML in config-defaults.yaml")
- if defaults:
- for key in defaults:
- if key == "wandb_version":
- continue
- self[key] = defaults[key].get('value')
- self._descriptions[key] = defaults[key].get('desc')
- else:
- logging.info(
- "Couldn't load default config, run `wandb config init` in this directory")
- def load_overrides(self):
- """Load overrides from command line arguments"""
- for arg in sys.argv:
- key_value = arg.split("=")
- if len(key_value) == 2:
- key = key_value[0].replace("--", "")
- if self.get(key):
- self[key] = self.convert(key_value[1])
- def load_env(self):
- """Load overrides from the environment"""
- for key in [key for key in os.environ if key.startswith("WANDB")]:
- value = os.environ[key]
- self[key.replace("WANDB_", "").lower()] = self.convert(value)
- def persist(self, overrides=False):
- """Stores the current configuration for pushing to W&B"""
- if overrides:
- path = "{dir}/latest.yaml".format(dir=self.config_dir)
- else:
- path = self.defaults_path
- try:
- with open(path, "w") as defaults_file:
- defaults_file.write(str(self))
- return True
- except IOError:
- logging.warn(
- "Unable to persist config, no wandb directory exists. Run `wandb config init` in this directory.")
- return False
- def __getitem__(self, name):
- return super(Config, self).__getitem__(name)
- def __setitem__(self, key, value):
- # TODO: this feels gross
- if key.endswith("_desc"):
- parts = key.split("_")
- parts.pop()
- self._descriptions["_".join(parts)] = str(value)
- else:
- # TODO: maybe don't convert, but otherwise python3 dumps unicode
- super(Config, self).__setitem__(key, self.convert(value))
- if not key.startswith("_") and self._external:
- self.persist(overrides=True)
- return value
- def __getattr__(self, name):
- return self.get(name)
- __setattr__ = __setitem__
- @property
- def __dict__(self):
- defaults = {}
- for key in self.keys:
- defaults[key] = {'value': self[key],
- 'desc': self._descriptions.get(key)}
- return defaults
- def __repr__(self):
- rep = "\n".join([str({
- 'key': key,
- 'desc': self._descriptions.get(key),
- 'value': self[key]
- }) for key in self.keys])
- return rep
- def __str__(self):
- s = "wandb_version: 1\n\n"
- if self.__dict__:
- s += yaml.dump(self.__dict__, default_flow_style=False)
- return s
|