uvr / MDX23v24 /modules /segm_models.py
gosummer's picture
Upload 14 files
e625816 verified
raw
history blame
4.28 kB
if __name__ == '__main__':
import os
gpu_use = "2"
print('GPU use: {}'.format(gpu_use))
os.environ["CUDA_VISIBLE_DEVICES"] = "{}".format(gpu_use)
import torch
import torch.nn as nn
import segmentation_models_pytorch as smp
class STFT:
def __init__(self, config):
self.n_fft = config.n_fft
self.hop_length = config.hop_length
self.window = torch.hann_window(window_length=self.n_fft, periodic=True)
self.dim_f = config.dim_f
def __call__(self, x):
window = self.window.to(x.device)
batch_dims = x.shape[:-2]
c, t = x.shape[-2:]
x = x.reshape([-1, t])
x = torch.stft(
x,
n_fft=self.n_fft,
hop_length=self.hop_length,
window=window,
center=True,
return_complex=True
)
x = torch.view_as_real(x)
x = x.permute([0, 3, 1, 2])
x = x.reshape([*batch_dims, c, 2, -1, x.shape[-1]]).reshape([*batch_dims, c * 2, -1, x.shape[-1]])
return x[..., :self.dim_f, :]
def inverse(self, x):
window = self.window.to(x.device)
batch_dims = x.shape[:-3]
c, f, t = x.shape[-3:]
n = self.n_fft // 2 + 1
f_pad = torch.zeros([*batch_dims, c, n - f, t]).to(x.device)
x = torch.cat([x, f_pad], -2)
x = x.reshape([*batch_dims, c // 2, 2, n, t]).reshape([-1, 2, n, t])
x = x.permute([0, 2, 3, 1])
x = x[..., 0] + x[..., 1] * 1.j
x = torch.istft(
x,
n_fft=self.n_fft,
hop_length=self.hop_length,
window=window,
center=True
)
x = x.reshape([*batch_dims, 2, -1])
return x
def get_act(act_type):
if act_type == 'gelu':
return nn.GELU()
elif act_type == 'relu':
return nn.ReLU()
elif act_type[:3] == 'elu':
alpha = float(act_type.replace('elu', ''))
return nn.ELU(alpha)
else:
raise Exception
class Segm_Models_Net(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
act = get_act(act_type=config.model.act)
self.num_target_instruments = 1 if config.training.target_instrument else len(config.training.instruments)
self.num_subbands = config.model.num_subbands
dim_c = self.num_subbands * config.audio.num_channels * 2
c = config.model.num_channels
f = config.audio.dim_f // self.num_subbands
self.first_conv = nn.Conv2d(dim_c, c, 1, 1, 0, bias=False)
if config.model.decoder_type == 'unet':
self.unet_model = smp.Unet(
encoder_name=config.model.encoder_name,
encoder_weights="imagenet",
in_channels=c,
classes=c,
)
elif config.model.decoder_type == 'fpn':
self.unet_model = smp.FPN(
encoder_name=config.model.encoder_name,
encoder_weights="imagenet",
in_channels=c,
classes=c,
)
self.final_conv = nn.Sequential(
nn.Conv2d(c + dim_c, c, 1, 1, 0, bias=False),
act,
nn.Conv2d(c, self.num_target_instruments * dim_c, 1, 1, 0, bias=False)
)
self.stft = STFT(config.audio)
def cac2cws(self, x):
k = self.num_subbands
b, c, f, t = x.shape
x = x.reshape(b, c, k, f // k, t)
x = x.reshape(b, c * k, f // k, t)
return x
def cws2cac(self, x):
k = self.num_subbands
b, c, f, t = x.shape
x = x.reshape(b, c // k, k, f, t)
x = x.reshape(b, c // k, f * k, t)
return x
def forward(self, x):
x = self.stft(x)
mix = x = self.cac2cws(x)
first_conv_out = x = self.first_conv(x)
x = x.transpose(-1, -2)
x = self.unet_model(x)
x = x.transpose(-1, -2)
x = x * first_conv_out # reduce artifacts
x = self.final_conv(torch.cat([mix, x], 1))
x = self.cws2cac(x)
if self.num_target_instruments > 1:
b, c, f, t = x.shape
x = x.reshape(b, self.num_target_instruments, -1, f, t)
x = self.stft.inverse(x)
return x