JahidHishab commited on
Commit
ccd183c
·
verified ·
1 Parent(s): bc4f968

Upload 23 files

Browse files
README.md CHANGED
@@ -17,9 +17,9 @@ language:
17
  - ko
18
  ---
19
 
20
- # 🗣️ XCodec2 Retrained (Multilingual, 100 Hours)
21
 
22
- This model is a retrained version of [HKUSTAudio/xcodec2](https://huggingface.co/HKUSTAudio/xcodec2), trained on a 100K-hour multilingual dataset across 7 languages. It is optimized for speech representation learning, compression, and high-fidelity reconstruction — particularly useful for TTS and bandwidth-efficient speech synthesis.
23
 
24
  ---
25
 
@@ -110,7 +110,7 @@ Reconstruction metrics are computed over 100 samples for English, Japanese, and
110
  - Japanese: 100 Examples (Emilia Dataset @ 24 kHz)
111
  - Bangla: 100 Examples (Inhouse TTS Dataset @ 22.05 kHz)
112
 
113
- | Model | Lang | MCD ↓ | MSE ↑ | BERTScore ↑ | BLEU ↑ | TokenDist ↑ |
114
  |-------------------|------|--------|--------|-------------|--------|-------------|
115
  | **XCODEC** | BN | 2.823 | 0.003 | 0.939 | 0.500 | 0.816 |
116
  | | EN | 3.166 | 0.012 | 0.962 | 0.660 | 0.856 |
@@ -129,6 +129,7 @@ Reconstruction metrics are computed over 100 samples for English, Japanese, and
129
  | | JA | 2.677 | 0.003 | 0.955 | 0.614 | 0.853 |
130
  | **Overall** | | 2.597 | 0.003 | 0.960 | 0.636 | 0.863 |
131
 
 
132
  ---
133
 
134
  ## ✅ Intended Use
 
17
  - ko
18
  ---
19
 
20
+ # 🗣️ XCodec2 Trained on 100K Hours of Multilingual Data
21
 
22
+ This [model](https://huggingface.co/HKUSTAudio/xcodec2) is trained on a 100K-hour multilingual dataset across 7 languages. It is optimized for speech representation learning, compression, and high-fidelity reconstruction — particularly useful for TTS and bandwidth-efficient speech synthesis.
23
 
24
  ---
25
 
 
110
  - Japanese: 100 Examples (Emilia Dataset @ 24 kHz)
111
  - Bangla: 100 Examples (Inhouse TTS Dataset @ 22.05 kHz)
112
 
113
+ | Model | Lang | MCD ↓ | MSE ↑ | SpeechBERTScore ↑ | SpeechBLEU ↑ | SpeechTokenDist ↑ |
114
  |-------------------|------|--------|--------|-------------|--------|-------------|
115
  | **XCODEC** | BN | 2.823 | 0.003 | 0.939 | 0.500 | 0.816 |
116
  | | EN | 3.166 | 0.012 | 0.962 | 0.660 | 0.856 |
 
129
  | | JA | 2.677 | 0.003 | 0.955 | 0.614 | 0.853 |
130
  | **Overall** | | 2.597 | 0.003 | 0.960 | 0.636 | 0.863 |
131
 
132
+ #### SpeechBERTScore, SpeechBLEU and SpeechTokenDistance are calculated using https://github.com/Takaaki-Saeki/DiscreteSpeechMetrics
133
  ---
134
 
135
  ## ✅ Intended Use
__init__.py ADDED
File without changes
config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "xcodec2",
3
+ "semantic_hidden_size": 1024,
4
+ "codec_encoder_hidden_size": 1024,
5
+ "codec_decoder_hidden_size": 1024,
6
+ "use_vocos": true,
7
+ "architectures": [
8
+ "XCodec2Model"
9
+ ]
10
+ }
11
+
configuration_bigcodec.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class BigCodecConfig(PretrainedConfig):
4
+ model_type = "bigcodec"
5
+
6
+ def __init__(
7
+ self,
8
+ # 下面这些只是示例超参
9
+ semantic_hidden_size=1024,
10
+ codec_encoder_hidden_size=1024,
11
+ codec_decoder_hidden_size=1024,
12
+ use_vocos=True,
13
+ **kwargs
14
+ ):
15
+ super().__init__(**kwargs)
16
+ self.semantic_hidden_size = semantic_hidden_size
17
+ self.codec_encoder_hidden_size = codec_encoder_hidden_size
18
+ self.codec_decoder_hidden_size = codec_decoder_hidden_size
19
+ self.use_vocos = use_vocos
modeling_xcodec2.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import PreTrainedModel
4
+ from configuration_bigcodec import BigCodecConfig
5
+
6
+ # 请确保这些模块路径是正确的
7
+ from vq.codec_encoder import CodecEncoder_Transformer
8
+ from vq.codec_decoder_vocos import CodecDecoderVocos
9
+ from vq.module import SemanticEncoder
10
+ from transformers import AutoFeatureExtractor, Wav2Vec2BertModel
11
+
12
+ class XCodec2Model(PreTrainedModel):
13
+ config_class = BigCodecConfig
14
+
15
+ def __init__(self, config: BigCodecConfig):
16
+ super().__init__(config)
17
+
18
+ # 1) 语义模型
19
+ self.semantic_model = Wav2Vec2BertModel.from_pretrained(
20
+ "facebook/w2v-bert-2.0",
21
+ output_hidden_states=True
22
+ )
23
+ self.semantic_model.eval()
24
+
25
+ self.SemanticEncoder_module = SemanticEncoder(
26
+ config.semantic_hidden_size,
27
+ config.semantic_hidden_size,
28
+ config.semantic_hidden_size
29
+ )
30
+
31
+ # 2) Codec Encoder
32
+ self.CodecEnc = CodecEncoder_Transformer()
33
+
34
+ # 3) Codec Decoder
35
+ self.generator = CodecDecoderVocos()
36
+
37
+ # 4) 两个全连接层
38
+ self.fc_prior = nn.Linear(2048, 2048)
39
+ self.fc_post_a = nn.Linear(2048, 1024)
40
+ feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0")
41
+ self.feature_extractor = feature_extractor
42
+
43
+ def forward(self, input_waveform, sample_rate=16000):
44
+ """
45
+ 这里的 forward 不一定要叫 forward,也可以拆成别的方法;
46
+ 但是如果想兼容 pipeline,需要在 forward 里给出核心逻辑。
47
+
48
+ 参数:
49
+ input_waveform: [batch_size, waveform_length]
50
+ sample_rate: 默认 16000
51
+ 返回:
52
+ 重构后的语音音频 (Tensor)
53
+ """
54
+ # 1) 特征提取
55
+ # 如果需要 padding,可以在这里做
56
+ input_features = self.feature_extractor(
57
+ input_waveform,
58
+ sampling_rate=sample_rate,
59
+ return_tensors="pt"
60
+ ).input_features.to(self.device) # [batch, frames, feat_dim]
61
+
62
+ # 2) 语义层
63
+ semantic_output = self.semantic_model(input_features)
64
+ semantic_hidden_16 = semantic_output.hidden_states[16] # 取第16层
65
+ semantic_hidden_16 = semantic_hidden_16.transpose(1, 2) # [batch, hidden_dim, frames]
66
+ semantic_encoded = self.SemanticEncoder_module(semantic_hidden_16)
67
+
68
+ # 3) codec encoder
69
+ wav = input_waveform.unsqueeze(1).to(self.device) # shape: [batch, 1, time]
70
+ vq_emb = self.CodecEnc(wav) # [batch, time//down, 1024] 只是示例
71
+ vq_emb = vq_emb.transpose(1, 2) # -> [batch, 1024, frames]
72
+
73
+ # 对齐语义向量的时间帧数,这里只做示例处理
74
+ # 真实做法里可能要先对齐维度
75
+ if vq_emb.shape[-1] != semantic_encoded.shape[-1]:
76
+ # 简单强行截断或补零都行,需要你自己决定
77
+ min_len = min(vq_emb.shape[-1], semantic_encoded.shape[-1])
78
+ vq_emb = vq_emb[:, :, :min_len]
79
+ semantic_encoded = semantic_encoded[:, :, :min_len]
80
+
81
+ # 4) 拼接
82
+ concat_emb = torch.cat([semantic_encoded, vq_emb], dim=1) # [batch, 1024 + 1024, frames]
83
+
84
+ # 5) fc_prior
85
+ concat_emb = self.fc_prior(concat_emb.transpose(1, 2)).transpose(1, 2)
86
+
87
+ # 6) decoder 的量化部分
88
+ _, vq_code, _ = self.generator(concat_emb, vq=True)
89
+ vq_post_emb = self.generator.quantizer.get_output_from_indices(vq_code.transpose(1, 2))
90
+ vq_post_emb = vq_post_emb.transpose(1, 2)
91
+
92
+ # 7) fc_post_a
93
+ vq_post_emb = self.fc_post_a(vq_post_emb.transpose(1, 2)).transpose(1, 2)
94
+
95
+ # 8) 最后解码成波形
96
+ recon_audio = self.generator(vq_post_emb.transpose(1, 2), vq=False)[0]
97
+ # recon_audio: [batch, time]
98
+ return recon_audio
99
+
100
+ def encode_code(self, input_waveform, sample_rate=16000):
101
+ """
102
+ 将输入的音频编码为代码表示。
103
+
104
+ 参数:
105
+ input_waveform: [batch_size, waveform_length]
106
+ sample_rate: 默认 16000
107
+ 返回:
108
+ 编码后的代码 (Tensor)
109
+ """
110
+ with torch.no_grad():
111
+ # 1) 特征提取
112
+ input_features = self.feature_extractor(
113
+ input_waveform,
114
+ sampling_rate=sample_rate,
115
+ return_tensors="pt"
116
+ ).input_features.to(self.device) # [batch, frames, feat_dim]
117
+
118
+ # 2) 语义层
119
+ semantic_output = self.semantic_model(input_features)
120
+ semantic_hidden_16 = semantic_output.hidden_states[16] # 取第16层
121
+ semantic_hidden_16 = semantic_hidden_16.transpose(1, 2) # [batch, hidden_dim, frames]
122
+ semantic_encoded = self.SemanticEncoder_module(semantic_hidden_16)
123
+
124
+ # 3) codec encoder
125
+ wav = input_waveform.unsqueeze(1).to(self.device) # shape: [batch, 1, time]
126
+ vq_emb = self.CodecEnc(wav) # [batch, time//down, 1024] 只是示例
127
+ vq_emb = vq_emb.transpose(1, 2) # -> [batch, 1024, frames]
128
+
129
+ # 对齐语义向量的时间帧数,这里只做示例处理
130
+ if vq_emb.shape[-1] != semantic_encoded.shape[-1]:
131
+ min_len = min(vq_emb.shape[-1], semantic_encoded.shape[-1])
132
+ vq_emb = vq_emb[:, :, :min_len]
133
+ semantic_encoded = semantic_encoded[:, :, :min_len]
134
+
135
+ # 4) 拼接
136
+ concat_emb = torch.cat([semantic_encoded, vq_emb], dim=1) # [batch, 2048, frames]
137
+
138
+ # 5) fc_prior
139
+ concat_emb = self.fc_prior(concat_emb.transpose(1, 2)).transpose(1, 2)
140
+
141
+ # 6) decoder 的量化部分,获取code
142
+ _, vq_code, _ = self.generator(concat_emb, vq=True)
143
+ # vq_code: [batch, frames]
144
+ return vq_code
145
+
146
+ def decode_code(self, vq_code):
147
+ """
148
+ 将编码后的代码解码回音频。
149
+
150
+ 参数:
151
+ vq_code: 编码后的代码 (Tensor) [batch, frames]
152
+ 返回:
153
+ 解码后的音频 (Tensor) [batch, waveform_length]
154
+ """
155
+ with torch.no_grad():
156
+ # 获取量化后的嵌入
157
+ vq_post_emb = self.generator.quantizer.get_output_from_indices(vq_code.transpose(1, 2))
158
+ vq_post_emb = vq_post_emb.transpose(1, 2) # [batch, 1024, frames]
159
+
160
+ # 7) fc_post_a
161
+ vq_post_emb = self.fc_post_a(vq_post_emb.transpose(1, 2)).transpose(1, 2) # [batch, 1024, frames]
162
+
163
+ # 8) 最后解码成波形
164
+ recon_audio = self.generator(vq_post_emb.transpose(1, 2), vq=False)[0] # [batch, time]
165
+ return recon_audio
module.py ADDED
File without changes
test.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import soundfile as sf
3
+ from transformers import AutoConfig
4
+
5
+ from modeling_xcodec2 import XCodec2Model
6
+
7
+ model_path = "/data/zheny/xcodec2" # 这是你在 huggingface 上的仓库名
8
+
9
+ model = XCodec2Model.from_pretrained(model_path)
10
+ model.eval().cuda()
11
+
12
+ # 准备一段音频
13
+ wav, sr = sf.read("test.flac")
14
+ wav_tensor = torch.from_numpy(wav).float().unsqueeze(0) # [1, time]
15
+
16
+ with torch.no_grad():
17
+ vq_code = model.encode_code(input_waveform=wav_tensor )
18
+ print(vq_code)
19
+ recon_wav = model.decode_code(vq_code).cpu()
20
+
21
+ sf.write("reconstructed.wav", recon_wav[0,0,:].numpy(), sr)
vq/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from vq.codec_encoder import CodecEncoder
2
+ from vq.codec_decoder import CodecDecoder
3
+ from vq.codec_decoder_vocos import CodecDecoderVocos
4
+ from vq.codec_encoder import CodecEncoder_Transformer,CodecEncoder_only_Transformer
vq/activations.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import torch
5
+ from torch import nn, sin, pow
6
+ from torch.nn import Parameter
7
+
8
+
9
+ class Snake(nn.Module):
10
+ '''
11
+ Implementation of a sine-based periodic activation function
12
+ Shape:
13
+ - Input: (B, C, T)
14
+ - Output: (B, C, T), same shape as the input
15
+ Parameters:
16
+ - alpha - trainable parameter
17
+ References:
18
+ - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
19
+ https://arxiv.org/abs/2006.08195
20
+ Examples:
21
+ >>> a1 = snake(256)
22
+ >>> x = torch.randn(256)
23
+ >>> x = a1(x)
24
+ '''
25
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
26
+ '''
27
+ Initialization.
28
+ INPUT:
29
+ - in_features: shape of the input
30
+ - alpha: trainable parameter
31
+ alpha is initialized to 1 by default, higher values = higher-frequency.
32
+ alpha will be trained along with the rest of your model.
33
+ '''
34
+ super(Snake, self).__init__()
35
+ self.in_features = in_features
36
+
37
+ # initialize alpha
38
+ self.alpha_logscale = alpha_logscale
39
+ if self.alpha_logscale: # log scale alphas initialized to zeros
40
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
41
+ else: # linear scale alphas initialized to ones
42
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
43
+
44
+ self.alpha.requires_grad = alpha_trainable
45
+
46
+ self.no_div_by_zero = 0.000000001
47
+
48
+ def forward(self, x):
49
+ '''
50
+ Forward pass of the function.
51
+ Applies the function to the input elementwise.
52
+ Snake ∶= x + 1/a * sin^2 (xa)
53
+ '''
54
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
55
+ if self.alpha_logscale:
56
+ alpha = torch.exp(alpha)
57
+ x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
58
+
59
+ return x
60
+
61
+
62
+ class SnakeBeta(nn.Module):
63
+ '''
64
+ A modified Snake function which uses separate parameters for the magnitude of the periodic components
65
+ Shape:
66
+ - Input: (B, C, T)
67
+ - Output: (B, C, T), same shape as the input
68
+ Parameters:
69
+ - alpha - trainable parameter that controls frequency
70
+ - beta - trainable parameter that controls magnitude
71
+ References:
72
+ - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
73
+ https://arxiv.org/abs/2006.08195
74
+ Examples:
75
+ >>> a1 = snakebeta(256)
76
+ >>> x = torch.randn(256)
77
+ >>> x = a1(x)
78
+ '''
79
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
80
+ '''
81
+ Initialization.
82
+ INPUT:
83
+ - in_features: shape of the input
84
+ - alpha - trainable parameter that controls frequency
85
+ - beta - trainable parameter that controls magnitude
86
+ alpha is initialized to 1 by default, higher values = higher-frequency.
87
+ beta is initialized to 1 by default, higher values = higher-magnitude.
88
+ alpha will be trained along with the rest of your model.
89
+ '''
90
+ super(SnakeBeta, self).__init__()
91
+ self.in_features = in_features
92
+
93
+ # initialize alpha
94
+ self.alpha_logscale = alpha_logscale
95
+ if self.alpha_logscale: # log scale alphas initialized to zeros
96
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
97
+ self.bias = Parameter(torch.zeros(in_features) * alpha)
98
+ else: # linear scale alphas initialized to ones
99
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
100
+ self.bias = Parameter(torch.ones(in_features) * alpha)
101
+
102
+ self.alpha.requires_grad = alpha_trainable
103
+ self.bias.requires_grad = alpha_trainable
104
+
105
+ self.no_div_by_zero = 0.000000001
106
+
107
+ def forward(self, x):
108
+ '''
109
+ Forward pass of the function.
110
+ Applies the function to the input elementwise.
111
+ SnakeBeta ∶= x + 1/b * sin^2 (xa)
112
+ '''
113
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
114
+ beta = self.bias.unsqueeze(0).unsqueeze(-1)
115
+ if self.alpha_logscale:
116
+ alpha = torch.exp(alpha)
117
+ beta = torch.exp(beta)
118
+ x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
119
+
120
+ return x
vq/alias_free_torch/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ from .filter import *
5
+ from .resample import *
6
+ from .act import *
vq/alias_free_torch/act.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import torch.nn as nn
5
+ from .resample import UpSample1d, DownSample1d
6
+
7
+
8
+ class Activation1d(nn.Module):
9
+ def __init__(self,
10
+ activation,
11
+ up_ratio: int = 2,
12
+ down_ratio: int = 2,
13
+ up_kernel_size: int = 12,
14
+ down_kernel_size: int = 12):
15
+ super().__init__()
16
+ self.up_ratio = up_ratio
17
+ self.down_ratio = down_ratio
18
+ self.act = activation
19
+ self.upsample = UpSample1d(up_ratio, up_kernel_size)
20
+ self.downsample = DownSample1d(down_ratio, down_kernel_size)
21
+
22
+ # x: [B,C,T]
23
+ def forward(self, x):
24
+ x = self.upsample(x)
25
+ x = self.act(x)
26
+ x = self.downsample(x)
27
+
28
+ return x
vq/alias_free_torch/filter.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import math
8
+
9
+ if 'sinc' in dir(torch):
10
+ sinc = torch.sinc
11
+ else:
12
+ # This code is adopted from adefossez's julius.core.sinc under the MIT License
13
+ # https://adefossez.github.io/julius/julius/core.html
14
+ # LICENSE is in incl_licenses directory.
15
+ def sinc(x: torch.Tensor):
16
+ """
17
+ Implementation of sinc, i.e. sin(pi * x) / (pi * x)
18
+ __Warning__: Different to julius.sinc, the input is multiplied by `pi`!
19
+ """
20
+ return torch.where(x == 0,
21
+ torch.tensor(1., device=x.device, dtype=x.dtype),
22
+ torch.sin(math.pi * x) / math.pi / x)
23
+
24
+
25
+ # This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
26
+ # https://adefossez.github.io/julius/julius/lowpass.html
27
+ # LICENSE is in incl_licenses directory.
28
+ def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size]
29
+ even = (kernel_size % 2 == 0)
30
+ half_size = kernel_size // 2
31
+
32
+ #For kaiser window
33
+ delta_f = 4 * half_width
34
+ A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
35
+ if A > 50.:
36
+ beta = 0.1102 * (A - 8.7)
37
+ elif A >= 21.:
38
+ beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.)
39
+ else:
40
+ beta = 0.
41
+ window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
42
+
43
+ # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
44
+ if even:
45
+ time = (torch.arange(-half_size, half_size) + 0.5)
46
+ else:
47
+ time = torch.arange(kernel_size) - half_size
48
+ if cutoff == 0:
49
+ filter_ = torch.zeros_like(time)
50
+ else:
51
+ filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
52
+ # Normalize filter to have sum = 1, otherwise we will have a small leakage
53
+ # of the constant component in the input signal.
54
+ filter_ /= filter_.sum()
55
+ filter = filter_.view(1, 1, kernel_size)
56
+
57
+ return filter
58
+
59
+
60
+ class LowPassFilter1d(nn.Module):
61
+ def __init__(self,
62
+ cutoff=0.5,
63
+ half_width=0.6,
64
+ stride: int = 1,
65
+ padding: bool = True,
66
+ padding_mode: str = 'replicate',
67
+ kernel_size: int = 12):
68
+ # kernel_size should be even number for stylegan3 setup,
69
+ # in this implementation, odd number is also possible.
70
+ super().__init__()
71
+ if cutoff < -0.:
72
+ raise ValueError("Minimum cutoff must be larger than zero.")
73
+ if cutoff > 0.5:
74
+ raise ValueError("A cutoff above 0.5 does not make sense.")
75
+ self.kernel_size = kernel_size
76
+ self.even = (kernel_size % 2 == 0)
77
+ self.pad_left = kernel_size // 2 - int(self.even)
78
+ self.pad_right = kernel_size // 2
79
+ self.stride = stride
80
+ self.padding = padding
81
+ self.padding_mode = padding_mode
82
+ filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
83
+ self.register_buffer("filter", filter)
84
+
85
+ #input [B, C, T]
86
+ def forward(self, x):
87
+ _, C, _ = x.shape
88
+
89
+ if self.padding:
90
+ x = F.pad(x, (self.pad_left, self.pad_right),
91
+ mode=self.padding_mode)
92
+ out = F.conv1d(x, self.filter.expand(C, -1, -1),
93
+ stride=self.stride, groups=C)
94
+
95
+ return out
vq/alias_free_torch/resample.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import torch.nn as nn
5
+ from torch.nn import functional as F
6
+ from .filter import LowPassFilter1d
7
+ from .filter import kaiser_sinc_filter1d
8
+
9
+
10
+ class UpSample1d(nn.Module):
11
+ def __init__(self, ratio=2, kernel_size=None):
12
+ super().__init__()
13
+ self.ratio = ratio
14
+ self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
15
+ self.stride = ratio
16
+ self.pad = self.kernel_size // ratio - 1
17
+ self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
18
+ self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
19
+ filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio,
20
+ half_width=0.6 / ratio,
21
+ kernel_size=self.kernel_size)
22
+ self.register_buffer("filter", filter)
23
+
24
+ # x: [B, C, T]
25
+ def forward(self, x):
26
+ _, C, _ = x.shape
27
+
28
+ x = F.pad(x, (self.pad, self.pad), mode='replicate')
29
+ x = self.ratio * F.conv_transpose1d(
30
+ x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
31
+ x = x[..., self.pad_left:-self.pad_right]
32
+
33
+ return x
34
+
35
+
36
+ class DownSample1d(nn.Module):
37
+ def __init__(self, ratio=2, kernel_size=None):
38
+ super().__init__()
39
+ self.ratio = ratio
40
+ self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
41
+ self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio,
42
+ half_width=0.6 / ratio,
43
+ stride=ratio,
44
+ kernel_size=self.kernel_size)
45
+
46
+ def forward(self, x):
47
+ xx = self.lowpass(x)
48
+
49
+ return xx
vq/blocks.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Sequence, Type, Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ ModuleFactory = Union[Type[nn.Module], Callable[[], nn.Module]]
8
+
9
+
10
+ class FeedForwardModule(nn.Module):
11
+
12
+ def __init__(self) -> None:
13
+ super().__init__()
14
+ self.net = None
15
+
16
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
17
+ return self.net(x)
18
+
19
+
20
+ class Residual(nn.Module):
21
+
22
+ def __init__(self, module: nn.Module) -> None:
23
+ super().__init__()
24
+ self.module = module
25
+
26
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
27
+ return self.module(x) + x
28
+
29
+
30
+ class DilatedConvolutionalUnit(FeedForwardModule):
31
+
32
+ def __init__(
33
+ self,
34
+ hidden_dim: int,
35
+ dilation: int,
36
+ kernel_size: int,
37
+ activation: ModuleFactory,
38
+ normalization: Callable[[nn.Module],
39
+ nn.Module] = lambda x: x) -> None:
40
+ super().__init__()
41
+ self.net = nn.Sequential(
42
+ activation(),
43
+ normalization(
44
+ nn.Conv1d(
45
+ in_channels=hidden_dim,
46
+ out_channels=hidden_dim,
47
+ kernel_size=kernel_size,
48
+ dilation=dilation,
49
+ padding=((kernel_size - 1) * dilation) // 2,
50
+ )),
51
+ activation(),
52
+ nn.Conv1d(in_channels=hidden_dim,
53
+ out_channels=hidden_dim,
54
+ kernel_size=1),
55
+ )
56
+
57
+
58
+ class UpsamplingUnit(FeedForwardModule):
59
+
60
+ def __init__(
61
+ self,
62
+ input_dim: int,
63
+ output_dim: int,
64
+ stride: int,
65
+ activation: ModuleFactory,
66
+ normalization: Callable[[nn.Module],
67
+ nn.Module] = lambda x: x) -> None:
68
+ super().__init__()
69
+ self.net = nn.Sequential(
70
+ activation(),
71
+ normalization(
72
+ nn.ConvTranspose1d(
73
+ in_channels=input_dim,
74
+ out_channels=output_dim,
75
+ kernel_size=2 * stride,
76
+ stride=stride,
77
+ padding=stride // 2+ stride % 2,
78
+ output_padding=1 if stride % 2 != 0 else 0
79
+ )))
80
+
81
+
82
+ class DownsamplingUnit(FeedForwardModule):
83
+
84
+ def __init__(
85
+ self,
86
+ input_dim: int,
87
+ output_dim: int,
88
+ stride: int,
89
+ activation: ModuleFactory,
90
+ normalization: Callable[[nn.Module],
91
+ nn.Module] = lambda x: x) -> None:
92
+ super().__init__()
93
+ self.net = nn.Sequential(
94
+ activation(),
95
+ normalization(
96
+ nn.Conv1d(
97
+ in_channels=input_dim,
98
+ out_channels=output_dim,
99
+ kernel_size=2 * stride,
100
+ stride=stride,
101
+ padding= stride // 2+ stride % 2,
102
+
103
+ )))
104
+
105
+
106
+ class DilatedResidualEncoder(FeedForwardModule):
107
+
108
+ def __init__(
109
+ self,
110
+ capacity: int,
111
+ dilated_unit: Type[DilatedConvolutionalUnit],
112
+ downsampling_unit: Type[DownsamplingUnit],
113
+ ratios: Sequence[int],
114
+ dilations: Union[Sequence[int], Sequence[Sequence[int]]],
115
+ pre_network_conv: Type[nn.Conv1d],
116
+ post_network_conv: Type[nn.Conv1d],
117
+ normalization: Callable[[nn.Module],
118
+ nn.Module] = lambda x: x) -> None:
119
+ super().__init__()
120
+ channels = capacity * 2**np.arange(len(ratios) + 1)
121
+
122
+ dilations_list = self.normalize_dilations(dilations, ratios)
123
+
124
+ net = [normalization(pre_network_conv(out_channels=channels[0]))]
125
+
126
+ for ratio, dilations, input_dim, output_dim in zip(
127
+ ratios, dilations_list, channels[:-1], channels[1:]):
128
+ for dilation in dilations:
129
+ net.append(Residual(dilated_unit(input_dim, dilation)))
130
+ net.append(downsampling_unit(input_dim, output_dim, ratio))
131
+
132
+ net.append(post_network_conv(in_channels=output_dim))
133
+
134
+ self.net = nn.Sequential(*net)
135
+
136
+ @staticmethod
137
+ def normalize_dilations(dilations: Union[Sequence[int],
138
+ Sequence[Sequence[int]]],
139
+ ratios: Sequence[int]):
140
+ if isinstance(dilations[0], int):
141
+ dilations = [dilations for _ in ratios]
142
+ return dilations
143
+
144
+
145
+ class DilatedResidualDecoder(FeedForwardModule):
146
+
147
+ def __init__(
148
+ self,
149
+ capacity: int,
150
+ dilated_unit: Type[DilatedConvolutionalUnit],
151
+ upsampling_unit: Type[UpsamplingUnit],
152
+ ratios: Sequence[int],
153
+ dilations: Union[Sequence[int], Sequence[Sequence[int]]],
154
+ pre_network_conv: Type[nn.Conv1d],
155
+ post_network_conv: Type[nn.Conv1d],
156
+ normalization: Callable[[nn.Module],
157
+ nn.Module] = lambda x: x) -> None:
158
+ super().__init__()
159
+ channels = capacity * 2**np.arange(len(ratios) + 1)
160
+ channels = channels[::-1]
161
+
162
+ dilations_list = self.normalize_dilations(dilations, ratios)
163
+ dilations_list = dilations_list[::-1]
164
+
165
+ net = [pre_network_conv(out_channels=channels[0])]
166
+
167
+ for ratio, dilations, input_dim, output_dim in zip(
168
+ ratios, dilations_list, channels[:-1], channels[1:]):
169
+ net.append(upsampling_unit(input_dim, output_dim, ratio))
170
+ for dilation in dilations:
171
+ net.append(Residual(dilated_unit(output_dim, dilation)))
172
+
173
+ net.append(normalization(post_network_conv(in_channels=output_dim)))
174
+
175
+ self.net = nn.Sequential(*net)
176
+
177
+ @staticmethod
178
+ def normalize_dilations(dilations: Union[Sequence[int],
179
+ Sequence[Sequence[int]]],
180
+ ratios: Sequence[int]):
181
+ if isinstance(dilations[0], int):
182
+ dilations = [dilations for _ in ratios]
183
+ return dilations
vq/bs_roformer5.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.nn import Module, ModuleList
5
+ import torchaudio
6
+ from einops import rearrange
7
+ import numpy as np
8
+ # from rotary_embedding_torch import RotaryEmbedding
9
+
10
+ from torchtune.modules import RotaryPositionalEmbeddings
11
+
12
+
13
+
14
+ class RMSNorm(torch.nn.Module):
15
+ def __init__(self, dim: int, eps: float = 1e-6):
16
+ r"""https://github.com/meta-llama/llama/blob/main/llama/model.py"""
17
+ super().__init__()
18
+ self.eps = eps
19
+ self.weight = nn.Parameter(torch.ones(dim))
20
+
21
+ def forward(self, x):
22
+ norm_x = torch.mean(x ** 2, dim=-1, keepdim=True)
23
+ output = x * torch.rsqrt(norm_x + self.eps) * self.weight
24
+ return output
25
+
26
+
27
+
28
+ class MLP(nn.Module):
29
+ def __init__(self, dim: int) -> None:
30
+ super().__init__()
31
+
32
+ self.fc1 = nn.Linear(dim, 4 * dim, bias=False)
33
+ self.silu = nn.SiLU()
34
+ self.fc2 = nn.Linear(4 * dim, dim, bias=False)
35
+
36
+ def forward(self, x):
37
+ x = self.fc1(x)
38
+ x = self.silu(x)
39
+ x = self.fc2(x)
40
+ return x
41
+
42
+
43
+ class Attention(nn.Module):
44
+
45
+ def __init__(self, dim: int, n_heads: int, rotary_embed: RotaryPositionalEmbeddings):
46
+ super().__init__()
47
+
48
+ assert dim % n_heads == 0
49
+
50
+ self.n_heads = n_heads
51
+ self.dim = dim
52
+ self.rotary_embed = rotary_embed
53
+
54
+ self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
55
+ assert self.flash, "Must have flash attention."
56
+
57
+ self.c_attn = nn.Linear(dim, 3 * dim, bias=False)
58
+ self.c_proj = nn.Linear(dim, dim, bias=False)
59
+
60
+ def forward(self, x):
61
+ r"""
62
+ Args:
63
+ x: (b, t, h*d)
64
+
65
+ Constants:
66
+ b: batch_size
67
+ t: time steps
68
+ r: 3
69
+ h: heads_num
70
+ d: heads_dim
71
+ """
72
+ B, T, C = x.size()
73
+
74
+ q, k, v = rearrange(self.c_attn(x), 'b t (r h d) -> r b h t d', r=3, h=self.n_heads)
75
+ # q, k, v: (b, h, t, d)
76
+
77
+ q = self.rotary_embed(q)
78
+ k = self.rotary_embed(k)
79
+
80
+ if self.flash:
81
+ y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0, is_causal=False)
82
+
83
+ y = rearrange(y, 'b h t d -> b t (h d)')
84
+
85
+ y = self.c_proj(y)
86
+ # shape: (b, t, h*d)
87
+
88
+ return y
89
+
90
+
91
+ class TransformerBlock(nn.Module):
92
+ def __init__(self, dim: int, n_heads: int, rotary_embed: RotaryPositionalEmbeddings):
93
+
94
+ super().__init__()
95
+ self.dim = dim
96
+ self.n_heads = n_heads
97
+
98
+ self.att_norm = RMSNorm(dim)
99
+ self.ffn_norm = RMSNorm(dim)
100
+ self.att = Attention(dim=dim, n_heads=n_heads, rotary_embed=rotary_embed)
101
+ self.mlp = MLP(dim=dim)
102
+
103
+
104
+ def forward(
105
+ self,
106
+ x: torch.Tensor,
107
+ ):
108
+ x = x + self.att(self.att_norm(x))
109
+ x = x + self.mlp(self.ffn_norm(x))
110
+ return x
111
+
112
+
113
+ if __name__ == '__main__':
114
+ rotary_embed_128 = RotaryPositionalEmbeddings(dim=128)
115
+ transformer_block = TransformerBlock(
116
+ dim=1024,
117
+ n_heads=8,
118
+ rotary_embed=rotary_embed_128
119
+ )
120
+ x = torch.randn(2, 128, 1024)
121
+ y = transformer_block(x)
122
+ print(y.shape)
123
+ c=1
vq/codec_decoder.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ from vq.residual_vq import ResidualVQ
7
+ from vq.module import WNConv1d, DecoderBlock, ResLSTM
8
+ from vq.alias_free_torch import *
9
+ from vq import activations
10
+ import vq.blocks as blocks
11
+ from torch.nn import utils
12
+
13
+ from vq.bs_roformer5 import TransformerBlock
14
+
15
+ from torchtune.modules import RotaryPositionalEmbeddings
16
+
17
+ def init_weights(m):
18
+ if isinstance(m, nn.Conv1d):
19
+ nn.init.trunc_normal_(m.weight, std=0.02)
20
+ nn.init.constant_(m.bias, 0)
21
+
22
+ class CodecDecoder(nn.Module):
23
+ def __init__(self,
24
+ in_channels=1024,
25
+ upsample_initial_channel=1536,
26
+ ngf=48,
27
+ use_rnn=True,
28
+ rnn_bidirectional=False,
29
+ rnn_num_layers=2,
30
+ up_ratios=(5, 4, 4, 4, 2),
31
+ dilations=(1, 3, 9),
32
+ vq_num_quantizers=1,
33
+ vq_dim=2048,
34
+ vq_commit_weight=0.25,
35
+ vq_weight_init=False,
36
+ vq_full_commit_loss=False,
37
+ codebook_size=16384,
38
+ codebook_dim=32,
39
+ ):
40
+ super().__init__()
41
+ self.hop_length = np.prod(up_ratios)
42
+ self.ngf = ngf
43
+ self.up_ratios = up_ratios
44
+
45
+ self.quantizer = ResidualVQ(
46
+ num_quantizers=vq_num_quantizers,
47
+ dim=vq_dim, # double the dim for acousitc and semantic
48
+ codebook_size=codebook_size,
49
+ codebook_dim=codebook_dim,
50
+ threshold_ema_dead_code=2,
51
+ commitment=vq_commit_weight,
52
+ weight_init=vq_weight_init,
53
+ full_commit_loss=vq_full_commit_loss,
54
+ )
55
+ channels = upsample_initial_channel
56
+ layers = [WNConv1d(in_channels, channels, kernel_size=7, padding=3)]
57
+
58
+ if use_rnn:
59
+ layers += [
60
+ ResLSTM(channels,
61
+ num_layers=rnn_num_layers,
62
+ bidirectional=rnn_bidirectional
63
+ )
64
+ ]
65
+
66
+ for i, stride in enumerate(up_ratios):
67
+ input_dim = channels // 2**i
68
+ output_dim = channels // 2 ** (i + 1)
69
+ layers += [DecoderBlock(input_dim, output_dim, stride, dilations)]
70
+
71
+ layers += [
72
+ Activation1d(activation=activations.SnakeBeta(output_dim, alpha_logscale=True)),
73
+ WNConv1d(output_dim, 1, kernel_size=7, padding=3),
74
+ nn.Tanh(),
75
+ ]
76
+
77
+ self.model = nn.Sequential(*layers)
78
+
79
+ self.reset_parameters()
80
+
81
+ def forward(self, x, vq=True):
82
+ if vq is True:
83
+ x, q, commit_loss = self.quantizer(x)
84
+ return x, q, commit_loss
85
+ x = self.model(x)
86
+ return x
87
+
88
+ def vq2emb(self, vq):
89
+ self.quantizer = self.quantizer.eval()
90
+ x = self.quantizer.vq2emb(vq)
91
+ return x
92
+
93
+ def get_emb(self):
94
+ self.quantizer = self.quantizer.eval()
95
+ embs = self.quantizer.get_emb()
96
+ return embs
97
+
98
+ def inference_vq(self, vq):
99
+ x = vq[None,:,:]
100
+ x = self.model(x)
101
+ return x
102
+
103
+ def inference_0(self, x):
104
+ x, q, loss, perp = self.quantizer(x)
105
+ x = self.model(x)
106
+ return x, None
107
+
108
+ def inference(self, x):
109
+ x = self.model(x)
110
+ return x, None
111
+
112
+
113
+ def remove_weight_norm(self):
114
+ """Remove weight normalization module from all of the layers."""
115
+
116
+ def _remove_weight_norm(m):
117
+ try:
118
+ torch.nn.utils.remove_weight_norm(m)
119
+ except ValueError: # this module didn't have weight norm
120
+ return
121
+
122
+ self.apply(_remove_weight_norm)
123
+
124
+ def apply_weight_norm(self):
125
+ """Apply weight normalization module from all of the layers."""
126
+
127
+ def _apply_weight_norm(m):
128
+ if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d):
129
+ torch.nn.utils.weight_norm(m)
130
+
131
+ self.apply(_apply_weight_norm)
132
+
133
+ def reset_parameters(self):
134
+ self.apply(init_weights)
135
+
136
+
137
+ class CodecDecoder_oobleck_Transformer(nn.Module):
138
+ def __init__(self,
139
+ ngf=32,
140
+ up_ratios=(5, 4, 4, 4, 2),
141
+ dilations=(1, 3, 9),
142
+ vq_num_quantizers=1,
143
+ vq_dim=1024,
144
+ vq_commit_weight=0.25,
145
+ vq_weight_init=False,
146
+ vq_full_commit_loss=False,
147
+ codebook_size=16384,
148
+ codebook_dim=16,
149
+ hidden_dim=1024,
150
+ depth=12,
151
+ heads=16,
152
+ pos_meb_dim=64,
153
+ ):
154
+ super().__init__()
155
+ self.hop_length = np.prod(up_ratios)
156
+ self.capacity = ngf
157
+ self.up_ratios = up_ratios
158
+ self.hidden_dim = hidden_dim
159
+ self.quantizer = ResidualVQ(
160
+ num_quantizers=vq_num_quantizers,
161
+ dim=vq_dim, # double the dim for acousitc and semantic
162
+ codebook_size=codebook_size,
163
+ codebook_dim=codebook_dim,
164
+ threshold_ema_dead_code=2,
165
+ commitment=vq_commit_weight,
166
+ weight_init=vq_weight_init,
167
+ full_commit_loss=vq_full_commit_loss,
168
+ )
169
+
170
+ time_rotary_embed = RotaryPositionalEmbeddings(dim=pos_meb_dim)
171
+
172
+ transformer_blocks = [
173
+ TransformerBlock(dim=hidden_dim, n_heads=heads, rotary_embed=time_rotary_embed)
174
+ for _ in range(depth)
175
+ ]
176
+
177
+ self.transformers = nn.Sequential(*transformer_blocks)
178
+
179
+ self.final_layer_norm = nn.LayerNorm(hidden_dim, eps=1e-6)
180
+
181
+ self.conv_blocks = blocks.DilatedResidualDecoder(
182
+ capacity=self.capacity,
183
+ dilated_unit=self.dilated_unit,
184
+ upsampling_unit=self.upsampling_unit,
185
+ ratios=up_ratios, # 逆转编码器的下采样比率
186
+ dilations=dilations,
187
+ pre_network_conv=self.pre_conv,
188
+ post_network_conv=self.post_conv,
189
+ )
190
+
191
+
192
+
193
+ self.reset_parameters()
194
+
195
+ def forward(self, x, vq=True):
196
+ if vq is True:
197
+ x, q, commit_loss = self.quantizer(x)
198
+ return x, q, commit_loss
199
+ x= self.transformers(x)
200
+ x = self.final_layer_norm(x)
201
+ x = x.permute(0, 2, 1)
202
+ x = self.conv_blocks(x)
203
+ return x
204
+
205
+ def vq2emb(self, vq):
206
+ self.quantizer = self.quantizer.eval()
207
+ x = self.quantizer.vq2emb(vq)
208
+ return x
209
+
210
+ def get_emb(self):
211
+ self.quantizer = self.quantizer.eval()
212
+ embs = self.quantizer.get_emb()
213
+ return embs
214
+
215
+ def inference_vq(self, vq):
216
+ x = vq[None,:,:]
217
+ x = self.model(x)
218
+ return x
219
+
220
+ def inference_0(self, x):
221
+ x, q, loss, perp = self.quantizer(x)
222
+ x = self.model(x)
223
+ return x, None
224
+
225
+ def inference(self, x):
226
+ x = self.model(x)
227
+ return x, None
228
+
229
+
230
+ def remove_weight_norm(self):
231
+ """Remove weight normalization module from all of the layers."""
232
+
233
+ def _remove_weight_norm(m):
234
+ try:
235
+ torch.nn.utils.remove_weight_norm(m)
236
+ except ValueError: # this module didn't have weight norm
237
+ return
238
+
239
+ self.apply(_remove_weight_norm)
240
+
241
+ def apply_weight_norm(self):
242
+ """Apply weight normalization module from all of the layers."""
243
+
244
+ def _apply_weight_norm(m):
245
+ if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d):
246
+ torch.nn.utils.weight_norm(m)
247
+
248
+ self.apply(_apply_weight_norm)
249
+
250
+ def reset_parameters(self):
251
+ self.apply(init_weights)
252
+
253
+ def pre_conv(self, out_channels):
254
+ return nn.Conv1d(in_channels=self.hidden_dim, out_channels=out_channels, kernel_size=1)
255
+
256
+ # 定义后处理卷积层,将模型的输出映射到最终的输出通道数
257
+ def post_conv(self,in_channels):
258
+ return nn.Conv1d(in_channels=in_channels, out_channels=1, kernel_size=1)
259
+
260
+ def dilated_unit(self, hidden_dim, dilation):
261
+ return blocks.DilatedConvolutionalUnit(
262
+ hidden_dim=hidden_dim,
263
+ dilation=dilation,
264
+ kernel_size=3,
265
+ activation=nn.ReLU ,
266
+ normalization=utils.weight_norm
267
+ )
268
+
269
+ # 定义上采样单元
270
+ def upsampling_unit(self,input_dim, output_dim, stride):
271
+ return blocks.UpsamplingUnit(
272
+ input_dim=input_dim,
273
+ output_dim=output_dim,
274
+ stride=stride,
275
+ activation=nn.ReLU ,
276
+ normalization=utils.weight_norm
277
+ )
278
+
279
+ def main():
280
+ # 设置设备
281
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
282
+ print(f"Using device: {device}")
283
+
284
+ # 初始化模型
285
+ model = CodecDecoder_oobleck_Transformer().to(device)
286
+ print("Model initialized.")
287
+
288
+ # 创建测试输入: batch_size x in_channels x sequence_length
289
+ batch_size = 2
290
+ in_channels = 1024
291
+ sequence_length = 100 # 示例长度,可以根据需要调整
292
+ dummy_input = torch.randn(batch_size, sequence_length, in_channels).to(device)
293
+ print(f"Dummy input shape: {dummy_input.shape}")
294
+
295
+ # 将模型设为评估模式
296
+ model.eval()
297
+
298
+
299
+
300
+ output_no_vq = model(dummy_input, vq=False)
301
+ c=1
302
+
303
+ if __name__ == "__main__":
304
+ main()
vq/codec_decoder_vocos.py ADDED
@@ -0,0 +1,638 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append('/aifs4su/data/zheny/bigcodec_final/BigCodec_conv_transformer_vocos')
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ from vq.residual_vq import ResidualVQ
7
+ from vq.module import WNConv1d, DecoderBlock, ResLSTM
8
+ from vq.alias_free_torch import *
9
+ from vq import activations
10
+ from typing import Optional
11
+ from vq.module import ConvNeXtBlock, AdaLayerNorm
12
+ from vq.bs_roformer5 import TransformerBlock
13
+ # from rotary_embedding_torch import RotaryEmbedding
14
+ from torchtune.modules import RotaryPositionalEmbeddings
15
+ from vector_quantize_pytorch import ResidualFSQ
16
+ from torch.nn import Module, ModuleList
17
+ class ISTFT(nn.Module):
18
+ """
19
+ Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with
20
+ windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges.
21
+ See issue: https://github.com/pytorch/pytorch/issues/62323
22
+ Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs.
23
+ The NOLA constraint is met as we trim padded samples anyway.
24
+
25
+ Args:
26
+ n_fft (int): Size of Fourier transform.
27
+ hop_length (int): The distance between neighboring sliding window frames.
28
+ win_length (int): The size of window frame and STFT filter.
29
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
30
+ """
31
+
32
+ def __init__(self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"):
33
+ super().__init__()
34
+ if padding not in ["center", "same"]:
35
+ raise ValueError("Padding must be 'center' or 'same'.")
36
+ self.padding = padding
37
+ self.n_fft = n_fft
38
+ self.hop_length = hop_length
39
+ self.win_length = win_length
40
+ window = torch.hann_window(win_length)
41
+ self.register_buffer("window", window)
42
+
43
+ def forward(self, spec: torch.Tensor) -> torch.Tensor:
44
+ """
45
+ Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram.
46
+
47
+ Args:
48
+ spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size,
49
+ N is the number of frequency bins, and T is the number of time frames.
50
+
51
+ Returns:
52
+ Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal.
53
+ """
54
+ if self.padding == "center":
55
+ # Fallback to pytorch native implementation
56
+ return torch.istft(spec, self.n_fft, self.hop_length, self.win_length, self.window, center=True)
57
+ elif self.padding == "same":
58
+ pad = (self.win_length - self.hop_length) // 2
59
+ else:
60
+ raise ValueError("Padding must be 'center' or 'same'.")
61
+
62
+ assert spec.dim() == 3, "Expected a 3D tensor as input"
63
+ B, N, T = spec.shape
64
+
65
+ # Inverse FFT
66
+ ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward")
67
+ ifft = ifft * self.window[None, :, None]
68
+
69
+ # Overlap and Add
70
+ output_size = (T - 1) * self.hop_length + self.win_length
71
+ y = torch.nn.functional.fold(
72
+ ifft, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length),
73
+ )[:, 0, 0, pad:-pad]
74
+
75
+ # Window envelope
76
+ window_sq = self.window.square().expand(1, T, -1).transpose(1, 2)
77
+ window_envelope = torch.nn.functional.fold(
78
+ window_sq, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length),
79
+ ).squeeze()[pad:-pad]
80
+
81
+ # Normalize
82
+ assert (window_envelope > 1e-11).all()
83
+ y = y / window_envelope
84
+
85
+ return y
86
+
87
+
88
+
89
+ class FourierHead(nn.Module):
90
+ """Base class for inverse fourier modules."""
91
+
92
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
93
+ """
94
+ Args:
95
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
96
+ L is the sequence length, and H denotes the model dimension.
97
+
98
+ Returns:
99
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
100
+ """
101
+ raise NotImplementedError("Subclasses must implement the forward method.")
102
+
103
+
104
+ class ISTFTHead(FourierHead):
105
+ """
106
+ ISTFT Head module for predicting STFT complex coefficients.
107
+
108
+ Args:
109
+ dim (int): Hidden dimension of the model.
110
+ n_fft (int): Size of Fourier transform.
111
+ hop_length (int): The distance between neighboring sliding window frames, which should align with
112
+ the resolution of the input features.
113
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
114
+ """
115
+
116
+ def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"):
117
+ super().__init__()
118
+ out_dim = n_fft + 2
119
+ self.out = torch.nn.Linear(dim, out_dim)
120
+ self.istft = ISTFT(n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding)
121
+
122
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
123
+ """
124
+ Forward pass of the ISTFTHead module.
125
+
126
+ Args:
127
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
128
+ L is the sequence length, and H denotes the model dimension.
129
+
130
+ Returns:
131
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
132
+ """
133
+ x_pred = self.out(x )
134
+ # x_pred = x
135
+ x_pred = x_pred.transpose(1, 2)
136
+ mag, p = x_pred.chunk(2, dim=1)
137
+ mag = torch.exp(mag)
138
+ mag = torch.clip(mag, max=1e2) # safeguard to prevent excessively large magnitudes
139
+ # wrapping happens here. These two lines produce real and imaginary value
140
+ x = torch.cos(p)
141
+ y = torch.sin(p)
142
+ # recalculating phase here does not produce anything new
143
+ # only costs time
144
+ # phase = torch.atan2(y, x)
145
+ # S = mag * torch.exp(phase * 1j)
146
+ # better directly produce the complex value
147
+ S = mag * (x + 1j * y)
148
+ audio = self.istft(S)
149
+ return audio.unsqueeze(1),x_pred
150
+
151
+
152
+ def nonlinearity(x):
153
+ # swish
154
+ return x * torch.sigmoid(x)
155
+
156
+
157
+ def Normalize(in_channels, num_groups=32):
158
+ return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
159
+
160
+
161
+ class ResnetBlock(nn.Module):
162
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
163
+ dropout, temb_channels=512):
164
+ super().__init__()
165
+ self.in_channels = in_channels
166
+ out_channels = in_channels if out_channels is None else out_channels
167
+ self.out_channels = out_channels
168
+ self.use_conv_shortcut = conv_shortcut
169
+
170
+ self.norm1 = Normalize(in_channels)
171
+ self.conv1 = torch.nn.Conv1d(in_channels,
172
+ out_channels,
173
+ kernel_size=3,
174
+ stride=1,
175
+ padding=1)
176
+ if temb_channels > 0:
177
+ self.temb_proj = torch.nn.Linear(temb_channels,
178
+ out_channels)
179
+ self.norm2 = Normalize(out_channels)
180
+ self.dropout = torch.nn.Dropout(dropout)
181
+ self.conv2 = torch.nn.Conv1d(out_channels,
182
+ out_channels,
183
+ kernel_size=3,
184
+ stride=1,
185
+ padding=1)
186
+ if self.in_channels != self.out_channels:
187
+ if self.use_conv_shortcut:
188
+ self.conv_shortcut = torch.nn.Conv1d(in_channels,
189
+ out_channels,
190
+ kernel_size=3,
191
+ stride=1,
192
+ padding=1)
193
+ else:
194
+ self.nin_shortcut = torch.nn.Conv1d(in_channels,
195
+ out_channels,
196
+ kernel_size=1,
197
+ stride=1,
198
+ padding=0)
199
+
200
+ def forward(self, x, temb=None):
201
+ h = x
202
+ h = self.norm1(h)
203
+ h = nonlinearity(h)
204
+ h = self.conv1(h)
205
+
206
+ if temb is not None:
207
+ h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
208
+
209
+ h = self.norm2(h)
210
+ h = nonlinearity(h)
211
+ h = self.dropout(h)
212
+ h = self.conv2(h)
213
+
214
+ if self.in_channels != self.out_channels:
215
+ if self.use_conv_shortcut:
216
+ x = self.conv_shortcut(x)
217
+ else:
218
+ x = self.nin_shortcut(x)
219
+
220
+ return x + h
221
+
222
+ class AttnBlock(nn.Module):
223
+ def __init__(self, in_channels):
224
+ super().__init__()
225
+ self.in_channels = in_channels
226
+
227
+ self.norm = Normalize(in_channels)
228
+ self.q = torch.nn.Conv1d(in_channels,
229
+ in_channels,
230
+ kernel_size=1,
231
+ stride=1,
232
+ padding=0)
233
+ self.k = torch.nn.Conv1d(in_channels,
234
+ in_channels,
235
+ kernel_size=1,
236
+ stride=1,
237
+ padding=0)
238
+ self.v = torch.nn.Conv1d(in_channels,
239
+ in_channels,
240
+ kernel_size=1,
241
+ stride=1,
242
+ padding=0)
243
+ self.proj_out = torch.nn.Conv1d(in_channels,
244
+ in_channels,
245
+ kernel_size=1,
246
+ stride=1,
247
+ padding=0)
248
+
249
+ def forward(self, x):
250
+ h_ = x
251
+ h_ = self.norm(h_)
252
+ q = self.q(h_)
253
+ k = self.k(h_)
254
+ v = self.v(h_)
255
+
256
+ # compute attention
257
+ b, c, h = q.shape
258
+ q = q.permute(0, 2, 1) # b,hw,c
259
+ w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
260
+ w_ = w_ * (int(c) ** (-0.5))
261
+ w_ = torch.nn.functional.softmax(w_, dim=2)
262
+
263
+ # attend to values
264
+ w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
265
+ h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
266
+
267
+ h_ = self.proj_out(h_)
268
+
269
+ return x + h_
270
+
271
+ def make_attn(in_channels, attn_type="vanilla"):
272
+ assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown'
273
+ print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
274
+ if attn_type == "vanilla":
275
+ return AttnBlock(in_channels)
276
+
277
+
278
+ class Backbone(nn.Module):
279
+ """Base class for the generator's backbone. It preserves the same temporal resolution across all layers."""
280
+
281
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
282
+ """
283
+ Args:
284
+ x (Tensor): Input tensor of shape (B, C, L), where B is the batch size,
285
+ C denotes output features, and L is the sequence length.
286
+
287
+ Returns:
288
+ Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length,
289
+ and H denotes the model dimension.
290
+ """
291
+ raise NotImplementedError("Subclasses must implement the forward method.")
292
+
293
+
294
+ class VocosBackbone(Backbone):
295
+ """
296
+ Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization
297
+
298
+ Args:
299
+ input_channels (int): Number of input features channels.
300
+ dim (int): Hidden dimension of the model.
301
+ intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock.
302
+ num_layers (int): Number of ConvNeXtBlock layers.
303
+ layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`.
304
+ adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
305
+ None means non-conditional model. Defaults to None.
306
+ """
307
+
308
+ def __init__(
309
+ self, hidden_dim=1024,depth=12,heads=16,pos_meb_dim=64):
310
+ super().__init__()
311
+
312
+ self.embed = nn.Conv1d(hidden_dim, hidden_dim, kernel_size=7, padding=3)
313
+
314
+
315
+
316
+ self.temb_ch = 0
317
+ block_in = hidden_dim
318
+ dropout = 0.1
319
+
320
+ prior_net : tp.List[nn.Module] = [
321
+ ResnetBlock(in_channels=block_in,out_channels=block_in,
322
+ temb_channels=self.temb_ch,dropout=dropout),
323
+ ResnetBlock(in_channels=block_in,out_channels=block_in,
324
+ temb_channels=self.temb_ch,dropout=dropout),
325
+ ]
326
+ self.prior_net = nn.Sequential(*prior_net)
327
+
328
+ depth = depth
329
+ time_rotary_embed = RotaryPositionalEmbeddings(dim=pos_meb_dim)
330
+
331
+
332
+ transformer_blocks = [
333
+ TransformerBlock(dim=hidden_dim, n_heads=heads, rotary_embed=time_rotary_embed)
334
+ for _ in range(depth)
335
+ ]
336
+
337
+
338
+ self.transformers = nn.Sequential(*transformer_blocks)
339
+ self.final_layer_norm = nn.LayerNorm(hidden_dim, eps=1e-6)
340
+ post_net : tp.List[nn.Module] = [
341
+ ResnetBlock(in_channels=block_in,out_channels=block_in,
342
+ temb_channels=self.temb_ch,dropout=dropout),
343
+ ResnetBlock(in_channels=block_in,out_channels=block_in,
344
+ temb_channels=self.temb_ch,dropout=dropout),
345
+ ]
346
+ self.post_net = nn.Sequential(*post_net)
347
+
348
+ def forward(self, x: torch.Tensor ) -> torch.Tensor:
349
+ x = x.transpose(1, 2)
350
+ x = self.embed(x)
351
+ x = self.prior_net(x)
352
+ x = x.transpose(1, 2)
353
+ x= self.transformers(x)
354
+ x = x.transpose(1, 2)
355
+ x = self.post_net(x)
356
+ x = x.transpose(1, 2)
357
+ x = self.final_layer_norm(x)
358
+ return x
359
+
360
+
361
+
362
+
363
+
364
+
365
+
366
+ def init_weights(m):
367
+ if isinstance(m, nn.Conv1d):
368
+ nn.init.trunc_normal_(m.weight, std=0.02)
369
+ nn.init.constant_(m.bias, 0)
370
+
371
+ class CodecDecoderVocos(nn.Module):
372
+ def __init__(self,
373
+ hidden_dim=1024,
374
+ depth=12,
375
+ heads=16,
376
+ pos_meb_dim=64,
377
+ hop_length=320,
378
+ vq_num_quantizers=1,
379
+ vq_dim=2048, #1024 2048
380
+ vq_commit_weight=0.25,
381
+ vq_weight_init=False,
382
+ vq_full_commit_loss=False,
383
+ codebook_size=16384,
384
+ codebook_dim=16,
385
+ ):
386
+ super().__init__()
387
+ self.hop_length = hop_length
388
+
389
+ self.quantizer = ResidualFSQ(
390
+ dim = vq_dim,
391
+ levels = [4, 4, 4, 4, 4,4,4,4],
392
+ num_quantizers = 1
393
+ )
394
+
395
+ # self.quantizer = ResidualVQ(
396
+ # num_quantizers=vq_num_quantizers,
397
+ # dim=vq_dim,
398
+ # codebook_size=codebook_size,
399
+ # codebook_dim=codebook_dim,
400
+ # threshold_ema_dead_code=2,
401
+ # commitment=vq_commit_weight,
402
+ # weight_init=vq_weight_init,
403
+ # full_commit_loss=vq_full_commit_loss,
404
+ # )
405
+
406
+
407
+ self.backbone = VocosBackbone( hidden_dim=hidden_dim,depth=depth,heads=heads,pos_meb_dim=pos_meb_dim)
408
+
409
+ self.head = ISTFTHead(dim=hidden_dim, n_fft=self.hop_length*4, hop_length=self.hop_length, padding="same")
410
+
411
+ self.reset_parameters()
412
+
413
+ def forward(self, x, vq=True):
414
+ if vq is True:
415
+ # x, q, commit_loss = self.quantizer(x)
416
+ x = x.permute(0, 2, 1)
417
+ x, q = self.quantizer(x)
418
+ x = x.permute(0, 2, 1)
419
+ q = q.permute(0, 2, 1)
420
+ return x, q, None
421
+ x = self.backbone(x)
422
+ x,_ = self.head(x)
423
+
424
+ return x ,_
425
+
426
+ def vq2emb(self, vq):
427
+ self.quantizer = self.quantizer.eval()
428
+ x = self.quantizer.vq2emb(vq)
429
+ return x
430
+
431
+ def get_emb(self):
432
+ self.quantizer = self.quantizer.eval()
433
+ embs = self.quantizer.get_emb()
434
+ return embs
435
+
436
+ def inference_vq(self, vq):
437
+ x = vq[None,:,:]
438
+ x = self.model(x)
439
+ return x
440
+
441
+ def inference_0(self, x):
442
+ x, q, loss, perp = self.quantizer(x)
443
+ x = self.model(x)
444
+ return x, None
445
+
446
+ def inference(self, x):
447
+ x = self.model(x)
448
+ return x, None
449
+
450
+
451
+ def remove_weight_norm(self):
452
+ """Remove weight normalization module from all of the layers."""
453
+
454
+ def _remove_weight_norm(m):
455
+ try:
456
+ torch.nn.utils.remove_weight_norm(m)
457
+ except ValueError: # this module didn't have weight norm
458
+ return
459
+
460
+ self.apply(_remove_weight_norm)
461
+
462
+ def apply_weight_norm(self):
463
+ """Apply weight normalization module from all of the layers."""
464
+
465
+ def _apply_weight_norm(m):
466
+ if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d):
467
+ torch.nn.utils.weight_norm(m)
468
+
469
+ self.apply(_apply_weight_norm)
470
+
471
+ def reset_parameters(self):
472
+ self.apply(init_weights)
473
+
474
+
475
+
476
+ class CodecDecoderVocos_transpose(nn.Module):
477
+ def __init__(self,
478
+ hidden_dim=1024,
479
+ depth=12,
480
+ heads=16,
481
+ pos_meb_dim=64,
482
+ hop_length=320,
483
+ vq_num_quantizers=1,
484
+ vq_dim=1024, #1024 2048
485
+ vq_commit_weight=0.25,
486
+ vq_weight_init=False,
487
+ vq_full_commit_loss=False,
488
+ codebook_size=16384,
489
+ codebook_dim=16,
490
+ ):
491
+ super().__init__()
492
+ self.hop_length = hop_length
493
+
494
+
495
+ self.quantizer = ResidualVQ(
496
+ num_quantizers=vq_num_quantizers,
497
+ dim=vq_dim,
498
+ codebook_size=codebook_size,
499
+ codebook_dim=codebook_dim,
500
+ threshold_ema_dead_code=2,
501
+ commitment=vq_commit_weight,
502
+ weight_init=vq_weight_init,
503
+ full_commit_loss=vq_full_commit_loss,
504
+ )
505
+
506
+
507
+ self.backbone = VocosBackbone( hidden_dim=hidden_dim,depth=depth,heads=heads,pos_meb_dim=pos_meb_dim)
508
+
509
+ self.inverse_mel_conv = nn.Sequential(
510
+ nn.GELU(),
511
+ nn.ConvTranspose1d(
512
+ in_channels=hidden_dim,
513
+ out_channels=hidden_dim,
514
+ kernel_size=3,
515
+ stride=2,
516
+ padding=1,
517
+ output_padding=1 # 确保输出长度与编码前匹配
518
+ ),
519
+ nn.GELU(),
520
+ nn.ConvTranspose1d(
521
+ in_channels=hidden_dim,
522
+ out_channels=hidden_dim,
523
+ kernel_size=3,
524
+ padding=1
525
+ )
526
+ )
527
+
528
+ self.head = ISTFTHead(dim=hidden_dim, n_fft=self.hop_length*4, hop_length=self.hop_length, padding="same")
529
+
530
+ self.reset_parameters()
531
+
532
+ def forward(self, x, vq=True):
533
+ if vq is True:
534
+ x, q, commit_loss = self.quantizer(x)
535
+ return x, q, commit_loss
536
+ x = self.backbone(x)
537
+ x,_ = self.head(x)
538
+
539
+ return x ,_
540
+
541
+ def vq2emb(self, vq):
542
+ self.quantizer = self.quantizer.eval()
543
+ x = self.quantizer.vq2emb(vq)
544
+ return x
545
+
546
+ def get_emb(self):
547
+ self.quantizer = self.quantizer.eval()
548
+ embs = self.quantizer.get_emb()
549
+ return embs
550
+
551
+ def inference_vq(self, vq):
552
+ x = vq[None,:,:]
553
+ x = self.model(x)
554
+ return x
555
+
556
+ def inference_0(self, x):
557
+ x, q, loss, perp = self.quantizer(x)
558
+ x = self.model(x)
559
+ return x, None
560
+
561
+ def inference(self, x):
562
+ x = self.model(x)
563
+ return x, None
564
+
565
+
566
+ def remove_weight_norm(self):
567
+ """Remove weight normalization module from all of the layers."""
568
+
569
+ def _remove_weight_norm(m):
570
+ try:
571
+ torch.nn.utils.remove_weight_norm(m)
572
+ except ValueError: # this module didn't have weight norm
573
+ return
574
+
575
+ self.apply(_remove_weight_norm)
576
+
577
+ def apply_weight_norm(self):
578
+ """Apply weight normalization module from all of the layers."""
579
+
580
+ def _apply_weight_norm(m):
581
+ if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d):
582
+ torch.nn.utils.weight_norm(m)
583
+
584
+ self.apply(_apply_weight_norm)
585
+
586
+ def reset_parameters(self):
587
+ self.apply(init_weights)
588
+
589
+
590
+
591
+
592
+ def main():
593
+ # 设置设备
594
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
595
+ print(f"Using device: {device}")
596
+
597
+ # 初始化模型
598
+ model = CodecDecoderVocos_transpose().to(device)
599
+ print("Model initialized.")
600
+
601
+ # 创建测试输入: batch_size x in_channels x sequence_length
602
+ batch_size = 2
603
+ in_channels = 1024
604
+ sequence_length = 50 # 示例长度,可以根据需要调整
605
+ dummy_input = torch.randn(batch_size, in_channels, sequence_length).to(device)
606
+ print(f"Dummy input shape: {dummy_input.shape}")
607
+
608
+ # 将模型设为评估模式
609
+ model.eval()
610
+
611
+ # 前向传播(使用 VQ)
612
+ # with torch.no_grad():
613
+ # try:
614
+ # output, q, commit_loss = model(dummy_input, vq=True)
615
+ # print("Forward pass with VQ:")
616
+ # print(f"Output shape: {output.shape}")
617
+ # print(f"Quantized codes shape: {q.shape}")
618
+ # print(f"Commitment loss: {commit_loss}")
619
+ # except Exception as e:
620
+ # print(f"Error during forward pass with VQ: {e}")
621
+
622
+ # 前向传播(不使用 VQ)
623
+ with torch.no_grad():
624
+ # try:
625
+ output_no_vq = model(dummy_input, vq=False)
626
+ print("\nForward pass without VQ:")
627
+ print(f"Output shape: {output_no_vq.shape}")
628
+ c=1
629
+ # except Exception as e:
630
+ # print(f"Error during forward pass without VQ: {e}")
631
+
632
+
633
+ # model_size_bytes = sum(p.numel() * p.element_size() for p in model.parameters())
634
+ # model_size_mb = model_size_bytes / (1024 ** 2)
635
+ # print(f"Model size: {model_size_bytes} bytes ({model_size_mb:.2f} MB)")
636
+
637
+ if __name__ == "__main__":
638
+ main()
vq/codec_encoder.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ import torch
4
+ from torch import nn
5
+ import numpy as np
6
+ from vq.module import WNConv1d, EncoderBlock, ResLSTM
7
+ from vq.alias_free_torch import *
8
+ from vq import activations
9
+ from vq.bs_roformer5 import TransformerBlock
10
+
11
+ from torchtune.modules import RotaryPositionalEmbeddings
12
+ import vq.blocks as blocks
13
+ from torch.nn import utils
14
+ def init_weights(m):
15
+ if isinstance(m, nn.Conv1d):
16
+ nn.init.trunc_normal_(m.weight, std=0.02)
17
+ nn.init.constant_(m.bias, 0)
18
+
19
+ class CodecEncoder(nn.Module):
20
+ def __init__(self,
21
+ ngf=48,
22
+ use_rnn=True,
23
+ rnn_bidirectional=False,
24
+ rnn_num_layers=2,
25
+ up_ratios=(2, 2, 4, 4, 5),
26
+ dilations=(1, 3, 9),
27
+ out_channels=1024):
28
+ super().__init__()
29
+ self.hop_length = np.prod(up_ratios)
30
+ self.ngf = ngf
31
+ self.up_ratios = up_ratios
32
+
33
+ # Create first convolution
34
+ d_model = ngf
35
+ self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
36
+
37
+ # Create EncoderBlocks that double channels as they downsample by `stride`
38
+ for i, stride in enumerate(up_ratios):
39
+ d_model *= 2
40
+ self.block += [EncoderBlock(d_model, stride=stride, dilations=dilations)]
41
+ # RNN
42
+ if use_rnn:
43
+ self.block += [
44
+ ResLSTM(d_model,
45
+ num_layers=rnn_num_layers,
46
+ bidirectional=rnn_bidirectional
47
+ )
48
+ ]
49
+ # Create last convolution
50
+ self.block += [
51
+ Activation1d(activation=activations.SnakeBeta(d_model, alpha_logscale=True)),
52
+ WNConv1d(d_model, out_channels, kernel_size=3, padding=1),
53
+ ]
54
+
55
+ # Wrap black into nn.Sequential
56
+ self.block = nn.Sequential(*self.block)
57
+ self.enc_dim = d_model
58
+
59
+ self.reset_parameters()
60
+
61
+ def forward(self, x):
62
+ out = self.block(x)
63
+ return out
64
+
65
+ def inference(self, x):
66
+ return self.block(x)
67
+
68
+ def remove_weight_norm(self):
69
+ """Remove weight normalization module from all of the layers."""
70
+
71
+ def _remove_weight_norm(m):
72
+ try:
73
+ torch.nn.utils.remove_weight_norm(m)
74
+ except ValueError: # this module didn't have weight norm
75
+ return
76
+
77
+ self.apply(_remove_weight_norm)
78
+
79
+ def apply_weight_norm(self):
80
+ """Apply weight normalization module from all of the layers."""
81
+
82
+ def _apply_weight_norm(m):
83
+ if isinstance(m, nn.Conv1d):
84
+ torch.nn.utils.weight_norm(m)
85
+
86
+ self.apply(_apply_weight_norm)
87
+
88
+ def reset_parameters(self):
89
+ self.apply(init_weights)
90
+
91
+
92
+ class Transpose(nn.Module):
93
+ def __init__(self, dim1, dim2):
94
+ super(Transpose, self).__init__()
95
+ self.dim1 = dim1
96
+ self.dim2 = dim2
97
+
98
+ def forward(self, x):
99
+ return x.transpose(self.dim1, self.dim2)
100
+
101
+ class CodecEncoder_Transformer(nn.Module):
102
+ def __init__(self,
103
+ ngf=48,
104
+ up_ratios=[2, 2, 4, 4, 5],
105
+ dilations=(1, 3, 9),
106
+ hidden_dim=1024,
107
+ depth=12,
108
+ heads=12,
109
+ pos_meb_dim=64,
110
+ ):
111
+ super().__init__()
112
+ self.hop_length = np.prod(up_ratios)
113
+ self.ngf =ngf
114
+ self.up_ratios = up_ratios
115
+
116
+ d_model = ngf
117
+ self.conv_blocks = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
118
+
119
+
120
+ for i, stride in enumerate(up_ratios):
121
+ d_model *= 2
122
+ self.conv_blocks += [EncoderBlock(d_model, stride=stride, dilations=dilations)]
123
+
124
+ self.conv_blocks = nn.Sequential(*self.conv_blocks)
125
+
126
+
127
+ # time_rotary_embed = RotaryPositionalEmbeddings(dim=pos_meb_dim)
128
+
129
+
130
+ # transformer_blocks = [
131
+ # TransformerBlock(dim=hidden_dim, n_heads=heads, rotary_embed=time_rotary_embed)
132
+ # for _ in range(depth)
133
+ # ]
134
+
135
+
136
+ # self.transformers = nn.Sequential(*transformer_blocks)
137
+
138
+ # self.final_layer_norm = nn.LayerNorm(hidden_dim, eps=1e-6)
139
+
140
+ self.conv_final_block = [
141
+ Activation1d(activation=activations.SnakeBeta(d_model, alpha_logscale=True)),
142
+ WNConv1d(d_model, hidden_dim, kernel_size=3, padding=1),
143
+ ]
144
+ self.conv_final_block = nn.Sequential(*self.conv_final_block)
145
+
146
+ self.reset_parameters()
147
+
148
+ def forward(self, x):
149
+ x = self.conv_blocks(x)
150
+ # x = x.permute(0, 2, 1)
151
+ # x= self.transformers(x)
152
+ # x = self.final_layer_norm(x)
153
+ # x = x.permute(0, 2, 1)
154
+ x = self.conv_final_block (x)
155
+ x = x.permute(0, 2, 1)
156
+ return x
157
+
158
+ def inference(self, x):
159
+ return self.block(x)
160
+
161
+ def remove_weight_norm(self):
162
+ """Remove weight normalization module from all of the layers."""
163
+
164
+ def _remove_weight_norm(m):
165
+ try:
166
+ torch.nn.utils.remove_weight_norm(m)
167
+ except ValueError: # this module didn't have weight norm
168
+ return
169
+
170
+ self.apply(_remove_weight_norm)
171
+
172
+ def apply_weight_norm(self):
173
+ """Apply weight normalization module from all of the layers."""
174
+
175
+ def _apply_weight_norm(m):
176
+ if isinstance(m, nn.Conv1d):
177
+ torch.nn.utils.weight_norm(m)
178
+
179
+ self.apply(_apply_weight_norm)
180
+
181
+ def reset_parameters(self):
182
+ self.apply(init_weights)
183
+
184
+
185
+
186
+ class Codec_oobleck_Transformer(nn.Module):
187
+ def __init__(self,
188
+ ngf=32,
189
+ up_ratios=(2, 2,4,4, 5),
190
+ dilations=(1, 3, 9),
191
+ hidden_dim=1024,
192
+ depth=12,
193
+ heads=16,
194
+ pos_meb_dim=64,
195
+ ):
196
+ super().__init__()
197
+ self.hop_length = np.prod(up_ratios)
198
+ self.ngf =ngf
199
+ self.up_ratios = up_ratios
200
+ self.hidden_dim = hidden_dim
201
+
202
+
203
+ self.conv_blocks = blocks.DilatedResidualEncoder(
204
+ capacity=ngf,
205
+ dilated_unit=self.dilated_unit,
206
+ downsampling_unit=self.downsampling_unit,
207
+ ratios=up_ratios,
208
+ dilations=dilations,
209
+ pre_network_conv=self.pre_conv,
210
+ post_network_conv=self.post_conv,
211
+ )
212
+
213
+
214
+ time_rotary_embed = RotaryPositionalEmbeddings(dim=pos_meb_dim)
215
+
216
+ transformer_blocks = [
217
+ TransformerBlock(dim=hidden_dim, n_heads=heads, rotary_embed=time_rotary_embed)
218
+ for _ in range(depth)
219
+ ]
220
+
221
+ self.transformers = nn.Sequential(*transformer_blocks)
222
+
223
+ self.final_layer_norm = nn.LayerNorm(hidden_dim, eps=1e-6)
224
+
225
+
226
+ self.reset_parameters()
227
+
228
+ def forward(self, x):
229
+ x = self.conv_blocks(x)
230
+ x = x.permute(0, 2, 1)
231
+ x= self.transformers(x)
232
+ x = self.final_layer_norm(x)
233
+ return x
234
+
235
+ def inference(self, x):
236
+ return self.block(x)
237
+
238
+ def remove_weight_norm(self):
239
+ """Remove weight normalization module from all of the layers."""
240
+
241
+ def _remove_weight_norm(m):
242
+ try:
243
+ torch.nn.utils.remove_weight_norm(m)
244
+ except ValueError: # this module didn't have weight norm
245
+ return
246
+
247
+ self.apply(_remove_weight_norm)
248
+
249
+ def apply_weight_norm(self):
250
+ """Apply weight normalization module from all of the layers."""
251
+
252
+ def _apply_weight_norm(m):
253
+ if isinstance(m, nn.Conv1d):
254
+ torch.nn.utils.weight_norm(m)
255
+
256
+ self.apply(_apply_weight_norm)
257
+
258
+ def reset_parameters(self):
259
+ self.apply(init_weights)
260
+
261
+ def dilated_unit(self,hidden_dim, dilation):
262
+ return blocks.DilatedConvolutionalUnit(hidden_dim,
263
+ dilation,
264
+ kernel_size=3,
265
+ activation=nn.ReLU,
266
+ normalization=utils.weight_norm)
267
+
268
+ def downsampling_unit(self, input_dim: int, output_dim: int, stride: int):
269
+ return blocks.DownsamplingUnit(input_dim,
270
+ output_dim,
271
+ stride,
272
+ nn.ReLU,
273
+ normalization=utils.weight_norm)
274
+
275
+ def pre_conv(self,out_channels):
276
+ return nn.Conv1d(1, out_channels, 1)
277
+
278
+ def post_conv(self,in_channels):
279
+ return nn.Conv1d(in_channels, self.hidden_dim, 1)
280
+
281
+
282
+
283
+
284
+
285
+ class CodecEncoder_only_Transformer(nn.Module):
286
+ def __init__(self,hidden_dim=1024,depth=12,heads=16,pos_meb_dim=64):
287
+ super().__init__()
288
+ # self.embed = nn.Linear(input_dim, hidden_dim )input_dim=300,
289
+
290
+ depth = depth
291
+ time_rotary_embed = RotaryPositionalEmbeddings(dim=pos_meb_dim)
292
+
293
+
294
+ transformer_blocks = [
295
+ TransformerBlock(dim=hidden_dim, n_heads=heads, rotary_embed=time_rotary_embed)
296
+ for _ in range(depth)
297
+ ]
298
+
299
+
300
+ self.transformers = nn.Sequential(*transformer_blocks)
301
+
302
+ self.final_layer_norm = nn.LayerNorm(hidden_dim, eps=1e-6)
303
+
304
+ def forward(self, x: torch.Tensor ) -> torch.Tensor:
305
+ # x = self.embed(x)
306
+
307
+
308
+ x= self.transformers(x)
309
+ x = self.final_layer_norm(x)
310
+
311
+ return x
312
+
313
+
314
+
315
+
316
+
317
+
318
+
319
+ def get_model_size(model):
320
+ # 计算总参数数
321
+ total_params = sum(p.numel() for p in model.parameters())
322
+
323
+ # 假设每个参数都是32位浮点数,计算模型大小(以字节为单位)
324
+ model_size_bytes = total_params # 每个参数4字节
325
+
326
+ # 转换为更易读的单位(例如,MB)
327
+ model_size_mb = model_size_bytes / (1024 ** 2)
328
+
329
+ return total_params, model_size_mb
330
+
331
+ if __name__ == '__main__':
332
+ model = Codec_oobleck_Transformer()
333
+ x = torch.randn(1, 1, 16000) # example input tensor
334
+ output = model(x)
335
+ print("Output shape:", output.shape)
vq/factorized_vector_quantize.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from einops import rearrange
8
+ from torch.nn.utils import weight_norm
9
+
10
+ class FactorizedVectorQuantize(nn.Module):
11
+ def __init__(self, dim, codebook_size, codebook_dim, commitment, **kwargs):
12
+ super().__init__()
13
+ self.codebook_size = codebook_size
14
+ self.codebook_dim = codebook_dim
15
+ self.commitment = commitment
16
+
17
+ if dim != self.codebook_dim:
18
+ self.in_proj = weight_norm(nn.Linear(dim, self.codebook_dim))
19
+ self.out_proj = weight_norm(nn.Linear(self.codebook_dim, dim))
20
+ else:
21
+ self.in_proj = nn.Identity()
22
+ self.out_proj = nn.Identity()
23
+ self._codebook = nn.Embedding(codebook_size, self.codebook_dim)
24
+
25
+ @property
26
+ def codebook(self):
27
+ return self._codebook
28
+
29
+ def forward(self, z):
30
+ """Quantized the input tensor using a fixed codebook and returns
31
+ the corresponding codebook vectors
32
+
33
+ Parameters
34
+ ----------
35
+ z : Tensor[B x D x T]
36
+
37
+ Returns
38
+ -------
39
+ Tensor[B x D x T]
40
+ Quantized continuous representation of input
41
+ Tensor[1]
42
+ Commitment loss to train encoder to predict vectors closer to codebook
43
+ entries
44
+ Tensor[1]
45
+ Codebook loss to update the codebook
46
+ Tensor[B x T]
47
+ Codebook indices (quantized discrete representation of input)
48
+ Tensor[B x D x T]
49
+ Projected latents (continuous representation of input before quantization)
50
+ """
51
+ # transpose since we use linear
52
+
53
+ z = rearrange(z, "b d t -> b t d")
54
+
55
+ # Factorized codes project input into low-dimensional space
56
+ z_e = self.in_proj(z) # z_e : (B x T x D)
57
+ z_e = rearrange(z_e, "b t d -> b d t")
58
+ z_q, indices = self.decode_latents(z_e)
59
+
60
+
61
+ if self.training:
62
+ commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction='none').mean([1, 2]) * self.commitment
63
+ codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction='none').mean([1, 2])
64
+ commit_loss = commitment_loss + codebook_loss
65
+ else:
66
+ commit_loss = torch.zeros(z.shape[0], device = z.device)
67
+
68
+ z_q = (
69
+ z_e + (z_q - z_e).detach()
70
+ ) # noop in forward pass, straight-through gradient estimator in backward pass
71
+
72
+ z_q = rearrange(z_q, "b d t -> b t d")
73
+ z_q = self.out_proj(z_q)
74
+ z_q = rearrange(z_q, "b t d -> b d t")
75
+
76
+ return z_q, indices, commit_loss
77
+
78
+ def vq2emb(self, vq, proj=True):
79
+ emb = self.embed_code(vq)
80
+ if proj:
81
+ emb = self.out_proj(emb)
82
+ return emb
83
+
84
+ def get_emb(self):
85
+ return self.codebook.weight
86
+
87
+ def embed_code(self, embed_id):
88
+ return F.embedding(embed_id, self.codebook.weight)
89
+
90
+ def decode_code(self, embed_id):
91
+ return self.embed_code(embed_id).transpose(1, 2)
92
+
93
+ def decode_latents(self, latents):
94
+ encodings = rearrange(latents, "b d t -> (b t) d")
95
+ codebook = self.codebook.weight # codebook: (N x D)
96
+
97
+ # L2 normalize encodings and codebook
98
+ encodings = F.normalize(encodings)
99
+ codebook = F.normalize(codebook)
100
+
101
+ # Compute euclidean distance with codebook
102
+ dist = (
103
+ encodings.pow(2).sum(1, keepdim=True)
104
+ - 2 * encodings @ codebook.t()
105
+ + codebook.pow(2).sum(1, keepdim=True).t()
106
+ )
107
+ indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
108
+ z_q = self.decode_code(indices)
109
+ return z_q, indices
vq/module.py ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from einops import rearrange
3
+ from . import activations
4
+ from .alias_free_torch import *
5
+ from torch.nn.utils import weight_norm
6
+
7
+ from typing import Optional, Tuple
8
+
9
+ from torch.nn.utils import weight_norm, remove_weight_norm
10
+
11
+
12
+ def WNConv1d(*args, **kwargs):
13
+ return weight_norm(nn.Conv1d(*args, **kwargs))
14
+
15
+
16
+ def WNConvTranspose1d(*args, **kwargs):
17
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
18
+
19
+ class ResidualUnit(nn.Module):
20
+ def __init__(self, dim: int = 16, dilation: int = 1):
21
+ super().__init__()
22
+ pad = ((7 - 1) * dilation) // 2
23
+ self.block = nn.Sequential(
24
+ Activation1d(activation=activations.SnakeBeta(dim, alpha_logscale=True)),
25
+ WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
26
+ Activation1d(activation=activations.SnakeBeta(dim, alpha_logscale=True)),
27
+ WNConv1d(dim, dim, kernel_size=1),
28
+ )
29
+
30
+ def forward(self, x):
31
+ return x + self.block(x)
32
+
33
+ class EncoderBlock(nn.Module):
34
+ def __init__(self, dim: int = 16, stride: int = 1, dilations = (1, 3, 9)):
35
+ super().__init__()
36
+ runits = [ResidualUnit(dim // 2, dilation=d) for d in dilations]
37
+ self.block = nn.Sequential(
38
+ *runits,
39
+ Activation1d(activation=activations.SnakeBeta(dim//2, alpha_logscale=True)),
40
+ WNConv1d(
41
+ dim // 2,
42
+ dim,
43
+ kernel_size=2 * stride,
44
+ stride=stride,
45
+ padding=stride // 2 + stride % 2,
46
+ ),
47
+ )
48
+
49
+ def forward(self, x):
50
+ return self.block(x)
51
+
52
+ class DecoderBlock(nn.Module):
53
+ def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1, dilations = (1, 3, 9)):
54
+ super().__init__()
55
+ self.block = nn.Sequential(
56
+ Activation1d(activation=activations.SnakeBeta(input_dim, alpha_logscale=True)),
57
+ WNConvTranspose1d(
58
+ input_dim,
59
+ output_dim,
60
+ kernel_size=2 * stride,
61
+ stride=stride,
62
+ padding=stride // 2 + stride % 2,
63
+ output_padding= stride % 2,
64
+ )
65
+ )
66
+ self.block.extend([ResidualUnit(output_dim, dilation=d) for d in dilations])
67
+
68
+ def forward(self, x):
69
+ return self.block(x)
70
+
71
+ class ResLSTM(nn.Module):
72
+ def __init__(self, dimension: int,
73
+ num_layers: int = 2,
74
+ bidirectional: bool = False,
75
+ skip: bool = True):
76
+ super().__init__()
77
+ self.skip = skip
78
+ self.lstm = nn.LSTM(dimension, dimension if not bidirectional else dimension // 2,
79
+ num_layers, batch_first=True,
80
+ bidirectional=bidirectional)
81
+
82
+ def forward(self, x):
83
+ """
84
+ Args:
85
+ x: [B, F, T]
86
+
87
+ Returns:
88
+ y: [B, F, T]
89
+ """
90
+ x = rearrange(x, "b f t -> b t f")
91
+ y, _ = self.lstm(x)
92
+ if self.skip:
93
+ y = y + x
94
+ y = rearrange(y, "b t f -> b f t")
95
+ return y
96
+
97
+
98
+
99
+ class ConvNeXtBlock(nn.Module):
100
+ """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
101
+
102
+ Args:
103
+ dim (int): Number of input channels.
104
+ intermediate_dim (int): Dimensionality of the intermediate layer.
105
+ layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
106
+ Defaults to None.
107
+ adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
108
+ None means non-conditional LayerNorm. Defaults to None.
109
+ """
110
+
111
+ def __init__(
112
+ self,
113
+ dim: int,
114
+ intermediate_dim: int,
115
+ layer_scale_init_value: float,
116
+ adanorm_num_embeddings: Optional[int] = None,
117
+ ):
118
+ super().__init__()
119
+ self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
120
+ self.adanorm = adanorm_num_embeddings is not None
121
+ if adanorm_num_embeddings:
122
+ self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
123
+ else:
124
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
125
+ self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
126
+ self.act = nn.GELU()
127
+ self.pwconv2 = nn.Linear(intermediate_dim, dim)
128
+ self.gamma = (
129
+ nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
130
+ if layer_scale_init_value > 0
131
+ else None
132
+ )
133
+
134
+ def forward(self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None) -> torch.Tensor:
135
+ residual = x
136
+ x = self.dwconv(x)
137
+ x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
138
+ if self.adanorm:
139
+ assert cond_embedding_id is not None
140
+ x = self.norm(x, cond_embedding_id)
141
+ else:
142
+ x = self.norm(x)
143
+ x = self.pwconv1(x)
144
+ x = self.act(x)
145
+ x = self.pwconv2(x)
146
+ if self.gamma is not None:
147
+ x = self.gamma * x
148
+ x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
149
+
150
+ x = residual + x
151
+ return x
152
+
153
+
154
+ class AdaLayerNorm(nn.Module):
155
+ """
156
+ Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes
157
+
158
+ Args:
159
+ num_embeddings (int): Number of embeddings.
160
+ embedding_dim (int): Dimension of the embeddings.
161
+ """
162
+
163
+ def __init__(self, num_embeddings: int, embedding_dim: int, eps: float = 1e-6):
164
+ super().__init__()
165
+ self.eps = eps
166
+ self.dim = embedding_dim
167
+ self.scale = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
168
+ self.shift = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
169
+ torch.nn.init.ones_(self.scale.weight)
170
+ torch.nn.init.zeros_(self.shift.weight)
171
+
172
+ def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor) -> torch.Tensor:
173
+ scale = self.scale(cond_embedding_id)
174
+ shift = self.shift(cond_embedding_id)
175
+ x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps)
176
+ x = x * scale + shift
177
+ return x
178
+
179
+
180
+ class ResBlock1(nn.Module):
181
+ """
182
+ ResBlock adapted from HiFi-GAN V1 (https://github.com/jik876/hifi-gan) with dilated 1D convolutions,
183
+ but without upsampling layers.
184
+
185
+ Args:
186
+ dim (int): Number of input channels.
187
+ kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3.
188
+ dilation (tuple[int], optional): Dilation factors for the dilated convolutions.
189
+ Defaults to (1, 3, 5).
190
+ lrelu_slope (float, optional): Negative slope of the LeakyReLU activation function.
191
+ Defaults to 0.1.
192
+ layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
193
+ Defaults to None.
194
+ """
195
+
196
+ def __init__(
197
+ self,
198
+ dim: int,
199
+ kernel_size: int = 3,
200
+ dilation: Tuple[int, int, int] = (1, 3, 5),
201
+ lrelu_slope: float = 0.1,
202
+ layer_scale_init_value: Optional[float] = None,
203
+ ):
204
+ super().__init__()
205
+ self.lrelu_slope = lrelu_slope
206
+ self.convs1 = nn.ModuleList(
207
+ [
208
+ weight_norm(
209
+ nn.Conv1d(
210
+ dim,
211
+ dim,
212
+ kernel_size,
213
+ 1,
214
+ dilation=dilation[0],
215
+ padding=self.get_padding(kernel_size, dilation[0]),
216
+ )
217
+ ),
218
+ weight_norm(
219
+ nn.Conv1d(
220
+ dim,
221
+ dim,
222
+ kernel_size,
223
+ 1,
224
+ dilation=dilation[1],
225
+ padding=self.get_padding(kernel_size, dilation[1]),
226
+ )
227
+ ),
228
+ weight_norm(
229
+ nn.Conv1d(
230
+ dim,
231
+ dim,
232
+ kernel_size,
233
+ 1,
234
+ dilation=dilation[2],
235
+ padding=self.get_padding(kernel_size, dilation[2]),
236
+ )
237
+ ),
238
+ ]
239
+ )
240
+
241
+ self.convs2 = nn.ModuleList(
242
+ [
243
+ weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))),
244
+ weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))),
245
+ weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))),
246
+ ]
247
+ )
248
+
249
+ self.gamma = nn.ParameterList(
250
+ [
251
+ nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True)
252
+ if layer_scale_init_value is not None
253
+ else None,
254
+ nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True)
255
+ if layer_scale_init_value is not None
256
+ else None,
257
+ nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True)
258
+ if layer_scale_init_value is not None
259
+ else None,
260
+ ]
261
+ )
262
+
263
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
264
+ for c1, c2, gamma in zip(self.convs1, self.convs2, self.gamma):
265
+ xt = torch.nn.functional.leaky_relu(x, negative_slope=self.lrelu_slope)
266
+ xt = c1(xt)
267
+ xt = torch.nn.functional.leaky_relu(xt, negative_slope=self.lrelu_slope)
268
+ xt = c2(xt)
269
+ if gamma is not None:
270
+ xt = gamma * xt
271
+ x = xt + x
272
+ return x
273
+
274
+ def remove_weight_norm(self):
275
+ for l in self.convs1:
276
+ remove_weight_norm(l)
277
+ for l in self.convs2:
278
+ remove_weight_norm(l)
279
+
280
+ @staticmethod
281
+ def get_padding(kernel_size: int, dilation: int = 1) -> int:
282
+ return int((kernel_size * dilation - dilation) / 2)
283
+
284
+
285
+ def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor:
286
+ """
287
+ Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values.
288
+
289
+ Args:
290
+ x (Tensor): Input tensor.
291
+ clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7.
292
+
293
+ Returns:
294
+ Tensor: Element-wise logarithm of the input tensor with clipping applied.
295
+ """
296
+ return torch.log(torch.clip(x, min=clip_val))
297
+
298
+
299
+ def symlog(x: torch.Tensor) -> torch.Tensor:
300
+ return torch.sign(x) * torch.log1p(x.abs())
301
+
302
+
303
+ def symexp(x: torch.Tensor) -> torch.Tensor:
304
+ return torch.sign(x) * (torch.exp(x.abs()) - 1)
305
+
306
+
307
+
308
+ class SemanticEncoder(nn.Module):
309
+ def __init__(
310
+ self,
311
+ input_channels: int,
312
+ code_dim: int,
313
+ encode_channels: int,
314
+ kernel_size: int = 3,
315
+ bias: bool = True,
316
+ ):
317
+ super(SemanticEncoder, self).__init__()
318
+
319
+ # 初始卷积,将 input_channels 映射到 encode_channels
320
+ self.initial_conv = nn.Conv1d(
321
+ in_channels=input_channels,
322
+ out_channels=encode_channels,
323
+ kernel_size=kernel_size,
324
+ stride=1,
325
+ padding=(kernel_size - 1) // 2,
326
+ bias=False
327
+ )
328
+
329
+ # 残差块
330
+ self.residual_blocks = nn.Sequential(
331
+ nn.ReLU(inplace=True),
332
+ nn.Conv1d(
333
+ encode_channels,
334
+ encode_channels,
335
+ kernel_size=kernel_size,
336
+ stride=1,
337
+ padding=(kernel_size - 1) // 2,
338
+ bias=bias
339
+ ),
340
+ nn.ReLU(inplace=True),
341
+ nn.Conv1d(
342
+ encode_channels,
343
+ encode_channels,
344
+ kernel_size=kernel_size,
345
+ stride=1,
346
+ padding=(kernel_size - 1) // 2,
347
+ bias=bias
348
+ )
349
+ )
350
+
351
+ # 最终卷积,将 encode_channels 映射到 code_dim
352
+ self.final_conv = nn.Conv1d(
353
+ in_channels=encode_channels,
354
+ out_channels=code_dim,
355
+ kernel_size=kernel_size,
356
+ stride=1,
357
+ padding=(kernel_size - 1) // 2,
358
+ bias=False
359
+ )
360
+
361
+ def forward(self, x):
362
+ """
363
+ 前向传播方法。
364
+
365
+ Args:
366
+ x (Tensor): 输入张量,形状为 (Batch, Input_channels, Length)
367
+
368
+ Returns:
369
+ Tensor: 编码后的张量,形状为 (Batch, Code_dim, Length)
370
+ """
371
+ x = self.initial_conv(x) # (Batch, Encode_channels, Length)
372
+ x = self.residual_blocks(x) + x # 残差连接
373
+ x = self.final_conv(x) # (Batch, Code_dim, Length)
374
+ return x
375
+
376
+ class SemanticDecoder(nn.Module):
377
+ def __init__(
378
+ self,
379
+ code_dim: int,
380
+ output_channels: int,
381
+ decode_channels: int,
382
+ kernel_size: int = 3,
383
+ bias: bool = True,
384
+ ):
385
+ super(SemanticDecoder, self).__init__()
386
+
387
+ # Initial convolution to map code_dim to decode_channels
388
+ self.initial_conv = nn.Conv1d(
389
+ in_channels=code_dim,
390
+ out_channels=decode_channels,
391
+ kernel_size=kernel_size,
392
+ stride=1,
393
+ padding=(kernel_size - 1) // 2,
394
+ bias=False
395
+ )
396
+
397
+ # Residual Blocks
398
+ self.residual_blocks = nn.Sequential(
399
+ nn.ReLU(inplace=True),
400
+ nn.Conv1d(decode_channels, decode_channels, kernel_size=kernel_size, stride=1, padding=(kernel_size - 1) // 2, bias=bias),
401
+ nn.ReLU(inplace=True),
402
+ nn.Conv1d(decode_channels, decode_channels, kernel_size=kernel_size, stride=1, padding=(kernel_size - 1) // 2, bias=bias)
403
+ )
404
+
405
+ # Final convolution to map decode_channels to output_channels
406
+ self.final_conv = nn.Conv1d(
407
+ in_channels=decode_channels,
408
+ out_channels=output_channels,
409
+ kernel_size=kernel_size,
410
+ stride=1,
411
+ padding=(kernel_size - 1) // 2,
412
+ bias=False
413
+ )
414
+
415
+ def forward(self, z):
416
+ # z: (Batch, Code_dim, Length)
417
+ x = self.initial_conv(z) # (Batch, Decode_channels, Length)
418
+ x = self.residual_blocks(x) + x # Residual connection
419
+ x = self.final_conv(x) # (Batch, Output_channels, Length)
420
+ return x
vq/residual_vq.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from .factorized_vector_quantize import FactorizedVectorQuantize
5
+
6
+ class ResidualVQ(nn.Module):
7
+ def __init__(
8
+ self,
9
+ *,
10
+ num_quantizers,
11
+ codebook_size,
12
+ **kwargs
13
+ ):
14
+ super().__init__()
15
+ VQ = FactorizedVectorQuantize
16
+ if type(codebook_size) == int:
17
+ codebook_size = [codebook_size] * num_quantizers
18
+ self.layers = nn.ModuleList([VQ(codebook_size=size, **kwargs) for size in codebook_size])
19
+ self.num_quantizers = num_quantizers
20
+
21
+ def forward(self, x):
22
+ quantized_out = 0.
23
+ residual = x
24
+
25
+ all_losses = []
26
+ all_indices = []
27
+
28
+ for idx, layer in enumerate(self.layers):
29
+ quantized, indices, loss = layer(residual)
30
+
31
+ residual = residual - quantized
32
+
33
+ quantized_out = quantized_out + quantized
34
+
35
+ loss = loss.mean()
36
+
37
+ all_indices.append(indices)
38
+ all_losses.append(loss)
39
+ all_losses, all_indices = map(torch.stack, (all_losses, all_indices))
40
+ return quantized_out, all_indices, all_losses
41
+
42
+ def vq2emb(self, vq, proj=True):
43
+ # [B, T, num_quantizers]
44
+ quantized_out = 0.
45
+ for idx, layer in enumerate(self.layers):
46
+ quantized = layer.vq2emb(vq[:, :, idx], proj=proj)
47
+ quantized_out = quantized_out + quantized
48
+ return quantized_out
49
+ def get_emb(self):
50
+ embs = []
51
+ for idx, layer in enumerate(self.layers):
52
+ embs.append(layer.get_emb())
53
+ return embs
vq/unet.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from einops import rearrange
5
+ import numpy as np
6
+
7
+
8
+ class EncoderBlock(nn.Module):
9
+ def __init__(self, in_channels, out_channels, kernel_size=(3, 3)):
10
+ super(EncoderBlock, self).__init__()
11
+
12
+ self.pool_size = 2
13
+
14
+ self.conv_block = ConvBlock(in_channels, out_channels, kernel_size)
15
+
16
+ def forward(self, x):
17
+ latent = self.conv_block(x)
18
+ output = F.avg_pool2d(latent, kernel_size=self.pool_size)
19
+ return output, latent
20
+
21
+ class DecoderBlock(nn.Module):
22
+ def __init__(self, in_channels, out_channels, kernel_size=(3, 3)):
23
+ super(DecoderBlock, self).__init__()
24
+
25
+ stride = 2
26
+
27
+ self.upsample = nn.ConvTranspose2d(
28
+ in_channels=in_channels,
29
+ out_channels=in_channels,
30
+ kernel_size=stride,
31
+ stride=stride,
32
+ padding=(0, 0),
33
+ bias=False,
34
+ )
35
+
36
+ self.conv_block = ConvBlock(in_channels * 2, out_channels, kernel_size)
37
+
38
+ def forward(self, x, latent):
39
+ x = self.upsample(x)
40
+ x = torch.cat((x, latent), dim=1)
41
+ output = self.conv_block(x)
42
+ return output
43
+
44
+
45
+ class UNet(nn.Module):
46
+ def __init__(self,freq_dim=1281,out_channel=1024):
47
+ super(UNet, self).__init__()
48
+
49
+ self.downsample_ratio = 16
50
+
51
+
52
+ in_channels = 1 #self.audio_channels * self.cmplx_num
53
+
54
+ self.encoder_block1 = EncoderBlock(in_channels, 16)
55
+ self.encoder_block2 = EncoderBlock(16, 64)
56
+ self.encoder_block3 = EncoderBlock(64, 256)
57
+ self.encoder_block4 = EncoderBlock(256, 1024)
58
+ self.middle = EncoderBlock(1024, 1024)
59
+ self.decoder_block1 = DecoderBlock(1024, 256)
60
+ self.decoder_block2 = DecoderBlock(256, 64)
61
+ self.decoder_block3 = DecoderBlock(64, 16)
62
+ self.decoder_block4 = DecoderBlock(16, 16)
63
+
64
+ self.fc = nn.Linear(freq_dim*16, out_channel)
65
+
66
+ def forward(self, x_ori):
67
+ """
68
+ Args:
69
+ complex_sp: (batch_size, channels_num, time_steps, freq_bins),复数张量
70
+
71
+ Returns:
72
+ output: (batch_size, channels_num, time_steps, freq_bins),复数张量
73
+ """
74
+
75
+
76
+ x= self.process_image(x_ori)
77
+ x1, latent1 = self.encoder_block1(x)
78
+ x2, latent2 = self.encoder_block2(x1)
79
+ x3, latent3 = self.encoder_block3(x2)
80
+ x4, latent4 = self.encoder_block4(x3)
81
+ _, h = self.middle(x4)
82
+ x5 = self.decoder_block1(h, latent4)
83
+ x6 = self.decoder_block2(x5, latent3)
84
+ x7 = self.decoder_block3(x6, latent2)
85
+ x8 = self.decoder_block4(x7, latent1)
86
+ x= self.unprocess_image(x8,x_ori.shape[2])
87
+ x = x.permute(0, 2, 1, 3).contiguous() # 将形状变为 [6, 256, 16, 1024]
88
+ x = x.view(x.size(0), x.size(1), -1)
89
+ x= self.fc(x)
90
+
91
+ return x
92
+
93
+ def process_image(self, x):
94
+ """
95
+ 处理频谱以便可以被 downsample_ratio 整除。
96
+
97
+ Args:
98
+ x: (B, C, T, F)
99
+
100
+ Returns:
101
+ output: (B, C, T_padded, F_reduced)
102
+ """
103
+
104
+ B, C, T, Freq = x.shape
105
+
106
+ pad_len = (
107
+ int(np.ceil(T / self.downsample_ratio)) * self.downsample_ratio
108
+ - T
109
+ )
110
+ x = F.pad(x, pad=(0, 0, 0, pad_len))
111
+
112
+ output = x[:, :, :, 0 : Freq - 1]
113
+
114
+ return output
115
+
116
+ def unprocess_image(self, x,time_steps):
117
+ """
118
+ 恢复频谱到原始形状。
119
+
120
+ Args:
121
+ x: (B, C, T_padded, F_reduced)
122
+
123
+ Returns:
124
+ output: (B, C, T_original, F_original)
125
+ """
126
+ x = F.pad(x, pad=(0, 1))
127
+
128
+ output = x[:, :,0:time_steps, :]
129
+
130
+ return output
131
+
132
+ class ConvBlock(nn.Module):
133
+ def __init__(self, in_channels, out_channels, kernel_size=(3, 3)):
134
+ super(ConvBlock, self).__init__()
135
+
136
+ padding = [kernel_size[0] // 2, kernel_size[1] // 2]
137
+
138
+ self.bn1 = nn.BatchNorm2d(in_channels)
139
+ self.bn2 = nn.BatchNorm2d(out_channels)
140
+
141
+ self.conv1 = nn.Conv2d(
142
+ in_channels=in_channels,
143
+ out_channels=out_channels,
144
+ kernel_size=kernel_size,
145
+ padding=padding,
146
+ bias=False,
147
+ )
148
+
149
+ self.conv2 = nn.Conv2d(
150
+ in_channels=out_channels,
151
+ out_channels=out_channels,
152
+ kernel_size=kernel_size,
153
+ padding=padding,
154
+ bias=False,
155
+ )
156
+
157
+ if in_channels != out_channels:
158
+ self.shortcut = nn.Conv2d(
159
+ in_channels=in_channels,
160
+ out_channels=out_channels,
161
+ kernel_size=(1, 1),
162
+ padding=(0, 0),
163
+ )
164
+ self.is_shortcut = True
165
+ else:
166
+ self.is_shortcut = False
167
+
168
+ def forward(self, x):
169
+ h = self.conv1(F.leaky_relu_(self.bn1(x)))
170
+ h = self.conv2(F.leaky_relu_(self.bn2(h)))
171
+
172
+ if self.is_shortcut:
173
+ return self.shortcut(x) + h
174
+ else:
175
+ return x + h
176
+
177
+
178
+ def test_unet():
179
+ # 定义输入参数
180
+ batch_size = 6
181
+ channels = 1 # 音频通道数
182
+ time_steps = 256 # 时间步数
183
+ freq_bins = 1024 # 频率 bins 数
184
+
185
+ # 创建一个随机的复数张量作为输入
186
+ real_part = torch.randn(batch_size, channels, time_steps, freq_bins)
187
+ imag_part = torch.randn(batch_size, channels, time_steps, freq_bins)
188
+ complex_sp = real_part #torch.complex(real_part, imag_part)
189
+
190
+ # 实例化 UNet 模型
191
+ model = UNet()
192
+
193
+ # 前向传播
194
+ output = model(complex_sp)
195
+
196
+ # 输出输入和输出的形状
197
+ print("输入形状:", complex_sp.shape)
198
+ print("输出形状:", output.shape)
199
+
200
+ # 检查输出是否为复数张量
201
+ assert torch.is_complex(output), "输出不是复数张量"
202
+
203
+ # 检查输出形状是否与输入形状一致
204
+ assert output.shape == complex_sp.shape, "输出形状与输入形状不一致"
205
+
206
+ print("测试通过,模型正常工作。")
207
+
208
+ # 运行测试函数
209
+ if __name__ == "__main__":
210
+ test_unet()