diff --git "a/modeling_spark_tts.py" "b/modeling_spark_tts.py" deleted file mode 100644--- "a/modeling_spark_tts.py" +++ /dev/null @@ -1,3499 +0,0 @@ -# coding=utf-8 -# Copyright 2024 The SparkAudio Authors and The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" PyTorch SparkTTS model.""" - -import torch -import torch.nn as nn -import numpy as np -import os -import warnings -from pathlib import Path -from typing import Dict, Any, Tuple, Optional, Union - -from transformers import PreTrainedModel, AutoModelForCausalLM, Wav2Vec2FeatureExtractor, Wav2Vec2Model -from transformers.utils import logging, requires_backends, cached_file -from huggingface_hub import snapshot_download -from transformers.generation.utils import GenerationMixin -from transformers.configuration_utils import PretrainedConfig -from safetensors.torch import load_file -import torchaudio.transforms as TT # Directly use torchaudio - -# # Import necessary components from the original codebase structure -# # These are now defined in _modeling_bicodec_components.py -# from ._modeling_bicodec_components import ( -# SpeakerEncoder, -# Encoder, -# Decoder, -# WaveGenerator, -# FactorizedVectorQuantize, -# # Include Snake1d or other base classes if BiCodec.__init__ needs them directly -# ) - - -""" Utility functions for SparkTTS """ - -import random -import soxr -import soundfile -import torch -import torchaudio -import numpy as np - -from pathlib import Path -from typing import Tuple, Dict, Any -from numpy.lib.stride_tricks import sliding_window_view -from omegaconf import OmegaConf # Keep if BiCodec config loading needs it - - -# --- Token Maps (from sparktts/utils/token_parser.py) --- -TASK_TOKEN_MAP = { - "vc": "<|task_vc|>", - "tts": "<|task_tts|>", - "asr": "<|task_asr|>", - "s2s": "<|task_s2s|>", - "t2s": "<|task_t2s|>", - "understand": "<|task_understand|>", - "caption": "<|task_cap|>", - "controllable_tts": "<|task_controllable_tts|>", - "prompt_tts": "<|task_prompt_tts|>", - "speech_edit": "<|task_edit|>", -} - -LEVELS_MAP = { - "very_low": 0, - "low": 1, - "moderate": 2, - "high": 3, - "very_high": 4, -} - -LEVELS_MAP_UI = { - 1: 'very_low', - 2: 'low', - 3: 'moderate', - 4: 'high', - 5: 'very_high' -} - -GENDER_MAP = { - "female": 0, - "male": 1, -} - -# --- Audio Utils (from sparktts/utils/audio.py) --- -def audio_volume_normalize(audio: np.ndarray, coeff: float = 0.2) -> np.ndarray: - temp = np.sort(np.abs(audio)) - if len(temp) == 0: # Handle empty audio case - return audio - if temp[-1] < 0.1: - scaling_factor = max(temp[-1], 1e-3) - audio = audio / scaling_factor * 0.1 - temp = temp[temp > 0.01] - L = temp.shape[0] - if L <= 10: - return audio - volume = np.mean(temp[int(0.9 * L) : int(0.99 * L)]) - if volume == 0: # Avoid division by zero if volume is effectively zero - return audio - audio = audio * np.clip(coeff / volume, a_min=0.1, a_max=10) - max_value = np.max(np.abs(audio)) if len(audio) > 0 else 0 - if max_value > 1: - audio = audio / max_value - return audio - -def load_audio( - adfile: Path, - sampling_rate: int = None, - length: int = None, - volume_normalize: bool = False, - segment_duration: int = None, -) -> np.ndarray: - try: - audio, sr = soundfile.read(adfile, dtype='float32') # Ensure float32 - except Exception as e: - raise IOError(f"Could not read audio file {adfile}: {e}") - - if audio is None or len(audio) == 0: - raise ValueError(f"Audio file {adfile} is empty or invalid.") - - if len(audio.shape) > 1: - audio = audio[:, 0] - - if sampling_rate is not None and sr != sampling_rate: - try: - # Ensure input is float64 for soxr - audio = audio.astype(np.float64) - audio = soxr.resample(audio, sr, sampling_rate, quality="VHQ") - # Convert back to float32 - audio = audio.astype(np.float32) - sr = sampling_rate - except Exception as e: - raise RuntimeError(f"Failed to resample audio from {sr}Hz to {sampling_rate}Hz: {e}") - - if segment_duration is not None: - seg_length = int(sr * segment_duration) - audio = random_select_audio_segment(audio, seg_length) - - if volume_normalize: - audio = audio_volume_normalize(audio) - - if length is not None: - if audio.shape[0] > length: - audio = audio[:length] - else: - audio = np.pad(audio, (0, int(length - audio.shape[0])), mode='constant') - return audio - -def random_select_audio_segment(audio: np.ndarray, length: int) -> np.ndarray: - if audio.shape[0] < length: - audio = np.pad(audio, (0, int(length - audio.shape[0])), mode='constant') - start_index = 0 # If padded, start from beginning - elif audio.shape[0] == length: - start_index = 0 # If exact length, start from beginning - else: - start_index = random.randint(0, audio.shape[0] - length) - - end_index = int(start_index + length) - return audio[start_index:end_index] - -# --- File Utils (Minimal required) --- -def load_config_yaml(config_path: Path) -> Dict: - """Loads a YAML configuration file using OmegaConf.""" - # Check if path exists - if not Path(config_path).is_file(): - raise FileNotFoundError(f"YAML Config file not found: {config_path}") - try: - config = OmegaConf.load(config_path) - # Convert OmegaConf DictConfig to standard Python dict - return OmegaConf.to_container(config, resolve=True) - except Exception as e: - raise IOError(f"Error loading YAML config file {config_path}: {e}") - - -""" PyTorch SparkTTS BiCodec sub-module definitions.""" - -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.distributed as dist -import random - -from torch.nn.utils import weight_norm, remove_weight_norm -from torch import Tensor, int32 -from torch.amp import autocast - -from typing import Any, Dict, List, Tuple, Optional -from collections import namedtuple -from functools import wraps, partial -from contextlib import nullcontext -from packaging import version - -from einops import rearrange, repeat, reduce, pack, unpack -from einops.layers.torch import Rearrange -from einx import get_at # Ensure einx is installed: pip install einx - -# =============================================================== -# Start: Content from sparktts/modules/blocks/layers.py -# =============================================================== -def WNConv1d(*args, **kwargs): - return weight_norm(nn.Conv1d(*args, **kwargs)) - - -def WNConvTranspose1d(*args, **kwargs): - return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) - - -# Scripting this brings model speed up 1.4x -@torch.jit.script -def snake(x, alpha): - shape = x.shape - x = x.reshape(shape[0], shape[1], -1) - x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2) - x = x.reshape(shape) - return x - - -class Snake1d(nn.Module): - def __init__(self, channels): - super().__init__() - self.alpha = nn.Parameter(torch.ones(1, channels, 1)) - - def forward(self, x): - return snake(x, self.alpha) - - -class ResidualUnit(nn.Module): - def __init__(self, dim: int = 16, dilation: int = 1): - super().__init__() - pad = ((7 - 1) * dilation) // 2 - self.block = nn.Sequential( - Snake1d(dim), - WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad), - Snake1d(dim), - WNConv1d(dim, dim, kernel_size=1), - ) - - def forward(self, x): - y = self.block(x) - # Adjust padding handling if input and output shapes differ - diff = x.shape[-1] - y.shape[-1] - if diff > 0: - pad = diff // 2 - x = x[..., pad:pad + y.shape[-1]] # Ensure shapes match for residual connection - elif diff < 0: - pad = -diff // 2 - y = y[..., pad:pad + x.shape[-1]] - - return x + y - - -def init_weights(m): - if isinstance(m, nn.Conv1d): - nn.init.trunc_normal_(m.weight, std=0.02) - if m.bias is not None: - nn.init.constant_(m.bias, 0) -# =============================================================== -# End: Content from sparktts/modules/blocks/layers.py -# =============================================================== - - -# =============================================================== -# Start: Content from sparktts/modules/blocks/samper.py -# =============================================================== -class SamplingBlock(nn.Module): - """Sampling block for upsampling or downsampling""" - - def __init__( - self, - dim: int, - groups: int = 1, - upsample_scale: int = 1, - downsample_scale: int = 1, - ) -> None: - """ - Args: - dim: input dimension - groups: number of groups - upsample_scale: upsampling scale - downsample_scale: downsampling scale - """ - super(SamplingBlock, self).__init__() - - self.upsample_scale = upsample_scale - self.downsample_scale = downsample_scale - - if self.upsample_scale > 1: - self.de_conv_upsampler = nn.Sequential( - nn.LeakyReLU(0.2), - nn.ConvTranspose1d( - dim, - dim, - kernel_size=upsample_scale * 2, - stride=upsample_scale, - padding=upsample_scale // 2 + upsample_scale % 2, - output_padding=upsample_scale % 2, - groups=groups, - ), - ) - - if self.downsample_scale > 1: - self.conv_downsampler = nn.Sequential( - nn.LeakyReLU(0.2), - nn.Conv1d( - dim, - dim, - kernel_size=2 * downsample_scale, - stride=downsample_scale, - padding=downsample_scale // 2 + downsample_scale % 2, - groups=groups, - ), - ) - - @staticmethod - def repeat_upsampler(x, upsample_scale): - return x.repeat_interleave(upsample_scale, dim=2) - - @staticmethod - def skip_downsampler(x, downsample_scale): - return F.avg_pool1d(x, kernel_size=downsample_scale, stride=downsample_scale) - - def forward(self, x): - # Input expected as (B, D, T) from VocosBackbone output (B, T, D) - # x = x.transpose(1, 2) # Remove this transpose, input should be (B, D, T) - if self.upsample_scale > 1: - repeat_res = self.repeat_upsampler(x, self.upsample_scale) - deconv_res = self.de_conv_upsampler(x) - # Ensure shapes match for addition - if deconv_res.shape[-1] > repeat_res.shape[-1]: - deconv_res = deconv_res[..., :repeat_res.shape[-1]] - elif repeat_res.shape[-1] > deconv_res.shape[-1]: - repeat_res = repeat_res[..., :deconv_res.shape[-1]] - upmerge_res = repeat_res + deconv_res - else: - upmerge_res = x - repeat_res = x - - if self.downsample_scale > 1: - conv_res = self.conv_downsampler(upmerge_res) - skip2_res = self.skip_downsampler(upmerge_res, self.downsample_scale) - skip1_res = self.skip_downsampler(repeat_res, self.downsample_scale) - # Ensure shapes match - min_len = min(conv_res.shape[-1], skip1_res.shape[-1], skip2_res.shape[-1]) - conv_res = conv_res[..., :min_len] - skip1_res = skip1_res[..., :min_len] - skip2_res = skip2_res[..., :min_len] - else: - conv_res = upmerge_res - skip2_res = upmerge_res - skip1_res = repeat_res - - final_res = conv_res + skip1_res + skip2_res - - # Return (B, D, T) for next VocosBackbone - # return final_res.transpose(1, 2) # Remove this, keep (B, D, T) - return final_res - -# =============================================================== -# End: Content from sparktts/modules/blocks/samper.py -# =============================================================== - - -# =============================================================== -# Start: Content from sparktts/modules/speaker/pooling_layers.py -# =============================================================== -class TAP(nn.Module): - """ - Temporal average pooling, only first-order mean is considered - """ - - def __init__(self, in_dim=0, **kwargs): - super(TAP, self).__init__() - self.in_dim = in_dim - - def forward(self, x): - pooling_mean = x.mean(dim=-1) - # To be compatable with 2D input - pooling_mean = pooling_mean.flatten(start_dim=1) - return pooling_mean - - def get_out_dim(self): - # This method seems specific to the original usage, might not be needed by HF - # self.out_dim = self.in_dim - # return self.out_dim - return self.in_dim - - -class TSDP(nn.Module): - """ - Temporal standard deviation pooling, only second-order std is considered - """ - - def __init__(self, in_dim=0, **kwargs): - super(TSDP, self).__init__() - self.in_dim = in_dim - - def forward(self, x): - # The last dimension is the temporal axis - pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-7) - pooling_std = pooling_std.flatten(start_dim=1) - return pooling_std - - def get_out_dim(self): - # self.out_dim = self.in_dim - # return self.out_dim - return self.in_dim - - -class TSTP(nn.Module): - """ - Temporal statistics pooling, concatenate mean and std, which is used in - x-vector - Comment: simple concatenation can not make full use of both statistics - """ - - def __init__(self, in_dim=0, **kwargs): - super(TSTP, self).__init__() - self.in_dim = in_dim - - def forward(self, x): - # The last dimension is the temporal axis - pooling_mean = x.mean(dim=-1) - pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-7) - pooling_mean = pooling_mean.flatten(start_dim=1) - pooling_std = pooling_std.flatten(start_dim=1) - stats = torch.cat((pooling_mean, pooling_std), 1) - return stats - - def get_out_dim(self): - # self.out_dim = self.in_dim * 2 - # return self.out_dim - return self.in_dim * 2 - - -class ASTP(nn.Module): - """ Attentive statistics pooling: Channel- and context-dependent - statistics pooling, first used in ECAPA_TDNN. - """ - - def __init__(self, - in_dim, - bottleneck_dim=128, - global_context_att=False, - **kwargs): - super(ASTP, self).__init__() - self.in_dim = in_dim - self.global_context_att = global_context_att - - # Use Conv1d with stride == 1 rather than Linear, then we don't - # need to transpose inputs. - if global_context_att: - self.linear1 = nn.Conv1d( - in_dim * 3, bottleneck_dim, - kernel_size=1) # equals W and b in the paper - else: - self.linear1 = nn.Conv1d( - in_dim, bottleneck_dim, - kernel_size=1) # equals W and b in the paper - self.linear2 = nn.Conv1d(bottleneck_dim, in_dim, - kernel_size=1) # equals V and k in the paper - - def forward(self, x): - """ - x: a 3-dimensional tensor in tdnn-based architecture (B,F,T) - or a 4-dimensional tensor in resnet architecture (B,C,F,T) - 0-dim: batch-dimension, last-dim: time-dimension (frame-dimension) - """ - if len(x.shape) == 4: - x = x.reshape(x.shape[0], x.shape[1] * x.shape[2], x.shape[3]) - assert len(x.shape) == 3 - - if self.global_context_att: - context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x) - context_std = torch.sqrt( - torch.var(x, dim=-1, keepdim=True) + 1e-7).expand_as(x) - x_in = torch.cat((x, context_mean, context_std), dim=1) - else: - x_in = x - - # DON'T use ReLU here! ReLU may be hard to converge. - alpha = torch.tanh( - self.linear1(x_in)) # alpha = F.relu(self.linear1(x_in)) - alpha = torch.softmax(self.linear2(alpha), dim=2) - mean = torch.sum(alpha * x, dim=2) - var = torch.sum(alpha * (x**2), dim=2) - mean**2 - std = torch.sqrt(var.clamp(min=1e-7)) - return torch.cat([mean, std], dim=1) - - def get_out_dim(self): - # self.out_dim = 2 * self.in_dim - # return self.out_dim - return self.in_dim * 2 - - -class MHASTP(torch.nn.Module): - """ Multi head attentive statistics pooling - Reference: - Self Multi-Head Attention for Speaker Recognition - https://arxiv.org/pdf/1906.09890.pdf - """ - - def __init__(self, - in_dim, - layer_num=2, - head_num=2, - d_s=1, - bottleneck_dim=64, - **kwargs): - super(MHASTP, self).__init__() - assert (in_dim % head_num - ) == 0 # make sure that head num can be divided by input_dim - self.in_dim = in_dim - self.head_num = head_num - d_model = int(in_dim / head_num) - channel_dims = [bottleneck_dim for i in range(layer_num + 1)] - if d_s > 1: - d_s = d_model - else: - d_s = 1 - self.d_s = d_s - channel_dims[0], channel_dims[-1] = d_model, d_s - heads_att_trans = [] - for i in range(self.head_num): - att_trans = nn.Sequential() - for j in range(layer_num - 1): # Use different loop variable - att_trans.add_module( - 'att_' + str(j), - nn.Conv1d(channel_dims[j], channel_dims[j + 1], 1, 1)) - att_trans.add_module('tanh' + str(j), nn.Tanh()) - att_trans.add_module( - 'att_' + str(layer_num - 1), - nn.Conv1d(channel_dims[layer_num - 1], channel_dims[layer_num], - 1, 1)) - heads_att_trans.append(att_trans) - self.heads_att_trans = nn.ModuleList(heads_att_trans) - - def forward(self, input): - """ - input: a 3-dimensional tensor in xvector architecture - or a 4-dimensional tensor in resnet architecture - 0-dim: batch-dimension, last-dim: time-dimension (frame-dimension) - """ - if len(input.shape) == 4: # B x C x F x T - input = input.reshape(input.shape[0], - input.shape[1] * input.shape[2], - input.shape[3]) # B x (C*F) x T - assert len(input.shape) == 3 - bs, f_dim, t_dim = input.shape - chunks = torch.chunk(input, self.head_num, 1) - # split - chunks_out = [] - for i, layer in enumerate(self.heads_att_trans): - att_score = layer(chunks[i]) - alpha = F.softmax(att_score, dim=-1) - mean = torch.sum(alpha * chunks[i], dim=2) - var = torch.sum(alpha * chunks[i]**2, dim=2) - mean**2 - std = torch.sqrt(var.clamp(min=1e-7)) - chunks_out.append(torch.cat((mean, std), dim=1)) - out = torch.cat(chunks_out, dim=1) - return out - - def get_out_dim(self): - # self.out_dim = 2 * self.in_dim - # return self.out_dim - return self.in_dim * 2 - - -class MQMHASTP(torch.nn.Module): - """ An attentive pooling - Reference: - multi query multi head attentive statistics pooling - https://arxiv.org/pdf/2110.05042.pdf - Args: - in_dim: the feature dimension of input - layer_num: the number of layer in the pooling layer - query_num: the number of querys - head_num: the number of heads - bottleneck_dim: the bottleneck dimension - - SA (H = 1, Q = 1, n = 2, d_s = 1) ref: - https://www.danielpovey.com/files/2018_interspeech_xvector_attention.pdf - MHA (H > 1, Q = 1, n = 1, d_s = 1) ref: - https://arxiv.org/pdf/1906.09890.pdf - AS (H = 1, Q > 1, n = 2, d_s = 1) ref: - https://arxiv.org/pdf/1803.10963.pdf - VSA (H = 1, Q > 1, n = 2, d_s = d_h) ref: - http://www.interspeech2020.org/uploadfile/pdf/Mon-2-10-5.pdf - """ - - def __init__(self, - in_dim, - layer_num=2, - query_num=2, - head_num=8, - d_s=2, - bottleneck_dim=64, - **kwargs): - super(MQMHASTP, self).__init__() - self.n_query = nn.ModuleList([ - MHASTP(in_dim, - layer_num=layer_num, - head_num=head_num, - d_s=d_s, - bottleneck_dim=bottleneck_dim) for i in range(query_num) - ]) - self.query_num = query_num - self.in_dim = in_dim - - def forward(self, input): - """ - input: a 3-dimensional tensor in xvector architecture - or a 4-dimensional tensor in resnet architecture - 0-dim: batch-dimension, last-dim: time-dimension (frame-dimension) - """ - if len(input.shape) == 4: # B x C x F x T - input = input.reshape(input.shape[0], - input.shape[1] * input.shape[2], - input.shape[3]) # B x (C*F) x T - assert len(input.shape) == 3 - res = [] - for i, layer in enumerate(self.n_query): - res.append(layer(input)) - out = torch.cat(res, dim=-1) - return out - - def get_out_dim(self): - # self.out_dim = self.in_dim * 2 * self.query_num - # return self.out_dim - return self.in_dim * 2 * self.query_num - -# =============================================================== -# End: Content from sparktts/modules/speaker/pooling_layers.py -# =============================================================== - - -# =============================================================== -# Start: Content from sparktts/modules/blocks/vocos.py -# =============================================================== -# Helper functions needed by VocosBackbone etc. -def exists(val): - return val is not None - -def default(val, d): - return val if exists(val) else d() if callable(d) else d - -class AdaLayerNorm(nn.Module): - """ - Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes - - Args: - condition_dim (int): Dimension of the condition. - embedding_dim (int): Dimension of the embeddings. - """ - - def __init__(self, condition_dim: int, embedding_dim: int, eps: float = 1e-6): - super().__init__() - self.eps = eps - self.dim = embedding_dim - self.scale = nn.Linear(condition_dim, embedding_dim) - self.shift = nn.Linear(condition_dim, embedding_dim) - # Initialize weights similar to original implementation if needed - # torch.nn.init.ones_(self.scale.weight) # Might be default - # torch.nn.init.zeros_(self.shift.weight) # Might be default - if self.scale.bias is not None: nn.init.zeros_(self.scale.bias) - if self.shift.bias is not None: nn.init.zeros_(self.shift.bias) - - def forward(self, x: torch.Tensor, cond_embedding: torch.Tensor) -> torch.Tensor: - scale = self.scale(cond_embedding) - shift = self.shift(cond_embedding) - x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps) - x = x * scale.unsqueeze(1) + shift.unsqueeze(1) - return x - - -class ConvNeXtBlock(nn.Module): - """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal. - - Args: - dim (int): Number of input channels. - intermediate_dim (int): Dimensionality of the intermediate layer. - layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling. - Defaults to None. - condition_dim (int, optional): Dimension for AdaLayerNorm. - None means non-conditional LayerNorm. Defaults to None. - """ - - def __init__( - self, - dim: int, - intermediate_dim: int, - layer_scale_init_value: float, - condition_dim: Optional[int] = None, - ): - super().__init__() - self.dwconv = nn.Conv1d( - dim, dim, kernel_size=7, padding=3, groups=dim - ) # depthwise conv - self.adanorm = condition_dim is not None - if self.adanorm: - self.norm = AdaLayerNorm(condition_dim, dim, eps=1e-6) - else: - self.norm = nn.LayerNorm(dim, eps=1e-6) - self.pwconv1 = nn.Linear( - dim, intermediate_dim - ) # pointwise/1x1 convs, implemented with linear layers - self.act = nn.GELU() - self.pwconv2 = nn.Linear(intermediate_dim, dim) - self.gamma = ( - nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True) - if layer_scale_init_value is not None and layer_scale_init_value > 0 - else None - ) - - def forward( - self, x: torch.Tensor, cond_embedding: Optional[torch.Tensor] = None - ) -> torch.Tensor: - residual = x - x = self.dwconv(x) - x = x.transpose(1, 2) # (B, C, T) -> (B, T, C) - if self.adanorm: - assert cond_embedding is not None, "Conditioning embedding required for AdaLayerNorm" - x = self.norm(x, cond_embedding) - else: - x = self.norm(x) - x = self.pwconv1(x) - x = self.act(x) - x = self.pwconv2(x) - if self.gamma is not None: - x = self.gamma * x - x = x.transpose(1, 2) # (B, T, C) -> (B, C, T) - - x = residual + x - return x - - -class ResBlock1(nn.Module): - """ - ResBlock adapted from HiFi-GAN V1 (https://github.com/jik876/hifi-gan) with dilated 1D convolutions, - but without upsampling layers. - - Args: - dim (int): Number of input channels. - kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3. - dilation (tuple[int], optional): Dilation factors for the dilated convolutions. - Defaults to (1, 3, 5). - lrelu_slope (float, optional): Negative slope of the LeakyReLU activation function. - Defaults to 0.1. - layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling. - Defaults to None. - """ - - def __init__( - self, - dim: int, - kernel_size: int = 3, - dilation: Tuple[int, int, int] = (1, 3, 5), - lrelu_slope: float = 0.1, - layer_scale_init_value: Optional[float] = None, - ): - super().__init__() - self.lrelu_slope = lrelu_slope - self.convs1 = nn.ModuleList( - [ - weight_norm( - nn.Conv1d( - dim, - dim, - kernel_size, - 1, - dilation=dilation[0], - padding=self.get_padding(kernel_size, dilation[0]), - ) - ), - weight_norm( - nn.Conv1d( - dim, - dim, - kernel_size, - 1, - dilation=dilation[1], - padding=self.get_padding(kernel_size, dilation[1]), - ) - ), - weight_norm( - nn.Conv1d( - dim, - dim, - kernel_size, - 1, - dilation=dilation[2], - padding=self.get_padding(kernel_size, dilation[2]), - ) - ), - ] - ) - - self.convs2 = nn.ModuleList( - [ - weight_norm( - nn.Conv1d( - dim, - dim, - kernel_size, - 1, - dilation=1, - padding=self.get_padding(kernel_size, 1), - ) - ), - weight_norm( - nn.Conv1d( - dim, - dim, - kernel_size, - 1, - dilation=1, - padding=self.get_padding(kernel_size, 1), - ) - ), - weight_norm( - nn.Conv1d( - dim, - dim, - kernel_size, - 1, - dilation=1, - padding=self.get_padding(kernel_size, 1), - ) - ), - ] - ) - - self.gamma = nn.ParameterList( - [ - ( - nn.Parameter( - layer_scale_init_value * torch.ones(dim, 1), requires_grad=True - ) - if layer_scale_init_value is not None - else None - ), - ( - nn.Parameter( - layer_scale_init_value * torch.ones(dim, 1), requires_grad=True - ) - if layer_scale_init_value is not None - else None - ), - ( - nn.Parameter( - layer_scale_init_value * torch.ones(dim, 1), requires_grad=True - ) - if layer_scale_init_value is not None - else None - ), - ] - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - for c1, c2, gamma in zip(self.convs1, self.convs2, self.gamma): - xt = torch.nn.functional.leaky_relu(x, negative_slope=self.lrelu_slope) - xt = c1(xt) - xt = torch.nn.functional.leaky_relu(xt, negative_slope=self.lrelu_slope) - xt = c2(xt) - if gamma is not None: - xt = gamma * xt - x = xt + x - return x - - def remove_weight_norm(self): - for l in self.convs1: - remove_weight_norm(l) - for l in self.convs2: - remove_weight_norm(l) - - @staticmethod - def get_padding(kernel_size: int, dilation: int = 1) -> int: - return int((kernel_size * dilation - dilation) / 2) - - -class Backbone(nn.Module): - """Base class for the generator's backbone. It preserves the same temporal resolution across all layers.""" - - def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: - """ - Args: - x (Tensor): Input tensor of shape (B, C, L), where B is the batch size, - C denotes input features, and L is the sequence length. - - Returns: - Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length, - and H denotes the model dimension. - """ - raise NotImplementedError("Subclasses must implement the forward method.") - - -class VocosBackbone(Backbone): - """ - Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization - - Args: - input_channels (int): Number of input features channels. - dim (int): Hidden dimension of the model. - intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock. - num_layers (int): Number of ConvNeXtBlock layers. - layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`. - condition_dim (int, optional): Dimension for AdaLayerNorm. - None means non-conditional model. Defaults to None. - """ - - def __init__( - self, - input_channels: int, - dim: int, - intermediate_dim: int, - num_layers: int, - layer_scale_init_value: Optional[float] = None, - condition_dim: Optional[int] = None, - ): - super().__init__() - self.input_channels = input_channels - self.embed = nn.Conv1d(input_channels, dim, kernel_size=7, padding=3) - self.adanorm = condition_dim is not None - if self.adanorm: - self.norm = AdaLayerNorm(condition_dim, dim, eps=1e-6) - else: - self.norm = nn.LayerNorm(dim, eps=1e-6) - layer_scale_init_value = layer_scale_init_value or 1 / num_layers if num_layers > 0 else None # Handle num_layers=0 - self.convnext = nn.ModuleList( - [ - ConvNeXtBlock( - dim=dim, - intermediate_dim=intermediate_dim, - layer_scale_init_value=layer_scale_init_value, - condition_dim=condition_dim, - ) - for _ in range(num_layers) - ] - ) - self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6) - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, (nn.Conv1d, nn.Linear)): - nn.init.trunc_normal_(m.weight, std=0.02) - if m.bias is not None: - nn.init.constant_(m.bias, 0) - - def forward(self, x: torch.Tensor, condition: Optional[torch.Tensor] = None) -> torch.Tensor: - # Input x: (B, C, L) - x = self.embed(x) - # After embed: (B, dim, L) - x_transposed = x.transpose(1, 2) # (B, L, dim) - if self.adanorm: - assert condition is not None - norm_out = self.norm(x_transposed, condition) - else: - norm_out = self.norm(x_transposed) - # After norm: (B, L, dim) - x = norm_out.transpose(1, 2) # (B, dim, L) - for conv_block in self.convnext: - x = conv_block(x, condition) - # After convnext blocks: (B, dim, L) - x = self.final_layer_norm(x.transpose(1, 2)) # (B, L, dim) - return x - - -class VocosResNetBackbone(Backbone): - """ - Vocos backbone module built with ResBlocks. - - Args: - input_channels (int): Number of input features channels. - dim (int): Hidden dimension of the model. - num_blocks (int): Number of ResBlock1 blocks. - layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to None. - """ - - def __init__( - self, - input_channels, - dim, - num_blocks, - layer_scale_init_value=None, - ): - super().__init__() - self.input_channels = input_channels - self.embed = weight_norm( - nn.Conv1d(input_channels, dim, kernel_size=3, padding=1) - ) - layer_scale_init_value = layer_scale_init_value or 1 / num_blocks / 3 if num_blocks > 0 else None # Handle num_blocks=0 - self.resnet = nn.Sequential( - *[ - ResBlock1(dim=dim, layer_scale_init_value=layer_scale_init_value) - for _ in range(num_blocks) - ] - ) - - def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: - # Input x: (B, C, L) - x = self.embed(x) - # After embed: (B, dim, L) - x = self.resnet(x) - # After resnet: (B, dim, L) - x = x.transpose(1, 2) # (B, L, dim) - return x - -# =============================================================== -# End: Content from sparktts/modules/blocks/vocos.py -# =============================================================== - - -# =============================================================== -# Start: Content from sparktts/modules/encoder_decoder/feat_decoder.py -# =============================================================== -class Decoder(nn.Module): - """Decoder module with convnext and upsampling blocks - - Args: - sample_ratios (List[int]): sample ratios - example: [2, 2] means upsample by 2x and then upsample by 2x - """ - - def __init__( - self, - input_channels: int, - vocos_dim: int, - vocos_intermediate_dim: int, - vocos_num_layers: int, - out_channels: int, - condition_dim: int = None, - sample_ratios: List[int] = [1, 1], - use_tanh_at_final: bool = False, - ): - super().__init__() - - self.linear_pre = nn.Linear(input_channels, vocos_dim) - - upsample_modules = [] - current_dim = vocos_dim - for i, ratio in enumerate(sample_ratios): - upsample_modules.append( - nn.Sequential( - SamplingBlock( - dim=current_dim, - groups=current_dim, # Maybe use 1 or fewer groups if dim is high? Check original intent. Using current_dim for now. - upsample_scale=ratio, - ), - # Note: The original code used VocosBackbone here, but it changes dims B,T,D -> B,D,T. - # SamplingBlock output is B,D,T, so VocosBackbone input matches. - # However, the VocosBackbone output is B,T,D, which doesn't fit the next SamplingBlock. - # Assuming the intent was to keep B,D,T format between sampling blocks. - # Replacing intermediate VocosBackbone with a simple Conv1d block to maintain format & refine. - nn.Conv1d(current_dim, current_dim, kernel_size=3, padding=1) # Simple refinement layer - # VocosBackbone( - # input_channels=current_dim, - # dim=current_dim, - # intermediate_dim=vocos_intermediate_dim // 2, # Smaller intermediate for efficiency? - # num_layers=2, # Fewer layers - # condition_dim=None, - # ) - ) - ) - # No dimension change expected here if using Conv1d refinement - # If using VocosBackbone, need transpose logic - - self.upsample = nn.Sequential(*upsample_modules) - - # Final Backbone processes the fully upsampled features - self.vocos_backbone = VocosBackbone( - input_channels=current_dim, # Use the dim after upsampling - dim=vocos_dim, # Map back to main vocos_dim or keep current_dim? Using vocos_dim - intermediate_dim=vocos_intermediate_dim, - num_layers=vocos_num_layers, - condition_dim=condition_dim, - ) - self.linear_post = nn.Linear(vocos_dim, out_channels) - self.use_tanh_at_final = use_tanh_at_final - - def forward(self, x: torch.Tensor, c: torch.Tensor = None): - """decoder forward. - - Args: - x (torch.Tensor): (batch_size, input_channels, length) - c (torch.Tensor): (batch_size, condition_dim) - Optional condition - - Returns: - x (torch.Tensor): (batch_size, out_channels, length_upsampled) - """ - # x: (B, C_in, T) - x = self.linear_pre(x.transpose(1, 2)) # (B, T, vocos_dim) - x = x.transpose(1, 2) # (B, vocos_dim, T) - - # Apply upsampling blocks - x = self.upsample(x) # (B, vocos_dim, T_upsampled) - - # Apply final backbone - x = self.vocos_backbone(x, condition=c) # (B, T_upsampled, vocos_dim) - - x = self.linear_post(x) # (B, T_upsampled, C_out) - x = x.transpose(1, 2) # (B, C_out, T_upsampled) - - if self.use_tanh_at_final: - x = torch.tanh(x) - - return x - -# =============================================================== -# End: Content from sparktts/modules/encoder_decoder/feat_decoder.py -# =============================================================== - - -# =============================================================== -# Start: Content from sparktts/modules/encoder_decoder/feat_encoder.py -# =============================================================== -class Encoder(nn.Module): - """Encoder module with convnext and downsampling blocks""" - - def __init__( - self, - input_channels: int, - vocos_dim: int, - vocos_intermediate_dim: int, - vocos_num_layers: int, - out_channels: int, - sample_ratios: List[int] = [1, 1], - ): - super().__init__() - """ - Encoder module with VocosBackbone and sampling blocks. - - Args: - sample_ratios (List[int]): sample ratios - example: [2, 2] means downsample by 2x and then downsample by 2x - """ - # Initial Backbone processing - self.encoder_backbone = VocosBackbone( - input_channels=input_channels, - dim=vocos_dim, - intermediate_dim=vocos_intermediate_dim, - num_layers=vocos_num_layers, # Use main num_layers here - condition_dim=None, - ) - - downsample_modules = [] - current_dim = vocos_dim - for i, ratio in enumerate(sample_ratios): - downsample_modules.append( - nn.Sequential( - SamplingBlock( - dim=current_dim, - groups=current_dim, # Again, check group size. Using current_dim. - downsample_scale=ratio, - ), - # Add refinement layer (optional, similar to Decoder logic) - nn.Conv1d(current_dim, current_dim, kernel_size=3, padding=1) - # VocosBackbone( # Or a lighter VocosBackbone - # input_channels=current_dim, - # dim=current_dim, - # intermediate_dim=vocos_intermediate_dim // 2, - # num_layers=2, - # condition_dim=None, - # ) - ) - ) - # No dimension change expected here - - self.downsample = nn.Sequential(*downsample_modules) - - self.project = nn.Linear(current_dim, out_channels) # Project from the final dimension - - def forward(self, x: torch.Tensor, *args): - """ - Args: - x (torch.Tensor): (batch_size, input_channels, length) - - Returns: - x (torch.Tensor): (batch_size, out_channels, length_downsampled) - """ - # x: (B, C_in, T) - x = self.encoder_backbone(x) # (B, T, vocos_dim) - x = x.transpose(1, 2) # (B, vocos_dim, T) - - # Apply downsampling blocks - x = self.downsample(x) # (B, vocos_dim, T_downsampled) - - x = x.transpose(1, 2) # (B, T_downsampled, vocos_dim) - x = self.project(x) # (B, T_downsampled, C_out) - return x.transpose(1, 2) # (B, C_out, T_downsampled) - -# =============================================================== -# End: Content from sparktts/modules/encoder_decoder/feat_encoder.py -# =============================================================== - - -# =============================================================== -# Start: Content from sparktts/modules/encoder_decoder/wave_generator.py -# =============================================================== -class DecoderBlock(nn.Module): - def __init__( - self, - input_dim: int = 16, - output_dim: int = 8, - kernel_size: int = 2, - stride: int = 1, - ): - super().__init__() - # Ensure stride is at least 1 - stride = max(1, stride) - # Ensure kernel_size is valid for ConvTranspose1d - if kernel_size < stride: - kernel_size = stride # Or handle differently - - padding = (kernel_size - stride) // 2 - output_padding = stride % 2 if kernel_size % 2 == 0 else 0 # Basic calculation, might need adjustment based on desired output length - - # print(f"DecoderBlock - Input: {input_dim}, Output: {output_dim}, Kernel: {kernel_size}, Stride: {stride}, Padding: {padding}, OutputPadding: {output_padding}") - - - self.block = nn.Sequential( - Snake1d(input_dim), - WNConvTranspose1d( - input_dim, - output_dim, - kernel_size=kernel_size, - stride=stride, - padding=padding, - output_padding=output_padding, # Add output_padding - ), - ResidualUnit(output_dim, dilation=1), - ResidualUnit(output_dim, dilation=3), - ResidualUnit(output_dim, dilation=9), - ) - - def forward(self, x): - return self.block(x) - - -class WaveGenerator(nn.Module): - def __init__( - self, - input_channel, - channels, - rates, - kernel_sizes, - d_out: int = 1, - ): - super().__init__() - - # Add first conv layer - layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)] - - # Add upsampling + MRF blocks - current_channels = channels - for i, (kernel_size, stride) in enumerate(zip(kernel_sizes, rates)): - input_dim = current_channels - # Ensure output_dim doesn't go below 1 - output_dim = max(1, channels // (2 ** (i + 1))) - layers += [DecoderBlock(input_dim, output_dim, kernel_size, stride)] - current_channels = output_dim # Update for the next block's input - - # Add final conv layer - layers += [ - Snake1d(current_channels), # Use the final output_dim - WNConv1d(current_channels, d_out, kernel_size=7, padding=3), - nn.Tanh(), - ] - - self.model = nn.Sequential(*layers) - - self.apply(init_weights) # Apply weight initialization - - def forward(self, x): - return self.model(x) - -# =============================================================== -# End: Content from sparktts/modules/encoder_decoder/wave_generator.py -# =============================================================== - - -# =============================================================== -# Start: Content from sparktts/modules/fsq/finite_scalar_quantization.py -# =============================================================== -# helper functions moved earlier -def round_ste(z: Tensor) -> Tensor: - """Round with straight through gradients.""" - zhat = z.round() - return z + (zhat - z).detach() - - -class FSQ(nn.Module): - def __init__( - self, - levels: List[int], - dim: int | None = None, - num_codebooks=1, - keep_num_codebooks_dim: bool | None = None, - scale: float | None = None, - allowed_dtypes: Tuple[torch.dtype, ...] = (torch.float32, torch.float64), - channel_first: bool = False, # Added based on usage in ResidualFSQ - projection_has_bias: bool = True, - return_indices=True, - force_quantization_f32=True, - ): - super().__init__() - _levels = torch.tensor(levels, dtype=int32) - self.register_buffer("_levels", _levels, persistent=False) - - _basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=int32) - self.register_buffer("_basis", _basis, persistent=False) - - self.scale = scale # Not used in current implementation, but kept - - codebook_dim = len(levels) - self.codebook_dim = codebook_dim - - effective_codebook_dim = codebook_dim * num_codebooks - self.num_codebooks = num_codebooks - self.effective_codebook_dim = effective_codebook_dim - - # keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1) - # Force keep_num_codebooks_dim to False if num_codebooks is 1 - if num_codebooks == 1: - keep_num_codebooks_dim = False - else: - keep_num_codebooks_dim = default(keep_num_codebooks_dim, True) - - # Original assert was checking if num_codebooks > 1 and keep_num_codebooks_dim is False. Let's refine. - # If num_codebooks > 1, keep_num_codebooks_dim must be True based on how rearrange is used. - if num_codebooks > 1 and not keep_num_codebooks_dim: - raise ValueError("If num_codebooks > 1, keep_num_codebooks_dim must be True or None (defaults to True).") - self.keep_num_codebooks_dim = keep_num_codebooks_dim - - - self.dim = default(dim, len(_levels) * num_codebooks) - - self.channel_first = channel_first # Store channel_first setting - - has_projections = self.dim != effective_codebook_dim - self.project_in = ( - nn.Linear(self.dim, effective_codebook_dim, bias=projection_has_bias) - if has_projections - else nn.Identity() - ) - self.project_out = ( - nn.Linear(effective_codebook_dim, self.dim, bias=projection_has_bias) - if has_projections - else nn.Identity() - ) - - self.has_projections = has_projections - - self.return_indices = return_indices - if return_indices: - self.codebook_size = self._levels.prod().item() - # Calculate implicit codebook based on current device during forward pass if needed - # For now, calculate assuming CPU and move later if necessary - # implicit_codebook = self._indices_to_codes(torch.arange(self.codebook_size, device=self._levels.device)) # Calculate on device - # self.register_buffer("implicit_codebook", implicit_codebook, persistent=False) - - self.allowed_dtypes = allowed_dtypes - self.force_quantization_f32 = force_quantization_f32 - - @property - def implicit_codebook(self): - # Calculate implicit codebook on the fly using the device of _levels - device = self._levels.device - indices = torch.arange(self.codebook_size, device=device) - return self._indices_to_codes(indices) - - - def bound(self, z, eps: float = 1e-3): - """Bound `z`, an array of shape (..., d).""" - levels = self._levels.to(z.device) # Ensure levels are on same device - half_l = (levels - 1) * (1 + eps) / 2 - offset = torch.where(levels % 2 == 0, 0.5, 0.0) - shift = (offset / half_l).atanh() if torch.any(half_l != 0) else torch.zeros_like(offset) # Avoid div by zero - # Ensure shift is compatible shape for broadcasting - shift = shift.view(1, 1, -1) if z.ndim == 3 else shift # Adjust based on z dims - half_l = half_l.view(1, 1, -1) if z.ndim == 3 else half_l - - # Clamp input to avoid inf/-inf in atanh - z_clipped = torch.clamp(z, min=-1.0 + eps, max=1.0 - eps) # Assuming input z is somewhat normalized? - - # Original formula might be sensitive, let's try direct clamping. - # return (z + shift).tanh() * half_l - offset - - # Alternative clamping approach (from original Jax version logic): - upper_bound = (levels - 1) / 2 - lower_bound = -upper_bound - upper_bound = upper_bound.view(1, 1, -1) if z.ndim == 3 else upper_bound - lower_bound = lower_bound.view(1, 1, -1) if z.ndim == 3 else lower_bound - - return torch.clamp(z, min=lower_bound, max=upper_bound) - - - def quantize(self, z): - """Quantizes z, returns quantized zhat, same shape as z.""" - quantized = round_ste(self.bound(z)) - levels = self._levels.to(z.device) - half_width = levels // 2 # Renormalize to [-1, 1]. - # Avoid division by zero if level is 1 - half_width = torch.where(half_width == 0, torch.tensor(1.0, device=z.device), half_width.float()) - half_width_view = half_width.view(1, 1, -1) if quantized.ndim == 3 else half_width - return quantized / half_width_view - - def _scale_and_shift(self, zhat_normalized): - levels = self._levels.to(zhat_normalized.device) - half_width = levels // 2 - half_width_view = half_width.view(1, 1, -1) if zhat_normalized.ndim == 3 else half_width - return (zhat_normalized * half_width_view) + half_width_view - - def _scale_and_shift_inverse(self, zhat): - levels = self._levels.to(zhat.device) - half_width = levels // 2 - # Avoid division by zero if level is 1 - half_width = torch.where(half_width == 0, torch.tensor(1.0, device=zhat.device), half_width.float()) - half_width_view = half_width.view(1, 1, -1) if zhat.ndim == 3 else half_width - return (zhat - half_width_view) / half_width_view - - def _indices_to_codes(self, indices): - level_indices = self.indices_to_level_indices(indices) - codes = self._scale_and_shift_inverse(level_indices.float()) # Convert level indices to float - return codes - - def codes_to_indices(self, zhat): - """Converts a `code` to an index in the codebook.""" - assert zhat.shape[-1] == self.codebook_dim - zhat_scaled = self._scale_and_shift(zhat) - # Ensure basis is on the correct device and dtype, handle potential shape mismatch - basis = self._basis.to(zhat.device, dtype=int32) - basis_view = basis.view(1, 1, -1) if zhat_scaled.ndim == 3 else basis # Match ndim - # Ensure zhat_scaled is integer type for multiplication with basis - product = (zhat_scaled * basis_view).round().int() - return product.sum(dim=-1).to(int32) - - def indices_to_level_indices(self, indices): - """Converts indices to indices at each level, perhaps needed for a transformer with factorized embeddings""" - indices_reshaped = rearrange(indices, "... -> ... 1") - basis = self._basis.to(indices.device) - levels = self._levels.to(indices.device) - # Ensure basis and levels match the device and potentially ndim of indices - basis_view = basis.view(*([1] * (indices_reshaped.ndim - 1)), -1) - levels_view = levels.view(*([1] * (indices_reshaped.ndim - 1)), -1) - - codes_non_centered = (indices_reshaped // basis_view) % levels_view - return codes_non_centered - - # indices_to_codes is now handled by implicit_codebook property + project_out if needed - - def forward(self, z): - """ - einstein notation - b - batch - ... - sequence, spatial dimensions - d - feature dimension - c - number of codebook dim (within a single quantizer) - g - number of quantizers (groups) - handled by ResidualFSQ/GroupedResidualFSQ - """ - - # Input z can be (b d ...) or (b ... d) - # self.channel_first determines the expected input format for projection - - if self.channel_first: - # Expects (b d ...) - if z.ndim > 2: # Has spatial/temporal dims - z = rearrange(z, "b d ... -> b ... d") - z, ps = pack([z], "b * d") - # else: z is (b d) -> processed directly by linear - else: - # Expects (b ... d) - if z.ndim > 2: - z, ps = pack([z], "b * d") - # else: z is (b d) -> processed directly by linear - - - assert ( - z.shape[-1] == self.dim - ), f"expected dimension of {self.dim} but found dimension of {z.shape[-1]}" - - # Project in - z_projected = self.project_in(z) # (b ... effective_codebook_dim) - - # Reshape for codebooks if num_codebooks > 1 - if self.num_codebooks > 1: - z_reshaped = rearrange(z_projected, "b ... (c d) -> b ... c d", c=self.num_codebooks) - else: - # Add a dummy codebook dim for consistent processing - z_reshaped = rearrange(z_projected, "b ... d -> b ... 1 d") - - # Force quantization step to be full precision or not - force_f32 = self.force_quantization_f32 - quantization_context = ( - partial(autocast, "cuda", enabled=False) if force_f32 else nullcontext - ) - - codes = None - indices = None - - with quantization_context(): - orig_dtype = z_reshaped.dtype - - if force_f32 and orig_dtype not in self.allowed_dtypes: - z_for_quant = z_reshaped.float() - else: - z_for_quant = z_reshaped - - codes = self.quantize(z_for_quant) # (b ... c d) - - if self.return_indices: - indices = self.codes_to_indices(codes) # (b ... c) - - # Convert codes back to original dtype if changed - codes = codes.type(orig_dtype) - - - # Reshape codes back and project out - if self.num_codebooks > 1: - codes_reshaped = rearrange(codes, "b ... c d -> b ... (c d)") - else: - codes_reshaped = rearrange(codes, "b ... 1 d -> b ... d") - - out = self.project_out(codes_reshaped) # (b ... dim) - - # Restore original spatial/temporal dimensions - if z.ndim > 2: # If we packed dimensions - out = unpack(out, ps, "b * d")[0] - if self.return_indices: - indices = unpack(indices, ps, "b * c")[0] - - # Restore channel dimension if needed - if self.channel_first and out.ndim > 2: - out = rearrange(out, "b ... d -> b d ...") - if self.return_indices and indices.ndim > 1: # Check indices ndim - # Indices shape (b ... c), need to decide how to handle channel dim - # Often indices might not need channel dim, depends on usage - # If indices are e.g. (b H W c), permuting might be complex. - # Keeping indices as (b ... c) for now. - pass - - - # Remove the dummy codebook dim from indices if num_codebooks was 1 - if self.return_indices and self.num_codebooks == 1 and not self.keep_num_codebooks_dim: - indices = indices.squeeze(-1) - - - return out, indices -# =============================================================== -# End: Content from sparktts/modules/fsq/finite_scalar_quantization.py -# =============================================================== - - -# =============================================================== -# Start: Content from sparktts/modules/fsq/residual_fsq.py -# =============================================================== -# Helper functions needed by ResidualFSQ -def is_distributed(): - return dist.is_initialized() and dist.get_world_size() > 1 - -def get_maybe_sync_seed(device, max_size=10_000): - rand_int = torch.randint(0, max_size, (), device=device) - if is_distributed(): - # Ensure rand_int is on the correct device for all_reduce - if rand_int.device != device: - rand_int = rand_int.to(device) - dist.all_reduce(rand_int) - return rand_int.item() - -def round_up_multiple(num, mult): - # Ensure mult is positive - if mult <= 0: - return num - # Use ceiling division - return (num + mult - 1) // mult * mult - - -class ResidualFSQ(nn.Module): - """Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf""" - - def __init__( - self, - *, - levels: List[int], - num_quantizers, - dim=None, - # is_channel_first=False, # Handled inside FSQ now - quantize_dropout=False, - quantize_dropout_cutoff_index=0, - quantize_dropout_multiple_of=1, - channel_first: bool = False, # Pass channel_first to FSQ - **kwargs, # Pass remaining kwargs to FSQ - ): - super().__init__() - codebook_dim = len(levels) - dim = default(dim, codebook_dim) - - requires_projection = codebook_dim != dim - self.project_in = ( - nn.Linear(dim, codebook_dim) if requires_projection else nn.Identity() - ) - self.project_out = ( - nn.Linear(codebook_dim, dim) if requires_projection else nn.Identity() - ) - self.has_projections = requires_projection - - self.channel_first = channel_first # Store for potential shape adjustments if needed later - self.num_quantizers = num_quantizers - - self.levels = levels - self.layers = nn.ModuleList([]) - - levels_tensor = torch.Tensor(levels) - - scales = [] - - for ind in range(num_quantizers): - # Calculate scale: (levels - 1) is max value range (- (l-1)/2 to +(l-1)/2) - # Residual is divided by scale before quantization - # Effective scale for quantizer 'ind' is (levels - 1)^ind ? Needs check. - # Original paper scale seems different. Let's stick to FSQ handling scale internally if needed. - # Using scale = 1.0 for now, assuming FSQ handles normalization. - scale_value = 1.0 # ((levels_tensor - 1)**-ind) - Check this logic - scales.append(scale_value) - - # Pass channel_first to FSQ - fsq = FSQ(levels=levels, dim=codebook_dim, channel_first=channel_first, **kwargs) - - self.layers.append(fsq) - - # Check if FSQ layers have projections internally. ResidualFSQ should handle overall projection. - assert all([not fsq.has_projections for fsq in self.layers]), "FSQ layers within ResidualFSQ should not have internal projections." - - self.codebook_size = self.layers[0].codebook_size - - # Using scale = 1.0, so register_buffer might not be needed, or store 1.0s - # self.register_buffer("scales", torch.Tensor(scales), persistent=False) - # If scales are needed, they should likely be parameters or calculated differently. - # For now, assuming FSQ normalizes correctly and scale is 1.0 here. - - self.quantize_dropout = quantize_dropout and num_quantizers > 1 - - assert quantize_dropout_cutoff_index >= 0 - - self.quantize_dropout_cutoff_index = quantize_dropout_cutoff_index - self.quantize_dropout_multiple_of = quantize_dropout_multiple_of # encodec paper proposes structured dropout, believe this was set to 4 - - @property - def codebooks(self): - # Codebooks are implicit in FSQ, access via property - codebooks = [layer.implicit_codebook for layer in self.layers] - codebooks = torch.stack(codebooks, dim=0) - return codebooks - - def get_codes_from_indices(self, indices): - # indices shape: (b ... q) or (b q ...) depending on usage - num_dims = indices.ndim - q_dim = -1 # Assume last dim is quantizer dim by default - - # Find the quantizer dimension (q) - for i in range(num_dims): - if indices.shape[i] == self.num_quantizers: - q_dim = i - break - if q_dim == -1 and self.num_quantizers == 1 and indices.shape[-1] != 1: - # If only 1 quantizer, indices might not have the quantizer dim explicitly - indices = indices.unsqueeze(-1) # Add the quantizer dim - q_dim = -1 - elif q_dim == -1: - raise ValueError(f"Could not find quantizer dimension ({self.num_quantizers}) in indices shape {indices.shape}") - - # Ensure q_dim is the last dimension for processing - if q_dim != num_dims - 1: - permute_dims = list(range(num_dims)) - permute_dims.pop(q_dim) - permute_dims.append(q_dim) - indices = indices.permute(*permute_dims) - - - batch_shape = indices.shape[:-1] # Shape before the quantizer dim - indices = indices.reshape(-1, self.num_quantizers) # Flatten batch/spatial dims - - # Handle dropout indices (-1) - if indices.max() >= self.codebook_size: - raise ValueError(f"Invalid index found in indices: {indices.max()}. Max allowed is {self.codebook_size - 1}.") - if indices.min() < -1: - raise ValueError(f"Invalid index found in indices: {indices.min()}. Min allowed is -1 (dropout).") - - mask = indices == -1 - effective_indices = indices.masked_fill(mask, 0) # Use 0 for dropout indices temporarily - - all_codes = [] - # Iterate through each quantizer layer - for i in range(self.num_quantizers): - layer_indices = effective_indices[:, i] - # Use the FSQ layer's method to convert indices to codes (handles normalization) - # Need to ensure indices_to_codes exists and works correctly in FSQ - # Assuming FSQ.indices_to_codes takes (batch,) indices and returns (batch, codebook_dim) codes - layer_codes = self.layers[i].indices_to_codes(layer_indices) # This needs correct FSQ method - all_codes.append(layer_codes) - - all_codes_tensor = torch.stack(all_codes, dim=0) # (q, b_flat, d) - - # Mask out dropout codes - mask_expanded = mask.permute(1, 0).unsqueeze(-1) # (q, b_flat, 1) - all_codes_tensor = all_codes_tensor.masked_fill(mask_expanded, 0.0) - - # Reshape back to original batch/spatial shape - all_codes_tensor = all_codes_tensor.reshape(self.num_quantizers, *batch_shape, -1) # (q, b ... d) - - # Restore original q_dim position if it was changed - if q_dim != num_dims - 1: - # Need inverse permutation - inv_permute_dims = list(range(num_dims)) # Start with 0, 1, ..., num_dims-1 - inv_permute_dims.insert(q_dim, num_dims) # Insert the last dim (q) at the original position - inv_permute_dims.pop() # Remove the last element - # Permute from (q, b ... d) -> (b ... q ... d) - careful with dims - # Example: Input (b h w q), processed to (q, b*h*w), output (q, b*h*w, d) - # Reshaped to (q, b, h, w, d) - # Want (b, h, w, q, d) -> Need to confirm this logic - # Let's assume output shape (q, b, ..., d) is desired for summation later. - pass # Keep as (q, b ... d) for now - - return all_codes_tensor - - - def get_output_from_indices(self, indices): - # indices shape: (b ... q) - codes = self.get_codes_from_indices(indices) # Output: (q, b ... d) - codes_summed = reduce(codes, "q b ... d -> b ... d", "sum") - # Project back to original dimension - output = self.project_out(codes_summed) - - # Handle channel first if necessary for the final output - if self.channel_first and output.ndim > 2: - # Assumes input was (b d ...), so output should be too - output = rearrange(output, "b ... d -> b d ...") - - return output - - def forward(self, x, return_all_codes=False, rand_quantize_dropout_fixed_seed=None): - num_quant, quant_dropout_multiple_of, device = ( - self.num_quantizers, - self.quantize_dropout_multiple_of, - x.device, - ) - - # handle channel first input if necessary for projection - original_shape = x.shape - if self.channel_first: - if x.ndim > 2: # Has spatial/temporal dims - x = rearrange(x, "b d ... -> b ... d") - x, ps = pack([x], "b * d") - # else: x is (b d), processed directly - else: - # Input is (b ... d) - if x.ndim > 2: - x, ps = pack([x], "b * d") - # else: x is (b d), processed directly - - - # maybe project in - projected_x = self.project_in(x) # (b ... codebook_dim) - - quantized_out = 0.0 - residual = projected_x # Start residual from projected input - - all_indices = [] - - should_quantize_dropout = self.training and self.quantize_dropout - - # sample a layer index at which to dropout further residual quantization - # also prepare null indices - rand_quantize_dropout_index = num_quant # Default to no dropout - - if should_quantize_dropout: - if not exists(rand_quantize_dropout_fixed_seed): - rand_quantize_dropout_fixed_seed = get_maybe_sync_seed(device) - - rand = random.Random(rand_quantize_dropout_fixed_seed) - # Ensure cutoff index is valid - valid_cutoff = max(0, self.quantize_dropout_cutoff_index) - rand_quantize_dropout_index = rand.randrange(valid_cutoff, num_quant) - - if quant_dropout_multiple_of != 1: - rand_quantize_dropout_index = ( - round_up_multiple(rand_quantize_dropout_index + 1, quant_dropout_multiple_of) - 1 - ) - # Clamp index to be within valid range - rand_quantize_dropout_index = min(rand_quantize_dropout_index, num_quant - 1) - - # Null indices shape should match the batch/spatial dims of x before pack - null_indices_shape = list(x.shape[:-1]) # All dims except last feature dim - null_indices = torch.full(null_indices_shape, -1, device=device, dtype=torch.long) - - - # go through the layers - # Assuming scale is handled within FSQ or is 1.0 here - # scales = self.scales.to(device) - - for quantizer_index, layer in enumerate(self.layers): - # scale = scales[quantizer_index] # If using external scales - - if quantizer_index > rand_quantize_dropout_index: - # Append null indices matching the shape of valid indices from FSQ - # FSQ returns indices shape (b ...) or (b ... c) -> need (b ...) - # Use the pre-calculated null_indices - all_indices.append(null_indices) - continue - - # Pass residual to the quantizer layer - # Assume FSQ takes (b ... d) or (b d ...) based on its channel_first setting - # Here, residual is (b ... codebook_dim) - quantized, indices = layer(residual) # layer should handle channel_first internally - - # residual = residual - quantized.detach() # Update residual BEFORE summing output - # quantized_out = quantized_out + quantized # Sum the quantized part - - # Algorithm 1 from paper: - # Input: x - # residual = x - # codes = [] - # for q in quantizers: - # x_q, indices = q(residual) # Quantize - # residual = residual - x_q # Update residual (use x_q directly, not detached?) - Check paper/encodec. Using detached version. - # codes.append(indices) - # x_hat = sum(x_q for each layer?) - No, final quantized output is reconstructed from indices. - - # Let's follow common implementation: sum quantized outputs, update residual with detached quantized - quantized_detached = quantized.detach() - residual = residual - quantized_detached - quantized_out = quantized_out + quantized # Sum quantized outputs from each layer - - # Store indices - if indices is None: - raise ValueError(f"FSQ layer {quantizer_index} did not return indices.") - all_indices.append(indices) - - # project out the summed quantized output - final_quantized_out = self.project_out(quantized_out) # (b ... dim) - - # stack all indices - all_indices = torch.stack(all_indices, dim=-1) # (b ... q) - - # Restore original shape if packed - if x.ndim > 2: # If we packed dimensions - final_quantized_out = unpack(final_quantized_out, ps, "b * d")[0] - all_indices = unpack(all_indices, ps, "b * q")[0] - - # Restore channel dimension if needed - if self.channel_first and final_quantized_out.ndim > 2: - final_quantized_out = rearrange(final_quantized_out, "b ... d -> b d ...") - # Decide how to handle indices shape. Keep as (b ... q) or (b q ...)? - # Keeping as (b ... q) seems more common. - # all_indices = rearrange(all_indices, "b ... q -> b q ...") # Optional rearrange - - # return - ret = (final_quantized_out, all_indices) - - if not return_all_codes: - return ret - - # Return all codes (reconstructed from indices) - # Input to get_codes_from_indices should be (b ... q) - all_codes = self.get_codes_from_indices(all_indices) # Output (q, b ... d) - - # Maybe reshape all_codes to match input shape conventions? - # If input was channel_first (b d ...), maybe output codes as (q b d ...)? - if self.channel_first and all_codes.ndim > 3: - all_codes = rearrange(all_codes, "q b ... d -> q b d ...") - - return (*ret, all_codes) - -# =============================================================== -# End: Content from sparktts/modules/fsq/residual_fsq.py -# =============================================================== - - -# =============================================================== -# Start: Content from sparktts/modules/speaker/ecapa_tdnn.py -# =============================================================== -class Res2Conv1dReluBn(nn.Module): - """ - in_channels == out_channels == channels - """ - - def __init__( - self, - channels, - kernel_size=1, - stride=1, - padding=0, - dilation=1, - bias=True, - scale=4, - ): - super().__init__() - assert channels % scale == 0, "{} % {} != 0".format(channels, scale) - self.scale = scale - self.width = channels // scale - self.nums = scale if scale == 1 else scale - 1 - - self.convs = [] - self.bns = [] - for i in range(self.nums): - self.convs.append( - nn.Conv1d( - self.width, - self.width, - kernel_size, - stride, - padding, - dilation, - bias=bias, - ) - ) - self.bns.append(nn.BatchNorm1d(self.width)) - self.convs = nn.ModuleList(self.convs) - self.bns = nn.ModuleList(self.bns) - - def forward(self, x): - out = [] - spx = torch.split(x, self.width, 1) - sp = spx[0] - # Enumerate starts from 0, matching list indices - for i, (conv, bn) in enumerate(zip(self.convs, self.bns)): - # Order: conv -> relu -> bn - if i >= 1: - sp = sp + spx[i] # Residual connection within block parts - sp = conv(sp) - sp = bn(F.relu(sp)) # Apply ReLU before BatchNorm - out.append(sp) - if self.scale != 1: - # Append the last chunk without processing if scale > 1 - out.append(spx[self.nums]) - out = torch.cat(out, dim=1) - - return out - - -""" Conv1d + BatchNorm1d + ReLU """ -class Conv1dReluBn(nn.Module): - def __init__( - self, - in_channels, - out_channels, - kernel_size=1, - stride=1, - padding=0, - dilation=1, - bias=True, - ): - super().__init__() - self.conv = nn.Conv1d( - in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias - ) - self.bn = nn.BatchNorm1d(out_channels) - - def forward(self, x): - # Original: bn(relu(conv(x))) - # ECAPA Paper Figure/Desc seems to suggest conv -> bn -> relu ? Check Res2Net paper/ECAPA details. - # Sticking to original code's bn(relu(conv(x))) for now. - return self.bn(F.relu(self.conv(x))) - - -""" The SE connection of 1D case. """ -class SE_Connect(nn.Module): - def __init__(self, channels, se_bottleneck_dim=128): - super().__init__() - self.linear1 = nn.Linear(channels, se_bottleneck_dim) - self.linear2 = nn.Linear(se_bottleneck_dim, channels) - - def forward(self, x): - # x shape: (B, C, T) - out = x.mean(dim=2) # Global average pooling over time -> (B, C) - out = F.relu(self.linear1(out)) - out = torch.sigmoid(self.linear2(out)) - out = x * out.unsqueeze(2) # (B, C, T) * (B, C, 1) -> (B, C, T) - - return out - - -""" SE-Res2Block of the ECAPA-TDNN architecture. """ -class SE_Res2Block(nn.Module): - def __init__(self, channels, kernel_size, stride, padding, dilation, scale): - super().__init__() - self.se_res2block = nn.Sequential( - Conv1dReluBn(channels, channels, kernel_size=1, stride=1, padding=0), - Res2Conv1dReluBn( - channels, kernel_size, stride, padding, dilation, scale=scale - ), - Conv1dReluBn(channels, channels, kernel_size=1, stride=1, padding=0), - SE_Connect(channels), - ) - - def forward(self, x): - return x + self.se_res2block(x) - - -class ECAPA_TDNN(nn.Module): - def __init__( - self, - channels=512, - feat_dim=80, - embed_dim=192, - pooling_func="ASTP", - global_context_att=False, - emb_bn=False, - ): - super().__init__() - - self.layer1 = Conv1dReluBn(feat_dim, channels, kernel_size=5, padding=2) - self.layer2 = SE_Res2Block( - channels, kernel_size=3, stride=1, padding=2, dilation=2, scale=8 - ) - self.layer3 = SE_Res2Block( - channels, kernel_size=3, stride=1, padding=3, dilation=3, scale=8 - ) - self.layer4 = SE_Res2Block( - channels, kernel_size=3, stride=1, padding=4, dilation=4, scale=8 - ) - - cat_channels = channels * 3 - # The output channels after conv depends on the pooling layer input expectation - # Original paper uses 1536. Let's assume pooling expects 1536. - self.conv = nn.Conv1d(cat_channels, cat_channels, kernel_size=1) # Keep channels same for pooling - - # Dynamically get pooling class based on string name from pooling_layers (defined earlier) - if pooling_func == "TAP": pooling_layer = TAP - elif pooling_func == "TSDP": pooling_layer = TSDP - elif pooling_func == "TSTP": pooling_layer = TSTP - elif pooling_func == "ASTP": pooling_layer = ASTP - elif pooling_func == "MHASTP": pooling_layer = MHASTP - elif pooling_func == "MQMHASTP": pooling_layer = MQMHASTP - else: raise ValueError(f"Unsupported pooling function: {pooling_func}") - - self.pool = pooling_layer( - in_dim=cat_channels, # Pooling operates on the output of self.conv - global_context_att=global_context_att # Pass context flag if relevant (ASTP) - # Add other necessary kwargs for specific pooling layers if needed - ) - # self.pool_out_dim = self.pool.get_out_dim() # Get output dim from pooling layer - # Use standard way to get output dim if get_out_dim not standard - # For TSTP/ASTP etc., it's usually 2 * in_dim - if hasattr(self.pool, 'get_out_dim'): - self.pool_out_dim = self.pool.get_out_dim() - elif isinstance(self.pool, (TSTP, ASTP, MHASTP, MQMHASTP)): - # Assuming these double the input dimension - self.pool_out_dim = cat_channels * (2 * getattr(self.pool, 'query_num', 1) if isinstance(self.pool, MQMHASTP) else 2) - else: # TAP, TSDP - self.pool_out_dim = cat_channels - - self.bn = nn.BatchNorm1d(self.pool_out_dim) - self.linear = nn.Linear(self.pool_out_dim, embed_dim) - self.emb_bn = emb_bn - if emb_bn: # better in SSL for SV - self.bn2 = nn.BatchNorm1d(embed_dim) - else: - self.bn2 = nn.Identity() - - def forward(self, x, return_latent=False): - # Input x expected as (B, T, F) e.g., mels - x = x.permute(0, 2, 1) # (B, T, F) -> (B, F, T) - - out1 = self.layer1(x) - out2 = self.layer2(out1) - out3 = self.layer3(out2) - out4 = self.layer4(out3) - - # Concat features from layers 2, 3, 4 - out = torch.cat([out2, out3, out4], dim=1) # (B, 3*channels, T) - latent = F.relu(self.conv(out)) # (B, 3*channels, T) - - # Pooling expects (B, F, T) - pooled_out = self.pool(latent) # (B, pool_out_dim) - bn_out = self.bn(pooled_out) - embedding = self.linear(bn_out) # (B, embed_dim) - - if self.emb_bn: - embedding = self.bn2(embedding) - - if return_latent: - # Return the embedding and the features before pooling - return embedding, latent # latent shape (B, 3*channels, T) - return embedding # Return only the final embedding - - -# Factory functions (optional, but keep if used elsewhere) -def ECAPA_TDNN_c1024(feat_dim, embed_dim, pooling_func="ASTP", emb_bn=False): - return ECAPA_TDNN( - channels=1024, - feat_dim=feat_dim, - embed_dim=embed_dim, - pooling_func=pooling_func, - emb_bn=emb_bn, - ) - -def ECAPA_TDNN_GLOB_c1024(feat_dim, embed_dim, pooling_func="ASTP", emb_bn=False): - return ECAPA_TDNN( - channels=1024, - feat_dim=feat_dim, - embed_dim=embed_dim, - pooling_func=pooling_func, - global_context_att=True, - emb_bn=emb_bn, - ) - -def ECAPA_TDNN_c512(feat_dim, embed_dim, pooling_func="ASTP", emb_bn=False): - return ECAPA_TDNN( - channels=512, - feat_dim=feat_dim, - embed_dim=embed_dim, - pooling_func=pooling_func, - emb_bn=emb_bn, - ) - -def ECAPA_TDNN_GLOB_c512(feat_dim, embed_dim, pooling_func="ASTP", emb_bn=False): - return ECAPA_TDNN( - channels=512, - feat_dim=feat_dim, - embed_dim=embed_dim, - pooling_func=pooling_func, - global_context_att=True, - emb_bn=emb_bn, - ) - -# =============================================================== -# End: Content from sparktts/modules/speaker/ecapa_tdnn.py -# =============================================================== - - -# =============================================================== -# Start: Content from sparktts/modules/speaker/perceiver_encoder.py -# =============================================================== -# Helper functions for Perceiver/Attention -def exists(val): # Redefined earlier - return val is not None - -def once(fn): # Redefined earlier - called = False - @wraps(fn) - def inner(x): - nonlocal called - if called: return - called = True - return fn(x) - return inner - -print_once = once(print) - -class Attend(nn.Module): - def __init__(self, dropout=0.0, causal=False, use_flash=False): - super().__init__() - self.dropout = dropout - self.attn_dropout = nn.Dropout(dropout) - - self.causal = causal - self.register_buffer("mask", None, persistent=False) - - self.use_flash = use_flash - can_use_flash = hasattr(F, 'scaled_dot_product_attention') and use_flash - if can_use_flash: - print_once("Using Flash Attention for Perceiver.") - else: - if use_flash: print_once("Flash Attention requested but not available/enabled.") - self.use_flash = False # Disable if not available - - # Flash attention config (simplified) - self.efficient_config = namedtuple("EfficientAttentionConfig", ["enable_flash", "enable_math", "enable_mem_efficient"]) - # Set default configs, actual backend selection happens in F.scaled_dot_product_attention - self.cpu_config = self.efficient_config(True, True, True) # Default for CPU - self.cuda_config = self.efficient_config(True, True, True) # Default for CUDA - - - def get_mask(self, n, device): - if exists(self.mask) and self.mask.shape[-1] >= n and self.mask.device == device: - return self.mask[:n, :n] - - mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1) - self.register_buffer("mask", mask, persistent=False) - return mask - - def flash_attn(self, q, k, v, mask=None): - _, heads, q_len, _, k_len, is_cuda = *q.shape, k.shape[-2], q.is_cuda - - # Expand KV if needed (for multi-query attention, though Perceiver might use standard MHA) - if k.ndim == 3: # (b n_kv d) -> (b h n_kv d) ? No, needs (b h n_kv d_head) - # Assume k/v are already (b h n d_head) or need different handling - pass - if v.ndim == 3: - pass - - # Format mask for flash attention (B, N_q, N_kv) or (B, H, N_q, N_kv) - flash_mask = None - if exists(mask): - # mask shape (b, n_kv) -> needs (b, 1, n_q, n_kv) or (b, h, n_q, n_kv) ? - # Check documentation. For key padding mask, usually (B, N_kv). - # Needs expansion. Let's assume (B, H, N_q, N_kv) for safety. - if mask.ndim == 2: # (b, n_kv) - flash_mask = rearrange(mask, "b j -> b 1 1 j") - # Flash attention expects additive mask (-inf for masked) not boolean? Check. - # F.scaled_dot_product_attention takes boolean mask with attn_mask arg. - flash_mask = flash_mask.expand(-1, heads, q_len, -1) # (b h n_q n_kv) - # Use ~mask because True means *mask out* in flash attn's attn_mask. - flash_mask = ~flash_mask - elif mask.ndim == 4 and mask.shape[1] == 1: # Maybe already expanded (b 1 1 n_kv) - flash_mask = mask.expand(-1, heads, q_len, -1) - flash_mask = ~flash_mask - else: - # Assuming mask might already be correctly shaped (e.g., B, H, Nq, Nkv boolean) - flash_mask = ~mask # Invert mask if boolean - - # pytorch 2.0 flash attn: q, k, v, attn_mask, dropout_p, is_causal - # attn_mask should be boolean where True indicates masking. - out = F.scaled_dot_product_attention( - q, - k, - v, - attn_mask=flash_mask if exists(flash_mask) else None, # Pass boolean mask - dropout_p=self.dropout if self.training else 0.0, - is_causal=self.causal # Pass causal flag directly - ) - return out - - def forward(self, q, k, v, mask=None): - """ - einstein notation - b - batch - h - heads - n, i, j - sequence length (query, key/value) - d - feature dimension (d_head) - """ - n, device = q.shape[-2], q.device - scale = q.shape[-1] ** -0.5 - - if self.use_flash: - return self.flash_attn(q, k, v, mask=mask) - - # Manual Attention Calculation - kv_einsum_eq = "b h j d" # Assuming k, v are always (b h n d) - - # similarity - sim = einsum(f"b h i d, {kv_einsum_eq} -> b h i j", q, k) * scale - - # key padding mask - if exists(mask): - # mask shape (b, j) -> (b, 1, 1, j) - mask_value = -torch.finfo(sim.dtype).max - mask = rearrange(mask, "b j -> b 1 1 j") - sim = sim.masked_fill(~mask, mask_value) # Mask where mask is False - - # causal mask (Not typically used in Perceiver cross-attention) - if self.causal: - causal_mask = self.get_mask(n, device) # (i, j) - sim = sim.masked_fill(causal_mask, mask_value) - - # attention - attn = sim.softmax(dim=-1) - attn = self.attn_dropout(attn) - - # aggregate values - out = einsum(f"b h i j, {kv_einsum_eq} -> b h i d", attn, v) - - return out - - -# Need Sequential, default, RMSNorm, GEGLU, FeedForward, Attention for PerceiverResampler -def Sequential(*mods): # Redefined earlier - return nn.Sequential(*filter(exists, mods)) - -class RMSNorm(nn.Module): - def __init__(self, dim, scale=True, dim_cond=None): - super().__init__() - self.cond = exists(dim_cond) - # Conditional LayerNorm not used in PerceiverResampler, simplify - # self.to_gamma_beta = nn.Linear(dim_cond, dim * 2) if self.cond else None - - self.scale = dim**0.5 - self.gamma = nn.Parameter(torch.ones(dim)) if scale else None - - def forward(self, x, cond=None): # Remove cond argument if not used - gamma = default(self.gamma, torch.tensor(1.0, device=x.device)) # Ensure gamma is tensor - # Note: F.normalize normalizes across the *last* dimension by default - normed_x = F.normalize(x, dim=-1) - return normed_x * self.scale * gamma - - -class CausalConv1d(nn.Conv1d): # Already defined earlier - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - kernel_size = self.kernel_size[0] - dilation = self.dilation[0] - stride = self.stride[0] - assert stride == 1 - self.causal_padding = dilation * (kernel_size - 1) - - def forward(self, x): - # Input x: (B, C, T) - causal_padded_x = F.pad(x, (self.causal_padding, 0), value=0.0) - return super().forward(causal_padded_x) - -class GEGLU(nn.Module): # Already defined earlier - def forward(self, x): - x, gate = x.chunk(2, dim=-1) - return F.gelu(gate) * x - -def FeedForward(dim, mult=4, causal_conv=False): # Already defined earlier - dim_inner = int(dim * mult * 2 / 3) - - conv = None - if causal_conv: - conv = nn.Sequential( - Rearrange("b n d -> b d n"), - CausalConv1d(dim_inner, dim_inner, 3), - Rearrange("b d n -> b n d"), - ) - - return Sequential( - nn.Linear(dim, dim_inner * 2, bias=False), # Bias False often used in transformers - GEGLU(), - conv, - nn.Linear(dim_inner, dim, bias=False) # Bias False - ) - - -class Attention(nn.Module): - def __init__( - self, - dim, - *, - dim_context=None, - causal=False, - dim_head=64, - heads=8, - dropout=0.0, - use_flash=False, - cross_attn_include_queries=False, - ): - super().__init__() - # self.scale = dim_head**-0.5 # scale is handled by Attend or flash attn - self.heads = heads - self.cross_attn_include_queries = cross_attn_include_queries - - dim_inner = dim_head * heads - dim_context = default(dim_context, dim) - - self.attend = Attend(causal=causal, dropout=dropout, use_flash=use_flash) - self.to_q = nn.Linear(dim, dim_inner, bias=False) - # Combine K and V projection for efficiency - self.to_kv = nn.Linear(dim_context, dim_inner * 2, bias=False) - self.to_out = nn.Linear(dim_inner, dim, bias=False) - - def forward(self, x, context=None, mask=None): - h, has_context = self.heads, exists(context) - # x shape: (b, n_q, d) - # context shape: (b, n_kv, d_ctx) - - context = default(context, x) # Use self if context not provided - - if has_context and self.cross_attn_include_queries: - # Prepend queries to context for attention calculation - context = torch.cat((x, context), dim=-2) # (b, n_q + n_kv, d_ctx) - ensure dims match - - # Project q, k, v - q = self.to_q(x) - k, v = self.to_kv(context).chunk(2, dim=-1) - - # Reshape for multi-head attention - q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) - - # Attend - out = self.attend(q, k, v, mask=mask) # mask should be (b, n_kv) - - # Combine heads and project out - out = rearrange(out, "b h n d -> b n (h d)") - return self.to_out(out) - - -class PerceiverResampler(nn.Module): - def __init__( - self, - *, - dim, - depth=2, - dim_context=None, - num_latents=32, - dim_head=64, - heads=8, - ff_mult=4, - use_flash_attn=False, - ): - super().__init__() - dim_context = default(dim_context, dim) - - # Project context to query dimension if different - self.proj_context = ( - nn.Linear(dim_context, dim) if dim_context != dim else nn.Identity() - ) - - # Learnable latent queries - self.latents = nn.Parameter(torch.randn(num_latents, dim)) - nn.init.normal_(self.latents, std=0.02) # Initialize latents - - self.layers = nn.ModuleList([]) - for _ in range(depth): - self.layers.append( - nn.ModuleList( - [ - # Cross-Attention from latents (queries) to context (keys/values) - Attention( - dim=dim, - dim_context=dim, # Context is projected to dim - dim_head=dim_head, - heads=heads, - use_flash=use_flash_attn, - cross_attn_include_queries=False, # Standard Perceiver cross-attn - ), - # Self-Attention within latents - # Optional: Add self-attention block here if needed - # Attention( - # dim=dim, dim_head=dim_head, heads=heads, use_flash=use_flash_attn - # ), - - # FeedForward block - FeedForward(dim=dim, mult=ff_mult), - ] - ) - ) - # Add LayerNorms (typically before attention and ff blocks) - # self.layers[-1].insert(0, RMSNorm(dim)) # Pre-Attention Norm - # self.layers[-1].insert(2, RMSNorm(dim)) # Pre-FF Norm - # Using Post-Norm structure as in original reference: - self.layers[-1].insert(1, RMSNorm(dim)) # After Attention - self.layers[-1].append(RMSNorm(dim)) # After FeedForward - - - # Final normalization of latents - # self.norm = RMSNorm(dim) # Final norm applied inside loop in original? Let's apply at end. - - def forward(self, x, mask=None): - # x shape: (b, n_ctx, d_ctx) - batch = x.shape[0] - - # Project context - x = self.proj_context(x) # (b, n_ctx, d) - - # Repeat latents for batch - latents = repeat(self.latents, "n d -> b n d", b=batch) # (b, n_lat, d) - - # Apply layers - # Original structure had norm inside loop, adapting: Attn -> Norm -> FF -> Norm - for attn, norm1, ff, norm2 in self.layers: - # Cross-Attention + Residual - latents_attn = attn(latents, x, mask=mask) # Query: latents, Context: x - latents = norm1(latents_attn + latents) - - # FeedForward + Residual - latents_ff = ff(latents) - latents = norm2(latents_ff + latents) - - # return self.norm(latents) # Apply final norm if defined outside loop - return latents # Return latents after last block norm - -# =============================================================== -# End: Content from sparktts/modules/speaker/perceiver_encoder.py -# =============================================================== - - -# =============================================================== -# Start: Content from sparktts/modules/speaker/speaker_encoder.py -# =============================================================== -class SpeakerEncoder(nn.Module): - """ - Speaker Encoder using ECAPA-TDNN, Perceiver Resampler, and Residual FSQ. - - Args: - input_dim (int): acoustic feature dimension (e.g., mel bins) - out_dim (int): output dimension of the final d-vector - latent_dim (int): latent dimension for perceiver and quantization - token_num (int): number of latent tokens from perceiver - fsq_levels (List[int]): levels for finite scalar quantization - fsq_num_quantizers (int): number of residual quantizers in FSQ - ecapa_embed_dim (int): embedding dimension from ECAPA-TDNN (before projection) - """ - - def __init__( - self, - input_dim: int = 80, # Default mel bins - out_dim: int = 1024, # Target d-vector dim from config - latent_dim: int = 128, # Latent dim for perceiver/quantizer - token_num: int = 32, # Number of speaker tokens - fsq_levels: List[int] = [4, 4, 4, 4, 4, 4], - fsq_num_quantizers: int = 1, - # Add ECAPA config params if needed, or use defaults - ecapa_channels: int = 512, - ecapa_embed_dim: int = 192, # Default ECAPA embed dim - ): - super(SpeakerEncoder, self).__init__() - - # ECAPA-TDNN for initial feature extraction and x-vector (optional) - # Using the GLOB variant as in the original __main__ test - self.speaker_encoder_base = ECAPA_TDNN_GLOB_c512( - feat_dim=input_dim, - embed_dim=ecapa_embed_dim # Use specific ECAPA embed dim - ) - # Dimension of features extracted by ECAPA (latent before pooling) - ecapa_feature_dim = ecapa_channels * 3 # From concatenation in ECAPA - - # Perceiver Resampler to get fixed-length sequence from variable-length ECAPA features - self.perceiver_sampler = PerceiverResampler( - dim=latent_dim, # Output dim of perceiver latents - dim_context=ecapa_feature_dim, # Input dim from ECAPA features - num_latents=token_num, - depth=2, # Default depth, adjust if needed - dim_head=64, heads=8, ff_mult=4, # Default attention/ff params - use_flash_attn=True # Enable flash attention if available - ) - - # Residual Finite Scalar Quantizer - self.quantizer = ResidualFSQ( - levels=fsq_levels, - num_quantizers=fsq_num_quantizers, - dim=latent_dim, # Quantizer operates on perceiver output dim - channel_first=False, # Perceiver output is (B, T, D), so channel_first=False - quantize_dropout=False, # No dropout specified in config - ) - - # Final projection from flattened quantized tokens to the target output dimension - self.project = nn.Linear(latent_dim * token_num, out_dim) - - def get_codes_from_indices(self, indices: torch.Tensor) -> torch.Tensor: - """Reconstruct quantized vectors from indices.""" - # indices shape: (B, T_token, Q) or (B, Q, T_token)? Check ResidualFSQ output. - # Assuming (B, T_token, Q) from forward pass. - # get_output_from_indices expects (indices_chunk1, indices_chunk2, ...) if grouped. - # If not grouped, expects (B, ... Q). Let's assume (B, T_token, Q). - zq = self.quantizer.get_output_from_indices(indices) - # Output zq shape should be (B, T_token, latent_dim) - return zq - - def get_indices(self, mels: torch.Tensor) -> torch.Tensor: - """Get FSQ indices directly from mel spectrograms.""" - # mels: (B, T_mel, D_mel) - _, features = self.speaker_encoder_base(mels, return_latent=True) # features: (B, ecapa_feat_dim, T_feat) - x = self.perceiver_sampler(features.transpose(1, 2)) # Input: (B, T_feat, ecapa_feat_dim), Output: (B, token_num, latent_dim) - _, indices = self.quantizer(x) # Input: (B, token_num, latent_dim), indices: (B, token_num, Q) - return indices - - def forward(self, mels: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Args: - mels: (B, T_mel, D_mel) - Mel spectrogram input - - Return: - x_vector: (B, ecapa_embed_dim) - Global speaker embedding from ECAPA - d_vector: (B, out_dim) - Speaker embedding derived from quantized tokens - """ - # Get base speaker embedding (x-vector) and intermediate features from ECAPA - x_vector, features = self.speaker_encoder_base(mels, return_latent=True) - # features shape: (B, ecapa_feat_dim, T_feat) - - # Resample features using Perceiver - # Perceiver expects (B, T, D), so transpose features - perceiver_latents = self.perceiver_sampler(features.transpose(1, 2)) - # perceiver_latents shape: (B, token_num, latent_dim) - - # Quantize the perceiver latents - # Quantizer expects (B, T, D) if channel_first=False - zq, indices = self.quantizer(perceiver_latents) - # zq shape: (B, token_num, latent_dim), indices shape: (B, token_num, Q) - - # Flatten quantized tokens and project to final d-vector dimension - zq_flat = rearrange(zq, 'b t d -> b (t d)') # (B, token_num * latent_dim) - d_vector = self.project(zq_flat) # (B, out_dim) - - return x_vector, d_vector - - def tokenize(self, mels: torch.Tensor) -> torch.Tensor: - """Tokenize the input mel spectrogram to get FSQ indices.""" - # Same logic as get_indices - _, features = self.speaker_encoder_base(mels, return_latent=True) # features: (B, ecapa_feat_dim, T_feat) - x = self.perceiver_sampler(features.transpose(1, 2)) # (B, token_num, latent_dim) - _, indices = self.quantizer(x) # indices: (B, token_num, Q) - return indices - - def detokenize(self, indices: torch.Tensor) -> torch.Tensor: - """Detokenize FSQ indices to get the final d-vector.""" - # indices shape: (B, token_num, Q) - # Reconstruct quantized vectors from indices - zq = self.get_codes_from_indices(indices) # (B, token_num, latent_dim) - - # Flatten and project - zq_flat = rearrange(zq, 'b t d -> b (t d)') - d_vector = self.project(zq_flat) - return d_vector - -# =============================================================== -# End: Content from sparktts/modules/speaker/speaker_encoder.py -# =============================================================== - - -# =============================================================== -# Start: Content from sparktts/modules/vq/factorized_vector_quantize.py -# =============================================================== -# Helper function from layers.py (already defined) -# def WNConv1d(*args, **kwargs): -# return weight_norm(nn.Conv1d(*args, **kwargs)) - -def ema_inplace(moving_avg, new, decay): - moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) - - -class FactorizedVectorQuantize(nn.Module): - def __init__( - self, - input_dim: int, - codebook_size: int, - codebook_dim: int, - commitment: float, - codebook_loss_weight: float = 1.0, - decay: float = 0.99, - threshold_ema_dead_code: float = 2.0, # Changed default from 2 based on config - momentum: float = 0.99, # Not used in current implementation? - use_l2_normlize: bool = True, # Added from config - **kwargs, - ): - super().__init__() - self.input_dim = input_dim - self.codebook_size = codebook_size - self.codebook_dim = codebook_dim - self.commitment = commitment - self.codebook_loss_weight = codebook_loss_weight - self.decay = decay - self.threshold_ema_dead_code = threshold_ema_dead_code - # self.momentum = momentum # Store if needed later - self.use_l2_normlize = use_l2_normlize - - if input_dim != self.codebook_dim: - self.in_project = WNConv1d(input_dim, self.codebook_dim, kernel_size=1) - self.out_project = WNConv1d(self.codebook_dim, input_dim, kernel_size=1) - else: - self.in_project = nn.Identity() - self.out_project = nn.Identity() - - # Codebook embedding layer - self.codebook = nn.Embedding(self.codebook_size, self.codebook_dim) - # Initialize codebook? Often random init is fine. - - # Buffers for EMA updates (cluster size and maybe embeddings) - self.register_buffer("cluster_size", torch.zeros(self.codebook_size)) - # EMA average embeddings (optional, can use self.codebook.weight directly for loss) - # self.register_buffer("ema_embed", self.codebook.weight.clone()) - - def forward(self, z: torch.Tensor) -> Dict[str, Any]: - """Quantizes the input tensor using a fixed codebook and returns - the corresponding codebook vectors and losses. - - Parameters - ---------- - z : Tensor[B x D_in x T] - - Returns - ------- - Dict containing: - z_q (Tensor[B x D_in x T]): Quantized continuous representation (passed through out_project) - indices (Tensor[B x T]): Codebook indices - vq_loss (Tensor[1]): Combined VQ loss (codebook + commitment) - perplexity (Tensor[1]): Codebook perplexity metric - active_num (Tensor[1]): Number of active codebook entries - """ - # z: (B, D_in, T) - B, _, T = z.shape - - # Project input to codebook dimension if necessary - z_e = self.in_project(z) # (B, D_code, T) - - # Find nearest neighbors and get quantized vectors + indices - z_q, indices, dists = self.decode_latents(z_e) # z_q: (B, D_code, T), indices: (B, T) - - # Calculate statistics for perplexity and active codes - with torch.no_grad(): # Stats should not contribute to gradient - embed_onehot = F.one_hot(indices, self.codebook_size).type(z_e.dtype) # (B, T, C) - # Flatten batch and time dims for stats - embed_onehot_flat = rearrange(embed_onehot, 'b t c -> (b t) c') - avg_probs = torch.mean(embed_onehot_flat, dim=0) # (C,) - perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) - - # EMA update for cluster size (only in training) - active_num_tensor = (embed_onehot_flat.sum(0) > 0).sum() # Before EMA - if self.training: - # Perform EMA update in place - ema_inplace(self.cluster_size, embed_onehot_flat.sum(0), self.decay) - # Calculate active codes based on EMA threshold - active_num_tensor = (self.cluster_size > self.threshold_ema_dead_code).sum() - - - # Calculate losses (only in training) - commit_loss = torch.tensor(0.0, device=z.device) - codebook_loss = torch.tensor(0.0, device=z.device) - vq_loss = torch.tensor(0.0, device=z.device) - - if self.training: - # Commitment loss (encourage encoder output E(x) to be close to codebook z_q) - # Use z_e (projected encoder output) and z_q.detach() - commit_loss = F.mse_loss(z_e, z_q.detach()) * self.commitment - - # Codebook loss (encourage codebook entries z_q to be close to encoder output E(x)) - # Use z_q and z_e.detach() - codebook_loss = F.mse_loss(z_q, z_e.detach()) * self.codebook_loss_weight - - vq_loss = commit_loss + codebook_loss - - - # Straight-through estimator: copy gradient from z_q to z_e - z_q_st = z_e + (z_q - z_e).detach() - - # Project quantized vectors back to input dimension if necessary - z_q_out = self.out_project(z_q_st) # (B, D_in, T) - - return { - "z_q": z_q_out, - "indices": indices, - # "dists": dists, # Dists might be large, exclude unless needed - "vq_loss": vq_loss, - "perplexity": perplexity, - "active_num": active_num_tensor.float(), - } - - def embed_code(self, embed_id): - """Retrieve codebook vectors for given indices.""" - return F.embedding(embed_id, self.codebook.weight) - - def decode_code(self, embed_id): - """Retrieve codebook vectors and transpose to (B, D, T) format.""" - # embed_id: (B, T) - # Embedding: (B, T, D_code) - # Transpose: (B, D_code, T) - return self.embed_code(embed_id).transpose(1, 2) - - def decode_latents(self, latents): - """Find nearest codebook entries for latent vectors.""" - # latents: (B, D_code, T) - B, D_code, T = latents.shape - encodings = rearrange(latents, "b d t -> (b t) d") # ((B*T), D_code) - codebook = self.codebook.weight # (C, D_code) - - # Normalize if required - if self.use_l2_normlize: - encodings = F.normalize(encodings, p=2, dim=-1) - codebook = F.normalize(codebook, p=2, dim=-1) - - # Compute distances (squared Euclidean or Cosine depending on normalization) - # dist = torch.cdist(encodings, codebook, p=2)**2 # Squared Euclidean - # Faster calculation using matrix multiplication if normalized: - # dist = 2 - 2 * (encodings @ codebook.t()) - # Or full squared Euclidean: - dist = ( - encodings.pow(2).sum(1, keepdim=True) # (B*T, 1) - - 2 * (encodings @ codebook.t()) # (B*T, C) - + codebook.pow(2).sum(1, keepdim=True).t() # (1, C) - ) # Result shape: (B*T, C) - - # Find nearest neighbors - indices = torch.argmin(dist, dim=-1) # (B*T) - indices = rearrange(indices, "(b t) -> b t", b=B) # (B, T) - - # Get the quantized vectors - z_q = self.decode_code(indices) # (B, D_code, T) - - return z_q, indices, dist # Return dist if needed, e.g., for debugging - - # --- Methods for inference/tokenization --- - def tokenize(self, z: torch.Tensor) -> torch.Tensor: - """Tokenize the input tensor without loss calculation.""" - # z: (B, D_in, T) - z_e = self.in_project(z) # (B, D_code, T) - _, indices, _ = self.decode_latents(z_e) # indices: (B, T) - return indices - - def detokenize(self, indices: torch.Tensor) -> torch.Tensor: - """Detokenize indices to quantized vectors in input dimension.""" - # indices: (B, T) - z_q_code_dim = self.decode_code(indices) # (B, D_code, T) - z_q_out = self.out_project(z_q_code_dim) # (B, D_in, T) - return z_q_out - -# =============================================================== -# End: Content from sparktts/modules/vq/factorized_vector_quantize.py -# =============================================================== - - - -# --- BiCodec Model Definition (Adapted from sparktts/models/bicodec.py) --- -class BiCodec(nn.Module): - """ - BiCodec model for speech synthesis, incorporating a speaker encoder, feature encoder/decoder, - quantizer, and wave generator. - """ - - def __init__( - self, - mel_params: Dict[str, Any], - encoder: nn.Module, - decoder: nn.Module, - quantizer: nn.Module, - speaker_encoder: nn.Module, - prenet: nn.Module, - postnet: nn.Module, - **kwargs - ) -> None: - """ - Initializes the BiCodec model with the required components. - - Args: - mel_params (dict): Parameters for the mel-spectrogram transformer. - encoder (nn.Module): Encoder module. - decoder (nn.Module): Decoder module. - quantizer (nn.Module): Quantizer module. - speaker_encoder (nn.Module): Speaker encoder module. - prenet (nn.Module): Prenet network. - postnet (nn.Module): Postnet network. - """ - super().__init__() - self.encoder = encoder - self.decoder = decoder - self.quantizer = quantizer - self.speaker_encoder = speaker_encoder - self.prenet = prenet - self.postnet = postnet - self._init_mel_transformer(mel_params) - - @classmethod - def load_from_config_and_checkpoint(cls, model_dir: Path, config_dict: Dict[str, Any], **kwargs) -> "BiCodec": - """Loads the model from a config dictionary and checkpoint file.""" - ckpt_path = model_dir / 'model.safetensors' - if not ckpt_path.is_file(): - raise FileNotFoundError(f"BiCodec checkpoint not found at {ckpt_path}") - - audio_tokenizer_config = config_dict # Assuming config_dict holds the relevant sub-config - - # Instantiate components using classes from _modeling_bicodec_components - mel_params = audio_tokenizer_config.get("mel_params", {}) - encoder_cfg = audio_tokenizer_config.get("encoder", {}) - quantizer_cfg = audio_tokenizer_config.get("quantizer", {}) - prenet_cfg = audio_tokenizer_config.get("prenet", {}) - postnet_cfg = audio_tokenizer_config.get("postnet", {}) - decoder_cfg = audio_tokenizer_config.get("decoder", {}) # This corresponds to WaveGenerator - speaker_encoder_cfg = audio_tokenizer_config.get("speaker_encoder", {}) - - # --- Input Validation --- - required_keys = { - "encoder": ["input_channels", "vocos_dim", "vocos_intermediate_dim", "vocos_num_layers", "out_channels"], - "quantizer": ["input_dim", "codebook_size", "codebook_dim", "commitment"], - "prenet": ["input_channels", "vocos_dim", "vocos_intermediate_dim", "vocos_num_layers", "out_channels"], - "postnet": ["input_channels", "vocos_dim", "vocos_intermediate_dim", "vocos_num_layers", "out_channels"], - "decoder": ["input_channel", "channels", "rates", "kernel_sizes"], # WaveGenerator keys - "speaker_encoder": ["input_dim", "out_dim", "latent_dim", "token_num"], - "mel_params": ["sample_rate", "n_fft", "win_length", "hop_length", "num_mels"] - } - for comp, keys in required_keys.items(): - cfg = audio_tokenizer_config.get(comp, {}) - if not cfg: logging.get_logger(__name__).warning(f"BiCodec config missing section: '{comp}'") - for key in keys: - if key not in cfg: - logging.get_logger(__name__).warning(f"BiCodec config missing key '{key}' in section '{comp}'") - # --- End Validation --- - - - # Instantiate modules - encoder = Encoder(**encoder_cfg) if encoder_cfg else None - quantizer = FactorizedVectorQuantize(**quantizer_cfg) if quantizer_cfg else None - prenet = Decoder(**prenet_cfg) if prenet_cfg else None - postnet = Decoder(**postnet_cfg) if postnet_cfg else None - decoder = WaveGenerator(**decoder_cfg) if decoder_cfg else None # WaveGenerator instance - speaker_encoder = SpeakerEncoder(**speaker_encoder_cfg) if speaker_encoder_cfg else None - - # Check if all components were successfully created - if not all([encoder, quantizer, prenet, postnet, decoder, speaker_encoder, mel_params]): - raise ValueError("Failed to initialize one or more BiCodec components due to missing configuration.") - - # Create the BiCodec instance - model = cls( - mel_params=mel_params, - encoder=encoder, - decoder=decoder, # Pass WaveGenerator instance as decoder - quantizer=quantizer, - speaker_encoder=speaker_encoder, - prenet=prenet, - postnet=postnet, - ) - - # Load state dict - try: - state_dict = load_file(ckpt_path, device="cpu") # Load to CPU first - missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) - if missing_keys: - print(f"BiCodec missing keys: {missing_keys}") - if unexpected_keys: - print(f"BiCodec unexpected keys: {unexpected_keys}") - except Exception as e: - raise IOError(f"Error loading BiCodec state dict from {ckpt_path}: {e}") - - model.eval() - # model.remove_weight_norm() # Assuming this method exists in components - - return model - - def _init_mel_transformer(self, config: Dict[str, Any]): - # Ensure required keys exist with defaults - sr = config.get("sample_rate", 16000) - n_fft = config.get("n_fft", 1024) - win_length = config.get("win_length", n_fft) - hop_length = config.get("hop_length", n_fft // 4) - fmin = config.get("mel_fmin", 0) - fmax = config.get("mel_fmax", None) - n_mels = config.get("num_mels", 80) - power = config.get("power", 2.0) # Typically 2.0 for power spectrogram - norm = config.get("norm", "slaney") - mel_scale = config.get("mel_scale", "htk") # htk or slaney - - self.mel_transformer = TT.MelSpectrogram( - sample_rate=sr, - n_fft=n_fft, - win_length=win_length, - hop_length=hop_length, - f_min=fmin, - f_max=fmax, - n_mels=n_mels, - power=power, - norm=norm, - mel_scale=mel_scale, - ).eval() # Set to eval mode - - def remove_weight_norm(self): - """Removes weight normalization from components that support it.""" - def _remove_wn(m): - if hasattr(m, 'remove_weight_norm'): - m.remove_weight_norm() - elif isinstance(m, (nn.Conv1d, nn.ConvTranspose1d)): - try: - remove_weight_norm(m) - except ValueError: - pass # Module might not have weight norm applied - - self.apply(_remove_wn) - - - @torch.no_grad() - def tokenize(self, feat: torch.Tensor, ref_wav: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - """ Tokenizes input features and reference wav into semantic and global tokens. """ - # Ensure models are on the correct device - device = feat.device - self.mel_transformer.to(device) - self.encoder.to(device) - self.quantizer.to(device) - self.speaker_encoder.to(device) - - # feat: (B, D_feat, T_feat), ref_wav: (B, T_wav) - mel = self.mel_transformer(ref_wav) # (B, D_mel, T_mel) - - # Encode features to get latents for semantic tokens - z = self.encoder(feat) # (B, D_latent, T_latent) - Assuming Encoder output matches quantizer input dim - - # Quantize latents to get semantic tokens (indices) - semantic_tokens = self.quantizer.tokenize(z) # (B, T_latent) - - # Encode mel spectrogram to get global tokens (indices) - # SpeakerEncoder.tokenize expects (B, T_mel, D_mel) - global_tokens = self.speaker_encoder.tokenize(mel.transpose(1, 2)) # (B, T_token, Q) - Check shape - - # Note: Original BiCodecTokenizer returned (global_tokens, semantic_tokens) - # Let's stick to that order for consistency with original SparkTTS usage. - return global_tokens, semantic_tokens - - - @torch.no_grad() - def detokenize(self, global_tokens: torch.Tensor, semantic_tokens: torch.Tensor) -> torch.Tensor: - """ Detokenizes semantic and global tokens into a waveform. """ - # Ensure models are on the correct device - device = semantic_tokens.device # Assume tokens are on target device - self.quantizer.to(device) - self.speaker_encoder.to(device) - self.prenet.to(device) - self.decoder.to(device) # WaveGenerator - - # semantic_tokens: (B, T_latent) or (B, T_latent, Q)? Check quantizer.tokenize output shape. Assuming (B, T_latent). - # global_tokens: (B, T_token, Q) - Check speaker_encoder.tokenize output shape. - - # Reconstruct quantized vectors from semantic tokens - z_q = self.quantizer.detokenize(semantic_tokens) # (B, D_latent, T_latent) - - # Reconstruct d-vector (condition) from global tokens - # SpeakerEncoder.detokenize expects (B, T_token, Q) - d_vector = self.speaker_encoder.detokenize(global_tokens) # (B, D_dvector) - - # Apply prenet conditioned on d-vector - # Prenet (Decoder class) expects input (B, D_latent, T_latent) and condition (B, D_dvector) - x = self.prenet(z_q, d_vector) # (B, D_prenet_out, T_latent) - Assuming prenet maintains time dim - - # Add condition (broadcasted) before wave generation - Check original logic - # Ensure d_vector has correct shape for broadcasting - if d_vector.ndim == 2: - d_vector_unsqueezed = d_vector.unsqueeze(-1) # (B, D_dvector, 1) - else: # Should not happen if speaker_encoder outputs (B, D) - d_vector_unsqueezed = d_vector - - # Ensure dimensions match for addition - if x.shape[1] == d_vector_unsqueezed.shape[1]: - # Broadcast d_vector across time dimension T_latent - x = x + d_vector_unsqueezed - else: - # Maybe project d_vector or x? Log a warning or adapt based on expected dims. - logging.get_logger(__name__).warning(f"Prenet output dim {x.shape[1]} != d-vector dim {d_vector_unsqueezed.shape[1]}. Skipping residual connection.") - - - # Generate waveform using the decoder (WaveGenerator) - # WaveGenerator expects (B, D_input, T_input) - wav_recon = self.decoder(x) # (B, 1, T_wav) - - return wav_recon - - -# --- Main SparkTTS Model --- -from .configuration_spark_tts import SparkTTSConfig -# from ._utils import load_audio # Use utils from _utils.py - -logger = logging.get_logger(__name__) - -class SparkTTSModel(PreTrainedModel, GenerationMixin): - """ - SparkTTS model integrating LLM, BiCodec, and Wav2Vec2 for text-to-speech. - """ - config_class = SparkTTSConfig - base_model_prefix = "spark_tts" - _supports_load_fast = True - - def __init__(self, config: SparkTTSConfig, llm=None, wav2vec2_model=None, wav2vec2_processor=None, bicodec=None): - super().__init__(config) - self.config = config - self.llm = llm - self.wav2vec2_model = wav2vec2_model - self.wav2vec2_processor = wav2vec2_processor - self.bicodec = bicodec - - # Ensure wav2vec2 config has output_hidden_states=True after loading - self.post_init() - - - @classmethod - def from_pretrained( - cls, - pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], - *model_args, - config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None, - cache_dir: Optional[Union[str, os.PathLike]] = None, - ignore_mismatched_sizes: bool = False, - force_download: bool = False, - local_files_only: bool = False, - token: Optional[Union[str, bool]] = None, - revision: str = "main", - use_safetensors: Optional[bool] = None, # Keep None to let transformers decide - **kwargs, - ): - # Pop device map and dtype early - handle placement later - # Note: device_map is complex with multiple components. Manual .to(device) is simpler here. - device_map = kwargs.pop("device_map", None) - if device_map: - logger.warning("`device_map` is not directly supported for this composite model. Use .to(device) after loading.") - - torch_dtype = kwargs.pop("torch_dtype", "auto") # Can be "auto", float32, float16, bfloat16 - trust_remote_code = kwargs.pop("trust_remote_code", False) # CRITICAL for custom code - - # --- 1. Resolve the main model directory --- - # This handles downloading from Hub or using a local path robustly. - if pretrained_model_name_or_path is None: - raise ValueError("`pretrained_model_name_or_path` must be provided.") - - model_path = Path(pretrained_model_name_or_path) - if not model_path.is_dir(): - # If it's not a local directory, assume it's a Hub ID and download everything - logger.info(f"{pretrained_model_name_or_path} is not a local directory. Assuming Hub ID and downloading.") - try: - resolved_model_path = snapshot_download( - repo_id=str(pretrained_model_name_or_path), - cache_dir=cache_dir, - force_download=force_download, - local_files_only=local_files_only, - token=token, - revision=revision, - allow_patterns=["*.json", "*.safetensors", "*.bin", "*.yaml", "*.txt", "README.md"], # Be somewhat permissive - # ignore_patterns=["*.git*"], # Optional: ignore git files - # user_agent={"agent": "spark-tts-custom-loader"}, # Optional - ) - resolved_model_path = Path(resolved_model_path) - logger.info(f"Model downloaded to cache: {resolved_model_path}") - except Exception as e: - raise OSError( - f"Failed to download model '{pretrained_model_name_or_path}' from Hugging Face Hub. " - f"Ensure the ID is correct and network is available. Error: {e}" - ) - else: - # It's a local directory path - resolved_model_path = model_path - logger.info(f"Loading model from local directory: {resolved_model_path}") - - if not resolved_model_path.is_dir(): - # This should ideally not happen after snapshot_download or initial check - raise EnvironmentError(f"Cannot find resolved model directory at {resolved_model_path}") - - # --- 2. Load the main configuration --- - # The config might have been passed explicitly, otherwise load from resolved path - if not isinstance(config, PretrainedConfig): - config_path = config if config is not None else resolved_model_path - config, model_kwargs = cls.config_class.from_pretrained( - config_path, # Load from the resolved directory or explicit config path - *model_args, # Pass *model_args here if they influence config loading - cache_dir=cache_dir, # Pass relevant args down - force_download=force_download, - local_files_only=local_files_only, - token=token, - revision=revision, - trust_remote_code=trust_remote_code, # Needed if config class itself is remote - return_unused_kwargs=True, - **kwargs, # Pass remaining kwargs - ) - # Update kwargs with unused ones from config loading - kwargs.update(model_kwargs) - else: - # Config object was passed directly - pass # kwargs remain as they were - - - # --- Determine torch_dtype (use config value if specified and not overridden) --- - # Priority: Explicit torch_dtype arg > config.torch_dtype > "auto" (default) - final_torch_dtype = torch_dtype # Explicit arg has highest prio - if final_torch_dtype == "auto": - final_torch_dtype = getattr(config, "torch_dtype", None) # Use config value if present - # final_torch_dtype can still be None or "auto" here, handle downstream - - # --- Helper function to resolve paths relative to the main model directory --- - def _resolve_sub_path(sub_path): - p = Path(sub_path) - if p.is_absolute(): - return str(p) - else: - # Resolve relative to the potentially cached main model path - return str(resolved_model_path / p) - - # --- 3. Load Sub-components --- - - # --- Load LLM --- - llm_path = _resolve_sub_path(config.llm_model_name_or_path) - logger.info(f"Loading LLM from resolved path: {llm_path}") - try: - llm = AutoModelForCausalLM.from_pretrained( - llm_path, - torch_dtype=final_torch_dtype if final_torch_dtype != "auto" else None, # Pass resolved dtype or None - trust_remote_code=trust_remote_code, # Pass down trust_remote_code - # Pass remaining kwargs that might be relevant for AutoModelForCausalLM - # Filter kwargs if necessary, but often passing them is fine - **kwargs - ) - except Exception as e: - raise OSError(f"Failed to load LLM from {llm_path}: {e}") - - # --- Load Wav2Vec2 --- - w2v_path = _resolve_sub_path(config.wav2vec2_model_name_or_path) - logger.info(f"Loading Wav2Vec2 components from resolved path: {w2v_path}") - try: - # Load feature extractor first - wav2vec2_processor = Wav2Vec2FeatureExtractor.from_pretrained( - w2v_path, - trust_remote_code=trust_remote_code, - # Add any relevant kwargs for feature extractor if needed - ) - # Load model - wav2vec2_model = Wav2Vec2Model.from_pretrained( - w2v_path, - trust_remote_code=trust_remote_code, - # Add any relevant kwargs for model if needed (e.g., add_adapter=False) - ) - except Exception as e: - raise OSError(f"Failed to load Wav2Vec2 components from {w2v_path}: {e}") - - # --- Load BiCodec --- - bicodec_path = _resolve_sub_path(config.bicodec_model_name_or_path) - logger.info(f"Loading BiCodec from resolved path: {bicodec_path}") - if not config.bicodec_config or "audio_tokenizer" not in config.bicodec_config: - raise ValueError("BiCodec configuration ('bicodec_config' with 'audio_tokenizer' key) not found in SparkTTSConfig.") - try: - # Assuming BiCodec class has the custom loading method - # Make sure BiCodec class is imported or defined above - bicodec = BiCodec.load_from_config_and_checkpoint( - model_dir=Path(bicodec_path), - config_dict=config.bicodec_config["audio_tokenizer"] - ) - # Ensure BiCodec is an nn.Module if you want .to(device) to work easily - if not isinstance(bicodec, torch.nn.Module): - logger.warning("Loaded BiCodec component is not an instance of torch.nn.Module. Automatic device placement might not work.") - - except FileNotFoundError as e: - raise OSError(f"Failed to load BiCodec: A required file was not found in {bicodec_path}. Original error: {e}") - except Exception as e: - logger.error(f"Raw error loading BiCodec: {type(e).__name__}: {e}") - import traceback - traceback.print_exc() # Print full traceback for debugging BiCodec loading - raise OSError(f"Failed to load BiCodec from {bicodec_path}. Check BiCodec implementation and file paths. Error: {e}") - - - # --- 4. Instantiate the main model wrapper --- - # Pass the loaded config and components - model = cls(config, llm=llm, wav2vec2_model=wav2vec2_model, wav2vec2_processor=wav2vec2_processor, bicodec=bicodec) - - # --- 5. Handle device placement --- - # Move the entire model (including sub-modules if they are nn.Module attributes) - # Determine target device based on availability - if torch.cuda.is_available(): - final_device = torch.device("cuda") - # If multiple GPUs, could select one, e.g., torch.device("cuda:0") - # Or rely on CUDA_VISIBLE_DEVICES environment variable - else: - final_device = torch.device("cpu") - - logger.info(f"Placing SparkTTSModel and components on device: {final_device}") - # This should move all registered nn.Module attributes (llm, wav2vec2_model, bicodec if it's an nn.Module) - try: - model.to(final_device) - except Exception as e: - logger.error(f"Failed to move model to device {final_device}. Error: {e}") - logger.warning("Device placement might be incomplete. Check component types.") - - # --- 6. Return the loaded and prepared model --- - return model - - - - # --- Embedding getters/setters (delegate to LLM if loaded) --- - def get_input_embeddings(self): - if self.llm: - return self.llm.get_input_embeddings() - return None # Or raise error - - def set_input_embeddings(self, value): - if self.llm: - self.llm.set_input_embeddings(value) - else: - logger.warning("LLM not loaded, cannot set input embeddings.") - - def get_output_embeddings(self): - if self.llm: - # For causal LM, output embeddings are usually tied to lm_head - return self.llm.get_output_embeddings() - return None # Or raise error - - def set_output_embeddings(self, new_embeddings): - if self.llm and hasattr(self.llm, 'set_output_embeddings'): - self.llm.set_output_embeddings(new_embeddings) - else: - logger.warning("LLM not loaded or does not support set_output_embeddings.") - # --- End Embedding methods --- - - # post_init is less critical now as loading happens in from_pretrained, - # but can be used for final checks or setup. - def post_init(self): - # Ensure wav2vec2 config has output_hidden_states=True after loading - if self.wav2vec2_model and hasattr(self.wav2vec2_model.config, 'output_hidden_states'): - if not self.wav2vec2_model.config.output_hidden_states: - self.wav2vec2_model.config.output_hidden_states = True - logger.info("Set wav2vec2_model.config.output_hidden_states=True") - else: - logger.warning("Could not access wav2vec2_model.config to ensure output_hidden_states=True.") - - @property - def device(self) -> torch.device: - """ Override device property to report the LLM's device as representative """ - if self.llm: - return self.llm.device - else: - # Fallback or default if LLM not loaded yet - # This might be called by pipeline before full init? Be cautious. - try: - return next(self.parameters()).device - except StopIteration: - # If no parameters, default to CPU - return torch.device("cpu") - - @torch.no_grad() - def _extract_wav2vec2_features(self, wavs: torch.Tensor) -> torch.Tensor: - """Extract wav2vec2 features. Input wavs: (B, T_wav)""" - if not self.wav2vec2_model or not self.wav2vec2_processor: - raise RuntimeError("Wav2Vec2 components not loaded.") - - # Use component's device - target_device = self.wav2vec2_model.device - wavs_on_device = wavs.to(target_device) # Expected shape [B, T_wav] e.g., [1, 61120] - - # Process audio using the Wav2Vec2FeatureExtractor - processor_output = self.wav2vec2_processor( - wavs_on_device, - sampling_rate=self.config.sample_rate, - return_tensors="pt", - padding=True, # Ensure padding is handled correctly - ) - inputs = processor_output.input_values # Should be shape [B, T_processed] - # --- START DEBUG & FIX --- - print(f"Shape returned by processor: {inputs.shape}") - - # Reshape if processor added extra dimensions - if inputs.ndim == 4 and inputs.shape[1] == 1 and inputs.shape[2] == 1: - print(f"Reshaping input from {inputs.shape} to 2D.") - inputs = inputs.squeeze(1).squeeze(1) # Remove the two middle dimensions - elif inputs.ndim == 3 and inputs.shape[1] == 1: - print(f"Reshaping input from {inputs.shape} to 2D.") - inputs = inputs.squeeze(1) # Remove the channel dimension - - # Ensure final shape is 2D: (batch_size, sequence_length) - if inputs.ndim != 2: - raise ValueError(f"Unexpected shape after processing/reshaping: {inputs.shape}. Expected 2D input for Wav2Vec2Model.") - - print(f"Shape BEFORE Wav2Vec2Model: {inputs.shape}") - # --- END DEBUG & FIX --- - - inputs = inputs.to(target_device) - # Ensure output_hidden_states=True during call if not set reliably in config - outputs = self.wav2vec2_model(inputs, output_hidden_states=True) - - if outputs.hidden_states is None: - raise ValueError("Wav2Vec2 model did not return hidden states. Ensure config.output_hidden_states=True.") - - # Mix specific layers - num_layers = len(outputs.hidden_states) - indices_to_mix = [11, 14, 16] - valid_indices = [i for i in indices_to_mix if i < num_layers] - - if len(valid_indices) != len(indices_to_mix): - logger.warning(f"Requested Wav2Vec2 hidden state indices {indices_to_mix} out of range (0-{num_layers-1}). Using available valid indices: {valid_indices}.") - if not valid_indices: # If no valid indices, use last hidden state - logger.warning("No valid hidden state indices for mixing. Using last hidden state.") - feats_mix = outputs.last_hidden_state - else: - # Mix available valid indices - feats_mix = torch.stack([outputs.hidden_states[i] for i in valid_indices]).mean(dim=0) - else: - # Original mixing logic - feats_mix = (outputs.hidden_states[11] + outputs.hidden_states[14] + outputs.hidden_states[16]) / 3 - - # Output shape: (B, T_feat, D_feat) - Transpose needed for BiCodec Encoder - return feats_mix.transpose(1, 2) # (B, D_feat, T_feat) - - def _get_ref_clip(self, wav: np.ndarray) -> np.ndarray: - """Get reference audio clip for speaker embedding.""" - ref_samples = int(self.config.sample_rate * self.config.ref_segment_duration) - latent_hop_length = self.config.latent_hop_length - # Ensure length is multiple of hop_length for potential downstream processing - ref_segment_length = max(latent_hop_length, (ref_samples // latent_hop_length) * latent_hop_length) # Ensure at least one hop - - wav_length = len(wav) - - if wav_length == 0: # Handle empty input - return np.zeros(ref_segment_length, dtype=np.float32) - if ref_segment_length > wav_length: - num_repeats = (ref_segment_length // wav_length) + 1 - wav = np.tile(wav, num_repeats) - - return wav[:ref_segment_length].astype(np.float32) # Ensure float32 - - - @torch.no_grad() - def _tokenize_audio(self, audio_path: str) -> Tuple[torch.Tensor, torch.Tensor]: - """Load audio, extract features, and tokenize using BiCodec.""" - wav_np = load_audio( - audio_path, - sampling_rate=self.config.sample_rate, - volume_normalize=self.config.volume_normalize, - ) - wav_ref_np = self._get_ref_clip(wav_np) - - # Convert to tensors, add batch dim, move to device - wav = torch.from_numpy(wav_np).unsqueeze(0).float().to(self.device) - ref_wav = torch.from_numpy(wav_ref_np).unsqueeze(0).float().to(self.device) - - # Extract Wav2Vec2 features -> (B, D_feat, T_feat) - feat = self._extract_wav2vec2_features(wav) - - # Tokenize using BiCodec -> returns (global_tokens, semantic_tokens) - # BiCodec.tokenize expects feat: (B, D_feat, T_feat), ref_wav: (B, T_wav) - global_tokens, semantic_tokens = self.bicodec.tokenize(feat, ref_wav) - # global_tokens: (B, T_token, Q), semantic_tokens: (B, T_latent) - - return global_tokens, semantic_tokens - - @torch.no_grad() - def _detokenize_audio(self, global_tokens: torch.Tensor, semantic_tokens: torch.Tensor) -> np.ndarray: - """Detokenize using BiCodec to get waveform.""" - global_tokens = global_tokens.to(self.device) - semantic_tokens = semantic_tokens.to(self.device) - self.bicodec.to(self.device) # Ensure BiCodec is on device - - # BiCodec.detokenize expects global_tokens: (B, T_token, Q), semantic_tokens: (B, T_latent) - wav_rec = self.bicodec.detokenize(global_tokens, semantic_tokens) # (B, 1, T_wav) - # Remove channel dim and batch dim, convert to numpy - return wav_rec.detach().squeeze(0).squeeze(0).cpu().numpy() - - - def forward(self, *args, **kwargs): - """ Forward pass delegates to the LLM for generation compatibility, but direct use is not intended for TTS. """ - # return self.llm(*args, **kwargs) # Option 1: Delegate fully - logger.warning("Direct forward pass on SparkTTSModel is not the intended use for TTS. Use the generate method or pipeline.") - # Option 2: Minimal implementation for compatibility if needed - if 'input_ids' in kwargs: - return self.llm(input_ids=kwargs['input_ids']) - else: - raise NotImplementedError("SparkTTSModel's forward pass requires 'input_ids' or should not be called directly for TTS.") - - - - # Use GenerationMixin's forward method by default if needed. - # Define prepare_inputs_for_generation if LLM needs specific handling. - def prepare_inputs_for_generation(self, input_ids, **kwargs): - """ Prepares inputs for the LLM's generate method. """ - if not self.llm: - raise RuntimeError("LLM component not loaded.") - - # --- START REVISED IMPLEMENTATION --- - # Delegate to the LLM's prepare_inputs_for_generation method directly. - # This ensures we use the exact logic defined for the specific LLM architecture (Qwen2). - # It should handle past_key_values, attention_mask, use_cache etc. correctly. - try: - # Pass all relevant kwargs received by the top-level generate call - # The LLM's method will select what it needs. - model_inputs = self.llm.prepare_inputs_for_generation(input_ids, **kwargs) - return model_inputs - except AttributeError: - # Fallback if the LLM doesn't have this method (unlikely for recent models) - logger.warning("LLM does not have 'prepare_inputs_for_generation'. Using basic fallback.") - model_kwargs = {} - model_kwargs["past_key_values"] = kwargs.get("past_key_values", None) - model_kwargs["use_cache"] = kwargs.get("use_cache", None) - # Ensure attention_mask is included if present in kwargs - if "attention_mask" in kwargs: - model_kwargs["attention_mask"] = kwargs["attention_mask"] - return {"input_ids": input_ids, **model_kwargs} - # --- END REVISED IMPLEMENTATION --- - - - # We need a minimal forward method compatible with GenerationMixin - # It should accept the output of prepare_inputs_for_generation - def forward( - self, - input_ids: Optional[torch.Tensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - **kwargs # Accept other potential kwargs from prepare_inputs - ) -> Any: # Return type depends on the LLM, usually CausalLMOutputWithPast - """ - Minimal forward pass that delegates to the underlying LLM. - Required for compatibility with GenerationMixin. - Accepts arguments typically returned by prepare_inputs_for_generation. - """ - if not self.llm: - raise RuntimeError("LLM component not loaded.") - - # Filter arguments for the LLM's forward method - # (Some LLMs might not accept position_ids directly in forward when using past_key_values) - llm_kwargs = { - "past_key_values": past_key_values, - "attention_mask": attention_mask, - **kwargs # Pass through any other relevant kwargs - } - # Only pass position_ids if the LLM's forward signature accepts it - # This requires inspecting the LLM's forward signature or knowing its behavior. - # For simplicity, we might omit it if it causes issues, or handle it more dynamically. - # Let's assume the LLM forward can handle it for now if prepare_inputs included it. - if position_ids is not None: - llm_kwargs["position_ids"] = position_ids - - return self.llm(input_ids=input_ids, **llm_kwargs) - - # Add generate method to use GenerationMixin capabilities directly on SparkTTSModel if desired - # This will internally call prepare_inputs_for_generation and forward (which might need defining/adjusting) - # However, the pipeline calls self.model.llm.generate, so this might not be strictly needed unless you want `model.generate(...)` - # @torch.no_grad() - # def generate(self, *args, **kwargs): - # if not self.llm: - # raise RuntimeError("LLM component not loaded.") - # # This might need adjustments based on how GenerationMixin interacts with the overridden forward - # # return super().generate(*args, **kwargs) # Calls self.prepare_inputs + self.forward loop - # # Or directly call the LLM's generate if forward is problematic: - # return self.llm.generate(*args, **kwargs) -