ancv commited on
Commit
d16a5a4
·
verified ·
1 Parent(s): c394c41

Update modeling_spark_tts.py

Browse files
Files changed (1) hide show
  1. modeling_spark_tts.py +133 -73
modeling_spark_tts.py CHANGED
@@ -24,6 +24,7 @@ from typing import Dict, Any, Tuple, Optional, Union
24
 
25
  from transformers import PreTrainedModel, AutoModelForCausalLM, Wav2Vec2FeatureExtractor, Wav2Vec2Model
26
  from transformers.utils import logging, requires_backends, cached_file
 
27
  from transformers.generation.utils import GenerationMixin
28
  from transformers.configuration_utils import PretrainedConfig
29
  from safetensors.torch import load_file
@@ -3039,7 +3040,7 @@ class SparkTTSModel(PreTrainedModel, GenerationMixin):
3039
  """
3040
  config_class = SparkTTSConfig
3041
  base_model_prefix = "spark_tts"
3042
- _supports_load_fast = False
3043
 
3044
  def __init__(self, config: SparkTTSConfig, llm=None, wav2vec2_model=None, wav2vec2_processor=None, bicodec=None):
3045
  super().__init__(config)
@@ -3049,9 +3050,8 @@ class SparkTTSModel(PreTrainedModel, GenerationMixin):
3049
  self.wav2vec2_processor = wav2vec2_processor
3050
  self.bicodec = bicodec
3051
 
3052
- # Wav2Vec2 specific config adjustment (needs to happen after loading)
3053
- if self.wav2vec2_model and hasattr(self.wav2vec2_model.config, 'output_hidden_states'):
3054
- self.wav2vec2_model.config.output_hidden_states = True
3055
 
3056
 
3057
  @classmethod
@@ -3066,124 +3066,182 @@ class SparkTTSModel(PreTrainedModel, GenerationMixin):
3066
  local_files_only: bool = False,
3067
  token: Optional[Union[str, bool]] = None,
3068
  revision: str = "main",
3069
- use_safetensors: bool = None,
3070
  **kwargs,
3071
  ):
3072
- # 1. Load Config
3073
- if config is None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3074
  config, model_kwargs = cls.config_class.from_pretrained(
3075
- pretrained_model_name_or_path,
3076
- *model_args,
3077
- cache_dir=cache_dir,
3078
  force_download=force_download,
3079
  local_files_only=local_files_only,
3080
  token=token,
3081
  revision=revision,
 
3082
  return_unused_kwargs=True,
3083
- **kwargs,
3084
  )
 
 
3085
  else:
3086
- model_kwargs = kwargs
3087
-
3088
- # Pop device map info - will handle placement later
3089
- device_map = model_kwargs.pop("device_map", None)
3090
- torch_dtype = model_kwargs.pop("torch_dtype", "auto") # Use config's or auto
3091
-
3092
- # Check for trust_remote_code - needed for config loading if custom code involved there too
3093
- trust_remote_code = model_kwargs.pop("trust_remote_code", False) # Important
3094
-
3095
-
3096
- # NEW IMPROVED PATH RESOLUTION
3097
- from huggingface_hub import snapshot_download
3098
- import os
3099
- # Check if it's a local path first
3100
- if os.path.isdir(pretrained_model_name_or_path):
3101
- resolved_model_path = Path(pretrained_model_name_or_path)
3102
- else:
3103
- # Try to get from Hugging Face Hub
3104
- try:
3105
- logger.info(f"Downloading/locating model from Hugging Face Hub: {pretrained_model_name_or_path}")
3106
- # This will download the model if needed and return the cached path
3107
- resolved_model_path = Path(snapshot_download(
3108
- pretrained_model_name_or_path,
3109
- revision=revision,
3110
- cache_dir=cache_dir,
3111
- force_download=force_download,
3112
- local_files_only=local_files_only,
3113
- token=token,
3114
- ))
3115
- except Exception as e:
3116
- logger.error(f"Error downloading model: {e}")
3117
- raise EnvironmentError(f"Failed to find or download model '{pretrained_model_name_or_path}': {e}")
3118
 
3119
- if not resolved_model_path.is_dir():
3120
- raise EnvironmentError(f"Cannot find model directory at {resolved_model_path}")
3121
 
 
 
 
 
 
 
3122
 
3123
- # Helper function to resolve paths relative to the main model directory
3124
- def _resolve_path(sub_path):
3125
  p = Path(sub_path)
3126
  if p.is_absolute():
3127
  return str(p)
3128
  else:
3129
- # Resolve relative to the potentially cached main model path
3130
- return str(resolved_model_path / p)
 
 
3131
 
3132
  # --- Load LLM ---
3133
- llm_path = _resolve_path(config.llm_model_name_or_path)
3134
  logger.info(f"Loading LLM from resolved path: {llm_path}")
3135
  try:
3136
  llm = AutoModelForCausalLM.from_pretrained(
3137
  llm_path,
3138
- torch_dtype=torch_dtype if torch_dtype != "auto" else config.torch_dtype, # Prioritize explicit dtype
3139
  trust_remote_code=trust_remote_code, # Pass down trust_remote_code
3140
- **model_kwargs # Pass remaining kwargs
 
 
3141
  )
3142
  except Exception as e:
3143
  raise OSError(f"Failed to load LLM from {llm_path}: {e}")
3144
 
3145
  # --- Load Wav2Vec2 ---
3146
- w2v_path = _resolve_path(config.wav2vec2_model_name_or_path)
3147
- logger.info(f"Loading Wav2Vec2 from resolved path: {w2v_path}")
3148
  try:
3149
- wav2vec2_processor = Wav2Vec2FeatureExtractor.from_pretrained(w2v_path, trust_remote_code=trust_remote_code)
3150
- wav2vec2_model = Wav2Vec2Model.from_pretrained(w2v_path, trust_remote_code=trust_remote_code)
 
 
 
 
 
 
 
 
 
 
3151
  except Exception as e:
3152
  raise OSError(f"Failed to load Wav2Vec2 components from {w2v_path}: {e}")
3153
 
3154
  # --- Load BiCodec ---
3155
- bicodec_path = _resolve_path(config.bicodec_model_name_or_path)
3156
  logger.info(f"Loading BiCodec from resolved path: {bicodec_path}")
3157
- # print(f"Loading BiCodec from resolved path: {bicodec_path}, {config}")
3158
  if not config.bicodec_config or "audio_tokenizer" not in config.bicodec_config:
3159
  raise ValueError("BiCodec configuration ('bicodec_config' with 'audio_tokenizer' key) not found in SparkTTSConfig.")
3160
  try:
3161
- # Assuming BiCodec class is defined above in this file
 
3162
  bicodec = BiCodec.load_from_config_and_checkpoint(
3163
  model_dir=Path(bicodec_path),
3164
  config_dict=config.bicodec_config["audio_tokenizer"]
3165
  )
 
 
 
 
 
 
3166
  except Exception as e:
3167
- raise OSError(f"Failed to load BiCodec from {bicodec_path}: {e}")
 
 
 
3168
 
3169
 
3170
- # Instantiate the main model wrapper, passing the loaded components
 
3171
  model = cls(config, llm=llm, wav2vec2_model=wav2vec2_model, wav2vec2_processor=wav2vec2_processor, bicodec=bicodec)
3172
 
3173
- # --- Handle device placement ---
3174
- # Note: device_map is complex; simple .to(device) is easier if not using accelerate
3175
- # Determine target device
3176
  if torch.cuda.is_available():
3177
- current_device = torch.cuda.current_device()
3178
- device = torch.device(f"cuda:{current_device}")
 
3179
  else:
3180
- device = torch.device("cpu")
3181
- logger.info(f"Placing SparkTTSModel and components on device: {device}")
3182
- model.to(device) # This should move all registered nn.Module attributes
3183
 
 
 
 
 
 
 
 
 
 
3184
  return model
3185
 
3186
 
 
3187
  # --- Embedding getters/setters (delegate to LLM if loaded) ---
3188
  def get_input_embeddings(self):
3189
  if self.llm:
@@ -3212,11 +3270,13 @@ class SparkTTSModel(PreTrainedModel, GenerationMixin):
3212
  # post_init is less critical now as loading happens in from_pretrained,
3213
  # but can be used for final checks or setup.
3214
  def post_init(self):
3215
- # Ensure wav2vec2 config has output_hidden_states=True
3216
  if self.wav2vec2_model and hasattr(self.wav2vec2_model.config, 'output_hidden_states'):
3217
- if not self.wav2vec2_model.config.output_hidden_states:
3218
- self.wav2vec2_model.config.output_hidden_states = True
3219
- logger.info("Set wav2vec2_model.config.output_hidden_states=True")
 
 
3220
 
3221
  @property
3222
  def device(self) -> torch.device:
 
24
 
25
  from transformers import PreTrainedModel, AutoModelForCausalLM, Wav2Vec2FeatureExtractor, Wav2Vec2Model
26
  from transformers.utils import logging, requires_backends, cached_file
27
+ from huggingface_hub import snapshot_download
28
  from transformers.generation.utils import GenerationMixin
29
  from transformers.configuration_utils import PretrainedConfig
30
  from safetensors.torch import load_file
 
3040
  """
