dangtr0408 commited on
Commit
2b1b519
·
1 Parent(s): 2914730

Update inference.py and meldataset,py

Browse files
Files changed (2) hide show
  1. inference.py +48 -9
  2. meldataset.py +129 -40
inference.py CHANGED
@@ -65,9 +65,31 @@ class StyleTTS2(torch.nn.Module):
65
  super().__init__()
66
  self.register_buffer("get_device", torch.empty(0))
67
  self.preprocess = Preprocess()
68
-
69
- config = yaml.safe_load(open(config_path))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  args = self.__recursive_munch(config['model_params'])
 
 
 
71
 
72
  assert args.decoder.type in ['hifigan'], 'Decoder type unknown'
73
 
@@ -186,7 +208,7 @@ class StyleTTS2(torch.nn.Module):
186
  speed = min(max(speed, 0.0001), 2) #speed range [0, 2]
187
 
188
  phonem = ' '.join(word_tokenize(phonem))
189
- tokens = TextCleaner()(phonem)
190
  tokens.insert(0, 0)
191
  tokens.append(0)
192
  tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
@@ -231,17 +253,34 @@ class StyleTTS2(torch.nn.Module):
231
 
232
  return out.squeeze().cpu().numpy(), duration.mean()
233
 
234
- def get_styles(self, speaker, denoise=0.3, avg_style=True):
235
- if avg_style: split_dur = 3
236
- else: split_dur = 0
237
- style = {}
238
- ref_s = self.__compute_style(speaker['path'], denoise=denoise, split_dur=split_dur)
 
 
 
239
  style = {
240
- 'style': ref_s,
241
  'path': speaker['path'],
242
  'speed': speaker['speed'],
243
  }
244
  return style
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
 
246
  def generate(self, phonem, style, stabilize=True, n_merge=16):
247
  if stabilize: smooth_value=0.2
 
65
  super().__init__()
66
  self.register_buffer("get_device", torch.empty(0))
67
  self.preprocess = Preprocess()
68
+ self.ref_s = None
69
+ config = yaml.safe_load(open(config_path, "r", encoding="utf-8"))
70
+
71
+ try:
72
+ symbols = (
73
+ list(config['symbol']['pad']) +
74
+ list(config['symbol']['punctuation']) +
75
+ list(config['symbol']['letters']) +
76
+ list(config['symbol']['letters_ipa']) +
77
+ list(config['symbol']['extend'])
78
+ )
79
+ symbol_dict = {}
80
+ for i in range(len((symbols))):
81
+ symbol_dict[symbols[i]] = i
82
+
83
+ n_token = len(symbol_dict) + 1
84
+ print("\nFound:", n_token, "symbols")
85
+ except Exception as e:
86
+ print(f"\nERROR: Cannot find {e} in config file!\nYour config file is likely outdated, please download updated version from the repository.")
87
+ raise SystemExit(1)
88
+
89
  args = self.__recursive_munch(config['model_params'])
90
+ args['n_token'] = n_token
91
+
92
+ self.cleaner = TextCleaner(symbol_dict, debug=False)
93
 
94
  assert args.decoder.type in ['hifigan'], 'Decoder type unknown'
95
 
 
208
  speed = min(max(speed, 0.0001), 2) #speed range [0, 2]
209
 
210
  phonem = ' '.join(word_tokenize(phonem))
211
+ tokens = self.cleaner(phonem)
212
  tokens.insert(0, 0)
213
  tokens.append(0)
214
  tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
 
253
 
254
  return out.squeeze().cpu().numpy(), duration.mean()
255
 
256
+ def get_styles(self, speaker, denoise=0.3, avg_style=True, load_styles=False):
257
+ if not load_styles:
258
+ if avg_style: split_dur = 3
259
+ else: split_dur = 0
260
+ self.ref_s = self.__compute_style(speaker['path'], denoise=denoise, split_dur=split_dur)
261
+ else:
262
+ if self.ref_s is None:
263
+ raise Exception("Have to compute or load the styles first!")
264
  style = {
265
+ 'style': self.ref_s,
266
  'path': speaker['path'],
267
  'speed': speaker['speed'],
268
  }
269
  return style
270
+
271
+ def save_styles(self, save_dir):
272
+ if self.ref_s is not None:
273
+ torch.save(self.ref_s, save_dir)
274
+ print("Saved styles!")
275
+ else:
276
+ raise Exception("Have to compute the styles before saving it.")
277
+
278
+ def load_styles(self, save_dir):
279
+ try:
280
+ self.ref_s = torch.load(save_dir)
281
+ print("Loaded styles!")
282
+ except Exception as e:
283
+ print(e)
284
 
285
  def generate(self, phonem, style, stabilize=True, n_merge=16):
286
  if stabilize: smooth_value=0.2
