Update modeling_spark_tts.py
Browse files- modeling_spark_tts.py +133 -73
modeling_spark_tts.py
CHANGED
@@ -24,6 +24,7 @@ from typing import Dict, Any, Tuple, Optional, Union
|
|
24 |
|
25 |
from transformers import PreTrainedModel, AutoModelForCausalLM, Wav2Vec2FeatureExtractor, Wav2Vec2Model
|
26 |
from transformers.utils import logging, requires_backends, cached_file
|
|
|
27 |
from transformers.generation.utils import GenerationMixin
|
28 |
from transformers.configuration_utils import PretrainedConfig
|
29 |
from safetensors.torch import load_file
|
@@ -3039,7 +3040,7 @@ class SparkTTSModel(PreTrainedModel, GenerationMixin):
|
|
3039 |
"""
|
3040 |
config_class = SparkTTSConfig
|
3041 |
base_model_prefix = "spark_tts"
|
3042 |
-
_supports_load_fast =
|
3043 |
|
3044 |
def __init__(self, config: SparkTTSConfig, llm=None, wav2vec2_model=None, wav2vec2_processor=None, bicodec=None):
|
3045 |
super().__init__(config)
|
@@ -3049,9 +3050,8 @@ class SparkTTSModel(PreTrainedModel, GenerationMixin):
|
|
3049 |
self.wav2vec2_processor = wav2vec2_processor
|
3050 |
self.bicodec = bicodec
|
3051 |
|
3052 |
-
#
|
3053 |
-
|
3054 |
-
self.wav2vec2_model.config.output_hidden_states = True
|
3055 |
|
3056 |
|
3057 |
@classmethod
|
@@ -3066,124 +3066,182 @@ class SparkTTSModel(PreTrainedModel, GenerationMixin):
|
|
3066 |
local_files_only: bool = False,
|
3067 |
token: Optional[Union[str, bool]] = None,
|
3068 |
revision: str = "main",
|
3069 |
-
use_safetensors: bool = None,
|
3070 |
**kwargs,
|
3071 |
):
|
3072 |
-
#
|
3073 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3074 |
config, model_kwargs = cls.config_class.from_pretrained(
|
3075 |
-
|
3076 |
-
*model_args,
|
3077 |
-
cache_dir=cache_dir,
|
3078 |
force_download=force_download,
|
3079 |
local_files_only=local_files_only,
|
3080 |
token=token,
|
3081 |
revision=revision,
|
|
|
3082 |
return_unused_kwargs=True,
|
3083 |
-
**kwargs,
|
3084 |
)
|
|
|
|
|
3085 |
else:
|
3086 |
-
|
3087 |
-
|
3088 |
-
# Pop device map info - will handle placement later
|
3089 |
-
device_map = model_kwargs.pop("device_map", None)
|
3090 |
-
torch_dtype = model_kwargs.pop("torch_dtype", "auto") # Use config's or auto
|
3091 |
-
|
3092 |
-
# Check for trust_remote_code - needed for config loading if custom code involved there too
|
3093 |
-
trust_remote_code = model_kwargs.pop("trust_remote_code", False) # Important
|
3094 |
-
|
3095 |
-
|
3096 |
-
# NEW IMPROVED PATH RESOLUTION
|
3097 |
-
from huggingface_hub import snapshot_download
|
3098 |
-
import os
|
3099 |
-
# Check if it's a local path first
|
3100 |
-
if os.path.isdir(pretrained_model_name_or_path):
|
3101 |
-
resolved_model_path = Path(pretrained_model_name_or_path)
|
3102 |
-
else:
|
3103 |
-
# Try to get from Hugging Face Hub
|
3104 |
-
try:
|
3105 |
-
logger.info(f"Downloading/locating model from Hugging Face Hub: {pretrained_model_name_or_path}")
|
3106 |
-
# This will download the model if needed and return the cached path
|
3107 |
-
resolved_model_path = Path(snapshot_download(
|
3108 |
-
pretrained_model_name_or_path,
|
3109 |
-
revision=revision,
|
3110 |
-
cache_dir=cache_dir,
|
3111 |
-
force_download=force_download,
|
3112 |
-
local_files_only=local_files_only,
|
3113 |
-
token=token,
|
3114 |
-
))
|
3115 |
-
except Exception as e:
|
3116 |
-
logger.error(f"Error downloading model: {e}")
|
3117 |
-
raise EnvironmentError(f"Failed to find or download model '{pretrained_model_name_or_path}': {e}")
|
3118 |
|
3119 |
-
if not resolved_model_path.is_dir():
|
3120 |
-
raise EnvironmentError(f"Cannot find model directory at {resolved_model_path}")
|
3121 |
|
|
|
|
|
|
|
|
|
|
|
|
|
3122 |
|
3123 |
-
# Helper function to resolve paths relative to the main model directory
|
3124 |
-
def
|
3125 |
p = Path(sub_path)
|
3126 |
if p.is_absolute():
|
3127 |
return str(p)
|
3128 |
else:
|
3129 |
-
|
3130 |
-
|
|
|
|
|
3131 |
|
3132 |
# --- Load LLM ---
|
3133 |
-
llm_path =
|
3134 |
logger.info(f"Loading LLM from resolved path: {llm_path}")
|
3135 |
try:
|
3136 |
llm = AutoModelForCausalLM.from_pretrained(
|
3137 |
llm_path,
|
3138 |
-
torch_dtype=
|
3139 |
trust_remote_code=trust_remote_code, # Pass down trust_remote_code
|
3140 |
-
|
|
|
|
|
3141 |
)
|
3142 |
except Exception as e:
|
3143 |
raise OSError(f"Failed to load LLM from {llm_path}: {e}")
|
3144 |
|
3145 |
# --- Load Wav2Vec2 ---
|
3146 |
-
w2v_path =
|
3147 |
-
logger.info(f"Loading Wav2Vec2 from resolved path: {w2v_path}")
|
3148 |
try:
|
3149 |
-
|
3150 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3151 |
except Exception as e:
|
3152 |
raise OSError(f"Failed to load Wav2Vec2 components from {w2v_path}: {e}")
|
3153 |
|
3154 |
# --- Load BiCodec ---
|
3155 |
-
bicodec_path =
|
3156 |
logger.info(f"Loading BiCodec from resolved path: {bicodec_path}")
|
3157 |
-
# print(f"Loading BiCodec from resolved path: {bicodec_path}, {config}")
|
3158 |
if not config.bicodec_config or "audio_tokenizer" not in config.bicodec_config:
|
3159 |
raise ValueError("BiCodec configuration ('bicodec_config' with 'audio_tokenizer' key) not found in SparkTTSConfig.")
|
3160 |
try:
|
3161 |
-
# Assuming BiCodec class
|
|
|
3162 |
bicodec = BiCodec.load_from_config_and_checkpoint(
|
3163 |
model_dir=Path(bicodec_path),
|
3164 |
config_dict=config.bicodec_config["audio_tokenizer"]
|
3165 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
3166 |
except Exception as e:
|
3167 |
-
|
|
|
|
|
|
|
3168 |
|
3169 |
|
3170 |
-
# Instantiate the main model wrapper
|
|
|
3171 |
model = cls(config, llm=llm, wav2vec2_model=wav2vec2_model, wav2vec2_processor=wav2vec2_processor, bicodec=bicodec)
|
3172 |
|
3173 |
-
# --- Handle device placement ---
|
3174 |
-
#
|
3175 |
-
# Determine target device
|
3176 |
if torch.cuda.is_available():
|
3177 |
-
|
3178 |
-
|
|
|
3179 |
else:
|
3180 |
-
|
3181 |
-
logger.info(f"Placing SparkTTSModel and components on device: {device}")
|
3182 |
-
model.to(device) # This should move all registered nn.Module attributes
|
3183 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3184 |
return model
|
3185 |
|
3186 |
|
|
|
3187 |
# --- Embedding getters/setters (delegate to LLM if loaded) ---
|
3188 |
def get_input_embeddings(self):
|
3189 |
if self.llm:
|
@@ -3212,11 +3270,13 @@ class SparkTTSModel(PreTrainedModel, GenerationMixin):
|
|
3212 |
# post_init is less critical now as loading happens in from_pretrained,
|
3213 |
# but can be used for final checks or setup.
|
3214 |
def post_init(self):
|
3215 |
-
# Ensure wav2vec2 config has output_hidden_states=True
|
3216 |
if self.wav2vec2_model and hasattr(self.wav2vec2_model.config, 'output_hidden_states'):
|
3217 |
-
|
3218 |
-
|
3219 |
-
|
|
|
|
|
3220 |
|
3221 |
@property
|
3222 |
def device(self) -> torch.device:
|
|
|
24 |
|
25 |
from transformers import PreTrainedModel, AutoModelForCausalLM, Wav2Vec2FeatureExtractor, Wav2Vec2Model
|
26 |
from transformers.utils import logging, requires_backends, cached_file
|
27 |
+
from huggingface_hub import snapshot_download
|
28 |
from transformers.generation.utils import GenerationMixin
|
29 |
from transformers.configuration_utils import PretrainedConfig
|
30 |
from safetensors.torch import load_file
|
|
|
3040 |
"""
|
3041 |
config_class = SparkTTSConfig
|
3042 |
base_model_prefix = "spark_tts"
|
3043 |
+
_supports_load_fast = True
|
3044 |
|
3045 |
def __init__(self, config: SparkTTSConfig, llm=None, wav2vec2_model=None, wav2vec2_processor=None, bicodec=None):
|
3046 |
super().__init__(config)
|
|
|
3050 |
self.wav2vec2_processor = wav2vec2_processor
|
3051 |
self.bicodec = bicodec
|
3052 |
|
3053 |
+
# Ensure wav2vec2 config has output_hidden_states=True after loading
|
3054 |
+
self.post_init()
|
|
|
3055 |
|
3056 |
|
3057 |
@classmethod
|
|
|
3066 |
local_files_only: bool = False,
|
3067 |
token: Optional[Union[str, bool]] = None,
|
3068 |
revision: str = "main",
|
3069 |
+
use_safetensors: Optional[bool] = None, # Keep None to let transformers decide
|
3070 |
**kwargs,
|
3071 |
):
|
3072 |
+
# Pop device map and dtype early - handle placement later
|
3073 |
+
# Note: device_map is complex with multiple components. Manual .to(device) is simpler here.
|
3074 |
+
device_map = kwargs.pop("device_map", None)
|
3075 |
+
if device_map:
|
3076 |
+
logger.warning("`device_map` is not directly supported for this composite model. Use .to(device) after loading.")
|
3077 |
+
|
3078 |
+
torch_dtype = kwargs.pop("torch_dtype", "auto") # Can be "auto", float32, float16, bfloat16
|
3079 |
+
trust_remote_code = kwargs.pop("trust_remote_code", False) # CRITICAL for custom code
|
3080 |
+
|
3081 |
+
# --- 1. Resolve the main model directory ---
|
3082 |
+
# This handles downloading from Hub or using a local path robustly.
|
3083 |
+
if pretrained_model_name_or_path is None:
|
3084 |
+
raise ValueError("`pretrained_model_name_or_path` must be provided.")
|
3085 |
+
|
3086 |
+
model_path = Path(pretrained_model_name_or_path)
|
3087 |
+
if not model_path.is_dir():
|
3088 |
+
# If it's not a local directory, assume it's a Hub ID and download everything
|
3089 |
+
logger.info(f"{pretrained_model_name_or_path} is not a local directory. Assuming Hub ID and downloading.")
|
3090 |
+
try:
|
3091 |
+
resolved_model_path = snapshot_download(
|
3092 |
+
repo_id=str(pretrained_model_name_or_path),
|
3093 |
+
cache_dir=cache_dir,
|
3094 |
+
force_download=force_download,
|
3095 |
+
local_files_only=local_files_only,
|
3096 |
+
token=token,
|
3097 |
+
revision=revision,
|
3098 |
+
allow_patterns=["*.json", "*.safetensors", "*.bin", "*.yaml", "*.txt", "README.md"], # Be somewhat permissive
|
3099 |
+
# ignore_patterns=["*.git*"], # Optional: ignore git files
|
3100 |
+
# user_agent={"agent": "spark-tts-custom-loader"}, # Optional
|
3101 |
+
)
|
3102 |
+
resolved_model_path = Path(resolved_model_path)
|
3103 |
+
logger.info(f"Model downloaded to cache: {resolved_model_path}")
|
3104 |
+
except Exception as e:
|
3105 |
+
raise OSError(
|
3106 |
+
f"Failed to download model '{pretrained_model_name_or_path}' from Hugging Face Hub. "
|
3107 |
+
f"Ensure the ID is correct and network is available. Error: {e}"
|
3108 |
+
)
|
3109 |
+
else:
|
3110 |
+
# It's a local directory path
|
3111 |
+
resolved_model_path = model_path
|
3112 |
+
logger.info(f"Loading model from local directory: {resolved_model_path}")
|
3113 |
+
|
3114 |
+
if not resolved_model_path.is_dir():
|
3115 |
+
# This should ideally not happen after snapshot_download or initial check
|
3116 |
+
raise EnvironmentError(f"Cannot find resolved model directory at {resolved_model_path}")
|
3117 |
+
|
3118 |
+
# --- 2. Load the main configuration ---
|
3119 |
+
# The config might have been passed explicitly, otherwise load from resolved path
|
3120 |
+
if not isinstance(config, PretrainedConfig):
|
3121 |
+
config_path = config if config is not None else resolved_model_path
|
3122 |
config, model_kwargs = cls.config_class.from_pretrained(
|
3123 |
+
config_path, # Load from the resolved directory or explicit config path
|
3124 |
+
*model_args, # Pass *model_args here if they influence config loading
|
3125 |
+
cache_dir=cache_dir, # Pass relevant args down
|
3126 |
force_download=force_download,
|
3127 |
local_files_only=local_files_only,
|
3128 |
token=token,
|
3129 |
revision=revision,
|
3130 |
+
trust_remote_code=trust_remote_code, # Needed if config class itself is remote
|
3131 |
return_unused_kwargs=True,
|
3132 |
+
**kwargs, # Pass remaining kwargs
|
3133 |
)
|
3134 |
+
# Update kwargs with unused ones from config loading
|
3135 |
+
kwargs.update(model_kwargs)
|
3136 |
else:
|
3137 |
+
# Config object was passed directly
|
3138 |
+
pass # kwargs remain as they were
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3139 |
|
|
|
|
|
3140 |
|
3141 |
+
# --- Determine torch_dtype (use config value if specified and not overridden) ---
|
3142 |
+
# Priority: Explicit torch_dtype arg > config.torch_dtype > "auto" (default)
|
3143 |
+
final_torch_dtype = torch_dtype # Explicit arg has highest prio
|
3144 |
+
if final_torch_dtype == "auto":
|
3145 |
+
final_torch_dtype = getattr(config, "torch_dtype", None) # Use config value if present
|
3146 |
+
# final_torch_dtype can still be None or "auto" here, handle downstream
|
3147 |
|
3148 |
+
# --- Helper function to resolve paths relative to the main model directory ---
|
3149 |
+
def _resolve_sub_path(sub_path):
|
3150 |
p = Path(sub_path)
|
3151 |
if p.is_absolute():
|
3152 |
return str(p)
|
3153 |
else:
|
3154 |
+
# Resolve relative to the potentially cached main model path
|
3155 |
+
return str(resolved_model_path / p)
|
3156 |
+
|
3157 |
+
# --- 3. Load Sub-components ---
|
3158 |
|
3159 |
# --- Load LLM ---
|
3160 |
+
llm_path = _resolve_sub_path(config.llm_model_name_or_path)
|
3161 |
logger.info(f"Loading LLM from resolved path: {llm_path}")
|
3162 |
try:
|
3163 |
llm = AutoModelForCausalLM.from_pretrained(
|
3164 |
llm_path,
|
3165 |
+
torch_dtype=final_torch_dtype if final_torch_dtype != "auto" else None, # Pass resolved dtype or None
|
3166 |
trust_remote_code=trust_remote_code, # Pass down trust_remote_code
|
3167 |
+
# Pass remaining kwargs that might be relevant for AutoModelForCausalLM
|
3168 |
+
# Filter kwargs if necessary, but often passing them is fine
|
3169 |
+
**kwargs
|
3170 |
)
|
3171 |
except Exception as e:
|
3172 |
raise OSError(f"Failed to load LLM from {llm_path}: {e}")
|
3173 |
|
3174 |
# --- Load Wav2Vec2 ---
|
3175 |
+
w2v_path = _resolve_sub_path(config.wav2vec2_model_name_or_path)
|
3176 |
+
logger.info(f"Loading Wav2Vec2 components from resolved path: {w2v_path}")
|
3177 |
try:
|
3178 |
+
# Load feature extractor first
|
3179 |
+
wav2vec2_processor = Wav2Vec2FeatureExtractor.from_pretrained(
|
3180 |
+
w2v_path,
|
3181 |
+
trust_remote_code=trust_remote_code,
|
3182 |
+
# Add any relevant kwargs for feature extractor if needed
|
3183 |
+
)
|
3184 |
+
# Load model
|
3185 |
+
wav2vec2_model = Wav2Vec2Model.from_pretrained(
|
3186 |
+
w2v_path,
|
3187 |
+
trust_remote_code=trust_remote_code,
|
3188 |
+
# Add any relevant kwargs for model if needed (e.g., add_adapter=False)
|
3189 |
+
)
|
3190 |
except Exception as e:
|
3191 |
raise OSError(f"Failed to load Wav2Vec2 components from {w2v_path}: {e}")
|
3192 |
|
3193 |
# --- Load BiCodec ---
|
3194 |
+
bicodec_path = _resolve_sub_path(config.bicodec_model_name_or_path)
|
3195 |
logger.info(f"Loading BiCodec from resolved path: {bicodec_path}")
|
|
|
3196 |
if not config.bicodec_config or "audio_tokenizer" not in config.bicodec_config:
|
3197 |
raise ValueError("BiCodec configuration ('bicodec_config' with 'audio_tokenizer' key) not found in SparkTTSConfig.")
|
3198 |
try:
|
3199 |
+
# Assuming BiCodec class has the custom loading method
|
3200 |
+
# Make sure BiCodec class is imported or defined above
|
3201 |
bicodec = BiCodec.load_from_config_and_checkpoint(
|
3202 |
model_dir=Path(bicodec_path),
|
3203 |
config_dict=config.bicodec_config["audio_tokenizer"]
|
3204 |
)
|
3205 |
+
# Ensure BiCodec is an nn.Module if you want .to(device) to work easily
|
3206 |
+
if not isinstance(bicodec, torch.nn.Module):
|
3207 |
+
logger.warning("Loaded BiCodec component is not an instance of torch.nn.Module. Automatic device placement might not work.")
|
3208 |
+
|
3209 |
+
except FileNotFoundError as e:
|
3210 |
+
raise OSError(f"Failed to load BiCodec: A required file was not found in {bicodec_path}. Original error: {e}")
|
3211 |
except Exception as e:
|
3212 |
+
logger.error(f"Raw error loading BiCodec: {type(e).__name__}: {e}")
|
3213 |
+
import traceback
|
3214 |
+
traceback.print_exc() # Print full traceback for debugging BiCodec loading
|
3215 |
+
raise OSError(f"Failed to load BiCodec from {bicodec_path}. Check BiCodec implementation and file paths. Error: {e}")
|
3216 |
|
3217 |
|
3218 |
+
# --- 4. Instantiate the main model wrapper ---
|
3219 |
+
# Pass the loaded config and components
|
3220 |
model = cls(config, llm=llm, wav2vec2_model=wav2vec2_model, wav2vec2_processor=wav2vec2_processor, bicodec=bicodec)
|
3221 |
|
3222 |
+
# --- 5. Handle device placement ---
|
3223 |
+
# Move the entire model (including sub-modules if they are nn.Module attributes)
|
3224 |
+
# Determine target device based on availability
|
3225 |
if torch.cuda.is_available():
|
3226 |
+
final_device = torch.device("cuda")
|
3227 |
+
# If multiple GPUs, could select one, e.g., torch.device("cuda:0")
|
3228 |
+
# Or rely on CUDA_VISIBLE_DEVICES environment variable
|
3229 |
else:
|
3230 |
+
final_device = torch.device("cpu")
|
|
|
|
|
3231 |
|
3232 |
+
logger.info(f"Placing SparkTTSModel and components on device: {final_device}")
|
3233 |
+
# This should move all registered nn.Module attributes (llm, wav2vec2_model, bicodec if it's an nn.Module)
|
3234 |
+
try:
|
3235 |
+
model.to(final_device)
|
3236 |
+
except Exception as e:
|
3237 |
+
logger.error(f"Failed to move model to device {final_device}. Error: {e}")
|
3238 |
+
logger.warning("Device placement might be incomplete. Check component types.")
|
3239 |
+
|
3240 |
+
# --- 6. Return the loaded and prepared model ---
|
3241 |
return model
|
3242 |
|
3243 |
|
3244 |
+
|
3245 |
# --- Embedding getters/setters (delegate to LLM if loaded) ---
|
3246 |
def get_input_embeddings(self):
|
3247 |
if self.llm:
|
|
|
3270 |
# post_init is less critical now as loading happens in from_pretrained,
|
3271 |
# but can be used for final checks or setup.
|
3272 |
def post_init(self):
|
3273 |
+
# Ensure wav2vec2 config has output_hidden_states=True after loading
|
3274 |
if self.wav2vec2_model and hasattr(self.wav2vec2_model.config, 'output_hidden_states'):
|
3275 |
+
if not self.wav2vec2_model.config.output_hidden_states:
|
3276 |
+
self.wav2vec2_model.config.output_hidden_states = True
|
3277 |
+
logger.info("Set wav2vec2_model.config.output_hidden_states=True")
|
3278 |
+
else:
|
3279 |
+
logger.warning("Could not access wav2vec2_model.config to ensure output_hidden_states=True.")
|
3280 |
|
3281 |
@property
|
3282 |
def device(self) -> torch.device:
|