3041
  config_class = SparkTTSConfig
3042
  base_model_prefix = "spark_tts"
3043
+ _supports_load_fast = True
3044
 
3045
  def __init__(self, config: SparkTTSConfig, llm=None, wav2vec2_model=None, wav2vec2_processor=None, bicodec=None):
3046
  super().__init__(config)
 
3050
  self.wav2vec2_processor = wav2vec2_processor
3051
  self.bicodec = bicodec
3052
 
3053
+ # Ensure wav2vec2 config has output_hidden_states=True after loading
3054
+ self.post_init()
 
3055
 
3056
 
3057
  @classmethod
 
3066
  local_files_only: bool = False,
3067
  token: Optional[Union[str, bool]] = None,
3068
  revision: str = "main",
3069
+ use_safetensors: Optional[bool] = None, # Keep None to let transformers decide
3070
  **kwargs,
3071
  ):
3072
+ # Pop device map and dtype early - handle placement later
3073
+ # Note: device_map is complex with multiple components. Manual .to(device) is simpler here.
3074
+ device_map = kwargs.pop("device_map", None)
3075
+ if device_map:
3076
+ logger.warning("`device_map` is not directly supported for this composite model. Use .to(device) after loading.")
3077
+
3078
+ torch_dtype = kwargs.pop("torch_dtype", "auto") # Can be "auto", float32, float16, bfloat16
3079
+ trust_remote_code = kwargs.pop("trust_remote_code", False) # CRITICAL for custom code
3080
+
3081
+ # --- 1. Resolve the main model directory ---
3082
+ # This handles downloading from Hub or using a local path robustly.
3083
+ if pretrained_model_name_or_path is None:
3084
+ raise ValueError("`pretrained_model_name_or_path` must be provided.")
3085
+
3086
+ model_path = Path(pretrained_model_name_or_path)
3087
+ if not model_path.is_dir():
3088
+ # If it's not a local directory, assume it's a Hub ID and download everything
3089
+ logger.info(f"{pretrained_model_name_or_path} is not a local directory. Assuming Hub ID and downloading.")
3090
+ try:
3091
+ resolved_model_path = snapshot_download(
3092
+ repo_id=str(pretrained_model_name_or_path),
3093
+ cache_dir=cache_dir,
3094
+ force_download=force_download,
3095
+ local_files_only=local_files_only,
3096
+ token=token,
3097
+ revision=revision,
3098
+ allow_patterns=["*.json", "*.safetensors", "*.bin", "*.yaml", "*.txt", "README.md"], # Be somewhat permissive
3099
+ # ignore_patterns=["*.git*"], # Optional: ignore git files
3100
+ # user_agent={"agent": "spark-tts-custom-loader"}, # Optional
3101
+ )
3102
+ resolved_model_path = Path(resolved_model_path)
3103
+ logger.info(f"Model downloaded to cache: {resolved_model_path}")
3104
+ except Exception as e:
3105
+ raise OSError(
3106
+ f"Failed to download model '{pretrained_model_name_or_path}' from Hugging Face Hub. "
3107
+ f"Ensure the ID is correct and network is available. Error: {e}"
3108
+ )
3109
+ else:
3110
+ # It's a local directory path
3111
+ resolved_model_path = model_path
3112
+ logger.info(f"Loading model from local directory: {resolved_model_path}")
3113
+
3114
+ if not resolved_model_path.is_dir():
3115
+ # This should ideally not happen after snapshot_download or initial check
3116
+ raise EnvironmentError(f"Cannot find resolved model directory at {resolved_model_path}")
3117
+
3118
+ # --- 2. Load the main configuration ---
3119
+ # The config might have been passed explicitly, otherwise load from resolved path
3120
+ if not isinstance(config, PretrainedConfig):
3121
+ config_path = config if config is not None else resolved_model_path
3122
  config, model_kwargs = cls.config_class.from_pretrained(
3123
+ config_path, # Load from the resolved directory or explicit config path
3124
+ *model_args, # Pass *model_args here if they influence config loading
3125
+ cache_dir=cache_dir, # Pass relevant args down
3126
  force_download=force_download,
3127
  local_files_only=local_files_only,
3128
  token=token,
3129
  revision=revision,
3130
+ trust_remote_code=trust_remote_code, # Needed if config class itself is remote
3131
  return_unused_kwargs=True,
3132
+ **kwargs, # Pass remaining kwargs
3133
  )