meldataset.py CHANGED
@@ -1,7 +1,5 @@
1
  #coding: utf-8
2
- import os
3
  import os.path as osp
4
- import time
5
  import random
6
  import numpy as np
7
  import random
@@ -9,10 +7,10 @@ import soundfile as sf
9
  import librosa
10
 
11
  import torch
12
- from torch import nn
13
- import torch.nn.functional as F
14
  import torchaudio
15
- from torch.utils.data import DataLoader
 
 
16
 
17
  import logging
18
  logger = logging.getLogger(__name__)
@@ -20,33 +18,19 @@ logger.setLevel(logging.DEBUG)
20
 
21
  import pandas as pd
22
 
23
- ##########################################################
24
- _pad = "$"
25
- _punctuation = ';:,.!?¡¿—…"«»“” '
26
- _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
27
- _letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
28
- _extend = "" #ADD MORE SYMBOLS HERE
29
-
30
- # Export all symbols:
31
- symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa) + list(_extend)
32
-
33
- dicts = {}
34
- for i in range(len((symbols))):
35
- dicts[symbols[i]] = i
36
-
37
- # Copy this code somewhere else then run with print(len(dicts) + 1) to check total symbols
38
- ##########################################################
39
-
40
  class TextCleaner:
41
- def __init__(self, dummy=None):
42
- self.word_index_dictionary = dicts
 
43
  def __call__(self, text):
44
  indexes = []
45
  for char in text:
46
  try:
47
  indexes.append(self.word_index_dictionary[char])
48
  except KeyError as e:
49
- #print(char)
 
 
50
  continue
51
  return indexes
52
 
@@ -75,17 +59,16 @@ class FilePathDataset(torch.utils.data.Dataset):
75
  def __init__(self,
76
  data_list,
77
  root_path,
 
78
  sr=24000,
79
  data_augmentation=False,
80
- validation=False
 
81
  ):
82
 
83
- spect_params = SPECT_PARAMS
84
- mel_params = MEL_PARAMS
85
-
86
  _data_list = [l.strip().split('|') for l in data_list]
87
  self.data_list = _data_list #[data if len(data) == 3 else (*data, 0) for data in _data_list] #append speakerid=0 for all
88
- self.text_cleaner = TextCleaner()
89
  self.sr = sr
90
 
91
  self.df = pd.DataFrame(self.data_list)
@@ -195,9 +178,13 @@ class Collater(object):
195
  return waves, texts, input_lengths, mels, output_lengths
196
 
197
 
 
 
 
198
 
