Commit
·
2dfcd66
1
Parent(s):
6a9bb58
removed code
Browse files- .gitattributes +0 -35
- .gitignore +0 -189
- neucodec/__init__.py +0 -3
- neucodec/activations.py +0 -120
- neucodec/alias_free_torch/__init__.py +0 -6
- neucodec/alias_free_torch/act.py +0 -28
- neucodec/alias_free_torch/filter.py +0 -95
- neucodec/alias_free_torch/resample.py +0 -49
- neucodec/bs_roformer5.py +0 -120
- neucodec/codec_decoder_vocos.py +0 -431
- neucodec/codec_encoder.py +0 -84
- neucodec/model.py +0 -269
- neucodec/module.py +0 -114
- setup.py +0 -32
- tests/__init__.py +0 -0
- tests/test_neucodec.py +0 -128
.gitattributes
DELETED
@@ -1,35 +0,0 @@
|
|
1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.gitignore
DELETED
@@ -1,189 +0,0 @@
|
|
1 |
-
# Emacs
|
2 |
-
*~
|
3 |
-
|
4 |
-
# Byte-compiled / optimized / DLL files
|
5 |
-
__pycache__/
|
6 |
-
*.py[cod]
|
7 |
-
*$py.class
|
8 |
-
|
9 |
-
# C extensions
|
10 |
-
*.so
|
11 |
-
|
12 |
-
# Distribution / packaging
|
13 |
-
.Python
|
14 |
-
build/
|
15 |
-
develop-eggs/
|
16 |
-
dist/
|
17 |
-
downloads/
|
18 |
-
eggs/
|
19 |
-
.eggs/
|
20 |
-
lib/
|
21 |
-
lib64/
|
22 |
-
parts/
|
23 |
-
sdist/
|
24 |
-
var/
|
25 |
-
wheels/
|
26 |
-
share/python-wheels/
|
27 |
-
*.egg-info/
|
28 |
-
.installed.cfg
|
29 |
-
*.egg
|
30 |
-
MANIFEST
|
31 |
-
/runs
|
32 |
-
/checkpoints
|
33 |
-
/base
|
34 |
-
|
35 |
-
# PyInstaller
|
36 |
-
# Usually these files are written by a python script from a template
|
37 |
-
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
38 |
-
*.manifest
|
39 |
-
*.spec
|
40 |
-
|
41 |
-
# Installer logs
|
42 |
-
pip-log.txt
|
43 |
-
pip-delete-this-directory.txt
|
44 |
-
|
45 |
-
# Unit test / coverage reports
|
46 |
-
htmlcov/
|
47 |
-
.tox/
|
48 |
-
.nox/
|
49 |
-
.coverage
|
50 |
-
.coverage.*
|
51 |
-
.cache
|
52 |
-
nosetests.xml
|
53 |
-
coverage.xml
|
54 |
-
*.cover
|
55 |
-
*.py,cover
|
56 |
-
.hypothesis/
|
57 |
-
.pytest_cache/
|
58 |
-
cover/
|
59 |
-
|
60 |
-
# Translations
|
61 |
-
*.mo
|
62 |
-
*.pot
|
63 |
-
|
64 |
-
# Django stuff:
|
65 |
-
*.log
|
66 |
-
local_settings.py
|
67 |
-
db.sqlite3
|
68 |
-
db.sqlite3-journal
|
69 |
-
|
70 |
-
# Flask stuff:
|
71 |
-
instance/
|
72 |
-
.webassets-cache
|
73 |
-
|
74 |
-
# Scrapy stuff:
|
75 |
-
.scrapy
|
76 |
-
|
77 |
-
# Sphinx documentation
|
78 |
-
docs/_build/
|
79 |
-
|
80 |
-
# PyBuilder
|
81 |
-
.pybuilder/
|
82 |
-
target/
|
83 |
-
|
84 |
-
# Jupyter Notebook
|
85 |
-
.ipynb_checkpoints
|
86 |
-
|
87 |
-
# IPython
|
88 |
-
profile_default/
|
89 |
-
ipython_config.py
|
90 |
-
|
91 |
-
# pyenv
|
92 |
-
# For a library or package, you might want to ignore these files since the code is
|
93 |
-
# intended to run in multiple environments; otherwise, check them in:
|
94 |
-
# .python-version
|
95 |
-
|
96 |
-
# pipenv
|
97 |
-
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
98 |
-
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
99 |
-
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
100 |
-
# install all needed dependencies.
|
101 |
-
#Pipfile.lock
|
102 |
-
|
103 |
-
# poetry
|
104 |
-
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
105 |
-
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
106 |
-
# commonly ignored for libraries.
|
107 |
-
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
108 |
-
#poetry.lock
|
109 |
-
|
110 |
-
# pdm
|
111 |
-
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
112 |
-
#pdm.lock
|
113 |
-
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
114 |
-
# in version control.
|
115 |
-
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
116 |
-
.pdm.toml
|
117 |
-
.pdm-python
|
118 |
-
.pdm-build/
|
119 |
-
|
120 |
-
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
121 |
-
__pypackages__/
|
122 |
-
|
123 |
-
# Celery stuff
|
124 |
-
celerybeat-schedule
|
125 |
-
celerybeat.pid
|
126 |
-
|
127 |
-
# SageMath parsed files
|
128 |
-
*.sage.py
|
129 |
-
|
130 |
-
# Environments
|
131 |
-
.env
|
132 |
-
.venv
|
133 |
-
env/
|
134 |
-
venv/
|
135 |
-
ENV/
|
136 |
-
env.bak/
|
137 |
-
venv.bak/
|
138 |
-
|
139 |
-
# Spyder project settings
|
140 |
-
.spyderproject
|
141 |
-
.spyproject
|
142 |
-
|
143 |
-
# Rope project settings
|
144 |
-
.ropeproject
|
145 |
-
|
146 |
-
# mkdocs documentation
|
147 |
-
/site
|
148 |
-
|
149 |
-
# mypy
|
150 |
-
.mypy_cache/
|
151 |
-
.dmypy.json
|
152 |
-
dmypy.json
|
153 |
-
|
154 |
-
# Pyre type checker
|
155 |
-
.pyre/
|
156 |
-
|
157 |
-
# pytype static type analyzer
|
158 |
-
.pytype/
|
159 |
-
|
160 |
-
# Cython debug symbols
|
161 |
-
cython_debug/
|
162 |
-
|
163 |
-
# PyCharm
|
164 |
-
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
165 |
-
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
166 |
-
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
167 |
-
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
168 |
-
#.idea/
|
169 |
-
|
170 |
-
/runs
|
171 |
-
/.cache
|
172 |
-
/__pycache__
|
173 |
-
|
174 |
-
*.wav
|
175 |
-
*.pth
|
176 |
-
*.pt
|
177 |
-
*.pt.gz
|
178 |
-
wandb/
|
179 |
-
sven_latest_checkpoint/
|
180 |
-
sven_qwen/
|
181 |
-
pretrained_models/
|
182 |
-
xcodec/
|
183 |
-
small_speaker_shards_all/
|
184 |
-
sven_all_shards/
|
185 |
-
qwen_380k/
|
186 |
-
evals/
|
187 |
-
*.safetensors
|
188 |
-
*.pt
|
189 |
-
.ruff_cache
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
neucodec/__init__.py
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
from .codec_encoder import CodecEncoder
|
2 |
-
from .codec_decoder_vocos import CodecDecoderVocos
|
3 |
-
from .model import NeuCodec
|
|
|
|
|
|
|
|
neucodec/activations.py
DELETED
@@ -1,120 +0,0 @@
|
|
1 |
-
# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
|
2 |
-
# LICENSE is in incl_licenses directory.
|
3 |
-
|
4 |
-
import torch
|
5 |
-
from torch import nn, sin, pow
|
6 |
-
from torch.nn import Parameter
|
7 |
-
|
8 |
-
|
9 |
-
class Snake(nn.Module):
|
10 |
-
'''
|
11 |
-
Implementation of a sine-based periodic activation function
|
12 |
-
Shape:
|
13 |
-
- Input: (B, C, T)
|
14 |
-
- Output: (B, C, T), same shape as the input
|
15 |
-
Parameters:
|
16 |
-
- alpha - trainable parameter
|
17 |
-
References:
|
18 |
-
- This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
19 |
-
https://arxiv.org/abs/2006.08195
|
20 |
-
Examples:
|
21 |
-
>>> a1 = snake(256)
|
22 |
-
>>> x = torch.randn(256)
|
23 |
-
>>> x = a1(x)
|
24 |
-
'''
|
25 |
-
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
26 |
-
'''
|
27 |
-
Initialization.
|
28 |
-
INPUT:
|
29 |
-
- in_features: shape of the input
|
30 |
-
- alpha: trainable parameter
|
31 |
-
alpha is initialized to 1 by default, higher values = higher-frequency.
|
32 |
-
alpha will be trained along with the rest of your model.
|
33 |
-
'''
|
34 |
-
super(Snake, self).__init__()
|
35 |
-
self.in_features = in_features
|
36 |
-
|
37 |
-
# initialize alpha
|
38 |
-
self.alpha_logscale = alpha_logscale
|
39 |
-
if self.alpha_logscale: # log scale alphas initialized to zeros
|
40 |
-
self.alpha = Parameter(torch.zeros(in_features) * alpha)
|
41 |
-
else: # linear scale alphas initialized to ones
|
42 |
-
self.alpha = Parameter(torch.ones(in_features) * alpha)
|
43 |
-
|
44 |
-
self.alpha.requires_grad = alpha_trainable
|
45 |
-
|
46 |
-
self.no_div_by_zero = 0.000000001
|
47 |
-
|
48 |
-
def forward(self, x):
|
49 |
-
'''
|
50 |
-
Forward pass of the function.
|
51 |
-
Applies the function to the input elementwise.
|
52 |
-
Snake ∶= x + 1/a * sin^2 (xa)
|
53 |
-
'''
|
54 |
-
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
55 |
-
if self.alpha_logscale:
|
56 |
-
alpha = torch.exp(alpha)
|
57 |
-
x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
58 |
-
|
59 |
-
return x
|
60 |
-
|
61 |
-
|
62 |
-
class SnakeBeta(nn.Module):
|
63 |
-
'''
|
64 |
-
A modified Snake function which uses separate parameters for the magnitude of the periodic components
|
65 |
-
Shape:
|
66 |
-
- Input: (B, C, T)
|
67 |
-
- Output: (B, C, T), same shape as the input
|
68 |
-
Parameters:
|
69 |
-
- alpha - trainable parameter that controls frequency
|
70 |
-
- beta - trainable parameter that controls magnitude
|
71 |
-
References:
|
72 |
-
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
73 |
-
https://arxiv.org/abs/2006.08195
|
74 |
-
Examples:
|
75 |
-
>>> a1 = snakebeta(256)
|
76 |
-
>>> x = torch.randn(256)
|
77 |
-
>>> x = a1(x)
|
78 |
-
'''
|
79 |
-
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
80 |
-
'''
|
81 |
-
Initialization.
|
82 |
-
INPUT:
|
83 |
-
- in_features: shape of the input
|
84 |
-
- alpha - trainable parameter that controls frequency
|
85 |
-
- beta - trainable parameter that controls magnitude
|
86 |
-
alpha is initialized to 1 by default, higher values = higher-frequency.
|
87 |
-
beta is initialized to 1 by default, higher values = higher-magnitude.
|
88 |
-
alpha will be trained along with the rest of your model.
|
89 |
-
'''
|
90 |
-
super(SnakeBeta, self).__init__()
|
91 |
-
self.in_features = in_features
|
92 |
-
|
93 |
-
# initialize alpha
|
94 |
-
self.alpha_logscale = alpha_logscale
|
95 |
-
if self.alpha_logscale: # log scale alphas initialized to zeros
|
96 |
-
self.alpha = Parameter(torch.zeros(in_features) * alpha)
|
97 |
-
self.beta = Parameter(torch.zeros(in_features) * alpha)
|
98 |
-
else: # linear scale alphas initialized to ones
|
99 |
-
self.alpha = Parameter(torch.ones(in_features) * alpha)
|
100 |
-
self.beta = Parameter(torch.ones(in_features) * alpha)
|
101 |
-
|
102 |
-
self.alpha.requires_grad = alpha_trainable
|
103 |
-
self.beta.requires_grad = alpha_trainable
|
104 |
-
|
105 |
-
self.no_div_by_zero = 0.000000001
|
106 |
-
|
107 |
-
def forward(self, x):
|
108 |
-
'''
|
109 |
-
Forward pass of the function.
|
110 |
-
Applies the function to the input elementwise.
|
111 |
-
SnakeBeta ∶= x + 1/b * sin^2 (xa)
|
112 |
-
'''
|
113 |
-
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
114 |
-
beta = self.beta.unsqueeze(0).unsqueeze(-1)
|
115 |
-
if self.alpha_logscale:
|
116 |
-
alpha = torch.exp(alpha)
|
117 |
-
beta = torch.exp(beta)
|
118 |
-
x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
119 |
-
|
120 |
-
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
neucodec/alias_free_torch/__init__.py
DELETED
@@ -1,6 +0,0 @@
|
|
1 |
-
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
2 |
-
# LICENSE is in incl_licenses directory.
|
3 |
-
|
4 |
-
from .filter import *
|
5 |
-
from .resample import *
|
6 |
-
from .act import *
|
|
|
|
|
|
|
|
|
|
|
|
|
|
neucodec/alias_free_torch/act.py
DELETED
@@ -1,28 +0,0 @@
|
|
1 |
-
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
2 |
-
# LICENSE is in incl_licenses directory.
|
3 |
-
|
4 |
-
import torch.nn as nn
|
5 |
-
from .resample import UpSample1d, DownSample1d
|
6 |
-
|
7 |
-
|
8 |
-
class Activation1d(nn.Module):
|
9 |
-
def __init__(self,
|
10 |
-
activation,
|
11 |
-
up_ratio: int = 2,
|
12 |
-
down_ratio: int = 2,
|
13 |
-
up_kernel_size: int = 12,
|
14 |
-
down_kernel_size: int = 12):
|
15 |
-
super().__init__()
|
16 |
-
self.up_ratio = up_ratio
|
17 |
-
self.down_ratio = down_ratio
|
18 |
-
self.act = activation
|
19 |
-
self.upsample = UpSample1d(up_ratio, up_kernel_size)
|
20 |
-
self.downsample = DownSample1d(down_ratio, down_kernel_size)
|
21 |
-
|
22 |
-
# x: [B,C,T]
|
23 |
-
def forward(self, x):
|
24 |
-
x = self.upsample(x)
|
25 |
-
x = self.act(x)
|
26 |
-
x = self.downsample(x)
|
27 |
-
|
28 |
-
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
neucodec/alias_free_torch/filter.py
DELETED
@@ -1,95 +0,0 @@
|
|
1 |
-
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
2 |
-
# LICENSE is in incl_licenses directory.
|
3 |
-
|
4 |
-
import torch
|
5 |
-
import torch.nn as nn
|
6 |
-
import torch.nn.functional as F
|
7 |
-
import math
|
8 |
-
|
9 |
-
if 'sinc' in dir(torch):
|
10 |
-
sinc = torch.sinc
|
11 |
-
else:
|
12 |
-
# This code is adopted from adefossez's julius.core.sinc under the MIT License
|
13 |
-
# https://adefossez.github.io/julius/julius/core.html
|
14 |
-
# LICENSE is in incl_licenses directory.
|
15 |
-
def sinc(x: torch.Tensor):
|
16 |
-
"""
|
17 |
-
Implementation of sinc, i.e. sin(pi * x) / (pi * x)
|
18 |
-
__Warning__: Different to julius.sinc, the input is multiplied by `pi`!
|
19 |
-
"""
|
20 |
-
return torch.where(x == 0,
|
21 |
-
torch.tensor(1., device=x.device, dtype=x.dtype),
|
22 |
-
torch.sin(math.pi * x) / math.pi / x)
|
23 |
-
|
24 |
-
|
25 |
-
# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
|
26 |
-
# https://adefossez.github.io/julius/julius/lowpass.html
|
27 |
-
# LICENSE is in incl_licenses directory.
|
28 |
-
def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size]
|
29 |
-
even = (kernel_size % 2 == 0)
|
30 |
-
half_size = kernel_size // 2
|
31 |
-
|
32 |
-
#For kaiser window
|
33 |
-
delta_f = 4 * half_width
|
34 |
-
A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
|
35 |
-
if A > 50.:
|
36 |
-
beta = 0.1102 * (A - 8.7)
|
37 |
-
elif A >= 21.:
|
38 |
-
beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.)
|
39 |
-
else:
|
40 |
-
beta = 0.
|
41 |
-
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
|
42 |
-
|
43 |
-
# ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
|
44 |
-
if even:
|
45 |
-
time = (torch.arange(-half_size, half_size) + 0.5)
|
46 |
-
else:
|
47 |
-
time = torch.arange(kernel_size) - half_size
|
48 |
-
if cutoff == 0:
|
49 |
-
filter_ = torch.zeros_like(time)
|
50 |
-
else:
|
51 |
-
filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
|
52 |
-
# Normalize filter to have sum = 1, otherwise we will have a small leakage
|
53 |
-
# of the constant component in the input signal.
|
54 |
-
filter_ /= filter_.sum()
|
55 |
-
filter = filter_.view(1, 1, kernel_size)
|
56 |
-
|
57 |
-
return filter
|
58 |
-
|
59 |
-
|
60 |
-
class LowPassFilter1d(nn.Module):
|
61 |
-
def __init__(self,
|
62 |
-
cutoff=0.5,
|
63 |
-
half_width=0.6,
|
64 |
-
stride: int = 1,
|
65 |
-
padding: bool = True,
|
66 |
-
padding_mode: str = 'replicate',
|
67 |
-
kernel_size: int = 12):
|
68 |
-
# kernel_size should be even number for stylegan3 setup,
|
69 |
-
# in this implementation, odd number is also possible.
|
70 |
-
super().__init__()
|
71 |
-
if cutoff < -0.:
|
72 |
-
raise ValueError("Minimum cutoff must be larger than zero.")
|
73 |
-
if cutoff > 0.5:
|
74 |
-
raise ValueError("A cutoff above 0.5 does not make sense.")
|
75 |
-
self.kernel_size = kernel_size
|
76 |
-
self.even = (kernel_size % 2 == 0)
|
77 |
-
self.pad_left = kernel_size // 2 - int(self.even)
|
78 |
-
self.pad_right = kernel_size // 2
|
79 |
-
self.stride = stride
|
80 |
-
self.padding = padding
|
81 |
-
self.padding_mode = padding_mode
|
82 |
-
filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
|
83 |
-
self.register_buffer("filter", filter)
|
84 |
-
|
85 |
-
#input [B, C, T]
|
86 |
-
def forward(self, x):
|
87 |
-
_, C, _ = x.shape
|
88 |
-
|
89 |
-
if self.padding:
|
90 |
-
x = F.pad(x, (self.pad_left, self.pad_right),
|
91 |
-
mode=self.padding_mode)
|
92 |
-
out = F.conv1d(x, self.filter.expand(C, -1, -1),
|
93 |
-
stride=self.stride, groups=C)
|
94 |
-
|
95 |
-
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
neucodec/alias_free_torch/resample.py
DELETED
@@ -1,49 +0,0 @@
|
|
1 |
-
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
2 |
-
# LICENSE is in incl_licenses directory.
|
3 |
-
|
4 |
-
import torch.nn as nn
|
5 |
-
from torch.nn import functional as F
|
6 |
-
from .filter import LowPassFilter1d
|
7 |
-
from .filter import kaiser_sinc_filter1d
|
8 |
-
|
9 |
-
|
10 |
-
class UpSample1d(nn.Module):
|
11 |
-
def __init__(self, ratio=2, kernel_size=None):
|
12 |
-
super().__init__()
|
13 |
-
self.ratio = ratio
|
14 |
-
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
15 |
-
self.stride = ratio
|
16 |
-
self.pad = self.kernel_size // ratio - 1
|
17 |
-
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
|
18 |
-
self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
|
19 |
-
filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio,
|
20 |
-
half_width=0.6 / ratio,
|
21 |
-
kernel_size=self.kernel_size)
|
22 |
-
self.register_buffer("filter", filter)
|
23 |
-
|
24 |
-
# x: [B, C, T]
|
25 |
-
def forward(self, x):
|
26 |
-
_, C, _ = x.shape
|
27 |
-
|
28 |
-
x = F.pad(x, (self.pad, self.pad), mode='replicate')
|
29 |
-
x = self.ratio * F.conv_transpose1d(
|
30 |
-
x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
|
31 |
-
x = x[..., self.pad_left:-self.pad_right]
|
32 |
-
|
33 |
-
return x
|
34 |
-
|
35 |
-
|
36 |
-
class DownSample1d(nn.Module):
|
37 |
-
def __init__(self, ratio=2, kernel_size=None):
|
38 |
-
super().__init__()
|
39 |
-
self.ratio = ratio
|
40 |
-
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
41 |
-
self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio,
|
42 |
-
half_width=0.6 / ratio,
|
43 |
-
stride=ratio,
|
44 |
-
kernel_size=self.kernel_size)
|
45 |
-
|
46 |
-
def forward(self, x):
|
47 |
-
xx = self.lowpass(x)
|
48 |
-
|
49 |
-
return xx
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
neucodec/bs_roformer5.py
DELETED
@@ -1,120 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import torch.nn as nn
|
3 |
-
import torch.nn.functional as F
|
4 |
-
import torchaudio
|
5 |
-
import numpy as np
|
6 |
-
|
7 |
-
from torch.nn import Module, ModuleList
|
8 |
-
from einops import rearrange
|
9 |
-
from torchtune.modules import RotaryPositionalEmbeddings
|
10 |
-
|
11 |
-
|
12 |
-
class RMSNorm(torch.nn.Module):
|
13 |
-
def __init__(self, dim: int, eps: float = 1e-6):
|
14 |
-
r"""https://github.com/meta-llama/llama/blob/main/llama/model.py"""
|
15 |
-
super().__init__()
|
16 |
-
self.eps = eps
|
17 |
-
self.weight = nn.Parameter(torch.ones(dim))
|
18 |
-
|
19 |
-
def forward(self, x):
|
20 |
-
norm_x = torch.mean(x ** 2, dim=-1, keepdim=True)
|
21 |
-
output = x * torch.rsqrt(norm_x + self.eps) * self.weight
|
22 |
-
return output
|
23 |
-
|
24 |
-
|
25 |
-
class MLP(nn.Module):
|
26 |
-
def __init__(self, dim: int) -> None:
|
27 |
-
super().__init__()
|
28 |
-
|
29 |
-
self.fc1 = nn.Linear(dim, 4 * dim, bias=False)
|
30 |
-
self.silu = nn.SiLU()
|
31 |
-
self.fc2 = nn.Linear(4 * dim, dim, bias=False)
|
32 |
-
|
33 |
-
def forward(self, x):
|
34 |
-
x = self.fc1(x)
|
35 |
-
x = self.silu(x)
|
36 |
-
x = self.fc2(x)
|
37 |
-
return x
|
38 |
-
|
39 |
-
|
40 |
-
class Attention(nn.Module):
|
41 |
-
|
42 |
-
def __init__(self, dim: int, n_heads: int, rotary_embed: RotaryPositionalEmbeddings):
|
43 |
-
super().__init__()
|
44 |
-
|
45 |
-
assert dim % n_heads == 0
|
46 |
-
|
47 |
-
self.n_heads = n_heads
|
48 |
-
self.dim = dim
|
49 |
-
self.rotary_embed = rotary_embed
|
50 |
-
|
51 |
-
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
|
52 |
-
assert self.flash, "Must have flash attention."
|
53 |
-
|
54 |
-
self.c_attn = nn.Linear(dim, 3 * dim, bias=False)
|
55 |
-
self.c_proj = nn.Linear(dim, dim, bias=False)
|
56 |
-
|
57 |
-
def forward(self, x):
|
58 |
-
r"""
|
59 |
-
Args:
|
60 |
-
x: (b, t, h*d)
|
61 |
-
|
62 |
-
Constants:
|
63 |
-
b: batch_size
|
64 |
-
t: time steps
|
65 |
-
r: 3
|
66 |
-
h: heads_num
|
67 |
-
d: heads_dim
|
68 |
-
"""
|
69 |
-
B, T, C = x.size()
|
70 |
-
|
71 |
-
q, k, v = rearrange(self.c_attn(x), 'b t (r h d) -> r b h t d', r=3, h=self.n_heads)
|
72 |
-
# q, k, v: (b, h, t, d)
|
73 |
-
|
74 |
-
q = self.rotary_embed(q)
|
75 |
-
k = self.rotary_embed(k)
|
76 |
-
|
77 |
-
if self.flash:
|
78 |
-
y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0, is_causal=False)
|
79 |
-
|
80 |
-
y = rearrange(y, 'b h t d -> b t (h d)')
|
81 |
-
|
82 |
-
y = self.c_proj(y)
|
83 |
-
# shape: (b, t, h*d)
|
84 |
-
|
85 |
-
return y
|
86 |
-
|
87 |
-
|
88 |
-
class TransformerBlock(nn.Module):
|
89 |
-
def __init__(self, dim: int, n_heads: int, rotary_embed: RotaryPositionalEmbeddings):
|
90 |
-
|
91 |
-
super().__init__()
|
92 |
-
self.dim = dim
|
93 |
-
self.n_heads = n_heads
|
94 |
-
|
95 |
-
self.att_norm = RMSNorm(dim)
|
96 |
-
self.ffn_norm = RMSNorm(dim)
|
97 |
-
self.att = Attention(dim=dim, n_heads=n_heads, rotary_embed=rotary_embed)
|
98 |
-
self.mlp = MLP(dim=dim)
|
99 |
-
|
100 |
-
|
101 |
-
def forward(
|
102 |
-
self,
|
103 |
-
x: torch.Tensor,
|
104 |
-
):
|
105 |
-
x = x + self.att(self.att_norm(x))
|
106 |
-
x = x + self.mlp(self.ffn_norm(x))
|
107 |
-
return x
|
108 |
-
|
109 |
-
|
110 |
-
if __name__ == '__main__':
|
111 |
-
rotary_embed_128 = RotaryPositionalEmbeddings(dim=128)
|
112 |
-
transformer_block = TransformerBlock(
|
113 |
-
dim=1024,
|
114 |
-
n_heads=8,
|
115 |
-
rotary_embed=rotary_embed_128
|
116 |
-
)
|
117 |
-
x = torch.randn(2, 128, 1024)
|
118 |
-
y = transformer_block(x)
|
119 |
-
print(y.shape)
|
120 |
-
c=1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
neucodec/codec_decoder_vocos.py
DELETED
@@ -1,431 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import torch.nn as nn
|
3 |
-
|
4 |
-
from typing import List
|
5 |
-
from torchtune.modules import RotaryPositionalEmbeddings
|
6 |
-
from vector_quantize_pytorch import ResidualFSQ
|
7 |
-
|
8 |
-
from .bs_roformer5 import TransformerBlock
|
9 |
-
|
10 |
-
|
11 |
-
class ISTFT(nn.Module):
|
12 |
-
"""
|
13 |
-
Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with
|
14 |
-
windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges.
|
15 |
-
See issue: https://github.com/pytorch/pytorch/issues/62323
|
16 |
-
Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs.
|
17 |
-
The NOLA constraint is met as we trim padded samples anyway.
|
18 |
-
|
19 |
-
Args:
|
20 |
-
n_fft (int): Size of Fourier transform.
|
21 |
-
hop_length (int): The distance between neighboring sliding window frames.
|
22 |
-
win_length (int): The size of window frame and STFT filter.
|
23 |
-
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
24 |
-
"""
|
25 |
-
|
26 |
-
def __init__(
|
27 |
-
self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"
|
28 |
-
):
|
29 |
-
super().__init__()
|
30 |
-
if padding not in ["center", "same"]:
|
31 |
-
raise ValueError("Padding must be 'center' or 'same'.")
|
32 |
-
self.padding = padding
|
33 |
-
self.n_fft = n_fft
|
34 |
-
self.hop_length = hop_length
|
35 |
-
self.win_length = win_length
|
36 |
-
window = torch.hann_window(win_length)
|
37 |
-
self.register_buffer("window", window)
|
38 |
-
|
39 |
-
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
40 |
-
"""
|
41 |
-
Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram.
|
42 |
-
|
43 |
-
Args:
|
44 |
-
spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size,
|
45 |
-
N is the number of frequency bins, and T is the number of time frames.
|
46 |
-
|
47 |
-
Returns:
|
48 |
-
Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal.
|
49 |
-
"""
|
50 |
-
if self.padding == "center":
|
51 |
-
# Fallback to pytorch native implementation
|
52 |
-
return torch.istft(
|
53 |
-
spec,
|
54 |
-
self.n_fft,
|
55 |
-
self.hop_length,
|
56 |
-
self.win_length,
|
57 |
-
self.window,
|
58 |
-
center=True,
|
59 |
-
)
|
60 |
-
elif self.padding == "same":
|
61 |
-
pad = (self.win_length - self.hop_length) // 2
|
62 |
-
else:
|
63 |
-
raise ValueError("Padding must be 'center' or 'same'.")
|
64 |
-
|
65 |
-
assert spec.dim() == 3, "Expected a 3D tensor as input"
|
66 |
-
B, N, T = spec.shape
|
67 |
-
|
68 |
-
# Inverse FFT
|
69 |
-
ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward")
|
70 |
-
ifft = ifft * self.window[None, :, None]
|
71 |
-
|
72 |
-
# Overlap and Add
|
73 |
-
output_size = (T - 1) * self.hop_length + self.win_length
|
74 |
-
y = torch.nn.functional.fold(
|
75 |
-
ifft,
|
76 |
-
output_size=(1, output_size),
|
77 |
-
kernel_size=(1, self.win_length),
|
78 |
-
stride=(1, self.hop_length),
|
79 |
-
)[:, 0, 0, pad:-pad]
|
80 |
-
|
81 |
-
# Window envelope
|
82 |
-
window_sq = self.window.square().expand(1, T, -1).transpose(1, 2)
|
83 |
-
window_envelope = torch.nn.functional.fold(
|
84 |
-
window_sq,
|
85 |
-
output_size=(1, output_size),
|
86 |
-
kernel_size=(1, self.win_length),
|
87 |
-
stride=(1, self.hop_length),
|
88 |
-
).squeeze()[pad:-pad]
|
89 |
-
|
90 |
-
# Normalize
|
91 |
-
assert (window_envelope > 1e-11).all()
|
92 |
-
y = y / window_envelope
|
93 |
-
|
94 |
-
return y
|
95 |
-
|
96 |
-
|
97 |
-
class FourierHead(nn.Module):
|
98 |
-
"""Base class for inverse fourier modules."""
|
99 |
-
|
100 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
101 |
-
"""
|
102 |
-
Args:
|
103 |
-
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
|
104 |
-
L is the sequence length, and H denotes the model dimension.
|
105 |
-
|
106 |
-
Returns:
|
107 |
-
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
|
108 |
-
"""
|
109 |
-
raise NotImplementedError("Subclasses must implement the forward method.")
|
110 |
-
|
111 |
-
|
112 |
-
class ISTFTHead(FourierHead):
|
113 |
-
"""
|
114 |
-
ISTFT Head module for predicting STFT complex coefficients.
|
115 |
-
|
116 |
-
Args:
|
117 |
-
dim (int): Hidden dimension of the model.
|
118 |
-
n_fft (int): Size of Fourier transform.
|
119 |
-
hop_length (int): The distance between neighboring sliding window frames, which should align with
|
120 |
-
the resolution of the input features.
|
121 |
-
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
122 |
-
"""
|
123 |
-
|
124 |
-
def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"):
|
125 |
-
super().__init__()
|
126 |
-
out_dim = n_fft + 2
|
127 |
-
self.out = torch.nn.Linear(dim, out_dim)
|
128 |
-
self.istft = ISTFT(
|
129 |
-
n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding
|
130 |
-
)
|
131 |
-
|
132 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
133 |
-
"""
|
134 |
-
Forward pass of the ISTFTHead module.
|
135 |
-
|
136 |
-
Args:
|
137 |
-
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
|
138 |
-
L is the sequence length, and H denotes the model dimension.
|
139 |
-
|
140 |
-
Returns:
|
141 |
-
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
|
142 |
-
"""
|
143 |
-
x_pred = self.out(x)
|
144 |
-
# x_pred = x
|
145 |
-
x_pred = x_pred.transpose(1, 2)
|
146 |
-
mag, p = x_pred.chunk(2, dim=1)
|
147 |
-
mag = torch.exp(mag)
|
148 |
-
mag = torch.clip(
|
149 |
-
mag, max=1e2
|
150 |
-
) # safeguard to prevent excessively large magnitudes
|
151 |
-
# wrapping happens here. These two lines produce real and imaginary value
|
152 |
-
x = torch.cos(p)
|
153 |
-
y = torch.sin(p)
|
154 |
-
# recalculating phase here does not produce anything new
|
155 |
-
# only costs time
|
156 |
-
# phase = torch.atan2(y, x)
|
157 |
-
# S = mag * torch.exp(phase * 1j)
|
158 |
-
# better directly produce the complex value
|
159 |
-
S = mag * (x + 1j * y)
|
160 |
-
audio = self.istft(S)
|
161 |
-
return audio.unsqueeze(1), x_pred
|
162 |
-
|
163 |
-
|
164 |
-
def nonlinearity(x):
|
165 |
-
# swish
|
166 |
-
return x * torch.sigmoid(x)
|
167 |
-
|
168 |
-
|
169 |
-
def Normalize(in_channels, num_groups=32):
|
170 |
-
return torch.nn.GroupNorm(
|
171 |
-
num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
|
172 |
-
)
|
173 |
-
|
174 |
-
|
175 |
-
class ResnetBlock(nn.Module):
|
176 |
-
def __init__(
|
177 |
-
self,
|
178 |
-
*,
|
179 |
-
in_channels,
|
180 |
-
out_channels=None,
|
181 |
-
conv_shortcut=False,
|
182 |
-
dropout,
|
183 |
-
temb_channels=512,
|
184 |
-
):
|
185 |
-
super().__init__()
|
186 |
-
self.in_channels = in_channels
|
187 |
-
out_channels = in_channels if out_channels is None else out_channels
|
188 |
-
self.out_channels = out_channels
|
189 |
-
self.use_conv_shortcut = conv_shortcut
|
190 |
-
|
191 |
-
self.norm1 = Normalize(in_channels)
|
192 |
-
self.conv1 = torch.nn.Conv1d(
|
193 |
-
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
194 |
-
)
|
195 |
-
if temb_channels > 0:
|
196 |
-
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
|
197 |
-
self.norm2 = Normalize(out_channels)
|
198 |
-
self.dropout = torch.nn.Dropout(dropout)
|
199 |
-
self.conv2 = torch.nn.Conv1d(
|
200 |
-
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
201 |
-
)
|
202 |
-
if self.in_channels != self.out_channels:
|
203 |
-
if self.use_conv_shortcut:
|
204 |
-
self.conv_shortcut = torch.nn.Conv1d(
|
205 |
-
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
206 |
-
)
|
207 |
-
else:
|
208 |
-
self.nin_shortcut = torch.nn.Conv1d(
|
209 |
-
in_channels, out_channels, kernel_size=1, stride=1, padding=0
|
210 |
-
)
|
211 |
-
|
212 |
-
def forward(self, x, temb=None):
|
213 |
-
h = x
|
214 |
-
h = self.norm1(h)
|
215 |
-
h = nonlinearity(h)
|
216 |
-
h = self.conv1(h)
|
217 |
-
|
218 |
-
if temb is not None:
|
219 |
-
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
|
220 |
-
|
221 |
-
h = self.norm2(h)
|
222 |
-
h = nonlinearity(h)
|
223 |
-
h = self.dropout(h)
|
224 |
-
h = self.conv2(h)
|
225 |
-
|
226 |
-
if self.in_channels != self.out_channels:
|
227 |
-
if self.use_conv_shortcut:
|
228 |
-
x = self.conv_shortcut(x)
|
229 |
-
else:
|
230 |
-
x = self.nin_shortcut(x)
|
231 |
-
|
232 |
-
return x + h
|
233 |
-
|
234 |
-
|
235 |
-
class Backbone(nn.Module):
|
236 |
-
"""Base class for the generator's backbone. It preserves the same temporal resolution across all layers."""
|
237 |
-
|
238 |
-
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
|
239 |
-
"""
|
240 |
-
Args:
|
241 |
-
x (Tensor): Input tensor of shape (B, C, L), where B is the batch size,
|
242 |
-
C denotes output features, and L is the sequence length.
|
243 |
-
|
244 |
-
Returns:
|
245 |
-
Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length,
|
246 |
-
and H denotes the model dimension.
|
247 |
-
"""
|
248 |
-
raise NotImplementedError("Subclasses must implement the forward method.")
|
249 |
-
|
250 |
-
|
251 |
-
class VocosBackbone(Backbone):
|
252 |
-
"""
|
253 |
-
Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization
|
254 |
-
|
255 |
-
Args:
|
256 |
-
input_channels (int): Number of input features channels.
|
257 |
-
dim (int): Hidden dimension of the model.
|
258 |
-
intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock.
|
259 |
-
num_layers (int): Number of ConvNeXtBlock layers.
|
260 |
-
layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`.
|
261 |
-
adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
|
262 |
-
None means non-conditional model. Defaults to None.
|
263 |
-
"""
|
264 |
-
|
265 |
-
def __init__(self, hidden_dim=1024, depth=12, heads=16, pos_meb_dim=64):
|
266 |
-
super().__init__()
|
267 |
-
|
268 |
-
self.embed = nn.Conv1d(hidden_dim, hidden_dim, kernel_size=7, padding=3)
|
269 |
-
|
270 |
-
self.temb_ch = 0
|
271 |
-
block_in = hidden_dim
|
272 |
-
dropout = 0.1
|
273 |
-
|
274 |
-
prior_net: List[nn.Module] = [
|
275 |
-
ResnetBlock(
|
276 |
-
in_channels=block_in,
|
277 |
-
out_channels=block_in,
|
278 |
-
temb_channels=self.temb_ch,
|
279 |
-
dropout=dropout,
|
280 |
-
),
|
281 |
-
ResnetBlock(
|
282 |
-
in_channels=block_in,
|
283 |
-
out_channels=block_in,
|
284 |
-
temb_channels=self.temb_ch,
|
285 |
-
dropout=dropout,
|
286 |
-
),
|
287 |
-
]
|
288 |
-
self.prior_net = nn.Sequential(*prior_net)
|
289 |
-
|
290 |
-
depth = depth
|
291 |
-
time_rotary_embed = RotaryPositionalEmbeddings(dim=pos_meb_dim)
|
292 |
-
|
293 |
-
transformer_blocks = [
|
294 |
-
TransformerBlock(
|
295 |
-
dim=hidden_dim, n_heads=heads, rotary_embed=time_rotary_embed
|
296 |
-
)
|
297 |
-
for _ in range(depth)
|
298 |
-
]
|
299 |
-
|
300 |
-
self.transformers = nn.Sequential(*transformer_blocks)
|
301 |
-
self.final_layer_norm = nn.LayerNorm(hidden_dim, eps=1e-6)
|
302 |
-
post_net: List[nn.Module] = [
|
303 |
-
ResnetBlock(
|
304 |
-
in_channels=block_in,
|
305 |
-
out_channels=block_in,
|
306 |
-
temb_channels=self.temb_ch,
|
307 |
-
dropout=dropout,
|
308 |
-
),
|
309 |
-
ResnetBlock(
|
310 |
-
in_channels=block_in,
|
311 |
-
out_channels=block_in,
|
312 |
-
temb_channels=self.temb_ch,
|
313 |
-
dropout=dropout,
|
314 |
-
),
|
315 |
-
]
|
316 |
-
self.post_net = nn.Sequential(*post_net)
|
317 |
-
|
318 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
319 |
-
x = x.transpose(1, 2)
|
320 |
-
x = self.embed(x)
|
321 |
-
x = self.prior_net(x)
|
322 |
-
x = x.transpose(1, 2)
|
323 |
-
x = self.transformers(x)
|
324 |
-
x = x.transpose(1, 2)
|
325 |
-
x = self.post_net(x)
|
326 |
-
x = x.transpose(1, 2)
|
327 |
-
x = self.final_layer_norm(x)
|
328 |
-
return x
|
329 |
-
|
330 |
-
|
331 |
-
def init_weights(m):
|
332 |
-
if isinstance(m, nn.Conv1d):
|
333 |
-
nn.init.trunc_normal_(m.weight, std=0.02)
|
334 |
-
nn.init.constant_(m.bias, 0)
|
335 |
-
|
336 |
-
|
337 |
-
class CodecDecoderVocos(nn.Module):
|
338 |
-
def __init__(
|
339 |
-
self,
|
340 |
-
hidden_dim=1024,
|
341 |
-
depth=12,
|
342 |
-
heads=16,
|
343 |
-
pos_meb_dim=64,
|
344 |
-
hop_length=320,
|
345 |
-
vq_num_quantizers=1,
|
346 |
-
vq_dim=2048, # 1024 2048
|
347 |
-
vq_commit_weight=0.25,
|
348 |
-
vq_weight_init=False,
|
349 |
-
vq_full_commit_loss=False,
|
350 |
-
codebook_size=16384,
|
351 |
-
codebook_dim=16,
|
352 |
-
):
|
353 |
-
super().__init__()
|
354 |
-
self.hop_length = hop_length
|
355 |
-
|
356 |
-
self.quantizer = ResidualFSQ(
|
357 |
-
dim=vq_dim, levels=[4, 4, 4, 4, 4, 4, 4, 4], num_quantizers=1
|
358 |
-
)
|
359 |
-
|
360 |
-
self.backbone = VocosBackbone(
|
361 |
-
hidden_dim=hidden_dim, depth=depth, heads=heads, pos_meb_dim=pos_meb_dim
|
362 |
-
)
|
363 |
-
|
364 |
-
self.head = ISTFTHead(
|
365 |
-
dim=hidden_dim,
|
366 |
-
n_fft=self.hop_length * 4,
|
367 |
-
hop_length=self.hop_length,
|
368 |
-
padding="same",
|
369 |
-
)
|
370 |
-
|
371 |
-
self.reset_parameters()
|
372 |
-
|
373 |
-
def forward(self, x, vq=True):
|
374 |
-
if vq is True:
|
375 |
-
# x, q, commit_loss = self.quantizer(x)
|
376 |
-
x = x.permute(0, 2, 1)
|
377 |
-
x, q = self.quantizer(x)
|
378 |
-
x = x.permute(0, 2, 1)
|
379 |
-
q = q.permute(0, 2, 1)
|
380 |
-
return x, q, None
|
381 |
-
x = self.backbone(x)
|
382 |
-
x, _ = self.head(x)
|
383 |
-
|
384 |
-
return x, _
|
385 |
-
|
386 |
-
def vq2emb(self, vq):
|
387 |
-
self.quantizer = self.quantizer.eval()
|
388 |
-
x = self.quantizer.vq2emb(vq)
|
389 |
-
return x
|
390 |
-
|
391 |
-
def get_emb(self):
|
392 |
-
self.quantizer = self.quantizer.eval()
|
393 |
-
embs = self.quantizer.get_emb()
|
394 |
-
return embs
|
395 |
-
|
396 |
-
def inference_vq(self, vq):
|
397 |
-
x = vq[None, :, :]
|
398 |
-
x = self.model(x)
|
399 |
-
return x
|
400 |
-
|
401 |
-
def inference_0(self, x):
|
402 |
-
x, q, loss, perp = self.quantizer(x)
|
403 |
-
x = self.model(x)
|
404 |
-
return x, None
|
405 |
-
|
406 |
-
def inference(self, x):
|
407 |
-
x = self.model(x)
|
408 |
-
return x, None
|
409 |
-
|
410 |
-
def remove_weight_norm(self):
|
411 |
-
"""Remove weight normalization module from all of the layers."""
|
412 |
-
|
413 |
-
def _remove_weight_norm(m):
|
414 |
-
try:
|
415 |
-
torch.nn.utils.remove_weight_norm(m)
|
416 |
-
except ValueError: # this module didn't have weight norm
|
417 |
-
return
|
418 |
-
|
419 |
-
self.apply(_remove_weight_norm)
|
420 |
-
|
421 |
-
def apply_weight_norm(self):
|
422 |
-
"""Apply weight normalization module from all of the layers."""
|
423 |
-
|
424 |
-
def _apply_weight_norm(m):
|
425 |
-
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d):
|
426 |
-
torch.nn.utils.weight_norm(m)
|
427 |
-
|
428 |
-
self.apply(_apply_weight_norm)
|
429 |
-
|
430 |
-
def reset_parameters(self):
|
431 |
-
self.apply(init_weights)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
neucodec/codec_encoder.py
DELETED
@@ -1,84 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import numpy as np
|
3 |
-
|
4 |
-
from torch import nn
|
5 |
-
|
6 |
-
from .module import WNConv1d, EncoderBlock
|
7 |
-
from .alias_free_torch import Activation1d
|
8 |
-
from . import activations
|
9 |
-
|
10 |
-
|
11 |
-
def init_weights(m):
|
12 |
-
if isinstance(m, nn.Conv1d):
|
13 |
-
nn.init.trunc_normal_(m.weight, std=0.02)
|
14 |
-
nn.init.constant_(m.bias, 0)
|
15 |
-
|
16 |
-
|
17 |
-
class CodecEncoder(nn.Module):
|
18 |
-
def __init__(
|
19 |
-
self,
|
20 |
-
ngf=48,
|
21 |
-
up_ratios=[2, 2, 4, 4, 5],
|
22 |
-
dilations=(1, 3, 9),
|
23 |
-
hidden_dim=1024,
|
24 |
-
depth=12,
|
25 |
-
heads=12,
|
26 |
-
pos_meb_dim=64,
|
27 |
-
):
|
28 |
-
super().__init__()
|
29 |
-
self.hop_length = np.prod(up_ratios)
|
30 |
-
self.ngf = ngf
|
31 |
-
self.up_ratios = up_ratios
|
32 |
-
|
33 |
-
d_model = ngf
|
34 |
-
self.conv_blocks = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
|
35 |
-
|
36 |
-
for i, stride in enumerate(up_ratios):
|
37 |
-
d_model *= 2
|
38 |
-
self.conv_blocks += [
|
39 |
-
EncoderBlock(d_model, stride=stride, dilations=dilations)
|
40 |
-
]
|
41 |
-
|
42 |
-
self.conv_blocks = nn.Sequential(*self.conv_blocks)
|
43 |
-
|
44 |
-
self.conv_final_block = [
|
45 |
-
Activation1d(
|
46 |
-
activation=activations.SnakeBeta(d_model, alpha_logscale=True)
|
47 |
-
),
|
48 |
-
WNConv1d(d_model, hidden_dim, kernel_size=3, padding=1),
|
49 |
-
]
|
50 |
-
self.conv_final_block = nn.Sequential(*self.conv_final_block)
|
51 |
-
|
52 |
-
self.reset_parameters()
|
53 |
-
|
54 |
-
def forward(self, x):
|
55 |
-
x = self.conv_blocks(x)
|
56 |
-
x = self.conv_final_block(x)
|
57 |
-
x = x.permute(0, 2, 1)
|
58 |
-
return x
|
59 |
-
|
60 |
-
def inference(self, x):
|
61 |
-
return self.block(x)
|
62 |
-
|
63 |
-
def remove_weight_norm(self):
|
64 |
-
"""Remove weight normalization module from all of the layers."""
|
65 |
-
|
66 |
-
def _remove_weight_norm(m):
|
67 |
-
try:
|
68 |
-
torch.nn.utils.remove_weight_norm(m)
|
69 |
-
except ValueError: # this module didn't have weight norm
|
70 |
-
return
|
71 |
-
|
72 |
-
self.apply(_remove_weight_norm)
|
73 |
-
|
74 |
-
def apply_weight_norm(self):
|
75 |
-
"""Apply weight normalization module from all of the layers."""
|
76 |
-
|
77 |
-
def _apply_weight_norm(m):
|
78 |
-
if isinstance(m, nn.Conv1d):
|
79 |
-
torch.nn.utils.weight_norm(m)
|
80 |
-
|
81 |
-
self.apply(_apply_weight_norm)
|
82 |
-
|
83 |
-
def reset_parameters(self):
|
84 |
-
self.apply(init_weights)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
neucodec/model.py
DELETED
@@ -1,269 +0,0 @@
|
|
1 |
-
import soundfile as sf
|
2 |
-
import os
|
3 |
-
|
4 |
-
import torch
|
5 |
-
import torch.nn as nn
|
6 |
-
import torch.nn.functional as F
|
7 |
-
import torchaudio
|
8 |
-
|
9 |
-
from typing import Optional
|
10 |
-
from torchaudio import transforms as T
|
11 |
-
from transformers import AutoFeatureExtractor, Wav2Vec2BertModel
|
12 |
-
|
13 |
-
from .codec_encoder import CodecEncoder
|
14 |
-
from .codec_decoder_vocos import CodecDecoderVocos
|
15 |
-
from .module import SemanticEncoder
|
16 |
-
|
17 |
-
|
18 |
-
class NeuCodec(nn.Module):
|
19 |
-
def __init__(self, ckpt_path: str, sample_rate: int, hop_length: int):
|
20 |
-
super().__init__()
|
21 |
-
|
22 |
-
# load ckpt
|
23 |
-
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
|
24 |
-
self.sample_rate = sample_rate
|
25 |
-
self.hop_length = hop_length
|
26 |
-
|
27 |
-
# load modules
|
28 |
-
self.semantic_model = Wav2Vec2BertModel.from_pretrained(
|
29 |
-
"facebook/w2v-bert-2.0", output_hidden_states=True
|
30 |
-
)
|
31 |
-
self.semantic_model.eval()
|
32 |
-
self.feature_extractor = AutoFeatureExtractor.from_pretrained(
|
33 |
-
"facebook/w2v-bert-2.0"
|
34 |
-
)
|
35 |
-
self.SemanticEncoder_module = SemanticEncoder(1024, 1024, 1024)
|
36 |
-
self.CodecEnc = CodecEncoder()
|
37 |
-
self.generator = CodecDecoderVocos(hop_length=hop_length)
|
38 |
-
self.fc_prior = nn.Linear(2048, 2048)
|
39 |
-
self.fc_post_a = nn.Linear(2048, 1024)
|
40 |
-
|
41 |
-
# load checkpoint
|
42 |
-
self._load_ckpt(ckpt)
|
43 |
-
|
44 |
-
def _load_ckpt(self, ckpt):
|
45 |
-
# differentiate between `.ckpt` and `.bin`
|
46 |
-
if ckpt.get("state_dict"):
|
47 |
-
state_dicts = ckpt.get("state_dict")
|
48 |
-
else:
|
49 |
-
state_dicts = ckpt
|
50 |
-
|
51 |
-
# assign keys to correct model components
|
52 |
-
filtered_enc = {}
|
53 |
-
filtered_gen = {}
|
54 |
-
filtered_post = {}
|
55 |
-
filtered_prior = {}
|
56 |
-
filtered_semantic = {}
|
57 |
-
for key, value in state_dicts.items():
|
58 |
-
if key.startswith("CodecEnc."):
|
59 |
-
new_key = key[len("CodecEnc."):]
|
60 |
-
filtered_enc[new_key] = value
|
61 |
-
elif key.startswith("generator."):
|
62 |
-
new_key = key[len("generator."):]
|
63 |
-
filtered_gen[new_key] = value
|
64 |
-
elif key.startswith("fc_post_a."):
|
65 |
-
new_key = key[len("fc_post_a."):]
|
66 |
-
filtered_post[new_key] = value
|
67 |
-
elif key.startswith("SemanticEncoder_module."):
|
68 |
-
new_key = key[len("SemanticEncoder_module."):]
|
69 |
-
filtered_semantic[new_key] = value
|
70 |
-
elif key.startswith("fc_prior."):
|
71 |
-
new_key = key[len("fc_prior."):]
|
72 |
-
filtered_prior[new_key] = value
|
73 |
-
|
74 |
-
# load
|
75 |
-
self.CodecEnc.load_state_dict(filtered_enc)
|
76 |
-
self.CodecEnc.eval()
|
77 |
-
self.generator.load_state_dict(filtered_gen, strict=False)
|
78 |
-
self.generator.eval()
|
79 |
-
self.fc_post_a.load_state_dict(filtered_post)
|
80 |
-
self.fc_post_a.eval()
|
81 |
-
self.fc_prior.load_state_dict(filtered_prior)
|
82 |
-
self.SemanticEncoder_module.load_state_dict(filtered_semantic)
|
83 |
-
self.SemanticEncoder_module.eval()
|
84 |
-
|
85 |
-
@torch.inference_mode()
|
86 |
-
def encode_code(
|
87 |
-
self,
|
88 |
-
input_waveform: torch.Tensor,
|
89 |
-
semantic_features: torch.Tensor = None,
|
90 |
-
sample_rate: int = 16_000,
|
91 |
-
) -> torch.Tensor:
|
92 |
-
pad_for_wav = 320 - (input_waveform.shape[1] % 320)
|
93 |
-
input_waveform = torch.nn.functional.pad(input_waveform, (0, pad_for_wav))
|
94 |
-
|
95 |
-
if semantic_features is None:
|
96 |
-
semantic_features = self.feature_extractor(
|
97 |
-
input_waveform, sampling_rate=sample_rate, return_tensors="pt"
|
98 |
-
).input_features.to(self.device) # [batch, frames, feat_dim]
|
99 |
-
else:
|
100 |
-
semantic_features = semantic_features[:, 0, :, :]
|
101 |
-
|
102 |
-
semantic_output = self.semantic_model(semantic_features)
|
103 |
-
semantic_hidden_16 = semantic_output.hidden_states[16]
|
104 |
-
semantic_hidden_16 = semantic_hidden_16.transpose(
|
105 |
-
1, 2
|
106 |
-
) # [batch, hidden_dim, frames]
|
107 |
-
semantic_encoded = self.SemanticEncoder_module(semantic_hidden_16)
|
108 |
-
if len(input_waveform.shape) == 2:
|
109 |
-
wav = input_waveform.unsqueeze(1).to(self.device) # shape: [batch, 1, time]
|
110 |
-
else:
|
111 |
-
wav = input_waveform.to(self.device)
|
112 |
-
|
113 |
-
vq_emb = self.CodecEnc(wav) # [batch, time//down, 1024]
|
114 |
-
vq_emb = vq_emb.transpose(1, 2) # -> [batch, 1024, frames]
|
115 |
-
|
116 |
-
if vq_emb.shape[-1] != semantic_encoded.shape[-1]:
|
117 |
-
min_len = min(vq_emb.shape[-1], semantic_encoded.shape[-1])
|
118 |
-
vq_emb = vq_emb[:, :, :min_len]
|
119 |
-
semantic_encoded = semantic_encoded[:, :, :min_len]
|
120 |
-
concat_emb = torch.cat(
|
121 |
-
[semantic_encoded, vq_emb], dim=1
|
122 |
-
) # [batch, 2048, frames]
|
123 |
-
concat_emb = self.fc_prior(concat_emb.transpose(1, 2)).transpose(1, 2)
|
124 |
-
_, vq_code, _ = self.generator(concat_emb, vq=True)
|
125 |
-
return vq_code
|
126 |
-
|
127 |
-
@torch.inference_mode()
|
128 |
-
def decode_code(self, vq_code: torch.Tensor) -> torch.Tensor:
|
129 |
-
vq_post_emb = self.generator.quantizer.get_output_from_indices(
|
130 |
-
vq_code.transpose(1, 2)
|
131 |
-
)
|
132 |
-
vq_post_emb = vq_post_emb.transpose(1, 2) # [batch, 1024, frames]
|
133 |
-
vq_post_emb = self.fc_post_a(vq_post_emb.transpose(1, 2)).transpose(
|
134 |
-
1, 2
|
135 |
-
) # [batch, 1024, frames]
|
136 |
-
recon_audio = self.generator(vq_post_emb.transpose(1, 2), vq=False)[
|
137 |
-
0
|
138 |
-
] # [batch, time]
|
139 |
-
return recon_audio
|
140 |
-
|
141 |
-
@torch.inference_mode()
|
142 |
-
def autoencode(self, fpath: str, output_fpath: Optional[str] = None):
|
143 |
-
y, sr = torchaudio.load(fpath)
|
144 |
-
if sr != 16_000:
|
145 |
-
y = T.Resample(sr, 16_000)(y)
|
146 |
-
vq_codes = self.encode_code(y)
|
147 |
-
recon = self.decode_code(vq_codes)
|
148 |
-
|
149 |
-
if output_fpath is None:
|
150 |
-
name, fext = os.path.splitext(fpath)
|
151 |
-
output_fpath = f"{name}_recon{fext}"
|
152 |
-
|
153 |
-
sf.write(output_fpath, recon[0, 0, :].cpu(), self.sample_rate)
|
154 |
-
|
155 |
-
@torch.inference_mode()
|
156 |
-
def batch_encode(
|
157 |
-
self, fpaths: list[str], return_tensor: bool = False
|
158 |
-
) -> tuple[list[torch.Tensor], list[int]] | tuple[torch.Tensor, list[int]]:
|
159 |
-
# prepare batch
|
160 |
-
wavs_batch, semantic_batch, token_durations = self._pad_batch(
|
161 |
-
[self._preprocess_file(fpath) for fpath in fpaths]
|
162 |
-
)
|
163 |
-
vq_codes = self.encode_code(wavs_batch, semantic_batch)
|
164 |
-
|
165 |
-
# return, unpad if we want to
|
166 |
-
if return_tensor:
|
167 |
-
return vq_codes, list(token_durations)
|
168 |
-
|
169 |
-
unpadded_vq_codes = []
|
170 |
-
for idx, token_dur in enumerate(token_durations):
|
171 |
-
curr_codes = vq_codes[idx, :, :token_dur]
|
172 |
-
unpadded_vq_codes.append(curr_codes)
|
173 |
-
|
174 |
-
return unpadded_vq_codes, None
|
175 |
-
|
176 |
-
@torch.inference_mode()
|
177 |
-
def batch_decode(
|
178 |
-
self,
|
179 |
-
vq_codes: list[torch.Tensor] | torch.Tensor,
|
180 |
-
token_durations: Optional[list[int]] = None,
|
181 |
-
):
|
182 |
-
# pad tensor if need be
|
183 |
-
if isinstance(vq_codes, list):
|
184 |
-
vq_codes, token_durations = self._pad_codes(vq_codes)
|
185 |
-
else:
|
186 |
-
assert token_durations is not None
|
187 |
-
|
188 |
-
# decode
|
189 |
-
recons = self.decode_code(vq_codes)
|
190 |
-
|
191 |
-
# unpad
|
192 |
-
cut_recons = []
|
193 |
-
for idx, token_dur in enumerate(token_durations):
|
194 |
-
curr_recon = recons[idx, :, : int(token_dur * self.hop_length)]
|
195 |
-
cut_recons.append(curr_recon)
|
196 |
-
|
197 |
-
return cut_recons
|
198 |
-
|
199 |
-
@torch.inference_mode()
|
200 |
-
def batch_autoencode(
|
201 |
-
self, fpaths: list[str], output_fpaths: Optional[list[str]] = None
|
202 |
-
) -> list[torch.Tensor]:
|
203 |
-
vq_codes, token_durations = self.batch_encode(fpaths, return_tensor=True)
|
204 |
-
cut_recons = self.batch_decode(vq_codes, token_durations)
|
205 |
-
|
206 |
-
if output_fpaths:
|
207 |
-
for recon, output_fpath in zip(cut_recons, output_fpaths):
|
208 |
-
sf.write(output_fpath, recon.cpu().numpy()[0, :], self.sample_rate)
|
209 |
-
|
210 |
-
return cut_recons
|
211 |
-
|
212 |
-
def _preprocess_file(self, fpath: str):
|
213 |
-
# load and resample
|
214 |
-
y, sr = torchaudio.load(fpath)
|
215 |
-
if sr != 16_000:
|
216 |
-
y = T.Resample(sr, 16_000)(y)
|
217 |
-
|
218 |
-
# compute duration for any cutting we might need to do, in terms of n_tokens
|
219 |
-
token_duration = int((y.shape[-1] / 16_000) * 50)
|
220 |
-
|
221 |
-
# get semantic model features: [harry] note i don't think this can be batched
|
222 |
-
semantic_model_input = self.feature_extractor(
|
223 |
-
y, sampling_rate=16_000, return_tensors="pt"
|
224 |
-
).input_features
|
225 |
-
|
226 |
-
return y.to(self.device), semantic_model_input.to(self.device), token_duration
|
227 |
-
|
228 |
-
def _pad_batch(self, batch: list[tuple[torch.Tensor, torch.Tensor, int]]):
|
229 |
-
# unpack batch
|
230 |
-
wavs, semantic_features, token_durations = zip(*batch)
|
231 |
-
max_length_semantic = max([f.shape[1] for f in semantic_features])
|
232 |
-
max_length = max_length_semantic * 320
|
233 |
-
|
234 |
-
# pad wavs
|
235 |
-
wavs_padded = []
|
236 |
-
for audio in wavs:
|
237 |
-
padding = max_length - audio.shape[1]
|
238 |
-
if padding > 0:
|
239 |
-
padded_audio = F.pad(audio, (0, padding), mode="constant", value=0)
|
240 |
-
else:
|
241 |
-
padded_audio = audio[:, :max_length]
|
242 |
-
wavs_padded.append(padded_audio)
|
243 |
-
wavs_tensor = torch.stack(wavs_padded)
|
244 |
-
|
245 |
-
# pad semantic features
|
246 |
-
semantic_features_padded = []
|
247 |
-
for feat in semantic_features:
|
248 |
-
padding = max_length_semantic - feat.shape[1]
|
249 |
-
padded_feat = F.pad(feat, (0, 0, 0, padding), mode="constant", value=0)
|
250 |
-
semantic_features_padded.append(padded_feat)
|
251 |
-
semantic_feature_tensor = torch.stack(semantic_features_padded)
|
252 |
-
|
253 |
-
return wavs_tensor, semantic_feature_tensor, token_durations
|
254 |
-
|
255 |
-
def _pad_codes(self, vq_codes: list[torch.Tensor]):
|
256 |
-
max_len = max([i.shape[-1] for i in vq_codes])
|
257 |
-
token_durations = []
|
258 |
-
padded_codes = []
|
259 |
-
for curr_codes in vq_codes:
|
260 |
-
curr_len = curr_codes.shape[-1]
|
261 |
-
token_durations.append(curr_len)
|
262 |
-
padding = max_len - curr_len
|
263 |
-
curr_codes = F.pad(curr_codes, (0, padding), mode="constant", value=0)
|
264 |
-
padded_codes.append(curr_codes)
|
265 |
-
return torch.stack(padded_codes), token_durations
|
266 |
-
|
267 |
-
@property
|
268 |
-
def device(self):
|
269 |
-
return next(self.parameters()).device
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
neucodec/module.py
DELETED
@@ -1,114 +0,0 @@
|
|
1 |
-
import torch.nn as nn
|
2 |
-
|
3 |
-
from torch.nn.utils import weight_norm
|
4 |
-
|
5 |
-
from .activations import SnakeBeta
|
6 |
-
from .alias_free_torch import Activation1d
|
7 |
-
|
8 |
-
|
9 |
-
def WNConv1d(*args, **kwargs):
|
10 |
-
return weight_norm(nn.Conv1d(*args, **kwargs))
|
11 |
-
|
12 |
-
|
13 |
-
class ResidualUnit(nn.Module):
|
14 |
-
def __init__(self, dim: int = 16, dilation: int = 1):
|
15 |
-
super().__init__()
|
16 |
-
pad = ((7 - 1) * dilation) // 2
|
17 |
-
self.block = nn.Sequential(
|
18 |
-
Activation1d(activation=SnakeBeta(dim, alpha_logscale=True)),
|
19 |
-
WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
|
20 |
-
Activation1d(activation=SnakeBeta(dim, alpha_logscale=True)),
|
21 |
-
WNConv1d(dim, dim, kernel_size=1),
|
22 |
-
)
|
23 |
-
|
24 |
-
def forward(self, x):
|
25 |
-
return x + self.block(x)
|
26 |
-
|
27 |
-
|
28 |
-
class EncoderBlock(nn.Module):
|
29 |
-
def __init__(self, dim: int = 16, stride: int = 1, dilations=(1, 3, 9)):
|
30 |
-
super().__init__()
|
31 |
-
runits = [ResidualUnit(dim // 2, dilation=d) for d in dilations]
|
32 |
-
self.block = nn.Sequential(
|
33 |
-
*runits,
|
34 |
-
Activation1d(activation=SnakeBeta(dim // 2, alpha_logscale=True)),
|
35 |
-
WNConv1d(
|
36 |
-
dim // 2,
|
37 |
-
dim,
|
38 |
-
kernel_size=2 * stride,
|
39 |
-
stride=stride,
|
40 |
-
padding=stride // 2 + stride % 2,
|
41 |
-
),
|
42 |
-
)
|
43 |
-
|
44 |
-
def forward(self, x):
|
45 |
-
return self.block(x)
|
46 |
-
|
47 |
-
|
48 |
-
class SemanticEncoder(nn.Module):
|
49 |
-
def __init__(
|
50 |
-
self,
|
51 |
-
input_channels: int,
|
52 |
-
code_dim: int,
|
53 |
-
encode_channels: int,
|
54 |
-
kernel_size: int = 3,
|
55 |
-
bias: bool = True,
|
56 |
-
):
|
57 |
-
super(SemanticEncoder, self).__init__()
|
58 |
-
|
59 |
-
# 初始卷积,将 input_channels 映射到 encode_channels
|
60 |
-
self.initial_conv = nn.Conv1d(
|
61 |
-
in_channels=input_channels,
|
62 |
-
out_channels=encode_channels,
|
63 |
-
kernel_size=kernel_size,
|
64 |
-
stride=1,
|
65 |
-
padding=(kernel_size - 1) // 2,
|
66 |
-
bias=False,
|
67 |
-
)
|
68 |
-
|
69 |
-
# 残差块
|
70 |
-
self.residual_blocks = nn.Sequential(
|
71 |
-
nn.ReLU(inplace=True),
|
72 |
-
nn.Conv1d(
|
73 |
-
encode_channels,
|
74 |
-
encode_channels,
|
75 |
-
kernel_size=kernel_size,
|
76 |
-
stride=1,
|
77 |
-
padding=(kernel_size - 1) // 2,
|
78 |
-
bias=bias,
|
79 |
-
),
|
80 |
-
nn.ReLU(inplace=True),
|
81 |
-
nn.Conv1d(
|
82 |
-
encode_channels,
|
83 |
-
encode_channels,
|
84 |
-
kernel_size=kernel_size,
|
85 |
-
stride=1,
|
86 |
-
padding=(kernel_size - 1) // 2,
|
87 |
-
bias=bias,
|
88 |
-
),
|
89 |
-
)
|
90 |
-
|
91 |
-
# 最终卷积,将 encode_channels 映射到 code_dim
|
92 |
-
self.final_conv = nn.Conv1d(
|
93 |
-
in_channels=encode_channels,
|
94 |
-
out_channels=code_dim,
|
95 |
-
kernel_size=kernel_size,
|
96 |
-
stride=1,
|
97 |
-
padding=(kernel_size - 1) // 2,
|
98 |
-
bias=False,
|
99 |
-
)
|
100 |
-
|
101 |
-
def forward(self, x):
|
102 |
-
"""
|
103 |
-
前向传播方法。
|
104 |
-
|
105 |
-
Args:
|
106 |
-
x (Tensor): 输入张量,形状为 (Batch, Input_channels, Length)
|
107 |
-
|
108 |
-
Returns:
|
109 |
-
Tensor: 编码后的张量,形状为 (Batch, Code_dim, Length)
|
110 |
-
"""
|
111 |
-
x = self.initial_conv(x) # (Batch, Encode_channels, Length)
|
112 |
-
x = self.residual_blocks(x) + x # 残差连接
|
113 |
-
x = self.final_conv(x) # (Batch, Code_dim, Length)
|
114 |
-
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
setup.py
DELETED
@@ -1,32 +0,0 @@
|
|
1 |
-
from setuptools import setup, find_packages
|
2 |
-
|
3 |
-
|
4 |
-
setup(
|
5 |
-
name='neucodec',
|
6 |
-
version='0.0.1',
|
7 |
-
description='A package for neucodec, based on xcodec2.',
|
8 |
-
long_description_content_type='text/markdown',
|
9 |
-
author='Harry Julian',
|
10 |
-
author_email='[email protected]',
|
11 |
-
packages=find_packages(),
|
12 |
-
install_requires=[
|
13 |
-
'librosa',
|
14 |
-
'soundfile',
|
15 |
-
'numpy>=2.0.2',
|
16 |
-
'omegaconf>=2.3.0',
|
17 |
-
'torch>=2.5.1',
|
18 |
-
'torchaudio>=2.5.1',
|
19 |
-
'torchao>=0.5.0',
|
20 |
-
'torchtune>=0.3.1',
|
21 |
-
'vector-quantize-pytorch>=1.17.8',
|
22 |
-
'rotary-embedding-torch>=0.8.4',
|
23 |
-
'transformers>=4.44.2',
|
24 |
-
'boto3>1.0',
|
25 |
-
'tqdm',
|
26 |
-
],
|
27 |
-
classifiers=[
|
28 |
-
'Programming Language :: Python',
|
29 |
-
'Programming Language :: Python :: 3',
|
30 |
-
'Programming Language :: Python :: 3.10',
|
31 |
-
],
|
32 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/__init__.py
DELETED
File without changes
|
tests/test_neucodec.py
DELETED
@@ -1,128 +0,0 @@
|
|
1 |
-
import pytest
|
2 |
-
import torch
|
3 |
-
import torchaudio
|
4 |
-
import librosa
|
5 |
-
from xcodec2 import XCodec2, MiniXCodec2Encoder
|
6 |
-
|
7 |
-
|
8 |
-
@pytest.fixture
|
9 |
-
def model_16khz():
|
10 |
-
return XCodec2.from_cache("16khz")
|
11 |
-
|
12 |
-
|
13 |
-
@pytest.fixture
|
14 |
-
def model_24khz():
|
15 |
-
return XCodec2.from_cache("24khz")
|
16 |
-
|
17 |
-
|
18 |
-
@pytest.fixture
|
19 |
-
def model_asr_encoder():
|
20 |
-
return MiniXCodec2Encoder.from_cache()
|
21 |
-
|
22 |
-
|
23 |
-
@pytest.fixture
|
24 |
-
def example_audio():
|
25 |
-
y, sr = torchaudio.load(librosa.ex("libri1"))
|
26 |
-
return y, sr
|
27 |
-
|
28 |
-
|
29 |
-
@pytest.fixture
|
30 |
-
def example_fpath():
|
31 |
-
return librosa.ex("libri1")
|
32 |
-
|
33 |
-
|
34 |
-
@pytest.fixture
|
35 |
-
def batch_fpaths():
|
36 |
-
return [librosa.ex("libri1"), librosa.ex("libri2")]
|
37 |
-
|
38 |
-
|
39 |
-
def load_and_validate_audio(save_path, sample_rate):
|
40 |
-
_, sr = torchaudio.load(save_path)
|
41 |
-
assert sr == sample_rate
|
42 |
-
|
43 |
-
|
44 |
-
def test_16khz_autoencode(example_fpath, tmp_path, model_16khz):
|
45 |
-
save_path = str(tmp_path / "0.wav")
|
46 |
-
model_16khz.autoencode(example_fpath, save_path)
|
47 |
-
load_and_validate_audio(save_path, 16_000)
|
48 |
-
|
49 |
-
|
50 |
-
def test_24khz_autoencode(example_fpath, tmp_path, model_24khz):
|
51 |
-
save_path = str(tmp_path / "0.wav")
|
52 |
-
model_24khz.autoencode(example_fpath, save_path)
|
53 |
-
load_and_validate_audio(save_path, 24_000)
|
54 |
-
|
55 |
-
|
56 |
-
def test_24khz_encode_decode_single(example_audio, model_24khz):
|
57 |
-
y, sr = example_audio
|
58 |
-
if sr != 16_000:
|
59 |
-
y = torchaudio.transforms.Resample(sr, 16_000)(y)
|
60 |
-
sr = 16_000
|
61 |
-
|
62 |
-
# encode
|
63 |
-
vq_codes = model_24khz.encode_code(y, sample_rate=sr)
|
64 |
-
assert isinstance(vq_codes, torch.Tensor)
|
65 |
-
assert vq_codes.dim() == 3 # [batch, channels, time]
|
66 |
-
|
67 |
-
# decode
|
68 |
-
reconstructed = model_24khz.decode_code(vq_codes)
|
69 |
-
assert isinstance(reconstructed, torch.Tensor)
|
70 |
-
assert reconstructed.dim() == 3 # [batch, channels, time]
|
71 |
-
|
72 |
-
|
73 |
-
def test_24khz_batch_encode(batch_fpaths, model_24khz):
|
74 |
-
vq_codes_list, token_durations = model_24khz.batch_encode(batch_fpaths, return_tensor=False)
|
75 |
-
assert isinstance(vq_codes_list, list)
|
76 |
-
assert token_durations is None
|
77 |
-
assert len(vq_codes_list) == 2
|
78 |
-
|
79 |
-
for codes in vq_codes_list:
|
80 |
-
assert isinstance(codes, torch.Tensor)
|
81 |
-
assert codes.dim() == 2 # [channels, time]
|
82 |
-
|
83 |
-
|
84 |
-
def test_24khz_batch_encode_tensor(batch_fpaths, model_24khz):
|
85 |
-
vq_codes_tensor, token_durations = model_24khz.batch_encode(batch_fpaths, return_tensor=True)
|
86 |
-
assert isinstance(vq_codes_tensor, torch.Tensor)
|
87 |
-
assert isinstance(token_durations, list)
|
88 |
-
assert vq_codes_tensor.dim() == 3 # [batch, channels, time]
|
89 |
-
assert len(token_durations) == 2
|
90 |
-
assert len(set(token_durations)) == 2 # ensure we get two different durations back
|
91 |
-
|
92 |
-
|
93 |
-
def test_24khz_batch_decode(batch_fpaths, model_24khz):
|
94 |
-
vq_codes_tensor, token_durations = model_24khz.batch_encode(batch_fpaths, return_tensor=True)
|
95 |
-
reconstructed_list = model_24khz.batch_decode(vq_codes_tensor, token_durations)
|
96 |
-
assert isinstance(reconstructed_list, list)
|
97 |
-
assert len(reconstructed_list) == 2
|
98 |
-
for recon in reconstructed_list:
|
99 |
-
assert isinstance(recon, torch.Tensor)
|
100 |
-
assert recon.dim() == 2 # [channels, time]
|
101 |
-
|
102 |
-
|
103 |
-
def test_24khz_batch_decode_list_input(batch_fpaths, model_24khz):
|
104 |
-
vq_codes_list, _ = model_24khz.batch_encode(batch_fpaths, return_tensor=False)
|
105 |
-
reconstructed_list = model_24khz.batch_decode(vq_codes_list)
|
106 |
-
assert isinstance(reconstructed_list, list)
|
107 |
-
assert len(reconstructed_list) == 2
|
108 |
-
for recon in reconstructed_list:
|
109 |
-
assert isinstance(recon, torch.Tensor)
|
110 |
-
assert recon.dim() == 2 # [channels, time]
|
111 |
-
|
112 |
-
|
113 |
-
def test_24khz_batch_autoencode(batch_fpaths, tmp_path, model_24khz):
|
114 |
-
output_paths = [str(tmp_path / f"{i}.wav") for i in range(len(batch_fpaths))]
|
115 |
-
reconstructed_list = model_24khz.batch_autoencode(batch_fpaths, output_paths)
|
116 |
-
assert isinstance(reconstructed_list, list)
|
117 |
-
assert len(reconstructed_list) == 2
|
118 |
-
for i, output_path in enumerate(output_paths):
|
119 |
-
load_and_validate_audio(output_path, 24_000)
|
120 |
-
|
121 |
-
|
122 |
-
def test_asr_encoder_encode(example_audio, model_asr_encoder):
|
123 |
-
y, sr = example_audio
|
124 |
-
if sr != model_asr_encoder.sample_rate:
|
125 |
-
y = torchaudio.transforms.Resample(sr, model_asr_encoder.sample_rate)(y)
|
126 |
-
vq_codes = model_asr_encoder.encode_code(y)
|
127 |
-
assert isinstance(vq_codes, torch.Tensor)
|
128 |
-
assert vq_codes.dim() == 3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|