3134
+ # Update kwargs with unused ones from config loading
3135
+ kwargs.update(model_kwargs)
3136
  else:
3137
+ # Config object was passed directly
3138
+ pass # kwargs remain as they were
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3139
 
 
 
3140
 
3141
+ # --- Determine torch_dtype (use config value if specified and not overridden) ---
3142
+ # Priority: Explicit torch_dtype arg > config.torch_dtype > "auto" (default)
3143
+ final_torch_dtype = torch_dtype # Explicit arg has highest prio
3144
+ if final_torch_dtype == "auto":
3145
+ final_torch_dtype = getattr(config, "torch_dtype", None) # Use config value if present
3146
+ # final_torch_dtype can still be None or "auto" here, handle downstream
3147
 
3148
+ # --- Helper function to resolve paths relative to the main model directory ---
3149
+ def _resolve_sub_path(sub_path):
3150
  p = Path(sub_path)
3151
  if p.is_absolute():
3152
  return str(p)
3153
  else:
3154
+ # Resolve relative to the potentially cached main model path
3155
+ return str(resolved_model_path / p)
3156
+
3157
+ # --- 3. Load Sub-components ---
3158
 
3159
  # --- Load LLM ---
3160
+ llm_path = _resolve_sub_path(config.llm_model_name_or_path)
3161
  logger.info(f"Loading LLM from resolved path: {llm_path}")