199
  def build_dataloader(path_list,
200
  root_path,
 
201
  validation=False,
202
  batch_size=4,
203
  num_workers=1,
@@ -205,14 +192,116 @@ def build_dataloader(path_list,
205
  collate_config={},
206
  dataset_config={}):
207
 
208
- dataset = FilePathDataset(path_list, root_path, validation=validation, **dataset_config)
209
  collate_fn = Collater(**collate_config)
210
- data_loader = DataLoader(dataset,
211
- batch_size=batch_size,
212
- shuffle=(not validation),
213
- num_workers=num_workers,
214
- drop_last=(not validation),
215
- collate_fn=collate_fn,
216
- pin_memory=(device != 'cpu'))
217
-
218
- return data_loader
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  #coding: utf-8
 
2
  import os.path as osp
 
3
  import random
4
  import numpy as np
5
  import random
 
7
  import librosa
8
 
9
  import torch
 
 
10
  import torchaudio
11
+ import torch.utils.data
12
+ import torch.distributed as dist
13
+ from multiprocessing import Pool
14
 
15
  import logging
16
  logger = logging.getLogger(__name__)
 
18
 
19
  import pandas as pd
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  class TextCleaner:
22
+ def __init__(self, symbol_dict, debug=True):
23
+ self.word_index_dictionary = symbol_dict
24
+ self.debug = debug
25
  def __call__(self, text):
26
  indexes = []
27
  for char in text:
28
  try:
29
  indexes.append(self.word_index_dictionary[char])
30
  except KeyError as e:
31
+ if self.debug:
32
+ print("\nWARNING UNKNOWN IPA CHARACTERS/LETTERS: ", char)
33
+ print("To ignore set 'debug' to false in the config")
34
  continue
35
  return indexes
36
 
 
59
  def __init__(self,
60
  data_list,
61
  root_path,
62
+ symbol_dict,
63
  sr=24000,
64
  data_augmentation=False,
65
+ validation=False,
66
+ debug=True
67
  ):
68
 
 
 
 
69
  _data_list = [l.strip().split('|') for l in data_list]
70
  self.data_list = _data_list #[data if len(data) == 3 else (*data, 0) for data in _data_list] #append speakerid=0 for all
71
+ self.text_cleaner = TextCleaner(symbol_dict, debug)
72
  self.sr = sr
73
 
74
  self.df = pd.DataFrame(self.data_list)
 
178
  return waves, texts, input_lengths, mels, output_lengths
179
 
180
 
181
+ def get_length(wave_path, root_path):
182
+ info = sf.info(osp.join(root_path, wave_path))
183
+ return info.frames * (24000 / info.samplerate)
184
 
185
  def build_dataloader(path_list,
186
  root_path,
187
+ symbol_dict,
188
  validation=False,
189
  batch_size=4,
190
  num_workers=1,
 
192
  collate_config={},
193
  dataset_config={}):
194
 
195
+ dataset = FilePathDataset(path_list, root_path, symbol_dict, validation=validation, **dataset_config)
196
  collate_fn = Collater(**collate_config)
197
+
198
+ print("Getting sample lengths...")
199
+
200
+ num_processes = num_workers * 2
201
+ if num_processes != 0:
202
+ list_of_tuples = [(d[0], root_path) for d in dataset.data_list]
203
+ with Pool(processes=num_processes) as pool:
204
+ sample_lengths = pool.starmap(get_length, list_of_tuples, chunksize=16)
205
+ else:
206
+ sample_lengths = []
207
+ for d in dataset.data_list:
208
+ sample_lengths.append(get_length(d[0], root_path))
209
+
210
+ data_loader = torch.utils.data.DataLoader(
211
+ dataset,
212
+ num_workers=num_workers,
213
+ batch_sampler=BatchSampler(
214
+ sample_lengths,
215
+ batch_size,
216
+ shuffle=(not validation),
217
+ drop_last=(not validation),
218
+ num_replicas=1,
219
+ rank=0,
220
+ ),
221
+ collate_fn=collate_fn,
222
+ pin_memory=(device != "cpu"),
223
+ )
224
+
225
+ return data_loader
226
+
227
+ #https://github.com/duerig/StyleTTS2/
228
+ class BatchSampler(torch.utils.data.Sampler):
229
+ def __init__(
230
+ self,
231
+ sample_lengths,
232
+ batch_sizes,
233
+ num_replicas=None,
234
+ rank=None,
235
+ shuffle=True,
236
+ drop_last=False,
237
+ ):
238
+ self.batch_sizes = batch_sizes
239
+ if num_replicas is None:
240
+ self.num_replicas = dist.get_world_size()
241
+ else:
242
+ self.num_replicas = num_replicas
243
+ if rank is None:
244
+ self.rank = dist.get_rank()
245
+ else:
246
+ self.rank = rank
247
+ self.shuffle = shuffle
248
+ self.drop_last = drop_last
249
+
250
+ self.time_bins = {}
251
+ self.epoch = 0
252
+ self.total_len = 0
253
+ self.last_bin = None
254
+
255
+ for i in range(len(sample_lengths)):
256
+ bin_num = self.get_time_bin(sample_lengths[i])
257
+ if bin_num != -1:
258
+ if bin_num not in self.time_bins:
259
+ self.time_bins[bin_num] = []
260
+ self.time_bins[bin_num].append(i)
261
+
262
+ for key in self.time_bins.keys():
263
+ val = self.time_bins[key]
264
+ total_batch = self.batch_sizes * num_replicas
265
+ self.total_len += len(val) // total_batch
266
+ if not self.drop_last and len(val) % total_batch != 0:
267
+ self.total_len += 1
268
+
269
+ def __iter__(self):
270
+ sampler_order = list(self.time_bins.keys())
271
+ sampler_indices = []
272
+
273
+ if self.shuffle:
274
+ sampler_indices = torch.randperm(len(sampler_order)).tolist()
275
+ else:
276
+ sampler_indices = list(range(len(sampler_order)))
277
+
278
+ for index in sampler_indices:
279
+ key = sampler_order[index]
280
+ current_bin = self.time_bins[key]
281
+ dist = torch.utils.data.distributed.DistributedSampler(
282
+ current_bin,
283
+ num_replicas=self.num_replicas,
284
+ rank=self.rank,
285
+ shuffle=self.shuffle,
286
+ drop_last=self.drop_last,
287
+ )
288
+ dist.set_epoch(self.epoch)
289
+ sampler = torch.utils.data.sampler.BatchSampler(
290
+ dist, self.batch_sizes, self.drop_last
291
+ )
292
+ for item_list in sampler:
293
+ self.last_bin = key
294
+ yield [current_bin[i] for i in item_list]
295
+
296
+ def __len__(self):
297
+ return self.total_len
298
+
299
+ def set_epoch(self, epoch):
300
+ self.epoch = epoch
301
+
302
+ def get_time_bin(self, sample_count):
303
+ result = -1
304
+ frames = sample_count // 300
305
+ if frames >= 20:
306
+ result = (frames - 20) // 20
307
+ return result