File size: 2,228 Bytes
e625816 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
class Conv_TDF_net_trim_model(nn.Module):
def __init__(self, device, target_name, L, n_fft, hop=1024, dim_f=3072):
super(Conv_TDF_net_trim_model, self).__init__()
self.dim_c = 4
self.dim_f, self.dim_t = dim_f, 256
self.n_fft = n_fft
self.hop = hop
self.n_bins = self.n_fft // 2 + 1
self.chunk_size = hop * (self.dim_t - 1)
self.window = torch.hann_window(window_length=self.n_fft, periodic=True).to(device)
self.target_name = target_name
out_c = self.dim_c * 4 if target_name == '*' else self.dim_c
self.freq_pad = torch.zeros([1, out_c, self.n_bins - self.dim_f, self.dim_t]).to(device)
self.n = L // 2
def stft(self, x):
x = x.reshape([-1, self.chunk_size])
x = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True, return_complex=True)
x = torch.view_as_real(x)
x = x.permute([0, 3, 1, 2])
x = x.reshape([-1, 2, 2, self.n_bins, self.dim_t]).reshape([-1, self.dim_c, self.n_bins, self.dim_t])
return x[:, :, :self.dim_f]
def istft(self, x, freq_pad=None):
freq_pad = self.freq_pad.repeat([x.shape[0], 1, 1, 1]) if freq_pad is None else freq_pad
x = torch.cat([x, freq_pad], -2)
x = x.reshape([-1, 2, 2, self.n_bins, self.dim_t]).reshape([-1, 2, self.n_bins, self.dim_t])
x = x.permute([0, 2, 3, 1])
x = x.contiguous()
x = torch.view_as_complex(x)
x = torch.istft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True)
return x.reshape([-1, 2, self.chunk_size])
def forward(self, x):
x = self.first_conv(x)
x = x.transpose(-1, -2)
ds_outputs = []
for i in range(self.n):
x = self.ds_dense[i](x)
ds_outputs.append(x)
x = self.ds[i](x)
x = self.mid_dense(x)
for i in range(self.n):
x = self.us[i](x)
x *= ds_outputs[-i - 1]
x = self.us_dense[i](x)
x = x.transpose(-1, -2)
x = self.final_conv(x)
return x
|