gosummer commited on
Commit
b9fd956
·
verified ·
1 Parent(s): f3f17e2

Upload 2 files

Browse files
Files changed (2) hide show
  1. MDX23v24/inference.py +980 -0
  2. MDX23v24/requirements.txt +15 -0
MDX23v24/inference.py ADDED
@@ -0,0 +1,980 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+
3
+ if __name__ == '__main__':
4
+ import os
5
+
6
+ gpu_use = "0"
7
+
8
+ print('GPU use: {}'.format(gpu_use))
9
+ os.environ["CUDA_VISIBLE_DEVICES"] = "{}".format(gpu_use)
10
+ import warnings
11
+ warnings.filterwarnings("ignore")
12
+
13
+ from tqdm import tqdm
14
+ import numpy as np
15
+ import torch
16
+ import torch.nn as nn
17
+ import os
18
+ import argparse
19
+ import soundfile as sf
20
+ from demucs.states import load_model
21
+ from demucs import pretrained
22
+ from demucs.apply import apply_model
23
+ import onnxruntime as ort
24
+ from time import time
25
+ import librosa
26
+ import hashlib
27
+ from scipy import signal
28
+ import gc
29
+ import yaml
30
+ from ml_collections import ConfigDict
31
+ import sys
32
+ import math
33
+ import pathlib
34
+ import warnings
35
+ from scipy.signal import resample_poly
36
+
37
+ from modules.tfc_tdf_v2 import Conv_TDF_net_trim_model
38
+ from modules.tfc_tdf_v3 import TFC_TDF_net, STFT
39
+ from modules.segm_models import Segm_Models_Net
40
+ from modules.bs_roformer import BSRoformer
41
+
42
+
43
+
44
+ def get_models(name, device, load=True, vocals_model_type=0):
45
+ if vocals_model_type == 2:
46
+ model_vocals = Conv_TDF_net_trim_model(
47
+ device=device,
48
+ target_name='vocals',
49
+ L=11,
50
+ n_fft=7680,
51
+ dim_f=3072
52
+ )
53
+ elif vocals_model_type == 3:
54
+ model_vocals = Conv_TDF_net_trim_model(
55
+ device=device,
56
+ target_name='instrum',
57
+ L=11,
58
+ n_fft=5120,
59
+ dim_f=2560
60
+ )
61
+
62
+ return [model_vocals]
63
+
64
+
65
+ def get_model_from_config(model_type, config_path):
66
+ with open(config_path) as f:
67
+ config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader))
68
+ if model_type == 'mdx23c':
69
+ from modules.tfc_tdf_v3 import TFC_TDF_net
70
+ model = TFC_TDF_net(config)
71
+ elif model_type == 'segm_models':
72
+ from modules.segm_models import Segm_Models_Net
73
+ model = Segm_Models_Net(config)
74
+ elif model_type == 'bs_roformer':
75
+ from modules.bs_roformer import BSRoformer
76
+ model = BSRoformer(
77
+ **dict(config.model)
78
+ )
79
+ else:
80
+ print('Unknown model: {}'.format(model_type))
81
+ model = None
82
+ return model, config
83
+
84
+
85
+ def demix_new(model, mix, device, config, dim_t=256):
86
+ mix = torch.tensor(mix)
87
+ #N = options["overlap_BSRoformer"]
88
+ N = 2 # overlap 50%
89
+ batch_size = 1
90
+ mdx_window_size = dim_t
91
+ C = config.audio.hop_length * (mdx_window_size - 1)
92
+ fade_size = C // 100
93
+ step = int(C // N)
94
+ border = C - step
95
+ length_init = mix.shape[-1]
96
+ #print(f"1: {mix.shape}")
97
+
98
+ # Do pad from the beginning and end to account floating window results better
99
+ if length_init > 2 * border and (border > 0):
100
+ mix = nn.functional.pad(mix, (border, border), mode='reflect')
101
+
102
+
103
+ # Prepare windows arrays (do 1 time for speed up). This trick repairs click problems on the edges of segment
104
+ window_size = C
105
+ fadein = torch.linspace(0, 1, fade_size)
106
+ fadeout = torch.linspace(1, 0, fade_size)
107
+ window_start = torch.ones(window_size)
108
+ window_middle = torch.ones(window_size)
109
+ window_finish = torch.ones(window_size)
110
+ window_start[-fade_size:] *= fadeout # First audio chunk, no fadein
111
+ window_finish[:fade_size] *= fadein # Last audio chunk, no fadeout
112
+ window_middle[-fade_size:] *= fadeout
113
+ window_middle[:fade_size] *= fadein
114
+
115
+
116
+
117
+
118
+ with torch.cuda.amp.autocast():
119
+ with torch.inference_mode():
120
+ if config.training.target_instrument is not None:
121
+ req_shape = (1, ) + tuple(mix.shape)
122
+ else:
123
+ req_shape = (len(config.training.instruments),) + tuple(mix.shape)
124
+
125
+ result = torch.zeros(req_shape, dtype=torch.float32)
126
+ counter = torch.zeros(req_shape, dtype=torch.float32)
127
+ i = 0
128
+ batch_data = []
129
+ batch_locations = []
130
+ while i < mix.shape[1]:
131
+ # print(i, i + C, mix.shape[1])
132
+ part = mix[:, i:i + C].to(device)
133
+ length = part.shape[-1]
134
+ if length < C:
135
+ if length > C // 2 + 1:
136
+ part = nn.functional.pad(input=part, pad=(0, C - length), mode='reflect')
137
+ else:
138
+ part = nn.functional.pad(input=part, pad=(0, C - length, 0, 0), mode='constant', value=0)
139
+ batch_data.append(part)
140
+ batch_locations.append((i, length))
141
+ i += step
142
+
143
+ if len(batch_data) >= batch_size or (i >= mix.shape[1]):
144
+ arr = torch.stack(batch_data, dim=0)
145
+ x = model(arr)
146
+
147
+ window = window_middle
148
+ if i - step == 0: # First audio chunk, no fadein
149
+ window = window_start
150
+ elif i >= mix.shape[1]: # Last audio chunk, no fadeout
151
+ window = window_finish
152
+
153
+ for j in range(len(batch_locations)):
154
+ start, l = batch_locations[j]
155
+ result[..., start:start+l] += x[j][..., :l].cpu() * window[..., :l]
156
+ counter[..., start:start+l] += window[..., :l]
157
+
158
+ batch_data = []
159
+ batch_locations = []
160
+
161
+ estimated_sources = result / counter
162
+ estimated_sources = estimated_sources.cpu().numpy()
163
+ np.nan_to_num(estimated_sources, copy=False, nan=0.0)
164
+
165
+ if length_init > 2 * border and (border > 0):
166
+ # Remove pad
167
+ estimated_sources = estimated_sources[..., border:-border]
168
+
169
+ if config.training.target_instrument is None:
170
+ return {k: v for k, v in zip(config.training.instruments, estimated_sources)}
171
+ else:
172
+ return {k: v for k, v in zip([config.training.target_instrument], estimated_sources)}
173
+
174
+
175
+ def demix_new_wrapper(mix, device, model, config, dim_t=256):
176
+ if options["BigShifts"] <= 0:
177
+ bigshifts = 1
178
+ else:
179
+ bigshifts = options["BigShifts"]
180
+
181
+ shift_in_samples = mix.shape[1] // bigshifts
182
+ shifts = [x * shift_in_samples for x in range(bigshifts)]
183
+
184
+ results = []
185
+
186
+ for shift in tqdm(shifts, position=0):
187
+ shifted_mix = np.concatenate((mix[:, -shift:], mix[:, :-shift]), axis=-1)
188
+ sources = demix_new(model, shifted_mix, device, config, dim_t=dim_t)
189
+ vocals = next(sources[key] for key in sources.keys() if key.lower() == "vocals")
190
+ unshifted_vocals = np.concatenate((vocals[..., shift:], vocals[..., :shift]), axis=-1)
191
+ vocals *= 1 # 1.0005168 CHECK NEEDED! volume compensation
192
+
193
+ results.append(unshifted_vocals)
194
+
195
+ vocals = np.mean(results, axis=0)
196
+
197
+ return vocals
198
+
199
+ def demix_vitlarge(model, mix, device):
200
+ C = model.config.audio.hop_length * (2 * model.config.inference.dim_t - 1)
201
+ N = 2
202
+ step = C // N
203
+
204
+ with torch.cuda.amp.autocast():
205
+ with torch.no_grad():
206
+ if model.config.training.target_instrument is not None:
207
+ req_shape = (1, ) + tuple(mix.shape)
208
+ else:
209
+ req_shape = (len(model.config.training.instruments),) + tuple(mix.shape)
210
+
211
+ mix = mix.to(device)
212
+ result = torch.zeros(req_shape, dtype=torch.float32).to(device)
213
+ counter = torch.zeros(req_shape, dtype=torch.float32).to(device)
214
+ i = 0
215
+
216
+ while i < mix.shape[1]:
217
+ part = mix[:, i:i + C]
218
+ length = part.shape[-1]
219
+ if length < C:
220
+ part = nn.functional.pad(input=part, pad=(0, C - length, 0, 0), mode='constant', value=0)
221
+ x = model(part.unsqueeze(0))[0]
222
+ result[..., i:i+length] += x[..., :length]
223
+ counter[..., i:i+length] += 1.
224
+ i += step
225
+ estimated_sources = result / counter
226
+
227
+ if model.config.training.target_instrument is None:
228
+ return {k: v for k, v in zip(model.config.training.instruments, estimated_sources.cpu().numpy())}
229
+ else:
230
+ return {k: v for k, v in zip([model.config.training.target_instrument], estimated_sources.cpu().numpy())}
231
+
232
+
233
+ def demix_full_vitlarge(mix, device, model):
234
+ if options["BigShifts"] <= 0:
235
+ bigshifts = 1
236
+ else:
237
+ bigshifts = options["BigShifts"]
238
+ shift_in_samples = mix.shape[1] // bigshifts
239
+ shifts = [x * shift_in_samples for x in range(bigshifts)]
240
+
241
+ results1 = []
242
+ results2 = []
243
+ mix = torch.from_numpy(mix).type('torch.FloatTensor').to(device)
244
+ for shift in tqdm(shifts, position=0):
245
+ shifted_mix = torch.cat((mix[:, -shift:], mix[:, :-shift]), dim=-1)
246
+ sources = demix_vitlarge(model, shifted_mix, device)
247
+ sources1 = sources["vocals"]
248
+ sources2 = sources["other"]
249
+ restored_sources1 = np.concatenate((sources1[..., shift:], sources1[..., :shift]), axis=-1)
250
+ restored_sources2 = np.concatenate((sources2[..., shift:], sources2[..., :shift]), axis=-1)
251
+ results1.append(restored_sources1)
252
+ results2.append(restored_sources2)
253
+
254
+
255
+ sources1 = np.mean(results1, axis=0)
256
+ sources2 = np.mean(results2, axis=0)
257
+
258
+ return sources1, sources2
259
+
260
+
261
+ def demix_wrapper(mix, device, models, infer_session, overlap=0.2, bigshifts=1, vc=1.0):
262
+ if bigshifts <= 0:
263
+ bigshifts = 1
264
+ shift_in_samples = mix.shape[1] // bigshifts
265
+ shifts = [x * shift_in_samples for x in range(bigshifts)]
266
+ results = []
267
+
268
+ for shift in tqdm(shifts, position=0):
269
+ shifted_mix = np.concatenate((mix[:, -shift:], mix[:, :-shift]), axis=-1)
270
+ sources = demix(shifted_mix, device, models, infer_session, overlap) * vc # 1.021 volume compensation
271
+ restored_sources = np.concatenate((sources[..., shift:], sources[..., :shift]), axis=-1)
272
+ results.append(restored_sources)
273
+
274
+ sources = np.mean(results, axis=0)
275
+
276
+ return sources
277
+
278
+ def demix(mix, device, models, infer_session, overlap=0.2):
279
+ start_time = time()
280
+ sources = []
281
+ n_sample = mix.shape[1]
282
+ n_fft = models[0].n_fft
283
+ n_bins = n_fft//2+1
284
+ trim = n_fft//2
285
+ hop = models[0].hop
286
+ dim_f = models[0].dim_f
287
+ dim_t = models[0].dim_t # * 2
288
+ chunk_size = hop * (dim_t -1)
289
+ org_mix = mix
290
+ tar_waves_ = []
291
+ mdx_batch_size = 1
292
+ overlap = overlap
293
+ gen_size = chunk_size-2*trim
294
+ pad = gen_size + trim - ((mix.shape[-1]) % gen_size)
295
+
296
+ mixture = np.concatenate((np.zeros((2, trim), dtype='float32'), mix, np.zeros((2, pad), dtype='float32')), 1)
297
+
298
+ step = int((1 - overlap) * chunk_size)
299
+ result = np.zeros((1, 2, mixture.shape[-1]), dtype=np.float32)
300
+ divider = np.zeros((1, 2, mixture.shape[-1]), dtype=np.float32)
301
+ total = 0
302
+ total_chunks = (mixture.shape[-1] + step - 1) // step
303
+
304
+ for i in range(0, mixture.shape[-1], step):
305
+ total += 1
306
+ start = i
307
+ end = min(i + chunk_size, mixture.shape[-1])
308
+ chunk_size_actual = end - start
309
+
310
+ if overlap == 0:
311
+ window = None
312
+ else:
313
+ window = np.hanning(chunk_size_actual)
314
+ window = np.tile(window[None, None, :], (1, 2, 1))
315
+
316
+ mix_part_ = mixture[:, start:end]
317
+ if end != i + chunk_size:
318
+ pad_size = (i + chunk_size) - end
319
+ mix_part_ = np.concatenate((mix_part_, np.zeros((2, pad_size), dtype='float32')), axis=-1)
320
+
321
+
322
+ mix_part = torch.tensor([mix_part_], dtype=torch.float32).to(device)
323
+ mix_waves = mix_part.split(mdx_batch_size)
324
+
325
+ with torch.no_grad():
326
+ for mix_wave in mix_waves:
327
+ _ort = infer_session
328
+ stft_res = models[0].stft(mix_wave)
329
+ stft_res[:, :, :3, :] *= 0
330
+ res = _ort.run(None, {'input': stft_res.cpu().numpy()})[0]
331
+ ten = torch.tensor(res)
332
+ tar_waves = models[0].istft(ten.to(device))
333
+ tar_waves = tar_waves.cpu().detach().numpy()
334
+
335
+ if window is not None:
336
+ tar_waves[..., :chunk_size_actual] *= window
337
+ divider[..., start:end] += window
338
+ else:
339
+ divider[..., start:end] += 1
340
+ result[..., start:end] += tar_waves[..., :end-start]
341
+
342
+
343
+ tar_waves = result / divider
344
+ tar_waves_.append(tar_waves)
345
+ tar_waves_ = np.vstack(tar_waves_)[:, :, trim:-trim]
346
+ tar_waves = np.concatenate(tar_waves_, axis=-1)[:, :mix.shape[-1]]
347
+ source = tar_waves[:,0:None]
348
+
349
+ return source
350
+
351
+ class EnsembleDemucsMDXMusicSeparationModel:
352
+ """
353
+ Doesn't do any separation just passes the input back as output
354
+ """
355
+ def __init__(self, options):
356
+ """
357
+ options - user options
358
+ """
359
+
360
+ if torch.cuda.is_available():
361
+ device = 'cuda:0'
362
+ else:
363
+ device = 'cpu'
364
+ if 'cpu' in options:
365
+ if options['cpu']:
366
+ device = 'cpu'
367
+ # print('Use device: {}'.format(device))
368
+ self.single_onnx = False
369
+ if 'single_onnx' in options:
370
+ if options['single_onnx']:
371
+ self.single_onnx = True
372
+ # print('Use single vocal ONNX')
373
+ self.overlap_demucs = float(options['overlap_demucs'])
374
+ self.overlap_MDX = float(options['overlap_VOCFT'])
375
+ if self.overlap_demucs > 0.99:
376
+ self.overlap_demucs = 0.99
377
+ if self.overlap_demucs < 0.0:
378
+ self.overlap_demucs = 0.0
379
+ if self.overlap_MDX > 0.99:
380
+ self.overlap_MDX = 0.99
381
+ if self.overlap_MDX < 0.0:
382
+ self.overlap_MDX = 0.0
383
+ model_folder = os.path.dirname(os.path.realpath(__file__)) + '/models/'
384
+ """
385
+
386
+ remote_url = 'https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/04573f0d-f3cf25b2.th'
387
+ model_path = model_folder + '04573f0d-f3cf25b2.th'
388
+ if not os.path.isfile(model_path):
389
+ torch.hub.download_url_to_file(remote_url, model_folder + '04573f0d-f3cf25b2.th')
390
+ model_vocals = load_model(model_path)
391
+ model_vocals.to(device)
392
+ self.model_vocals_only = model_vocals
393
+ """
394
+
395
+ if options['vocals_only'] is False:
396
+ self.models = []
397
+ self.weights_vocals = np.array([10, 1, 8, 9])
398
+ self.weights_bass = np.array([19, 4, 5, 8])
399
+ self.weights_drums = np.array([18, 2, 4, 9])
400
+ self.weights_other = np.array([14, 2, 5, 10])
401
+
402
+ model1 = pretrained.get_model('htdemucs_ft')
403
+ model1.to(device)
404
+ self.models.append(model1)
405
+
406
+ model2 = pretrained.get_model('htdemucs')
407
+ model2.to(device)
408
+ self.models.append(model2)
409
+
410
+ model3 = pretrained.get_model('htdemucs_6s')
411
+ model3.to(device)
412
+ self.models.append(model3)
413
+
414
+ model4 = pretrained.get_model('hdemucs_mmi')
415
+ model4.to(device)
416
+ self.models.append(model4)
417
+
418
+ if 0:
419
+ for model in self.models:
420
+ pass
421
+ # print(model.sources)
422
+ '''
423
+ ['drums', 'bass', 'other', 'vocals']
424
+ ['drums', 'bass', 'other', 'vocals']
425
+ ['drums', 'bass', 'other', 'vocals', 'guitar', 'piano']
426
+ ['drums', 'bass', 'other', 'vocals']
427
+ '''
428
+
429
+ """
430
+ #BS-RoformerDRUMS+BASS init
431
+ print("Loading BS-RoformerDB into memory")
432
+ remote_url_bsrofoDB = 'https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/model_bs_roformer_ep_937_sdr_10.5309.ckpt'
433
+ remote_url_conf_bsrofoDB = 'https://raw.githubusercontent.com/TRvlvr/application_data/main/mdx_model_data/mdx_c_configs/model_bs_roformer_ep_937_sdr_10.5309.yaml'
434
+ if not os.path.isfile(model_folder+'model_bs_roformer_ep_937_sdr_10.5309.ckpt'):
435
+ torch.hub.download_url_to_file(remote_url_bsrofoDB, model_folder+'model_bs_roformer_ep_937_sdr_10.5309.ckpt')
436
+ if not os.path.isfile(model_folder+'model_bs_roformer_ep_937_sdr_10.5309.yaml'):
437
+ torch.hub.download_url_to_file(remote_url_conf_bsrofoDB, model_folder+'model_bs_roformer_ep_937_sdr_10.5309.yaml')
438
+
439
+ with open(model_folder + 'model_bs_roformer_ep_937_sdr_10.5309.yaml') as f:
440
+ config_bsrofoDB = ConfigDict(yaml.load(f, Loader=yaml.FullLoader))
441
+
442
+ self.model_bsrofoDB = BSRoformer(**dict(config_bsrofoDB.model))
443
+ self.config_bsrofoDB = config_bsrofoDB
444
+ self.model_bsrofoDB.load_state_dict(torch.load(model_folder+'model_bs_roformer_ep_937_sdr_10.5309.ckpt'))
445
+ self.device = torch.device(device)
446
+ self.model_bsrofoDB = self.model_bsrofoDB.to(device)
447
+ self.model_bsrofoDB.eval()
448
+ """
449
+
450
+ if device == 'cpu':
451
+ providers = ["CPUExecutionProvider"]
452
+ else:
453
+ providers = ["CUDAExecutionProvider"]
454
+
455
+
456
+ #BS-RoformerVOC init
457
+ print("Loading BS-Roformer into memory")
458
+ if options["BSRoformer_model"] == "ep_368_1296":
459
+ model_name = "model_bs_roformer_ep_368_sdr_12.9628"
460
+ else:
461
+ model_name = "model_bs_roformer_ep_317_sdr_12.9755"
462
+ remote_url_bsrofo = f'https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/{model_name}.ckpt'
463
+ remote_url_conf_bsrofo = f'https://raw.githubusercontent.com/TRvlvr/application_data/main/mdx_model_data/mdx_c_configs/{model_name}.yaml'
464
+ if not os.path.isfile(model_folder+f'{model_name}.ckpt'):
465
+ torch.hub.download_url_to_file(remote_url_bsrofo, model_folder+f'{model_name}.ckpt')
466
+ if not os.path.isfile(model_folder+f'{model_name}.yaml'):
467
+ torch.hub.download_url_to_file(remote_url_conf_bsrofo, model_folder+f'{model_name}.yaml')
468
+
469
+ with open(model_folder + f'{model_name}.yaml') as f:
470
+ config_bsrofo = ConfigDict(yaml.load(f, Loader=yaml.FullLoader))
471
+
472
+ self.model_bsrofo = BSRoformer(**dict(config_bsrofo.model))
473
+ self.config_bsrofo = config_bsrofo
474
+ self.model_bsrofo.load_state_dict(torch.load(model_folder+f'{model_name}.ckpt'))
475
+ self.device = torch.device(device)
476
+ self.model_bsrofo = self.model_bsrofo.to(device)
477
+ self.model_bsrofo.eval()
478
+
479
+
480
+ #MDXv3 init
481
+ print("Loading InstVoc into memory")
482
+ remote_url_mdxv3 = 'https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/MDX23C-8KFFT-InstVoc_HQ.ckpt'
483
+ remote_url_conf_mdxv3 = 'https://raw.githubusercontent.com/TRvlvr/application_data/main/mdx_model_data/mdx_c_configs/model_2_stem_full_band_8k.yaml'
484
+ if not os.path.isfile(model_folder+'MDX23C-8KFFT-InstVoc_HQ.ckpt'):
485
+ torch.hub.download_url_to_file(remote_url_mdxv3, model_folder+'MDX23C-8KFFT-InstVoc_HQ.ckpt')
486
+ if not os.path.isfile(model_folder+'model_2_stem_full_band_8k.yaml'):
487
+ torch.hub.download_url_to_file(remote_url_conf_mdxv3, model_folder+'model_2_stem_full_band_8k.yaml')
488
+
489
+ with open(model_folder + 'model_2_stem_full_band_8k.yaml') as f:
490
+ config_mdxv3 = ConfigDict(yaml.load(f, Loader=yaml.FullLoader))
491
+
492
+ self.config_mdxv3 = config_mdxv3
493
+ self.model_mdxv3 = TFC_TDF_net(config_mdxv3)
494
+ self.model_mdxv3.load_state_dict(torch.load(model_folder+'MDX23C-8KFFT-InstVoc_HQ.ckpt'))
495
+ self.device = torch.device(device)
496
+ self.model_mdxv3 = self.model_mdxv3.to(device)
497
+ self.model_mdxv3.eval()
498
+
499
+ #VitLarge init
500
+ if options['use_VitLarge'] is True:
501
+ print("Loading VitLarge into memory")
502
+ remote_url_vitlarge = 'https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.0/model_vocals_segm_models_sdr_9.77.ckpt'
503
+ remote_url_vl_conf = 'https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.0/config_vocals_segm_models.yaml'
504
+ if not os.path.isfile(model_folder+'model_vocals_segm_models_sdr_9.77.ckpt'):
505
+ torch.hub.download_url_to_file(remote_url_vitlarge, model_folder+'model_vocals_segm_models_sdr_9.77.ckpt')
506
+ if not os.path.isfile(model_folder+'config_vocals_segm_models.yaml'):
507
+ torch.hub.download_url_to_file(remote_url_vl_conf, model_folder+'config_vocals_segm_models.yaml')
508
+
509
+ with open(model_folder + 'config_vocals_segm_models.yaml') as f:
510
+ config_vl = ConfigDict(yaml.load(f, Loader=yaml.FullLoader))
511
+
512
+ self.config_vl = config_vl
513
+ self.model_vl = Segm_Models_Net(config_vl)
514
+ self.model_vl.load_state_dict(torch.load(model_folder+'model_vocals_segm_models_sdr_9.77.ckpt'))
515
+ self.device = torch.device(device)
516
+ self.model_vl = self.model_vl.to(device)
517
+ self.model_vl.eval()
518
+
519
+ # VOCFT init
520
+ if options['use_VOCFT']:
521
+ print("Loading VOCFT into memory")
522
+ self.mdx_models1 = get_models('tdf_extra', load=False, device=device, vocals_model_type=2)
523
+ model_path_onnx1 = model_folder + 'UVR-MDX-NET-Voc_FT.onnx'
524
+ remote_url_onnx1 = 'https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/UVR-MDX-NET-Voc_FT.onnx'
525
+
526
+ if not os.path.isfile(model_path_onnx1):
527
+ torch.hub.download_url_to_file(remote_url_onnx1, model_path_onnx1)
528
+
529
+ self.infer_session1 = ort.InferenceSession(
530
+ model_path_onnx1,
531
+ providers=providers,
532
+ provider_options=[{"device_id": 0}],
533
+ )
534
+
535
+ # InstHQ4 init
536
+ if options['use_InstHQ4']:
537
+ print("Loading InstHQ4 into memory")
538
+ self.mdx_models2 = get_models('tdf_extra', load=False, device=device, vocals_model_type=3)
539
+ model_path_onnx2 = model_folder + 'UVR-MDX-NET-Inst_HQ_4.onnx'
540
+ remote_url_onnx2 = 'https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/UVR-MDX-NET-Inst_HQ_4.onnx'
541
+
542
+ if not os.path.isfile(model_path_onnx2):
543
+ torch.hub.download_url_to_file(remote_url_onnx2, model_path_onnx2)
544
+
545
+ self.infer_session2 = ort.InferenceSession(
546
+ model_path_onnx2,
547
+ providers=providers,
548
+ provider_options=[{"device_id": 0}],
549
+ )
550
+
551
+
552
+ self.device = device
553
+ pass
554
+
555
+ @property
556
+ def instruments(self):
557
+
558
+ if options['vocals_only'] is False:
559
+ return ['bass', 'drums', 'other', 'vocals']
560
+ else:
561
+ return ['vocals']
562
+
563
+ def raise_aicrowd_error(self, msg):
564
+ """ Will be used by the evaluator to provide logs, DO NOT CHANGE """
565
+ raise NameError(msg)
566
+
567
+ def separate_music_file(
568
+ self,
569
+ mixed_sound_array,
570
+ sample_rate,
571
+ current_file_number=0,
572
+ total_files=0,
573
+ ):
574
+ """
575
+ Implements the sound separation for a single sound file
576
+ Inputs: Outputs from soundfile.read('mixture.wav')
577
+ mixed_sound_array
578
+ sample_rate
579
+
580
+ Outputs:
581
+ separated_music_arrays: Dictionary numpy array of each separated instrument
582
+ output_sample_rates: Dictionary of sample rates separated sequence
583
+ """
584
+
585
+ # print('Update percent func: {}'.format(update_percent_func))
586
+
587
+ separated_music_arrays = {}
588
+ output_sample_rates = {}
589
+ #print(mixed_sound_array.T.shape)
590
+ #audio = np.expand_dims(mixed_sound_array.T, axis=0)
591
+
592
+
593
+ overlap_demucs = self.overlap_demucs
594
+ overlap_MDX = self.overlap_MDX
595
+ shifts = 0
596
+ overlap = overlap_demucs
597
+
598
+ vocals_model_names = [
599
+ "BSRoformer",
600
+ "InstVoc",
601
+ "VitLarge",
602
+ "VOCFT",
603
+ "InstHQ4"
604
+ ]
605
+
606
+ vocals_model_outputs = []
607
+ weights = []
608
+
609
+ for model_name in vocals_model_names:
610
+
611
+ if options[f"use_{model_name}"]:
612
+
613
+ if model_name == "BSRoformer":
614
+ print(f'Processing vocals with {model_name} model...')
615
+ sources_bs = demix_new_wrapper(mixed_sound_array.T, self.device, self.model_bsrofo, self.config_bsrofo, dim_t=1101)
616
+ vocals_bs = match_array_shapes(sources_bs, mixed_sound_array.T)
617
+ vocals_model_outputs.append(vocals_bs)
618
+ weights.append(options.get(f"weight_{model_name}"))
619
+
620
+
621
+ if model_name == "InstVoc":
622
+ print(f'Processing vocals with {model_name} model...')
623
+ sources3 = demix_new_wrapper(mixed_sound_array.T, self.device, self.model_mdxv3, self.config_mdxv3, dim_t=1024)
624
+ vocals3 = match_array_shapes(sources3, mixed_sound_array.T)
625
+ vocals_model_outputs.append(vocals3)
626
+ weights.append(options.get(f"weight_{model_name}"))
627
+
628
+ elif model_name == "VitLarge":
629
+ print(f'Processing vocals with {model_name} model...')
630
+ vocals4, instrum4 = demix_full_vitlarge(mixed_sound_array.T, self.device, self.model_vl)#, self.config_vl, dim_t=512)
631
+ vocals4 = match_array_shapes(vocals4, mixed_sound_array.T)
632
+ vocals_model_outputs.append(vocals4)
633
+ weights.append(options.get(f"weight_{model_name}"))
634
+
635
+ elif model_name == "VOCFT":
636
+ print(f'Processing vocals with {model_name} model...')
637
+ overlap = overlap_MDX
638
+ sources1 = 0.5 * demix_wrapper(
639
+ mixed_sound_array.T,
640
+ self.device,
641
+ self.mdx_models1,
642
+ self.infer_session1,
643
+ overlap=overlap,
644
+ vc=1.021,
645
+ bigshifts=options['BigShifts'] // 3
646
+ )
647
+ sources1 += 0.5 * -demix_wrapper(
648
+ -mixed_sound_array.T,
649
+ self.device,
650
+ self.mdx_models1,
651
+ self.infer_session1,
652
+ overlap=overlap,
653
+ vc=1.021,
654
+ bigshifts=options['BigShifts'] // 3
655
+ )
656
+ vocals_mdxb1 = sources1
657
+ vocals_model_outputs.append(vocals_mdxb1)
658
+ weights.append(options.get(f"weight_{model_name}"))
659
+
660
+ elif model_name == "InstHQ4":
661
+ print(f'Processing vocals with {model_name} model...')
662
+ overlap = overlap_MDX
663
+ sources2 = 0.5 * demix_wrapper(
664
+ mixed_sound_array.T,
665
+ self.device,
666
+ self.mdx_models2,
667
+ self.infer_session2,
668
+ overlap=overlap,
669
+ vc=1.019,
670
+ bigshifts=options['BigShifts'] // 3
671
+ )
672
+ sources2 += 0.5 * -demix_wrapper(
673
+ -mixed_sound_array.T,
674
+ self.device,
675
+ self.mdx_models2,
676
+ self.infer_session2,
677
+ overlap=overlap,
678
+ vc=1.019,
679
+ bigshifts=options['BigShifts'] // 3
680
+ )
681
+ vocals_mdxb2 = mixed_sound_array.T - sources2
682
+ vocals_model_outputs.append(vocals_mdxb2)
683
+ weights.append(options.get(f"weight_{model_name}"))
684
+
685
+ else:
686
+ # No more model to process or unknown one
687
+ pass
688
+
689
+ print('Processing vocals: DONE!')
690
+
691
+ vocals_combined = np.zeros_like(vocals_model_outputs[0])
692
+
693
+ for output, weight in zip(vocals_model_outputs, weights):
694
+ vocals_combined += output * weight
695
+
696
+ vocals_combined /= np.sum(weights)
697
+
698
+ vocals_low = lr_filter(vocals_combined.T, 12000, 'lowpass') # * 1.01055 # remember to check if new final finetuned volume compensation is needed !
699
+ vocals_high = lr_filter(vocals3.T, 12000, 'highpass')
700
+
701
+ vocals = vocals_low + vocals_high
702
+ #vocals = vocals_combined.T
703
+
704
+ if options['filter_vocals'] is True:
705
+ vocals = lr_filter(vocals, 50, 'highpass', order=8)
706
+
707
+ # Generate instrumental
708
+ instrum = mixed_sound_array - vocals
709
+
710
+ if options['vocals_only'] is False:
711
+
712
+ """
713
+ print(f'Processing drums & bass with 2nd BS-Roformer model...')
714
+ other_bs2 = demix_full_bsrofo(instrum.T, self.device, self.model_bsrofoDB, self.config_bsrofoDB)
715
+ other_bs2 = match_array_shapes(other_bs2, mixed_sound_array.T)
716
+ drums_bass_bs2 = mixed_sound_array.T - other_bs2
717
+
718
+
719
+ print('Starting Demucs processing...')
720
+
721
+ drums_bass_bs2 = np.expand_dims(drums_bass_bs2.T, axis=0)
722
+ drums_bass_bs2 = torch.from_numpy(drums_bass_bs2).type('torch.FloatTensor').to(self.device)
723
+ """
724
+ audio = np.expand_dims(instrum.T, axis=0)
725
+ audio = torch.from_numpy(audio).type('torch.FloatTensor').to(self.device)
726
+ all_outs = []
727
+ print('Processing with htdemucs_ft...')
728
+ i = 0
729
+ overlap = overlap_demucs
730
+ model = pretrained.get_model('htdemucs_ft')
731
+ model.to(self.device)
732
+ out = 0.5 * apply_model(model, audio, shifts=shifts, overlap=overlap)[0].cpu().numpy() \
733
+ + 0.5 * -apply_model(model, -audio, shifts=shifts, overlap=overlap)[0].cpu().numpy()
734
+
735
+ out[0] = self.weights_drums[i] * out[0]
736
+ out[1] = self.weights_bass[i] * out[1]
737
+ out[2] = self.weights_other[i] * out[2]
738
+ out[3] = self.weights_vocals[i] * out[3]
739
+ all_outs.append(out)
740
+ model = model.cpu()
741
+ del model
742
+ gc.collect()
743
+ i = 1
744
+ print('Processing with htdemucs...')
745
+ overlap = overlap_demucs
746
+ model = pretrained.get_model('htdemucs')
747
+ model.to(self.device)
748
+ out = 0.5 * apply_model(model, audio, shifts=shifts, overlap=overlap)[0].cpu().numpy() \
749
+ + 0.5 * -apply_model(model, -audio, shifts=shifts, overlap=overlap)[0].cpu().numpy()
750
+
751
+ out[0] = self.weights_drums[i] * out[0]
752
+ out[1] = self.weights_bass[i] * out[1]
753
+ out[2] = self.weights_other[i] * out[2]
754
+ out[3] = self.weights_vocals[i] * out[3]
755
+ all_outs.append(out)
756
+ model = model.cpu()
757
+ del model
758
+ gc.collect()
759
+ i = 2
760
+ print('Processing with htdemucs_6s...')
761
+ overlap = overlap_demucs
762
+ model = pretrained.get_model('htdemucs_6s')
763
+ model.to(self.device)
764
+ out = apply_model(model, audio, shifts=shifts, overlap=overlap)[0].cpu().numpy()
765
+
766
+ # More stems need to add
767
+ out[2] = out[2] + out[4] + out[5]
768
+ out = out[:4]
769
+ out[0] = self.weights_drums[i] * out[0]
770
+ out[1] = self.weights_bass[i] * out[1]
771
+ out[2] = self.weights_other[i] * out[2]
772
+ out[3] = self.weights_vocals[i] * out[3]
773
+ all_outs.append(out)
774
+ model = model.cpu()
775
+ del model
776
+ gc.collect()
777
+ i = 3
778
+ print('Processing with htdemucs_mmi...')
779
+ model = pretrained.get_model('hdemucs_mmi')
780
+ model.to(self.device)
781
+ out = 0.5 * apply_model(model, audio, shifts=shifts, overlap=overlap)[0].cpu().numpy() \
782
+ + 0.5 * -apply_model(model, -audio, shifts=shifts, overlap=overlap)[0].cpu().numpy()
783
+
784
+ out[0] = self.weights_drums[i] * out[0]
785
+ out[1] = self.weights_bass[i] * out[1]
786
+ out[2] = self.weights_other[i] * out[2]
787
+ out[3] = self.weights_vocals[i] * out[3]
788
+ all_outs.append(out)
789
+ model = model.cpu()
790
+ del model
791
+ gc.collect()
792
+ out = np.array(all_outs).sum(axis=0)
793
+ out[0] = out[0] / self.weights_drums.sum()
794
+ out[1] = out[1] / self.weights_bass.sum()
795
+ out[2] = out[2] / self.weights_other.sum()
796
+ out[3] = out[3] / self.weights_vocals.sum()
797
+
798
+ # other
799
+ res = mixed_sound_array - vocals - out[0].T - out[1].T
800
+ res = np.clip(res, -1, 1)
801
+ separated_music_arrays['other'] = (2 * res + out[2].T) / 3.0
802
+ output_sample_rates['other'] = sample_rate
803
+
804
+ # drums
805
+ res = mixed_sound_array - vocals - out[1].T - out[2].T
806
+ res = np.clip(res, -1, 1)
807
+ separated_music_arrays['drums'] = (res + 2 * out[0].T.copy()) / 3.0
808
+ output_sample_rates['drums'] = sample_rate
809
+
810
+ # bass
811
+ res = mixed_sound_array - vocals - out[0].T - out[2].T
812
+ res = np.clip(res, -1, 1)
813
+ separated_music_arrays['bass'] = (res + 2 * out[1].T) / 3.0
814
+ output_sample_rates['bass'] = sample_rate
815
+
816
+ bass = separated_music_arrays['bass']
817
+ drums = separated_music_arrays['drums']
818
+ other = separated_music_arrays['other']
819
+
820
+ separated_music_arrays['other'] = mixed_sound_array - vocals - bass - drums
821
+ separated_music_arrays['drums'] = mixed_sound_array - vocals - bass - other
822
+ separated_music_arrays['bass'] = mixed_sound_array - vocals - drums - other
823
+
824
+ # vocals
825
+ separated_music_arrays['vocals'] = vocals
826
+ output_sample_rates['vocals'] = sample_rate
827
+
828
+ # instrum
829
+ separated_music_arrays['instrum'] = instrum
830
+
831
+ return separated_music_arrays, output_sample_rates
832
+
833
+
834
+ def predict_with_model(options):
835
+
836
+ output_format = options['output_format']
837
+ output_extension = 'flac' if output_format == 'FLAC' else "wav"
838
+ output_format = 'PCM_16' if output_format == 'FLAC' else options['output_format']
839
+
840
+ for input_audio in options['input_audio']:
841
+ if not os.path.isfile(input_audio):
842
+ print('Error. No such file: {}. Please check path!'.format(input_audio))
843
+ return
844
+ output_folder = options['output_folder']
845
+ if not os.path.isdir(output_folder):
846
+ os.mkdir(output_folder)
847
+
848
+ model = None
849
+ model = EnsembleDemucsMDXMusicSeparationModel(options)
850
+
851
+ for i, input_audio in enumerate(options['input_audio']):
852
+ print('Go for: {}'.format(input_audio))
853
+ audio, sr = librosa.load(input_audio, mono=False, sr=44100)
854
+ if len(audio.shape) == 1:
855
+ audio = np.stack([audio, audio], axis=0)
856
+
857
+
858
+ if options['input_gain'] != 0:
859
+ audio = dBgain(audio, options['input_gain'])
860
+
861
+ print("Input audio: {} Sample rate: {}".format(audio.shape, sr))
862
+ result, sample_rates = model.separate_music_file(audio.T, sr, i, len(options['input_audio']))
863
+
864
+ for instrum in model.instruments:
865
+ output_name = os.path.splitext(os.path.basename(input_audio))[0] + '_{}.{}'.format(instrum, output_extension)
866
+ if options["restore_gain"] is True: #restoring original gain
867
+ result[instrum] = dBgain(result[instrum], -options['input_gain'])
868
+ sf.write(output_folder + '/' + output_name, result[instrum], sample_rates[instrum], subtype=output_format)
869
+ print('File created: {}'.format(output_folder + '/' + output_name))
870
+
871
+ # instrumental part 1
872
+ # inst = (audio.T - result['vocals'])
873
+ inst = result['instrum']
874
+
875
+ if options["restore_gain"] is True: #restoring original gain
876
+ inst = dBgain(inst, -options['input_gain'])
877
+
878
+ output_name = os.path.splitext(os.path.basename(input_audio))[0] + '_{}.{}'.format('instrum', output_extension)
879
+ sf.write(output_folder + '/' + output_name, inst, sr, subtype=output_format)
880
+ print('File created: {}'.format(output_folder + '/' + output_name))
881
+
882
+ if options['vocals_only'] is False:
883
+ # instrumental part 2
884
+ inst2 = (result['bass'] + result['drums'] + result['other'])
885
+ output_name = os.path.splitext(os.path.basename(input_audio))[0] + '_{}.{}'.format('instrum2', output_extension)
886
+ sf.write(output_folder + '/' + output_name, inst2, sr, subtype=output_format)
887
+ print('File created: {}'.format(output_folder + '/' + output_name))
888
+
889
+
890
+ # Linkwitz-Riley filter
891
+ def lr_filter(audio, cutoff, filter_type, order=6, sr=44100):
892
+ audio = audio.T
893
+ nyquist = 0.5 * sr
894
+ normal_cutoff = cutoff / nyquist
895
+ b, a = signal.butter(order//2, normal_cutoff, btype=filter_type, analog=False)
896
+ sos = signal.tf2sos(b, a)
897
+ filtered_audio = signal.sosfiltfilt(sos, audio)
898
+ return filtered_audio.T
899
+
900
+ def match_array_shapes(array_1:np.ndarray, array_2:np.ndarray):
901
+ if array_1.shape[1] > array_2.shape[1]:
902
+ array_1 = array_1[:,:array_2.shape[1]]
903
+ elif array_1.shape[1] < array_2.shape[1]:
904
+ padding = array_2.shape[1] - array_1.shape[1]
905
+ array_1 = np.pad(array_1, ((0,0), (0,padding)), 'constant', constant_values=0)
906
+ return array_1
907
+
908
+ def dBgain(audio, volume_gain_dB):
909
+ attenuation = 10 ** (volume_gain_dB / 20)
910
+ gained_audio = audio * attenuation
911
+ return gained_audio
912
+
913
+
914
+
915
+ if __name__ == '__main__':
916
+ start_time = time()
917
+ print("started!\n")
918
+ m = argparse.ArgumentParser()
919
+ m.add_argument("--input_audio", "-i", nargs='+', type=str, help="Input audio location. You can provide multiple files at once", required=True)
920
+ m.add_argument("--output_folder", "-r", type=str, help="Output audio folder", required=True)
921
+ m.add_argument("--large_gpu", action='store_true', help="It will store all models on GPU for faster processing of multiple audio files. Requires 11 and more GB of free GPU memory.")
922
+ m.add_argument("--single_onnx", action='store_true', help="Only use single ONNX model for vocals. Can be useful if you have not enough GPU memory.")
923
+ m.add_argument("--cpu", action='store_true', help="Choose CPU instead of GPU for processing. Can be very slow.")
924
+ m.add_argument("--overlap_demucs", type=float, help="Overlap of splited audio for light models. Closer to 1.0 - slower", required=False, default=0.1)
925
+ m.add_argument("--overlap_VOCFT", type=float, help="Overlap of splited audio for heavy models. Closer to 1.0 - slower", required=False, default=0.1)
926
+ m.add_argument("--overlap_InstHQ4", type=float, help="Overlap of splited audio for heavy models. Closer to 1.0 - slower", required=False, default=0.1)
927
+ m.add_argument("--overlap_VitLarge", type=int, help="Overlap of splited audio for heavy models. Closer to 1.0 - slower", required=False, default=1)
928
+ m.add_argument("--overlap_InstVoc", type=int, help="MDXv3 overlap", required=False, default=2)
929
+ m.add_argument("--overlap_BSRoformer", type=int, help="BSRoformer overlap", required=False, default=2)
930
+ m.add_argument("--weight_InstVoc", type=float, help="Weight of MDXv3 model", required=False, default=4)
931
+ m.add_argument("--weight_VOCFT", type=float, help="Weight of VOC-FT model", required=False, default=1)
932
+ m.add_argument("--weight_InstHQ4", type=float, help="Weight of instHQ4 model", required=False, default=1)
933
+ m.add_argument("--weight_VitLarge", type=float, help="Weight of VitLarge model", required=False, default=1)
934
+ m.add_argument("--weight_BSRoformer", type=float, help="Weight of BS-Roformer model", required=False, default=10)
935
+ m.add_argument("--BigShifts", type=int, help="Managing MDX 'BigShifts' trick value.", required=False, default=7)
936
+ m.add_argument("--vocals_only", action='store_true', help="Vocals + instrumental only")
937
+ m.add_argument("--use_BSRoformer", action='store_true', help="use BSRoformer in vocal ensemble")
938
+ m.add_argument("--BSRoformer_model", type=str, help="Which checkpoint to use", required=False, default="ep_317_1297")
939
+ m.add_argument("--use_InstVoc", action='store_true', help="use instVoc in vocal ensemble")
940
+ m.add_argument("--use_VitLarge", action='store_true', help="use VitLarge in vocal ensemble")
941
+ m.add_argument("--use_InstHQ4", action='store_true', help="use InstHQ4 in vocal ensemble")
942
+ m.add_argument("--use_VOCFT", action='store_true', help="use VOCFT in vocal ensemble")
943
+ m.add_argument("--output_format", type=str, help="Output audio folder", default="float")
944
+ m.add_argument("--input_gain", type=int, help="input volume gain", required=False, default=0)
945
+ m.add_argument("--restore_gain", action='store_true', help="restore original gain after separation")
946
+ m.add_argument("--filter_vocals", action='store_true', help="Remove audio below 50hz in vocals stem")
947
+ options = m.parse_args().__dict__
948
+ print("Options: ")
949
+
950
+ print(f'Input Gain: {options["input_gain"]}dB')
951
+ print(f'Restore Gain: {options["restore_gain"]}')
952
+ print(f'BigShifts: {options["BigShifts"]}\n')
953
+
954
+ print(f'BSRoformer_model: {options["BSRoformer_model"]}')
955
+ print(f'weight_BSRoformer: {options["weight_BSRoformer"]}')
956
+ print(f'weight_InstVoc: {options["weight_InstVoc"]}\n')
957
+
958
+ print(f'use_VitLarge: {options["use_VitLarge"]}')
959
+ if options["use_VitLarge"] is True:
960
+ print(f'weight_VitLarge: {options["weight_VitLarge"]}\n')
961
+
962
+ print(f'use_VOCFT: {options["use_VOCFT"]}')
963
+ if options["use_VOCFT"] is True:
964
+ print(f'overlap_VOCFT: {options["overlap_VOCFT"]}')
965
+ print(f'weight_VOCFT: {options["weight_VOCFT"]}\n')
966
+
967
+ print(f'use_InstHQ4: {options["use_InstHQ4"]}')
968
+ if options["use_InstHQ4"] is True:
969
+ print(f'overlap_InstHQ4: {options["overlap_InstHQ4"]}')
970
+ print(f'weight_InstHQ4: {options["weight_InstHQ4"]}\n')
971
+
972
+ print(f'vocals_only: {options["vocals_only"]}')
973
+
974
+ if options["vocals_only"] is False:
975
+ print(f'overlap_demucs: {options["overlap_demucs"]}\n')
976
+
977
+ print(f'output_format: {options["output_format"]}\n')
978
+ predict_with_model(options)
979
+ print('Time: {:.0f} sec'.format(time() - start_time))
980
+
MDX23v24/requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy
2
+ soundfile
3
+ scipy
4
+ tqdm
5
+ librosa
6
+ demucs
7
+ #onnxruntime-gpu # nighlty version installed within the notebook to fix cuda12.2 issue.
8
+ torch
9
+ pyyaml
10
+ ml_collections
11
+ #pytorch_lightning
12
+ samplerate==0.1.0
13
+ segmentation_models_pytorch==0.3.3
14
+ beartype==0.14.1
15
+ rotary_embedding_torch==0.3.5