Guilherme34 commited on
Commit
62498a7
·
verified ·
1 Parent(s): 72a8e98

Upload processing_minicpmo.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. processing_minicpmo.py +505 -0
processing_minicpmo.py ADDED
@@ -0,0 +1,505 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 The OpenBMB Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Processor class for MiniCPMO.
17
+ """
18
+
19
+ import math
20
+ import re
21
+ from typing import List
22
+ from typing import Literal
23
+ from typing import Optional
24
+ from typing import Union
25
+
26
+ import numpy as np
27
+ import torch
28
+ import torchaudio
29
+ from transformers.image_utils import ImageInput
30
+ from transformers.processing_utils import ProcessorMixin
31
+ from transformers.tokenization_utils_base import PreTokenizedInput
32
+ from transformers.tokenization_utils_base import TextInput
33
+ from transformers.utils import TensorType
34
+
35
+ from .image_processing_minicpmv import MiniCPMOBatchFeature
36
+
37
+
38
+ class MiniCPMOProcessor(ProcessorMixin):
39
+ r"""
40
+ Constructs a MiniCPMV processor which wraps a MiniCPMV image processor and a MiniCPMV tokenizer into a single processor.
41
+
42
+ [`MiniCPMVProcessor`] offers all the functionalities of [`MiniCPMVImageProcessor`] and [`LlamaTokenizerWrapper`]. See the
43
+ [`~MiniCPMVProcessor.__call__`] and [`~MiniCPMVProcessor.decode`] for more information.
44
+
45
+ Args:
46
+ image_processor ([`MiniCPMVImageProcessor`], *optional*):
47
+ The image processor is a required input.
48
+ tokenizer ([`LlamaTokenizerWrapper`], *optional*):
49
+ The tokenizer is a required input.
50
+ """
51
+
52
+ attributes = ["image_processor", "feature_extractor", "tokenizer"]
53
+ feature_extractor_class = "WhisperFeatureExtractor"
54
+ image_processor_class = "AutoImageProcessor"
55
+ tokenizer_class = "AutoTokenizer"
56
+
57
+ def __init__(self, image_processor=None, feature_extractor=None, tokenizer=None):
58
+ super().__init__(image_processor, feature_extractor, tokenizer)
59
+ self.version = image_processor.version
60
+
61
+ def __call__(
62
+ self,
63
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]],
64
+ images: ImageInput = None,
65
+ audios: Union[np.ndarray, List[np.ndarray], List[List[np.ndarray]]] = None,
66
+ audio_parts: Optional[list] = None,
67
+ max_length: Optional[int] = None,
68
+ do_pad: Optional[bool] = True,
69
+ max_slice_nums: int = None,
70
+ use_image_id: bool = True,
71
+ chunk_input: bool = False,
72
+ return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
73
+ sampling_rate: Optional[int] = 16000,
74
+ **kwargs,
75
+ ) -> MiniCPMOBatchFeature:
76
+ if images is not None:
77
+ image_inputs = self.image_processor(
78
+ images, do_pad=do_pad, max_slice_nums=max_slice_nums, return_tensors=return_tensors
79
+ )
80
+ else:
81
+ image_inputs = None
82
+
83
+ if audios is not None:
84
+ audio_features, audio_feature_lens, audio_phs = self.audio_feature_extract(
85
+ audios, audio_parts, chunk_input, sampling_rate
86
+ )
87
+ else:
88
+ audio_features, audio_feature_lens, audio_phs = [], [], []
89
+
90
+ model_inputs = self._convert_omni_to_inputs(
91
+ image_inputs,
92
+ audio_phs,
93
+ text,
94
+ max_slice_nums=max_slice_nums,
95
+ use_image_id=use_image_id,
96
+ max_length=max_length,
97
+ **kwargs,
98
+ )
99
+
100
+ model_inputs["audio_features"] = audio_features
101
+ model_inputs["audio_feature_lens"] = audio_feature_lens
102
+
103
+ return MiniCPMOBatchFeature(data={**model_inputs})
104
+
105
+ def get_audio_placeholder(self, audio_lens, chunk_input, chunk_length):
106
+ pool_step = 2
107
+ feature_lens = math.ceil(audio_lens / self.feature_extractor.hop_length)
108
+
109
+ feature_lens = (feature_lens - 1) // 2 + 1
110
+ output_lens = (feature_lens - pool_step) // pool_step + 1
111
+
112
+ if chunk_input:
113
+ fbank_feat_in_chunk = int(chunk_length * 100)
114
+ cnn_feat_in_chunk = (fbank_feat_in_chunk - 1) // 2 + 1
115
+ audio_embeds_in_chunk = (cnn_feat_in_chunk - pool_step) // pool_step + 1
116
+ num_audio_chunks = (output_lens + audio_embeds_in_chunk - 1) // audio_embeds_in_chunk
117
+
118
+ place_holders = ""
119
+ total_unk_len = 0
120
+ for _ in range(num_audio_chunks):
121
+ unk_len = min(audio_embeds_in_chunk, output_lens - total_unk_len)
122
+ place_holders += self.tokenizer.audio_start + "<unk>" * unk_len + self.tokenizer.audio_end
123
+ total_unk_len += unk_len
124
+ audio_placeholder = place_holders
125
+ else:
126
+ audio_placeholder = self.tokenizer.audio_start + "<unk>" * output_lens + self.tokenizer.audio_end
127
+
128
+ return audio_placeholder
129
+
130
+ def audio_feature_extract(
131
+ self,
132
+ audios: Union[np.ndarray, List[np.ndarray], List[List[np.ndarray]]],
133
+ audio_parts: Optional[list] = None,
134
+ chunk_input: Optional[bool] = False,
135
+ sampling_rate: Optional[int] = None,
136
+ chunk_length: Optional[int] = 1,
137
+ **kwargs,
138
+ ):
139
+ if isinstance(audios, np.ndarray):
140
+ audios_list = [[audios]]
141
+ elif isinstance(audios[0], np.ndarray):
142
+ audios_list = [audios]
143
+ else:
144
+ audios_list = audios
145
+
146
+ if audio_parts is not None:
147
+ assert len(audio_parts) == len(audios_list)
148
+ for parts, audios in zip(audio_parts, audios_list):
149
+ assert len(parts) == len(audios)
150
+
151
+ audio_feature_lens_list = []
152
+ audio_ph_list = []
153
+
154
+ audio_features_all = []
155
+
156
+ # audio placeholder not dependent on audio_parts
157
+ for audios in audios_list:
158
+ if audios:
159
+ audio_ph_list.append([self.get_audio_placeholder(len(a), chunk_input, chunk_length) for a in audios])
160
+ else:
161
+ audio_ph_list.append([])
162
+
163
+ for idx, audios in enumerate(audios_list):
164
+ if audio_parts is not None:
165
+ # same audio part merge
166
+ audio_part = audio_parts[idx]
167
+ merge_audio = []
168
+ cur_audio = []
169
+ for aid, (part, audio) in enumerate(zip(audio_part, audios)):
170
+ if aid == 0 or audio_part[aid] == audio_part[aid - 1]:
171
+ cur_audio.append(audio)
172
+ else:
173
+ merge_audio.append(np.hstack(cur_audio))
174
+ cur_audio = [audio]
175
+ if cur_audio:
176
+ merge_audio.append(np.hstack(cur_audio))
177
+
178
+ else:
179
+ merge_audio = audios
180
+
181
+ audio_feature_lens = []
182
+
183
+ # If the audio exceeds 30 seconds, split it into chunks every 30 seconds.
184
+ final_merge_audio = []
185
+ max_audio_inp_len = 30 * sampling_rate
186
+ for audio in merge_audio:
187
+ if len(audio) <= max_audio_inp_len:
188
+ final_merge_audio.append(audio)
189
+ else:
190
+ for i in range(math.ceil(len(audio) / max_audio_inp_len)):
191
+ final_merge_audio.append(audio[i * max_audio_inp_len : (i + 1) * max_audio_inp_len])
192
+
193
+ if audios:
194
+ audio_inputs = self.feature_extractor(
195
+ final_merge_audio,
196
+ sampling_rate=sampling_rate,
197
+ return_attention_mask=True,
198
+ padding="max_length",
199
+ return_tensors="pt",
200
+ **kwargs,
201
+ )
202
+ audio_feature = audio_inputs["input_features"]
203
+ actual_lens = audio_inputs["attention_mask"].sum(dim=1)
204
+
205
+ for feat, lens in zip(audio_feature, actual_lens):
206
+ audio_features_all.append(feat[:, :lens])
207
+ audio_feature_lens.append(lens)
208
+
209
+ audio_feature_lens = torch.hstack(audio_feature_lens)
210
+ audio_feature_lens_list.append(audio_feature_lens)
211
+ else:
212
+ audio_feature_lens_list.append([])
213
+
214
+ if audio_features_all:
215
+ audio_features = [i.permute(1, 0) for i in audio_features_all]
216
+ audio_features = torch.nn.utils.rnn.pad_sequence(
217
+ audio_features, batch_first=True, padding_value=0.0
218
+ ).permute(0, 2, 1)
219
+ else:
220
+ audio_features = []
221
+
222
+ return audio_features, audio_feature_lens_list, audio_ph_list
223
+
224
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
225
+ def batch_decode(self, *args, **kwargs):
226
+ """
227
+ This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
228
+ refer to the docstring of this method for more information.
229
+ """
230
+ output_ids = args[0]
231
+ result_text = []
232
+ for result in output_ids:
233
+ result = result[result != 0]
234
+ if result[0] == self.tokenizer.bos_id:
235
+ result = result[1:]
236
+ if result[-1] == self.tokenizer.eos_id:
237
+ result = result[:-1]
238
+ result_text.append(self.tokenizer.decode(result, *args[1:], **kwargs).strip())
239
+ return result_text
240
+ # return self.tokenizer.batch_decode(*args, **kwargs)
241
+
242
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama
243
+ def decode(self, *args, **kwargs):
244
+ """
245
+ This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
246
+ the docstring of this method for more information.
247
+ """
248
+ result = args[0]
249
+ result = result[result != 0]
250
+ if result[0] == self.tokenizer.bos_id:
251
+ result = result[1:]
252
+ if result[-1] == self.tokenizer.eos_id or (
253
+ hasattr(self.tokenizer, "eot_id") and result[-1] == self.tokenizer.eot_id
254
+ ):
255
+ result = result[:-1]
256
+ return self.tokenizer.decode(result, *args[1:], **kwargs).strip()
257
+
258
+ def _convert(self, input_str, max_inp_length: Optional[int] = None, **kwargs):
259
+ input_ids = self.tokenizer.encode(input_str, **kwargs)
260
+ if max_inp_length is not None:
261
+ input_ids = input_ids[:max_inp_length]
262
+ input_ids = torch.tensor(input_ids, dtype=torch.int32)
263
+
264
+ ## image bound
265
+ start_cond = (input_ids == self.tokenizer.im_start_id) | (input_ids == self.tokenizer.slice_start_id)
266
+ end_cond = (input_ids == self.tokenizer.im_end_id) | (input_ids == self.tokenizer.slice_end_id)
267
+
268
+ image_start_idx = torch.where(start_cond)[0]
269
+ image_start_idx += 1
270
+ image_end_idx = torch.where(end_cond)[0]
271
+
272
+ valid_image_nums = max(len(image_start_idx), len(image_end_idx))
273
+
274
+ image_bounds = torch.hstack(
275
+ [
276
+ image_start_idx[:valid_image_nums].unsqueeze(-1),
277
+ image_end_idx[:valid_image_nums].unsqueeze(-1),
278
+ ]
279
+ )
280
+
281
+ ## audio bound
282
+ audio_start_idx = torch.where(input_ids == self.tokenizer.audio_start_id)[0]
283
+ audio_end_idx = torch.where(input_ids == self.tokenizer.audio_end_id)[0]
284
+ assert len(audio_start_idx) == len(audio_end_idx)
285
+ audio_bounds = torch.hstack([(audio_start_idx + 1).unsqueeze(-1), audio_end_idx.unsqueeze(-1)])
286
+
287
+ spk_start_idx = torch.where(input_ids == self.tokenizer.spk_start_id)[0]
288
+ spk_end_idx = torch.where(input_ids == self.tokenizer.spk_end_id)[0]
289
+ assert len(spk_start_idx) == len(spk_end_idx)
290
+ spk_bounds = torch.hstack([(spk_start_idx + 1).unsqueeze(-1), spk_end_idx.unsqueeze(-1)])
291
+
292
+ return input_ids, image_bounds, audio_bounds, spk_bounds
293
+
294
+ def _convert_omni_to_inputs(
295
+ self,
296
+ images,
297
+ audio_phs,
298
+ texts: Union[str, List[str]],
299
+ truncation=None,
300
+ max_length=None,
301
+ max_slice_nums=None,
302
+ use_image_id=None,
303
+ return_tensors=None,
304
+ **kwargs,
305
+ ):
306
+ if images is None and audio_phs is None:
307
+ model_inputs = self.tokenizer(
308
+ texts, return_tensors=return_tensors, truncation=truncation, max_length=max_length, **kwargs
309
+ )
310
+ return MiniCPMOBatchFeature(data={**model_inputs})
311
+
312
+ image_tag = "(<image>./</image>)"
313
+ image_pattern = "\(<image>./</image>\)"
314
+ audio_tag = "(<audio>./</audio>)"
315
+ audio_pattern = "\(<audio>./</audio>\)"
316
+ split_pattern = f"({image_pattern}|{audio_pattern})"
317
+
318
+ if isinstance(texts, str):
319
+ texts = [texts]
320
+
321
+ bs = len(texts)
322
+ if images is not None:
323
+ images, image_sizes, tgt_sizes = images["pixel_values"], images["image_sizes"], images["tgt_sizes"]
324
+ else:
325
+ images, image_sizes, tgt_sizes = [[]] * bs, [[]] * bs, [[]] * bs
326
+
327
+ input_ids_list = []
328
+ image_bounds_list = []
329
+ audio_bounds_list = []
330
+ spk_bounds_list = []
331
+
332
+ for index, text in enumerate(texts):
333
+ text_chunks = re.split(split_pattern, text)
334
+
335
+ image_tags = re.findall(image_pattern, text)
336
+ audio_tags = re.findall(audio_pattern, text)
337
+
338
+ if image_tags:
339
+ assert images is not None
340
+ assert len(image_tags) == len(image_sizes[index])
341
+ if audio_tags:
342
+ assert audio_phs is not None
343
+ assert len(audio_tags) == len(audio_phs[index])
344
+
345
+ image_id = 0
346
+ audio_id = 0
347
+ for i, chunk in enumerate(text_chunks):
348
+ if chunk == image_tag:
349
+ image_placeholder = self.image_processor.get_slice_image_placeholder(
350
+ image_sizes[index][image_id], image_id, max_slice_nums, use_image_id
351
+ )
352
+ image_id += 1
353
+ text_chunks[i] = image_placeholder
354
+ elif chunk == audio_tag:
355
+ audio_placeholder = audio_phs[index][audio_id]
356
+ audio_id += 1
357
+ text_chunks[i] = audio_placeholder
358
+
359
+ final_text = "".join(text_chunks)
360
+ input_ids, image_bounds, audio_bounds, spk_bounds = self._convert(final_text, max_length, **kwargs)
361
+
362
+ input_ids_list.append(input_ids)
363
+ image_bounds_list.append(image_bounds)
364
+ audio_bounds_list.append(audio_bounds)
365
+ spk_bounds_list.append(spk_bounds)
366
+
367
+ padded_input_ids, padding_lengths = self.pad(input_ids_list, padding_side="left")
368
+ attention_mask = torch.ones_like(padded_input_ids, dtype=torch.bool)
369
+ for i, length in enumerate(padding_lengths):
370
+ image_bounds_list[i] = image_bounds_list[i] + length
371
+ audio_bounds_list[i] = audio_bounds_list[i] + length
372
+ spk_bounds_list[i] = spk_bounds_list[i] + length
373
+ attention_mask[i, :length] = False
374
+
375
+ data = {
376
+ "input_ids": padded_input_ids,
377
+ "attention_mask": attention_mask,
378
+ "pixel_values": images,
379
+ "image_sizes": image_sizes,
380
+ "image_bound": image_bounds_list,
381
+ "tgt_sizes": tgt_sizes,
382
+ "audio_bounds": audio_bounds_list,
383
+ "spk_bounds": spk_bounds_list,
384
+ }
385
+
386
+ return data
387
+
388
+ @property
389
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names
390
+ def model_input_names(self):
391
+ tokenizer_input_names = self.tokenizer.model_input_names
392
+ image_processor_input_names = self.image_processor.model_input_names
393
+ feature_extractor_input_names = self.feature_extractor.model_input_names
394
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names + feature_extractor_input_names))
395
+
396
+ def pad(self, inputs, max_length=None, padding_value=0, padding_side="left"):
397
+ items = []
398
+ if isinstance(inputs[0], list):
399
+ assert isinstance(inputs[0][0], torch.Tensor)
400
+ for it in inputs:
401
+ for tr in it:
402
+ items.append(tr)
403
+ else:
404
+ assert isinstance(inputs[0], torch.Tensor)
405
+ items = inputs
406
+
407
+ batch_size = len(items)
408
+ shape = items[0].shape
409
+ dim = len(shape)
410
+ assert dim <= 2
411
+ if max_length is None:
412
+ max_length = 0
413
+ max_length = max(max_length, max(item.shape[-1] for item in items))
414
+ min_length = min(item.shape[-1] for item in items)
415
+ dtype = items[0].dtype
416
+
417
+ if dim == 0:
418
+ return torch.stack([item for item in items], dim=0), [0]
419
+ elif dim == 1:
420
+ if max_length == min_length:
421
+ return torch.stack([item for item in items], dim=0), [0] * batch_size
422
+ tensor = torch.zeros((batch_size, max_length), dtype=dtype) + padding_value
423
+ else:
424
+ tensor = torch.zeros((batch_size, max_length, shape[-1]), dtype=dtype) + padding_value
425
+
426
+ padding_length = []
427
+ for i, item in enumerate(items):
428
+ if dim == 1:
429
+ if padding_side == "left":
430
+ tensor[i, -len(item) :] = item.clone()
431
+ else:
432
+ tensor[i, : len(item)] = item.clone()
433
+ elif dim == 2:
434
+ if padding_side == "left":
435
+ tensor[i, -len(item) :, :] = item.clone()
436
+ else:
437
+ tensor[i, : len(item), :] = item.clone()
438
+ padding_length.append(tensor.shape[-1] - len(item))
439
+
440
+ return tensor, padding_length
441
+
442
+
443
+ class MelSpectrogramFeatures(torch.nn.Module):
444
+ def __init__(
445
+ self,
446
+ sample_rate=24000,
447
+ n_fft=1024,
448
+ hop_length=256,
449
+ n_mels=100,
450
+ padding: Literal["center", "same"] = "center",
451
+ ):
452
+ super().__init__()
453
+ if padding not in ["center", "same"]:
454
+ raise ValueError("Padding must be 'center' or 'same'.")
455
+ self.padding = padding
456
+ self.mel_spec = torchaudio.transforms.MelSpectrogram(
457
+ sample_rate=sample_rate,
458
+ n_fft=n_fft,
459
+ hop_length=hop_length,
460
+ n_mels=n_mels,
461
+ center=padding == "center",
462
+ power=1,
463
+ )
464
+
465
+ def __call__(self, audio: torch.Tensor) -> torch.Tensor:
466
+ """
467
+ audio: Tensor([num_channels, num_samples])
468
+ """
469
+ return super().__call__(audio)
470
+
471
+ def forward(self, audio: torch.Tensor) -> torch.Tensor:
472
+ """
473
+ audio: Tensor([num_channels, num_samples])
474
+ """
475
+ mel: torch.Tensor = self.mel_spec(audio)
476
+ features = torch.log(torch.clip(mel, min=1e-5))
477
+ return features
478
+
479
+
480
+ class ChatTTSProcessor:
481
+ def __init__(self, text_tokenizer):
482
+ self.audio_processor = MelSpectrogramFeatures()
483
+ self.text_tokenizer = text_tokenizer
484
+
485
+ def __call__(self, text_list, audio_list):
486
+ assert len(text_list) == len(audio_list)
487
+ input_ids_varlen = []
488
+ for text in text_list:
489
+ input_ids_ = self.text_tokenizer.encode(text, return_tensors="pt", add_special_tokens=False) # [1, seq_len]
490
+ input_ids_ = input_ids_.squeeze(0) # [seq_len]
491
+ input_ids_varlen.append(input_ids_)
492
+
493
+ audio_features_varlen = []
494
+ for audio in audio_list:
495
+ assert audio.shape.__len__() == 1 # [seq_len]
496
+ try:
497
+ mel = self.audio_processor(audio) # [100(num_mel_bins), seq_len_mel]
498
+ except Exception as e:
499
+ raise e
500
+ audio_features_varlen.append(mel)
501
+
502
+ return {
503
+ "tts_input_ids_varlen": input_ids_varlen, # return List[Tensor]
504
+ "tts_input_features_varlen": audio_features_varlen, # return List[Tensor]
505
+ }