3162
  try:
3163
  llm = AutoModelForCausalLM.from_pretrained(
3164
  llm_path,
3165
+ torch_dtype=final_torch_dtype if final_torch_dtype != "auto" else None, # Pass resolved dtype or None
3166
  trust_remote_code=trust_remote_code, # Pass down trust_remote_code
3167
+ # Pass remaining kwargs that might be relevant for AutoModelForCausalLM
3168
+ # Filter kwargs if necessary, but often passing them is fine
3169
+ **kwargs
3170
  )
3171
  except Exception as e:
3172
  raise OSError(f"Failed to load LLM from {llm_path}: {e}")
3173
 
3174
  # --- Load Wav2Vec2 ---
3175
+ w2v_path = _resolve_sub_path(config.wav2vec2_model_name_or_path)
3176
+ logger.info(f"Loading Wav2Vec2 components from resolved path: {w2v_path}")
3177
  try:
3178
+ # Load feature extractor first
3179
+ wav2vec2_processor = Wav2Vec2FeatureExtractor.from_pretrained(
3180
+ w2v_path,
3181
+ trust_remote_code=trust_remote_code,
3182
+ # Add any relevant kwargs for feature extractor if needed
3183
+ )
3184
+ # Load model
3185
+ wav2vec2_model = Wav2Vec2Model.from_pretrained(
3186
+ w2v_path,
3187
+ trust_remote_code=trust_remote_code,
3188
+ # Add any relevant kwargs for model if needed (e.g., add_adapter=False)
3189
+ )
3190
  except Exception as e:
