rulerman commited on
Commit
c29df8a
·
verified ·
1 Parent(s): 19aa13f

update config

Browse files
README.md CHANGED
@@ -22,4 +22,32 @@ MOSS-TTSD supports voice cloning and single-session speech generation of up to 9
22
  - **Two-Speaker Voice Cloning**: MOSS-TTSD supports zero-shot two speakers voice cloning and can generate conversational speech with accurate speaker swithcing based on dialogue scripts.
23
  - **Chinese-English Bilingual Support**: MOSS-TTSD enables highly expressive speech generation in both Chinese and English.
24
  - **Long-Form Speech Generation (up to 960 seconds)**: Thanks to low-bitrate codec and training framework optimization, MOSS-TTSD has been trained for long speech generation, enabling single-session speech generation of up to 960 seconds.
25
- - **Fully Open Source & Commercial-Ready**: MOSS-TTSD and its future updates will be fully open-source and support free commercial use.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  - **Two-Speaker Voice Cloning**: MOSS-TTSD supports zero-shot two speakers voice cloning and can generate conversational speech with accurate speaker swithcing based on dialogue scripts.
23
  - **Chinese-English Bilingual Support**: MOSS-TTSD enables highly expressive speech generation in both Chinese and English.
24
  - **Long-Form Speech Generation (up to 960 seconds)**: Thanks to low-bitrate codec and training framework optimization, MOSS-TTSD has been trained for long speech generation, enabling single-session speech generation of up to 960 seconds.
