Upload 23 files
Browse files- README.md +4 -3
- __init__.py +0 -0
- config.json +11 -0
- configuration_bigcodec.py +19 -0
- modeling_xcodec2.py +165 -0
- module.py +0 -0
- test.py +21 -0
- vq/__init__.py +4 -0
- vq/activations.py +120 -0
- vq/alias_free_torch/__init__.py +6 -0
- vq/alias_free_torch/act.py +28 -0
- vq/alias_free_torch/filter.py +95 -0
- vq/alias_free_torch/resample.py +49 -0
- vq/blocks.py +183 -0
- vq/bs_roformer5.py +123 -0
- vq/codec_decoder.py +304 -0
- vq/codec_decoder_vocos.py +638 -0
- vq/codec_encoder.py +335 -0
- vq/factorized_vector_quantize.py +109 -0
- vq/module.py +420 -0
- vq/residual_vq.py +53 -0
- vq/unet.py +210 -0
README.md
CHANGED
@@ -17,9 +17,9 @@ language:
|
|
17 |
- ko
|
18 |
---
|
19 |
|
20 |
-
# 🗣️ XCodec2
|
21 |
|
22 |
-
This model
|
23 |
|
24 |
---
|
25 |
|
@@ -110,7 +110,7 @@ Reconstruction metrics are computed over 100 samples for English, Japanese, and
|
|
110 |
- Japanese: 100 Examples (Emilia Dataset @ 24 kHz)
|
111 |
- Bangla: 100 Examples (Inhouse TTS Dataset @ 22.05 kHz)
|
112 |
|
113 |
-
| Model | Lang | MCD ↓ | MSE ↑ |
|
114 |
|-------------------|------|--------|--------|-------------|--------|-------------|
|
115 |
| **XCODEC** | BN | 2.823 | 0.003 | 0.939 | 0.500 | 0.816 |
|
116 |
| | EN | 3.166 | 0.012 | 0.962 | 0.660 | 0.856 |
|
@@ -129,6 +129,7 @@ Reconstruction metrics are computed over 100 samples for English, Japanese, and
|
|
129 |
| | JA | 2.677 | 0.003 | 0.955 | 0.614 | 0.853 |
|
130 |
| **Overall** | | 2.597 | 0.003 | 0.960 | 0.636 | 0.863 |
|
131 |
|
|
|
132 |
---
|
133 |
|
134 |
## ✅ Intended Use
|
|
|
17 |
- ko
|
18 |
---
|
19 |
|
20 |
+
# 🗣️ XCodec2 Trained on 100K Hours of Multilingual Data
|
21 |
|
22 |
+
This [model](https://huggingface.co/HKUSTAudio/xcodec2) is trained on a 100K-hour multilingual dataset across 7 languages. It is optimized for speech representation learning, compression, and high-fidelity reconstruction — particularly useful for TTS and bandwidth-efficient speech synthesis.
|
23 |
|
24 |
---
|
25 |
|
|
|
110 |
- Japanese: 100 Examples (Emilia Dataset @ 24 kHz)
|
111 |
- Bangla: 100 Examples (Inhouse TTS Dataset @ 22.05 kHz)
|
112 |
|
113 |
+
| Model | Lang | MCD ↓ | MSE ↑ | SpeechBERTScore ↑ | SpeechBLEU ↑ | SpeechTokenDist ↑ |
|
114 |
|-------------------|------|--------|--------|-------------|--------|-------------|
|
115 |
| **XCODEC** | BN | 2.823 | 0.003 | 0.939 | 0.500 | 0.816 |
|
116 |
| | EN | 3.166 | 0.012 | 0.962 | 0.660 | 0.856 |
|
|
|
129 |
| | JA | 2.677 | 0.003 | 0.955 | 0.614 | 0.853 |
|
130 |
| **Overall** | | 2.597 | 0.003 | 0.960 | 0.636 | 0.863 |
|
131 |
|
132 |
+
#### SpeechBERTScore, SpeechBLEU and SpeechTokenDistance are calculated using https://github.com/Takaaki-Saeki/DiscreteSpeechMetrics
|
133 |
---
|
134 |
|
135 |
## ✅ Intended Use
|
__init__.py
ADDED
File without changes
|
config.json
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"model_type": "xcodec2",
|
3 |
+
"semantic_hidden_size": 1024,
|
4 |
+
"codec_encoder_hidden_size": 1024,
|
5 |
+
"codec_decoder_hidden_size": 1024,
|
6 |
+
"use_vocos": true,
|
7 |
+
"architectures": [
|
8 |
+
"XCodec2Model"
|
9 |
+
]
|
10 |
+
}
|
11 |
+
|
configuration_bigcodec.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PretrainedConfig
|
2 |
+
|
3 |
+
class BigCodecConfig(PretrainedConfig):
|
4 |
+
model_type = "bigcodec"
|
5 |
+
|
6 |
+
def __init__(
|
7 |
+
self,
|
8 |
+
# 下面这些只是示例超参
|
9 |
+
semantic_hidden_size=1024,
|
10 |
+
codec_encoder_hidden_size=1024,
|
11 |
+
codec_decoder_hidden_size=1024,
|
12 |
+
use_vocos=True,
|
13 |
+
**kwargs
|
14 |
+
):
|
15 |
+
super().__init__(**kwargs)
|
16 |
+
self.semantic_hidden_size = semantic_hidden_size
|
17 |
+
self.codec_encoder_hidden_size = codec_encoder_hidden_size
|
18 |
+
self.codec_decoder_hidden_size = codec_decoder_hidden_size
|
19 |
+
self.use_vocos = use_vocos
|
modeling_xcodec2.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from transformers import PreTrainedModel
|
4 |
+
from configuration_bigcodec import BigCodecConfig
|
5 |
+
|
6 |
+
# 请确保这些模块路径是正确的
|
7 |
+
from vq.codec_encoder import CodecEncoder_Transformer
|
8 |
+
from vq.codec_decoder_vocos import CodecDecoderVocos
|
9 |
+
from vq.module import SemanticEncoder
|
10 |
+
from transformers import AutoFeatureExtractor, Wav2Vec2BertModel
|
11 |
+
|
12 |
+
class XCodec2Model(PreTrainedModel):
|
13 |
+
config_class = BigCodecConfig
|
14 |
+
|
15 |
+
def __init__(self, config: BigCodecConfig):
|
16 |
+
super().__init__(config)
|
17 |
+
|
18 |
+
# 1) 语义模型
|
19 |
+
self.semantic_model = Wav2Vec2BertModel.from_pretrained(
|
20 |
+
"facebook/w2v-bert-2.0",
|
21 |
+
output_hidden_states=True
|
22 |
+
)
|
23 |
+
self.semantic_model.eval()
|
24 |
+
|
25 |
+
self.SemanticEncoder_module = SemanticEncoder(
|
26 |
+
config.semantic_hidden_size,
|
27 |
+
config.semantic_hidden_size,
|
28 |
+
config.semantic_hidden_size
|
29 |
+
)
|
30 |
+
|
31 |
+
# 2) Codec Encoder
|
32 |
+
self.CodecEnc = CodecEncoder_Transformer()
|
33 |
+
|
34 |
+
# 3) Codec Decoder
|
35 |
+
self.generator = CodecDecoderVocos()
|
36 |
+
|
37 |
+
# 4) 两个全连接层
|
38 |
+
self.fc_prior = nn.Linear(2048, 2048)
|
39 |
+
self.fc_post_a = nn.Linear(2048, 1024)
|
40 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0")
|
41 |
+
self.feature_extractor = feature_extractor
|
42 |
+
|
43 |
+
def forward(self, input_waveform, sample_rate=16000):
|
44 |
+
"""
|
45 |
+
这里的 forward 不一定要叫 forward,也可以拆成别的方法;
|
46 |
+
但是如果想兼容 pipeline,需要在 forward 里给出核心逻辑。
|
47 |
+
|
48 |
+
参数:
|
49 |
+
input_waveform: [batch_size, waveform_length]
|
50 |
+
sample_rate: 默认 16000
|
51 |
+
返回:
|
52 |
+
重构后的语音音频 (Tensor)
|
53 |
+
"""
|
54 |
+
# 1) 特征提取
|
55 |
+
# 如果需要 padding,可以在这里做
|
56 |
+
input_features = self.feature_extractor(
|
57 |
+
input_waveform,
|
58 |
+
sampling_rate=sample_rate,
|
59 |
+
return_tensors="pt"
|
60 |
+
).input_features.to(self.device) # [batch, frames, feat_dim]
|
61 |
+
|
62 |
+
# 2) 语义层
|
63 |
+
semantic_output = self.semantic_model(input_features)
|
64 |
+
semantic_hidden_16 = semantic_output.hidden_states[16] # 取第16层
|
65 |
+
semantic_hidden_16 = semantic_hidden_16.transpose(1, 2) # [batch, hidden_dim, frames]
|
66 |
+
semantic_encoded = self.SemanticEncoder_module(semantic_hidden_16)
|
67 |
+
|
68 |
+
# 3) codec encoder
|
69 |
+
wav = input_waveform.unsqueeze(1).to(self.device) # shape: [batch, 1, time]
|
70 |
+
vq_emb = self.CodecEnc(wav) # [batch, time//down, 1024] 只是示例
|
71 |
+
vq_emb = vq_emb.transpose(1, 2) # -> [batch, 1024, frames]
|
72 |
+
|
73 |
+
# 对齐语义向量的时间帧数,这里只做示例处理
|
74 |
+
# 真实做法里可能要先对齐维度
|
75 |
+
if vq_emb.shape[-1] != semantic_encoded.shape[-1]:
|
76 |
+
# 简单强行截断或补零都行,需要你自己决定
|
77 |
+
min_len = min(vq_emb.shape[-1], semantic_encoded.shape[-1])
|
78 |
+
vq_emb = vq_emb[:, :, :min_len]
|
79 |
+
semantic_encoded = semantic_encoded[:, :, :min_len]
|
80 |
+
|
81 |
+
# 4) 拼接
|
82 |
+
concat_emb = torch.cat([semantic_encoded, vq_emb], dim=1) # [batch, 1024 + 1024, frames]
|
83 |
+
|
84 |
+
# 5) fc_prior
|
85 |
+
concat_emb = self.fc_prior(concat_emb.transpose(1, 2)).transpose(1, 2)
|
86 |
+
|
87 |
+
# 6) decoder 的量化部分
|
88 |
+
_, vq_code, _ = self.generator(concat_emb, vq=True)
|
89 |
+
vq_post_emb = self.generator.quantizer.get_output_from_indices(vq_code.transpose(1, 2))
|
90 |
+
vq_post_emb = vq_post_emb.transpose(1, 2)
|
91 |
+
|
92 |
+
# 7) fc_post_a
|
93 |
+
vq_post_emb = self.fc_post_a(vq_post_emb.transpose(1, 2)).transpose(1, 2)
|
94 |
+
|
95 |
+
# 8) 最后解码成波形
|
96 |
+
recon_audio = self.generator(vq_post_emb.transpose(1, 2), vq=False)[0]
|
97 |
+
# recon_audio: [batch, time]
|
98 |
+
return recon_audio
|
99 |
+
|
100 |
+
def encode_code(self, input_waveform, sample_rate=16000):
|
101 |
+
"""
|
102 |
+
将输入的音频编码为代码表示。
|
103 |
+
|
104 |
+
参数:
|
105 |
+
input_waveform: [batch_size, waveform_length]
|
106 |
+
sample_rate: 默认 16000
|
107 |
+
返回:
|
108 |
+
编码后的代码 (Tensor)
|
109 |
+
"""
|
110 |
+
with torch.no_grad():
|
111 |
+
# 1) 特征提取
|
112 |
+
input_features = self.feature_extractor(
|
113 |
+
input_waveform,
|
114 |
+
sampling_rate=sample_rate,
|
115 |
+
return_tensors="pt"
|
116 |
+
).input_features.to(self.device) # [batch, frames, feat_dim]
|
117 |
+
|
118 |
+
# 2) 语义层
|
119 |
+
semantic_output = self.semantic_model(input_features)
|
120 |
+
semantic_hidden_16 = semantic_output.hidden_states[16] # 取第16层
|
121 |
+
semantic_hidden_16 = semantic_hidden_16.transpose(1, 2) # [batch, hidden_dim, frames]
|
122 |
+
semantic_encoded = self.SemanticEncoder_module(semantic_hidden_16)
|
123 |
+
|
124 |
+
# 3) codec encoder
|
125 |
+
wav = input_waveform.unsqueeze(1).to(self.device) # shape: [batch, 1, time]
|
126 |
+
vq_emb = self.CodecEnc(wav) # [batch, time//down, 1024] 只是示例
|
127 |
+
vq_emb = vq_emb.transpose(1, 2) # -> [batch, 1024, frames]
|
128 |
+
|
129 |
+
# 对齐语义向量的时间帧数,这里只做示例处理
|
130 |
+
if vq_emb.shape[-1] != semantic_encoded.shape[-1]:
|
131 |
+
min_len = min(vq_emb.shape[-1], semantic_encoded.shape[-1])
|
132 |
+
vq_emb = vq_emb[:, :, :min_len]
|
133 |
+
semantic_encoded = semantic_encoded[:, :, :min_len]
|
134 |
+
|
135 |
+
# 4) 拼接
|
136 |
+
concat_emb = torch.cat([semantic_encoded, vq_emb], dim=1) # [batch, 2048, frames]
|
137 |
+
|
138 |
+
# 5) fc_prior
|
139 |
+
concat_emb = self.fc_prior(concat_emb.transpose(1, 2)).transpose(1, 2)
|
140 |
+
|
141 |
+
# 6) decoder 的量化部分,获取code
|
142 |
+
_, vq_code, _ = self.generator(concat_emb, vq=True)
|
143 |
+
# vq_code: [batch, frames]
|
144 |
+
return vq_code
|
145 |
+
|
146 |
+
def decode_code(self, vq_code):
|
147 |
+
"""
|
148 |
+
将编码后的代码解码回音频。
|
149 |
+
|
150 |
+
参数:
|
151 |
+
vq_code: 编码后的代码 (Tensor) [batch, frames]
|
152 |
+
返回:
|
153 |
+
解码后的音频 (Tensor) [batch, waveform_length]
|
154 |
+
"""
|
155 |
+
with torch.no_grad():
|
156 |
+
# 获取量化后的嵌入
|
157 |
+
vq_post_emb = self.generator.quantizer.get_output_from_indices(vq_code.transpose(1, 2))
|
158 |
+
vq_post_emb = vq_post_emb.transpose(1, 2) # [batch, 1024, frames]
|
159 |
+
|
160 |
+
# 7) fc_post_a
|
161 |
+
vq_post_emb = self.fc_post_a(vq_post_emb.transpose(1, 2)).transpose(1, 2) # [batch, 1024, frames]
|
162 |
+
|
163 |
+
# 8) 最后解码成波形
|
164 |
+
recon_audio = self.generator(vq_post_emb.transpose(1, 2), vq=False)[0] # [batch, time]
|
165 |
+
return recon_audio
|
module.py
ADDED
File without changes
|
test.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import soundfile as sf
|
3 |
+
from transformers import AutoConfig
|
4 |
+
|
5 |
+
from modeling_xcodec2 import XCodec2Model
|
6 |
+
|
7 |
+
model_path = "/data/zheny/xcodec2" # 这是你在 huggingface 上的仓库名
|
8 |
+
|
9 |
+
model = XCodec2Model.from_pretrained(model_path)
|
10 |
+
model.eval().cuda()
|
11 |
+
|
12 |
+
# 准备一段音频
|
13 |
+
wav, sr = sf.read("test.flac")
|
14 |
+
wav_tensor = torch.from_numpy(wav).float().unsqueeze(0) # [1, time]
|
15 |
+
|
16 |
+
with torch.no_grad():
|
17 |
+
vq_code = model.encode_code(input_waveform=wav_tensor )
|
18 |
+
print(vq_code)
|
19 |
+
recon_wav = model.decode_code(vq_code).cpu()
|
20 |
+
|
21 |
+
sf.write("reconstructed.wav", recon_wav[0,0,:].numpy(), sr)
|
vq/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from vq.codec_encoder import CodecEncoder
|
2 |
+
from vq.codec_decoder import CodecDecoder
|
3 |
+
from vq.codec_decoder_vocos import CodecDecoderVocos
|
4 |
+
from vq.codec_encoder import CodecEncoder_Transformer,CodecEncoder_only_Transformer
|
vq/activations.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
|
2 |
+
# LICENSE is in incl_licenses directory.
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import nn, sin, pow
|
6 |
+
from torch.nn import Parameter
|
7 |
+
|
8 |
+
|
9 |
+
class Snake(nn.Module):
|
10 |
+
'''
|
11 |
+
Implementation of a sine-based periodic activation function
|
12 |
+
Shape:
|
13 |
+
- Input: (B, C, T)
|
14 |
+
- Output: (B, C, T), same shape as the input
|
15 |
+
Parameters:
|
16 |
+
- alpha - trainable parameter
|
17 |
+
References:
|
18 |
+
- This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
19 |
+
https://arxiv.org/abs/2006.08195
|
20 |
+
Examples:
|
21 |
+
>>> a1 = snake(256)
|
22 |
+
>>> x = torch.randn(256)
|
23 |
+
>>> x = a1(x)
|
24 |
+
'''
|
25 |
+
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
26 |
+
'''
|
27 |
+
Initialization.
|
28 |
+
INPUT:
|
29 |
+
- in_features: shape of the input
|
30 |
+
- alpha: trainable parameter
|
31 |
+
alpha is initialized to 1 by default, higher values = higher-frequency.
|
32 |
+
alpha will be trained along with the rest of your model.
|
33 |
+
'''
|
34 |
+
super(Snake, self).__init__()
|
35 |
+
self.in_features = in_features
|
36 |
+
|
37 |
+
# initialize alpha
|
38 |
+
self.alpha_logscale = alpha_logscale
|
39 |
+
if self.alpha_logscale: # log scale alphas initialized to zeros
|
40 |
+
self.alpha = Parameter(torch.zeros(in_features) * alpha)
|
41 |
+
else: # linear scale alphas initialized to ones
|
42 |
+
self.alpha = Parameter(torch.ones(in_features) * alpha)
|
43 |
+
|
44 |
+
self.alpha.requires_grad = alpha_trainable
|
45 |
+
|
46 |
+
self.no_div_by_zero = 0.000000001
|
47 |
+
|
48 |
+
def forward(self, x):
|
49 |
+
'''
|
50 |
+
Forward pass of the function.
|
51 |
+
Applies the function to the input elementwise.
|
52 |
+
Snake ∶= x + 1/a * sin^2 (xa)
|
53 |
+
'''
|
54 |
+
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
55 |
+
if self.alpha_logscale:
|
56 |
+
alpha = torch.exp(alpha)
|
57 |
+
x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
58 |
+
|
59 |
+
return x
|
60 |
+
|
61 |
+
|
62 |
+
class SnakeBeta(nn.Module):
|
63 |
+
'''
|
64 |
+
A modified Snake function which uses separate parameters for the magnitude of the periodic components
|
65 |
+
Shape:
|
66 |
+
- Input: (B, C, T)
|
67 |
+
- Output: (B, C, T), same shape as the input
|
68 |
+
Parameters:
|
69 |
+
- alpha - trainable parameter that controls frequency
|
70 |
+
- beta - trainable parameter that controls magnitude
|
71 |
+
References:
|
72 |
+
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
73 |
+
https://arxiv.org/abs/2006.08195
|
74 |
+
Examples:
|
75 |
+
>>> a1 = snakebeta(256)
|
76 |
+
>>> x = torch.randn(256)
|
77 |
+
>>> x = a1(x)
|
78 |
+
'''
|
79 |
+
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
80 |
+
'''
|
81 |
+
Initialization.
|
82 |
+
INPUT:
|
83 |
+
- in_features: shape of the input
|
84 |
+
- alpha - trainable parameter that controls frequency
|
85 |
+
- beta - trainable parameter that controls magnitude
|
86 |
+
alpha is initialized to 1 by default, higher values = higher-frequency.
|
87 |
+
beta is initialized to 1 by default, higher values = higher-magnitude.
|
88 |
+
alpha will be trained along with the rest of your model.
|
89 |
+
'''
|
90 |
+
super(SnakeBeta, self).__init__()
|
91 |
+
self.in_features = in_features
|
92 |
+
|
93 |
+
# initialize alpha
|
94 |
+
self.alpha_logscale = alpha_logscale
|
95 |
+
if self.alpha_logscale: # log scale alphas initialized to zeros
|
96 |
+
self.alpha = Parameter(torch.zeros(in_features) * alpha)
|
97 |
+
self.bias = Parameter(torch.zeros(in_features) * alpha)
|
98 |
+
else: # linear scale alphas initialized to ones
|
99 |
+
self.alpha = Parameter(torch.ones(in_features) * alpha)
|
100 |
+
self.bias = Parameter(torch.ones(in_features) * alpha)
|
101 |
+
|
102 |
+
self.alpha.requires_grad = alpha_trainable
|
103 |
+
self.bias.requires_grad = alpha_trainable
|
104 |
+
|
105 |
+
self.no_div_by_zero = 0.000000001
|
106 |
+
|
107 |
+
def forward(self, x):
|
108 |
+
'''
|
109 |
+
Forward pass of the function.
|
110 |
+
Applies the function to the input elementwise.
|
111 |
+
SnakeBeta ∶= x + 1/b * sin^2 (xa)
|
112 |
+
'''
|
113 |
+
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
114 |
+
beta = self.bias.unsqueeze(0).unsqueeze(-1)
|
115 |
+
if self.alpha_logscale:
|
116 |
+
alpha = torch.exp(alpha)
|
117 |
+
beta = torch.exp(beta)
|
118 |
+
x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
119 |
+
|
120 |
+
return x
|
vq/alias_free_torch/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
2 |
+
# LICENSE is in incl_licenses directory.
|
3 |
+
|
4 |
+
from .filter import *
|
5 |
+
from .resample import *
|
6 |
+
from .act import *
|
vq/alias_free_torch/act.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
2 |
+
# LICENSE is in incl_licenses directory.
|
3 |
+
|
4 |
+
import torch.nn as nn
|
5 |
+
from .resample import UpSample1d, DownSample1d
|
6 |
+
|
7 |
+
|
8 |
+
class Activation1d(nn.Module):
|
9 |
+
def __init__(self,
|
10 |
+
activation,
|
11 |
+
up_ratio: int = 2,
|
12 |
+
down_ratio: int = 2,
|
13 |
+
up_kernel_size: int = 12,
|
14 |
+
down_kernel_size: int = 12):
|
15 |
+
super().__init__()
|
16 |
+
self.up_ratio = up_ratio
|
17 |
+
self.down_ratio = down_ratio
|
18 |
+
self.act = activation
|
19 |
+
self.upsample = UpSample1d(up_ratio, up_kernel_size)
|
20 |
+
self.downsample = DownSample1d(down_ratio, down_kernel_size)
|
21 |
+
|
22 |
+
# x: [B,C,T]
|
23 |
+
def forward(self, x):
|
24 |
+
x = self.upsample(x)
|
25 |
+
x = self.act(x)
|
26 |
+
x = self.downsample(x)
|
27 |
+
|
28 |
+
return x
|
vq/alias_free_torch/filter.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
2 |
+
# LICENSE is in incl_licenses directory.
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import math
|
8 |
+
|
9 |
+
if 'sinc' in dir(torch):
|
10 |
+
sinc = torch.sinc
|
11 |
+
else:
|
12 |
+
# This code is adopted from adefossez's julius.core.sinc under the MIT License
|
13 |
+
# https://adefossez.github.io/julius/julius/core.html
|
14 |
+
# LICENSE is in incl_licenses directory.
|
15 |
+
def sinc(x: torch.Tensor):
|
16 |
+
"""
|
17 |
+
Implementation of sinc, i.e. sin(pi * x) / (pi * x)
|
18 |
+
__Warning__: Different to julius.sinc, the input is multiplied by `pi`!
|
19 |
+
"""
|
20 |
+
return torch.where(x == 0,
|
21 |
+
torch.tensor(1., device=x.device, dtype=x.dtype),
|
22 |
+
torch.sin(math.pi * x) / math.pi / x)
|
23 |
+
|
24 |
+
|
25 |
+
# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
|
26 |
+
# https://adefossez.github.io/julius/julius/lowpass.html
|
27 |
+
# LICENSE is in incl_licenses directory.
|
28 |
+
def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size]
|
29 |
+
even = (kernel_size % 2 == 0)
|
30 |
+
half_size = kernel_size // 2
|
31 |
+
|
32 |
+
#For kaiser window
|
33 |
+
delta_f = 4 * half_width
|
34 |
+
A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
|
35 |
+
if A > 50.:
|
36 |
+
beta = 0.1102 * (A - 8.7)
|
37 |
+
elif A >= 21.:
|
38 |
+
beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.)
|
39 |
+
else:
|
40 |
+
beta = 0.
|
41 |
+
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
|
42 |
+
|
43 |
+
# ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
|
44 |
+
if even:
|
45 |
+
time = (torch.arange(-half_size, half_size) + 0.5)
|
46 |
+
else:
|
47 |
+
time = torch.arange(kernel_size) - half_size
|
48 |
+
if cutoff == 0:
|
49 |
+
filter_ = torch.zeros_like(time)
|
50 |
+
else:
|
51 |
+
filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
|
52 |
+
# Normalize filter to have sum = 1, otherwise we will have a small leakage
|
53 |
+
# of the constant component in the input signal.
|
54 |
+
filter_ /= filter_.sum()
|
55 |
+
filter = filter_.view(1, 1, kernel_size)
|
56 |
+
|
57 |
+
return filter
|
58 |
+
|
59 |
+
|
60 |
+
class LowPassFilter1d(nn.Module):
|
61 |
+
def __init__(self,
|
62 |
+
cutoff=0.5,
|
63 |
+
half_width=0.6,
|
64 |
+
stride: int = 1,
|
65 |
+
padding: bool = True,
|
66 |
+
padding_mode: str = 'replicate',
|
67 |
+
kernel_size: int = 12):
|
68 |
+
# kernel_size should be even number for stylegan3 setup,
|
69 |
+
# in this implementation, odd number is also possible.
|
70 |
+
super().__init__()
|
71 |
+
if cutoff < -0.:
|
72 |
+
raise ValueError("Minimum cutoff must be larger than zero.")
|
73 |
+
if cutoff > 0.5:
|
74 |
+
raise ValueError("A cutoff above 0.5 does not make sense.")
|
75 |
+
self.kernel_size = kernel_size
|
76 |
+
self.even = (kernel_size % 2 == 0)
|
77 |
+
self.pad_left = kernel_size // 2 - int(self.even)
|
78 |
+
self.pad_right = kernel_size // 2
|
79 |
+
self.stride = stride
|
80 |
+
self.padding = padding
|
81 |
+
self.padding_mode = padding_mode
|
82 |
+
filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
|
83 |
+
self.register_buffer("filter", filter)
|
84 |
+
|
85 |
+
#input [B, C, T]
|
86 |
+
def forward(self, x):
|
87 |
+
_, C, _ = x.shape
|
88 |
+
|
89 |
+
if self.padding:
|
90 |
+
x = F.pad(x, (self.pad_left, self.pad_right),
|
91 |
+
mode=self.padding_mode)
|
92 |
+
out = F.conv1d(x, self.filter.expand(C, -1, -1),
|
93 |
+
stride=self.stride, groups=C)
|
94 |
+
|
95 |
+
return out
|
vq/alias_free_torch/resample.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
2 |
+
# LICENSE is in incl_licenses directory.
|
3 |
+
|
4 |
+
import torch.nn as nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
from .filter import LowPassFilter1d
|
7 |
+
from .filter import kaiser_sinc_filter1d
|
8 |
+
|
9 |
+
|
10 |
+
class UpSample1d(nn.Module):
|
11 |
+
def __init__(self, ratio=2, kernel_size=None):
|
12 |
+
super().__init__()
|
13 |
+
self.ratio = ratio
|
14 |
+
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
15 |
+
self.stride = ratio
|
16 |
+
self.pad = self.kernel_size // ratio - 1
|
17 |
+
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
|
18 |
+
self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
|
19 |
+
filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio,
|
20 |
+
half_width=0.6 / ratio,
|
21 |
+
kernel_size=self.kernel_size)
|
22 |
+
self.register_buffer("filter", filter)
|
23 |
+
|
24 |
+
# x: [B, C, T]
|
25 |
+
def forward(self, x):
|
26 |
+
_, C, _ = x.shape
|
27 |
+
|
28 |
+
x = F.pad(x, (self.pad, self.pad), mode='replicate')
|
29 |
+
x = self.ratio * F.conv_transpose1d(
|
30 |
+
x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
|
31 |
+
x = x[..., self.pad_left:-self.pad_right]
|
32 |
+
|
33 |
+
return x
|
34 |
+
|
35 |
+
|
36 |
+
class DownSample1d(nn.Module):
|
37 |
+
def __init__(self, ratio=2, kernel_size=None):
|
38 |
+
super().__init__()
|
39 |
+
self.ratio = ratio
|
40 |
+
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
41 |
+
self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio,
|
42 |
+
half_width=0.6 / ratio,
|
43 |
+
stride=ratio,
|
44 |
+
kernel_size=self.kernel_size)
|
45 |
+
|
46 |
+
def forward(self, x):
|
47 |
+
xx = self.lowpass(x)
|
48 |
+
|
49 |
+
return xx
|
vq/blocks.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable, Sequence, Type, Union
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
ModuleFactory = Union[Type[nn.Module], Callable[[], nn.Module]]
|
8 |
+
|
9 |
+
|
10 |
+
class FeedForwardModule(nn.Module):
|
11 |
+
|
12 |
+
def __init__(self) -> None:
|
13 |
+
super().__init__()
|
14 |
+
self.net = None
|
15 |
+
|
16 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
17 |
+
return self.net(x)
|
18 |
+
|
19 |
+
|
20 |
+
class Residual(nn.Module):
|
21 |
+
|
22 |
+
def __init__(self, module: nn.Module) -> None:
|
23 |
+
super().__init__()
|
24 |
+
self.module = module
|
25 |
+
|
26 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
27 |
+
return self.module(x) + x
|
28 |
+
|
29 |
+
|
30 |
+
class DilatedConvolutionalUnit(FeedForwardModule):
|
31 |
+
|
32 |
+
def __init__(
|
33 |
+
self,
|
34 |
+
hidden_dim: int,
|
35 |
+
dilation: int,
|
36 |
+
kernel_size: int,
|
37 |
+
activation: ModuleFactory,
|
38 |
+
normalization: Callable[[nn.Module],
|
39 |
+
nn.Module] = lambda x: x) -> None:
|
40 |
+
super().__init__()
|
41 |
+
self.net = nn.Sequential(
|
42 |
+
activation(),
|
43 |
+
normalization(
|
44 |
+
nn.Conv1d(
|
45 |
+
in_channels=hidden_dim,
|
46 |
+
out_channels=hidden_dim,
|
47 |
+
kernel_size=kernel_size,
|
48 |
+
dilation=dilation,
|
49 |
+
padding=((kernel_size - 1) * dilation) // 2,
|
50 |
+
)),
|
51 |
+
activation(),
|
52 |
+
nn.Conv1d(in_channels=hidden_dim,
|
53 |
+
out_channels=hidden_dim,
|
54 |
+
kernel_size=1),
|
55 |
+
)
|
56 |
+
|
57 |
+
|
58 |
+
class UpsamplingUnit(FeedForwardModule):
|
59 |
+
|
60 |
+
def __init__(
|
61 |
+
self,
|
62 |
+
input_dim: int,
|
63 |
+
output_dim: int,
|
64 |
+
stride: int,
|
65 |
+
activation: ModuleFactory,
|
66 |
+
normalization: Callable[[nn.Module],
|
67 |
+
nn.Module] = lambda x: x) -> None:
|
68 |
+
super().__init__()
|
69 |
+
self.net = nn.Sequential(
|
70 |
+
activation(),
|
71 |
+
normalization(
|
72 |
+
nn.ConvTranspose1d(
|
73 |
+
in_channels=input_dim,
|
74 |
+
out_channels=output_dim,
|
75 |
+
kernel_size=2 * stride,
|
76 |
+
stride=stride,
|
77 |
+
padding=stride // 2+ stride % 2,
|
78 |
+
output_padding=1 if stride % 2 != 0 else 0
|
79 |
+
)))
|
80 |
+
|
81 |
+
|
82 |
+
class DownsamplingUnit(FeedForwardModule):
|
83 |
+
|
84 |
+
def __init__(
|
85 |
+
self,
|
86 |
+
input_dim: int,
|
87 |
+
output_dim: int,
|
88 |
+
stride: int,
|
89 |
+
activation: ModuleFactory,
|
90 |
+
normalization: Callable[[nn.Module],
|
91 |
+
nn.Module] = lambda x: x) -> None:
|
92 |
+
super().__init__()
|
93 |
+
self.net = nn.Sequential(
|
94 |
+
activation(),
|
95 |
+
normalization(
|
96 |
+
nn.Conv1d(
|
97 |
+
in_channels=input_dim,
|
98 |
+
out_channels=output_dim,
|
99 |
+
kernel_size=2 * stride,
|
100 |
+
stride=stride,
|
101 |
+
padding= stride // 2+ stride % 2,
|
102 |
+
|
103 |
+
)))
|
104 |
+
|
105 |
+
|
106 |
+
class DilatedResidualEncoder(FeedForwardModule):
|
107 |
+
|
108 |
+
def __init__(
|
109 |
+
self,
|
110 |
+
capacity: int,
|
111 |
+
dilated_unit: Type[DilatedConvolutionalUnit],
|
112 |
+
downsampling_unit: Type[DownsamplingUnit],
|
113 |
+
ratios: Sequence[int],
|
114 |
+
dilations: Union[Sequence[int], Sequence[Sequence[int]]],
|
115 |
+
pre_network_conv: Type[nn.Conv1d],
|
116 |
+
post_network_conv: Type[nn.Conv1d],
|
117 |
+
normalization: Callable[[nn.Module],
|
118 |
+
nn.Module] = lambda x: x) -> None:
|
119 |
+
super().__init__()
|
120 |
+
channels = capacity * 2**np.arange(len(ratios) + 1)
|
121 |
+
|
122 |
+
dilations_list = self.normalize_dilations(dilations, ratios)
|
123 |
+
|
124 |
+
net = [normalization(pre_network_conv(out_channels=channels[0]))]
|
125 |
+
|
126 |
+
for ratio, dilations, input_dim, output_dim in zip(
|
127 |
+
ratios, dilations_list, channels[:-1], channels[1:]):
|
128 |
+
for dilation in dilations:
|
129 |
+
net.append(Residual(dilated_unit(input_dim, dilation)))
|
130 |
+
net.append(downsampling_unit(input_dim, output_dim, ratio))
|
131 |
+
|
132 |
+
net.append(post_network_conv(in_channels=output_dim))
|
133 |
+
|
134 |
+
self.net = nn.Sequential(*net)
|
135 |
+
|
136 |
+
@staticmethod
|
137 |
+
def normalize_dilations(dilations: Union[Sequence[int],
|
138 |
+
Sequence[Sequence[int]]],
|
139 |
+
ratios: Sequence[int]):
|
140 |
+
if isinstance(dilations[0], int):
|
141 |
+
dilations = [dilations for _ in ratios]
|
142 |
+
return dilations
|
143 |
+
|
144 |
+
|
145 |
+
class DilatedResidualDecoder(FeedForwardModule):
|
146 |
+
|
147 |
+
def __init__(
|
148 |
+
self,
|
149 |
+
capacity: int,
|
150 |
+
dilated_unit: Type[DilatedConvolutionalUnit],
|
151 |
+
upsampling_unit: Type[UpsamplingUnit],
|
152 |
+
ratios: Sequence[int],
|
153 |
+
dilations: Union[Sequence[int], Sequence[Sequence[int]]],
|
154 |
+
pre_network_conv: Type[nn.Conv1d],
|
155 |
+
post_network_conv: Type[nn.Conv1d],
|
156 |
+
normalization: Callable[[nn.Module],
|
157 |
+
nn.Module] = lambda x: x) -> None:
|
158 |
+
super().__init__()
|
159 |
+
channels = capacity * 2**np.arange(len(ratios) + 1)
|
160 |
+
channels = channels[::-1]
|
161 |
+
|
162 |
+
dilations_list = self.normalize_dilations(dilations, ratios)
|
163 |
+
dilations_list = dilations_list[::-1]
|
164 |
+
|
165 |
+
net = [pre_network_conv(out_channels=channels[0])]
|
166 |
+
|
167 |
+
for ratio, dilations, input_dim, output_dim in zip(
|
168 |
+
ratios, dilations_list, channels[:-1], channels[1:]):
|
169 |
+
net.append(upsampling_unit(input_dim, output_dim, ratio))
|
170 |
+
for dilation in dilations:
|
171 |
+
net.append(Residual(dilated_unit(output_dim, dilation)))
|
172 |
+
|
173 |
+
net.append(normalization(post_network_conv(in_channels=output_dim)))
|
174 |
+
|
175 |
+
self.net = nn.Sequential(*net)
|
176 |
+
|
177 |
+
@staticmethod
|
178 |
+
def normalize_dilations(dilations: Union[Sequence[int],
|
179 |
+
Sequence[Sequence[int]]],
|
180 |
+
ratios: Sequence[int]):
|
181 |
+
if isinstance(dilations[0], int):
|
182 |
+
dilations = [dilations for _ in ratios]
|
183 |
+
return dilations
|
vq/bs_roformer5.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch.nn import Module, ModuleList
|
5 |
+
import torchaudio
|
6 |
+
from einops import rearrange
|
7 |
+
import numpy as np
|
8 |
+
# from rotary_embedding_torch import RotaryEmbedding
|
9 |
+
|
10 |
+
from torchtune.modules import RotaryPositionalEmbeddings
|
11 |
+
|
12 |
+
|
13 |
+
|
14 |
+
class RMSNorm(torch.nn.Module):
|
15 |
+
def __init__(self, dim: int, eps: float = 1e-6):
|
16 |
+
r"""https://github.com/meta-llama/llama/blob/main/llama/model.py"""
|
17 |
+
super().__init__()
|
18 |
+
self.eps = eps
|
19 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
20 |
+
|
21 |
+
def forward(self, x):
|
22 |
+
norm_x = torch.mean(x ** 2, dim=-1, keepdim=True)
|
23 |
+
output = x * torch.rsqrt(norm_x + self.eps) * self.weight
|
24 |
+
return output
|
25 |
+
|
26 |
+
|
27 |
+
|
28 |
+
class MLP(nn.Module):
|
29 |
+
def __init__(self, dim: int) -> None:
|
30 |
+
super().__init__()
|
31 |
+
|
32 |
+
self.fc1 = nn.Linear(dim, 4 * dim, bias=False)
|
33 |
+
self.silu = nn.SiLU()
|
34 |
+
self.fc2 = nn.Linear(4 * dim, dim, bias=False)
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
x = self.fc1(x)
|
38 |
+
x = self.silu(x)
|
39 |
+
x = self.fc2(x)
|
40 |
+
return x
|
41 |
+
|
42 |
+
|
43 |
+
class Attention(nn.Module):
|
44 |
+
|
45 |
+
def __init__(self, dim: int, n_heads: int, rotary_embed: RotaryPositionalEmbeddings):
|
46 |
+
super().__init__()
|
47 |
+
|
48 |
+
assert dim % n_heads == 0
|
49 |
+
|
50 |
+
self.n_heads = n_heads
|
51 |
+
self.dim = dim
|
52 |
+
self.rotary_embed = rotary_embed
|
53 |
+
|
54 |
+
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
|
55 |
+
assert self.flash, "Must have flash attention."
|
56 |
+
|
57 |
+
self.c_attn = nn.Linear(dim, 3 * dim, bias=False)
|
58 |
+
self.c_proj = nn.Linear(dim, dim, bias=False)
|
59 |
+
|
60 |
+
def forward(self, x):
|
61 |
+
r"""
|
62 |
+
Args:
|
63 |
+
x: (b, t, h*d)
|
64 |
+
|
65 |
+
Constants:
|
66 |
+
b: batch_size
|
67 |
+
t: time steps
|
68 |
+
r: 3
|
69 |
+
h: heads_num
|
70 |
+
d: heads_dim
|
71 |
+
"""
|
72 |
+
B, T, C = x.size()
|
73 |
+
|
74 |
+
q, k, v = rearrange(self.c_attn(x), 'b t (r h d) -> r b h t d', r=3, h=self.n_heads)
|
75 |
+
# q, k, v: (b, h, t, d)
|
76 |
+
|
77 |
+
q = self.rotary_embed(q)
|
78 |
+
k = self.rotary_embed(k)
|
79 |
+
|
80 |
+
if self.flash:
|
81 |
+
y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0, is_causal=False)
|
82 |
+
|
83 |
+
y = rearrange(y, 'b h t d -> b t (h d)')
|
84 |
+
|
85 |
+
y = self.c_proj(y)
|
86 |
+
# shape: (b, t, h*d)
|
87 |
+
|
88 |
+
return y
|
89 |
+
|
90 |
+
|
91 |
+
class TransformerBlock(nn.Module):
|
92 |
+
def __init__(self, dim: int, n_heads: int, rotary_embed: RotaryPositionalEmbeddings):
|
93 |
+
|
94 |
+
super().__init__()
|
95 |
+
self.dim = dim
|
96 |
+
self.n_heads = n_heads
|
97 |
+
|
98 |
+
self.att_norm = RMSNorm(dim)
|
99 |
+
self.ffn_norm = RMSNorm(dim)
|
100 |
+
self.att = Attention(dim=dim, n_heads=n_heads, rotary_embed=rotary_embed)
|
101 |
+
self.mlp = MLP(dim=dim)
|
102 |
+
|
103 |
+
|
104 |
+
def forward(
|
105 |
+
self,
|
106 |
+
x: torch.Tensor,
|
107 |
+
):
|
108 |
+
x = x + self.att(self.att_norm(x))
|
109 |
+
x = x + self.mlp(self.ffn_norm(x))
|
110 |
+
return x
|
111 |
+
|
112 |
+
|
113 |
+
if __name__ == '__main__':
|
114 |
+
rotary_embed_128 = RotaryPositionalEmbeddings(dim=128)
|
115 |
+
transformer_block = TransformerBlock(
|
116 |
+
dim=1024,
|
117 |
+
n_heads=8,
|
118 |
+
rotary_embed=rotary_embed_128
|
119 |
+
)
|
120 |
+
x = torch.randn(2, 128, 1024)
|
121 |
+
y = transformer_block(x)
|
122 |
+
print(y.shape)
|
123 |
+
c=1
|
vq/codec_decoder.py
ADDED
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from vq.residual_vq import ResidualVQ
|
7 |
+
from vq.module import WNConv1d, DecoderBlock, ResLSTM
|
8 |
+
from vq.alias_free_torch import *
|
9 |
+
from vq import activations
|
10 |
+
import vq.blocks as blocks
|
11 |
+
from torch.nn import utils
|
12 |
+
|
13 |
+
from vq.bs_roformer5 import TransformerBlock
|
14 |
+
|
15 |
+
from torchtune.modules import RotaryPositionalEmbeddings
|
16 |
+
|
17 |
+
def init_weights(m):
|
18 |
+
if isinstance(m, nn.Conv1d):
|
19 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
20 |
+
nn.init.constant_(m.bias, 0)
|
21 |
+
|
22 |
+
class CodecDecoder(nn.Module):
|
23 |
+
def __init__(self,
|
24 |
+
in_channels=1024,
|
25 |
+
upsample_initial_channel=1536,
|
26 |
+
ngf=48,
|
27 |
+
use_rnn=True,
|
28 |
+
rnn_bidirectional=False,
|
29 |
+
rnn_num_layers=2,
|
30 |
+
up_ratios=(5, 4, 4, 4, 2),
|
31 |
+
dilations=(1, 3, 9),
|
32 |
+
vq_num_quantizers=1,
|
33 |
+
vq_dim=2048,
|
34 |
+
vq_commit_weight=0.25,
|
35 |
+
vq_weight_init=False,
|
36 |
+
vq_full_commit_loss=False,
|
37 |
+
codebook_size=16384,
|
38 |
+
codebook_dim=32,
|
39 |
+
):
|
40 |
+
super().__init__()
|
41 |
+
self.hop_length = np.prod(up_ratios)
|
42 |
+
self.ngf = ngf
|
43 |
+
self.up_ratios = up_ratios
|
44 |
+
|
45 |
+
self.quantizer = ResidualVQ(
|
46 |
+
num_quantizers=vq_num_quantizers,
|
47 |
+
dim=vq_dim, # double the dim for acousitc and semantic
|
48 |
+
codebook_size=codebook_size,
|
49 |
+
codebook_dim=codebook_dim,
|
50 |
+
threshold_ema_dead_code=2,
|
51 |
+
commitment=vq_commit_weight,
|
52 |
+
weight_init=vq_weight_init,
|
53 |
+
full_commit_loss=vq_full_commit_loss,
|
54 |
+
)
|
55 |
+
channels = upsample_initial_channel
|
56 |
+
layers = [WNConv1d(in_channels, channels, kernel_size=7, padding=3)]
|
57 |
+
|
58 |
+
if use_rnn:
|
59 |
+
layers += [
|
60 |
+
ResLSTM(channels,
|
61 |
+
num_layers=rnn_num_layers,
|
62 |
+
bidirectional=rnn_bidirectional
|
63 |
+
)
|
64 |
+
]
|
65 |
+
|
66 |
+
for i, stride in enumerate(up_ratios):
|
67 |
+
input_dim = channels // 2**i
|
68 |
+
output_dim = channels // 2 ** (i + 1)
|
69 |
+
layers += [DecoderBlock(input_dim, output_dim, stride, dilations)]
|
70 |
+
|
71 |
+
layers += [
|
72 |
+
Activation1d(activation=activations.SnakeBeta(output_dim, alpha_logscale=True)),
|
73 |
+
WNConv1d(output_dim, 1, kernel_size=7, padding=3),
|
74 |
+
nn.Tanh(),
|
75 |
+
]
|
76 |
+
|
77 |
+
self.model = nn.Sequential(*layers)
|
78 |
+
|
79 |
+
self.reset_parameters()
|
80 |
+
|
81 |
+
def forward(self, x, vq=True):
|
82 |
+
if vq is True:
|
83 |
+
x, q, commit_loss = self.quantizer(x)
|
84 |
+
return x, q, commit_loss
|
85 |
+
x = self.model(x)
|
86 |
+
return x
|
87 |
+
|
88 |
+
def vq2emb(self, vq):
|
89 |
+
self.quantizer = self.quantizer.eval()
|
90 |
+
x = self.quantizer.vq2emb(vq)
|
91 |
+
return x
|
92 |
+
|
93 |
+
def get_emb(self):
|
94 |
+
self.quantizer = self.quantizer.eval()
|
95 |
+
embs = self.quantizer.get_emb()
|
96 |
+
return embs
|
97 |
+
|
98 |
+
def inference_vq(self, vq):
|
99 |
+
x = vq[None,:,:]
|
100 |
+
x = self.model(x)
|
101 |
+
return x
|
102 |
+
|
103 |
+
def inference_0(self, x):
|
104 |
+
x, q, loss, perp = self.quantizer(x)
|
105 |
+
x = self.model(x)
|
106 |
+
return x, None
|
107 |
+
|
108 |
+
def inference(self, x):
|
109 |
+
x = self.model(x)
|
110 |
+
return x, None
|
111 |
+
|
112 |
+
|
113 |
+
def remove_weight_norm(self):
|
114 |
+
"""Remove weight normalization module from all of the layers."""
|
115 |
+
|
116 |
+
def _remove_weight_norm(m):
|
117 |
+
try:
|
118 |
+
torch.nn.utils.remove_weight_norm(m)
|
119 |
+
except ValueError: # this module didn't have weight norm
|
120 |
+
return
|
121 |
+
|
122 |
+
self.apply(_remove_weight_norm)
|
123 |
+
|
124 |
+
def apply_weight_norm(self):
|
125 |
+
"""Apply weight normalization module from all of the layers."""
|
126 |
+
|
127 |
+
def _apply_weight_norm(m):
|
128 |
+
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d):
|
129 |
+
torch.nn.utils.weight_norm(m)
|
130 |
+
|
131 |
+
self.apply(_apply_weight_norm)
|
132 |
+
|
133 |
+
def reset_parameters(self):
|
134 |
+
self.apply(init_weights)
|
135 |
+
|
136 |
+
|
137 |
+
class CodecDecoder_oobleck_Transformer(nn.Module):
|
138 |
+
def __init__(self,
|
139 |
+
ngf=32,
|
140 |
+
up_ratios=(5, 4, 4, 4, 2),
|
141 |
+
dilations=(1, 3, 9),
|
142 |
+
vq_num_quantizers=1,
|
143 |
+
vq_dim=1024,
|
144 |
+
vq_commit_weight=0.25,
|
145 |
+
vq_weight_init=False,
|
146 |
+
vq_full_commit_loss=False,
|
147 |
+
codebook_size=16384,
|
148 |
+
codebook_dim=16,
|
149 |
+
hidden_dim=1024,
|
150 |
+
depth=12,
|
151 |
+
heads=16,
|
152 |
+
pos_meb_dim=64,
|
153 |
+
):
|
154 |
+
super().__init__()
|
155 |
+
self.hop_length = np.prod(up_ratios)
|
156 |
+
self.capacity = ngf
|
157 |
+
self.up_ratios = up_ratios
|
158 |
+
self.hidden_dim = hidden_dim
|
159 |
+
self.quantizer = ResidualVQ(
|
160 |
+
num_quantizers=vq_num_quantizers,
|
161 |
+
dim=vq_dim, # double the dim for acousitc and semantic
|
162 |
+
codebook_size=codebook_size,
|
163 |
+
codebook_dim=codebook_dim,
|
164 |
+
threshold_ema_dead_code=2,
|
165 |
+
commitment=vq_commit_weight,
|
166 |
+
weight_init=vq_weight_init,
|
167 |
+
full_commit_loss=vq_full_commit_loss,
|
168 |
+
)
|
169 |
+
|
170 |
+
time_rotary_embed = RotaryPositionalEmbeddings(dim=pos_meb_dim)
|
171 |
+
|
172 |
+
transformer_blocks = [
|
173 |
+
TransformerBlock(dim=hidden_dim, n_heads=heads, rotary_embed=time_rotary_embed)
|
174 |
+
for _ in range(depth)
|
175 |
+
]
|
176 |
+
|
177 |
+
self.transformers = nn.Sequential(*transformer_blocks)
|
178 |
+
|
179 |
+
self.final_layer_norm = nn.LayerNorm(hidden_dim, eps=1e-6)
|
180 |
+
|
181 |
+
self.conv_blocks = blocks.DilatedResidualDecoder(
|
182 |
+
capacity=self.capacity,
|
183 |
+
dilated_unit=self.dilated_unit,
|
184 |
+
upsampling_unit=self.upsampling_unit,
|
185 |
+
ratios=up_ratios, # 逆转编码器的下采样比率
|
186 |
+
dilations=dilations,
|
187 |
+
pre_network_conv=self.pre_conv,
|
188 |
+
post_network_conv=self.post_conv,
|
189 |
+
)
|
190 |
+
|
191 |
+
|
192 |
+
|
193 |
+
self.reset_parameters()
|
194 |
+
|
195 |
+
def forward(self, x, vq=True):
|
196 |
+
if vq is True:
|
197 |
+
x, q, commit_loss = self.quantizer(x)
|
198 |
+
return x, q, commit_loss
|
199 |
+
x= self.transformers(x)
|
200 |
+
x = self.final_layer_norm(x)
|
201 |
+
x = x.permute(0, 2, 1)
|
202 |
+
x = self.conv_blocks(x)
|
203 |
+
return x
|
204 |
+
|
205 |
+
def vq2emb(self, vq):
|
206 |
+
self.quantizer = self.quantizer.eval()
|
207 |
+
x = self.quantizer.vq2emb(vq)
|
208 |
+
return x
|
209 |
+
|
210 |
+
def get_emb(self):
|
211 |
+
self.quantizer = self.quantizer.eval()
|
212 |
+
embs = self.quantizer.get_emb()
|
213 |
+
return embs
|
214 |
+
|
215 |
+
def inference_vq(self, vq):
|
216 |
+
x = vq[None,:,:]
|
217 |
+
x = self.model(x)
|
218 |
+
return x
|
219 |
+
|
220 |
+
def inference_0(self, x):
|
221 |
+
x, q, loss, perp = self.quantizer(x)
|
222 |
+
x = self.model(x)
|
223 |
+
return x, None
|
224 |
+
|
225 |
+
def inference(self, x):
|
226 |
+
x = self.model(x)
|
227 |
+
return x, None
|
228 |
+
|
229 |
+
|
230 |
+
def remove_weight_norm(self):
|
231 |
+
"""Remove weight normalization module from all of the layers."""
|
232 |
+
|
233 |
+
def _remove_weight_norm(m):
|
234 |
+
try:
|
235 |
+
torch.nn.utils.remove_weight_norm(m)
|
236 |
+
except ValueError: # this module didn't have weight norm
|
237 |
+
return
|
238 |
+
|
239 |
+
self.apply(_remove_weight_norm)
|
240 |
+
|
241 |
+
def apply_weight_norm(self):
|
242 |
+
"""Apply weight normalization module from all of the layers."""
|
243 |
+
|
244 |
+
def _apply_weight_norm(m):
|
245 |
+
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d):
|
246 |
+
torch.nn.utils.weight_norm(m)
|
247 |
+
|
248 |
+
self.apply(_apply_weight_norm)
|
249 |
+
|
250 |
+
def reset_parameters(self):
|
251 |
+
self.apply(init_weights)
|
252 |
+
|
253 |
+
def pre_conv(self, out_channels):
|
254 |
+
return nn.Conv1d(in_channels=self.hidden_dim, out_channels=out_channels, kernel_size=1)
|
255 |
+
|
256 |
+
# 定义后处理卷积层,将模型的输出映射到最终的输出通道数
|
257 |
+
def post_conv(self,in_channels):
|
258 |
+
return nn.Conv1d(in_channels=in_channels, out_channels=1, kernel_size=1)
|
259 |
+
|
260 |
+
def dilated_unit(self, hidden_dim, dilation):
|
261 |
+
return blocks.DilatedConvolutionalUnit(
|
262 |
+
hidden_dim=hidden_dim,
|
263 |
+
dilation=dilation,
|
264 |
+
kernel_size=3,
|
265 |
+
activation=nn.ReLU ,
|
266 |
+
normalization=utils.weight_norm
|
267 |
+
)
|
268 |
+
|
269 |
+
# 定义上采样单元
|
270 |
+
def upsampling_unit(self,input_dim, output_dim, stride):
|
271 |
+
return blocks.UpsamplingUnit(
|
272 |
+
input_dim=input_dim,
|
273 |
+
output_dim=output_dim,
|
274 |
+
stride=stride,
|
275 |
+
activation=nn.ReLU ,
|
276 |
+
normalization=utils.weight_norm
|
277 |
+
)
|
278 |
+
|
279 |
+
def main():
|
280 |
+
# 设置设备
|
281 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
282 |
+
print(f"Using device: {device}")
|
283 |
+
|
284 |
+
# 初始化模型
|
285 |
+
model = CodecDecoder_oobleck_Transformer().to(device)
|
286 |
+
print("Model initialized.")
|
287 |
+
|
288 |
+
# 创建测试输入: batch_size x in_channels x sequence_length
|
289 |
+
batch_size = 2
|
290 |
+
in_channels = 1024
|
291 |
+
sequence_length = 100 # 示例长度,可以根据需要调整
|
292 |
+
dummy_input = torch.randn(batch_size, sequence_length, in_channels).to(device)
|
293 |
+
print(f"Dummy input shape: {dummy_input.shape}")
|
294 |
+
|
295 |
+
# 将模型设为评估模式
|
296 |
+
model.eval()
|
297 |
+
|
298 |
+
|
299 |
+
|
300 |
+
output_no_vq = model(dummy_input, vq=False)
|
301 |
+
c=1
|
302 |
+
|
303 |
+
if __name__ == "__main__":
|
304 |
+
main()
|
vq/codec_decoder_vocos.py
ADDED
@@ -0,0 +1,638 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
sys.path.append('/aifs4su/data/zheny/bigcodec_final/BigCodec_conv_transformer_vocos')
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from vq.residual_vq import ResidualVQ
|
7 |
+
from vq.module import WNConv1d, DecoderBlock, ResLSTM
|
8 |
+
from vq.alias_free_torch import *
|
9 |
+
from vq import activations
|
10 |
+
from typing import Optional
|
11 |
+
from vq.module import ConvNeXtBlock, AdaLayerNorm
|
12 |
+
from vq.bs_roformer5 import TransformerBlock
|
13 |
+
# from rotary_embedding_torch import RotaryEmbedding
|
14 |
+
from torchtune.modules import RotaryPositionalEmbeddings
|
15 |
+
from vector_quantize_pytorch import ResidualFSQ
|
16 |
+
from torch.nn import Module, ModuleList
|
17 |
+
class ISTFT(nn.Module):
|
18 |
+
"""
|
19 |
+
Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with
|
20 |
+
windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges.
|
21 |
+
See issue: https://github.com/pytorch/pytorch/issues/62323
|
22 |
+
Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs.
|
23 |
+
The NOLA constraint is met as we trim padded samples anyway.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
n_fft (int): Size of Fourier transform.
|
27 |
+
hop_length (int): The distance between neighboring sliding window frames.
|
28 |
+
win_length (int): The size of window frame and STFT filter.
|
29 |
+
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
30 |
+
"""
|
31 |
+
|
32 |
+
def __init__(self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"):
|
33 |
+
super().__init__()
|
34 |
+
if padding not in ["center", "same"]:
|
35 |
+
raise ValueError("Padding must be 'center' or 'same'.")
|
36 |
+
self.padding = padding
|
37 |
+
self.n_fft = n_fft
|
38 |
+
self.hop_length = hop_length
|
39 |
+
self.win_length = win_length
|
40 |
+
window = torch.hann_window(win_length)
|
41 |
+
self.register_buffer("window", window)
|
42 |
+
|
43 |
+
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
44 |
+
"""
|
45 |
+
Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram.
|
46 |
+
|
47 |
+
Args:
|
48 |
+
spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size,
|
49 |
+
N is the number of frequency bins, and T is the number of time frames.
|
50 |
+
|
51 |
+
Returns:
|
52 |
+
Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal.
|
53 |
+
"""
|
54 |
+
if self.padding == "center":
|
55 |
+
# Fallback to pytorch native implementation
|
56 |
+
return torch.istft(spec, self.n_fft, self.hop_length, self.win_length, self.window, center=True)
|
57 |
+
elif self.padding == "same":
|
58 |
+
pad = (self.win_length - self.hop_length) // 2
|
59 |
+
else:
|
60 |
+
raise ValueError("Padding must be 'center' or 'same'.")
|
61 |
+
|
62 |
+
assert spec.dim() == 3, "Expected a 3D tensor as input"
|
63 |
+
B, N, T = spec.shape
|
64 |
+
|
65 |
+
# Inverse FFT
|
66 |
+
ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward")
|
67 |
+
ifft = ifft * self.window[None, :, None]
|
68 |
+
|
69 |
+
# Overlap and Add
|
70 |
+
output_size = (T - 1) * self.hop_length + self.win_length
|
71 |
+
y = torch.nn.functional.fold(
|
72 |
+
ifft, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length),
|
73 |
+
)[:, 0, 0, pad:-pad]
|
74 |
+
|
75 |
+
# Window envelope
|
76 |
+
window_sq = self.window.square().expand(1, T, -1).transpose(1, 2)
|
77 |
+
window_envelope = torch.nn.functional.fold(
|
78 |
+
window_sq, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length),
|
79 |
+
).squeeze()[pad:-pad]
|
80 |
+
|
81 |
+
# Normalize
|
82 |
+
assert (window_envelope > 1e-11).all()
|
83 |
+
y = y / window_envelope
|
84 |
+
|
85 |
+
return y
|
86 |
+
|
87 |
+
|
88 |
+
|
89 |
+
class FourierHead(nn.Module):
|
90 |
+
"""Base class for inverse fourier modules."""
|
91 |
+
|
92 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
93 |
+
"""
|
94 |
+
Args:
|
95 |
+
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
|
96 |
+
L is the sequence length, and H denotes the model dimension.
|
97 |
+
|
98 |
+
Returns:
|
99 |
+
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
|
100 |
+
"""
|
101 |
+
raise NotImplementedError("Subclasses must implement the forward method.")
|
102 |
+
|
103 |
+
|
104 |
+
class ISTFTHead(FourierHead):
|
105 |
+
"""
|
106 |
+
ISTFT Head module for predicting STFT complex coefficients.
|
107 |
+
|
108 |
+
Args:
|
109 |
+
dim (int): Hidden dimension of the model.
|
110 |
+
n_fft (int): Size of Fourier transform.
|
111 |
+
hop_length (int): The distance between neighboring sliding window frames, which should align with
|
112 |
+
the resolution of the input features.
|
113 |
+
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
114 |
+
"""
|
115 |
+
|
116 |
+
def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"):
|
117 |
+
super().__init__()
|
118 |
+
out_dim = n_fft + 2
|
119 |
+
self.out = torch.nn.Linear(dim, out_dim)
|
120 |
+
self.istft = ISTFT(n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding)
|
121 |
+
|
122 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
123 |
+
"""
|
124 |
+
Forward pass of the ISTFTHead module.
|
125 |
+
|
126 |
+
Args:
|
127 |
+
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
|
128 |
+
L is the sequence length, and H denotes the model dimension.
|
129 |
+
|
130 |
+
Returns:
|
131 |
+
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
|
132 |
+
"""
|
133 |
+
x_pred = self.out(x )
|
134 |
+
# x_pred = x
|
135 |
+
x_pred = x_pred.transpose(1, 2)
|
136 |
+
mag, p = x_pred.chunk(2, dim=1)
|
137 |
+
mag = torch.exp(mag)
|
138 |
+
mag = torch.clip(mag, max=1e2) # safeguard to prevent excessively large magnitudes
|
139 |
+
# wrapping happens here. These two lines produce real and imaginary value
|
140 |
+
x = torch.cos(p)
|
141 |
+
y = torch.sin(p)
|
142 |
+
# recalculating phase here does not produce anything new
|
143 |
+
# only costs time
|
144 |
+
# phase = torch.atan2(y, x)
|
145 |
+
# S = mag * torch.exp(phase * 1j)
|
146 |
+
# better directly produce the complex value
|
147 |
+
S = mag * (x + 1j * y)
|
148 |
+
audio = self.istft(S)
|
149 |
+
return audio.unsqueeze(1),x_pred
|
150 |
+
|
151 |
+
|
152 |
+
def nonlinearity(x):
|
153 |
+
# swish
|
154 |
+
return x * torch.sigmoid(x)
|
155 |
+
|
156 |
+
|
157 |
+
def Normalize(in_channels, num_groups=32):
|
158 |
+
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
159 |
+
|
160 |
+
|
161 |
+
class ResnetBlock(nn.Module):
|
162 |
+
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
|
163 |
+
dropout, temb_channels=512):
|
164 |
+
super().__init__()
|
165 |
+
self.in_channels = in_channels
|
166 |
+
out_channels = in_channels if out_channels is None else out_channels
|
167 |
+
self.out_channels = out_channels
|
168 |
+
self.use_conv_shortcut = conv_shortcut
|
169 |
+
|
170 |
+
self.norm1 = Normalize(in_channels)
|
171 |
+
self.conv1 = torch.nn.Conv1d(in_channels,
|
172 |
+
out_channels,
|
173 |
+
kernel_size=3,
|
174 |
+
stride=1,
|
175 |
+
padding=1)
|
176 |
+
if temb_channels > 0:
|
177 |
+
self.temb_proj = torch.nn.Linear(temb_channels,
|
178 |
+
out_channels)
|
179 |
+
self.norm2 = Normalize(out_channels)
|
180 |
+
self.dropout = torch.nn.Dropout(dropout)
|
181 |
+
self.conv2 = torch.nn.Conv1d(out_channels,
|
182 |
+
out_channels,
|
183 |
+
kernel_size=3,
|
184 |
+
stride=1,
|
185 |
+
padding=1)
|
186 |
+
if self.in_channels != self.out_channels:
|
187 |
+
if self.use_conv_shortcut:
|
188 |
+
self.conv_shortcut = torch.nn.Conv1d(in_channels,
|
189 |
+
out_channels,
|
190 |
+
kernel_size=3,
|
191 |
+
stride=1,
|
192 |
+
padding=1)
|
193 |
+
else:
|
194 |
+
self.nin_shortcut = torch.nn.Conv1d(in_channels,
|
195 |
+
out_channels,
|
196 |
+
kernel_size=1,
|
197 |
+
stride=1,
|
198 |
+
padding=0)
|
199 |
+
|
200 |
+
def forward(self, x, temb=None):
|
201 |
+
h = x
|
202 |
+
h = self.norm1(h)
|
203 |
+
h = nonlinearity(h)
|
204 |
+
h = self.conv1(h)
|
205 |
+
|
206 |
+
if temb is not None:
|
207 |
+
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
|
208 |
+
|
209 |
+
h = self.norm2(h)
|
210 |
+
h = nonlinearity(h)
|
211 |
+
h = self.dropout(h)
|
212 |
+
h = self.conv2(h)
|
213 |
+
|
214 |
+
if self.in_channels != self.out_channels:
|
215 |
+
if self.use_conv_shortcut:
|
216 |
+
x = self.conv_shortcut(x)
|
217 |
+
else:
|
218 |
+
x = self.nin_shortcut(x)
|
219 |
+
|
220 |
+
return x + h
|
221 |
+
|
222 |
+
class AttnBlock(nn.Module):
|
223 |
+
def __init__(self, in_channels):
|
224 |
+
super().__init__()
|
225 |
+
self.in_channels = in_channels
|
226 |
+
|
227 |
+
self.norm = Normalize(in_channels)
|
228 |
+
self.q = torch.nn.Conv1d(in_channels,
|
229 |
+
in_channels,
|
230 |
+
kernel_size=1,
|
231 |
+
stride=1,
|
232 |
+
padding=0)
|
233 |
+
self.k = torch.nn.Conv1d(in_channels,
|
234 |
+
in_channels,
|
235 |
+
kernel_size=1,
|
236 |
+
stride=1,
|
237 |
+
padding=0)
|
238 |
+
self.v = torch.nn.Conv1d(in_channels,
|
239 |
+
in_channels,
|
240 |
+
kernel_size=1,
|
241 |
+
stride=1,
|
242 |
+
padding=0)
|
243 |
+
self.proj_out = torch.nn.Conv1d(in_channels,
|
244 |
+
in_channels,
|
245 |
+
kernel_size=1,
|
246 |
+
stride=1,
|
247 |
+
padding=0)
|
248 |
+
|
249 |
+
def forward(self, x):
|
250 |
+
h_ = x
|
251 |
+
h_ = self.norm(h_)
|
252 |
+
q = self.q(h_)
|
253 |
+
k = self.k(h_)
|
254 |
+
v = self.v(h_)
|
255 |
+
|
256 |
+
# compute attention
|
257 |
+
b, c, h = q.shape
|
258 |
+
q = q.permute(0, 2, 1) # b,hw,c
|
259 |
+
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
260 |
+
w_ = w_ * (int(c) ** (-0.5))
|
261 |
+
w_ = torch.nn.functional.softmax(w_, dim=2)
|
262 |
+
|
263 |
+
# attend to values
|
264 |
+
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
|
265 |
+
h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
266 |
+
|
267 |
+
h_ = self.proj_out(h_)
|
268 |
+
|
269 |
+
return x + h_
|
270 |
+
|
271 |
+
def make_attn(in_channels, attn_type="vanilla"):
|
272 |
+
assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown'
|
273 |
+
print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
|
274 |
+
if attn_type == "vanilla":
|
275 |
+
return AttnBlock(in_channels)
|
276 |
+
|
277 |
+
|
278 |
+
class Backbone(nn.Module):
|
279 |
+
"""Base class for the generator's backbone. It preserves the same temporal resolution across all layers."""
|
280 |
+
|
281 |
+
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
|
282 |
+
"""
|
283 |
+
Args:
|
284 |
+
x (Tensor): Input tensor of shape (B, C, L), where B is the batch size,
|
285 |
+
C denotes output features, and L is the sequence length.
|
286 |
+
|
287 |
+
Returns:
|
288 |
+
Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length,
|
289 |
+
and H denotes the model dimension.
|
290 |
+
"""
|
291 |
+
raise NotImplementedError("Subclasses must implement the forward method.")
|
292 |
+
|
293 |
+
|
294 |
+
class VocosBackbone(Backbone):
|
295 |
+
"""
|
296 |
+
Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization
|
297 |
+
|
298 |
+
Args:
|
299 |
+
input_channels (int): Number of input features channels.
|
300 |
+
dim (int): Hidden dimension of the model.
|
301 |
+
intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock.
|
302 |
+
num_layers (int): Number of ConvNeXtBlock layers.
|
303 |
+
layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`.
|
304 |
+
adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
|
305 |
+
None means non-conditional model. Defaults to None.
|
306 |
+
"""
|
307 |
+
|
308 |
+
def __init__(
|
309 |
+
self, hidden_dim=1024,depth=12,heads=16,pos_meb_dim=64):
|
310 |
+
super().__init__()
|
311 |
+
|
312 |
+
self.embed = nn.Conv1d(hidden_dim, hidden_dim, kernel_size=7, padding=3)
|
313 |
+
|
314 |
+
|
315 |
+
|
316 |
+
self.temb_ch = 0
|
317 |
+
block_in = hidden_dim
|
318 |
+
dropout = 0.1
|
319 |
+
|
320 |
+
prior_net : tp.List[nn.Module] = [
|
321 |
+
ResnetBlock(in_channels=block_in,out_channels=block_in,
|
322 |
+
temb_channels=self.temb_ch,dropout=dropout),
|
323 |
+
ResnetBlock(in_channels=block_in,out_channels=block_in,
|
324 |
+
temb_channels=self.temb_ch,dropout=dropout),
|
325 |
+
]
|
326 |
+
self.prior_net = nn.Sequential(*prior_net)
|
327 |
+
|
328 |
+
depth = depth
|
329 |
+
time_rotary_embed = RotaryPositionalEmbeddings(dim=pos_meb_dim)
|
330 |
+
|
331 |
+
|
332 |
+
transformer_blocks = [
|
333 |
+
TransformerBlock(dim=hidden_dim, n_heads=heads, rotary_embed=time_rotary_embed)
|
334 |
+
for _ in range(depth)
|
335 |
+
]
|
336 |
+
|
337 |
+
|
338 |
+
self.transformers = nn.Sequential(*transformer_blocks)
|
339 |
+
self.final_layer_norm = nn.LayerNorm(hidden_dim, eps=1e-6)
|
340 |
+
post_net : tp.List[nn.Module] = [
|
341 |
+
ResnetBlock(in_channels=block_in,out_channels=block_in,
|
342 |
+
temb_channels=self.temb_ch,dropout=dropout),
|
343 |
+
ResnetBlock(in_channels=block_in,out_channels=block_in,
|
344 |
+
temb_channels=self.temb_ch,dropout=dropout),
|
345 |
+
]
|
346 |
+
self.post_net = nn.Sequential(*post_net)
|
347 |
+
|
348 |
+
def forward(self, x: torch.Tensor ) -> torch.Tensor:
|
349 |
+
x = x.transpose(1, 2)
|
350 |
+
x = self.embed(x)
|
351 |
+
x = self.prior_net(x)
|
352 |
+
x = x.transpose(1, 2)
|
353 |
+
x= self.transformers(x)
|
354 |
+
x = x.transpose(1, 2)
|
355 |
+
x = self.post_net(x)
|
356 |
+
x = x.transpose(1, 2)
|
357 |
+
x = self.final_layer_norm(x)
|
358 |
+
return x
|
359 |
+
|
360 |
+
|
361 |
+
|
362 |
+
|
363 |
+
|
364 |
+
|
365 |
+
|
366 |
+
def init_weights(m):
|
367 |
+
if isinstance(m, nn.Conv1d):
|
368 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
369 |
+
nn.init.constant_(m.bias, 0)
|
370 |
+
|
371 |
+
class CodecDecoderVocos(nn.Module):
|
372 |
+
def __init__(self,
|
373 |
+
hidden_dim=1024,
|
374 |
+
depth=12,
|
375 |
+
heads=16,
|
376 |
+
pos_meb_dim=64,
|
377 |
+
hop_length=320,
|
378 |
+
vq_num_quantizers=1,
|
379 |
+
vq_dim=2048, #1024 2048
|
380 |
+
vq_commit_weight=0.25,
|
381 |
+
vq_weight_init=False,
|
382 |
+
vq_full_commit_loss=False,
|
383 |
+
codebook_size=16384,
|
384 |
+
codebook_dim=16,
|
385 |
+
):
|
386 |
+
super().__init__()
|
387 |
+
self.hop_length = hop_length
|
388 |
+
|
389 |
+
self.quantizer = ResidualFSQ(
|
390 |
+
dim = vq_dim,
|
391 |
+
levels = [4, 4, 4, 4, 4,4,4,4],
|
392 |
+
num_quantizers = 1
|
393 |
+
)
|
394 |
+
|
395 |
+
# self.quantizer = ResidualVQ(
|
396 |
+
# num_quantizers=vq_num_quantizers,
|
397 |
+
# dim=vq_dim,
|
398 |
+
# codebook_size=codebook_size,
|
399 |
+
# codebook_dim=codebook_dim,
|
400 |
+
# threshold_ema_dead_code=2,
|
401 |
+
# commitment=vq_commit_weight,
|
402 |
+
# weight_init=vq_weight_init,
|
403 |
+
# full_commit_loss=vq_full_commit_loss,
|
404 |
+
# )
|
405 |
+
|
406 |
+
|
407 |
+
self.backbone = VocosBackbone( hidden_dim=hidden_dim,depth=depth,heads=heads,pos_meb_dim=pos_meb_dim)
|
408 |
+
|
409 |
+
self.head = ISTFTHead(dim=hidden_dim, n_fft=self.hop_length*4, hop_length=self.hop_length, padding="same")
|
410 |
+
|
411 |
+
self.reset_parameters()
|
412 |
+
|
413 |
+
def forward(self, x, vq=True):
|
414 |
+
if vq is True:
|
415 |
+
# x, q, commit_loss = self.quantizer(x)
|
416 |
+
x = x.permute(0, 2, 1)
|
417 |
+
x, q = self.quantizer(x)
|
418 |
+
x = x.permute(0, 2, 1)
|
419 |
+
q = q.permute(0, 2, 1)
|
420 |
+
return x, q, None
|
421 |
+
x = self.backbone(x)
|
422 |
+
x,_ = self.head(x)
|
423 |
+
|
424 |
+
return x ,_
|
425 |
+
|
426 |
+
def vq2emb(self, vq):
|
427 |
+
self.quantizer = self.quantizer.eval()
|
428 |
+
x = self.quantizer.vq2emb(vq)
|
429 |
+
return x
|
430 |
+
|
431 |
+
def get_emb(self):
|
432 |
+
self.quantizer = self.quantizer.eval()
|
433 |
+
embs = self.quantizer.get_emb()
|
434 |
+
return embs
|
435 |
+
|
436 |
+
def inference_vq(self, vq):
|
437 |
+
x = vq[None,:,:]
|
438 |
+
x = self.model(x)
|
439 |
+
return x
|
440 |
+
|
441 |
+
def inference_0(self, x):
|
442 |
+
x, q, loss, perp = self.quantizer(x)
|
443 |
+
x = self.model(x)
|
444 |
+
return x, None
|
445 |
+
|
446 |
+
def inference(self, x):
|
447 |
+
x = self.model(x)
|
448 |
+
return x, None
|
449 |
+
|
450 |
+
|
451 |
+
def remove_weight_norm(self):
|
452 |
+
"""Remove weight normalization module from all of the layers."""
|
453 |
+
|
454 |
+
def _remove_weight_norm(m):
|
455 |
+
try:
|
456 |
+
torch.nn.utils.remove_weight_norm(m)
|
457 |
+
except ValueError: # this module didn't have weight norm
|
458 |
+
return
|
459 |
+
|
460 |
+
self.apply(_remove_weight_norm)
|
461 |
+
|
462 |
+
def apply_weight_norm(self):
|
463 |
+
"""Apply weight normalization module from all of the layers."""
|
464 |
+
|
465 |
+
def _apply_weight_norm(m):
|
466 |
+
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d):
|
467 |
+
torch.nn.utils.weight_norm(m)
|
468 |
+
|
469 |
+
self.apply(_apply_weight_norm)
|
470 |
+
|
471 |
+
def reset_parameters(self):
|
472 |
+
self.apply(init_weights)
|
473 |
+
|
474 |
+
|
475 |
+
|
476 |
+
class CodecDecoderVocos_transpose(nn.Module):
|
477 |
+
def __init__(self,
|
478 |
+
hidden_dim=1024,
|
479 |
+
depth=12,
|
480 |
+
heads=16,
|
481 |
+
pos_meb_dim=64,
|
482 |
+
hop_length=320,
|
483 |
+
vq_num_quantizers=1,
|
484 |
+
vq_dim=1024, #1024 2048
|
485 |
+
vq_commit_weight=0.25,
|
486 |
+
vq_weight_init=False,
|
487 |
+
vq_full_commit_loss=False,
|
488 |
+
codebook_size=16384,
|
489 |
+
codebook_dim=16,
|
490 |
+
):
|
491 |
+
super().__init__()
|
492 |
+
self.hop_length = hop_length
|
493 |
+
|
494 |
+
|
495 |
+
self.quantizer = ResidualVQ(
|
496 |
+
num_quantizers=vq_num_quantizers,
|
497 |
+
dim=vq_dim,
|
498 |
+
codebook_size=codebook_size,
|
499 |
+
codebook_dim=codebook_dim,
|
500 |
+
threshold_ema_dead_code=2,
|
501 |
+
commitment=vq_commit_weight,
|
502 |
+
weight_init=vq_weight_init,
|
503 |
+
full_commit_loss=vq_full_commit_loss,
|
504 |
+
)
|
505 |
+
|
506 |
+
|
507 |
+
self.backbone = VocosBackbone( hidden_dim=hidden_dim,depth=depth,heads=heads,pos_meb_dim=pos_meb_dim)
|
508 |
+
|
509 |
+
self.inverse_mel_conv = nn.Sequential(
|
510 |
+
nn.GELU(),
|
511 |
+
nn.ConvTranspose1d(
|
512 |
+
in_channels=hidden_dim,
|
513 |
+
out_channels=hidden_dim,
|
514 |
+
kernel_size=3,
|
515 |
+
stride=2,
|
516 |
+
padding=1,
|
517 |
+
output_padding=1 # 确保输出长度与编码前匹配
|
518 |
+
),
|
519 |
+
nn.GELU(),
|
520 |
+
nn.ConvTranspose1d(
|
521 |
+
in_channels=hidden_dim,
|
522 |
+
out_channels=hidden_dim,
|
523 |
+
kernel_size=3,
|
524 |
+
padding=1
|
525 |
+
)
|
526 |
+
)
|
527 |
+
|
528 |
+
self.head = ISTFTHead(dim=hidden_dim, n_fft=self.hop_length*4, hop_length=self.hop_length, padding="same")
|
529 |
+
|
530 |
+
self.reset_parameters()
|
531 |
+
|
532 |
+
def forward(self, x, vq=True):
|
533 |
+
if vq is True:
|
534 |
+
x, q, commit_loss = self.quantizer(x)
|
535 |
+
return x, q, commit_loss
|
536 |
+
x = self.backbone(x)
|
537 |
+
x,_ = self.head(x)
|
538 |
+
|
539 |
+
return x ,_
|
540 |
+
|
541 |
+
def vq2emb(self, vq):
|
542 |
+
self.quantizer = self.quantizer.eval()
|
543 |
+
x = self.quantizer.vq2emb(vq)
|
544 |
+
return x
|
545 |
+
|
546 |
+
def get_emb(self):
|
547 |
+
self.quantizer = self.quantizer.eval()
|
548 |
+
embs = self.quantizer.get_emb()
|
549 |
+
return embs
|
550 |
+
|
551 |
+
def inference_vq(self, vq):
|
552 |
+
x = vq[None,:,:]
|
553 |
+
x = self.model(x)
|
554 |
+
return x
|
555 |
+
|
556 |
+
def inference_0(self, x):
|
557 |
+
x, q, loss, perp = self.quantizer(x)
|
558 |
+
x = self.model(x)
|
559 |
+
return x, None
|
560 |
+
|
561 |
+
def inference(self, x):
|
562 |
+
x = self.model(x)
|
563 |
+
return x, None
|
564 |
+
|
565 |
+
|
566 |
+
def remove_weight_norm(self):
|
567 |
+
"""Remove weight normalization module from all of the layers."""
|
568 |
+
|
569 |
+
def _remove_weight_norm(m):
|
570 |
+
try:
|
571 |
+
torch.nn.utils.remove_weight_norm(m)
|
572 |
+
except ValueError: # this module didn't have weight norm
|
573 |
+
return
|
574 |
+
|
575 |
+
self.apply(_remove_weight_norm)
|
576 |
+
|
577 |
+
def apply_weight_norm(self):
|
578 |
+
"""Apply weight normalization module from all of the layers."""
|
579 |
+
|
580 |
+
def _apply_weight_norm(m):
|
581 |
+
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d):
|
582 |
+
torch.nn.utils.weight_norm(m)
|
583 |
+
|
584 |
+
self.apply(_apply_weight_norm)
|
585 |
+
|
586 |
+
def reset_parameters(self):
|
587 |
+
self.apply(init_weights)
|
588 |
+
|
589 |
+
|
590 |
+
|
591 |
+
|
592 |
+
def main():
|
593 |
+
# 设置设备
|
594 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
595 |
+
print(f"Using device: {device}")
|
596 |
+
|
597 |
+
# 初始化模型
|
598 |
+
model = CodecDecoderVocos_transpose().to(device)
|
599 |
+
print("Model initialized.")
|
600 |
+
|
601 |
+
# 创建测试输入: batch_size x in_channels x sequence_length
|
602 |
+
batch_size = 2
|
603 |
+
in_channels = 1024
|
604 |
+
sequence_length = 50 # 示例长度,可以根据需要调整
|
605 |
+
dummy_input = torch.randn(batch_size, in_channels, sequence_length).to(device)
|
606 |
+
print(f"Dummy input shape: {dummy_input.shape}")
|
607 |
+
|
608 |
+
# 将模型设为评估模式
|
609 |
+
model.eval()
|
610 |
+
|
611 |
+
# 前向传播(使用 VQ)
|
612 |
+
# with torch.no_grad():
|
613 |
+
# try:
|
614 |
+
# output, q, commit_loss = model(dummy_input, vq=True)
|
615 |
+
# print("Forward pass with VQ:")
|
616 |
+
# print(f"Output shape: {output.shape}")
|
617 |
+
# print(f"Quantized codes shape: {q.shape}")
|
618 |
+
# print(f"Commitment loss: {commit_loss}")
|
619 |
+
# except Exception as e:
|
620 |
+
# print(f"Error during forward pass with VQ: {e}")
|
621 |
+
|
622 |
+
# 前向传播(不使用 VQ)
|
623 |
+
with torch.no_grad():
|
624 |
+
# try:
|
625 |
+
output_no_vq = model(dummy_input, vq=False)
|
626 |
+
print("\nForward pass without VQ:")
|
627 |
+
print(f"Output shape: {output_no_vq.shape}")
|
628 |
+
c=1
|
629 |
+
# except Exception as e:
|
630 |
+
# print(f"Error during forward pass without VQ: {e}")
|
631 |
+
|
632 |
+
|
633 |
+
# model_size_bytes = sum(p.numel() * p.element_size() for p in model.parameters())
|
634 |
+
# model_size_mb = model_size_bytes / (1024 ** 2)
|
635 |
+
# print(f"Model size: {model_size_bytes} bytes ({model_size_mb:.2f} MB)")
|
636 |
+
|
637 |
+
if __name__ == "__main__":
|
638 |
+
main()
|
vq/codec_encoder.py
ADDED
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
import numpy as np
|
6 |
+
from vq.module import WNConv1d, EncoderBlock, ResLSTM
|
7 |
+
from vq.alias_free_torch import *
|
8 |
+
from vq import activations
|
9 |
+
from vq.bs_roformer5 import TransformerBlock
|
10 |
+
|
11 |
+
from torchtune.modules import RotaryPositionalEmbeddings
|
12 |
+
import vq.blocks as blocks
|
13 |
+
from torch.nn import utils
|
14 |
+
def init_weights(m):
|
15 |
+
if isinstance(m, nn.Conv1d):
|
16 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
17 |
+
nn.init.constant_(m.bias, 0)
|
18 |
+
|
19 |
+
class CodecEncoder(nn.Module):
|
20 |
+
def __init__(self,
|
21 |
+
ngf=48,
|
22 |
+
use_rnn=True,
|
23 |
+
rnn_bidirectional=False,
|
24 |
+
rnn_num_layers=2,
|
25 |
+
up_ratios=(2, 2, 4, 4, 5),
|
26 |
+
dilations=(1, 3, 9),
|
27 |
+
out_channels=1024):
|
28 |
+
super().__init__()
|
29 |
+
self.hop_length = np.prod(up_ratios)
|
30 |
+
self.ngf = ngf
|
31 |
+
self.up_ratios = up_ratios
|
32 |
+
|
33 |
+
# Create first convolution
|
34 |
+
d_model = ngf
|
35 |
+
self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
|
36 |
+
|
37 |
+
# Create EncoderBlocks that double channels as they downsample by `stride`
|
38 |
+
for i, stride in enumerate(up_ratios):
|
39 |
+
d_model *= 2
|
40 |
+
self.block += [EncoderBlock(d_model, stride=stride, dilations=dilations)]
|
41 |
+
# RNN
|
42 |
+
if use_rnn:
|
43 |
+
self.block += [
|
44 |
+
ResLSTM(d_model,
|
45 |
+
num_layers=rnn_num_layers,
|
46 |
+
bidirectional=rnn_bidirectional
|
47 |
+
)
|
48 |
+
]
|
49 |
+
# Create last convolution
|
50 |
+
self.block += [
|
51 |
+
Activation1d(activation=activations.SnakeBeta(d_model, alpha_logscale=True)),
|
52 |
+
WNConv1d(d_model, out_channels, kernel_size=3, padding=1),
|
53 |
+
]
|
54 |
+
|
55 |
+
# Wrap black into nn.Sequential
|
56 |
+
self.block = nn.Sequential(*self.block)
|
57 |
+
self.enc_dim = d_model
|
58 |
+
|
59 |
+
self.reset_parameters()
|
60 |
+
|
61 |
+
def forward(self, x):
|
62 |
+
out = self.block(x)
|
63 |
+
return out
|
64 |
+
|
65 |
+
def inference(self, x):
|
66 |
+
return self.block(x)
|
67 |
+
|
68 |
+
def remove_weight_norm(self):
|
69 |
+
"""Remove weight normalization module from all of the layers."""
|
70 |
+
|
71 |
+
def _remove_weight_norm(m):
|
72 |
+
try:
|
73 |
+
torch.nn.utils.remove_weight_norm(m)
|
74 |
+
except ValueError: # this module didn't have weight norm
|
75 |
+
return
|
76 |
+
|
77 |
+
self.apply(_remove_weight_norm)
|
78 |
+
|
79 |
+
def apply_weight_norm(self):
|
80 |
+
"""Apply weight normalization module from all of the layers."""
|
81 |
+
|
82 |
+
def _apply_weight_norm(m):
|
83 |
+
if isinstance(m, nn.Conv1d):
|
84 |
+
torch.nn.utils.weight_norm(m)
|
85 |
+
|
86 |
+
self.apply(_apply_weight_norm)
|
87 |
+
|
88 |
+
def reset_parameters(self):
|
89 |
+
self.apply(init_weights)
|
90 |
+
|
91 |
+
|
92 |
+
class Transpose(nn.Module):
|
93 |
+
def __init__(self, dim1, dim2):
|
94 |
+
super(Transpose, self).__init__()
|
95 |
+
self.dim1 = dim1
|
96 |
+
self.dim2 = dim2
|
97 |
+
|
98 |
+
def forward(self, x):
|
99 |
+
return x.transpose(self.dim1, self.dim2)
|
100 |
+
|
101 |
+
class CodecEncoder_Transformer(nn.Module):
|
102 |
+
def __init__(self,
|
103 |
+
ngf=48,
|
104 |
+
up_ratios=[2, 2, 4, 4, 5],
|
105 |
+
dilations=(1, 3, 9),
|
106 |
+
hidden_dim=1024,
|
107 |
+
depth=12,
|
108 |
+
heads=12,
|
109 |
+
pos_meb_dim=64,
|
110 |
+
):
|
111 |
+
super().__init__()
|
112 |
+
self.hop_length = np.prod(up_ratios)
|
113 |
+
self.ngf =ngf
|
114 |
+
self.up_ratios = up_ratios
|
115 |
+
|
116 |
+
d_model = ngf
|
117 |
+
self.conv_blocks = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
|
118 |
+
|
119 |
+
|
120 |
+
for i, stride in enumerate(up_ratios):
|
121 |
+
d_model *= 2
|
122 |
+
self.conv_blocks += [EncoderBlock(d_model, stride=stride, dilations=dilations)]
|
123 |
+
|
124 |
+
self.conv_blocks = nn.Sequential(*self.conv_blocks)
|
125 |
+
|
126 |
+
|
127 |
+
# time_rotary_embed = RotaryPositionalEmbeddings(dim=pos_meb_dim)
|
128 |
+
|
129 |
+
|
130 |
+
# transformer_blocks = [
|
131 |
+
# TransformerBlock(dim=hidden_dim, n_heads=heads, rotary_embed=time_rotary_embed)
|
132 |
+
# for _ in range(depth)
|
133 |
+
# ]
|
134 |
+
|
135 |
+
|
136 |
+
# self.transformers = nn.Sequential(*transformer_blocks)
|
137 |
+
|
138 |
+
# self.final_layer_norm = nn.LayerNorm(hidden_dim, eps=1e-6)
|
139 |
+
|
140 |
+
self.conv_final_block = [
|
141 |
+
Activation1d(activation=activations.SnakeBeta(d_model, alpha_logscale=True)),
|
142 |
+
WNConv1d(d_model, hidden_dim, kernel_size=3, padding=1),
|
143 |
+
]
|
144 |
+
self.conv_final_block = nn.Sequential(*self.conv_final_block)
|
145 |
+
|
146 |
+
self.reset_parameters()
|
147 |
+
|
148 |
+
def forward(self, x):
|
149 |
+
x = self.conv_blocks(x)
|
150 |
+
# x = x.permute(0, 2, 1)
|
151 |
+
# x= self.transformers(x)
|
152 |
+
# x = self.final_layer_norm(x)
|
153 |
+
# x = x.permute(0, 2, 1)
|
154 |
+
x = self.conv_final_block (x)
|
155 |
+
x = x.permute(0, 2, 1)
|
156 |
+
return x
|
157 |
+
|
158 |
+
def inference(self, x):
|
159 |
+
return self.block(x)
|
160 |
+
|
161 |
+
def remove_weight_norm(self):
|
162 |
+
"""Remove weight normalization module from all of the layers."""
|
163 |
+
|
164 |
+
def _remove_weight_norm(m):
|
165 |
+
try:
|
166 |
+
torch.nn.utils.remove_weight_norm(m)
|
167 |
+
except ValueError: # this module didn't have weight norm
|
168 |
+
return
|
169 |
+
|
170 |
+
self.apply(_remove_weight_norm)
|
171 |
+
|
172 |
+
def apply_weight_norm(self):
|
173 |
+
"""Apply weight normalization module from all of the layers."""
|
174 |
+
|
175 |
+
def _apply_weight_norm(m):
|
176 |
+
if isinstance(m, nn.Conv1d):
|
177 |
+
torch.nn.utils.weight_norm(m)
|
178 |
+
|
179 |
+
self.apply(_apply_weight_norm)
|
180 |
+
|
181 |
+
def reset_parameters(self):
|
182 |
+
self.apply(init_weights)
|
183 |
+
|
184 |
+
|
185 |
+
|
186 |
+
class Codec_oobleck_Transformer(nn.Module):
|
187 |
+
def __init__(self,
|
188 |
+
ngf=32,
|
189 |
+
up_ratios=(2, 2,4,4, 5),
|
190 |
+
dilations=(1, 3, 9),
|
191 |
+
hidden_dim=1024,
|
192 |
+
depth=12,
|
193 |
+
heads=16,
|
194 |
+
pos_meb_dim=64,
|
195 |
+
):
|
196 |
+
super().__init__()
|
197 |
+
self.hop_length = np.prod(up_ratios)
|
198 |
+
self.ngf =ngf
|
199 |
+
self.up_ratios = up_ratios
|
200 |
+
self.hidden_dim = hidden_dim
|
201 |
+
|
202 |
+
|
203 |
+
self.conv_blocks = blocks.DilatedResidualEncoder(
|
204 |
+
capacity=ngf,
|
205 |
+
dilated_unit=self.dilated_unit,
|
206 |
+
downsampling_unit=self.downsampling_unit,
|
207 |
+
ratios=up_ratios,
|
208 |
+
dilations=dilations,
|
209 |
+
pre_network_conv=self.pre_conv,
|
210 |
+
post_network_conv=self.post_conv,
|
211 |
+
)
|
212 |
+
|
213 |
+
|
214 |
+
time_rotary_embed = RotaryPositionalEmbeddings(dim=pos_meb_dim)
|
215 |
+
|
216 |
+
transformer_blocks = [
|
217 |
+
TransformerBlock(dim=hidden_dim, n_heads=heads, rotary_embed=time_rotary_embed)
|
218 |
+
for _ in range(depth)
|
219 |
+
]
|
220 |
+
|
221 |
+
self.transformers = nn.Sequential(*transformer_blocks)
|
222 |
+
|
223 |
+
self.final_layer_norm = nn.LayerNorm(hidden_dim, eps=1e-6)
|
224 |
+
|
225 |
+
|
226 |
+
self.reset_parameters()
|
227 |
+
|
228 |
+
def forward(self, x):
|
229 |
+
x = self.conv_blocks(x)
|
230 |
+
x = x.permute(0, 2, 1)
|
231 |
+
x= self.transformers(x)
|
232 |
+
x = self.final_layer_norm(x)
|
233 |
+
return x
|
234 |
+
|
235 |
+
def inference(self, x):
|
236 |
+
return self.block(x)
|
237 |
+
|
238 |
+
def remove_weight_norm(self):
|
239 |
+
"""Remove weight normalization module from all of the layers."""
|
240 |
+
|
241 |
+
def _remove_weight_norm(m):
|
242 |
+
try:
|
243 |
+
torch.nn.utils.remove_weight_norm(m)
|
244 |
+
except ValueError: # this module didn't have weight norm
|
245 |
+
return
|
246 |
+
|
247 |
+
self.apply(_remove_weight_norm)
|
248 |
+
|
249 |
+
def apply_weight_norm(self):
|
250 |
+
"""Apply weight normalization module from all of the layers."""
|
251 |
+
|
252 |
+
def _apply_weight_norm(m):
|
253 |
+
if isinstance(m, nn.Conv1d):
|
254 |
+
torch.nn.utils.weight_norm(m)
|
255 |
+
|
256 |
+
self.apply(_apply_weight_norm)
|
257 |
+
|
258 |
+
def reset_parameters(self):
|
259 |
+
self.apply(init_weights)
|
260 |
+
|
261 |
+
def dilated_unit(self,hidden_dim, dilation):
|
262 |
+
return blocks.DilatedConvolutionalUnit(hidden_dim,
|
263 |
+
dilation,
|
264 |
+
kernel_size=3,
|
265 |
+
activation=nn.ReLU,
|
266 |
+
normalization=utils.weight_norm)
|
267 |
+
|
268 |
+
def downsampling_unit(self, input_dim: int, output_dim: int, stride: int):
|
269 |
+
return blocks.DownsamplingUnit(input_dim,
|
270 |
+
output_dim,
|
271 |
+
stride,
|
272 |
+
nn.ReLU,
|
273 |
+
normalization=utils.weight_norm)
|
274 |
+
|
275 |
+
def pre_conv(self,out_channels):
|
276 |
+
return nn.Conv1d(1, out_channels, 1)
|
277 |
+
|
278 |
+
def post_conv(self,in_channels):
|
279 |
+
return nn.Conv1d(in_channels, self.hidden_dim, 1)
|
280 |
+
|
281 |
+
|
282 |
+
|
283 |
+
|
284 |
+
|
285 |
+
class CodecEncoder_only_Transformer(nn.Module):
|
286 |
+
def __init__(self,hidden_dim=1024,depth=12,heads=16,pos_meb_dim=64):
|
287 |
+
super().__init__()
|
288 |
+
# self.embed = nn.Linear(input_dim, hidden_dim )input_dim=300,
|
289 |
+
|
290 |
+
depth = depth
|
291 |
+
time_rotary_embed = RotaryPositionalEmbeddings(dim=pos_meb_dim)
|
292 |
+
|
293 |
+
|
294 |
+
transformer_blocks = [
|
295 |
+
TransformerBlock(dim=hidden_dim, n_heads=heads, rotary_embed=time_rotary_embed)
|
296 |
+
for _ in range(depth)
|
297 |
+
]
|
298 |
+
|
299 |
+
|
300 |
+
self.transformers = nn.Sequential(*transformer_blocks)
|
301 |
+
|
302 |
+
self.final_layer_norm = nn.LayerNorm(hidden_dim, eps=1e-6)
|
303 |
+
|
304 |
+
def forward(self, x: torch.Tensor ) -> torch.Tensor:
|
305 |
+
# x = self.embed(x)
|
306 |
+
|
307 |
+
|
308 |
+
x= self.transformers(x)
|
309 |
+
x = self.final_layer_norm(x)
|
310 |
+
|
311 |
+
return x
|
312 |
+
|
313 |
+
|
314 |
+
|
315 |
+
|
316 |
+
|
317 |
+
|
318 |
+
|
319 |
+
def get_model_size(model):
|
320 |
+
# 计算总参数数
|
321 |
+
total_params = sum(p.numel() for p in model.parameters())
|
322 |
+
|
323 |
+
# 假设每个参数都是32位浮点数,计算模型大小(以字节为单位)
|
324 |
+
model_size_bytes = total_params # 每个参数4字节
|
325 |
+
|
326 |
+
# 转换为更易读的单位(例如,MB)
|
327 |
+
model_size_mb = model_size_bytes / (1024 ** 2)
|
328 |
+
|
329 |
+
return total_params, model_size_mb
|
330 |
+
|
331 |
+
if __name__ == '__main__':
|
332 |
+
model = Codec_oobleck_Transformer()
|
333 |
+
x = torch.randn(1, 1, 16000) # example input tensor
|
334 |
+
output = model(x)
|
335 |
+
print("Output shape:", output.shape)
|
vq/factorized_vector_quantize.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Union
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from einops import rearrange
|
8 |
+
from torch.nn.utils import weight_norm
|
9 |
+
|
10 |
+
class FactorizedVectorQuantize(nn.Module):
|
11 |
+
def __init__(self, dim, codebook_size, codebook_dim, commitment, **kwargs):
|
12 |
+
super().__init__()
|
13 |
+
self.codebook_size = codebook_size
|
14 |
+
self.codebook_dim = codebook_dim
|
15 |
+
self.commitment = commitment
|
16 |
+
|
17 |
+
if dim != self.codebook_dim:
|
18 |
+
self.in_proj = weight_norm(nn.Linear(dim, self.codebook_dim))
|
19 |
+
self.out_proj = weight_norm(nn.Linear(self.codebook_dim, dim))
|
20 |
+
else:
|
21 |
+
self.in_proj = nn.Identity()
|
22 |
+
self.out_proj = nn.Identity()
|
23 |
+
self._codebook = nn.Embedding(codebook_size, self.codebook_dim)
|
24 |
+
|
25 |
+
@property
|
26 |
+
def codebook(self):
|
27 |
+
return self._codebook
|
28 |
+
|
29 |
+
def forward(self, z):
|
30 |
+
"""Quantized the input tensor using a fixed codebook and returns
|
31 |
+
the corresponding codebook vectors
|
32 |
+
|
33 |
+
Parameters
|
34 |
+
----------
|
35 |
+
z : Tensor[B x D x T]
|
36 |
+
|
37 |
+
Returns
|
38 |
+
-------
|
39 |
+
Tensor[B x D x T]
|
40 |
+
Quantized continuous representation of input
|
41 |
+
Tensor[1]
|
42 |
+
Commitment loss to train encoder to predict vectors closer to codebook
|
43 |
+
entries
|
44 |
+
Tensor[1]
|
45 |
+
Codebook loss to update the codebook
|
46 |
+
Tensor[B x T]
|
47 |
+
Codebook indices (quantized discrete representation of input)
|
48 |
+
Tensor[B x D x T]
|
49 |
+
Projected latents (continuous representation of input before quantization)
|
50 |
+
"""
|
51 |
+
# transpose since we use linear
|
52 |
+
|
53 |
+
z = rearrange(z, "b d t -> b t d")
|
54 |
+
|
55 |
+
# Factorized codes project input into low-dimensional space
|
56 |
+
z_e = self.in_proj(z) # z_e : (B x T x D)
|
57 |
+
z_e = rearrange(z_e, "b t d -> b d t")
|
58 |
+
z_q, indices = self.decode_latents(z_e)
|
59 |
+
|
60 |
+
|
61 |
+
if self.training:
|
62 |
+
commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction='none').mean([1, 2]) * self.commitment
|
63 |
+
codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction='none').mean([1, 2])
|
64 |
+
commit_loss = commitment_loss + codebook_loss
|
65 |
+
else:
|
66 |
+
commit_loss = torch.zeros(z.shape[0], device = z.device)
|
67 |
+
|
68 |
+
z_q = (
|
69 |
+
z_e + (z_q - z_e).detach()
|
70 |
+
) # noop in forward pass, straight-through gradient estimator in backward pass
|
71 |
+
|
72 |
+
z_q = rearrange(z_q, "b d t -> b t d")
|
73 |
+
z_q = self.out_proj(z_q)
|
74 |
+
z_q = rearrange(z_q, "b t d -> b d t")
|
75 |
+
|
76 |
+
return z_q, indices, commit_loss
|
77 |
+
|
78 |
+
def vq2emb(self, vq, proj=True):
|
79 |
+
emb = self.embed_code(vq)
|
80 |
+
if proj:
|
81 |
+
emb = self.out_proj(emb)
|
82 |
+
return emb
|
83 |
+
|
84 |
+
def get_emb(self):
|
85 |
+
return self.codebook.weight
|
86 |
+
|
87 |
+
def embed_code(self, embed_id):
|
88 |
+
return F.embedding(embed_id, self.codebook.weight)
|
89 |
+
|
90 |
+
def decode_code(self, embed_id):
|
91 |
+
return self.embed_code(embed_id).transpose(1, 2)
|
92 |
+
|
93 |
+
def decode_latents(self, latents):
|
94 |
+
encodings = rearrange(latents, "b d t -> (b t) d")
|
95 |
+
codebook = self.codebook.weight # codebook: (N x D)
|
96 |
+
|
97 |
+
# L2 normalize encodings and codebook
|
98 |
+
encodings = F.normalize(encodings)
|
99 |
+
codebook = F.normalize(codebook)
|
100 |
+
|
101 |
+
# Compute euclidean distance with codebook
|
102 |
+
dist = (
|
103 |
+
encodings.pow(2).sum(1, keepdim=True)
|
104 |
+
- 2 * encodings @ codebook.t()
|
105 |
+
+ codebook.pow(2).sum(1, keepdim=True).t()
|
106 |
+
)
|
107 |
+
indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
|
108 |
+
z_q = self.decode_code(indices)
|
109 |
+
return z_q, indices
|
vq/module.py
ADDED
@@ -0,0 +1,420 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
from einops import rearrange
|
3 |
+
from . import activations
|
4 |
+
from .alias_free_torch import *
|
5 |
+
from torch.nn.utils import weight_norm
|
6 |
+
|
7 |
+
from typing import Optional, Tuple
|
8 |
+
|
9 |
+
from torch.nn.utils import weight_norm, remove_weight_norm
|
10 |
+
|
11 |
+
|
12 |
+
def WNConv1d(*args, **kwargs):
|
13 |
+
return weight_norm(nn.Conv1d(*args, **kwargs))
|
14 |
+
|
15 |
+
|
16 |
+
def WNConvTranspose1d(*args, **kwargs):
|
17 |
+
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
|
18 |
+
|
19 |
+
class ResidualUnit(nn.Module):
|
20 |
+
def __init__(self, dim: int = 16, dilation: int = 1):
|
21 |
+
super().__init__()
|
22 |
+
pad = ((7 - 1) * dilation) // 2
|
23 |
+
self.block = nn.Sequential(
|
24 |
+
Activation1d(activation=activations.SnakeBeta(dim, alpha_logscale=True)),
|
25 |
+
WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
|
26 |
+
Activation1d(activation=activations.SnakeBeta(dim, alpha_logscale=True)),
|
27 |
+
WNConv1d(dim, dim, kernel_size=1),
|
28 |
+
)
|
29 |
+
|
30 |
+
def forward(self, x):
|
31 |
+
return x + self.block(x)
|
32 |
+
|
33 |
+
class EncoderBlock(nn.Module):
|
34 |
+
def __init__(self, dim: int = 16, stride: int = 1, dilations = (1, 3, 9)):
|
35 |
+
super().__init__()
|
36 |
+
runits = [ResidualUnit(dim // 2, dilation=d) for d in dilations]
|
37 |
+
self.block = nn.Sequential(
|
38 |
+
*runits,
|
39 |
+
Activation1d(activation=activations.SnakeBeta(dim//2, alpha_logscale=True)),
|
40 |
+
WNConv1d(
|
41 |
+
dim // 2,
|
42 |
+
dim,
|
43 |
+
kernel_size=2 * stride,
|
44 |
+
stride=stride,
|
45 |
+
padding=stride // 2 + stride % 2,
|
46 |
+
),
|
47 |
+
)
|
48 |
+
|
49 |
+
def forward(self, x):
|
50 |
+
return self.block(x)
|
51 |
+
|
52 |
+
class DecoderBlock(nn.Module):
|
53 |
+
def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1, dilations = (1, 3, 9)):
|
54 |
+
super().__init__()
|
55 |
+
self.block = nn.Sequential(
|
56 |
+
Activation1d(activation=activations.SnakeBeta(input_dim, alpha_logscale=True)),
|
57 |
+
WNConvTranspose1d(
|
58 |
+
input_dim,
|
59 |
+
output_dim,
|
60 |
+
kernel_size=2 * stride,
|
61 |
+
stride=stride,
|
62 |
+
padding=stride // 2 + stride % 2,
|
63 |
+
output_padding= stride % 2,
|
64 |
+
)
|
65 |
+
)
|
66 |
+
self.block.extend([ResidualUnit(output_dim, dilation=d) for d in dilations])
|
67 |
+
|
68 |
+
def forward(self, x):
|
69 |
+
return self.block(x)
|
70 |
+
|
71 |
+
class ResLSTM(nn.Module):
|
72 |
+
def __init__(self, dimension: int,
|
73 |
+
num_layers: int = 2,
|
74 |
+
bidirectional: bool = False,
|
75 |
+
skip: bool = True):
|
76 |
+
super().__init__()
|
77 |
+
self.skip = skip
|
78 |
+
self.lstm = nn.LSTM(dimension, dimension if not bidirectional else dimension // 2,
|
79 |
+
num_layers, batch_first=True,
|
80 |
+
bidirectional=bidirectional)
|
81 |
+
|
82 |
+
def forward(self, x):
|
83 |
+
"""
|
84 |
+
Args:
|
85 |
+
x: [B, F, T]
|
86 |
+
|
87 |
+
Returns:
|
88 |
+
y: [B, F, T]
|
89 |
+
"""
|
90 |
+
x = rearrange(x, "b f t -> b t f")
|
91 |
+
y, _ = self.lstm(x)
|
92 |
+
if self.skip:
|
93 |
+
y = y + x
|
94 |
+
y = rearrange(y, "b t f -> b f t")
|
95 |
+
return y
|
96 |
+
|
97 |
+
|
98 |
+
|
99 |
+
class ConvNeXtBlock(nn.Module):
|
100 |
+
"""ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
|
101 |
+
|
102 |
+
Args:
|
103 |
+
dim (int): Number of input channels.
|
104 |
+
intermediate_dim (int): Dimensionality of the intermediate layer.
|
105 |
+
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
|
106 |
+
Defaults to None.
|
107 |
+
adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
|
108 |
+
None means non-conditional LayerNorm. Defaults to None.
|
109 |
+
"""
|
110 |
+
|
111 |
+
def __init__(
|
112 |
+
self,
|
113 |
+
dim: int,
|
114 |
+
intermediate_dim: int,
|
115 |
+
layer_scale_init_value: float,
|
116 |
+
adanorm_num_embeddings: Optional[int] = None,
|
117 |
+
):
|
118 |
+
super().__init__()
|
119 |
+
self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
|
120 |
+
self.adanorm = adanorm_num_embeddings is not None
|
121 |
+
if adanorm_num_embeddings:
|
122 |
+
self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
|
123 |
+
else:
|
124 |
+
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
125 |
+
self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
|
126 |
+
self.act = nn.GELU()
|
127 |
+
self.pwconv2 = nn.Linear(intermediate_dim, dim)
|
128 |
+
self.gamma = (
|
129 |
+
nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
|
130 |
+
if layer_scale_init_value > 0
|
131 |
+
else None
|
132 |
+
)
|
133 |
+
|
134 |
+
def forward(self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None) -> torch.Tensor:
|
135 |
+
residual = x
|
136 |
+
x = self.dwconv(x)
|
137 |
+
x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
|
138 |
+
if self.adanorm:
|
139 |
+
assert cond_embedding_id is not None
|
140 |
+
x = self.norm(x, cond_embedding_id)
|
141 |
+
else:
|
142 |
+
x = self.norm(x)
|
143 |
+
x = self.pwconv1(x)
|
144 |
+
x = self.act(x)
|
145 |
+
x = self.pwconv2(x)
|
146 |
+
if self.gamma is not None:
|
147 |
+
x = self.gamma * x
|
148 |
+
x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
|
149 |
+
|
150 |
+
x = residual + x
|
151 |
+
return x
|
152 |
+
|
153 |
+
|
154 |
+
class AdaLayerNorm(nn.Module):
|
155 |
+
"""
|
156 |
+
Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes
|
157 |
+
|
158 |
+
Args:
|
159 |
+
num_embeddings (int): Number of embeddings.
|
160 |
+
embedding_dim (int): Dimension of the embeddings.
|
161 |
+
"""
|
162 |
+
|
163 |
+
def __init__(self, num_embeddings: int, embedding_dim: int, eps: float = 1e-6):
|
164 |
+
super().__init__()
|
165 |
+
self.eps = eps
|
166 |
+
self.dim = embedding_dim
|
167 |
+
self.scale = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
|
168 |
+
self.shift = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
|
169 |
+
torch.nn.init.ones_(self.scale.weight)
|
170 |
+
torch.nn.init.zeros_(self.shift.weight)
|
171 |
+
|
172 |
+
def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor) -> torch.Tensor:
|
173 |
+
scale = self.scale(cond_embedding_id)
|
174 |
+
shift = self.shift(cond_embedding_id)
|
175 |
+
x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps)
|
176 |
+
x = x * scale + shift
|
177 |
+
return x
|
178 |
+
|
179 |
+
|
180 |
+
class ResBlock1(nn.Module):
|
181 |
+
"""
|
182 |
+
ResBlock adapted from HiFi-GAN V1 (https://github.com/jik876/hifi-gan) with dilated 1D convolutions,
|
183 |
+
but without upsampling layers.
|
184 |
+
|
185 |
+
Args:
|
186 |
+
dim (int): Number of input channels.
|
187 |
+
kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3.
|
188 |
+
dilation (tuple[int], optional): Dilation factors for the dilated convolutions.
|
189 |
+
Defaults to (1, 3, 5).
|
190 |
+
lrelu_slope (float, optional): Negative slope of the LeakyReLU activation function.
|
191 |
+
Defaults to 0.1.
|
192 |
+
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
|
193 |
+
Defaults to None.
|
194 |
+
"""
|
195 |
+
|
196 |
+
def __init__(
|
197 |
+
self,
|
198 |
+
dim: int,
|
199 |
+
kernel_size: int = 3,
|
200 |
+
dilation: Tuple[int, int, int] = (1, 3, 5),
|
201 |
+
lrelu_slope: float = 0.1,
|
202 |
+
layer_scale_init_value: Optional[float] = None,
|
203 |
+
):
|
204 |
+
super().__init__()
|
205 |
+
self.lrelu_slope = lrelu_slope
|
206 |
+
self.convs1 = nn.ModuleList(
|
207 |
+
[
|
208 |
+
weight_norm(
|
209 |
+
nn.Conv1d(
|
210 |
+
dim,
|
211 |
+
dim,
|
212 |
+
kernel_size,
|
213 |
+
1,
|
214 |
+
dilation=dilation[0],
|
215 |
+
padding=self.get_padding(kernel_size, dilation[0]),
|
216 |
+
)
|
217 |
+
),
|
218 |
+
weight_norm(
|
219 |
+
nn.Conv1d(
|
220 |
+
dim,
|
221 |
+
dim,
|
222 |
+
kernel_size,
|
223 |
+
1,
|
224 |
+
dilation=dilation[1],
|
225 |
+
padding=self.get_padding(kernel_size, dilation[1]),
|
226 |
+
)
|
227 |
+
),
|
228 |
+
weight_norm(
|
229 |
+
nn.Conv1d(
|
230 |
+
dim,
|
231 |
+
dim,
|
232 |
+
kernel_size,
|
233 |
+
1,
|
234 |
+
dilation=dilation[2],
|
235 |
+
padding=self.get_padding(kernel_size, dilation[2]),
|
236 |
+
)
|
237 |
+
),
|
238 |
+
]
|
239 |
+
)
|
240 |
+
|
241 |
+
self.convs2 = nn.ModuleList(
|
242 |
+
[
|
243 |
+
weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))),
|
244 |
+
weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))),
|
245 |
+
weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))),
|
246 |
+
]
|
247 |
+
)
|
248 |
+
|
249 |
+
self.gamma = nn.ParameterList(
|
250 |
+
[
|
251 |
+
nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True)
|
252 |
+
if layer_scale_init_value is not None
|
253 |
+
else None,
|
254 |
+
nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True)
|
255 |
+
if layer_scale_init_value is not None
|
256 |
+
else None,
|
257 |
+
nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True)
|
258 |
+
if layer_scale_init_value is not None
|
259 |
+
else None,
|
260 |
+
]
|
261 |
+
)
|
262 |
+
|
263 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
264 |
+
for c1, c2, gamma in zip(self.convs1, self.convs2, self.gamma):
|
265 |
+
xt = torch.nn.functional.leaky_relu(x, negative_slope=self.lrelu_slope)
|
266 |
+
xt = c1(xt)
|
267 |
+
xt = torch.nn.functional.leaky_relu(xt, negative_slope=self.lrelu_slope)
|
268 |
+
xt = c2(xt)
|
269 |
+
if gamma is not None:
|
270 |
+
xt = gamma * xt
|
271 |
+
x = xt + x
|
272 |
+
return x
|
273 |
+
|
274 |
+
def remove_weight_norm(self):
|
275 |
+
for l in self.convs1:
|
276 |
+
remove_weight_norm(l)
|
277 |
+
for l in self.convs2:
|
278 |
+
remove_weight_norm(l)
|
279 |
+
|
280 |
+
@staticmethod
|
281 |
+
def get_padding(kernel_size: int, dilation: int = 1) -> int:
|
282 |
+
return int((kernel_size * dilation - dilation) / 2)
|
283 |
+
|
284 |
+
|
285 |
+
def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor:
|
286 |
+
"""
|
287 |
+
Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values.
|
288 |
+
|
289 |
+
Args:
|
290 |
+
x (Tensor): Input tensor.
|
291 |
+
clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7.
|
292 |
+
|
293 |
+
Returns:
|
294 |
+
Tensor: Element-wise logarithm of the input tensor with clipping applied.
|
295 |
+
"""
|
296 |
+
return torch.log(torch.clip(x, min=clip_val))
|
297 |
+
|
298 |
+
|
299 |
+
def symlog(x: torch.Tensor) -> torch.Tensor:
|
300 |
+
return torch.sign(x) * torch.log1p(x.abs())
|
301 |
+
|
302 |
+
|
303 |
+
def symexp(x: torch.Tensor) -> torch.Tensor:
|
304 |
+
return torch.sign(x) * (torch.exp(x.abs()) - 1)
|
305 |
+
|
306 |
+
|
307 |
+
|
308 |
+
class SemanticEncoder(nn.Module):
|
309 |
+
def __init__(
|
310 |
+
self,
|
311 |
+
input_channels: int,
|
312 |
+
code_dim: int,
|
313 |
+
encode_channels: int,
|
314 |
+
kernel_size: int = 3,
|
315 |
+
bias: bool = True,
|
316 |
+
):
|
317 |
+
super(SemanticEncoder, self).__init__()
|
318 |
+
|
319 |
+
# 初始卷积,将 input_channels 映射到 encode_channels
|
320 |
+
self.initial_conv = nn.Conv1d(
|
321 |
+
in_channels=input_channels,
|
322 |
+
out_channels=encode_channels,
|
323 |
+
kernel_size=kernel_size,
|
324 |
+
stride=1,
|
325 |
+
padding=(kernel_size - 1) // 2,
|
326 |
+
bias=False
|
327 |
+
)
|
328 |
+
|
329 |
+
# 残差块
|
330 |
+
self.residual_blocks = nn.Sequential(
|
331 |
+
nn.ReLU(inplace=True),
|
332 |
+
nn.Conv1d(
|
333 |
+
encode_channels,
|
334 |
+
encode_channels,
|
335 |
+
kernel_size=kernel_size,
|
336 |
+
stride=1,
|
337 |
+
padding=(kernel_size - 1) // 2,
|
338 |
+
bias=bias
|
339 |
+
),
|
340 |
+
nn.ReLU(inplace=True),
|
341 |
+
nn.Conv1d(
|
342 |
+
encode_channels,
|
343 |
+
encode_channels,
|
344 |
+
kernel_size=kernel_size,
|
345 |
+
stride=1,
|
346 |
+
padding=(kernel_size - 1) // 2,
|
347 |
+
bias=bias
|
348 |
+
)
|
349 |
+
)
|
350 |
+
|
351 |
+
# 最终卷积,将 encode_channels 映射到 code_dim
|
352 |
+
self.final_conv = nn.Conv1d(
|
353 |
+
in_channels=encode_channels,
|
354 |
+
out_channels=code_dim,
|
355 |
+
kernel_size=kernel_size,
|
356 |
+
stride=1,
|
357 |
+
padding=(kernel_size - 1) // 2,
|
358 |
+
bias=False
|
359 |
+
)
|
360 |
+
|
361 |
+
def forward(self, x):
|
362 |
+
"""
|
363 |
+
前向传播方法。
|
364 |
+
|
365 |
+
Args:
|
366 |
+
x (Tensor): 输入张量,形状为 (Batch, Input_channels, Length)
|
367 |
+
|
368 |
+
Returns:
|
369 |
+
Tensor: 编码后的张量,形状为 (Batch, Code_dim, Length)
|
370 |
+
"""
|
371 |
+
x = self.initial_conv(x) # (Batch, Encode_channels, Length)
|
372 |
+
x = self.residual_blocks(x) + x # 残差连接
|
373 |
+
x = self.final_conv(x) # (Batch, Code_dim, Length)
|
374 |
+
return x
|
375 |
+
|
376 |
+
class SemanticDecoder(nn.Module):
|
377 |
+
def __init__(
|
378 |
+
self,
|
379 |
+
code_dim: int,
|
380 |
+
output_channels: int,
|
381 |
+
decode_channels: int,
|
382 |
+
kernel_size: int = 3,
|
383 |
+
bias: bool = True,
|
384 |
+
):
|
385 |
+
super(SemanticDecoder, self).__init__()
|
386 |
+
|
387 |
+
# Initial convolution to map code_dim to decode_channels
|
388 |
+
self.initial_conv = nn.Conv1d(
|
389 |
+
in_channels=code_dim,
|
390 |
+
out_channels=decode_channels,
|
391 |
+
kernel_size=kernel_size,
|
392 |
+
stride=1,
|
393 |
+
padding=(kernel_size - 1) // 2,
|
394 |
+
bias=False
|
395 |
+
)
|
396 |
+
|
397 |
+
# Residual Blocks
|
398 |
+
self.residual_blocks = nn.Sequential(
|
399 |
+
nn.ReLU(inplace=True),
|
400 |
+
nn.Conv1d(decode_channels, decode_channels, kernel_size=kernel_size, stride=1, padding=(kernel_size - 1) // 2, bias=bias),
|
401 |
+
nn.ReLU(inplace=True),
|
402 |
+
nn.Conv1d(decode_channels, decode_channels, kernel_size=kernel_size, stride=1, padding=(kernel_size - 1) // 2, bias=bias)
|
403 |
+
)
|
404 |
+
|
405 |
+
# Final convolution to map decode_channels to output_channels
|
406 |
+
self.final_conv = nn.Conv1d(
|
407 |
+
in_channels=decode_channels,
|
408 |
+
out_channels=output_channels,
|
409 |
+
kernel_size=kernel_size,
|
410 |
+
stride=1,
|
411 |
+
padding=(kernel_size - 1) // 2,
|
412 |
+
bias=False
|
413 |
+
)
|
414 |
+
|
415 |
+
def forward(self, z):
|
416 |
+
# z: (Batch, Code_dim, Length)
|
417 |
+
x = self.initial_conv(z) # (Batch, Decode_channels, Length)
|
418 |
+
x = self.residual_blocks(x) + x # Residual connection
|
419 |
+
x = self.final_conv(x) # (Batch, Output_channels, Length)
|
420 |
+
return x
|
vq/residual_vq.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
from .factorized_vector_quantize import FactorizedVectorQuantize
|
5 |
+
|
6 |
+
class ResidualVQ(nn.Module):
|
7 |
+
def __init__(
|
8 |
+
self,
|
9 |
+
*,
|
10 |
+
num_quantizers,
|
11 |
+
codebook_size,
|
12 |
+
**kwargs
|
13 |
+
):
|
14 |
+
super().__init__()
|
15 |
+
VQ = FactorizedVectorQuantize
|
16 |
+
if type(codebook_size) == int:
|
17 |
+
codebook_size = [codebook_size] * num_quantizers
|
18 |
+
self.layers = nn.ModuleList([VQ(codebook_size=size, **kwargs) for size in codebook_size])
|
19 |
+
self.num_quantizers = num_quantizers
|
20 |
+
|
21 |
+
def forward(self, x):
|
22 |
+
quantized_out = 0.
|
23 |
+
residual = x
|
24 |
+
|
25 |
+
all_losses = []
|
26 |
+
all_indices = []
|
27 |
+
|
28 |
+
for idx, layer in enumerate(self.layers):
|
29 |
+
quantized, indices, loss = layer(residual)
|
30 |
+
|
31 |
+
residual = residual - quantized
|
32 |
+
|
33 |
+
quantized_out = quantized_out + quantized
|
34 |
+
|
35 |
+
loss = loss.mean()
|
36 |
+
|
37 |
+
all_indices.append(indices)
|
38 |
+
all_losses.append(loss)
|
39 |
+
all_losses, all_indices = map(torch.stack, (all_losses, all_indices))
|
40 |
+
return quantized_out, all_indices, all_losses
|
41 |
+
|
42 |
+
def vq2emb(self, vq, proj=True):
|
43 |
+
# [B, T, num_quantizers]
|
44 |
+
quantized_out = 0.
|
45 |
+
for idx, layer in enumerate(self.layers):
|
46 |
+
quantized = layer.vq2emb(vq[:, :, idx], proj=proj)
|
47 |
+
quantized_out = quantized_out + quantized
|
48 |
+
return quantized_out
|
49 |
+
def get_emb(self):
|
50 |
+
embs = []
|
51 |
+
for idx, layer in enumerate(self.layers):
|
52 |
+
embs.append(layer.get_emb())
|
53 |
+
return embs
|
vq/unet.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from einops import rearrange
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
|
8 |
+
class EncoderBlock(nn.Module):
|
9 |
+
def __init__(self, in_channels, out_channels, kernel_size=(3, 3)):
|
10 |
+
super(EncoderBlock, self).__init__()
|
11 |
+
|
12 |
+
self.pool_size = 2
|
13 |
+
|
14 |
+
self.conv_block = ConvBlock(in_channels, out_channels, kernel_size)
|
15 |
+
|
16 |
+
def forward(self, x):
|
17 |
+
latent = self.conv_block(x)
|
18 |
+
output = F.avg_pool2d(latent, kernel_size=self.pool_size)
|
19 |
+
return output, latent
|
20 |
+
|
21 |
+
class DecoderBlock(nn.Module):
|
22 |
+
def __init__(self, in_channels, out_channels, kernel_size=(3, 3)):
|
23 |
+
super(DecoderBlock, self).__init__()
|
24 |
+
|
25 |
+
stride = 2
|
26 |
+
|
27 |
+
self.upsample = nn.ConvTranspose2d(
|
28 |
+
in_channels=in_channels,
|
29 |
+
out_channels=in_channels,
|
30 |
+
kernel_size=stride,
|
31 |
+
stride=stride,
|
32 |
+
padding=(0, 0),
|
33 |
+
bias=False,
|
34 |
+
)
|
35 |
+
|
36 |
+
self.conv_block = ConvBlock(in_channels * 2, out_channels, kernel_size)
|
37 |
+
|
38 |
+
def forward(self, x, latent):
|
39 |
+
x = self.upsample(x)
|
40 |
+
x = torch.cat((x, latent), dim=1)
|
41 |
+
output = self.conv_block(x)
|
42 |
+
return output
|
43 |
+
|
44 |
+
|
45 |
+
class UNet(nn.Module):
|
46 |
+
def __init__(self,freq_dim=1281,out_channel=1024):
|
47 |
+
super(UNet, self).__init__()
|
48 |
+
|
49 |
+
self.downsample_ratio = 16
|
50 |
+
|
51 |
+
|
52 |
+
in_channels = 1 #self.audio_channels * self.cmplx_num
|
53 |
+
|
54 |
+
self.encoder_block1 = EncoderBlock(in_channels, 16)
|
55 |
+
self.encoder_block2 = EncoderBlock(16, 64)
|
56 |
+
self.encoder_block3 = EncoderBlock(64, 256)
|
57 |
+
self.encoder_block4 = EncoderBlock(256, 1024)
|
58 |
+
self.middle = EncoderBlock(1024, 1024)
|
59 |
+
self.decoder_block1 = DecoderBlock(1024, 256)
|
60 |
+
self.decoder_block2 = DecoderBlock(256, 64)
|
61 |
+
self.decoder_block3 = DecoderBlock(64, 16)
|
62 |
+
self.decoder_block4 = DecoderBlock(16, 16)
|
63 |
+
|
64 |
+
self.fc = nn.Linear(freq_dim*16, out_channel)
|
65 |
+
|
66 |
+
def forward(self, x_ori):
|
67 |
+
"""
|
68 |
+
Args:
|
69 |
+
complex_sp: (batch_size, channels_num, time_steps, freq_bins),复数张量
|
70 |
+
|
71 |
+
Returns:
|
72 |
+
output: (batch_size, channels_num, time_steps, freq_bins),复数张量
|
73 |
+
"""
|
74 |
+
|
75 |
+
|
76 |
+
x= self.process_image(x_ori)
|
77 |
+
x1, latent1 = self.encoder_block1(x)
|
78 |
+
x2, latent2 = self.encoder_block2(x1)
|
79 |
+
x3, latent3 = self.encoder_block3(x2)
|
80 |
+
x4, latent4 = self.encoder_block4(x3)
|
81 |
+
_, h = self.middle(x4)
|
82 |
+
x5 = self.decoder_block1(h, latent4)
|
83 |
+
x6 = self.decoder_block2(x5, latent3)
|
84 |
+
x7 = self.decoder_block3(x6, latent2)
|
85 |
+
x8 = self.decoder_block4(x7, latent1)
|
86 |
+
x= self.unprocess_image(x8,x_ori.shape[2])
|
87 |
+
x = x.permute(0, 2, 1, 3).contiguous() # 将形状变为 [6, 256, 16, 1024]
|
88 |
+
x = x.view(x.size(0), x.size(1), -1)
|
89 |
+
x= self.fc(x)
|
90 |
+
|
91 |
+
return x
|
92 |
+
|
93 |
+
def process_image(self, x):
|
94 |
+
"""
|
95 |
+
处理频谱以便可以被 downsample_ratio 整除。
|
96 |
+
|
97 |
+
Args:
|
98 |
+
x: (B, C, T, F)
|
99 |
+
|
100 |
+
Returns:
|
101 |
+
output: (B, C, T_padded, F_reduced)
|
102 |
+
"""
|
103 |
+
|
104 |
+
B, C, T, Freq = x.shape
|
105 |
+
|
106 |
+
pad_len = (
|
107 |
+
int(np.ceil(T / self.downsample_ratio)) * self.downsample_ratio
|
108 |
+
- T
|
109 |
+
)
|
110 |
+
x = F.pad(x, pad=(0, 0, 0, pad_len))
|
111 |
+
|
112 |
+
output = x[:, :, :, 0 : Freq - 1]
|
113 |
+
|
114 |
+
return output
|
115 |
+
|
116 |
+
def unprocess_image(self, x,time_steps):
|
117 |
+
"""
|
118 |
+
恢复频谱到原始形状。
|
119 |
+
|
120 |
+
Args:
|
121 |
+
x: (B, C, T_padded, F_reduced)
|
122 |
+
|
123 |
+
Returns:
|
124 |
+
output: (B, C, T_original, F_original)
|
125 |
+
"""
|
126 |
+
x = F.pad(x, pad=(0, 1))
|
127 |
+
|
128 |
+
output = x[:, :,0:time_steps, :]
|
129 |
+
|
130 |
+
return output
|
131 |
+
|
132 |
+
class ConvBlock(nn.Module):
|
133 |
+
def __init__(self, in_channels, out_channels, kernel_size=(3, 3)):
|
134 |
+
super(ConvBlock, self).__init__()
|
135 |
+
|
136 |
+
padding = [kernel_size[0] // 2, kernel_size[1] // 2]
|
137 |
+
|
138 |
+
self.bn1 = nn.BatchNorm2d(in_channels)
|
139 |
+
self.bn2 = nn.BatchNorm2d(out_channels)
|
140 |
+
|
141 |
+
self.conv1 = nn.Conv2d(
|
142 |
+
in_channels=in_channels,
|
143 |
+
out_channels=out_channels,
|
144 |
+
kernel_size=kernel_size,
|
145 |
+
padding=padding,
|
146 |
+
bias=False,
|
147 |
+
)
|
148 |
+
|
149 |
+
self.conv2 = nn.Conv2d(
|
150 |
+
in_channels=out_channels,
|
151 |
+
out_channels=out_channels,
|
152 |
+
kernel_size=kernel_size,
|
153 |
+
padding=padding,
|
154 |
+
bias=False,
|
155 |
+
)
|
156 |
+
|
157 |
+
if in_channels != out_channels:
|
158 |
+
self.shortcut = nn.Conv2d(
|
159 |
+
in_channels=in_channels,
|
160 |
+
out_channels=out_channels,
|
161 |
+
kernel_size=(1, 1),
|
162 |
+
padding=(0, 0),
|
163 |
+
)
|
164 |
+
self.is_shortcut = True
|
165 |
+
else:
|
166 |
+
self.is_shortcut = False
|
167 |
+
|
168 |
+
def forward(self, x):
|
169 |
+
h = self.conv1(F.leaky_relu_(self.bn1(x)))
|
170 |
+
h = self.conv2(F.leaky_relu_(self.bn2(h)))
|
171 |
+
|
172 |
+
if self.is_shortcut:
|
173 |
+
return self.shortcut(x) + h
|
174 |
+
else:
|
175 |
+
return x + h
|
176 |
+
|
177 |
+
|
178 |
+
def test_unet():
|
179 |
+
# 定义输入参数
|
180 |
+
batch_size = 6
|
181 |
+
channels = 1 # 音频通道数
|
182 |
+
time_steps = 256 # 时间步数
|
183 |
+
freq_bins = 1024 # 频率 bins 数
|
184 |
+
|
185 |
+
# 创建一个随机的复数张量作为输入
|
186 |
+
real_part = torch.randn(batch_size, channels, time_steps, freq_bins)
|
187 |
+
imag_part = torch.randn(batch_size, channels, time_steps, freq_bins)
|
188 |
+
complex_sp = real_part #torch.complex(real_part, imag_part)
|
189 |
+
|
190 |
+
# 实例化 UNet 模型
|
191 |
+
model = UNet()
|
192 |
+
|
193 |
+
# 前向传播
|
194 |
+
output = model(complex_sp)
|
195 |
+
|
196 |
+
# 输出输入和输出的形状
|
197 |
+
print("输入形状:", complex_sp.shape)
|
198 |
+
print("输出形状:", output.shape)
|
199 |
+
|
200 |
+
# 检查输出是否为复数张量
|
201 |
+
assert torch.is_complex(output), "输出不是复数张量"
|
202 |
+
|
203 |
+
# 检查输出形状是否与输入形状一致
|
204 |
+
assert output.shape == complex_sp.shape, "输出形状与输入形状不一致"
|
205 |
+
|
206 |
+
print("测试通过,模型正常工作。")
|
207 |
+
|
208 |
+
# 运行测试函数
|
209 |
+
if __name__ == "__main__":
|
210 |
+
test_unet()
|