3191
  raise OSError(f"Failed to load Wav2Vec2 components from {w2v_path}: {e}")
3192
 
3193
  # --- Load BiCodec ---
3194
+ bicodec_path = _resolve_sub_path(config.bicodec_model_name_or_path)
3195
  logger.info(f"Loading BiCodec from resolved path: {bicodec_path}")
 
3196
  if not config.bicodec_config or "audio_tokenizer" not in config.bicodec_config:
3197
  raise ValueError("BiCodec configuration ('bicodec_config' with 'audio_tokenizer' key) not found in SparkTTSConfig.")
3198
  try:
3199
+ # Assuming BiCodec class has the custom loading method
3200
+ # Make sure BiCodec class is imported or defined above
3201
  bicodec = BiCodec.load_from_config_and_checkpoint(
3202
  model_dir=Path(bicodec_path),
3203
  config_dict=config.bicodec_config["audio_tokenizer"]
3204
  )
3205
+ # Ensure BiCodec is an nn.Module if you want .to(device) to work easily
3206
+ if not isinstance(bicodec, torch.nn.Module):
3207
+ logger.warning("Loaded BiCodec component is not an instance of torch.nn.Module. Automatic device placement might not work.")
3208
+
3209
+ except FileNotFoundError as e:
3210
+ raise OSError(f"Failed to load BiCodec: A required file was not found in {bicodec_path}. Original error: {e}")
3211
  except Exception as e:
3212
+ logger.error(f"Raw error loading BiCodec: {type(e).__name__}: {e}")
3213
+ import traceback
3214
+ traceback.print_exc() # Print full traceback for debugging BiCodec loading
3215
+ raise OSError(f"Failed to load BiCodec from {bicodec_path}. Check BiCodec implementation and file paths. Error: {e}")
3216
 
3217
 
3218
+ # --- 4. Instantiate the main model wrapper ---
3219
+ # Pass the loaded config and components
3220
  model = cls(config, llm=llm, wav2vec2_model=wav2vec2_model, wav2vec2_processor=wav2vec2_processor, bicodec=bicodec)
3221
 
3222
+ # --- 5. Handle device placement ---
3223
+ # Move the entire model (including sub-modules if they are nn.Module attributes)
3224
+ # Determine target device based on availability
3225
  if torch.cuda.is_available():
3226
+ final_device = torch.device("cuda")
3227
+ # If multiple GPUs, could select one, e.g., torch.device("cuda:0")
3228
+ # Or rely on CUDA_VISIBLE_DEVICES environment variable
3229
  else:
3230
+ final_device = torch.device("cpu")
 
 
3231
 
3232
+ logger.info(f"Placing SparkTTSModel and components on device: {final_device}")
3233
+ # This should move all registered nn.Module attributes (llm, wav2vec2_model, bicodec if it's an nn.Module)
3234
+ try:
3235
+ model.to(final_device)
3236
+ except Exception as e:
3237
+ logger.error(f"Failed to move model to device {final_device}. Error: {e}")
3238
+ logger.warning("Device placement might be incomplete. Check component types.")
3239
+
3240
+ # --- 6. Return the loaded and prepared model ---
3241
  return model
3242
 
3243
 
3244
+
3245
  # --- Embedding getters/setters (delegate to LLM if loaded) ---
3246
  def get_input_embeddings(self):
3247
  if self.llm:
 
3270
  # post_init is less critical now as loading happens in from_pretrained,
3271
  # but can be used for final checks or setup.
3272
  def post_init(self):
3273
+ # Ensure wav2vec2 config has output_hidden_states=True after loading
3274
  if self.wav2vec2_model and hasattr(self.wav2vec2_model.config, 'output_hidden_states'):
3275
+ if not self.wav2vec2_model.config.output_hidden_states:
3276
+ self.wav2vec2_model.config.output_hidden_states = True
3277
+ logger.info("Set wav2vec2_model.config.output_hidden_states=True")
3278
+ else:
3279
+ logger.warning("Could not access wav2vec2_model.config to ensure output_hidden_states=True.")
3280
 
3281
  @property
3282
  def device(self) -> torch.device: