gosummer commited on
Commit
e625816
·
verified ·
1 Parent(s): d958504

Upload 14 files

Browse files
MDX23v24/modules/__pycache__/segm_models.cpython-310.pyc ADDED
Binary file (4.09 kB). View file
 
MDX23v24/modules/__pycache__/tfc_tdf_v2.cpython-310.pyc ADDED
Binary file (2.48 kB). View file
 
MDX23v24/modules/__pycache__/tfc_tdf_v3.cpython-310.pyc ADDED
Binary file (6.65 kB). View file
 
MDX23v24/modules/bs_roformer/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from modules.bs_roformer.bs_roformer import BSRoformer
2
+ from modules.bs_roformer.mel_band_roformer import MelBandRoformer
MDX23v24/modules/bs_roformer/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (312 Bytes). View file
 
MDX23v24/modules/bs_roformer/__pycache__/attend.cpython-310.pyc ADDED
Binary file (3.33 kB). View file
 
MDX23v24/modules/bs_roformer/__pycache__/bs_roformer.cpython-310.pyc ADDED
Binary file (13.2 kB). View file
 
MDX23v24/modules/bs_roformer/__pycache__/mel_band_roformer.cpython-310.pyc ADDED
Binary file (14.3 kB). View file
 
MDX23v24/modules/bs_roformer/attend.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import wraps
2
+ from packaging import version
3
+ from collections import namedtuple
4
+
5
+ import torch
6
+ from torch import nn, einsum
7
+ import torch.nn.functional as F
8
+
9
+ from einops import rearrange, reduce
10
+
11
+ # constants
12
+
13
+ FlashAttentionConfig = namedtuple('FlashAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
14
+
15
+ # helpers
16
+
17
+ def exists(val):
18
+ return val is not None
19
+
20
+ def default(v, d):
21
+ return v if exists(v) else d
22
+
23
+ def once(fn):
24
+ called = False
25
+ @wraps(fn)
26
+ def inner(x):
27
+ nonlocal called
28
+ if called:
29
+ return
30
+ called = True
31
+ return fn(x)
32
+ return inner
33
+
34
+ print_once = once(print)
35
+
36
+ # main class
37
+
38
+ class Attend(nn.Module):
39
+ def __init__(
40
+ self,
41
+ dropout = 0.,
42
+ flash = False,
43
+ scale = None
44
+ ):
45
+ super().__init__()
46
+ self.scale = scale
47
+ self.dropout = dropout
48
+ self.attn_dropout = nn.Dropout(dropout)
49
+
50
+ self.flash = flash
51
+ assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
52
+
53
+ # determine efficient attention configs for cuda and cpu
54
+
55
+ self.cpu_config = FlashAttentionConfig(True, True, True)
56
+ self.cuda_config = None
57
+
58
+ if not torch.cuda.is_available() or not flash:
59
+ return
60
+
61
+ device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
62
+
63
+ if device_properties.major == 8 and device_properties.minor == 0:
64
+ print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
65
+ self.cuda_config = FlashAttentionConfig(True, False, False)
66
+ else:
67
+ print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
68
+ self.cuda_config = FlashAttentionConfig(False, True, True)
69
+
70
+ def flash_attn(self, q, k, v):
71
+ _, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device
72
+
73
+ if exists(self.scale):
74
+ default_scale = q.shape[-1] ** -0.5
75
+ q = q * (self.scale / default_scale)
76
+
77
+ # Check if there is a compatible device for flash attention
78
+
79
+ config = self.cuda_config if is_cuda else self.cpu_config
80
+
81
+ # pytorch 2.0 flash attn: q, k, v, mask, dropout, softmax_scale
82
+
83
+ with torch.backends.cuda.sdp_kernel(**config._asdict()):
84
+ out = F.scaled_dot_product_attention(
85
+ q, k, v,
86
+ dropout_p = self.dropout if self.training else 0.
87
+ )
88
+
89
+ return out
90
+
91
+ def forward(self, q, k, v):
92
+ """
93
+ einstein notation
94
+ b - batch
95
+ h - heads
96
+ n, i, j - sequence length (base sequence length, source, target)
97
+ d - feature dimension
98
+ """
99
+
100
+ q_len, k_len, device = q.shape[-2], k.shape[-2], q.device
101
+
102
+ scale = default(self.scale, q.shape[-1] ** -0.5)
103
+
104
+ if self.flash:
105
+ return self.flash_attn(q, k, v)
106
+
107
+ # similarity
108
+
109
+ sim = einsum(f"b h i d, b h j d -> b h i j", q, k) * scale
110
+
111
+ # attention
112
+
113
+ attn = sim.softmax(dim=-1)
114
+ attn = self.attn_dropout(attn)
115
+
116
+ # aggregate values
117
+
118
+ out = einsum(f"b h i j, b h j d -> b h i d", attn, v)
119
+
120
+ return out
MDX23v24/modules/bs_roformer/bs_roformer.py ADDED
@@ -0,0 +1,577 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ import torch
4
+ from torch import nn, einsum, Tensor
5
+ from torch.nn import Module, ModuleList
6
+ import torch.nn.functional as F
7
+
8
+ from modules.bs_roformer.attend import Attend
9
+
10
+ from beartype.typing import Tuple, Optional, List, Callable
11
+ from beartype import beartype
12
+
13
+ from rotary_embedding_torch import RotaryEmbedding
14
+
15
+ from einops import rearrange, pack, unpack
16
+ from einops.layers.torch import Rearrange
17
+
18
+ # helper functions
19
+
20
+ def exists(val):
21
+ return val is not None
22
+
23
+
24
+ def default(v, d):
25
+ return v if exists(v) else d
26
+
27
+
28
+ def pack_one(t, pattern):
29
+ return pack([t], pattern)
30
+
31
+
32
+ def unpack_one(t, ps, pattern):
33
+ return unpack(t, ps, pattern)[0]
34
+
35
+
36
+ # norm
37
+
38
+ def l2norm(t):
39
+ return F.normalize(t, dim = -1, p = 2)
40
+
41
+
42
+ class RMSNorm(Module):
43
+ def __init__(self, dim):
44
+ super().__init__()
45
+ self.scale = dim ** 0.5
46
+ self.gamma = nn.Parameter(torch.ones(dim))
47
+
48
+ def forward(self, x):
49
+ return F.normalize(x, dim=-1) * self.scale * self.gamma
50
+
51
+
52
+ # attention
53
+
54
+ class FeedForward(Module):
55
+ def __init__(
56
+ self,
57
+ dim,
58
+ mult=4,
59
+ dropout=0.
60
+ ):
61
+ super().__init__()
62
+ dim_inner = int(dim * mult)
63
+ self.net = nn.Sequential(
64
+ RMSNorm(dim),
65
+ nn.Linear(dim, dim_inner),
66
+ nn.GELU(),
67
+ nn.Dropout(dropout),
68
+ nn.Linear(dim_inner, dim),
69
+ nn.Dropout(dropout)
70
+ )
71
+
72
+ def forward(self, x):
73
+ return self.net(x)
74
+
75
+
76
+ class Attention(Module):
77
+ def __init__(
78
+ self,
79
+ dim,
80
+ heads=8,
81
+ dim_head=64,
82
+ dropout=0.,
83
+ rotary_embed=None,
84
+ flash=True
85
+ ):
86
+ super().__init__()
87
+ self.heads = heads
88
+ self.scale = dim_head ** -0.5
89
+ dim_inner = heads * dim_head
90
+
91
+ self.rotary_embed = rotary_embed
92
+
93
+ self.attend = Attend(flash=flash, dropout=dropout)
94
+
95
+ self.norm = RMSNorm(dim)
96
+ self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False)
97
+
98
+ self.to_gates = nn.Linear(dim, heads)
99
+
100
+ self.to_out = nn.Sequential(
101
+ nn.Linear(dim_inner, dim, bias=False),
102
+ nn.Dropout(dropout)
103
+ )
104
+
105
+ def forward(self, x):
106
+ x = self.norm(x)
107
+
108
+ q, k, v = rearrange(self.to_qkv(x), 'b n (qkv h d) -> qkv b h n d', qkv=3, h=self.heads)
109
+
110
+ if exists(self.rotary_embed):
111
+ q = self.rotary_embed.rotate_queries_or_keys(q)
112
+ k = self.rotary_embed.rotate_queries_or_keys(k)
113
+
114
+ out = self.attend(q, k, v)
115
+
116
+ gates = self.to_gates(x)
117
+ out = out * rearrange(gates, 'b n h -> b h n 1').sigmoid()
118
+
119
+ out = rearrange(out, 'b h n d -> b n (h d)')
120
+ return self.to_out(out)
121
+
122
+
123
+ class LinearAttention(Module):
124
+ """
125
+ this flavor of linear attention proposed in https://arxiv.org/abs/2106.09681 by El-Nouby et al.
126
+ """
127
+
128
+ @beartype
129
+ def __init__(
130
+ self,
131
+ *,
132
+ dim,
133
+ dim_head=32,
134
+ heads=8,
135
+ scale=8,
136
+ flash=False,
137
+ dropout=0.
138
+ ):
139
+ super().__init__()
140
+ dim_inner = dim_head * heads
141
+ self.norm = RMSNorm(dim)
142
+
143
+ self.to_qkv = nn.Sequential(
144
+ nn.Linear(dim, dim_inner * 3, bias=False),
145
+ Rearrange('b n (qkv h d) -> qkv b h d n', qkv=3, h=heads)
146
+ )
147
+
148
+ self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
149
+
150
+ self.attend = Attend(
151
+ scale=scale,
152
+ dropout=dropout,
153
+ flash=flash
154
+ )
155
+
156
+ self.to_out = nn.Sequential(
157
+ Rearrange('b h d n -> b n (h d)'),
158
+ nn.Linear(dim_inner, dim, bias=False)
159
+ )
160
+
161
+ def forward(
162
+ self,
163
+ x
164
+ ):
165
+ x = self.norm(x)
166
+
167
+ q, k, v = self.to_qkv(x)
168
+
169
+ q, k = map(l2norm, (q, k))
170
+ q = q * self.temperature.exp()
171
+
172
+ out = self.attend(q, k, v)
173
+
174
+ return self.to_out(out)
175
+
176
+
177
+ class Transformer(Module):
178
+ def __init__(
179
+ self,
180
+ *,
181
+ dim,
182
+ depth,
183
+ dim_head=64,
184
+ heads=8,
185
+ attn_dropout=0.,
186
+ ff_dropout=0.,
187
+ ff_mult=4,
188
+ norm_output=True,
189
+ rotary_embed=None,
190
+ flash_attn=True,
191
+ linear_attn=False
192
+ ):
193
+ super().__init__()
194
+ self.layers = ModuleList([])
195
+
196
+ for _ in range(depth):
197
+ if linear_attn:
198
+ attn = LinearAttention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout, flash=flash_attn)
199
+ else:
200
+ attn = Attention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout,
201
+ rotary_embed=rotary_embed, flash=flash_attn)
202
+
203
+ self.layers.append(ModuleList([
204
+ attn,
205
+ FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)
206
+ ]))
207
+
208
+ self.norm = RMSNorm(dim) if norm_output else nn.Identity()
209
+
210
+ def forward(self, x):
211
+
212
+ for attn, ff in self.layers:
213
+ x = attn(x) + x
214
+ x = ff(x) + x
215
+
216
+ return self.norm(x)
217
+
218
+
219
+ # bandsplit module
220
+
221
+ class BandSplit(Module):
222
+ @beartype
223
+ def __init__(
224
+ self,
225
+ dim,
226
+ dim_inputs: Tuple[int, ...]
227
+ ):
228
+ super().__init__()
229
+ self.dim_inputs = dim_inputs
230
+ self.to_features = ModuleList([])
231
+
232
+ for dim_in in dim_inputs:
233
+ net = nn.Sequential(
234
+ RMSNorm(dim_in),
235
+ nn.Linear(dim_in, dim)
236
+ )
237
+
238
+ self.to_features.append(net)
239
+
240
+ def forward(self, x):
241
+ x = x.split(self.dim_inputs, dim=-1)
242
+
243
+ outs = []
244
+ for split_input, to_feature in zip(x, self.to_features):
245
+ split_output = to_feature(split_input)
246
+ outs.append(split_output)
247
+
248
+ return torch.stack(outs, dim=-2)
249
+
250
+
251
+ def MLP(
252
+ dim_in,
253
+ dim_out,
254
+ dim_hidden=None,
255
+ depth=1,
256
+ activation=nn.Tanh
257
+ ):
258
+ dim_hidden = default(dim_hidden, dim_in)
259
+
260
+ net = []
261
+ dims = (dim_in, *((dim_hidden,) * (depth - 1)), dim_out)
262
+
263
+ for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
264
+ is_last = ind == (len(dims) - 2)
265
+
266
+ net.append(nn.Linear(layer_dim_in, layer_dim_out))
267
+
268
+ if is_last:
269
+ continue
270
+
271
+ net.append(activation())
272
+
273
+ return nn.Sequential(*net)
274
+
275
+
276
+ class MaskEstimator(Module):
277
+ @beartype
278
+ def __init__(
279
+ self,
280
+ dim,
281
+ dim_inputs: Tuple[int, ...],
282
+ depth,
283
+ mlp_expansion_factor=4
284
+ ):
285
+ super().__init__()
286
+ self.dim_inputs = dim_inputs
287
+ self.to_freqs = ModuleList([])
288
+ dim_hidden = dim * mlp_expansion_factor
289
+
290
+ for dim_in in dim_inputs:
291
+ net = []
292
+
293
+ mlp = nn.Sequential(
294
+ MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth),
295
+ nn.GLU(dim=-1)
296
+ )
297
+
298
+ self.to_freqs.append(mlp)
299
+
300
+ def forward(self, x):
301
+ x = x.unbind(dim=-2)
302
+
303
+ outs = []
304
+
305
+ for band_features, mlp in zip(x, self.to_freqs):
306
+ freq_out = mlp(band_features)
307
+ outs.append(freq_out)
308
+
309
+ return torch.cat(outs, dim=-1)
310
+
311
+
312
+ # main class
313
+
314
+ DEFAULT_FREQS_PER_BANDS = (
315
+ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
316
+ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
317
+ 2, 2, 2, 2,
318
+ 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
319
+ 12, 12, 12, 12, 12, 12, 12, 12,
320
+ 24, 24, 24, 24, 24, 24, 24, 24,
321
+ 48, 48, 48, 48, 48, 48, 48, 48,
322
+ 128, 129,
323
+ )
324
+
325
+
326
+ class BSRoformer(Module):
327
+
328
+ @beartype
329
+ def __init__(
330
+ self,
331
+ dim,
332
+ *,
333
+ depth,
334
+ stereo=False,
335
+ num_stems=1,
336
+ time_transformer_depth=2,
337
+ freq_transformer_depth=2,
338
+ linear_transformer_depth=0,
339
+ freqs_per_bands: Tuple[int, ...] = DEFAULT_FREQS_PER_BANDS,
340
+ # in the paper, they divide into ~60 bands, test with 1 for starters
341
+ dim_head=64,
342
+ heads=8,
343
+ attn_dropout=0.,
344
+ ff_dropout=0.,
345
+ flash_attn=True,
346
+ dim_freqs_in=1025,
347
+ stft_n_fft=2048,
348
+ stft_hop_length=512,
349
+ # 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction
350
+ stft_win_length=2048,
351
+ stft_normalized=False,
352
+ stft_window_fn: Optional[Callable] = None,
353
+ mask_estimator_depth=2,
354
+ multi_stft_resolution_loss_weight=1.,
355
+ multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256),
356
+ multi_stft_hop_size=147,
357
+ multi_stft_normalized=False,
358
+ multi_stft_window_fn: Callable = torch.hann_window
359
+ ):
360
+ super().__init__()
361
+
362
+ self.stereo = stereo
363
+ self.audio_channels = 2 if stereo else 1
364
+ self.num_stems = num_stems
365
+
366
+ self.layers = ModuleList([])
367
+
368
+ transformer_kwargs = dict(
369
+ dim=dim,
370
+ heads=heads,
371
+ dim_head=dim_head,
372
+ attn_dropout=attn_dropout,
373
+ ff_dropout=ff_dropout,
374
+ flash_attn=flash_attn,
375
+ norm_output=False
376
+ )
377
+
378
+ time_rotary_embed = RotaryEmbedding(dim=dim_head)
379
+ freq_rotary_embed = RotaryEmbedding(dim=dim_head)
380
+
381
+ for _ in range(depth):
382
+ tran_modules = []
383
+ if linear_transformer_depth > 0:
384
+ tran_modules.append(Transformer(depth=linear_transformer_depth, linear_attn=True, **transformer_kwargs))
385
+ tran_modules.append(
386
+ Transformer(depth=time_transformer_depth, rotary_embed=time_rotary_embed, **transformer_kwargs)
387
+ )
388
+ tran_modules.append(
389
+ Transformer(depth=freq_transformer_depth, rotary_embed=freq_rotary_embed, **transformer_kwargs)
390
+ )
391
+ self.layers.append(nn.ModuleList(tran_modules))
392
+
393
+ self.final_norm = RMSNorm(dim)
394
+
395
+ self.stft_kwargs = dict(
396
+ n_fft=stft_n_fft,
397
+ hop_length=stft_hop_length,
398
+ win_length=stft_win_length,
399
+ normalized=stft_normalized
400
+ )
401
+
402
+ self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), stft_win_length)
403
+
404
+ freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, return_complex=True).shape[1]
405
+
406
+ assert len(freqs_per_bands) > 1
407
+ assert sum(
408
+ freqs_per_bands) == freqs, f'the number of freqs in the bands must equal {freqs} based on the STFT settings, but got {sum(freqs_per_bands)}'
409
+
410
+ freqs_per_bands_with_complex = tuple(2 * f * self.audio_channels for f in freqs_per_bands)
411
+
412
+ self.band_split = BandSplit(
413
+ dim=dim,
414
+ dim_inputs=freqs_per_bands_with_complex
415
+ )
416
+
417
+ self.mask_estimators = nn.ModuleList([])
418
+
419
+ for _ in range(num_stems):
420
+ mask_estimator = MaskEstimator(
421
+ dim=dim,
422
+ dim_inputs=freqs_per_bands_with_complex,
423
+ depth=mask_estimator_depth
424
+ )
425
+
426
+ self.mask_estimators.append(mask_estimator)
427
+
428
+ # for the multi-resolution stft loss
429
+
430
+ self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight
431
+ self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes
432
+ self.multi_stft_n_fft = stft_n_fft
433
+ self.multi_stft_window_fn = multi_stft_window_fn
434
+
435
+ self.multi_stft_kwargs = dict(
436
+ hop_length=multi_stft_hop_size,
437
+ normalized=multi_stft_normalized
438
+ )
439
+
440
+ def forward(
441
+ self,
442
+ raw_audio,
443
+ target=None,
444
+ return_loss_breakdown=False
445
+ ):
446
+ """
447
+ einops
448
+
449
+ b - batch
450
+ f - freq
451
+ t - time
452
+ s - audio channel (1 for mono, 2 for stereo)
453
+ n - number of 'stems'
454
+ c - complex (2)
455
+ d - feature dimension
456
+ """
457
+
458
+ device = raw_audio.device
459
+
460
+ if raw_audio.ndim == 2:
461
+ raw_audio = rearrange(raw_audio, 'b t -> b 1 t')
462
+
463
+ channels = raw_audio.shape[1]
464
+ assert (not self.stereo and channels == 1) or (
465
+ self.stereo and channels == 2), 'stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)'
466
+
467
+ # to stft
468
+
469
+ raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, '* t')
470
+
471
+ stft_window = self.stft_window_fn(device=device)
472
+
473
+ stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True)
474
+ stft_repr = torch.view_as_real(stft_repr)
475
+
476
+ stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, '* f t c')
477
+ stft_repr = rearrange(stft_repr,
478
+ 'b s f t c -> b (f s) t c') # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting
479
+
480
+ x = rearrange(stft_repr, 'b f t c -> b t (f c)')
481
+
482
+ x = self.band_split(x)
483
+
484
+ # axial / hierarchical attention
485
+
486
+ for transformer_block in self.layers:
487
+
488
+ if len(transformer_block) == 3:
489
+ linear_transformer, time_transformer, freq_transformer = transformer_block
490
+
491
+ x, ft_ps = pack([x], 'b * d')
492
+ x = linear_transformer(x)
493
+ x, = unpack(x, ft_ps, 'b * d')
494
+ else:
495
+ time_transformer, freq_transformer = transformer_block
496
+
497
+ x = rearrange(x, 'b t f d -> b f t d')
498
+ x, ps = pack([x], '* t d')
499
+
500
+ x = time_transformer(x)
501
+
502
+ x, = unpack(x, ps, '* t d')
503
+ x = rearrange(x, 'b f t d -> b t f d')
504
+ x, ps = pack([x], '* f d')
505
+
506
+ x = freq_transformer(x)
507
+
508
+ x, = unpack(x, ps, '* f d')
509
+
510
+ x = self.final_norm(x)
511
+
512
+ num_stems = len(self.mask_estimators)
513
+
514
+ mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
515
+ mask = rearrange(mask, 'b n t (f c) -> b n f t c', c=2)
516
+
517
+ # modulate frequency representation
518
+
519
+ stft_repr = rearrange(stft_repr, 'b f t c -> b 1 f t c')
520
+
521
+ # complex number multiplication
522
+
523
+ stft_repr = torch.view_as_complex(stft_repr)
524
+ mask = torch.view_as_complex(mask)
525
+
526
+ stft_repr = stft_repr * mask
527
+
528
+ # istft
529
+
530
+ stft_repr = rearrange(stft_repr, 'b n (f s) t -> (b n s) f t', s=self.audio_channels)
531
+
532
+ recon_audio = torch.istft(stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False)
533
+
534
+ recon_audio = rearrange(recon_audio, '(b n s) t -> b n s t', s=self.audio_channels, n=num_stems)
535
+
536
+ if num_stems == 1:
537
+ recon_audio = rearrange(recon_audio, 'b 1 s t -> b s t')
538
+
539
+ # if a target is passed in, calculate loss for learning
540
+
541
+ if not exists(target):
542
+ return recon_audio
543
+
544
+ if self.num_stems > 1:
545
+ assert target.ndim == 4 and target.shape[1] == self.num_stems
546
+
547
+ if target.ndim == 2:
548
+ target = rearrange(target, '... t -> ... 1 t')
549
+
550
+ target = target[..., :recon_audio.shape[-1]] # protect against lost length on istft
551
+
552
+ loss = F.l1_loss(recon_audio, target)
553
+
554
+ multi_stft_resolution_loss = 0.
555
+
556
+ for window_size in self.multi_stft_resolutions_window_sizes:
557
+ res_stft_kwargs = dict(
558
+ n_fft=max(window_size, self.multi_stft_n_fft), # not sure what n_fft is across multi resolution stft
559
+ win_length=window_size,
560
+ return_complex=True,
561
+ window=self.multi_stft_window_fn(window_size, device=device),
562
+ **self.multi_stft_kwargs,
563
+ )
564
+
565
+ recon_Y = torch.stft(rearrange(recon_audio, '... s t -> (... s) t'), **res_stft_kwargs)
566
+ target_Y = torch.stft(rearrange(target, '... s t -> (... s) t'), **res_stft_kwargs)
567
+
568
+ multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(recon_Y, target_Y)
569
+
570
+ weighted_multi_resolution_loss = multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight
571
+
572
+ total_loss = loss + weighted_multi_resolution_loss
573
+
574
+ if not return_loss_breakdown:
575
+ return total_loss
576
+
577
+ return total_loss, (loss, multi_stft_resolution_loss)
MDX23v24/modules/bs_roformer/mel_band_roformer.py ADDED
@@ -0,0 +1,637 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ import torch
4
+ from torch import nn, einsum, Tensor
5
+ from torch.nn import Module, ModuleList
6
+ import torch.nn.functional as F
7
+
8
+ from modules.bs_roformer.attend import Attend
9
+
10
+ from beartype.typing import Tuple, Optional, List, Callable
11
+ from beartype import beartype
12
+
13
+ from rotary_embedding_torch import RotaryEmbedding
14
+
15
+ from einops import rearrange, pack, unpack, reduce, repeat
16
+ from einops.layers.torch import Rearrange
17
+
18
+ from librosa import filters
19
+
20
+
21
+ # helper functions
22
+
23
+ def exists(val):
24
+ return val is not None
25
+
26
+
27
+ def default(v, d):
28
+ return v if exists(v) else d
29
+
30
+
31
+ def pack_one(t, pattern):
32
+ return pack([t], pattern)
33
+
34
+
35
+ def unpack_one(t, ps, pattern):
36
+ return unpack(t, ps, pattern)[0]
37
+
38
+
39
+ def pad_at_dim(t, pad, dim=-1, value=0.):
40
+ dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
41
+ zeros = ((0, 0) * dims_from_right)
42
+ return F.pad(t, (*zeros, *pad), value=value)
43
+
44
+
45
+ def l2norm(t):
46
+ return F.normalize(t, dim=-1, p=2)
47
+
48
+
49
+ # norm
50
+
51
+ class RMSNorm(Module):
52
+ def __init__(self, dim):
53
+ super().__init__()
54
+ self.scale = dim ** 0.5
55
+ self.gamma = nn.Parameter(torch.ones(dim))
56
+
57
+ def forward(self, x):
58
+ return F.normalize(x, dim=-1) * self.scale * self.gamma
59
+
60
+
61
+ # attention
62
+
63
+ class FeedForward(Module):
64
+ def __init__(
65
+ self,
66
+ dim,
67
+ mult=4,
68
+ dropout=0.
69
+ ):
70
+ super().__init__()
71
+ dim_inner = int(dim * mult)
72
+ self.net = nn.Sequential(
73
+ RMSNorm(dim),
74
+ nn.Linear(dim, dim_inner),
75
+ nn.GELU(),
76
+ nn.Dropout(dropout),
77
+ nn.Linear(dim_inner, dim),
78
+ nn.Dropout(dropout)
79
+ )
80
+
81
+ def forward(self, x):
82
+ return self.net(x)
83
+
84
+
85
+ class Attention(Module):
86
+ def __init__(
87
+ self,
88
+ dim,
89
+ heads=8,
90
+ dim_head=64,
91
+ dropout=0.,
92
+ rotary_embed=None,
93
+ flash=True
94
+ ):
95
+ super().__init__()
96
+ self.heads = heads
97
+ self.scale = dim_head ** -0.5
98
+ dim_inner = heads * dim_head
99
+
100
+ self.rotary_embed = rotary_embed
101
+
102
+ self.attend = Attend(flash=flash, dropout=dropout)
103
+
104
+ self.norm = RMSNorm(dim)
105
+ self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False)
106
+
107
+ self.to_gates = nn.Linear(dim, heads)
108
+
109
+ self.to_out = nn.Sequential(
110
+ nn.Linear(dim_inner, dim, bias=False),
111
+ nn.Dropout(dropout)
112
+ )
113
+
114
+ def forward(self, x):
115
+ x = self.norm(x)
116
+
117
+ q, k, v = rearrange(self.to_qkv(x), 'b n (qkv h d) -> qkv b h n d', qkv=3, h=self.heads)
118
+
119
+ if exists(self.rotary_embed):
120
+ q = self.rotary_embed.rotate_queries_or_keys(q)
121
+ k = self.rotary_embed.rotate_queries_or_keys(k)
122
+
123
+ out = self.attend(q, k, v)
124
+
125
+ gates = self.to_gates(x)
126
+ out = out * rearrange(gates, 'b n h -> b h n 1').sigmoid()
127
+
128
+ out = rearrange(out, 'b h n d -> b n (h d)')
129
+ return self.to_out(out)
130
+
131
+
132
+ class LinearAttention(Module):
133
+ """
134
+ this flavor of linear attention proposed in https://arxiv.org/abs/2106.09681 by El-Nouby et al.
135
+ """
136
+
137
+ @beartype
138
+ def __init__(
139
+ self,
140
+ *,
141
+ dim,
142
+ dim_head=32,
143
+ heads=8,
144
+ scale=8,
145
+ flash=False,
146
+ dropout=0.
147
+ ):
148
+ super().__init__()
149
+ dim_inner = dim_head * heads
150
+ self.norm = RMSNorm(dim)
151
+
152
+ self.to_qkv = nn.Sequential(
153
+ nn.Linear(dim, dim_inner * 3, bias=False),
154
+ Rearrange('b n (qkv h d) -> qkv b h d n', qkv=3, h=heads)
155
+ )
156
+
157
+ self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
158
+
159
+ self.attend = Attend(
160
+ scale=scale,
161
+ dropout=dropout,
162
+ flash=flash
163
+ )
164
+
165
+ self.to_out = nn.Sequential(
166
+ Rearrange('b h d n -> b n (h d)'),
167
+ nn.Linear(dim_inner, dim, bias=False)
168
+ )
169
+
170
+ def forward(
171
+ self,
172
+ x
173
+ ):
174
+ x = self.norm(x)
175
+
176
+ q, k, v = self.to_qkv(x)
177
+
178
+ q, k = map(l2norm, (q, k))
179
+ q = q * self.temperature.exp()
180
+
181
+ out = self.attend(q, k, v)
182
+
183
+ return self.to_out(out)
184
+
185
+
186
+ class Transformer(Module):
187
+ def __init__(
188
+ self,
189
+ *,
190
+ dim,
191
+ depth,
192
+ dim_head=64,
193
+ heads=8,
194
+ attn_dropout=0.,
195
+ ff_dropout=0.,
196
+ ff_mult=4,
197
+ norm_output=True,
198
+ rotary_embed=None,
199
+ flash_attn=True,
200
+ linear_attn=False
201
+ ):
202
+ super().__init__()
203
+ self.layers = ModuleList([])
204
+
205
+ for _ in range(depth):
206
+ if linear_attn:
207
+ attn = LinearAttention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout, flash=flash_attn)
208
+ else:
209
+ attn = Attention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout,
210
+ rotary_embed=rotary_embed, flash=flash_attn)
211
+
212
+ self.layers.append(ModuleList([
213
+ attn,
214
+ FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)
215
+ ]))
216
+
217
+ self.norm = RMSNorm(dim) if norm_output else nn.Identity()
218
+
219
+ def forward(self, x):
220
+
221
+ for attn, ff in self.layers:
222
+ x = attn(x) + x
223
+ x = ff(x) + x
224
+
225
+ return self.norm(x)
226
+
227
+
228
+ # bandsplit module
229
+
230
+ class BandSplit(Module):
231
+ @beartype
232
+ def __init__(
233
+ self,
234
+ dim,
235
+ dim_inputs: Tuple[int, ...]
236
+ ):
237
+ super().__init__()
238
+ self.dim_inputs = dim_inputs
239
+ self.to_features = ModuleList([])
240
+
241
+ for dim_in in dim_inputs:
242
+ net = nn.Sequential(
243
+ RMSNorm(dim_in),
244
+ nn.Linear(dim_in, dim)
245
+ )
246
+
247
+ self.to_features.append(net)
248
+
249
+ def forward(self, x):
250
+ x = x.split(self.dim_inputs, dim=-1)
251
+
252
+ outs = []
253
+ for split_input, to_feature in zip(x, self.to_features):
254
+ split_output = to_feature(split_input)
255
+ outs.append(split_output)
256
+
257
+ return torch.stack(outs, dim=-2)
258
+
259
+
260
+ def MLP(
261
+ dim_in,
262
+ dim_out,
263
+ dim_hidden=None,
264
+ depth=1,
265
+ activation=nn.Tanh
266
+ ):
267
+ dim_hidden = default(dim_hidden, dim_in)
268
+
269
+ net = []
270
+ dims = (dim_in, *((dim_hidden,) * depth), dim_out)
271
+
272
+ for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
273
+ is_last = ind == (len(dims) - 2)
274
+
275
+ net.append(nn.Linear(layer_dim_in, layer_dim_out))
276
+
277
+ if is_last:
278
+ continue
279
+
280
+ net.append(activation())
281
+
282
+ return nn.Sequential(*net)
283
+
284
+
285
+ class MaskEstimator(Module):
286
+ @beartype
287
+ def __init__(
288
+ self,
289
+ dim,
290
+ dim_inputs: Tuple[int, ...],
291
+ depth,
292
+ mlp_expansion_factor=4
293
+ ):
294
+ super().__init__()
295
+ self.dim_inputs = dim_inputs
296
+ self.to_freqs = ModuleList([])
297
+ dim_hidden = dim * mlp_expansion_factor
298
+
299
+ for dim_in in dim_inputs:
300
+ net = []
301
+
302
+ mlp = nn.Sequential(
303
+ MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth),
304
+ nn.GLU(dim=-1)
305
+ )
306
+
307
+ self.to_freqs.append(mlp)
308
+
309
+ def forward(self, x):
310
+ x = x.unbind(dim=-2)
311
+
312
+ outs = []
313
+
314
+ for band_features, mlp in zip(x, self.to_freqs):
315
+ freq_out = mlp(band_features)
316
+ outs.append(freq_out)
317
+
318
+ return torch.cat(outs, dim=-1)
319
+
320
+
321
+ # main class
322
+
323
+ class MelBandRoformer(Module):
324
+
325
+ @beartype
326
+ def __init__(
327
+ self,
328
+ dim,
329
+ *,
330
+ depth,
331
+ stereo=False,
332
+ num_stems=1,
333
+ time_transformer_depth=2,
334
+ freq_transformer_depth=2,
335
+ linear_transformer_depth=0,
336
+ num_bands=60,
337
+ dim_head=64,
338
+ heads=8,
339
+ attn_dropout=0.1,
340
+ ff_dropout=0.1,
341
+ flash_attn=True,
342
+ dim_freqs_in=1025,
343
+ sample_rate=44100, # needed for mel filter bank from librosa
344
+ stft_n_fft=2048,
345
+ stft_hop_length=512,
346
+ # 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction
347
+ stft_win_length=2048,
348
+ stft_normalized=False,
349
+ stft_window_fn: Optional[Callable] = None,
350
+ mask_estimator_depth=1,
351
+ multi_stft_resolution_loss_weight=1.,
352
+ multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256),
353
+ multi_stft_hop_size=147,
354
+ multi_stft_normalized=False,
355
+ multi_stft_window_fn: Callable = torch.hann_window,
356
+ match_input_audio_length=False, # if True, pad output tensor to match length of input tensor
357
+ ):
358
+ super().__init__()
359
+
360
+ self.stereo = stereo
361
+ self.audio_channels = 2 if stereo else 1
362
+ self.num_stems = num_stems
363
+
364
+ self.layers = ModuleList([])
365
+
366
+ transformer_kwargs = dict(
367
+ dim=dim,
368
+ heads=heads,
369
+ dim_head=dim_head,
370
+ attn_dropout=attn_dropout,
371
+ ff_dropout=ff_dropout,
372
+ flash_attn=flash_attn
373
+ )
374
+
375
+ time_rotary_embed = RotaryEmbedding(dim=dim_head)
376
+ freq_rotary_embed = RotaryEmbedding(dim=dim_head)
377
+
378
+ for _ in range(depth):
379
+ tran_modules = []
380
+ if linear_transformer_depth > 0:
381
+ tran_modules.append(Transformer(depth=linear_transformer_depth, linear_attn=True, **transformer_kwargs))
382
+ tran_modules.append(
383
+ Transformer(depth=time_transformer_depth, rotary_embed=time_rotary_embed, **transformer_kwargs)
384
+ )
385
+ tran_modules.append(
386
+ Transformer(depth=freq_transformer_depth, rotary_embed=freq_rotary_embed, **transformer_kwargs)
387
+ )
388
+ self.layers.append(nn.ModuleList(tran_modules))
389
+
390
+ self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), stft_win_length)
391
+
392
+ self.stft_kwargs = dict(
393
+ n_fft=stft_n_fft,
394
+ hop_length=stft_hop_length,
395
+ win_length=stft_win_length,
396
+ normalized=stft_normalized
397
+ )
398
+
399
+ freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, return_complex=True).shape[1]
400
+
401
+ # create mel filter bank
402
+ # with librosa.filters.mel as in section 2 of paper
403
+
404
+ mel_filter_bank_numpy = filters.mel(sr=sample_rate, n_fft=stft_n_fft, n_mels=num_bands)
405
+
406
+ mel_filter_bank = torch.from_numpy(mel_filter_bank_numpy)
407
+
408
+ # for some reason, it doesn't include the first freq? just force a value for now
409
+
410
+ mel_filter_bank[0][0] = 1.
411
+
412
+ # In some systems/envs we get 0.0 instead of ~1.9e-18 in the last position,
413
+ # so let's force a positive value
414
+
415
+ mel_filter_bank[-1, -1] = 1.
416
+
417
+ # binary as in paper (then estimated masks are averaged for overlapping regions)
418
+
419
+ freqs_per_band = mel_filter_bank > 0
420
+ assert freqs_per_band.any(dim=0).all(), 'all frequencies need to be covered by all bands for now'
421
+
422
+ repeated_freq_indices = repeat(torch.arange(freqs), 'f -> b f', b=num_bands)
423
+ freq_indices = repeated_freq_indices[freqs_per_band]
424
+
425
+ if stereo:
426
+ freq_indices = repeat(freq_indices, 'f -> f s', s=2)
427
+ freq_indices = freq_indices * 2 + torch.arange(2)
428
+ freq_indices = rearrange(freq_indices, 'f s -> (f s)')
429
+
430
+ self.register_buffer('freq_indices', freq_indices, persistent=False)
431
+ self.register_buffer('freqs_per_band', freqs_per_band, persistent=False)
432
+
433
+ num_freqs_per_band = reduce(freqs_per_band, 'b f -> b', 'sum')
434
+ num_bands_per_freq = reduce(freqs_per_band, 'b f -> f', 'sum')
435
+
436
+ self.register_buffer('num_freqs_per_band', num_freqs_per_band, persistent=False)
437
+ self.register_buffer('num_bands_per_freq', num_bands_per_freq, persistent=False)
438
+
439
+ # band split and mask estimator
440
+
441
+ freqs_per_bands_with_complex = tuple(2 * f * self.audio_channels for f in num_freqs_per_band.tolist())
442
+
443
+ self.band_split = BandSplit(
444
+ dim=dim,
445
+ dim_inputs=freqs_per_bands_with_complex
446
+ )
447
+
448
+ self.mask_estimators = nn.ModuleList([])
449
+
450
+ for _ in range(num_stems):
451
+ mask_estimator = MaskEstimator(
452
+ dim=dim,
453
+ dim_inputs=freqs_per_bands_with_complex,
454
+ depth=mask_estimator_depth
455
+ )
456
+
457
+ self.mask_estimators.append(mask_estimator)
458
+
459
+ # for the multi-resolution stft loss
460
+
461
+ self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight
462
+ self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes
463
+ self.multi_stft_n_fft = stft_n_fft
464
+ self.multi_stft_window_fn = multi_stft_window_fn
465
+
466
+ self.multi_stft_kwargs = dict(
467
+ hop_length=multi_stft_hop_size,
468
+ normalized=multi_stft_normalized
469
+ )
470
+
471
+ self.match_input_audio_length = match_input_audio_length
472
+
473
+ def forward(
474
+ self,
475
+ raw_audio,
476
+ target=None,
477
+ return_loss_breakdown=False
478
+ ):
479
+ """
480
+ einops
481
+
482
+ b - batch
483
+ f - freq
484
+ t - time
485
+ s - audio channel (1 for mono, 2 for stereo)
486
+ n - number of 'stems'
487
+ c - complex (2)
488
+ d - feature dimension
489
+ """
490
+
491
+ device = raw_audio.device
492
+
493
+ if raw_audio.ndim == 2:
494
+ raw_audio = rearrange(raw_audio, 'b t -> b 1 t')
495
+
496
+ batch, channels, raw_audio_length = raw_audio.shape
497
+
498
+ istft_length = raw_audio_length if self.match_input_audio_length else None
499
+
500
+ assert (not self.stereo and channels == 1) or (
501
+ self.stereo and channels == 2), 'stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)'
502
+
503
+ # to stft
504
+
505
+ raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, '* t')
506
+
507
+ stft_window = self.stft_window_fn(device=device)
508
+
509
+ stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True)
510
+ stft_repr = torch.view_as_real(stft_repr)
511
+
512
+ stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, '* f t c')
513
+ stft_repr = rearrange(stft_repr,
514
+ 'b s f t c -> b (f s) t c') # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting
515
+
516
+ # index out all frequencies for all frequency ranges across bands ascending in one go
517
+
518
+ batch_arange = torch.arange(batch, device=device)[..., None]
519
+
520
+ # account for stereo
521
+
522
+ x = stft_repr[batch_arange, self.freq_indices]
523
+
524
+ # fold the complex (real and imag) into the frequencies dimension
525
+
526
+ x = rearrange(x, 'b f t c -> b t (f c)')
527
+
528
+ x = self.band_split(x)
529
+
530
+ # axial / hierarchical attention
531
+
532
+ for transformer_block in self.layers:
533
+
534
+ if len(transformer_block) == 3:
535
+ linear_transformer, time_transformer, freq_transformer = transformer_block
536
+
537
+ x, ft_ps = pack([x], 'b * d')
538
+ x = linear_transformer(x)
539
+ x, = unpack(x, ft_ps, 'b * d')
540
+ else:
541
+ time_transformer, freq_transformer = transformer_block
542
+
543
+ x = rearrange(x, 'b t f d -> b f t d')
544
+ x, ps = pack([x], '* t d')
545
+
546
+ x = time_transformer(x)
547
+
548
+ x, = unpack(x, ps, '* t d')
549
+ x = rearrange(x, 'b f t d -> b t f d')
550
+ x, ps = pack([x], '* f d')
551
+
552
+ x = freq_transformer(x)
553
+
554
+ x, = unpack(x, ps, '* f d')
555
+
556
+ num_stems = len(self.mask_estimators)
557
+
558
+ masks = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
559
+ masks = rearrange(masks, 'b n t (f c) -> b n f t c', c=2)
560
+
561
+ # modulate frequency representation
562
+
563
+ stft_repr = rearrange(stft_repr, 'b f t c -> b 1 f t c')
564
+
565
+ # complex number multiplication
566
+
567
+ stft_repr = torch.view_as_complex(stft_repr)
568
+ masks = torch.view_as_complex(masks)
569
+
570
+ masks = masks.type(stft_repr.dtype)
571
+
572
+ # need to average the estimated mask for the overlapped frequencies
573
+
574
+ scatter_indices = repeat(self.freq_indices, 'f -> b n f t', b=batch, n=num_stems, t=stft_repr.shape[-1])
575
+
576
+ stft_repr_expanded_stems = repeat(stft_repr, 'b 1 ... -> b n ...', n=num_stems)
577
+ masks_summed = torch.zeros_like(stft_repr_expanded_stems).scatter_add_(2, scatter_indices, masks)
578
+
579
+ denom = repeat(self.num_bands_per_freq, 'f -> (f r) 1', r=channels)
580
+
581
+ masks_averaged = masks_summed / denom.clamp(min=1e-8)
582
+
583
+ # modulate stft repr with estimated mask
584
+
585
+ stft_repr = stft_repr * masks_averaged
586
+
587
+ # istft
588
+
589
+ stft_repr = rearrange(stft_repr, 'b n (f s) t -> (b n s) f t', s=self.audio_channels)
590
+
591
+ recon_audio = torch.istft(stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False,
592
+ length=istft_length)
593
+
594
+ recon_audio = rearrange(recon_audio, '(b n s) t -> b n s t', b=batch, s=self.audio_channels, n=num_stems)
595
+
596
+ if num_stems == 1:
597
+ recon_audio = rearrange(recon_audio, 'b 1 s t -> b s t')
598
+
599
+ # if a target is passed in, calculate loss for learning
600
+
601
+ if not exists(target):
602
+ return recon_audio
603
+
604
+ if self.num_stems > 1:
605
+ assert target.ndim == 4 and target.shape[1] == self.num_stems
606
+
607
+ if target.ndim == 2:
608
+ target = rearrange(target, '... t -> ... 1 t')
609
+
610
+ target = target[..., :recon_audio.shape[-1]] # protect against lost length on istft
611
+
612
+ loss = F.l1_loss(recon_audio, target)
613
+
614
+ multi_stft_resolution_loss = 0.
615
+
616
+ for window_size in self.multi_stft_resolutions_window_sizes:
617
+ res_stft_kwargs = dict(
618
+ n_fft=max(window_size, self.multi_stft_n_fft), # not sure what n_fft is across multi resolution stft
619
+ win_length=window_size,
620
+ return_complex=True,
621
+ window=self.multi_stft_window_fn(window_size, device=device),
622
+ **self.multi_stft_kwargs,
623
+ )
624
+
625
+ recon_Y = torch.stft(rearrange(recon_audio, '... s t -> (... s) t'), **res_stft_kwargs)
626
+ target_Y = torch.stft(rearrange(target, '... s t -> (... s) t'), **res_stft_kwargs)
627
+
628
+ multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(recon_Y, target_Y)
629
+
630
+ weighted_multi_resolution_loss = multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight
631
+
632
+ total_loss = loss + weighted_multi_resolution_loss
633
+
634
+ if not return_loss_breakdown:
635
+ return total_loss
636
+
637
+ return total_loss, (loss, multi_stft_resolution_loss)
MDX23v24/modules/segm_models.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ if __name__ == '__main__':
2
+ import os
3
+
4
+ gpu_use = "2"
5
+ print('GPU use: {}'.format(gpu_use))
6
+ os.environ["CUDA_VISIBLE_DEVICES"] = "{}".format(gpu_use)
7
+
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import segmentation_models_pytorch as smp
12
+
13
+
14
+ class STFT:
15
+ def __init__(self, config):
16
+ self.n_fft = config.n_fft
17
+ self.hop_length = config.hop_length
18
+ self.window = torch.hann_window(window_length=self.n_fft, periodic=True)
19
+ self.dim_f = config.dim_f
20
+
21
+ def __call__(self, x):
22
+ window = self.window.to(x.device)
23
+ batch_dims = x.shape[:-2]
24
+ c, t = x.shape[-2:]
25
+ x = x.reshape([-1, t])
26
+ x = torch.stft(
27
+ x,
28
+ n_fft=self.n_fft,
29
+ hop_length=self.hop_length,
30
+ window=window,
31
+ center=True,
32
+ return_complex=True
33
+ )
34
+ x = torch.view_as_real(x)
35
+ x = x.permute([0, 3, 1, 2])
36
+ x = x.reshape([*batch_dims, c, 2, -1, x.shape[-1]]).reshape([*batch_dims, c * 2, -1, x.shape[-1]])
37
+ return x[..., :self.dim_f, :]
38
+
39
+ def inverse(self, x):
40
+ window = self.window.to(x.device)
41
+ batch_dims = x.shape[:-3]
42
+ c, f, t = x.shape[-3:]
43
+ n = self.n_fft // 2 + 1
44
+ f_pad = torch.zeros([*batch_dims, c, n - f, t]).to(x.device)
45
+ x = torch.cat([x, f_pad], -2)
46
+ x = x.reshape([*batch_dims, c // 2, 2, n, t]).reshape([-1, 2, n, t])
47
+ x = x.permute([0, 2, 3, 1])
48
+ x = x[..., 0] + x[..., 1] * 1.j
49
+ x = torch.istft(
50
+ x,
51
+ n_fft=self.n_fft,
52
+ hop_length=self.hop_length,
53
+ window=window,
54
+ center=True
55
+ )
56
+ x = x.reshape([*batch_dims, 2, -1])
57
+ return x
58
+
59
+
60
+ def get_act(act_type):
61
+ if act_type == 'gelu':
62
+ return nn.GELU()
63
+ elif act_type == 'relu':
64
+ return nn.ReLU()
65
+ elif act_type[:3] == 'elu':
66
+ alpha = float(act_type.replace('elu', ''))
67
+ return nn.ELU(alpha)
68
+ else:
69
+ raise Exception
70
+
71
+
72
+ class Segm_Models_Net(nn.Module):
73
+ def __init__(self, config):
74
+ super().__init__()
75
+ self.config = config
76
+
77
+ act = get_act(act_type=config.model.act)
78
+
79
+ self.num_target_instruments = 1 if config.training.target_instrument else len(config.training.instruments)
80
+ self.num_subbands = config.model.num_subbands
81
+
82
+ dim_c = self.num_subbands * config.audio.num_channels * 2
83
+ c = config.model.num_channels
84
+ f = config.audio.dim_f // self.num_subbands
85
+
86
+ self.first_conv = nn.Conv2d(dim_c, c, 1, 1, 0, bias=False)
87
+
88
+ if config.model.decoder_type == 'unet':
89
+ self.unet_model = smp.Unet(
90
+ encoder_name=config.model.encoder_name,
91
+ encoder_weights="imagenet",
92
+ in_channels=c,
93
+ classes=c,
94
+ )
95
+ elif config.model.decoder_type == 'fpn':
96
+ self.unet_model = smp.FPN(
97
+ encoder_name=config.model.encoder_name,
98
+ encoder_weights="imagenet",
99
+ in_channels=c,
100
+ classes=c,
101
+ )
102
+
103
+ self.final_conv = nn.Sequential(
104
+ nn.Conv2d(c + dim_c, c, 1, 1, 0, bias=False),
105
+ act,
106
+ nn.Conv2d(c, self.num_target_instruments * dim_c, 1, 1, 0, bias=False)
107
+ )
108
+
109
+ self.stft = STFT(config.audio)
110
+
111
+ def cac2cws(self, x):
112
+ k = self.num_subbands
113
+ b, c, f, t = x.shape
114
+ x = x.reshape(b, c, k, f // k, t)
115
+ x = x.reshape(b, c * k, f // k, t)
116
+ return x
117
+
118
+ def cws2cac(self, x):
119
+ k = self.num_subbands
120
+ b, c, f, t = x.shape
121
+ x = x.reshape(b, c // k, k, f, t)
122
+ x = x.reshape(b, c // k, f * k, t)
123
+ return x
124
+
125
+ def forward(self, x):
126
+
127
+ x = self.stft(x)
128
+
129
+ mix = x = self.cac2cws(x)
130
+
131
+ first_conv_out = x = self.first_conv(x)
132
+
133
+ x = x.transpose(-1, -2)
134
+
135
+ x = self.unet_model(x)
136
+
137
+ x = x.transpose(-1, -2)
138
+
139
+ x = x * first_conv_out # reduce artifacts
140
+
141
+ x = self.final_conv(torch.cat([mix, x], 1))
142
+
143
+ x = self.cws2cac(x)
144
+
145
+ if self.num_target_instruments > 1:
146
+ b, c, f, t = x.shape
147
+ x = x.reshape(b, self.num_target_instruments, -1, f, t)
148
+
149
+ x = self.stft.inverse(x)
150
+ return x
MDX23v24/modules/tfc_tdf_v2.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from functools import partial
5
+
6
+
7
+ class Conv_TDF_net_trim_model(nn.Module):
8
+ def __init__(self, device, target_name, L, n_fft, hop=1024, dim_f=3072):
9
+ super(Conv_TDF_net_trim_model, self).__init__()
10
+ self.dim_c = 4
11
+ self.dim_f, self.dim_t = dim_f, 256
12
+ self.n_fft = n_fft
13
+ self.hop = hop
14
+ self.n_bins = self.n_fft // 2 + 1
15
+ self.chunk_size = hop * (self.dim_t - 1)
16
+ self.window = torch.hann_window(window_length=self.n_fft, periodic=True).to(device)
17
+ self.target_name = target_name
18
+ out_c = self.dim_c * 4 if target_name == '*' else self.dim_c
19
+ self.freq_pad = torch.zeros([1, out_c, self.n_bins - self.dim_f, self.dim_t]).to(device)
20
+ self.n = L // 2
21
+
22
+ def stft(self, x):
23
+ x = x.reshape([-1, self.chunk_size])
24
+ x = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True, return_complex=True)
25
+ x = torch.view_as_real(x)
26
+ x = x.permute([0, 3, 1, 2])
27
+ x = x.reshape([-1, 2, 2, self.n_bins, self.dim_t]).reshape([-1, self.dim_c, self.n_bins, self.dim_t])
28
+ return x[:, :, :self.dim_f]
29
+
30
+ def istft(self, x, freq_pad=None):
31
+ freq_pad = self.freq_pad.repeat([x.shape[0], 1, 1, 1]) if freq_pad is None else freq_pad
32
+ x = torch.cat([x, freq_pad], -2)
33
+ x = x.reshape([-1, 2, 2, self.n_bins, self.dim_t]).reshape([-1, 2, self.n_bins, self.dim_t])
34
+ x = x.permute([0, 2, 3, 1])
35
+ x = x.contiguous()
36
+ x = torch.view_as_complex(x)
37
+ x = torch.istft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True)
38
+ return x.reshape([-1, 2, self.chunk_size])
39
+
40
+ def forward(self, x):
41
+ x = self.first_conv(x)
42
+ x = x.transpose(-1, -2)
43
+
44
+ ds_outputs = []
45
+ for i in range(self.n):
46
+ x = self.ds_dense[i](x)
47
+ ds_outputs.append(x)
48
+ x = self.ds[i](x)
49
+
50
+ x = self.mid_dense(x)
51
+ for i in range(self.n):
52
+ x = self.us[i](x)
53
+ x *= ds_outputs[-i - 1]
54
+ x = self.us_dense[i](x)
55
+
56
+ x = x.transpose(-1, -2)
57
+ x = self.final_conv(x)
58
+ return x
MDX23v24/modules/tfc_tdf_v3.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from functools import partial
5
+
6
+ class STFT:
7
+ def __init__(self, config):
8
+ self.n_fft = config.n_fft
9
+ self.hop_length = config.hop_length
10
+ self.window = torch.hann_window(window_length=self.n_fft, periodic=True)
11
+ self.dim_f = config.dim_f
12
+
13
+ def __call__(self, x):
14
+ window = self.window.to(x.device)
15
+ batch_dims = x.shape[:-2]
16
+ c, t = x.shape[-2:]
17
+ x = x.reshape([-1, t])
18
+ x = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop_length, window=window, center=True, return_complex=True)
19
+ x = torch.view_as_real(x)
20
+ x = x.permute([0, 3, 1, 2])
21
+ x = x.reshape([*batch_dims, c, 2, -1, x.shape[-1]]).reshape([*batch_dims, c * 2, -1, x.shape[-1]])
22
+ return x[..., :self.dim_f, :]
23
+
24
+ def inverse(self, x):
25
+ window = self.window.to(x.device)
26
+ batch_dims = x.shape[:-3]
27
+ c, f, t = x.shape[-3:]
28
+ n = self.n_fft // 2 + 1
29
+ f_pad = torch.zeros([*batch_dims, c, n - f, t]).to(x.device)
30
+ x = torch.cat([x, f_pad], -2)
31
+ x = x.reshape([*batch_dims, c // 2, 2, n, t]).reshape([-1, 2, n, t])
32
+ x = x.permute([0, 2, 3, 1])
33
+ x = x[..., 0] + x[..., 1] * 1.j
34
+ x = torch.istft(x, n_fft=self.n_fft, hop_length=self.hop_length, window=window, center=True)
35
+ x = x.reshape([*batch_dims, 2, -1])
36
+ return x
37
+
38
+ def get_norm(norm_type):
39
+ def norm(c, norm_type):
40
+ if norm_type == 'BatchNorm':
41
+ return nn.BatchNorm2d(c)
42
+ elif norm_type == 'InstanceNorm':
43
+ return nn.InstanceNorm2d(c, affine=True)
44
+ elif 'GroupNorm' in norm_type:
45
+ g = int(norm_type.replace('GroupNorm', ''))
46
+ return nn.GroupNorm(num_groups=g, num_channels=c)
47
+ else:
48
+ return nn.Identity()
49
+
50
+ return partial(norm, norm_type=norm_type)
51
+
52
+ def get_act(act_type):
53
+ if act_type == 'gelu':
54
+ return nn.GELU()
55
+ elif act_type == 'relu':
56
+ return nn.ReLU()
57
+ elif act_type[:3] == 'elu':
58
+ alpha = float(act_type.replace('elu', ''))
59
+ return nn.ELU(alpha)
60
+ else:
61
+ raise Exception
62
+
63
+ class Upscale(nn.Module):
64
+ def __init__(self, in_c, out_c, scale, norm, act):
65
+ super().__init__()
66
+ self.conv = nn.Sequential(norm(in_c), act, nn.ConvTranspose2d(in_channels=in_c, out_channels=out_c, kernel_size=scale, stride=scale, bias=False))
67
+
68
+ def forward(self, x):
69
+ return self.conv(x)
70
+
71
+ class Downscale(nn.Module):
72
+ def __init__(self, in_c, out_c, scale, norm, act):
73
+ super().__init__()
74
+ self.conv = nn.Sequential(norm(in_c), act, nn.Conv2d(in_channels=in_c, out_channels=out_c, kernel_size=scale, stride=scale, bias=False))
75
+
76
+ def forward(self, x):
77
+ return self.conv(x)
78
+
79
+ class TFC_TDF(nn.Module):
80
+ def __init__(self, in_c, c, l, f, bn, norm, act):
81
+ super().__init__()
82
+ self.blocks = nn.ModuleList()
83
+
84
+ for i in range(l):
85
+ block = nn.Module()
86
+ block.tfc1 = nn.Sequential(norm(in_c), act, nn.Conv2d(in_c, c, 3, 1, 1, bias=False),)
87
+ block.tdf = nn.Sequential(norm(c), act, nn.Linear(f, f // bn, bias=False), norm(c), act, nn.Linear(f // bn, f, bias=False))
88
+ block.tfc2 = nn.Sequential(norm(c), act, nn.Conv2d(c, c, 3, 1, 1, bias=False))
89
+ block.shortcut = nn.Conv2d(in_c, c, 1, 1, 0, bias=False)
90
+ self.blocks.append(block)
91
+ in_c = c
92
+
93
+ def forward(self, x):
94
+ for block in self.blocks:
95
+ s = block.shortcut(x)
96
+ x = block.tfc1(x)
97
+ x = x + block.tdf(x)
98
+ x = block.tfc2(x)
99
+ x = x + s
100
+ return x
101
+
102
+ class TFC_TDF_net(nn.Module):
103
+ def __init__(self, config):
104
+ super().__init__()
105
+ self.config = config
106
+ norm = get_norm(norm_type=config.model.norm)
107
+ act = get_act(act_type=config.model.act)
108
+ self.num_target_instruments = 1 if config.training.target_instrument else len(config.training.instruments)
109
+ self.num_subbands = config.model.num_subbands
110
+ dim_c = self.num_subbands * config.audio.num_channels * 2
111
+ n = config.model.num_scales
112
+ scale = config.model.scale
113
+ l = config.model.num_blocks_per_scale
114
+ c = config.model.num_channels
115
+ g = config.model.growth
116
+ bn = config.model.bottleneck_factor
117
+ f = config.audio.dim_f // self.num_subbands
118
+ self.first_conv = nn.Conv2d(dim_c, c, 1, 1, 0, bias=False)
119
+ self.encoder_blocks = nn.ModuleList()
120
+
121
+ for i in range(n):
122
+ block = nn.Module()
123
+ block.tfc_tdf = TFC_TDF(c, c, l, f, bn, norm, act)
124
+ block.downscale = Downscale(c, c + g, scale, norm, act)
125
+ f = f // scale[1]
126
+ c += g
127
+ self.encoder_blocks.append(block)
128
+
129
+ self.bottleneck_block = TFC_TDF(c, c, l, f, bn, norm, act)
130
+ self.decoder_blocks = nn.ModuleList()
131
+
132
+ for i in range(n):
133
+ block = nn.Module()
134
+ block.upscale = Upscale(c, c - g, scale, norm, act)
135
+ f = f * scale[1]
136
+ c -= g
137
+ block.tfc_tdf = TFC_TDF(2 * c, c, l, f, bn, norm, act)
138
+ self.decoder_blocks.append(block)
139
+
140
+ self.final_conv = nn.Sequential(nn.Conv2d(c + dim_c, c, 1, 1, 0, bias=False), act, nn.Conv2d(c, self.num_target_instruments * dim_c, 1, 1, 0, bias=False))
141
+ self.stft = STFT(config.audio)
142
+
143
+ def cac2cws(self, x):
144
+ k = self.num_subbands
145
+ b, c, f, t = x.shape
146
+ x = x.reshape(b, c, k, f // k, t)
147
+ x = x.reshape(b, c * k, f // k, t)
148
+ return x
149
+
150
+ def cws2cac(self, x):
151
+ k = self.num_subbands
152
+ b, c, f, t = x.shape
153
+ x = x.reshape(b, c // k, k, f, t)
154
+ x = x.reshape(b, c // k, f * k, t)
155
+ return x
156
+
157
+ def forward(self, x):
158
+ x = self.stft(x)
159
+ mix = x = self.cac2cws(x)
160
+ first_conv_out = x = self.first_conv(x)
161
+ x = x.transpose(-1, -2)
162
+ encoder_outputs = []
163
+
164
+ for block in self.encoder_blocks:
165
+ x = block.tfc_tdf(x)
166
+ encoder_outputs.append(x)
167
+ x = block.downscale(x)
168
+
169
+ x = self.bottleneck_block(x)
170
+
171
+ for block in self.decoder_blocks:
172
+ x = block.upscale(x)
173
+ x = torch.cat([x, encoder_outputs.pop()], 1)
174
+ x = block.tfc_tdf(x)
175
+
176
+ x = x.transpose(-1, -2)
177
+ x = x * first_conv_out
178
+ x = self.final_conv(torch.cat([mix, x], 1))
179
+ x = self.cws2cac(x)
180
+
181
+ if self.num_target_instruments > 1:
182
+ b, c, f, t = x.shape
183
+ x = x.reshape(b, self.num_target_instruments, -1, f, t)
184
+
185
+ x = self.stft.inverse(x)
186
+
187
+ return x