import torch import torch.nn as nn import torch.nn.functional as F from functools import partial class Conv_TDF_net_trim_model(nn.Module): def __init__(self, device, target_name, L, n_fft, hop=1024, dim_f=3072): super(Conv_TDF_net_trim_model, self).__init__() self.dim_c = 4 self.dim_f, self.dim_t = dim_f, 256 self.n_fft = n_fft self.hop = hop self.n_bins = self.n_fft // 2 + 1 self.chunk_size = hop * (self.dim_t - 1) self.window = torch.hann_window(window_length=self.n_fft, periodic=True).to(device) self.target_name = target_name out_c = self.dim_c * 4 if target_name == '*' else self.dim_c self.freq_pad = torch.zeros([1, out_c, self.n_bins - self.dim_f, self.dim_t]).to(device) self.n = L // 2 def stft(self, x): x = x.reshape([-1, self.chunk_size]) x = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True, return_complex=True) x = torch.view_as_real(x) x = x.permute([0, 3, 1, 2]) x = x.reshape([-1, 2, 2, self.n_bins, self.dim_t]).reshape([-1, self.dim_c, self.n_bins, self.dim_t]) return x[:, :, :self.dim_f] def istft(self, x, freq_pad=None): freq_pad = self.freq_pad.repeat([x.shape[0], 1, 1, 1]) if freq_pad is None else freq_pad x = torch.cat([x, freq_pad], -2) x = x.reshape([-1, 2, 2, self.n_bins, self.dim_t]).reshape([-1, 2, self.n_bins, self.dim_t]) x = x.permute([0, 2, 3, 1]) x = x.contiguous() x = torch.view_as_complex(x) x = torch.istft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True) return x.reshape([-1, 2, self.chunk_size]) def forward(self, x): x = self.first_conv(x) x = x.transpose(-1, -2) ds_outputs = [] for i in range(self.n): x = self.ds_dense[i](x) ds_outputs.append(x) x = self.ds[i](x) x = self.mid_dense(x) for i in range(self.n): x = self.us[i](x) x *= ds_outputs[-i - 1] x = self.us_dense[i](x) x = x.transpose(-1, -2) x = self.final_conv(x) return x