|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from functools import partial |
|
|
|
|
|
class Conv_TDF_net_trim_model(nn.Module): |
|
def __init__(self, device, target_name, L, n_fft, hop=1024, dim_f=3072): |
|
super(Conv_TDF_net_trim_model, self).__init__() |
|
self.dim_c = 4 |
|
self.dim_f, self.dim_t = dim_f, 256 |
|
self.n_fft = n_fft |
|
self.hop = hop |
|
self.n_bins = self.n_fft // 2 + 1 |
|
self.chunk_size = hop * (self.dim_t - 1) |
|
self.window = torch.hann_window(window_length=self.n_fft, periodic=True).to(device) |
|
self.target_name = target_name |
|
out_c = self.dim_c * 4 if target_name == '*' else self.dim_c |
|
self.freq_pad = torch.zeros([1, out_c, self.n_bins - self.dim_f, self.dim_t]).to(device) |
|
self.n = L // 2 |
|
|
|
def stft(self, x): |
|
x = x.reshape([-1, self.chunk_size]) |
|
x = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True, return_complex=True) |
|
x = torch.view_as_real(x) |
|
x = x.permute([0, 3, 1, 2]) |
|
x = x.reshape([-1, 2, 2, self.n_bins, self.dim_t]).reshape([-1, self.dim_c, self.n_bins, self.dim_t]) |
|
return x[:, :, :self.dim_f] |
|
|
|
def istft(self, x, freq_pad=None): |
|
freq_pad = self.freq_pad.repeat([x.shape[0], 1, 1, 1]) if freq_pad is None else freq_pad |
|
x = torch.cat([x, freq_pad], -2) |
|
x = x.reshape([-1, 2, 2, self.n_bins, self.dim_t]).reshape([-1, 2, self.n_bins, self.dim_t]) |
|
x = x.permute([0, 2, 3, 1]) |
|
x = x.contiguous() |
|
x = torch.view_as_complex(x) |
|
x = torch.istft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True) |
|
return x.reshape([-1, 2, self.chunk_size]) |
|
|
|
def forward(self, x): |
|
x = self.first_conv(x) |
|
x = x.transpose(-1, -2) |
|
|
|
ds_outputs = [] |
|
for i in range(self.n): |
|
x = self.ds_dense[i](x) |
|
ds_outputs.append(x) |
|
x = self.ds[i](x) |
|
|
|
x = self.mid_dense(x) |
|
for i in range(self.n): |
|
x = self.us[i](x) |
|
x *= ds_outputs[-i - 1] |
|
x = self.us_dense[i](x) |
|
|
|
x = x.transpose(-1, -2) |
|
x = self.final_conv(x) |
|
return x |
|
|