# Copyright (c) 2025 SparkAudio & The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ SparkTTS model configuration""" from transformers.configuration_utils import PretrainedConfig from transformers.utils import logging logger = logging.get_logger(__name__) class SparkTTSConfig(PretrainedConfig): """ This is the configuration class to store the configuration of a [`SparkTTSModel`]. It is used to instantiate a SparkTTS model according to the specified arguments, defining the model architecture and sub-component paths. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. Args: llm_model_name_or_path (`str`, *optional*, defaults to `"./LLM"`): Path to the pretrained LLM model or model identifier from huggingface.co/models. bicodec_model_name_or_path (`str`, *optional*, defaults to `"./BiCodec"`): Path to the pretrained BiCodec model directory. wav2vec2_model_name_or_path (`str`, *optional*, defaults to `"./wav2vec2-large-xlsr-53"`): Path to the pretrained Wav2Vec2 model directory. sample_rate (`int`, *optional*, defaults to 16000): The sampling rate of the audio files. highpass_cutoff_freq (`int`, *optional*, defaults to 40): Highpass filter cutoff frequency for audio processing. latent_hop_length (`int`, *optional*, defaults to 320): Hop length used in BiCodec processing. ref_segment_duration (`float`, *optional*, defaults to 6.0): Duration (in seconds) of the reference audio clip used for speaker embedding. volume_normalize (`bool`, *optional*, defaults to `True`): Whether to normalize the volume of audio inputs. bicodec_config (`dict`, *optional*): A dictionary containing the configuration for the BiCodec model components (encoder, decoder, etc.). This is typically loaded from the `BiCodec/config.yaml` originally. **kwargs Additional keyword arguments passed along to [`PretrainedConfig`]. """ model_type = "spark-tts" processor_class = "SparkTTSProcessor" config_files = ["config.json"] attribute_map = {} # Add mappings if needed for renaming attributes def __init__( self, llm_model_name_or_path="./LLM", bicodec_model_name_or_path="./BiCodec", wav2vec2_model_name_or_path="./wav2vec2-large-xlsr-53", sample_rate=16000, highpass_cutoff_freq=40, latent_hop_length=320, ref_segment_duration=6.0, volume_normalize=True, bicodec_config=None, **kwargs, ): self.llm_model_name_or_path = llm_model_name_or_path self.bicodec_model_name_or_path = bicodec_model_name_or_path self.wav2vec2_model_name_or_path = wav2vec2_model_name_or_path self.sample_rate = sample_rate self.highpass_cutoff_freq = highpass_cutoff_freq self.latent_hop_length = latent_hop_length self.ref_segment_duration = ref_segment_duration self.volume_normalize = volume_normalize self.bicodec_config = bicodec_config if bicodec_config is not None else {} # REMOVE THIS WARNING - the check in SparkTTSModel is better # if not self.bicodec_config: # logger.warning("BiCodec config is empty. BiCodec model might not load correctly.") super().__init__(**kwargs)