1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
|
- #!/usr/bin/env python3
- # Copyright 2023 (authors: Feiteng Li)
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import re
- from dataclasses import asdict, dataclass
- from typing import Any, Dict, List, Optional, Pattern, Union
- import numpy as np
- import torch
- import torchaudio
- from encodec import EncodecModel
- from encodec.utils import convert_audio
- def remove_encodec_weight_norm(model):
- from encodec.modules import SConv1d
- from encodec.modules.seanet import SConvTranspose1d, SEANetResnetBlock
- from torch.nn.utils import remove_weight_norm
- encoder = model.encoder.model
- for key in encoder._modules:
- if isinstance(encoder._modules[key], SEANetResnetBlock):
- remove_weight_norm(encoder._modules[key].shortcut.conv.conv)
- block_modules = encoder._modules[key].block._modules
- for skey in block_modules:
- if isinstance(block_modules[skey], SConv1d):
- remove_weight_norm(block_modules[skey].conv.conv)
- elif isinstance(encoder._modules[key], SConv1d):
- remove_weight_norm(encoder._modules[key].conv.conv)
- decoder = model.decoder.model
- for key in decoder._modules:
- if isinstance(decoder._modules[key], SEANetResnetBlock):
- remove_weight_norm(decoder._modules[key].shortcut.conv.conv)
- block_modules = decoder._modules[key].block._modules
- for skey in block_modules:
- if isinstance(block_modules[skey], SConv1d):
- remove_weight_norm(block_modules[skey].conv.conv)
- elif isinstance(decoder._modules[key], SConvTranspose1d):
- remove_weight_norm(decoder._modules[key].convtr.convtr)
- elif isinstance(decoder._modules[key], SConv1d):
- remove_weight_norm(decoder._modules[key].conv.conv)
- class AudioTokenizer:
- """EnCodec audio."""
- def __init__(
- self,
- device: Any = None,
- ) -> None:
- # Instantiate a pretrained EnCodec model
- model = EncodecModel.encodec_model_24khz()
- model.set_target_bandwidth(6.0)
- remove_encodec_weight_norm(model)
- if not device:
- device = torch.device("cpu")
- if torch.cuda.is_available():
- device = torch.device("cuda:0")
- self._device = device
- self.codec = model.to(device)
- self.sample_rate = model.sample_rate
- self.channels = model.channels
- @property
- def device(self):
- return self._device
- def encode(self, wav: torch.Tensor) -> torch.Tensor:
- return self.codec.encode(wav.to(self.device))
- def decode(self, frames: torch.Tensor) -> torch.Tensor:
- return self.codec.decode(frames)
- def tokenize_audio(tokenizer: AudioTokenizer, audio):
- # Load and pre-process the audio waveform
- if isinstance(audio, str):
- wav, sr = torchaudio.load(audio)
- else:
- wav, sr = audio
- wav = convert_audio(wav, sr, tokenizer.sample_rate, tokenizer.channels)
- wav = wav.unsqueeze(0)
- # Extract discrete codes from EnCodec
- with torch.no_grad():
- encoded_frames = tokenizer.encode(wav)
- return encoded_frames
- if __name__ == "__main__":
- model = EncodecModel.encodec_model_24khz()
- model.set_target_bandwidth(6.0)
- samples = torch.from_numpy(np.random.random([4, 1, 1600])).type(
- torch.float32
- )
- codes_raw = model.encode(samples)
- remove_encodec_weight_norm(model)
- codes_norm = model.encode(samples)
- assert torch.allclose(codes_raw[0][0], codes_norm[0][0])
|