harryjulian commited on
Commit
2dfcd66
·
1 Parent(s): 6a9bb58

removed code

Browse files
.gitattributes DELETED
@@ -1,35 +0,0 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.gitignore DELETED
@@ -1,189 +0,0 @@
1
- # Emacs
2
- *~
3
-
4
- # Byte-compiled / optimized / DLL files
5
- __pycache__/
6
- *.py[cod]
7
- *$py.class
8
-
9
- # C extensions
10
- *.so
11
-
12
- # Distribution / packaging
13
- .Python
14
- build/
15
- develop-eggs/
16
- dist/
17
- downloads/
18
- eggs/
19
- .eggs/
20
- lib/
21
- lib64/
22
- parts/
23
- sdist/
24
- var/
25
- wheels/
26
- share/python-wheels/
27
- *.egg-info/
28
- .installed.cfg
29
- *.egg
30
- MANIFEST
31
- /runs
32
- /checkpoints
33
- /base
34
-
35
- # PyInstaller
36
- # Usually these files are written by a python script from a template
37
- # before PyInstaller builds the exe, so as to inject date/other infos into it.
38
- *.manifest
39
- *.spec
40
-
41
- # Installer logs
42
- pip-log.txt
43
- pip-delete-this-directory.txt
44
-
45
- # Unit test / coverage reports
46
- htmlcov/
47
- .tox/
48
- .nox/
49
- .coverage
50
- .coverage.*
51
- .cache
52
- nosetests.xml
53
- coverage.xml
54
- *.cover
55
- *.py,cover
56
- .hypothesis/
57
- .pytest_cache/
58
- cover/
59
-
60
- # Translations
61
- *.mo
62
- *.pot
63
-
64
- # Django stuff:
65
- *.log
66
- local_settings.py
67
- db.sqlite3
68
- db.sqlite3-journal
69
-
70
- # Flask stuff:
71
- instance/
72
- .webassets-cache
73
-
74
- # Scrapy stuff:
75
- .scrapy
76
-
77
- # Sphinx documentation
78
- docs/_build/
79
-
80
- # PyBuilder
81
- .pybuilder/
82
- target/
83
-
84
- # Jupyter Notebook
85
- .ipynb_checkpoints
86
-
87
- # IPython
88
- profile_default/
89
- ipython_config.py
90
-
91
- # pyenv
92
- # For a library or package, you might want to ignore these files since the code is
93
- # intended to run in multiple environments; otherwise, check them in:
94
- # .python-version
95
-
96
- # pipenv
97
- # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
98
- # However, in case of collaboration, if having platform-specific dependencies or dependencies
99
- # having no cross-platform support, pipenv may install dependencies that don't work, or not
100
- # install all needed dependencies.
101
- #Pipfile.lock
102
-
103
- # poetry
104
- # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
- # This is especially recommended for binary packages to ensure reproducibility, and is more
106
- # commonly ignored for libraries.
107
- # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
- #poetry.lock
109
-
110
- # pdm
111
- # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112
- #pdm.lock
113
- # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114
- # in version control.
115
- # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116
- .pdm.toml
117
- .pdm-python
118
- .pdm-build/
119
-
120
- # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121
- __pypackages__/
122
-
123
- # Celery stuff
124
- celerybeat-schedule
125
- celerybeat.pid
126
-
127
- # SageMath parsed files
128
- *.sage.py
129
-
130
- # Environments
131
- .env
132
- .venv
133
- env/
134
- venv/
135
- ENV/
136
- env.bak/
137
- venv.bak/
138
-
139
- # Spyder project settings
140
- .spyderproject
141
- .spyproject
142
-
143
- # Rope project settings
144
- .ropeproject
145
-
146
- # mkdocs documentation
147
- /site
148
-
149
- # mypy
150
- .mypy_cache/
151
- .dmypy.json
152
- dmypy.json
153
-
154
- # Pyre type checker
155
- .pyre/
156
-
157
- # pytype static type analyzer
158
- .pytype/
159
-
160
- # Cython debug symbols
161
- cython_debug/
162
-
163
- # PyCharm
164
- # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165
- # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166
- # and can be added to the global gitignore or merged into this file. For a more nuclear
167
- # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168
- #.idea/
169
-
170
- /runs
171
- /.cache
172
- /__pycache__
173
-
174
- *.wav
175
- *.pth
176
- *.pt
177
- *.pt.gz
178
- wandb/
179
- sven_latest_checkpoint/
180
- sven_qwen/
181
- pretrained_models/
182
- xcodec/
183
- small_speaker_shards_all/
184
- sven_all_shards/
185
- qwen_380k/
186
- evals/
187
- *.safetensors
188
- *.pt
189
- .ruff_cache
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
neucodec/__init__.py DELETED
@@ -1,3 +0,0 @@
1
- from .codec_encoder import CodecEncoder
2
- from .codec_decoder_vocos import CodecDecoderVocos
3
- from .model import NeuCodec
 
 
 
 
neucodec/activations.py DELETED
@@ -1,120 +0,0 @@
1
- # Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
2
- # LICENSE is in incl_licenses directory.
3
-
4
- import torch
5
- from torch import nn, sin, pow
6
- from torch.nn import Parameter
7
-
8
-
9
- class Snake(nn.Module):
10
- '''
11
- Implementation of a sine-based periodic activation function
12
- Shape:
13
- - Input: (B, C, T)
14
- - Output: (B, C, T), same shape as the input
15
- Parameters:
16
- - alpha - trainable parameter
17
- References:
18
- - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
19
- https://arxiv.org/abs/2006.08195
20
- Examples:
21
- >>> a1 = snake(256)
22
- >>> x = torch.randn(256)
23
- >>> x = a1(x)
24
- '''
25
- def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
26
- '''
27
- Initialization.
28
- INPUT:
29
- - in_features: shape of the input
30
- - alpha: trainable parameter
31
- alpha is initialized to 1 by default, higher values = higher-frequency.
32
- alpha will be trained along with the rest of your model.
33
- '''
34
- super(Snake, self).__init__()
35
- self.in_features = in_features
36
-
37
- # initialize alpha
38
- self.alpha_logscale = alpha_logscale
39
- if self.alpha_logscale: # log scale alphas initialized to zeros
40
- self.alpha = Parameter(torch.zeros(in_features) * alpha)
41
- else: # linear scale alphas initialized to ones
42
- self.alpha = Parameter(torch.ones(in_features) * alpha)
43
-
44
- self.alpha.requires_grad = alpha_trainable
45
-
46
- self.no_div_by_zero = 0.000000001
47
-
48
- def forward(self, x):
49
- '''
50
- Forward pass of the function.
51
- Applies the function to the input elementwise.
52
- Snake ∶= x + 1/a * sin^2 (xa)
53
- '''
54
- alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
55
- if self.alpha_logscale:
56
- alpha = torch.exp(alpha)
57
- x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
58
-
59
- return x
60
-
61
-
62
- class SnakeBeta(nn.Module):
63
- '''
64
- A modified Snake function which uses separate parameters for the magnitude of the periodic components
65
- Shape:
66
- - Input: (B, C, T)
67
- - Output: (B, C, T), same shape as the input
68
- Parameters:
69
- - alpha - trainable parameter that controls frequency
70
- - beta - trainable parameter that controls magnitude
71
- References:
72
- - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
73
- https://arxiv.org/abs/2006.08195
74
- Examples:
75
- >>> a1 = snakebeta(256)
76
- >>> x = torch.randn(256)
77
- >>> x = a1(x)
78
- '''
79
- def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
80
- '''
81
- Initialization.
82
- INPUT:
83
- - in_features: shape of the input
84
- - alpha - trainable parameter that controls frequency
85
- - beta - trainable parameter that controls magnitude
86
- alpha is initialized to 1 by default, higher values = higher-frequency.
87
- beta is initialized to 1 by default, higher values = higher-magnitude.
88
- alpha will be trained along with the rest of your model.
89
- '''
90
- super(SnakeBeta, self).__init__()
91
- self.in_features = in_features
92
-
93
- # initialize alpha
94
- self.alpha_logscale = alpha_logscale
95
- if self.alpha_logscale: # log scale alphas initialized to zeros
96
- self.alpha = Parameter(torch.zeros(in_features) * alpha)
97
- self.beta = Parameter(torch.zeros(in_features) * alpha)
98
- else: # linear scale alphas initialized to ones
99
- self.alpha = Parameter(torch.ones(in_features) * alpha)
100
- self.beta = Parameter(torch.ones(in_features) * alpha)
101
-
102
- self.alpha.requires_grad = alpha_trainable
103
- self.beta.requires_grad = alpha_trainable
104
-
105
- self.no_div_by_zero = 0.000000001
106
-
107
- def forward(self, x):
108
- '''
109
- Forward pass of the function.
110
- Applies the function to the input elementwise.
111
- SnakeBeta ∶= x + 1/b * sin^2 (xa)
112
- '''
113
- alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
114
- beta = self.beta.unsqueeze(0).unsqueeze(-1)
115
- if self.alpha_logscale:
116
- alpha = torch.exp(alpha)
117
- beta = torch.exp(beta)
118
- x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
119
-
120
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
neucodec/alias_free_torch/__init__.py DELETED
@@ -1,6 +0,0 @@
1
- # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
- # LICENSE is in incl_licenses directory.
3
-
4
- from .filter import *
5
- from .resample import *
6
- from .act import *
 
 
 
 
 
 
 
