Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

convert_model.lua 3.5 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
  1. -- Copyright (c) 2017-present, Facebook, Inc.
  2. -- All rights reserved.
  3. --
  4. -- This source code is licensed under the license found in the LICENSE file in
  5. -- the root directory of this source tree. An additional grant of patent rights
  6. -- can be found in the PATENTS file in the same directory.
  7. --
  8. -- Usage: convert_model.lua <model_epoch1.th7>
  9. require 'torch'
  10. local fairseq = require 'fairseq'
  11. model = torch.load(arg[1])
  12. function find_weight_norm(container, module)
  13. for _, wn in ipairs(container:listModules()) do
  14. if torch.type(wn) == 'nn.WeightNorm' and wn.modules[1] == module then
  15. return wn
  16. end
  17. end
  18. end
  19. function push_state(dict, key, module)
  20. if torch.type(module) == 'nn.Linear' then
  21. local wn = find_weight_norm(model.module, module)
  22. assert(wn)
  23. dict[key .. '.weight_v'] = wn.v:float()
  24. dict[key .. '.weight_g'] = wn.g:float()
  25. elseif torch.type(module) == 'nn.TemporalConvolutionTBC' then
  26. local wn = find_weight_norm(model.module, module)
  27. assert(wn)
  28. local v = wn.v:float():view(wn.viewOut):transpose(2, 3)
  29. dict[key .. '.weight_v'] = v
  30. dict[key .. '.weight_g'] = wn.g:float():view(module.weight:size(3), 1, 1)
  31. else
  32. dict[key .. '.weight'] = module.weight:float()
  33. end
  34. if module.bias then
  35. dict[key .. '.bias'] = module.bias:float()
  36. end
  37. end
  38. encoder_dict = {}
  39. decoder_dict = {}
  40. combined_dict = {}
  41. function encoder_state(encoder)
  42. luts = encoder:findModules('nn.LookupTable')
  43. push_state(encoder_dict, 'embed_tokens', luts[1])
  44. push_state(encoder_dict, 'embed_positions', luts[2])
  45. fcs = encoder:findModules('nn.Linear')
  46. assert(#fcs >= 2)
  47. local nInputPlane = fcs[1].weight:size(1)
  48. push_state(encoder_dict, 'fc1', table.remove(fcs, 1))
  49. push_state(encoder_dict, 'fc2', table.remove(fcs, #fcs))
  50. for i, module in ipairs(encoder:findModules('nn.TemporalConvolutionTBC')) do
  51. push_state(encoder_dict, 'convolutions.' .. tostring(i - 1), module)
  52. if nInputPlane ~= module.weight:size(3) / 2 then
  53. push_state(encoder_dict, 'projections.' .. tostring(i - 1), table.remove(fcs, 1))
  54. end
  55. nInputPlane = module.weight:size(3) / 2
  56. end
  57. assert(#fcs == 0)
  58. end
  59. function decoder_state(decoder)
  60. luts = decoder:findModules('nn.LookupTable')
  61. push_state(decoder_dict, 'embed_tokens', luts[1])
  62. push_state(decoder_dict, 'embed_positions', luts[2])
  63. fcs = decoder:findModules('nn.Linear')
  64. local nInputPlane = fcs[1].weight:size(1)
  65. push_state(decoder_dict, 'fc1', table.remove(fcs, 1))
  66. push_state(decoder_dict, 'fc2', fcs[#fcs - 1])
  67. push_state(decoder_dict, 'fc3', fcs[#fcs])
  68. table.remove(fcs, #fcs)
  69. table.remove(fcs, #fcs)
  70. for i, module in ipairs(decoder:findModules('nn.TemporalConvolutionTBC')) do
  71. if nInputPlane ~= module.weight:size(3) / 2 then
  72. push_state(decoder_dict, 'projections.' .. tostring(i - 1), table.remove(fcs, 1))
  73. end
  74. nInputPlane = module.weight:size(3) / 2
  75. local prefix = 'attention.' .. tostring(i - 1)
  76. push_state(decoder_dict, prefix .. '.in_projection', table.remove(fcs, 1))
  77. push_state(decoder_dict, prefix .. '.out_projection', table.remove(fcs, 1))
  78. push_state(decoder_dict, 'convolutions.' .. tostring(i - 1), module)
  79. end
  80. assert(#fcs == 0)
  81. end
  82. _encoder = model.module.modules[2]
  83. _decoder = model.module.modules[3]
  84. encoder_state(_encoder)
  85. decoder_state(_decoder)
  86. for k, v in pairs(encoder_dict) do
  87. combined_dict['encoder.' .. k] = v
  88. end
  89. for k, v in pairs(decoder_dict) do
  90. combined_dict['decoder.' .. k] = v
  91. end
  92. torch.save('state_dict.t7', combined_dict)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...