1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
|
- #!/usr/bin/env python3
- import argparse
- import collections
- import torch
- import os
- import re
- def average_checkpoints(inputs):
- """Loads checkpoints from inputs and returns a model with averaged weights.
- Args:
- inputs: An iterable of string paths of checkpoints to load from.
- Returns:
- A dict of string keys mapping to various values. The 'model' key
- from the returned dict should correspond to an OrderedDict mapping
- string parameter names to torch Tensors.
- """
- params_dict = collections.OrderedDict()
- params_keys = None
- new_state = None
- for f in inputs:
- state = torch.load(
- f,
- map_location=(
- lambda s, _: torch.serialization.default_restore_location(s, 'cpu')
- ),
- )
- # Copies over the settings from the first checkpoint
- if new_state is None:
- new_state = state
- model_params = state['model']
- model_params_keys = list(model_params.keys())
- if params_keys is None:
- params_keys = model_params_keys
- elif params_keys != model_params_keys:
- raise KeyError(
- 'For checkpoint {}, expected list of params: {}, '
- 'but found: {}'.format(f, params_keys, model_params_keys)
- )
- for k in params_keys:
- if k not in params_dict:
- params_dict[k] = []
- p = model_params[k]
- if isinstance(p, torch.HalfTensor):
- p = p.float()
- params_dict[k].append(p)
- averaged_params = collections.OrderedDict()
- # v should be a list of torch Tensor.
- for k, v in params_dict.items():
- summed_v = None
- for x in v:
- summed_v = summed_v + x if summed_v is not None else x
- averaged_params[k] = summed_v / len(v)
- new_state['model'] = averaged_params
- return new_state
- def last_n_checkpoints(paths, n, update_based, upper_bound=None):
- assert len(paths) == 1
- path = paths[0]
- if update_based:
- pt_regexp = re.compile(r'checkpoint_\d+_(\d+)\.pt')
- else:
- pt_regexp = re.compile(r'checkpoint(\d+)\.pt')
- files = os.listdir(path)
- entries = []
- for f in files:
- m = pt_regexp.fullmatch(f)
- if m is not None:
- sort_key = int(m.group(1))
- if upper_bound is None or sort_key <= upper_bound:
- entries.append((sort_key, m.group(0)))
- if len(entries) < n:
- raise Exception('Found {} checkpoint files but need at least {}', len(entries), n)
- return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)[:n]]
- def main():
- parser = argparse.ArgumentParser(
- description='Tool to average the params of input checkpoints to '
- 'produce a new checkpoint',
- )
- # fmt: off
- parser.add_argument('--inputs', required=True, nargs='+',
- help='Input checkpoint file paths.')
- parser.add_argument('--output', required=True, metavar='FILE',
- help='Write the new checkpoint containing the averaged weights to this path.')
- num_group = parser.add_mutually_exclusive_group()
- num_group.add_argument('--num-epoch-checkpoints', type=int,
- help='if set, will try to find checkpoints with names checkpoint_xx.pt in the path specified by input, '
- 'and average last this many of them.')
- num_group.add_argument('--num-update-checkpoints', type=int,
- help='if set, will try to find checkpoints with names checkpoint_ee_xx.pt in the path specified by input, '
- 'and average last this many of them.')
- parser.add_argument('--checkpoint-upper-bound', type=int,
- help='when using --num-epoch-checkpoints, this will set an upper bound on which checkpoint to use, '
- 'e.g., with --num-epoch-checkpoints=10 --checkpoint-upper-bound=50, checkpoints 41-50 would be averaged.')
- # fmt: on
- args = parser.parse_args()
- print(args)
- num = None
- is_update_based = False
- if args.num_update_checkpoints is not None:
- num = args.num_update_checkpoints
- is_update_based = True
- elif args.num_epoch_checkpoints is not None:
- num = args.num_epoch_checkpoints
- assert args.checkpoint_upper_bound is None or args.num_epoch_checkpoints is not None, \
- '--checkpoint-upper-bound requires --num-epoch-checkpoints'
- assert args.num_epoch_checkpoints is None or args.num_update_checkpoints is None, \
- 'Cannot combine --num-epoch-checkpoints and --num-update-checkpoints'
- if num is not None:
- args.inputs = last_n_checkpoints(
- args.inputs, num, is_update_based, upper_bound=args.checkpoint_upper_bound,
- )
- print('averaging checkpoints: ', args.inputs)
- new_state = average_checkpoints(args.inputs)
- torch.save(new_state, args.output)
- print('Finished writing averaged checkpoint to {}.'.format(args.output))
- if __name__ == '__main__':
- main()
|