neucodec/alias_free_torch/act.py DELETED
@@ -1,28 +0,0 @@
1
- # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
- # LICENSE is in incl_licenses directory.
3
-
4
- import torch.nn as nn
5
- from .resample import UpSample1d, DownSample1d
6
-
7
-
8
- class Activation1d(nn.Module):
9
- def __init__(self,
10
- activation,
11
- up_ratio: int = 2,
12
- down_ratio: int = 2,
13
- up_kernel_size: int = 12,
14
- down_kernel_size: int = 12):
15
- super().__init__()
16
- self.up_ratio = up_ratio
17
- self.down_ratio = down_ratio
18
- self.act = activation
19
- self.upsample = UpSample1d(up_ratio, up_kernel_size)
20
- self.downsample = DownSample1d(down_ratio, down_kernel_size)
21
-
22
- # x: [B,C,T]
23
- def forward(self, x):
24
- x = self.upsample(x)
25
- x = self.act(x)
26
- x = self.downsample(x)
27
-
28
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
neucodec/alias_free_torch/filter.py DELETED
@@ -1,95 +0,0 @@
1
- # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
- # LICENSE is in incl_licenses directory.
3
-
4
- import torch
5
- import torch.nn as nn
6
- import torch.nn.functional as F
7
- import math
8
-
9
- if 'sinc' in dir(torch):
10
- sinc = torch.sinc
11
- else:
12
- # This code is adopted from adefossez's julius.core.sinc under the MIT License
13
- # https://adefossez.github.io/julius/julius/core.html
14
- # LICENSE is in incl_licenses directory.
15
- def sinc(x: torch.Tensor):
16
- """
17
- Implementation of sinc, i.e. sin(pi * x) / (pi * x)
18
- __Warning__: Different to julius.sinc, the input is multiplied by `pi`!
19
- """
20
- return torch.where(x == 0,
21
- torch.tensor(1., device=x.device, dtype=x.dtype),
22
- torch.sin(math.pi * x) / math.pi / x)
23
-
24
-
25
- # This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
26
- # https://adefossez.github.io/julius/julius/lowpass.html
27
- # LICENSE is in incl_licenses directory.
28
- def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size]
29
- even = (kernel_size % 2 == 0)
30
- half_size = kernel_size // 2
31
-
32
- #For kaiser window
33
- delta_f = 4 * half_width
34
- A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
35
- if A > 50.:
36
- beta = 0.1102 * (A - 8.7)
37
- elif A >= 21.:
38
- beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.)
39
- else:
40
- beta = 0.
41
- window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
42
-
43
- # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
44
- if even:
45
- time = (torch.arange(-half_size, half_size) + 0.5)
46
- else:
47
- time = torch.arange(kernel_size) - half_size
48
- if cutoff == 0:
49
- filter_ = torch.zeros_like(time)
50
- else:
51
- filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
52
- # Normalize filter to have sum = 1, otherwise we will have a small leakage
53
- # of the constant component in the input signal.
54
- filter_ /= filter_.sum()
55
- filter = filter_.view(1, 1, kernel_size)
56
-
57
- return filter
58
-
59
-
60
- class LowPassFilter1d(nn.Module):
61
- def __init__(self,
62
- cutoff=0.5,
63
- half_width=0.6,
64
- stride: int = 1,
65
- padding: bool = True,
66
- padding_mode: str = 'replicate',
67
- kernel_size: int = 12):
68
- # kernel_size should be even number for stylegan3 setup,
69
- # in this implementation, odd number is also possible.
70
- super().__init__()
71
- if cutoff < -0.:
72
- raise ValueError("Minimum cutoff must be larger than zero.")
73
- if cutoff > 0.5:
74
- raise ValueError("A cutoff above 0.5 does not make sense.")
75
- self.kernel_size = kernel_size
76
- self.even = (kernel_size % 2 == 0)
77
- self.pad_left = kernel_size // 2 - int(self.even)
78
- self.pad_right = kernel_size // 2
79
- self.stride = stride
80
- self.padding = padding
81
- self.padding_mode = padding_mode
82
- filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
83
- self.register_buffer("filter", filter)
84
-
85
- #input [B, C, T]
86
- def forward(self, x):
87
- _, C, _ = x.shape
88
-
89
- if self.padding:
90
- x = F.pad(x, (self.pad_left, self.pad_right),
91
- mode=self.padding_mode)
92
- out = F.conv1d(x, self.filter.expand(C, -1, -1),
93
- stride=self.stride, groups=C)
94
-
95
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
neucodec/alias_free_torch/resample.py DELETED
@@ -1,49 +0,0 @@
1
- # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
- # LICENSE is in incl_licenses directory.
3
-
4
- import torch.nn as nn
5
- from torch.nn import functional as F
6
- from .filter import LowPassFilter1d
7
- from .filter import kaiser_sinc_filter1d
8
-
9
-
10
- class UpSample1d(nn.Module):
11
- def __init__(self, ratio=2, kernel_size=None):
12
- super().__init__()
13
- self.ratio = ratio
14
- self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
15
- self.stride = ratio
16
- self.pad = self.kernel_size // ratio - 1
17
- self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
18
- self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
19
- filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio,
20
- half_width=0.6 / ratio,
21
- kernel_size=self.kernel_size)
22
- self.register_buffer("filter", filter)
23
-
24
- # x: [B, C, T]
25
- def forward(self, x):
26
- _, C, _ = x.shape
27
-
28
- x = F.pad(x, (self.pad, self.pad), mode='replicate')
29
- x = self.ratio * F.conv_transpose1d(
30
- x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
31
- x = x[..., self.pad_left:-self.pad_right]
32
-
33
- return x
34
-
35
-
36
- class DownSample1d(nn.Module):
37
- def __init__(self, ratio=2, kernel_size=None):
38
- super().__init__()
39
- self.ratio = ratio
40
- self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
41
- self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio,
42
- half_width=0.6 / ratio,
43
- stride=ratio,
44
- kernel_size=self.kernel_size)
45
-
46
- def forward(self, x):
47
- xx = self.lowpass(x)
48
-
49
- return xx
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
neucodec/bs_roformer5.py DELETED
@@ -1,120 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- import torchaudio
5
- import numpy as np
6
-
7
- from torch.nn import Module, ModuleList
8
- from einops import rearrange
9
- from torchtune.modules import RotaryPositionalEmbeddings
10
-
11
-
12
- class RMSNorm(torch.nn.Module):
13
- def __init__(self, dim: int, eps: float = 1e-6):
14
- r"""https://github.com/meta-llama/llama/blob/main/llama/model.py"""
15
- super().__init__()
16
- self.eps = eps
17
- self.weight = nn.Parameter(torch.ones(dim))
18
-
19
- def forward(self, x):
20
- norm_x = torch.mean(x ** 2, dim=-1, keepdim=True)
21
- output = x * torch.rsqrt(norm_x + self.eps) * self.weight
22
- return output
23
-
24
-
25
- class MLP(nn.Module):
26
- def __init__(self, dim: int) -> None:
27
- super().__init__()
28
-
29
- self.fc1 = nn.Linear(dim, 4 * dim, bias=False)
30
- self.silu = nn.SiLU()
31
- self.fc2 = nn.Linear(4 * dim, dim, bias=False)
32
-
33
- def forward(self, x):
34
- x = self.fc1(x)
35
- x = self.silu(x)
36
- x = self.fc2(x)
37
- return x
38
-
39
-
40
- class Attention(nn.Module):
41
-
42
- def __init__(self, dim: int, n_heads: int, rotary_embed: RotaryPositionalEmbeddings):
43
- super().__init__()
44
-
45
- assert dim % n_heads == 0
46
-
47
- self.n_heads = n_heads
48
- self.dim = dim
49
- self.rotary_embed = rotary_embed
50
-
51
- self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
52
- assert self.flash, "Must have flash attention."
53
-
54
- self.c_attn = nn.Linear(dim, 3 * dim, bias=False)
55
- self.c_proj = nn.Linear(dim, dim, bias=False)
56
-
57
- def forward(self, x):
58
- r"""
59
- Args:
60
- x: (b, t, h*d)
61
-
62
- Constants:
63
- b: batch_size
64
- t: time steps
65
- r: 3
66
- h: heads_num
67
- d: heads_dim
68
- """
69
- B, T, C = x.size()
70
-
71
- q, k, v = rearrange(self.c_attn(x), 'b t (r h d) -> r b h t d', r=3, h=self.n_heads)
72
- # q, k, v: (b, h, t, d)
73
-
74
- q = self.rotary_embed(q)
75
- k = self.rotary_embed(k)
76
-
77
- if self.flash:
78
- y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0, is_causal=False)
79
-
80
- y = rearrange(y, 'b h t d -> b t (h d)')
81
-
82
- y = self.c_proj(y)
83
- # shape: (b, t, h*d)
84
-
85
- return y
86
-
87
-
88
- class TransformerBlock(nn.Module):
89
- def __init__(self, dim: int, n_heads: int, rotary_embed: RotaryPositionalEmbeddings):
90
-
91
- super().__init__()
92
- self.dim = dim
93
- self.n_heads = n_heads
94
-
95
- self.att_norm = RMSNorm(dim)
96
- self.ffn_norm = RMSNorm(dim)
97
- self.att = Attention(dim=dim, n_heads=n_heads, rotary_embed=rotary_embed)
98
- self.mlp = MLP(dim=dim)
99
-
100
-
101
- def forward(
102
- self,
103
- x: torch.Tensor,
104
- ):
105
- x = x + self.att(self.att_norm(x))
106
- x = x + self.mlp(self.ffn_norm(x))
107
- return x
108
-
109
-
110
- if __name__ == '__main__':
111
- rotary_embed_128 = RotaryPositionalEmbeddings(dim=128)
112
- transformer_block = TransformerBlock(
113
- dim=1024,
114
- n_heads=8,
115
- rotary_embed=rotary_embed_128
116
- )
117
- x = torch.randn(2, 128, 1024)
118
- y = transformer_block(x)
119
- print(y.shape)
120
- c=1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
neucodec/codec_decoder_vocos.py DELETED
@@ -1,431 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
-
4
- from typing import List
5
- from torchtune.modules import RotaryPositionalEmbeddings
6
- from vector_quantize_pytorch import ResidualFSQ
7
-
8
- from .bs_roformer5 import TransformerBlock
9
-
10
-
11
- class ISTFT(nn.Module):
12
- """
13
- Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with
14
- windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges.
15
- See issue: https://github.com/pytorch/pytorch/issues/62323
16
- Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs.
17
- The NOLA constraint is met as we trim padded samples anyway.
18
-
19
- Args:
20
- n_fft (int): Size of Fourier transform.
21
- hop_length (int): The distance between neighboring sliding window frames.
22
- win_length (int): The size of window frame and STFT filter.
23
- padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
24
- """
25
-
26
- def __init__(
27
- self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"
28
- ):
29
- super().__init__()
30
- if padding not in ["center", "same"]:
31
- raise ValueError("Padding must be 'center' or 'same'.")
32
- self.padding = padding
33
- self.n_fft = n_fft
34
- self.hop_length = hop_length
35
- self.win_length = win_length
36
- window = torch.hann_window(win_length)
37
- self.register_buffer("window", window)
38
-
39
- def forward(self, spec: torch.Tensor) -> torch.Tensor:
40
- """
41
- Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram.
42
-
43
- Args:
44
- spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size,
45
- N is the number of frequency bins, and T is the number of time frames.
46
-
47
- Returns:
48
- Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal.
49
- """
50
- if self.padding == "center":
51
- # Fallback to pytorch native implementation
52
- return torch.istft(
53
- spec,
54
- self.n_fft,
55
- self.hop_length,
56
- self.win_length,
57
- self.window,
58
- center=True,
59
- )
60
- elif self.padding == "same":
61
- pad = (self.win_length - self.hop_length) // 2
62
- else:
63
- raise ValueError("Padding must be 'center' or 'same'.")
64
-
65
- assert spec.dim() == 3, "Expected a 3D tensor as input"
66
- B, N, T = spec.shape
67
-
68
- # Inverse FFT
69
- ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward")
70
- ifft = ifft * self.window[None, :, None]
71
-
72
- # Overlap and Add
73
- output_size = (T - 1) * self.hop_length + self.win_length
74
- y = torch.nn.functional.fold(
75
- ifft,
76
- output_size=(1, output_size),
77
- kernel_size=(1, self.win_length),
78
- stride=(1, self.hop_length),
79
- )[:, 0, 0, pad:-pad]
80
-
81
- # Window envelope
82
- window_sq = self.window.square().expand(1, T, -1).transpose(1, 2)
83
- window_envelope = torch.nn.functional.fold(
84
- window_sq,
85
- output_size=(1, output_size),
86
- kernel_size=(1, self.win_length),
87
- stride=(1, self.hop_length),
88
- ).squeeze()[pad:-pad]
89
-
90
- # Normalize
91
- assert (window_envelope > 1e-11).all()
92
- y = y / window_envelope
93
-
94
- return y
95
-
96
-
97
- class FourierHead(nn.Module):
98
- """Base class for inverse fourier modules."""
99
-
100
- def forward(self, x: torch.Tensor) -> torch.Tensor:
101
- """
102
- Args:
103
- x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
104
- L is the sequence length, and H denotes the model dimension.
105
-
106
- Returns:
107
- Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
108
- """
109
- raise NotImplementedError("Subclasses must implement the forward method.")
110
-
111
-
112
- class ISTFTHead(FourierHead):
113
- """
114
- ISTFT Head module for predicting STFT complex coefficients.
115
-
116
- Args:
117
- dim (int): Hidden dimension of the model.
118
- n_fft (int): Size of Fourier transform.
119
- hop_length (int): The distance between neighboring sliding window frames, which should align with
120
- the resolution of the input features.
121
- padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
122
- """
123
-
124
- def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"):
125
- super().__init__()
126
- out_dim = n_fft + 2
127
- self.out = torch.nn.Linear(dim, out_dim)
128
- self.istft = ISTFT(
129
- n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding
130
- )
131
-
132
- def forward(self, x: torch.Tensor) -> torch.Tensor:
133
- """
134
- Forward pass of the ISTFTHead module.
135
-
136
- Args:
137
- x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
138
- L is the sequence length, and H denotes the model dimension.
139
-
140
- Returns:
141
- Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
142
- """
143
- x_pred = self.out(x)
144
- # x_pred = x
145
- x_pred = x_pred.transpose(1, 2)
146
- mag, p = x_pred.chunk(2, dim=1)
147
- mag = torch.exp(mag)
148
- mag = torch.clip(
149
- mag, max=1e2
150
- ) # safeguard to prevent excessively large magnitudes
151
- # wrapping happens here. These two lines produce real and imaginary value
152
- x = torch.cos(p)
153
- y = torch.sin(p)
154
- # recalculating phase here does not produce anything new
155
- # only costs time
156
- # phase = torch.atan2(y, x)
157
- # S = mag * torch.exp(phase * 1j)
158
- # better directly produce the complex value
159
- S = mag * (x + 1j * y)
160
- audio = self.istft(S)
161
- return audio.unsqueeze(1), x_pred
162
-
163
-
164
- def nonlinearity(x):
165
- # swish
166
- return x * torch.sigmoid(x)
167
-
168
-
169
- def Normalize(in_channels, num_groups=32):
170
- return torch.nn.GroupNorm(
171
- num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
172
- )
173
-
174
-
175
- class ResnetBlock(nn.Module):
176
- def __init__(
177
- self,
178
- *,
179
- in_channels,
180
- out_channels=None,
181
- conv_shortcut=False,
182
- dropout,
183
- temb_channels=512,
184
- ):
185
- super().__init__()
186
- self.in_channels = in_channels
187
- out_channels = in_channels if out_channels is None else out_channels
188
- self.out_channels = out_channels
189
- self.use_conv_shortcut = conv_shortcut
190
-
191
- self.norm1 = Normalize(in_channels)
192
- self.conv1 = torch.nn.Conv1d(
193
- in_channels, out_channels, kernel_size=3, stride=1, padding=1
194
- )
195
- if temb_channels > 0:
196
- self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
197
- self.norm2 = Normalize(out_channels)
198
- self.dropout = torch.nn.Dropout(dropout)
199
- self.conv2 = torch.nn.Conv1d(
200
- out_channels, out_channels, kernel_size=3, stride=1, padding=1
201
- )
202
- if self.in_channels != self.out_channels:
203
- if self.use_conv_shortcut:
204
- self.conv_shortcut = torch.nn.Conv1d(
205
- in_channels, out_channels, kernel_size=3, stride=1, padding=1
206
- )
207
- else:
208
- self.nin_shortcut = torch.nn.Conv1d(
209
- in_channels, out_channels, kernel_size=1, stride=1, padding=0
210
- )
211
-
212
- def forward(self, x, temb=None):
213
- h = x
214
- h = self.norm1(h)
215
- h = nonlinearity(h)
216
- h = self.conv1(h)
217
-
218
- if temb is not None:
219
- h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
220
-
221
- h = self.norm2(h)
222
- h = nonlinearity(h)
223
- h = self.dropout(h)
224
- h = self.conv2(h)
225
-
226
- if self.in_channels != self.out_channels:
227
- if self.use_conv_shortcut:
228
- x = self.conv_shortcut(x)
229
- else:
230
- x = self.nin_shortcut(x)
231
-
232
- return x + h
233
-
234
-
235
- class Backbone(nn.Module):
236
- """Base class for the generator's backbone. It preserves the same temporal resolution across all layers."""
237
-
238
- def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
239
- """
240
- Args:
241
- x (Tensor): Input tensor of shape (B, C, L), where B is the batch size,
242
- C denotes output features, and L is the sequence length.
243
-
244
- Returns:
245
- Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length,
246
- and H denotes the model dimension.
247
- """
248
- raise NotImplementedError("Subclasses must implement the forward method.")
249
-
250
-
251
- class VocosBackbone(Backbone):
252
- """
253
- Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization
254
-
255
- Args:
256
- input_channels (int): Number of input features channels.
257
- dim (int): Hidden dimension of the model.
258
- intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock.
259
- num_layers (int): Number of ConvNeXtBlock layers.
260
- layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`.
261
- adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
262
- None means non-conditional model. Defaults to None.
263
- """
264
-
265
- def __init__(self, hidden_dim=1024, depth=12, heads=16, pos_meb_dim=64):
266
- super().__init__()
267
-
268
- self.embed = nn.Conv1d(hidden_dim, hidden_dim, kernel_size=7, padding=3)
269
-
270
- self.temb_ch = 0
271
- block_in = hidden_dim
272
- dropout = 0.1
273
-
274
- prior_net: List[nn.Module] = [
275
- ResnetBlock(
276
- in_channels=block_in,
277
- out_channels=block_in,
278
- temb_channels=self.temb_ch,
279
- dropout=dropout,
280
- ),
281
- ResnetBlock(
282
- in_channels=block_in,
283
- out_channels=block_in,
284
- temb_channels=self.temb_ch,
285
- dropout=dropout,
286
- ),
287
- ]
288
- self.prior_net = nn.Sequential(*prior_net)
289
-
290
- depth = depth
291
- time_rotary_embed = RotaryPositionalEmbeddings(dim=pos_meb_dim)
292
-
293
- transformer_blocks = [
294
- TransformerBlock(
295
- dim=hidden_dim, n_heads=heads, rotary_embed=time_rotary_embed
296
- )
297
- for _ in range(depth)
298
- ]
299
-
300
- self.transformers = nn.Sequential(*transformer_blocks)
301
- self.final_layer_norm = nn.LayerNorm(hidden_dim, eps=1e-6)
302
- post_net: List[nn.Module] = [
303
- ResnetBlock(
304
- in_channels=block_in,
305
- out_channels=block_in,
306
- temb_channels=self.temb_ch,
307
- dropout=dropout,
308
- ),
309
- ResnetBlock(
310
- in_channels=block_in,
311
- out_channels=block_in,
312
- temb_channels=self.temb_ch,
313
- dropout=dropout,
314
- ),
315
- ]
316
- self.post_net = nn.Sequential(*post_net)
317
-
318
- def forward(self, x: torch.Tensor) -> torch.Tensor:
319
- x = x.transpose(1, 2)
320
- x = self.embed(x)
321
- x = self.prior_net(x)
322
- x = x.transpose(1, 2)
323
- x = self.transformers(x)
324
- x = x.transpose(1, 2)
325
- x = self.post_net(x)
326
- x = x.transpose(1, 2)
327
- x = self.final_layer_norm(x)
328
- return x
329
-
330
-
331
- def init_weights(m):
332
- if isinstance(m, nn.Conv1d):
333
- nn.init.trunc_normal_(m.weight, std=0.02)
334
- nn.init.constant_(m.bias, 0)
335
-
336
-
337
- class CodecDecoderVocos(nn.Module):
338
- def __init__(
339
- self,
340
- hidden_dim=1024,
341
- depth=12,
342
- heads=16,
343
- pos_meb_dim=64,
344
- hop_length=320,
345
- vq_num_quantizers=1,
346
- vq_dim=2048, # 1024 2048
347
- vq_commit_weight=0.25,
348
- vq_weight_init=False,
349
- vq_full_commit_loss=False,
350
- codebook_size=16384,
351
- codebook_dim=16,
352
- ):
353
- super().__init__()
354
- self.hop_length = hop_length
355
-
356
- self.quantizer = ResidualFSQ(
357
- dim=vq_dim, levels=[4, 4, 4, 4, 4, 4, 4, 4], num_quantizers=1
358
- )
359
-
360
- self.backbone = VocosBackbone(
361
- hidden_dim=hidden_dim, depth=depth, heads=heads, pos_meb_dim=pos_meb_dim
362
- )
363
-
364
- self.head = ISTFTHead(
365
- dim=hidden_dim,
366
- n_fft=self.hop_length * 4,
367
- hop_length=self.hop_length,
368
- padding="same",
369
- )
370
-
371
- self.reset_parameters()
372
-
373
- def forward(self, x, vq=True):
374
- if vq is True:
375
- # x, q, commit_loss = self.quantizer(x)
376
- x = x.permute(0, 2, 1)
377
- x, q = self.quantizer(x)
378
- x = x.permute(0, 2, 1)
379
- q = q.permute(0, 2, 1)
380
- return x, q, None
381
- x = self.backbone(x)
382
- x, _ = self.head(x)
383
-
384
- return x, _
385
-
386
- def vq2emb(self, vq):
387
- self.quantizer = self.quantizer.eval()
388
- x = self.quantizer.vq2emb(vq)
389
- return x
390
-
391
- def get_emb(self):
392
- self.quantizer = self.quantizer.eval()
393
- embs = self.quantizer.get_emb()
394
- return embs
395
-
396
- def inference_vq(self, vq):
397
- x = vq[None, :, :]
398
- x = self.model(x)
399
- return x
400
-
401
- def inference_0(self, x):
402
- x, q, loss, perp = self.quantizer(x)
403
- x = self.model(x)
404
- return x, None
405
-
406
- def inference(self, x):
407
- x = self.model(x)
408
- return x, None
409
-
410
- def remove_weight_norm(self):
411
- """Remove weight normalization module from all of the layers."""
412
-
413
- def _remove_weight_norm(m):
414
- try:
415
- torch.nn.utils.remove_weight_norm(m)
416
- except ValueError: # this module didn't have weight norm
417
- return
418
-
419
- self.apply(_remove_weight_norm)
420
-
421
- def apply_weight_norm(self):
422
- """Apply weight normalization module from all of the layers."""
423
-
424
- def _apply_weight_norm(m):
425
- if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d):
426
- torch.nn.utils.weight_norm(m)
427
-
428
- self.apply(_apply_weight_norm)
429
-
430
- def reset_parameters(self):
431
- self.apply(init_weights)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
neucodec/codec_encoder.py DELETED
@@ -1,84 +0,0 @@
1
- import torch
2
- import numpy as np
3
-
4
- from torch import nn
5
-
6
- from .module import WNConv1d, EncoderBlock
7
- from .alias_free_torch import Activation1d
8
- from . import activations
9
-
10
-
11
- def init_weights(m):
12
- if isinstance(m, nn.Conv1d):
13
- nn.init.trunc_normal_(m.weight, std=0.02)
14
- nn.init.constant_(m.bias, 0)
15
-
16
-
17
- class CodecEncoder(nn.Module):
18
- def __init__(
19
- self,
20
- ngf=48,
21
- up_ratios=[2, 2, 4, 4, 5],
22
- dilations=(1, 3, 9),
23
- hidden_dim=1024,
24
- depth=12,
25
- heads=12,
26
- pos_meb_dim=64,
27
- ):
28
- super().__init__()
29
- self.hop_length = np.prod(up_ratios)
30
- self.ngf = ngf
31
- self.up_ratios = up_ratios
32
-
33
- d_model = ngf
34
- self.conv_blocks = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
35
-
36
- for i, stride in enumerate(up_ratios):
37
- d_model *= 2
38
- self.conv_blocks += [
39
- EncoderBlock(d_model, stride=stride, dilations=dilations)
40
- ]
41
-
42
- self.conv_blocks = nn.Sequential(*self.conv_blocks)
43
-
44
- self.conv_final_block = [
45
- Activation1d(
46
- activation=activations.SnakeBeta(d_model, alpha_logscale=True)
47
- ),
48
- WNConv1d(d_model, hidden_dim, kernel_size=3, padding=1),
49
- ]
50
- self.conv_final_block = nn.Sequential(*self.conv_final_block)
51
-
52
- self.reset_parameters()
53
-
54
- def forward(self, x):
55
- x = self.conv_blocks(x)
56
- x = self.conv_final_block(x)
57
- x = x.permute(0, 2, 1)
58
- return x
59
-
60
- def inference(self, x):
61
- return self.block(x)
62
-
63
- def remove_weight_norm(self):
64
- """Remove weight normalization module from all of the layers."""
65
-
66
- def _remove_weight_norm(m):
67
- try:
68
- torch.nn.utils.remove_weight_norm(m)
69
- except ValueError: # this module didn't have weight norm
70
- return
71
-
72
- self.apply(_remove_weight_norm)
73
-
74
- def apply_weight_norm(self):
75
- """Apply weight normalization module from all of the layers."""
76
-
77
- def _apply_weight_norm(m):
78
- if isinstance(m, nn.Conv1d):
79
- torch.nn.utils.weight_norm(m)
80
-
81
- self.apply(_apply_weight_norm)
82
-
83
- def reset_parameters(self):
84
- self.apply(init_weights)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
neucodec/model.py DELETED
@@ -1,269 +0,0 @@
1
- import soundfile as sf
2
- import os
3
-
4
- import torch
5
- import torch.nn as nn
6
- import torch.nn.functional as F
7
- import torchaudio
8
-
9
- from typing import Optional
10
- from torchaudio import transforms as T
11
- from transformers import AutoFeatureExtractor, Wav2Vec2BertModel
12
-
13
- from .codec_encoder import CodecEncoder
14
- from .codec_decoder_vocos import CodecDecoderVocos
15
- from .module import SemanticEncoder
16
-
17
-
18
- class NeuCodec(nn.Module):
19
- def __init__(self, ckpt_path: str, sample_rate: int, hop_length: int):
20
- super().__init__()
21
-
22
- # load ckpt
23
- ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
24
- self.sample_rate = sample_rate
25
- self.hop_length = hop_length
26
-
27
- # load modules
28
- self.semantic_model = Wav2Vec2BertModel.from_pretrained(
29
- "facebook/w2v-bert-2.0", output_hidden_states=True
30
- )
31
- self.semantic_model.eval()
32
- self.feature_extractor = AutoFeatureExtractor.from_pretrained(
33
- "facebook/w2v-bert-2.0"
34
- )
35
- self.SemanticEncoder_module = SemanticEncoder(1024, 1024, 1024)
36
- self.CodecEnc = CodecEncoder()
37
- self.generator = CodecDecoderVocos(hop_length=hop_length)
38
- self.fc_prior = nn.Linear(2048, 2048)
39
- self.fc_post_a = nn.Linear(2048, 1024)
40
-
41
- # load checkpoint
42
- self._load_ckpt(ckpt)
43
-
44
- def _load_ckpt(self, ckpt):
45
- # differentiate between `.ckpt` and `.bin`
46
- if ckpt.get("state_dict"):
47
- state_dicts = ckpt.get("state_dict")
48
- else:
49
- state_dicts = ckpt
50
-
51
- # assign keys to correct model components
52
- filtered_enc = {}
53
- filtered_gen = {}
54
- filtered_post = {}
55
- filtered_prior = {}
56
- filtered_semantic = {}
57
- for key, value in state_dicts.items():
58
- if key.startswith("CodecEnc."):
59
- new_key = key[len("CodecEnc."):]
60
- filtered_enc[new_key] = value
61
- elif key.startswith("generator."):
62
- new_key = key[len("generator."):]
63
- filtered_gen[new_key] = value
64
- elif key.startswith("fc_post_a."):
65
- new_key = key[len("fc_post_a."):]
66
- filtered_post[new_key] = value
67
- elif key.startswith("SemanticEncoder_module."):
68
- new_key = key[len("SemanticEncoder_module."):]
69
- filtered_semantic[new_key] = value
70
- elif key.startswith("fc_prior."):
71
- new_key = key[len("fc_prior."):]
72
- filtered_prior[new_key] = value
73
-
74
- # load
75
- self.CodecEnc.load_state_dict(filtered_enc)
76
- self.CodecEnc.eval()
77
- self.generator.load_state_dict(filtered_gen, strict=False)
78
- self.generator.eval()
79
- self.fc_post_a.load_state_dict(filtered_post)
80
- self.fc_post_a.eval()
81
- self.fc_prior.load_state_dict(filtered_prior)
82
- self.SemanticEncoder_module.load_state_dict(filtered_semantic)
83
- self.SemanticEncoder_module.eval()
84
-
85
- @torch.inference_mode()
86
- def encode_code(
87
- self,
88
- input_waveform: torch.Tensor,
89
- semantic_features: torch.Tensor = None,
90
- sample_rate: int = 16_000,
91
- ) -> torch.Tensor:
92
- pad_for_wav = 320 - (input_waveform.shape[1] % 320)
93
- input_waveform = torch.nn.functional.pad(input_waveform, (0, pad_for_wav))
94
-
95
- if semantic_features is None:
96
- semantic_features = self.feature_extractor(
97
- input_waveform, sampling_rate=sample_rate, return_tensors="pt"
98
- ).input_features.to(self.device) # [batch, frames, feat_dim]
99
- else:
100
- semantic_features = semantic_features[:, 0, :, :]
101
-
102
- semantic_output = self.semantic_model(semantic_features)
103
- semantic_hidden_16 = semantic_output.hidden_states[16]
104
- semantic_hidden_16 = semantic_hidden_16.transpose(
105
- 1, 2
106
- ) # [batch, hidden_dim, frames]
107
- semantic_encoded = self.SemanticEncoder_module(semantic_hidden_16)
108
- if len(input_waveform.shape) == 2:
109
- wav = input_waveform.unsqueeze(1).to(self.device) # shape: [batch, 1, time]
110
- else:
111
- wav = input_waveform.to(self.device)
112
-
113
- vq_emb = self.CodecEnc(wav) # [batch, time//down, 1024]
114
- vq_emb = vq_emb.transpose(1, 2) # -> [batch, 1024, frames]
115
-
116
- if vq_emb.shape[-1] != semantic_encoded.shape[-1]:
117
- min_len = min(vq_emb.shape[-1], semantic_encoded.shape[-1])
118
- vq_emb = vq_emb[:, :, :min_len]
119
- semantic_encoded = semantic_encoded[:, :, :min_len]
120
- concat_emb = torch.cat(
121
- [semantic_encoded, vq_emb], dim=1
122
- ) # [batch, 2048, frames]
123
- concat_emb = self.fc_prior(concat_emb.transpose(1, 2)).transpose(1, 2)
124
- _, vq_code, _ = self.generator(concat_emb, vq=True)
125
- return vq_code
126
-
127
- @torch.inference_mode()
128
- def decode_code(self, vq_code: torch.Tensor) -> torch.Tensor:
129
- vq_post_emb = self.generator.quantizer.get_output_from_indices(
130
- vq_code.transpose(1, 2)
131
- )
132
- vq_post_emb = vq_post_emb.transpose(1, 2) # [batch, 1024, frames]
133
- vq_post_emb = self.fc_post_a(vq_post_emb.transpose(1, 2)).transpose(
134
- 1, 2
135
- ) # [batch, 1024, frames]
136
- recon_audio = self.generator(vq_post_emb.transpose(1, 2), vq=False)[
137
- 0
138
- ] # [batch, time]
139
- return recon_audio
140
-
141
- @torch.inference_mode()
142
- def autoencode(self, fpath: str, output_fpath: Optional[str] = None):
143
- y, sr = torchaudio.load(fpath)
144
- if sr != 16_000:
145
- y = T.Resample(sr, 16_000)(y)
146
- vq_codes = self.encode_code(y)
147
- recon = self.decode_code(vq_codes)
148
-
149
- if output_fpath is None:
150
- name, fext = os.path.splitext(fpath)
151
- output_fpath = f"{name}_recon{fext}"
152
-
153
- sf.write(output_fpath, recon[0, 0, :].cpu(), self.sample_rate)
154
-
155
- @torch.inference_mode()
156
- def batch_encode(
157
- self, fpaths: list[str], return_tensor: bool = False
158
- ) -> tuple[list[torch.Tensor], list[int]] | tuple[torch.Tensor, list[int]]:
159
- # prepare batch
160
- wavs_batch, semantic_batch, token_durations = self._pad_batch(
161
- [self._preprocess_file(fpath) for fpath in fpaths]
162
- )
163
- vq_codes = self.encode_code(wavs_batch, semantic_batch)
164
-
165
- # return, unpad if we want to
166
- if return_tensor:
167
- return vq_codes, list(token_durations)
168
-
169
- unpadded_vq_codes = []
170
- for idx, token_dur in enumerate(token_durations):
171
- curr_codes = vq_codes[idx, :, :token_dur]
172
- unpadded_vq_codes.append(curr_codes)
173
-
174
- return unpadded_vq_codes, None
175
-
176
- @torch.inference_mode()
177
- def batch_decode(
178
- self,
179
- vq_codes: list[torch.Tensor] | torch.Tensor,
180
- token_durations: Optional[list[int]] = None,
181
- ):
182
- # pad tensor if need be
183
- if isinstance(vq_codes, list):
184
- vq_codes, token_durations = self._pad_codes(vq_codes)
185
- else:
186
- assert token_durations is not None
187
-
188
- # decode
189
- recons = self.decode_code(vq_codes)
190
-
191
- # unpad
192
- cut_recons = []
193
- for idx, token_dur in enumerate(token_durations):
194
- curr_recon = recons[idx, :, : int(token_dur * self.hop_length)]
195
- cut_recons.append(curr_recon)
196
-
197
- return cut_recons
198
-
199
- @torch.inference_mode()
200
- def batch_autoencode(
201
- self, fpaths: list[str], output_fpaths: Optional[list[str]] = None
202
- ) -> list[torch.Tensor]:
203
- vq_codes, token_durations = self.batch_encode(fpaths, return_tensor=True)
204
- cut_recons = self.batch_decode(vq_codes, token_durations)
205
-
206
- if output_fpaths:
207
- for recon, output_fpath in zip(cut_recons, output_fpaths):
208
- sf.write(output_fpath, recon.cpu().numpy()[0, :], self.sample_rate)
209
-
210
- return cut_recons
211
-
212
- def _preprocess_file(self, fpath: str):
213
- # load and resample
214
- y, sr = torchaudio.load(fpath)
215
- if sr != 16_000:
216
- y = T.Resample(sr, 16_000)(y)
217
-
218
- # compute duration for any cutting we might need to do, in terms of n_tokens
219
- token_duration = int((y.shape[-1] / 16_000) * 50)
220
-
221
- # get semantic model features: [harry] note i don't think this can be batched
222
- semantic_model_input = self.feature_extractor(
223
- y, sampling_rate=16_000, return_tensors="pt"
224
- ).input_features
225
-
226
- return y.to(self.device), semantic_model_input.to(self.device), token_duration
227
-
228
- def _pad_batch(self, batch: list[tuple[torch.Tensor, torch.Tensor, int]]):
229
- # unpack batch
230
- wavs, semantic_features, token_durations = zip(*batch)
231
- max_length_semantic = max([f.shape[1] for f in semantic_features])
232
- max_length = max_length_semantic * 320
233
-
234
- # pad wavs
235
- wavs_padded = []
236
- for audio in wavs:
237
- padding = max_length - audio.shape[1]
238
- if padding > 0:
239
- padded_audio = F.pad(audio, (0, padding), mode="constant", value=0)
240
- else:
241
- padded_audio = audio[:, :max_length]
242
- wavs_padded.append(padded_audio)
243
- wavs_tensor = torch.stack(wavs_padded)
244
-
245
- # pad semantic features
246
- semantic_features_padded = []
247
- for feat in semantic_features:
248
- padding = max_length_semantic - feat.shape[1]
249
- padded_feat = F.pad(feat, (0, 0, 0, padding), mode="constant", value=0)
250
- semantic_features_padded.append(padded_feat)
251
- semantic_feature_tensor = torch.stack(semantic_features_padded)
252
-
253
- return wavs_tensor, semantic_feature_tensor, token_durations
254
-
255
- def _pad_codes(self, vq_codes: list[torch.Tensor]):
256
- max_len = max([i.shape[-1] for i in vq_codes])
257
- token_durations = []
258
- padded_codes = []
259
- for curr_codes in vq_codes:
260
- curr_len = curr_codes.shape[-1]
261
- token_durations.append(curr_len)
262
- padding = max_len - curr_len
263
- curr_codes = F.pad(curr_codes, (0, padding), mode="constant", value=0)
264
- padded_codes.append(curr_codes)
265
- return torch.stack(padded_codes), token_durations
266
-
267
- @property
268
- def device(self):
269
- return next(self.parameters()).device
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
neucodec/module.py DELETED
@@ -1,114 +0,0 @@
1
- import torch.nn as nn
2
-
3
- from torch.nn.utils import weight_norm
4
-
5
- from .activations import SnakeBeta
6
- from .alias_free_torch import Activation1d
7
-
8
-
9
- def WNConv1d(*args, **kwargs):
10
- return weight_norm(nn.Conv1d(*args, **kwargs))
11
-
12
-
13
- class ResidualUnit(nn.Module):
14
- def __init__(self, dim: int = 16, dilation: int = 1):
15
- super().__init__()
16
- pad = ((7 - 1) * dilation) // 2
17
- self.block = nn.Sequential(
18
- Activation1d(activation=SnakeBeta(dim, alpha_logscale=True)),
19
- WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
20
- Activation1d(activation=SnakeBeta(dim, alpha_logscale=True)),
21
- WNConv1d(dim, dim, kernel_size=1),
22
- )
23
-
24
- def forward(self, x):
25
- return x + self.block(x)
26
-
27
-
28
- class EncoderBlock(nn.Module):
29
- def __init__(self, dim: int = 16, stride: int = 1, dilations=(1, 3, 9)):
30
- super().__init__()
31
- runits = [ResidualUnit(dim // 2, dilation=d) for d in dilations]
32
- self.block = nn.Sequential(
33
- *runits,
34
- Activation1d(activation=SnakeBeta(dim // 2, alpha_logscale=True)),
35
- WNConv1d(
36
- dim // 2,
37
- dim,
38
- kernel_size=2 * stride,
39
- stride=stride,
40
- padding=stride // 2 + stride % 2,
41
- ),
42
- )
43
-
44
- def forward(self, x):
45
- return self.block(x)
46
-
47
-
48
- class SemanticEncoder(nn.Module):
49
- def __init__(
50
- self,
51
- input_channels: int,
52
- code_dim: int,
53
- encode_channels: int,
54
- kernel_size: int = 3,
55
- bias: bool = True,
56
- ):
57
- super(SemanticEncoder, self).__init__()
58
-
59
- # 初始卷积,将 input_channels 映射到 encode_channels
60
- self.initial_conv = nn.Conv1d(
61
- in_channels=input_channels,
62
- out_channels=encode_channels,
63
- kernel_size=kernel_size,
64
- stride=1,
65
- padding=(kernel_size - 1) // 2,
66
- bias=False,
67
- )
68
-
69
- # 残差块
70
- self.residual_blocks = nn.Sequential(
71
- nn.ReLU(inplace=True),
72
- nn.Conv1d(
73
- encode_channels,
74
- encode_channels,
75
- kernel_size=kernel_size,
76
- stride=1,
77
- padding=(kernel_size - 1) // 2,
78
- bias=bias,
79
- ),
80
- nn.ReLU(inplace=True),
81
- nn.Conv1d(
82
- encode_channels,
83
- encode_channels,
84
- kernel_size=kernel_size,
85
- stride=1,
86
- padding=(kernel_size - 1) // 2,
87
- bias=bias,
88
- ),
89
- )
90
-
91
- # 最终卷积,将 encode_channels 映射到 code_dim
92
- self.final_conv = nn.Conv1d(
93
- in_channels=encode_channels,
94
- out_channels=code_dim,
95
- kernel_size=kernel_size,
96
- stride=1,
97
- padding=(kernel_size - 1) // 2,
98
- bias=False,
99
- )
100
-
101
- def forward(self, x):
102
- """
103
- 前向传播方法。
104
-
105
- Args:
106
- x (Tensor): 输入张量,形状为 (Batch, Input_channels, Length)
107
-
108
- Returns:
109
- Tensor: 编码后的张量,形状为 (Batch, Code_dim, Length)
110
- """
111
- x = self.initial_conv(x) # (Batch, Encode_channels, Length)
112
- x = self.residual_blocks(x) + x # 残差连接
113
- x = self.final_conv(x) # (Batch, Code_dim, Length)
114
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
setup.py DELETED
@@ -1,32 +0,0 @@
1
- from setuptools import setup, find_packages
2
-
3
-
4
- setup(
5
- name='neucodec',
6
- version='0.0.1',
7
- description='A package for neucodec, based on xcodec2.',
8
- long_description_content_type='text/markdown',
9
- author='Harry Julian',
10
- author_email='[email protected]',
11
- packages=find_packages(),
12
- install_requires=[
13
- 'librosa',
14
- 'soundfile',
15
- 'numpy>=2.0.2',
16
- 'omegaconf>=2.3.0',
17
- 'torch>=2.5.1',
18
- 'torchaudio>=2.5.1',
19
- 'torchao>=0.5.0',
20
- 'torchtune>=0.3.1',
21
- 'vector-quantize-pytorch>=1.17.8',
22
- 'rotary-embedding-torch>=0.8.4',
23
- 'transformers>=4.44.2',
24
- 'boto3>1.0',
25
- 'tqdm',
26
- ],
27
- classifiers=[
28
- 'Programming Language :: Python',
29
- 'Programming Language :: Python :: 3',
30
- 'Programming Language :: Python :: 3.10',
31
- ],
32
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/__init__.py DELETED
File without changes
tests/test_neucodec.py DELETED
@@ -1,128 +0,0 @@
1
- import pytest
2
- import torch
3
- import torchaudio
4
- import librosa
5
- from xcodec2 import XCodec2, MiniXCodec2Encoder
6
-
7
-
8
- @pytest.fixture
9
- def model_16khz():
10
- return XCodec2.from_cache("16khz")
11
-
12
-
13
- @pytest.fixture
14
- def model_24khz():
15
- return XCodec2.from_cache("24khz")
16
-
17
-
18
- @pytest.fixture
19
- def model_asr_encoder():
20
- return MiniXCodec2Encoder.from_cache()
21
-
22
-
23
- @pytest.fixture
24
- def example_audio():
25
- y, sr = torchaudio.load(librosa.ex("libri1"))
26
- return y, sr
27
-
28
-
29
- @pytest.fixture
30
- def example_fpath():
31
- return librosa.ex("libri1")
32
-
33
-
34
- @pytest.fixture
35
- def batch_fpaths():
36
- return [librosa.ex("libri1"), librosa.ex("libri2")]
37
-
38
-
39
- def load_and_validate_audio(save_path, sample_rate):
40
- _, sr = torchaudio.load(save_path)
41
- assert sr == sample_rate
42
-
43
-
44
- def test_16khz_autoencode(example_fpath, tmp_path, model_16khz):
45
- save_path = str(tmp_path / "0.wav")
46
- model_16khz.autoencode(example_fpath, save_path)
47
- load_and_validate_audio(save_path, 16_000)
48
-
49
-
50
- def test_24khz_autoencode(example_fpath, tmp_path, model_24khz):
51
- save_path = str(tmp_path / "0.wav")
52
- model_24khz.autoencode(example_fpath, save_path)
53
- load_and_validate_audio(save_path, 24_000)
54
-
55
-
56
- def test_24khz_encode_decode_single(example_audio, model_24khz):
57
- y, sr = example_audio
58
- if sr != 16_000:
59
- y = torchaudio.transforms.Resample(sr, 16_000)(y)
60
- sr = 16_000
61
-
62
- # encode
63
- vq_codes = model_24khz.encode_code(y, sample_rate=sr)
64
- assert isinstance(vq_codes, torch.Tensor)
65
- assert vq_codes.dim() == 3 # [batch, channels, time]
66
-
67
- # decode
68
- reconstructed = model_24khz.decode_code(vq_codes)
69
- assert isinstance(reconstructed, torch.Tensor)
70
- assert reconstructed.dim() == 3 # [batch, channels, time]
71
-
72
-
73
- def test_24khz_batch_encode(batch_fpaths, model_24khz):
74
- vq_codes_list, token_durations = model_24khz.batch_encode(batch_fpaths, return_tensor=False)
75
- assert isinstance(vq_codes_list, list)
76
- assert token_durations is None
77
- assert len(vq_codes_list) == 2
78
-
79
- for codes in vq_codes_list:
80
- assert isinstance(codes, torch.Tensor)
81
- assert codes.dim() == 2 # [channels, time]
82
-
83
-
84
- def test_24khz_batch_encode_tensor(batch_fpaths, model_24khz):
85
- vq_codes_tensor, token_durations = model_24khz.batch_encode(batch_fpaths, return_tensor=True)
86
- assert isinstance(vq_codes_tensor, torch.Tensor)
87
- assert isinstance(token_durations, list)
88
- assert vq_codes_tensor.dim() == 3 # [batch, channels, time]
89
- assert len(token_durations) == 2
90
- assert len(set(token_durations)) == 2 # ensure we get two different durations back
91
-
92
-
93
- def test_24khz_batch_decode(batch_fpaths, model_24khz):
94
- vq_codes_tensor, token_durations = model_24khz.batch_encode(batch_fpaths, return_tensor=True)
95
- reconstructed_list = model_24khz.batch_decode(vq_codes_tensor, token_durations)
96
- assert isinstance(reconstructed_list, list)
97
- assert len(reconstructed_list) == 2
98
- for recon in reconstructed_list:
99
- assert isinstance(recon, torch.Tensor)
100
- assert recon.dim() == 2 # [channels, time]
101
-
102
-
103
- def test_24khz_batch_decode_list_input(batch_fpaths, model_24khz):
104
- vq_codes_list, _ = model_24khz.batch_encode(batch_fpaths, return_tensor=False)
105
- reconstructed_list = model_24khz.batch_decode(vq_codes_list)
106
- assert isinstance(reconstructed_list, list)
107
- assert len(reconstructed_list) == 2
108
- for recon in reconstructed_list:
109
- assert isinstance(recon, torch.Tensor)
110
- assert recon.dim() == 2 # [channels, time]
111
-
112
-
113
- def test_24khz_batch_autoencode(batch_fpaths, tmp_path, model_24khz):
114
- output_paths = [str(tmp_path / f"{i}.wav") for i in range(len(batch_fpaths))]
115
- reconstructed_list = model_24khz.batch_autoencode(batch_fpaths, output_paths)
116
- assert isinstance(reconstructed_list, list)
117
- assert len(reconstructed_list) == 2
118
- for i, output_path in enumerate(output_paths):
119
- load_and_validate_audio(output_path, 24_000)
120
-
121
-
122
- def test_asr_encoder_encode(example_audio, model_asr_encoder):
123
- y, sr = example_audio
124
- if sr != model_asr_encoder.sample_rate:
125
- y = torchaudio.transforms.Resample(sr, model_asr_encoder.sample_rate)(y)
126
- vq_codes = model_asr_encoder.encode_code(y)
127
- assert isinstance(vq_codes, torch.Tensor)
128
- assert vq_codes.dim() == 3