25
+ - **Fully Open Source & Commercial-Ready**: MOSS-TTSD and its future updates will be fully open-source and support free commercial use.
26
+
27
+
28
+ ```python
29
+ import os
30
+ import torchaudio
31
+ from transformers import AutoModel, AutoProcessor
32
+
33
+ processor = AutoProcessor.from_pretrained("fnlp/MOSS-TTSD-v0.5", codec_path="fnlp/XY_Tokenizer_TTSD_V0_hf", trust_remote_code=True)
34
+ model = AutoModel.from_pretrained("fnlp/MOSS-TTSD-v0.5", trust_remote_code=True, device_map="auto").eval()
35
+
36
+ data = [{
37
+ "base_path": "/path/to/data/",
38
+ "text": "跟踪他们,他俩不行,从屋上平安下来没有扭伤脖子,",
39
+ "system_prompt": "你是一个根据文本生成对应音频的语音合成器。",
40
+ "prompt_text": "这支史诗级的美国迷幻摇滚乐队创建于,",
41
+ "prompt_audio": "prompt.wav",
42
+ }]
43
+
44
+ inputs = processor(data)
45
+ token_ids = model.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])
46
+ text, audios = processor.batch_decode(token_ids)
47
+
48
+ if not os.path.exists("outputs/"):
49
+ os.mkdir("outputs/")
50
+ for i, data in enumerate(audios):
51
+ for j, fragment in enumerate(data):
52
+ torchaudio.save(f"outputs/audio_{i}_{j}.wav", fragment.cpu(), 24000)
53
+ ```
config.json CHANGED
@@ -1,7 +1,11 @@
1
  {
2
- "architectures": [
3
- "AsteroidTTSModel"
4
- ],
 
 
 
 
5
  "attention_bias": false,
6
  "attention_dropout": 0.0,
7
  "bos_token_id": 151643,
@@ -14,7 +18,6 @@
14
  "intermediate_size": 6144,
15
  "max_position_embeddings": 32768,
16
  "max_window_layers": 28,
17
- "model_type": "qwen3",
18
  "num_attention_heads": 16,
19
  "num_hidden_layers": 28,
20
  "num_key_value_heads": 8,
@@ -30,7 +33,7 @@
30
  "speech_vocab_size": 1025,
31
  "tie_word_embeddings": true,
32
  "torch_dtype": "bfloat16",
33
- "transformers_version": "4.51.3",
34
  "use_cache": true,
35
  "use_sliding_window": false,
36
  "vocab_size": 152697
 
1
  {
2
+ "model_type": "moss_ttsd",
3
+ "architectures": ["MossTTSDModel"],
4
+ "auto_map": {
5
+ "AutoProcessor": "processing_moss_ttsd.MossTTSDProcessor",
6
+ "AutoConfig": "configuration_moss_ttsd.MossTTSDConfig",
7
+ "AutoModel": "modeling_moss_ttsd.MossTTSDForCausalLM"
8
+ },
9
  "attention_bias": false,
10
  "attention_dropout": 0.0,
11
  "bos_token_id": 151643,
 
18
  "intermediate_size": 6144,
19
  "max_position_embeddings": 32768,
20
  "max_window_layers": 28,
 
21
  "num_attention_heads": 16,
22
  "num_hidden_layers": 28,
23
  "num_key_value_heads": 8,
 
33
  "speech_vocab_size": 1025,
34
  "tie_word_embeddings": true,
35
  "torch_dtype": "bfloat16",
36
+ "transformers_version": "4.53.2",
37
  "use_cache": true,
38
  "use_sliding_window": false,
39
  "vocab_size": 152697
configuration_moss_ttsd.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/asteroid/modular_asteroid.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_asteroid.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ # coding=utf-8
8
+ # Copyright 2025 OpenMOSS and the HuggingFace Inc. team. All rights reserved.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+ from transformers.configuration_utils import PretrainedConfig, layer_type_validation
22
+ from transformers.modeling_rope_utils import rope_config_validation
23
+ from transformers.utils import logging
24
+
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+
29
+ class MossTTSDConfig(PretrainedConfig):
30
+ r"""
31
+ This is the configuration class to store the configuration of a [`MossTTSDModel`]. It is used to instantiate a
32
+ MOSS-TTSD model according to the specified arguments, defining the model architecture. Instantiating a
33
+ configuration with the defaults will yield a similar configuration to that of the MOSS-TTSD
34
+ [fnlp/MOSS-TTSD-v0.5](https://huggingface.co/fnlp/MOSS-TTSD-v0.5) architecture.
35
+
36
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
37
+ documentation from [`PretrainedConfig`] for more information.
38
+
39
+ Example:
40
+
41
+ ```python
42
+ >>> from transformers import MossTTSDConfig, MossTTSDModel
43
+
44
+ >>> # Initializing a MOSS-TTSD configuration
45
+ >>> configuration = MossTTSDConfig()
46
+
47
+ >>> # Initializing a model from the configuration
48
+ >>> model = MossTTSDModel(configuration)
49
+
50
+ >>> # Accessing the model configuration
51
+ >>> configuration = model.config
52
+ ```
53
+
54
+ Args:
55
+ vocab_size (`int`, *optional*, defaults to 152697):
56
+ Vocabulary size of the MOSS-TTSD model. Defines the number of different tokens that can be represented by the
57
+ `inputs_ids` passed when calling [`MossTTSDModel`]
58
+ hidden_size (`int`, *optional*, defaults to 2048):
59
+ Dimension of the hidden representations.
60
+ intermediate_size (`int`, *optional*, defaults to 6144):
61
+ Dimension of the MLP representations.
62
+ num_hidden_layers (`int`, *optional*, defaults to 28):
63
+ Number of hidden layers in the Transformer encoder.
64
+ num_attention_heads (`int`, *optional*, defaults to 16):
65
+ Number of attention heads for each attention layer in the Transformer encoder.
66
+ num_key_value_heads (`int`, *optional*, defaults to 8):
67
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
68
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
69
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
70
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
71
+ by meanpooling all the original heads within that group. For more details, check out [this
72
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`.
73
+ head_dim (`int`, *optional*, defaults to 128):
74
+ The attention head dimension.
75
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
76
+ The non-linear activation function (function or string) in the decoder.
77
+ max_position_embeddings (`int`, *optional*, defaults to 32768):
78
+ The maximum sequence length that this model might ever be used with.
79
+ initializer_range (`float`, *optional*, defaults to 0.02):
80
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
81
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
82
+ The epsilon used by the rms normalization layers.
83
+ use_cache (`bool`, *optional*, defaults to `True`):
84
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
85
+ relevant if `config.is_decoder=True`.
86
+ tie_word_embeddings (`bool`, *optional*, defaults to `True`):
87
+ Whether the model's input and output word embeddings should be tied.
88
+ rope_theta (`float`, *optional*, defaults to 1000000.0):
89
+ The base period of the RoPE embeddings.
90
+ rope_scaling (`Dict`, *optional*):
91
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
92
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
93
+ accordingly.
94
+ Expected contents:
95
+ `rope_type` (`str`):
96
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
97
+ 'llama3'], with 'default' being the original RoPE implementation.
98
+ `factor` (`float`, *optional*):
99
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
100
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
101
+ original maximum pre-trained length.
102
+ `original_max_position_embeddings` (`int`, *optional*):
103
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
104
+ pretraining.
105
+ `attention_factor` (`float`, *optional*):
106
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
107
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
108
+ `factor` field to infer the suggested value.
109
+ `beta_fast` (`float`, *optional*):
110
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
111
+ ramp function. If unspecified, it defaults to 32.
112
+ `beta_slow` (`float`, *optional*):
113
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
114
+ ramp function. If unspecified, it defaults to 1.
115
+ `short_factor` (`list[float]`, *optional*):
116
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
117
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
118
+ size divided by the number of attention heads divided by 2
119
+ `long_factor` (`list[float]`, *optional*):
120
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
121
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
122
+ size divided by the number of attention heads divided by 2
123
+ `low_freq_factor` (`float`, *optional*):
124
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
125
+ `high_freq_factor` (`float`, *optional*):
126
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
127
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
128
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
129
+ use_sliding_window (`bool`, *optional*, defaults to `False`):
130
+ Whether to use sliding window attention.
131
+ sliding_window (`int`, *optional*, defaults to 4096):
132
+ Sliding window attention (SWA) window size. If not specified, will default to `4096`.
133
+ max_window_layers (`int`, *optional*, defaults to 28):
134
+ The number of layers using full attention. The first `max_window_layers` layers will use full attention, while any
135
+ additional layer afterwards will use SWA (Sliding Window Attention).
136
+ layer_types (`list`, *optional*):
137
+ Attention pattern for each layer.
138
+ attention_dropout (`float`, *optional*, defaults to 0.0):
139
+ The dropout ratio for the attention probabilities.
140
+ channels (`<fill_type>`, *optional*, defaults to 8): <fill_docstring>
141
+ speech_vocab_size (`<fill_type>`, *optional*, defaults to 1025): <fill_docstring>
142
+ speech_pad_token (`<fill_type>`, *optional*, defaults to 1024): <fill_docstring>
143
+ speech_token_range (`<fill_type>`, *optional*, defaults to `(151665, 152689)`): <fill_docstring>
144
+ speech_eos_token (`<fill_type>`, *optional*, defaults to 152694): <fill_docstring>
145
+
146
+ ```python
147
+ >>> from transformers import MossTTSDModel, MossTTSDConfig
148
+
149
+ >>> # Initializing a Qwen3 style configuration
150
+ >>> configuration = MossTTSDConfig()
151
+
152
+ >>> # Initializing a model from the Qwen3-8B style configuration
153
+ >>> model = MossTTSDModel(configuration)
154
+
155
+ >>> # Accessing the model configuration
156
+ >>> configuration = model.config
157
+ ```"""
158
+
159
+ model_type = "moss_ttsd"
160
+ keys_to_ignore_at_inference = ["past_key_values"]
161
+
162
+ # Default tensor parallel plan for base model `MossTTSD`
163
+ base_model_tp_plan = {
164
+ "layers.*.self_attn.q_proj": "colwise",
165
+ "layers.*.self_attn.k_proj": "colwise",
166
+ "layers.*.self_attn.v_proj": "colwise",
167
+ "layers.*.self_attn.o_proj": "rowwise",
168
+ "layers.*.mlp.gate_proj": "colwise",
169
+ "layers.*.mlp.up_proj": "colwise",
170
+ "layers.*.mlp.down_proj": "rowwise",
171
+ }
172
+ base_model_pp_plan = {
173
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
174
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
175
+ "norm": (["hidden_states"], ["hidden_states"]),
176
+ }
177
+
178
+ def __init__(
179
+ self,
180
+ vocab_size=152697,
181
+ hidden_size=2048,
182
+ intermediate_size=6144,
183
+ num_hidden_layers=28,
184
+ num_attention_heads=16,
185
+ num_key_value_heads=8,
186
+ head_dim=128,
187
+ hidden_act="silu",
188
+ max_position_embeddings=32768,
189
+ initializer_range=0.02,
190
+ rms_norm_eps=1e-6,
191
+ use_cache=True,
192
+ tie_word_embeddings=True,
193
+ rope_theta=1000000.0,
194
+ rope_scaling=None,
195
+ attention_bias=False,
196
+ use_sliding_window=False,
197
+ sliding_window=None,
198
+ max_window_layers=28,
199
+ layer_types=None,
200
+ attention_dropout=0.0,
201
+ channels=8,
202
+ speech_vocab_size=1025,
203
+ speech_pad_token=1024,
204
+ speech_token_range=(151665, 152689),
205
+ speech_eos_token=152694,
206
+ **kwargs,
207
+ ):
208
+ self.vocab_size = vocab_size
209
+ self.max_position_embeddings = max_position_embeddings
210
+ self.hidden_size = hidden_size
211
+ self.intermediate_size = intermediate_size
212
+ self.num_hidden_layers = num_hidden_layers
213
+ self.num_attention_heads = num_attention_heads
214
+ self.use_sliding_window = use_sliding_window
215
+ self.sliding_window = sliding_window if self.use_sliding_window else None
216
+ self.max_window_layers = max_window_layers
217
+
218
+ # for backward compatibility
219
+ if num_key_value_heads is None:
220
+ num_key_value_heads = num_attention_heads
221
+
222
+ self.num_key_value_heads = num_key_value_heads
223
+ self.head_dim = head_dim
224
+ self.hidden_act = hidden_act
225
+ self.initializer_range = initializer_range
226
+ self.rms_norm_eps = rms_norm_eps
227
+ self.use_cache = use_cache
228
+ self.rope_theta = rope_theta
229
+ self.rope_scaling = rope_scaling
230
+ self.attention_bias = attention_bias
231
+ self.attention_dropout = attention_dropout
232
+ # Validate the correctness of rotary position embeddings parameters
233
+ # BC: if there is a 'type' field, move it to 'rope_type'.
234
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
235
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
236
+ rope_config_validation(self)
237
+
238
+ self.layer_types = layer_types
239
+ if self.layer_types is None:
240
+ self.layer_types = [
241
+ "sliding_attention"
242
+ if self.sliding_window is not None and i >= self.max_window_layers
243
+ else "full_attention"
244
+ for i in range(self.num_hidden_layers)
245
+ ]
246
+ layer_type_validation(self.layer_types)
247
+
248
+ self.channels = channels
249
+ self.speech_vocab_size = speech_vocab_size
250
+ self.speech_pad_token = speech_pad_token
251
+ self.speech_token_range = speech_token_range
252
+ self.speech_eos_token = speech_eos_token
253
+
254
+ super().__init__(
255
+ tie_word_embeddings=tie_word_embeddings,
256
+ **kwargs,
257
+ )
258
+
259
+
260
+ __all__ = ["MossTTSDConfig"]
modeling.py ADDED
@@ -0,0 +1,426 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from dataclasses import dataclass
4
+ from transformers.utils import ModelOutput
5
+ from transformers.cache_utils import Cache
6
+ from typing import Optional, List, Tuple, Union
7
+ from transformers.loss.loss_utils import ForCausalLMLoss
8
+ from transformers.generation.streamers import BaseStreamer
9
+ from transformers.modeling_outputs import BaseModelOutputWithPast
10
+ from transformers.generation.configuration_utils import GenerationConfig
11
+ from transformers.generation.stopping_criteria import StoppingCriteriaList
12
+ from transformers import PreTrainedModel, GenerationMixin, Qwen3Config, Qwen3Model
13
+ from transformers.generation.logits_process import (
14
+ LogitsProcessorList,
15
+ RepetitionPenaltyLogitsProcessor,
16
+ TopKLogitsWarper,
17
+ TopPLogitsWarper,
18
+ TemperatureLogitsWarper
19
+ )
20
+
21
+
22
+ class AsteroidTTSConfig(Qwen3Config):
23
+ def __init__(self,
24
+ channels = 8,
25
+ speech_pad_token = 1024,
26
+ speech_vocab_size = 1025,
27
+ speech_token_range = [],
28
+ **kwargs):
29
+ super().__init__(**kwargs)
30
+ self.channels = channels
31
+ self.speech_pad_token = speech_pad_token
32
+ self.speech_vocab_size = speech_vocab_size
33
+ self.speech_token_range = speech_token_range
34
+
35
+
36
+ @dataclass
37
+ class AsteroidTTSOutputWithPast(ModelOutput):
38
+ loss: Optional[torch.FloatTensor] = None
39
+ logits: torch.FloatTensor = None
40
+ loss_all: Optional[Tuple[torch.FloatTensor]] = None
41
+ logits_all: Optional[Tuple[torch.FloatTensor]] = None
42
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
43
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
44
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
45
+
46
+
47
+ @dataclass
48
+ class GenerateDecoderOnlyOutput(ModelOutput):
49
+ sequences: torch.LongTensor = None
50
+ scores: Optional[Tuple[torch.FloatTensor]] = None
51
+ logits: Optional[Tuple[torch.FloatTensor]] = None
52
+ attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
53
+ hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
54
+ past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None
55
+
56
+
57
+ class AsteroidTTSPretrainedModel(PreTrainedModel):
58
+ config_class = AsteroidTTSConfig
59
+ base_model_prefix = "model"
60
+ supports_gradient_checkpointing = True
61
+ _no_split_modules = ["Qwen3DecoderLayer"]
62
+ _skip_keys_device_placement = ["past_key_values"]
63
+ _supports_flash_attn_2 = True
64
+ _supports_sdpa = True
65
+ _supports_flex_attn = True
66
+ _supports_cache_class = True
67
+ _supports_quantized_cache = True
68
+ _supports_static_cache = True
69
+ _supports_attention_backend = True
70
+
71
+
72
+ class AsteroidTTSModel(AsteroidTTSPretrainedModel):
73
+ def __init__(self, config: AsteroidTTSConfig):
74
+ super().__init__(config)
75
+ self.text_pad_idx = config.pad_token_id
76
+ self.speech_pad_idx = config.speech_pad_token
77
+ self.embedding_list = nn.ModuleList([])
78
+ self.embedding_list.append(nn.Embedding(config.vocab_size, config.hidden_size, self.text_pad_idx))
79
+ # Channels 1 to channels-1: Speech tokens only
80
+ for _ in range(1, config.channels):
81
+ self.embedding_list.append(nn.Embedding(config.speech_vocab_size, config.hidden_size, self.speech_pad_idx))
82
+
83
+ self.language_model = Qwen3Model(config)
84
+ self.post_init()
85
+
86
+ def get_input_embeddings(self):
87
+ return self.embedding_list[0]
88
+
89
+ def set_input_embeddings(self, value: nn.Embedding):
90
+ self.embedding_list[0] = value
91
+
92
+ def _prepare_multi_modal_inputs(self, input_ids: torch.LongTensor) -> torch.FloatTensor:
93
+ """
94
+ Prepares multi-modal embeddings from input_ids of shape (batch_size, channels, sequence_length).
95
+ For channel 0: text + speech tokens, for channels 1 to channels-1: speech tokens padded with speech_pad_token.
96
+ """
97
+ batch_size, seq_length, channels = input_ids.shape
98
+ if channels != self.config.channels:
99
+ raise ValueError(f"Expected {self.config.channels} channels, got {channels}")
100
+
101
+ inputs_embeds = torch.zeros(batch_size, seq_length, self.config.hidden_size, device=input_ids.device, dtype=self.embedding_list[0].weight.dtype)
102
+ for i in range(channels):
103
+ embed_layer = self.embedding_list[i]
104
+ channel_input = input_ids[...,i]
105
+ inputs_embeds += embed_layer(channel_input)
106
+
107
+ return inputs_embeds
108
+
109
+ def forward(
110
+ self,
111
+ input_ids: torch.LongTensor = None, # Shape: (batch_size, channels, sequence_length)
112
+ attention_mask: Optional[torch.Tensor] = None,
113
+ position_ids: Optional[torch.LongTensor] = None,
114
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
115
+ inputs_embeds: Optional[torch.FloatTensor] = None,
116
+ use_cache: Optional[bool] = None,
117
+ output_attentions: Optional[bool] = None,
118
+ output_hidden_states: Optional[bool] = None,
119
+ return_dict: Optional[bool] = None,
120
+ cache_position: Optional[torch.LongTensor] = None,
121
+ **kwargs,
122
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
123
+
124
+ if (input_ids is None) ^ (inputs_embeds is not None):
125
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
126
+
127
+ if input_ids is not None:
128
+ inputs_embeds = self._prepare_multi_modal_inputs(input_ids)
129
+
130
+ outputs = self.language_model(
131
+ input_ids=None,
132
+ attention_mask=attention_mask,
133
+ position_ids=position_ids,
134
+ past_key_values=past_key_values,
135
+ inputs_embeds=inputs_embeds,
136
+ use_cache=use_cache,
137
+ output_attentions=output_attentions,
138
+ output_hidden_states=output_hidden_states,
139
+ return_dict=return_dict,
140
+ cache_position=cache_position,
141
+ )
142
+ return outputs
143
+
144
+
145
+ class AsteroidTTSInstruct(AsteroidTTSPretrainedModel, GenerationMixin):
146
+ _tied_weights_keys = []
147
+ _tp_plan = {"lm_head": "colwise_rep"}
148
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
149
+
150
+ def __init__(self, config: AsteroidTTSConfig):
151
+ super().__init__(config)
152
+ self.model = AsteroidTTSModel(config)
153
+ self.channels = config.channels
154
+ self.weights = [1 for _ in range(self.channels)]
155
+ self._tied_weights_keys = [f"lm_heads.{i}.weight" for i in range(self.channels)]
156
+ self.vocab_size = config.vocab_size
157
+ self.lm_heads = nn.ModuleList([])
158
+ self.lm_heads.append(nn.Linear(config.hidden_size, config.vocab_size, bias=False))
159
+ for _ in range(1, config.channels):
160
+ self.lm_heads.append(nn.Linear(config.hidden_size, config.speech_vocab_size, bias=False))
161
+ self.post_init()
162
+
163
+ def get_input_embeddings(self):
164
+ return self.model.embedding_list[0]
165
+
166
+ def can_generate(self):
167
+ return True
168
+
169
+ def is_speech_token(self, tokens):
170
+ return (tokens >= self.config.speech_token_range[0]) & (tokens < self.config.speech_token_range[1])
171
+
172
+ def tie_weights(self):
173
+ for i in range(self.config.channels):
174
+ self._tie_or_clone_weights(self.lm_heads[i], self.model.embedding_list[i])
175
+
176
+ def set_input_embeddings(self, value):
177
+ self.model.embedding_list[0] = value
178
+
179
+ def get_output_embeddings(self):
180
+ return self.lm_heads[0]
181
+
182
+ def set_output_embeddings(self, new_embeddings):
183
+ self.lm_heads[0] = new_embeddings
184
+
185
+ def set_decoder(self, decoder):
186
+ self.model = decoder
187
+
188
+ def get_decoder(self):
189
+ return self.model
190
+
191
+ def set_weights(self, weights):
192
+ self.weights = weights
193
+
194
+ def forward(
195
+ self,
196
+ input_ids: torch.LongTensor = None,
197
+ attention_mask: Optional[torch.Tensor] = None,
198
+ position_ids: Optional[torch.LongTensor] = None,
199
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
200
+ inputs_embeds: Optional[torch.FloatTensor] = None,
201
+ labels: Optional[torch.LongTensor] = None,
202
+ use_cache: Optional[bool] = None,
203
+ output_attentions: Optional[bool] = None,
204
+ output_hidden_states: Optional[bool] = None,
205
+ return_dict: Optional[bool] = None,
206
+ cache_position: Optional[torch.LongTensor] = None,
207
+ **kwargs,
208
+ ) -> Union[Tuple, AsteroidTTSOutputWithPast]:
209
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
210
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
211
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
212
+
213
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
214
+ outputs = self.model(
215
+ input_ids=input_ids,
216
+ attention_mask=attention_mask,
217
+ position_ids=position_ids,
218
+ past_key_values=past_key_values,
219
+ inputs_embeds=inputs_embeds,
220
+ use_cache=use_cache,
221
+ output_attentions=output_attentions,
222
+ output_hidden_states=output_hidden_states,
223
+ return_dict=return_dict,
224
+ cache_position=cache_position,
225
+ **kwargs,
226
+ )
227
+
228
+ hidden_states = outputs[0]
229
+ logits_all = [lm_head(hidden_states) for lm_head in self.lm_heads]
230
+
231
+ loss_all = torch.empty(self.channels, device=input_ids.device if not input_ids is None else inputs_embeds.device)
232
+
233
+ if labels is not None:
234
+ for i in range(self.config.channels):
235
+ vocab_size = self.config.vocab_size if i == 0 else self.config.speech_vocab_size
236
+ loss_all[i] = ForCausalLMLoss(logits_all[i], labels[..., i], vocab_size)
237
+
238
+ # total_weight = sum(self.weights)
239
+ # normalized_weights = [w / total_weight for w in self.weights]
240
+ normalized_weights = self.weights
241
+
242
+ total_loss = 0
243
+ for w, loss in zip(normalized_weights, loss_all):
244
+ total_loss += w * loss
245
+
246
+ if not return_dict:
247
+ output = (logits_all,) + outputs[1:]
248
+ return (total_loss, loss_all, ) + output if loss is not None else output
249
+
250
+ return AsteroidTTSOutputWithPast(
251
+ loss=total_loss,
252
+ logits=logits_all[0],
253
+ loss_all=loss_all,
254
+ logits_all=logits_all,
255
+ past_key_values=outputs.past_key_values,
256
+ hidden_states=outputs.hidden_states,
257
+ attentions=outputs.attentions,
258
+ )
259
+
260
+ @torch.no_grad()
261
+ def generate(
262
+ self,
263
+ input_ids: Optional[torch.Tensor] = None,
264
+ output_only: bool = True,
265
+ **kwargs,
266
+ ):
267
+ batch_size, seq_len, channels = input_ids.shape
268
+ start_id = seq_len - channels + 1
269
+ outputs = super().generate(input_ids, **kwargs)
270
+ return_dict_in_generate = kwargs.get("return_dict_in_generate", False)
271
+ if return_dict_in_generate:
272
+ output_ids = outputs["sequences"]
273
+ else:
274
+ output_ids = outputs
275
+ if output_only:
276
+ output_ids = output_ids[:, start_id:]
277
+ if return_dict_in_generate:
278
+ outputs["sequences"] = output_ids
279
+ else:
280
+ outputs = output_ids
281
+ return outputs
282
+
283
+ def _sample(
284
+ self,
285
+ input_ids: torch.LongTensor,
286
+ logits_processor: LogitsProcessorList,
287
+ stopping_criteria: StoppingCriteriaList,
288
+ generation_config: GenerationConfig,
289
+ synced_gpus: bool,
290
+ streamer: Optional["BaseStreamer"],
291
+ **model_kwargs,
292
+ ) -> Union[GenerateDecoderOnlyOutput, torch.LongTensor]:
293
+ # 提取配置参数
294
+ speech_pad_idx = self.config.speech_pad_token
295
+
296
+ eos_token_id = generation_config.eos_token_id
297
+ output_attentions = generation_config.output_attentions
298
+ output_hidden_states = generation_config.output_hidden_states
299
+ output_scores = generation_config.output_scores
300
+ output_logits = generation_config.output_logits
301
+ return_dict_in_generate = generation_config.return_dict_in_generate
302
+ max_length = generation_config.max_length
303
+ has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
304
+ do_sample = generation_config.do_sample
305
+
306
+ # 初始化输出元组
307
+ scores = () if (return_dict_in_generate and output_scores) else None
308
+ raw_logits = () if (return_dict_in_generate and output_logits) else None
309
+ decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
310
+ decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
311
+
312
+ # 初始化跟踪变量
313
+ batch_size, cur_len, channels = input_ids.shape # channels = 8
314
+ this_peer_finished = False
315
+ unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
316
+ needs_additional_steps = -1 * torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
317
+ tf_inputs = input_ids[:]
318
+ input_ids = input_ids[:, :-(channels - 1)]
319
+ model_kwargs["attention_mask"] = model_kwargs["attention_mask"][:, :-(channels - 1)]
320
+ base_length = input_ids.shape[1]
321
+ model_kwargs = self._get_initial_cache_position(base_length, input_ids.device, model_kwargs)
322
+
323
+ # 定义logits processor
324
+ if generation_config.do_samples is not None:
325
+ do_samples = generation_config.do_samples
326
+ realprocessor = [LogitsProcessorList() for _ in range(channels)]
327
+ for i, layer_config in enumerate(generation_config.layers):
328
+ if layer_config.get("repetition_penalty") is not None:
329
+ realprocessor[i].append(RepetitionPenaltyLogitsProcessor(penalty=layer_config.get("repetition_penalty")))
330
+ if layer_config.get("temperature") is not None:
331
+ realprocessor[i].append(TemperatureLogitsWarper(temperature=layer_config.get("temperature")))
332
+ if layer_config.get("top_k") is not None:
333
+ realprocessor[i].append(TopKLogitsWarper(top_k=layer_config.get("top_k")))
334
+ if layer_config.get("top_p") is not None:
335
+ realprocessor[i].append(TopPLogitsWarper(top_p=layer_config.get("top_p")))
336
+ else:
337
+ do_samples = [do_sample for _ in range(channels)]
338
+ realprocessor = [logits_processor for _ in range(channels)]
339
+ while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
340
+ # 准备模型输入
341
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
342
+ model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
343
+ model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
344
+ # 前向传递
345
+ outputs = self(**model_inputs, return_dict=True)
346
+ model_kwargs = self._update_model_kwargs_for_generation(outputs, model_kwargs)
347
+
348
+ if synced_gpus and this_peer_finished:
349
+ continue
350
+
351
+ # 获取下一个 token 的 logits
352
+ next_token_logits = [logits[:, -1, :].clone().float().to(input_ids.device) for logits in outputs.logits_all]
353
+ for i, channel_logits in enumerate(next_token_logits):
354
+ if i != 0 and input_ids.shape[1] + 1 > tf_inputs.shape[1] - 7 + i:
355
+ channel_logits[:, 1024] = - torch.inf
356
+ if i == 0 and input_ids.shape[1] + 1 <= tf_inputs.shape[1]:
357
+ channel_logits[:, 152694] = - torch.inf
358
+ next_token_scores = [realprocessor[i](input_ids[..., i], logits) for i, logits in enumerate(next_token_logits)]
359
+ # 生成下一个 token
360
+ next_tokens = []
361
+ for i, channel_score in enumerate(next_token_scores):
362
+ if do_samples[i]:
363
+ channel_ntk = torch.multinomial(nn.functional.softmax(channel_score, dim=-1), num_samples=1).squeeze(1)
364
+ elif not do_samples[i]:
365
+ channel_ntk = torch.argmax(channel_score, dim=-1)
366
+ next_tokens.append(channel_ntk)
367
+ next_tokens = torch.stack(next_tokens, dim=-1) # [batch_size, channels]
368
+ # 额外步骤逻辑
369
+ indices = (~self.is_speech_token(next_tokens[:, 0])) & (needs_additional_steps < 0)
370
+ needs_additional_steps[indices] = channels - 1 # 对于 8 个通道,需要 6 步
371
+
372
+ if input_ids.shape[1] + 1 <= tf_inputs.shape[1]:
373
+ i = input_ids.shape[1] + 1 - base_length
374
+ next_tokens[:, i:] = tf_inputs[:, input_ids.shape[1], i:]
375
+
376
+ # 在额外步骤中替换 token
377
+ mask = (needs_additional_steps > 0) & (needs_additional_steps < 7)
378
+ if mask.any().item():
379
+ next_tokens[mask, 0] = self.config.eos_token_id
380
+ for i in range(1, channels):
381
+ mask_i = mask & (needs_additional_steps < channels - i)
382
+ next_tokens[mask_i, i] = speech_pad_idx
383
+
384
+ if has_eos_stopping_criteria:
385
+ for i in range(channels):
386
+ pddp = self.config.eos_token_id if i == 0 else speech_pad_idx
387
+ next_tokens[:, i] = next_tokens[:, i] * unfinished_sequences + pddp * (1 - unfinished_sequences)
388
+
389
+ input_ids = torch.cat([input_ids, next_tokens[:, None, :]], dim=1)
390
+ if streamer is not None:
391
+ streamer.put(next_tokens[:, 0].cpu())
392
+
393
+ # 更新 unfinished_sequences
394
+ needs_additional_steps = torch.where(needs_additional_steps > 0, needs_additional_steps - 1, needs_additional_steps)
395
+ stopping = stopping_criteria(input_ids[..., 0], scores) | (needs_additional_steps == 0)
396
+ unfinished_sequences = unfinished_sequences & ~stopping
397
+ unfinished_sequences = unfinished_sequences | (needs_additional_steps > 0)
398
+ this_peer_finished = unfinished_sequences.max() == 0
399
+
400
+ if return_dict_in_generate:
401
+ if output_scores:
402
+ scores += (next_token_scores,)
403
+ if output_logits:
404
+ raw_logits += (next_token_logits,)
405
+ if output_attentions:
406
+ decoder_attentions += (outputs.attentions,)
407
+ if output_hidden_states:
408
+ decoder_hidden_states += (outputs.hidden_states,)
409
+
410
+ cur_len += 1
411
+ del outputs
412
+
413
+ if streamer is not None:
414
+ streamer.end()
415
+
416
+ if return_dict_in_generate:
417
+ return GenerateDecoderOnlyOutput(
418
+ sequences=input_ids,
419
+ scores=scores,
420
+ logits=raw_logits,
421
+ attentions=decoder_attentions,
422
+ hidden_states=decoder_hidden_states,
423
+ past_key_values=model_kwargs.get("past_key_values"),
424
+ )
425
+ else:
426
+ return input_ids
modeling_moss_ttsd.py ADDED
@@ -0,0 +1,611 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 OpenMOSS and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch MOSS-TTSD model."""
16
+
17
+ from dataclasses import dataclass
18
+ from typing import Optional, Union
19
+
20
+ from transformers.cache_utils import Cache
21
+ from transformers.generation import GenerationConfig, GenerationMixin, LogitsProcessorList, StoppingCriteriaList
22
+ from transformers.generation.logits_process import (
23
+ RepetitionPenaltyLogitsProcessor,
24
+ TemperatureLogitsWarper,
25
+ TopKLogitsWarper,
26
+ TopPLogitsWarper,
27
+ )
28
+ from transformers.generation.streamers import BaseStreamer
29
+ from transformers.generation.utils import GenerateDecoderOnlyOutput
30
+ from transformers.loss.loss_utils import ForCausalLMLoss
31
+ from transformers.modeling_outputs import BaseModelOutputWithPast
32
+ from transformers.modeling_utils import PreTrainedModel
33
+ from transformers.models.qwen3.modeling_qwen3 import Qwen3Model
34
+ from transformers.utils import ModelOutput, auto_docstring, is_torch_available
35
+ from .configuration_moss_ttsd import MossTTSDConfig
36
+
37
+
38
+ if is_torch_available():
39
+ import torch
40
+ import torch.nn as nn
41
+
42
+ _CHECKPOINT_FOR_DOC = "fnlp/MOSS-TTSD-v0.5"
43
+
44
+
45
+ @dataclass
46
+ @auto_docstring(
47
+ custom_intro="""
48
+ Base class for MOSS-TTSD outputs, with hidden states and attentions.
49
+ """
50
+ )
51
+ class MossTTSDOutputWithPast(ModelOutput):
52
+ """Base class for MOSS-TTSD outputs with past key values."""
53
+
54
+ loss: Optional[torch.FloatTensor] = None
55
+ logits: torch.FloatTensor = None
56
+ loss_all: Optional[tuple[torch.FloatTensor, ...]] = None
57
+ logits_all: Optional[tuple[torch.FloatTensor, ...]] = None
58
+ past_key_values: Optional[tuple[tuple[torch.FloatTensor, ...], ...]] = None
59
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
60
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
61
+
62
+
63
+ @dataclass
64
+ @auto_docstring(
65
+ custom_intro="""
66
+ Base class for MOSS-TTSD causal language model (or autoregressive) outputs.
67
+ """
68
+ )
69
+ class MossTTSDCausalLMOutputWithPast(ModelOutput):
70
+ r"""
71
+ Base class for MOSS-TTSD causal language model outputs.
72
+
73
+ Args:
74
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
75
+ Language modeling loss (for next-token prediction).
76
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
77
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
78
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
79
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
80
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
81
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
82
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
83
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
84
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
85
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
86
+ sequence_length)`.
87
+ """
88
+
89
+ loss: Optional[torch.FloatTensor] = None
90
+ logits: torch.FloatTensor = None
91
+ past_key_values: Optional[Cache] = None
92
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
93
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
94
+
95
+
96
+ class MossTTSDGenerationMixin(GenerationMixin):
97
+ """
98
+ Generation mixin for MossTTSD model with multi-channel support.
99
+ """
100
+
101
+ def _setup_channel_processors(
102
+ self, generation_config: GenerationConfig, channels: int
103
+ ) -> list[LogitsProcessorList]:
104
+ """Setup logits processors for each channel based on generation config."""
105
+ realprocessor = [LogitsProcessorList() for _ in range(channels)]
106
+
107
+ if hasattr(generation_config, "layers"):
108
+ for i, layer_config in enumerate(generation_config.layers):
109
+ if i >= channels:
110
+ break
111
+
112
+ if layer_config.get("repetition_penalty") is not None:
113
+ realprocessor[i].append(
114
+ RepetitionPenaltyLogitsProcessor(penalty=layer_config.get("repetition_penalty"))
115
+ )
116
+ if layer_config.get("temperature") is not None:
117
+ realprocessor[i].append(TemperatureLogitsWarper(temperature=layer_config.get("temperature")))
118
+ if layer_config.get("top_k") is not None:
119
+ realprocessor[i].append(TopKLogitsWarper(top_k=layer_config.get("top_k")))
120
+ if layer_config.get("top_p") is not None:
121
+ realprocessor[i].append(TopPLogitsWarper(top_p=layer_config.get("top_p")))
122
+
123
+ return realprocessor
124
+
125
+ def _generate_next_tokens_with_scores(
126
+ self,
127
+ logits_all: tuple[torch.Tensor, ...],
128
+ input_ids: torch.LongTensor,
129
+ tf_inputs: torch.LongTensor,
130
+ channels: int,
131
+ realprocessor: list[LogitsProcessorList],
132
+ do_samples: list[bool],
133
+ speech_pad_idx: int,
134
+ ) -> tuple[torch.LongTensor, tuple[torch.Tensor, ...], tuple[torch.Tensor, ...]]:
135
+ """Generate next tokens for all channels with scores and logits."""
136
+ # Get next token logits
137
+ next_token_logits = tuple(logits[:, -1, :].clone().float().to(input_ids.device) for logits in logits_all)
138
+
139
+ # Apply channel-specific constraints
140
+ for i, channel_logits in enumerate(next_token_logits):
141
+ if i != 0 and input_ids.shape[1] + 1 > tf_inputs.shape[1] - 7 + i:
142
+ channel_logits[:, speech_pad_idx] = -torch.inf
143
+ if i == 0 and input_ids.shape[1] + 1 <= tf_inputs.shape[1]:
144
+ channel_logits[:, self.config.speech_eos_token] = -torch.inf
145
+
146
+ # Process logits
147
+ next_token_scores = tuple(
148
+ realprocessor[i](input_ids[..., i], logits) for i, logits in enumerate(next_token_logits)
149
+ )
150
+
151
+ # Sample or select tokens
152
+ next_tokens = []
153
+ for i, channel_score in enumerate(next_token_scores):
154
+ if do_samples[i]:
155
+ channel_ntk = torch.multinomial(nn.functional.softmax(channel_score, dim=-1), num_samples=1).squeeze(1)
156
+ else:
157
+ channel_ntk = torch.argmax(channel_score, dim=-1)
158
+ next_tokens.append(channel_ntk)
159
+
160
+ return torch.stack(next_tokens, dim=-1), next_token_scores, next_token_logits
161
+
162
+ def _process_multi_channel_tokens(
163
+ self,
164
+ next_tokens: torch.LongTensor,
165
+ needs_additional_steps: torch.LongTensor,
166
+ input_ids: torch.LongTensor,
167
+ tf_inputs: torch.LongTensor,
168
+ base_length: int,
169
+ channels: int,
170
+ eos_token_id: Optional[int],
171
+ speech_pad_idx: int,
172
+ unfinished_sequences: torch.LongTensor,
173
+ has_eos_stopping_criteria: bool,
174
+ ) -> tuple[torch.LongTensor, torch.LongTensor]:
175
+ """Process tokens for multi-channel TTS generation."""
176
+ # Additional steps logic
177
+ indices = (~self.is_speech_token(next_tokens[:, 0])) & (needs_additional_steps < 0)
178
+ needs_additional_steps[indices] = channels - 1 # For 8 channels, need 7 steps
179
+
180
+ if input_ids.shape[1] + 1 <= tf_inputs.shape[1]:
181
+ i = input_ids.shape[1] + 1 - base_length
182
+ next_tokens[:, i:] = tf_inputs[:, input_ids.shape[1], i:]
183
+
184
+ # Replace tokens in additional steps
185
+ mask = (needs_additional_steps > 0) & (needs_additional_steps < 7)
186
+ if mask.any().item():
187
+ next_tokens[mask, 0] = eos_token_id
188
+ for i in range(1, channels):
189
+ mask_i = mask & (needs_additional_steps < channels - i)
190
+ next_tokens[mask_i, i] = speech_pad_idx
191
+
192
+ if has_eos_stopping_criteria:
193
+ for i in range(channels):
194
+ pddp = eos_token_id if i == 0 else speech_pad_idx
195
+ next_tokens[:, i] = next_tokens[:, i] * unfinished_sequences + pddp * (1 - unfinished_sequences)
196
+
197
+ return next_tokens, needs_additional_steps
198
+
199
+ def _sample(
200
+ self,
201
+ input_ids: torch.LongTensor,
202
+ logits_processor: LogitsProcessorList,
203
+ stopping_criteria: StoppingCriteriaList,
204
+ generation_config: GenerationConfig,
205
+ synced_gpus: bool,
206
+ streamer: Optional[BaseStreamer],
207
+ **model_kwargs,
208
+ ) -> Union[GenerateDecoderOnlyOutput, torch.LongTensor]:
209
+ """Sample method for multi-channel TTS generation."""
210
+ # Extract configuration parameters
211
+ speech_pad_idx = getattr(self.config, "speech_pad_token", 1024)
212
+ eos_token_id = generation_config.eos_token_id
213
+ channels = getattr(self.config, "channels", 8)
214
+
215
+ # Generation config parameters
216
+ output_attentions = generation_config.output_attentions
217
+ output_hidden_states = generation_config.output_hidden_states
218
+ output_scores = generation_config.output_scores
219
+ output_logits = generation_config.output_logits
220
+ return_dict_in_generate = generation_config.return_dict_in_generate
221
+ has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
222
+ do_sample = generation_config.do_sample
223
+
224
+ # Initialize output tuples
225
+ scores = () if (return_dict_in_generate and output_scores) else None
226
+ raw_logits = () if (return_dict_in_generate and output_logits) else None
227
+ decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
228
+ decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
229
+
230
+ # Initialize tracking variables
231
+ batch_size, cur_len, input_channels = input_ids.shape
232
+ this_peer_finished = False
233
+ unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
234
+ needs_additional_steps = -1 * torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
235
+
236
+ # Adjust input for generation
237
+ tf_inputs = input_ids.clone()
238
+ input_ids = input_ids[:, : -(channels - 1)]
239
+ cur_len = input_ids.shape[1]
240
+ model_kwargs["attention_mask"] = model_kwargs["attention_mask"][:, : -(channels - 1)]
241
+ base_length = input_ids.shape[1]
242
+ model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
243
+
244
+ # Setup logits processors and sampling config
245
+ if hasattr(generation_config, "do_samples") and generation_config.do_samples is not None:
246
+ do_samples = generation_config.do_samples
247
+ realprocessor = self._setup_channel_processors(generation_config, channels)
248
+ else:
249
+ do_samples = [do_sample for _ in range(channels)]
250
+ realprocessor = [logits_processor for _ in range(channels)]
251
+ while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
252
+ # Prepare model inputs
253
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
254
+ model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
255
+ model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
256
+ # Forward pass
257
+ outputs = self(**model_inputs, return_dict=True)
258
+ model_kwargs = self._update_model_kwargs_for_generation(outputs, model_kwargs)
259
+
260
+ if synced_gpus and this_peer_finished:
261
+ continue
262
+
263
+ # Generate next tokens for all channels
264
+ next_tokens, next_token_scores, next_token_logits = self._generate_next_tokens_with_scores(
265
+ outputs.logits_all, input_ids, tf_inputs, channels, realprocessor, do_samples, speech_pad_idx
266
+ )
267
+ # Process tokens for multi-channel TTS
268
+ next_tokens, needs_additional_steps = self._process_multi_channel_tokens(
269
+ next_tokens,
270
+ needs_additional_steps,
271
+ input_ids,
272
+ tf_inputs,
273
+ base_length,
274
+ channels,
275
+ eos_token_id,
276
+ speech_pad_idx,
277
+ unfinished_sequences,
278
+ has_eos_stopping_criteria,
279
+ )
280
+
281
+ input_ids = torch.cat([input_ids, next_tokens[:, None, :]], dim=1)
282
+ if streamer is not None:
283
+ streamer.put(next_tokens[:, 0].cpu())
284
+
285
+ # Update unfinished_sequences
286
+ needs_additional_steps = torch.where(
287
+ needs_additional_steps > 0, needs_additional_steps - 1, needs_additional_steps
288
+ )
289
+ stopping = stopping_criteria(input_ids[..., 0], scores) | (needs_additional_steps == 0)
290
+ unfinished_sequences = unfinished_sequences & ~stopping
291
+ unfinished_sequences = unfinished_sequences | (needs_additional_steps > 0)
292
+ this_peer_finished = unfinished_sequences.max() == 0
293
+
294
+ if return_dict_in_generate:
295
+ if output_scores:
296
+ scores += (next_token_scores,)
297
+ if output_logits:
298
+ raw_logits += (next_token_logits,)
299
+ if output_attentions:
300
+ decoder_attentions += (outputs.attentions,)
301
+ if output_hidden_states:
302
+ decoder_hidden_states += (outputs.hidden_states,)
303
+
304
+ cur_len += 1
305
+ del outputs
306
+
307
+ if streamer is not None:
308
+ streamer.end()
309
+
310
+ if return_dict_in_generate:
311
+ return GenerateDecoderOnlyOutput(
312
+ sequences=input_ids,
313
+ scores=scores,
314
+ logits=raw_logits,
315
+ attentions=decoder_attentions,
316
+ hidden_states=decoder_hidden_states,
317
+ past_key_values=model_kwargs.get("past_key_values"),
318
+ )
319
+ else:
320
+ return input_ids
321
+
322
+ @torch.no_grad()
323
+ def generate(
324
+ self,
325
+ input_ids: Optional[torch.Tensor] = None,
326
+ output_only: bool = True,
327
+ **kwargs,
328
+ ):
329
+ batch_size, seq_len, channels = input_ids.shape
330
+ start_id = seq_len - channels + 1
331
+ outputs = super().generate(input_ids, **kwargs)
332
+ return_dict_in_generate = kwargs.get("return_dict_in_generate", False)
333
+ if return_dict_in_generate:
334
+ output_ids = outputs["sequences"]
335
+ else:
336
+ output_ids = outputs
337
+ if output_only:
338
+ output_ids = output_ids[:, start_id:]
339
+ if return_dict_in_generate:
340
+ outputs["sequences"] = output_ids
341
+ else:
342
+ outputs = output_ids
343
+ return outputs
344
+
345
+
346
+
347
+ class MossTTSDPretrainedModel(PreTrainedModel):
348
+ """Base class for MOSS-TTSD pretrained models."""
349
+
350
+ config_class = MossTTSDConfig
351
+ base_model_prefix = "model"
352
+ supports_gradient_checkpointing = True
353
+ _no_split_modules = ["Qwen3DecoderLayer"]
354
+ _skip_keys_device_placement = ["past_key_values"]
355
+ _supports_flash_attn_2 = True
356
+ _supports_sdpa = True
357
+ _supports_flex_attn = True
358
+ _supports_cache_class = True
359
+ _supports_quantized_cache = True
360
+ _supports_static_cache = True
361
+ _supports_attention_backend = True
362
+
363
+
364
+ class MossTTSDModel(MossTTSDPretrainedModel):
365
+ """MOSS-TTSD model for text-to-speech synthesis."""
366
+
367
+ def __init__(self, config: MossTTSDConfig):
368
+ super().__init__(config)
369
+ self.text_pad_idx = config.pad_token_id
370
+ self.speech_pad_idx = config.speech_pad_token
371
+
372
+ self.embedding_list = nn.ModuleList([])
373
+ self.embedding_list.append(nn.Embedding(config.vocab_size, config.hidden_size, self.text_pad_idx))
374
+ # Channels 1 to channels-1: Speech tokens only
375
+ for _ in range(1, config.channels):
376
+ self.embedding_list.append(nn.Embedding(config.speech_vocab_size, config.hidden_size, self.speech_pad_idx))
377
+
378
+ self.language_model = Qwen3Model(config)
379
+ self.post_init()
380
+
381
+ def get_input_embeddings(self):
382
+ """Get the input embeddings for the model."""
383
+ return self.embedding_list[0]
384
+
385
+ def set_input_embeddings(self, value: nn.Embedding):
386
+ """Set the input embeddings for the model."""
387
+ self.embedding_list[0] = value
388
+
389
+ def _prepare_multi_modal_inputs(self, input_ids: torch.LongTensor) -> torch.FloatTensor:
390
+ """
391
+ Prepare multi-modal embeddings from input_ids of shape (batch_size, channels, sequence_length).
392
+
393
+ For channel 0: text + speech tokens, for channels 1 to channels-1: speech tokens padded with speech_pad_token.
394
+ """
395
+ batch_size, seq_length, channels = input_ids.shape
396
+ if channels != self.config.channels:
397
+ raise ValueError(f"Expected {self.config.channels} channels, got {channels}")
398
+
399
+ inputs_embeds = torch.zeros(
400
+ batch_size,
401
+ seq_length,
402
+ self.config.hidden_size,
403
+ device=input_ids.device,
404
+ dtype=self.embedding_list[0].weight.dtype,
405
+ )
406
+ for i in range(channels):
407
+ embed_layer = self.embedding_list[i]
408
+ channel_input = input_ids[..., i]
409
+ inputs_embeds += embed_layer(channel_input)
410
+
411
+ return inputs_embeds
412
+
413
+ def forward(
414
+ self,
415
+ input_ids: Optional[torch.LongTensor] = None,
416
+ attention_mask: Optional[torch.Tensor] = None,
417
+ position_ids: Optional[torch.LongTensor] = None,
418
+ past_key_values: Optional[list[torch.FloatTensor]] = None,
419
+ inputs_embeds: Optional[torch.FloatTensor] = None,
420
+ use_cache: Optional[bool] = None,
421
+ output_attentions: Optional[bool] = None,
422
+ output_hidden_states: Optional[bool] = None,
423
+ return_dict: Optional[bool] = None,
424
+ cache_position: Optional[torch.LongTensor] = None,
425
+ **kwargs,
426
+ ) -> Union[tuple, BaseModelOutputWithPast]:
427
+ """Forward pass for MOSS-TTSD model."""
428
+ if (input_ids is None) ^ (inputs_embeds is not None):
429
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
430
+
431
+ if input_ids is not None:
432
+ inputs_embeds = self._prepare_multi_modal_inputs(input_ids)
433
+
434
+ return self.language_model(
435
+ input_ids=None,
436
+ attention_mask=attention_mask,
437
+ position_ids=position_ids,
438
+ past_key_values=past_key_values,
439
+ inputs_embeds=inputs_embeds,
440
+ use_cache=use_cache,
441
+ output_attentions=output_attentions,
442
+ output_hidden_states=output_hidden_states,
443
+ return_dict=return_dict,
444
+ cache_position=cache_position,
445
+ )
446
+
447
+
448
+ class MossTTSDForCausalLM(MossTTSDPretrainedModel, MossTTSDGenerationMixin):
449
+ """MOSS-TTSD model for causal language modeling with multi-channel support."""
450
+
451
+ _tied_weights_keys = []
452
+ _tp_plan = {"lm_head": "colwise_rep"}
453
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
454
+
455
+ def __init__(self, config: MossTTSDConfig):
456
+ super().__init__(config)
457
+ self.model = MossTTSDModel(config)
458
+ self.channels = config.channels
459
+ self.weights = [1 for _ in range(self.channels)]
460
+ self._tied_weights_keys = [f"lm_heads.{i}.weight" for i in range(self.channels)]
461
+ self.vocab_size = config.vocab_size
462
+ self.lm_heads = nn.ModuleList([])
463
+ self.lm_heads.append(nn.Linear(config.hidden_size, config.vocab_size, bias=False))
464
+ for _ in range(1, config.channels):
465
+ self.lm_heads.append(nn.Linear(config.hidden_size, config.speech_vocab_size, bias=False))
466
+ self.post_init()
467
+
468
+ def get_input_embeddings(self):
469
+ """Get the input embeddings for the model."""
470
+ return self.model.embedding_list[0]
471
+
472
+ def can_generate(self):
473
+ """Check if the model can generate."""
474
+ return True
475
+
476
+ def is_speech_token(self, tokens: torch.Tensor) -> torch.Tensor:
477
+ """Check if tokens are speech tokens."""
478
+ return (tokens >= self.config.speech_token_range[0]) & (tokens < self.config.speech_token_range[1])
479
+
480
+ def tie_weights(self):
481
+ """Tie the weights between input embeddings and output embeddings."""
482
+ for i in range(self.config.channels):
483
+ self._tie_or_clone_weights(self.lm_heads[i], self.model.embedding_list[i])
484
+
485
+ def set_input_embeddings(self, value: nn.Embedding):
486
+ """Set the input embeddings for the model."""
487
+ self.model.embedding_list[0] = value
488
+
489
+ def get_output_embeddings(self):
490
+ """Get the output embeddings for the model."""
491
+ return self.lm_heads[0]
492
+
493
+ def set_output_embeddings(self, new_embeddings: nn.Linear):
494
+ """Set the output embeddings for the model."""
495
+ self.lm_heads[0] = new_embeddings
496
+
497
+ def set_decoder(self, decoder: MossTTSDModel):
498
+ """Set the decoder for the model."""
499
+ self.model = decoder
500
+
501
+ def get_decoder(self):
502
+ """Get the decoder for the model."""
503
+ return self.model
504
+
505
+ def set_weights(self, weights: list[float]):
506
+ """Set the weights for different channels."""
507
+ self.weights = weights
508
+
509
+ def _compute_loss(
510
+ self, hidden_states: torch.Tensor, labels: torch.LongTensor, skip_logits: bool, **kwargs
511
+ ) -> tuple[torch.Tensor, torch.Tensor, Optional[tuple[torch.Tensor, ...]]]:
512
+ """Compute loss for all channels."""
513
+ device = hidden_states.device
514
+ loss_all = torch.empty(self.channels, device=device)
515
+ logits_list = []
516
+
517
+ for i in range(self.config.channels):
518
+ vocab_size = self.config.vocab_size if i == 0 else self.config.speech_vocab_size
519
+ logits = self.lm_heads[i](hidden_states)
520
+ loss_all[i] = ForCausalLMLoss(logits, labels[..., i], vocab_size)
521
+ if not skip_logits:
522
+ logits_list.append(logits)
523
+
524
+ logits_all = tuple(logits_list) if logits_list else None
525
+
526
+ # Compute weighted total loss
527
+ total_weight = sum(self.weights)
528
+ normalized_weights = [w / total_weight for w in self.weights]
529
+ total_loss = sum(w * loss for w, loss in zip(normalized_weights, loss_all))
530
+
531
+ return total_loss, loss_all, logits_all
532
+
533
+ def forward(
534
+ self,
535
+ input_ids: Optional[torch.LongTensor] = None,
536
+ attention_mask: Optional[torch.Tensor] = None,
537
+ position_ids: Optional[torch.LongTensor] = None,
538
+ past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
539
+ inputs_embeds: Optional[torch.FloatTensor] = None,
540
+ labels: Optional[torch.LongTensor] = None,
541
+ use_cache: Optional[bool] = None,
542
+ output_attentions: Optional[bool] = None,
543
+ output_hidden_states: Optional[bool] = None,
544
+ return_dict: Optional[bool] = None,
545
+ cache_position: Optional[torch.LongTensor] = None,
546
+ skip_logits: Optional[bool] = None,
547
+ **kwargs,
548
+ ) -> Union[tuple, MossTTSDOutputWithPast]:
549
+ """Forward pass for MOSS-TTSD causal language model."""
550
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
551
+ output_hidden_states = (
552
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
553
+ )
554
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
555
+
556
+ skip_logits = skip_logits if skip_logits is not None else (self.training and labels is not None)
557
+ if skip_logits and labels is None:
558
+ skip_logits = False
559
+
560
+ # Decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
561
+ outputs = self.model(
562
+ input_ids=input_ids,
563
+ attention_mask=attention_mask,
564
+ position_ids=position_ids,
565
+ past_key_values=past_key_values,
566
+ inputs_embeds=inputs_embeds,
567
+ use_cache=use_cache,
568
+ output_attentions=output_attentions,
569
+ output_hidden_states=output_hidden_states,
570
+ return_dict=return_dict,
571
+ cache_position=cache_position,
572
+ **kwargs,
573
+ )
574
+
575
+ hidden_states = outputs[0]
576
+
577
+ logits_all = None
578
+ loss_all = None
579
+ total_loss = None
580
+
581
+ if labels is not None:
582
+ total_loss, loss_all, logits_all = self._compute_loss(hidden_states, labels, skip_logits, **kwargs)
583
+ else:
584
+ logits_all = [lm_head(hidden_states) for lm_head in self.lm_heads]
585
+ total_loss = None
586
+ loss_all = None
587
+
588
+ if not return_dict:
589
+ output = (logits_all,) + outputs[1:]
590
+ return (
591
+ (
592
+ total_loss,
593
+ loss_all,
594
+ )
595
+ + output
596
+ if total_loss is not None
597
+ else output
598
+ )
599
+
600
+ return MossTTSDOutputWithPast(
601
+ loss=total_loss,
602
+ logits=logits_all[0] if logits_all is not None else None,
603
+ loss_all=loss_all,
604
+ logits_all=logits_all,
605
+ past_key_values=outputs.past_key_values,
606
+ hidden_states=outputs.hidden_states,
607
+ attentions=outputs.attentions,
608
+ )
609
+
610
+
611
+ __all__ = ["MossTTSDModel", "MossTTSDForCausalLM"]
processing_moss_ttsd.py ADDED
@@ -0,0 +1,914 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 OpenMOSS and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Processor class for MOSS-TTSD.
17
+ """
18
+
19
+ from __future__ import annotations
20
+
21
+ import math
22
+ import os
23
+ import re
24
+ from dataclasses import asdict, dataclass
25
+ from typing import Any, Callable, Optional, Union
26
+
27
+ import numpy as np
28
+
29
+ from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
30
+ from transformers.tokenization_utils_base import BatchEncoding
31
+ from transformers.utils import is_torch_available, is_torchaudio_available
32
+ from transformers import AutoFeatureExtractor, AutoTokenizer, AutoModel
33
+ #from transformers.models.xy_tokenizer.modeling_xy_tokenizer import XYTokenizer
34
+
35
+
36
+ if is_torch_available():
37
+ import torch
38
+
39
+ if is_torchaudio_available():
40
+ import torchaudio
41
+
42
+
43
+ class MossTTSDProcessorKwargs(ProcessingKwargs, total=False):
44
+ """
45
+ Arguments for configuring MOSS-TTSD processing operations.
46
+
47
+ Inherits from ProcessingKwargs and provides structured configuration for text and audio processing.
48
+ """
49
+
50
+ _defaults = {
51
+ "text_kwargs": {
52
+ "pad_token_id": 0, # Fallback pad token ID, actual value comes from tokenizer.pad_token_id
53
+ },
54
+ "audio_kwargs": {
55
+ "max_channels": 8, # Maximum number of quantization channels
56
+ "audio_pad_token_id": 1024, # Padding token ID for non-text channels
57
+ "silence_duration": 0.0, # Duration of silence to append for encoder segmentation
58
+ "input_sample_rate": 16000, # Input audio sampling rate (fallback, inferred from audio_tokenizer.config)
59
+ "encoder_downsample_rate": 320, # Encoder downsampling rate (fallback, inferred from audio_tokenizer.config)
60
+ "speech_token_range": [151665, 152689], # Token range for speech tokens (first codebook offset mapping)
61
+ "audio_bos_token": "<|begin_of_speech|>",
62
+ "audio_eos_token": "<|end_of_speech|>",
63
+ },
64
+ "common_kwargs": {
65
+ "return_tensors": "pt",
66
+ "padding": True,
67
+ "use_normalize": False,
68
+ },
69
+ }
70
+
71
+
72
+ @dataclass
73
+ class MossTTSDChatSample:
74
+ """
75
+ Intermediate representation of a single sample with T×C grid layout and metadata.
76
+
77
+ Args:
78
+ input_ids_2d (`torch.LongTensor`):
79
+ Shape (T, C) tensor where column 0 contains text tokens and columns 1..C-1 contain
80
+ quantized audio codebooks (or padding token 1024 for empty slots).
81
+ label_ids_2d (`torch.LongTensor`, *optional*):
82
+ Optional label tensor for training, same shape as input_ids_2d.
83
+ meta (`dict`):
84
+ Dictionary containing metadata for debugging and tracking purposes.
85
+ """
86
+
87
+ input_ids_2d: "torch.LongTensor"
88
+ label_ids_2d: Optional["torch.LongTensor"]
89
+ meta: dict
90
+
91
+ @dataclass
92
+ class MossTTSDBatchInput:
93
+ """
94
+ Batched input tensors for MOSS-TTSD model.
95
+
96
+ Args:
97
+ input_ids (`torch.LongTensor`):
98
+ Shape (B, T, C) tensor containing batched input token IDs.
99
+ attention_mask (`torch.LongTensor`):
100
+ Shape (B, T) tensor containing attention mask for valid tokens.
101
+ labels (`torch.LongTensor`, *optional*):
102
+ Optional shape (B, T, C) tensor containing label token IDs for training.
103
+ """
104
+
105
+ input_ids: "torch.LongTensor"
106
+ attention_mask: "torch.LongTensor"
107
+ labels: Optional["torch.LongTensor"]
108
+
109
+
110
+ @dataclass
111
+ class MossTTSDResponse:
112
+ """
113
+ Unified response container for MOSS-TTSD inference outputs.
114
+
115
+ Args:
116
+ audio (`np.ndarray`, *optional*):
117
+ Optional numpy array containing generated audio waveform.
118
+ generated_text (`str`, *optional*, defaults to `""`):
119
+ String containing generated text output.
120
+ sampling_rate (`int`, *optional*):
121
+ Optional integer specifying the sampling rate of the generated audio.
122
+ """
123
+
124
+ audio: Optional[np.ndarray] = None
125
+ generated_text: str = ""
126
+ sampling_rate: Optional[int] = None
127
+
128
+
129
+ class MossTTSDSampleProcessor:
130
+ """
131
+ Sample-level processor for MOSS-TTSD that handles individual sample processing without batch padding.
132
+
133
+ This class handles per-sample processing logic:
134
+ - Parses JSONL items (text/prompt_text/prompt_audio)
135
+ - Optional text normalization
136
+ - Audio loading/resampling/merging, feature extraction and encoding
137
+ - Generates T×C grid and performs multi-channel shifting
138
+
139
+ Args:
140
+ tokenizer (`AutoTokenizer`):
141
+ The text tokenizer for encoding text tokens.
142
+ feature_extractor (`AutoFeatureExtractor`, *optional*):
143
+ Optional feature extractor for audio preprocessing.
144
+ audio_tokenizer (`AutoModel`, *optional*):
145
+ Optional audio tokenizer for audio encoding/decoding.
146
+ chat_template (`str`, *optional*):
147
+ Optional chat template string for conversation formatting.
148
+ speech_token_range (`List[int]`):
149
+ List of [start, end] token IDs for speech token mapping.
150
+ audio_bos_token (`str`):
151
+ Beginning of speech token string.
152
+ audio_eos_token (`str`):
153
+ End of speech token string.
154
+ audio_pad_token_id (`int`):
155
+ Padding token ID for audio channels.
156
+ max_channels (`int`):
157
+ Maximum number of quantization channels.
158
+ input_sample_rate (`int`):
159
+ Target sample rate for input audio.
160
+ encoder_downsample_rate (`int`):
161
+ Downsampling rate of the audio encoder.
162
+ """
163
+
164
+ def __init__(
165
+ self,
166
+ tokenizer,
167
+ feature_extractor: Optional = None,
168
+ audio_tokenizer: Optional = None,
169
+ *,
170
+ chat_template: Optional[str],
171
+ speech_token_range: list[int],
172
+ audio_bos_token: str,
173
+ audio_eos_token: str,
174
+ audio_pad_token_id: int,
175
+ max_channels: int,
176
+ input_sample_rate: int,
177
+ encoder_downsample_rate: int,
178
+ ) -> None:
179
+ self.tokenizer = tokenizer
180
+ self.feature_extractor = feature_extractor
181
+ self.audio_tokenizer = audio_tokenizer
182
+ self.chat_template = chat_template
183
+ self.speech_token_range = speech_token_range
184
+ self.audio_bos_token = audio_bos_token
185
+ self.audio_eos_token = audio_eos_token
186
+ self.audio_pad_token_id = audio_pad_token_id
187
+ self.max_channels = max_channels
188
+ self.input_sample_rate = input_sample_rate
189
+ self.encoder_downsample_rate = encoder_downsample_rate
190
+
191
+ def prepare_sample(
192
+ self,
193
+ item: dict[str, Any],
194
+ *,
195
+ apply_chat_template: Callable[[str, dict], str],
196
+ use_normalize: bool = False,
197
+ silence_duration: float = 0.0,
198
+ **kwargs,
199
+ ) -> MossTTSDChatSample:
200
+ """
201
+ Prepare a single sample from JSONL item into MossTTSDChatSample format.
202
+
203
+ Args:
204
+ item (`dict`):
205
+ Dictionary containing the input data (text, prompt_audio, etc.).
206
+ apply_chat_template (`callable`):
207
+ Function to apply chat template formatting.
208
+ use_normalize (`bool`, *optional*, defaults to `False`):
209
+ Whether to apply text normalization.
210
+ silence_duration (`float`, *optional*, defaults to `0.0`):
211
+ Duration of silence to append to audio for encoder segmentation.
212
+ **kwargs:
213
+ Additional keyword arguments passed to chat template.
214
+
215
+ Returns:
216
+ `MossTTSDChatSample`: Processed sample with 2D input tensor and metadata.
217
+ """
218
+ processed = self._process_jsonl_item(item)
219
+ system_prompt = item.get("system_prompt")
220
+ if isinstance(system_prompt, str):
221
+ kwargs["system_prompt"] = system_prompt
222
+
223
+ full_text = (processed["prompt_text"] or "") + processed["text"]
224
+ original_full_text = full_text
225
+ if use_normalize:
226
+ full_text = self._normalize_text(full_text)
227
+ final_text = full_text.replace("[S1]", "<speaker1>").replace("[S2]", "<speaker2>")
228
+
229
+ # Load and resample audio (may be None)
230
+ wav = self._process_audio_data(processed["prompt_audio"], target_sample_rate=self.input_sample_rate)
231
+
232
+ # Assemble into grid (T, C)
233
+ inputs_2d = self._build_inputs(
234
+ text=final_text,
235
+ audio_data=wav,
236
+ apply_chat_template=apply_chat_template,
237
+ silence_duration=silence_duration,
238
+ **kwargs,
239
+ )
240
+ inputs_2d = self._shift_inputs(inputs_2d, pad_token_id=self.tokenizer.pad_token_id, max_channels=self.max_channels)
241
+
242
+ meta = {
243
+ "original_text": original_full_text,
244
+ "normalized_text": self._normalize_text(original_full_text) if use_normalize else None,
245
+ "final_text": final_text,
246
+ "use_normalize": use_normalize,
247
+ }
248
+ ids_t = torch.tensor(inputs_2d, dtype=torch.long)
249
+ return MossTTSDChatSample(input_ids_2d=ids_t, label_ids_2d=None, meta=meta)
250
+
251
+ def collate(
252
+ self,
253
+ samples: list[MossTTSDChatSample],
254
+ *,
255
+ pad_token_id: int,
256
+ audio_pad_token_id: int,
257
+ ) -> MossTTSDBatchInput:
258
+ """
259
+ Collate multiple samples into a batch with proper padding.
260
+
261
+ Args:
262
+ samples (`List[MossTTSDChatSample]`):
263
+ List of MossTTSDChatSample objects to collate.
264
+ pad_token_id (`int`):
265
+ Padding token ID for text tokens.
266
+ audio_pad_token_id (`int`):
267
+ Padding token ID for audio tokens.
268
+
269
+ Returns:
270
+ `MossTTSDBatchInput`: Batched input with padded tensors.
271
+ """
272
+ assert is_torch_available(), "PyTorch is required for collation."
273
+ ids_list = [s.input_ids_2d for s in samples]
274
+ labels_list = [s.label_ids_2d for s in samples]
275
+
276
+ C = ids_list[0].shape[1]
277
+ max_len = max(x.shape[0] for x in ids_list)
278
+ padded_ids, padded_labels, padded_attn = [], [], []
279
+
280
+ for ids, labels in zip(ids_list, labels_list):
281
+ pad_len = max_len - ids.shape[0]
282
+ pad_grid = torch.full((pad_len, C), audio_pad_token_id, dtype=torch.long)
283
+ pad_grid[:, 0] = pad_token_id # Text column uses tokenizer pad
284
+ ids_padded = torch.cat([pad_grid, ids], dim=0)
285
+ padded_ids.append(ids_padded)
286
+
287
+ attn = torch.ones(ids.shape[0], dtype=torch.long)
288
+ a_pad = torch.zeros(pad_len, dtype=torch.long)
289
+ padded_attn.append(torch.cat([a_pad, attn], dim=0))
290
+
291
+ if labels is None:
292
+ padded_labels.append(None)
293
+ else:
294
+ lab_pad = torch.full((pad_len, C), audio_pad_token_id, dtype=torch.long)
295
+ lab_pad[:, 0] = -100 # Text labels are ignored by default
296
+ padded_labels.append(torch.cat([lab_pad, labels], dim=0))
297
+
298
+ input_ids = torch.stack(padded_ids) # (B, T, C)
299
+ attention_mask = torch.stack(padded_attn) # (B, T)
300
+ labels = torch.stack([l if l is not None else torch.full_like(input_ids[0], -100) for l in padded_labels]) \
301
+ if any(l is not None for l in padded_labels) else None
302
+
303
+ return MossTTSDBatchInput(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
304
+
305
+ @staticmethod
306
+ def _process_jsonl_item(item: dict[str, Any]) -> dict[str, Any]:
307
+ """
308
+ Process a JSONL item to extract text and audio data.
309
+
310
+ Supports both single-speaker and multi-speaker formats:
311
+ - Single: {"prompt_audio": path, "prompt_text": text}
312
+ - Multi: {"prompt_audio_speaker1": path1, "prompt_text_speaker1": text1, ...}
313
+
314
+ Args:
315
+ item: Dictionary containing the JSONL item data.
316
+
317
+ Returns:
318
+ Dictionary with extracted "text", "prompt_text", and "prompt_audio" fields.
319
+ """
320
+ base_path = item.get("base_path", "")
321
+ text = item.get("text", "")
322
+
323
+ prompt_audio = None
324
+ prompt_text = ""
325
+
326
+ if "prompt_audio" in item and "prompt_text" in item:
327
+ pa = item["prompt_audio"]
328
+ if pa:
329
+ prompt_audio = os.path.join(base_path, pa) if isinstance(pa, str) and base_path else pa
330
+ prompt_text = item.get("prompt_text", "")
331
+ else:
332
+ pa1, pt1 = item.get("prompt_audio_speaker1", ""), item.get("prompt_text_speaker1", "")
333
+ pa2, pt2 = item.get("prompt_audio_speaker2", ""), item.get("prompt_text_speaker2", "")
334
+ has1 = (isinstance(pa1, str) and pa1) or isinstance(pa1, tuple)
335
+ has2 = (isinstance(pa2, str) and pa2) or isinstance(pa2, tuple)
336
+ if has1 or has2:
337
+ spk1 = os.path.join(base_path, pa1) if isinstance(pa1, str) and base_path and pa1 else pa1
338
+ spk2 = os.path.join(base_path, pa2) if isinstance(pa2, str) and base_path and pa2 else pa2
339
+ prompt_audio = {"speaker1": spk1, "speaker2": spk2}
340
+ tmp = ""
341
+ if pt1:
342
+ tmp += f"[S1]{pt1}"
343
+ if pt2:
344
+ tmp += f"[S2]{pt2}"
345
+ prompt_text = tmp.strip()
346
+
347
+ return {"text": text, "prompt_text": prompt_text, "prompt_audio": prompt_audio}
348
+
349
+ @staticmethod
350
+ def _normalize_text(text: str) -> str:
351
+ """
352
+ Normalize text by applying various transformations for TTS processing.
353
+
354
+ Performs speaker tag conversion, punctuation normalization, laughter conversion,
355
+ and other text cleaning operations suitable for speech synthesis.
356
+
357
+ Args:
358
+ text: Input text string to normalize.
359
+
360
+ Returns:
361
+ Normalized text string.
362
+ """
363
+ text = re.sub(r"\[(\d+)\]", r"[S\1]", text)
364
+ remove_chars = '【】《》()『』「」"-""~~'
365
+ text = re.sub(r"\[(?!S\d+\])([^\]]*)\]", r"\1", text)
366
+ segments = re.split(r"(?=\[S\d+\])", text.replace("\n", " "))
367
+ out = []
368
+ for seg in segments:
369
+ seg = seg.strip()
370
+ if not seg:
371
+ continue
372
+ m = re.match(r"^(\[S\d+\])\s*(.*)", seg)
373
+ tag, content = m.groups() if m else ("", seg)
374
+ content = re.sub(f"[{re.escape(remove_chars)}]", "", content)
375
+ content = re.sub(r"哈{2,}", "(笑)", content)
376
+ content = re.sub(r"\b(ha(\s*ha)+)\b", "(laughs)", content, flags=re.IGNORECASE)
377
+ content = content.replace("——", ",").replace("……", ",")
378
+ trans = str.maketrans({"!": ",", "!": ",", ";": ",", ";": ",", ":": ",", ":": ",", "、": ",", "?": ",", "?": ","})
379
+ content = content.translate(trans).strip()
380
+ if len(content) > 1:
381
+ last = "。" if content[-1] == "," else ("." if content[-1] == "," else content[-1])
382
+ body = content[:-1].replace("。", ",")
383
+ content = body + last
384
+ out.append(f"{tag}{content}".strip())
385
+ return "".join(out)
386
+
387
+ @staticmethod
388
+ def _load_single_audio(audio_input: Union[str, tuple["torch.Tensor", int]]):
389
+ """
390
+ Load audio from file path or tensor tuple.
391
+
392
+ Args:
393
+ audio_input: Either a file path string or a tuple of (tensor, sample_rate).
394
+
395
+ Returns:
396
+ Tuple of (audio_tensor, sample_rate).
397
+
398
+ Raises:
399
+ ValueError: If audio input format is unsupported.
400
+ """
401
+ if isinstance(audio_input, tuple) and len(audio_input) == 2:
402
+ return audio_input
403
+ if isinstance(audio_input, str):
404
+ try:
405
+ return torchaudio.load(audio_input)
406
+ except Exception:
407
+ import soundfile as sf # type: ignore
408
+ data, sr = sf.read(audio_input, always_2d=True)
409
+ data_t = torch.from_numpy(np.transpose(data)) # (C, T)
410
+ return data_t, int(sr)
411
+ raise ValueError(f"Unsupported audio input format: {type(audio_input)}")
412
+
413
+ @staticmethod
414
+ def _resample(audio: "torch.Tensor", sr: int, target_sr: int) -> tuple["torch.Tensor", int]:
415
+ """
416
+ Resample audio to target sample rate and convert to mono if needed.
417
+
418
+ Args:
419
+ audio: Input audio tensor with shape (channels, time).
420
+ sr: Current sample rate.
421
+ target_sr: Target sample rate.
422
+
423
+ Returns:
424
+ Tuple of (resampled_audio, target_sr) where audio is mono with shape (1, time).
425
+ """
426
+ if sr != target_sr:
427
+ audio = torchaudio.functional.resample(audio, sr, target_sr)
428
+ if audio.shape[0] > 1:
429
+ audio = audio.mean(dim=0, keepdim=True)
430
+ if audio.ndim == 1:
431
+ audio = audio.unsqueeze(0)
432
+ return audio, target_sr
433
+
434
+ @classmethod
435
+ def _load_audio_data(
436
+ cls, audio_input: Union[str, tuple["torch.Tensor", int]], target_sample_rate: int
437
+ ) -> tuple["torch.Tensor", int]:
438
+ """
439
+ Load and resample audio data to target sample rate.
440
+
441
+ Args:
442
+ audio_input: Audio file path or tensor tuple.
443
+ target_sample_rate: Target sample rate for resampling.
444
+
445
+ Returns:
446
+ Tuple of (audio_tensor, target_sample_rate).
447
+ """
448
+ audio, sr = cls._load_single_audio(audio_input)
449
+ return cls._resample(audio, sr, target_sample_rate)
450
+
451
+ @classmethod
452
+ def _merge_speaker_audios(
453
+ cls,
454
+ wav1: Union[str, tuple["torch.Tensor", int]],
455
+ wav2: Union[str, tuple["torch.Tensor", int]],
456
+ target_sample_rate: int,
457
+ ) -> "torch.Tensor":
458
+ """
459
+ Merge two speaker audio inputs by concatenation.
460
+
461
+ Args:
462
+ wav1: Audio input for speaker 1.
463
+ wav2: Audio input for speaker 2.
464
+ target_sample_rate: Target sample rate for both audio inputs.
465
+
466
+ Returns:
467
+ Concatenated audio tensor.
468
+ """
469
+ a1, _ = cls._load_audio_data(wav1, target_sample_rate)
470
+ a2, _ = cls._load_audio_data(wav2, target_sample_rate)
471
+ return torch.cat([a1, a2], dim=1)
472
+
473
+ @classmethod
474
+ def _process_audio_data(
475
+ cls, prompt_audio: Optional[Union[str, dict[str, Any], tuple["torch.Tensor", int]]], target_sample_rate: int
476
+ ) -> Optional["torch.Tensor"]:
477
+ """
478
+ Process audio data from various input formats.
479
+
480
+ Handles single audio files, multi-speaker audio dictionaries, or None input.
481
+
482
+ Args:
483
+ prompt_audio: Audio input in various formats (path, dict, tensor tuple, or None).
484
+ target_sample_rate: Target sample rate for processing.
485
+
486
+ Returns:
487
+ Processed audio tensor or None if no audio provided.
488
+ """
489
+ if prompt_audio is None:
490
+ return None
491
+ if isinstance(prompt_audio, dict) and "speaker1" in prompt_audio and "speaker2" in prompt_audio:
492
+ return cls._merge_speaker_audios(prompt_audio["speaker1"], prompt_audio["speaker2"], target_sample_rate)
493
+ wav, _ = cls._load_audio_data(prompt_audio, target_sample_rate)
494
+ return wav
495
+
496
+ def _build_inputs(
497
+ self,
498
+ text: str,
499
+ audio_data: Optional["torch.Tensor"],
500
+ apply_chat_template: Callable[[str, dict], str],
501
+ silence_duration: float,
502
+ **kwargs,
503
+ ) -> np.ndarray:
504
+ """
505
+ Build input grid from text and optional audio data.
506
+
507
+ Creates a TxC grid where column 0 contains text tokens and columns 1..C-1 contain
508
+ quantized audio codebook tokens. Audio tokens are mapped to speech token range.
509
+
510
+ Args:
511
+ text: Input text string to process.
512
+ audio_data: Optional audio tensor with shape (channels, time).
513
+ apply_chat_template: Function to apply chat template formatting.
514
+ silence_duration: Duration of silence to append for encoder segmentation.
515
+ **kwargs: Additional arguments for chat template.
516
+
517
+ Returns:
518
+ NumPy array with shape (T, max_channels) containing the input grid.
519
+ """
520
+ assert isinstance(text, str), "text must be a string"
521
+ prompt = apply_chat_template(text, kwargs)
522
+
523
+ text_ids = np.array(self.tokenizer.encode(prompt, add_special_tokens=False))
524
+ grid = np.full((text_ids.shape[0], self.max_channels), self.audio_pad_token_id, dtype=np.int64)
525
+ grid[:, 0] = text_ids
526
+
527
+ if audio_data is not None:
528
+ silence_samples = int(max(0.0, silence_duration) * self.input_sample_rate)
529
+ silence = torch.zeros(audio_data.shape[0], silence_samples, device=audio_data.device)
530
+ wav = torch.cat([audio_data, silence], dim=1)
531
+
532
+ feat = self.feature_extractor(
533
+ wav, sampling_rate=self.input_sample_rate, return_attention_mask=True, return_tensors="pt"
534
+ )
535
+ with torch.no_grad():
536
+ enc = self.audio_tokenizer.encode(feat)
537
+ # (time, codebooks)
538
+ audio_codes = enc["audio_codes"][:, 0].permute(1, 0).cpu().numpy()
539
+ # Map first codebook to speech token range
540
+ audio_codes[:, 0] = audio_codes[:, 0] + self.speech_token_range[0]
541
+ grid = np.concatenate([grid, audio_codes], axis=0)
542
+
543
+ # Trim silence tokens at the end based on encoder downsampling
544
+ silence_tokens = silence_duration * self.input_sample_rate / self.encoder_downsample_rate
545
+ cut = math.floor(silence_tokens / 10) * 10
546
+ if cut > 0:
547
+ grid = grid[:-cut]
548
+
549
+ return grid
550
+
551
+ @staticmethod
552
+ def _shift_inputs(input_ids: np.ndarray, pad_token_id: int, max_channels: int) -> np.ndarray:
553
+ """
554
+ Convert (T, C) grid to time-shifted multi-channel layout (preserving original implementation logic).
555
+
556
+ Creates a shifted layout where new_len = T + C - 1, with column j shifted backwards by j positions.
557
+ This enables the model to process multiple codebook channels with temporal alignment.
558
+
559
+ Args:
560
+ input_ids: Input grid with shape (T, C).
561
+ pad_token_id: Padding token ID for text tokens.
562
+ max_channels: Maximum number of channels.
563
+
564
+ Returns:
565
+ Shifted array with shape (T + max_channels - 1, max_channels).
566
+ """
567
+ T, _ = input_ids.shape
568
+ new_len = T + max_channels - 1
569
+ shifted = np.full((new_len, max_channels), fill_value=1024, dtype=np.int64)
570
+ shifted[:, 0] = np.full(new_len, pad_token_id, dtype=np.int64)
571
+ for j in range(max_channels):
572
+ shifted[j : (T + j), j] = input_ids[:, j]
573
+ return shifted
574
+
575
+
576
+ class MossTTSDProcessor(ProcessorMixin):
577
+ r"""
578
+ Constructs a MOSS-TTSD processor which wraps a tokenizer, feature extractor, and audio tokenizer into a single
579
+ processor. It provides unified text-speech processing capabilities while maintaining backward compatibility with
580
+ previous API versions.
581
+
582
+ [`MossTTSDProcessor`] offers all the functionalities of [`AutoTokenizer`], [`AutoFeatureExtractor`] and
583
+ [`XYTokenizer`]. See the [`~MossTTSDProcessor.__call__`] and [`~MossTTSDProcessor.decode`] for more information.
584
+
585
+ Args:
586
+ tokenizer ([`AutoTokenizer`]):
587
+ An instance of [`AutoTokenizer`]. The tokenizer is a required input.
588
+ feature_extractor ([`AutoFeatureExtractor`]):
589
+ An instance of [`AutoFeatureExtractor`]. The feature extractor is a required input.
590
+ audio_tokenizer ([`XYTokenizer`]):
591
+ An instance of [`XYTokenizer`]. The audio tokenizer is a required input.
592
+ chat_template (`str`, *optional*):
593
+ A template string for chat formatting when combining text and audio interactions.
594
+ speech_token_range (`List[int]`, *optional*, defaults to `[151665, 152689]`):
595
+ Token range [start, end] for mapping speech tokens.
596
+ audio_bos_token (`str`, *optional*, defaults to `"<|begin_of_speech|>"`):
597
+ Beginning of speech token string.
598
+ audio_eos_token (`str`, *optional*, defaults to `"<|end_of_speech|>"`):
599
+ End of speech token string.
600
+ audio_pad_token_id (`int`, *optional*, defaults to `1024`):
601
+ Padding token ID for audio channels.
602
+ """
603
+ feature_extractor_class = "AutoFeatureExtractor"
604
+ tokenizer_class = "AutoTokenizer"
605
+ audio_tokenizer_class = "PreTrainedModel"
606
+
607
+ def __init__(
608
+ self,
609
+ tokenizer,
610
+ feature_extractor,
611
+ audio_tokenizer,
612
+ chat_template: Optional[str] = None,
613
+ speech_token_range: Optional[list[int]] = None,
614
+ audio_bos_token: str = "<|begin_of_speech|>",
615
+ audio_eos_token: str = "<|end_of_speech|>",
616
+ audio_pad_token_id: int = 1024,
617
+ **kwargs,
618
+ ) -> None:
619
+ super().__init__(tokenizer=tokenizer, feature_extractor=feature_extractor, audio_tokenizer=audio_tokenizer, **kwargs)
620
+
621
+ self.max_channels = (audio_tokenizer.quantizer.num_quantizers if audio_tokenizer else None) or 8
622
+ self.input_sample_rate = (getattr(audio_tokenizer, "config", None).input_sample_rate if audio_tokenizer else None) or 16000
623
+ self.output_sample_rate = (getattr(audio_tokenizer, "config", None).output_sample_rate if audio_tokenizer else None) or 16000
624
+ self.encoder_downsample_rate = (getattr(audio_tokenizer, "config", None).encoder_downsample_rate if audio_tokenizer else None) or 320
625
+
626
+ # Use tokenizer's built-in chat template as primary
627
+ self.chat_template = getattr(tokenizer, "chat_template", None) or chat_template
628
+
629
+ # Read speech token range from tokenizer with fallback
630
+ self.speech_token_range = (
631
+ getattr(tokenizer, "speech_token_range", None) or speech_token_range or [151665, 152689]
632
+ )
633
+ self.audio_bos_token = getattr(tokenizer, "audio_bos_token", None) or audio_bos_token
634
+ self.audio_eos_token = getattr(tokenizer, "audio_eos_token", None) or audio_eos_token
635
+ self.audio_pad_token_id = getattr(tokenizer, "audio_pad_token_id", None) or audio_pad_token_id
636
+
637
+ # Sample-level processor
638
+ self.sample_processor = MossTTSDSampleProcessor(
639
+ tokenizer=self.tokenizer,
640
+ feature_extractor=self.feature_extractor,
641
+ audio_tokenizer=self.audio_tokenizer,
642
+ chat_template=self.chat_template,
643
+ speech_token_range=self.speech_token_range,
644
+ audio_bos_token=self.audio_bos_token,
645
+ audio_eos_token=self.audio_eos_token,
646
+ audio_pad_token_id=self.audio_pad_token_id,
647
+ max_channels=self.max_channels,
648
+ input_sample_rate=self.input_sample_rate,
649
+ encoder_downsample_rate=self.encoder_downsample_rate,
650
+ )
651
+
652
+ @classmethod
653
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], trust_remote_code=True, **kwargs):
654
+ """
655
+ Instantiate a processor from a pretrained model.
656
+
657
+ Args:
658
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
659
+ The name of or path to the pretrained model.
660
+ **kwargs:
661
+ Additional keyword arguments passed to the respective component loaders.
662
+
663
+ Returns:
664
+ [`MossTTSDProcessor`]: A new instance of the processor.
665
+ """
666
+ kwargs.pop("_from_auto")
667
+ audio_tokenizer_path = kwargs.pop("codec_path", os.path.join(pretrained_model_name_or_path, "XY_Tokenizer"))
668
+ assert isinstance(audio_tokenizer_path, str), f"Unsupported audio_tokenizer_path input format: {type(audio_tokenizer_path)}"
669
+
670
+ tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)
671
+ feature_extractor = AutoFeatureExtractor.from_pretrained(audio_tokenizer_path, trust_remote_code=trust_remote_code, **kwargs)
672
+ audio_tokenizer = AutoModel.from_pretrained(audio_tokenizer_path, trust_remote_code=trust_remote_code, **kwargs)
673
+
674
+ return cls(
675
+ tokenizer=tokenizer,
676
+ feature_extractor=feature_extractor,
677
+ audio_tokenizer=audio_tokenizer,
678
+ **kwargs,
679
+ )
680
+
681
+ @classmethod
682
+ def get_processor_dict(
683
+ cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
684
+ ) -> tuple[dict[str, Any], dict[str, Any]]:
685
+ proc_dict, rest = super().get_processor_dict(pretrained_model_name_or_path, **kwargs)
686
+ if "audio_tokenizer" in rest:
687
+ proc_dict["audio_tokenizer"] = rest.pop("audio_tokenizer")
688
+ for key in ("speech_token_range", "audio_bos_token", "audio_eos_token", "audio_pad_token_id"):
689
+ if key in rest:
690
+ proc_dict[key] = rest.pop(key)
691
+ return proc_dict, rest
692
+
693
+ def __call__(
694
+ self,
695
+ data: Union[dict[str, Any], list[dict[str, Any]]],
696
+ **kwargs: Unpack[MossTTSDProcessorKwargs],
697
+ ) -> BatchEncoding:
698
+ """
699
+ Main method to prepare inputs for the model from structured data.
700
+
701
+ This method forwards the `data` and `kwargs` arguments to prepare inputs for MOSS-TTSD model. Please refer to the
702
+ docstring of the respective methods for more information.
703
+
704
+ Args:
705
+ data (`dict` or `list[dict]`):
706
+ Single dictionary or list of dictionaries containing input data. Expected keys include 'text',
707
+ 'prompt_text', 'prompt_audio', etc.
708
+ **kwargs (`MossTTSDProcessorKwargs`):
709
+ Additional processing arguments.
710
+
711
+ Returns:
712
+ [`BatchEncoding`]: Processed inputs ready for model consumption.
713
+ """
714
+ if isinstance(data, dict):
715
+ data = [data]
716
+
717
+ out_kwargs = self._merge_kwargs(MossTTSDProcessorKwargs, **kwargs)
718
+ text_kwargs = out_kwargs["text_kwargs"]
719
+ audio_kwargs = out_kwargs["audio_kwargs"]
720
+ common_kwargs = out_kwargs["common_kwargs"]
721
+
722
+ return_tensors = common_kwargs.get("return_tensors", "pt")
723
+ padding = common_kwargs.get("padding", True)
724
+ use_normalize = common_kwargs.get("use_normalize", False)
725
+
726
+ pad_token_id = int(text_kwargs.get("pad_token_id", self.tokenizer.pad_token_id or 0))
727
+ max_channels = int(audio_kwargs.get("max_channels", self.max_channels))
728
+ audio_pad_token_id = int(audio_kwargs.get("audio_pad_token_id", self.audio_pad_token_id))
729
+ silence_duration = float(audio_kwargs.get("silence_duration", 0.0))
730
+
731
+ def _apply_chat_template(text: str, extra: dict) -> str:
732
+ return self.apply_chat_template(conversation=None, text=text, **extra)
733
+
734
+ samples: list[MossTTSDChatSample] = []
735
+ for item in data:
736
+ sample = self.sample_processor.prepare_sample(
737
+ item,
738
+ apply_chat_template=_apply_chat_template,
739
+ use_normalize=use_normalize,
740
+ silence_duration=silence_duration,
741
+ )
742
+ # Override with call-time max_channels (may differ from component initialization)
743
+ if sample.input_ids_2d.shape[1] != max_channels:
744
+ # Simplified: for clipping/extending channels, only pad/clip on the right side
745
+ T, C = sample.input_ids_2d.shape
746
+ if C > max_channels:
747
+ sample.input_ids_2d = sample.input_ids_2d[:, :max_channels]
748
+ else:
749
+ pad = torch.full((T, max_channels - C), audio_pad_token_id, dtype=torch.long)
750
+ sample.input_ids_2d = torch.cat([sample.input_ids_2d, pad], dim=1)
751
+ samples.append(sample)
752
+
753
+ if not padding:
754
+ raise NotImplementedError("Unpadded batches are not supported yet.")
755
+
756
+ batch = self.sample_processor.collate(
757
+ samples,
758
+ pad_token_id=pad_token_id,
759
+ audio_pad_token_id=audio_pad_token_id,
760
+ )
761
+ # Align with HiggsAudioProcessor: explicit dict -> BatchEncoding/Feature
762
+ inputs = asdict(batch)
763
+ inputs = {k: v for k, v in inputs.items() if v is not None}
764
+ return BatchEncoding(inputs, tensor_type=return_tensors)
765
+
766
+ def shifting_outputs(
767
+ self,
768
+ output_ids: "torch.Tensor",
769
+ speech_token_range: list[int],
770
+ max_channels: int = 8,
771
+ ) -> "torch.Tensor":
772
+ """
773
+ Restore time-shifted layout to per-timestep C-channel arrangement and reverse-offset first codebook.
774
+
775
+ Converts the time-shifted multi-channel output back to standard (batch, time, channels) format
776
+ and maps the first codebook tokens back to their original space by subtracting the speech token offset.
777
+
778
+ Args:
779
+ output_ids: Time-shifted output tensor.
780
+ speech_token_range: Speech token range for reverse mapping.
781
+ max_channels: Number of codebook channels.
782
+
783
+ Returns:
784
+ Restored tensor with shape (batch, seq_len, max_channels).
785
+ """
786
+ seq_len = output_ids.shape[1] - max_channels + 1
787
+ speech_ids = torch.full((output_ids.shape[0], seq_len, max_channels), 0, dtype=output_ids.dtype, device=output_ids.device)
788
+ for j in range(max_channels):
789
+ speech_ids[..., j] = output_ids[:, j : seq_len + j, j]
790
+ if j == 0:
791
+ speech_ids[..., j] = speech_ids[..., j] - speech_token_range[0]
792
+ return speech_ids
793
+
794
+ def _find_max_valid_positions(self, data: "torch.Tensor", invalid_value: int = 1024):
795
+ """
796
+ Locate continuous valid audio segment intervals in each sequence (all non-text channels valid simultaneously).
797
+
798
+ Identifies contiguous spans where all audio channels (columns 1+) contain valid tokens
799
+ (not the invalid_value padding token).
800
+
801
+ Args:
802
+ data: Input tensor with shape (batch, time, channels).
803
+ invalid_value: Token ID considered as invalid/padding.
804
+
805
+ Returns:
806
+ List of lists containing valid audio segments for each sequence in the batch.
807
+ """
808
+ mask = torch.all(data[:, :, 1:] != invalid_value, dim=2)
809
+ valid_indices = torch.where(mask)
810
+ result = [[] for _ in range(len(data))]
811
+ if valid_indices[0].numel() == 0:
812
+ return result
813
+ grouped = []
814
+ group_ids = []
815
+ for i, seq_no in enumerate(valid_indices[0]):
816
+ pos = valid_indices[1][i]
817
+ if not group_ids or seq_no > group_ids[-1]:
818
+ group_ids.append(seq_no)
819
+ grouped.append([[pos, pos + 1]])
820
+ elif pos == grouped[-1][-1][-1]:
821
+ grouped[-1][-1][-1] += 1
822
+ else:
823
+ grouped[-1].append([pos, pos + 1])
824
+ for gid, spans in zip(group_ids, grouped):
825
+ for s, e in spans:
826
+ result[gid].append(data[gid, s:e, :])
827
+ return result
828
+
829
+ def batch_decode(self, token_ids: "torch.Tensor", *args, **kwargs):
830
+ """
831
+ Decode a batch of token sequences into text and audio outputs.
832
+
833
+ This method forwards the `token_ids` and `kwargs` arguments to decode text and audio outputs from the model.
834
+ Please refer to the docstring of the respective methods for more information.
835
+
836
+ Args:
837
+ token_ids (`torch.Tensor`):
838
+ Token tensor with shape (batch, time, channels).
839
+ *args:
840
+ Additional arguments passed to tokenizer.batch_decode.
841
+ **kwargs:
842
+ Additional keyword arguments passed to tokenizer.batch_decode.
843
+
844
+ Returns:
845
+ `tuple`: Tuple of (text_list, audio_list) where text_list contains decoded text strings and audio_list
846
+ contains decoded audio arrays for each sequence.
847
+ """
848
+ assert token_ids.ndim == 3 and token_ids.shape[2] == self.max_channels
849
+ text = self.tokenizer.batch_decode(token_ids[:, :, 0], *args, **kwargs)
850
+ normal = self.shifting_outputs(token_ids, self.speech_token_range, self.max_channels)
851
+ audio_frags = self._find_max_valid_positions(normal, self.audio_pad_token_id)
852
+ decode_audio = []
853
+ for seq_frags in audio_frags:
854
+ if len(seq_frags):
855
+ frag = torch.cat([f.permute(1, 0).unsqueeze(1) for f in seq_frags], dim=1)
856
+ decode_audio.append(self.audio_tokenizer.decode(frag, overlap_seconds=10)["audio_values"])
857
+ else:
858
+ decode_audio.append([])
859
+ return text, decode_audio
860
+
861
+ def decode(self, token_ids: "torch.Tensor", *args, **kwargs) -> MossTTSDResponse:
862
+ """
863
+ Decode a single sequence of token IDs into text and audio.
864
+
865
+ This method forwards the `token_ids` and `kwargs` arguments to decode a single sequence. Please refer to the
866
+ docstring of the respective methods for more information.
867
+
868
+ Args:
869
+ token_ids (`torch.Tensor`):
870
+ Token tensor with shape (time, channels).
871
+ *args:
872
+ Additional arguments passed to tokenizer.decode.
873
+ **kwargs:
874
+ Additional keyword arguments passed to tokenizer.decode.
875
+
876
+ Returns:
877
+ [`MossTTSDResponse`]: Response object containing generated text, audio, and sampling rate.
878
+ """
879
+ assert token_ids.ndim == 2 and token_ids.shape[1] == self.max_channels
880
+ text = self.tokenizer.decode(token_ids[:, 0].squeeze(-1), *args, **kwargs)
881
+ normal = self.shifting_outputs(token_ids.unsqueeze(0), self.speech_token_range, self.max_channels)
882
+ audio_frags = self._find_max_valid_positions(normal, self.audio_pad_token_id)[0]
883
+ if len(audio_frags):
884
+ frag = torch.cat([f.permute(1, 0).unsqueeze(1) for f in audio_frags], dim=1)
885
+ audio = self.audio_tokenizer.decode(frag, overlap_seconds=10)["audio_values"]
886
+ else:
887
+ audio = None
888
+ return MossTTSDResponse(
889
+ audio=None if audio is None else audio.detach().cpu().numpy(),
890
+ generated_text=text,
891
+ sampling_rate=self.output_sample_rate,
892
+ )
893
+
894
+ def save_audio(self, audios, output_dir="output", prefix="audio"):
895
+ """
896
+ Save multiple audio fragments to files.
897
+
898
+ Args:
899
+ audios: List of audio data fragments from batch_decode
900
+ output_dir (str): Directory to save audio files
901
+ prefix (str): Prefix for audio filenames
902
+ """
903
+ if not is_torchaudio_available():
904
+ raise ImportError("Please install `torchaudio` to save audio files.")
905
+
906
+ os.makedirs(output_dir, exist_ok=True)
907
+
908
+ for i, data in enumerate(audios):
909
+ for j, fragment in enumerate(data):
910
+ filename = f"{output_dir}/{prefix}_{i}_{j}.wav"
911
+ torchaudio.save(filename, fragment.cpu(), self.output_sample_rate)
912
+
913
+
914
+ __all__ = ["MossTTSDProcessor"]
processor_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "processor_class": "processing_moss_ttsd.MossTTSDProcessor",
3
+ "auto_map": {
4
+ "AutoProcessor": "processing_moss_ttsd.MossTTSDProcessor"
5
+ }
6
+ }
tokenizer_config.json CHANGED
@@ -8451,12 +8451,20 @@
8451
  "<|video_pad|>"
8452
  ],
8453
  "bos_token": null,
8454
- "chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'You are a helpful assistant.' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are a helpful assistant.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- message.content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n",
8455
  "clean_up_tokenization_spaces": false,
8456
  "eos_token": "<|endoftext|>",
8457
  "errors": "replace",
8458
  "extra_special_tokens": {},
8459
- "model_max_length": 131072,
 
 
 
 
 
 
 
 
8460
  "pad_token": "<|endoftext|>",
8461
  "padding_side": "right",
8462
  "split_special_tokens": false,
 
8451
  "<|video_pad|>"
8452
  ],
8453
  "bos_token": null,
8454
+ "chat_template": "<|begin_of_style|>{{ system_prompt | default('You are a speech synthesizer that generates natural, realistic, and human-like conversational audio from dialogue text.') }}<|end_of_style|>\n<|begin_of_text|>{{ text }}<|end_of_text|>\n<|begin_of_speech|>",
8455
  "clean_up_tokenization_spaces": false,
8456
  "eos_token": "<|endoftext|>",
8457
  "errors": "replace",
8458
  "extra_special_tokens": {},
8459
+ "model_max_length": 16384,
8460
+ "processor_class": "processing_moss_ttsd.MossTTSDProcessor",
8461
+ "speech_token_range": [
8462
+ 151665,
8463
+ 152689
8464
+ ],
8465
+ "audio_bos_token": "<|begin_of_speech|>",
8466
+ "audio_eos_token": "<|end_of_speech|>",
8467
+ "audio_pad_token_id": 1024,
8468
  "pad_token": "<|endoftext|>",
8469
  "padding_side": "right",
8470
  "split_special_tokens": false,