Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

tokenizer.py 4.0 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
  1. #!/usr/bin/env python3
  2. # Copyright 2023 (authors: Feiteng Li)
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import re
  16. from dataclasses import asdict, dataclass
  17. from typing import Any, Dict, List, Optional, Pattern, Union
  18. import numpy as np
  19. import torch
  20. import torchaudio
  21. from encodec import EncodecModel
  22. from encodec.utils import convert_audio
  23. def remove_encodec_weight_norm(model):
  24. from encodec.modules import SConv1d
  25. from encodec.modules.seanet import SConvTranspose1d, SEANetResnetBlock
  26. from torch.nn.utils import remove_weight_norm
  27. encoder = model.encoder.model
  28. for key in encoder._modules:
  29. if isinstance(encoder._modules[key], SEANetResnetBlock):
  30. remove_weight_norm(encoder._modules[key].shortcut.conv.conv)
  31. block_modules = encoder._modules[key].block._modules
  32. for skey in block_modules:
  33. if isinstance(block_modules[skey], SConv1d):
  34. remove_weight_norm(block_modules[skey].conv.conv)
  35. elif isinstance(encoder._modules[key], SConv1d):
  36. remove_weight_norm(encoder._modules[key].conv.conv)
  37. decoder = model.decoder.model
  38. for key in decoder._modules:
  39. if isinstance(decoder._modules[key], SEANetResnetBlock):
  40. remove_weight_norm(decoder._modules[key].shortcut.conv.conv)
  41. block_modules = decoder._modules[key].block._modules
  42. for skey in block_modules:
  43. if isinstance(block_modules[skey], SConv1d):
  44. remove_weight_norm(block_modules[skey].conv.conv)
  45. elif isinstance(decoder._modules[key], SConvTranspose1d):
  46. remove_weight_norm(decoder._modules[key].convtr.convtr)
  47. elif isinstance(decoder._modules[key], SConv1d):
  48. remove_weight_norm(decoder._modules[key].conv.conv)
  49. class AudioTokenizer:
  50. """EnCodec audio."""
  51. def __init__(
  52. self,
  53. device: Any = None,
  54. ) -> None:
  55. # Instantiate a pretrained EnCodec model
  56. model = EncodecModel.encodec_model_24khz()
  57. model.set_target_bandwidth(6.0)
  58. remove_encodec_weight_norm(model)
  59. if not device:
  60. device = torch.device("cpu")
  61. if torch.cuda.is_available():
  62. device = torch.device("cuda:0")
  63. self._device = device
  64. self.codec = model.to(device)
  65. self.sample_rate = model.sample_rate
  66. self.channels = model.channels
  67. @property
  68. def device(self):
  69. return self._device
  70. def encode(self, wav: torch.Tensor) -> torch.Tensor:
  71. return self.codec.encode(wav.to(self.device))
  72. def decode(self, frames: torch.Tensor) -> torch.Tensor:
  73. return self.codec.decode(frames)
  74. def tokenize_audio(tokenizer: AudioTokenizer, audio):
  75. # Load and pre-process the audio waveform
  76. if isinstance(audio, str):
  77. wav, sr = torchaudio.load(audio)
  78. else:
  79. wav, sr = audio
  80. wav = convert_audio(wav, sr, tokenizer.sample_rate, tokenizer.channels)
  81. wav = wav.unsqueeze(0)
  82. # Extract discrete codes from EnCodec
  83. with torch.no_grad():
  84. encoded_frames = tokenizer.encode(wav)
  85. return encoded_frames
  86. if __name__ == "__main__":
  87. model = EncodecModel.encodec_model_24khz()
  88. model.set_target_bandwidth(6.0)
  89. samples = torch.from_numpy(np.random.random([4, 1, 1600])).type(
  90. torch.float32
  91. )
  92. codes_raw = model.encode(samples)
  93. remove_encodec_weight_norm(model)
  94. codes_norm = model.encode(samples)
  95. assert torch.allclose(codes_raw[0][0], codes_norm[0][0])
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...