# coding=utf-8 # Copyright 2024 The SparkAudio Authors and The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ PyTorch SparkTTS model.""" import torch import torch.nn as nn import numpy as np import os import warnings from pathlib import Path from typing import Dict, Any, Tuple, Optional, Union from transformers import PreTrainedModel, AutoModelForCausalLM, Wav2Vec2FeatureExtractor, Wav2Vec2Model from transformers.utils import logging, requires_backends, cached_file from huggingface_hub import snapshot_download from transformers.generation.utils import GenerationMixin from transformers.configuration_utils import PretrainedConfig from safetensors.torch import load_file import torchaudio.transforms as TT # Directly use torchaudio # # Import necessary components from the original codebase structure # # These are now defined in _modeling_bicodec_components.py # from ._modeling_bicodec_components import ( # SpeakerEncoder, # Encoder, # Decoder, # WaveGenerator, # FactorizedVectorQuantize, # # Include Snake1d or other base classes if BiCodec.__init__ needs them directly # ) """ Utility functions for SparkTTS """ import random import soxr import soundfile import torch import torchaudio import numpy as np from pathlib import Path from typing import Tuple, Dict, Any from numpy.lib.stride_tricks import sliding_window_view from omegaconf import OmegaConf # Keep if BiCodec config loading needs it # --- Token Maps (from sparktts/utils/token_parser.py) --- TASK_TOKEN_MAP = { "vc": "<|task_vc|>", "tts": "<|task_tts|>", "asr": "<|task_asr|>", "s2s": "<|task_s2s|>", "t2s": "<|task_t2s|>", "understand": "<|task_understand|>", "caption": "<|task_cap|>", "controllable_tts": "<|task_controllable_tts|>", "prompt_tts": "<|task_prompt_tts|>", "speech_edit": "<|task_edit|>", } LEVELS_MAP = { "very_low": 0, "low": 1, "moderate": 2, "high": 3, "very_high": 4, } LEVELS_MAP_UI = { 1: 'very_low', 2: 'low', 3: 'moderate', 4: 'high', 5: 'very_high' } GENDER_MAP = { "female": 0, "male": 1, } # --- Audio Utils (from sparktts/utils/audio.py) --- def audio_volume_normalize(audio: np.ndarray, coeff: float = 0.2) -> np.ndarray: temp = np.sort(np.abs(audio)) if len(temp) == 0: # Handle empty audio case return audio if temp[-1] < 0.1: scaling_factor = max(temp[-1], 1e-3) audio = audio / scaling_factor * 0.1 temp = temp[temp > 0.01] L = temp.shape[0] if L <= 10: return audio volume = np.mean(temp[int(0.9 * L) : int(0.99 * L)]) if volume == 0: # Avoid division by zero if volume is effectively zero return audio audio = audio * np.clip(coeff / volume, a_min=0.1, a_max=10) max_value = np.max(np.abs(audio)) if len(audio) > 0 else 0 if max_value > 1: audio = audio / max_value return audio def load_audio( adfile: Path, sampling_rate: int = None, length: int = None, volume_normalize: bool = False, segment_duration: int = None, ) -> np.ndarray: try: audio, sr = soundfile.read(adfile, dtype='float32') # Ensure float32 except Exception as e: raise IOError(f"Could not read audio file {adfile}: {e}") if audio is None or len(audio) == 0: raise ValueError(f"Audio file {adfile} is empty or invalid.") if len(audio.shape) > 1: audio = audio[:, 0] if sampling_rate is not None and sr != sampling_rate: try: # Ensure input is float64 for soxr audio = audio.astype(np.float64) audio = soxr.resample(audio, sr, sampling_rate, quality="VHQ") # Convert back to float32 audio = audio.astype(np.float32) sr = sampling_rate except Exception as e: raise RuntimeError(f"Failed to resample audio from {sr}Hz to {sampling_rate}Hz: {e}") if segment_duration is not None: seg_length = int(sr * segment_duration) audio = random_select_audio_segment(audio, seg_length) if volume_normalize: audio = audio_volume_normalize(audio) if length is not None: if audio.shape[0] > length: audio = audio[:length] else: audio = np.pad(audio, (0, int(length - audio.shape[0])), mode='constant') return audio def random_select_audio_segment(audio: np.ndarray, length: int) -> np.ndarray: if audio.shape[0] < length: audio = np.pad(audio, (0, int(length - audio.shape[0])), mode='constant') start_index = 0 # If padded, start from beginning elif audio.shape[0] == length: start_index = 0 # If exact length, start from beginning else: start_index = random.randint(0, audio.shape[0] - length) end_index = int(start_index + length) return audio[start_index:end_index] # --- File Utils (Minimal required) --- def load_config_yaml(config_path: Path) -> Dict: """Loads a YAML configuration file using OmegaConf.""" # Check if path exists if not Path(config_path).is_file(): raise FileNotFoundError(f"YAML Config file not found: {config_path}") try: config = OmegaConf.load(config_path) # Convert OmegaConf DictConfig to standard Python dict return OmegaConf.to_container(config, resolve=True) except Exception as e: raise IOError(f"Error loading YAML config file {config_path}: {e}") """ PyTorch SparkTTS BiCodec sub-module definitions.""" import torch import torch.nn as nn import torch.nn.functional as F import torch.distributed as dist import random from torch.nn.utils import weight_norm, remove_weight_norm from torch import Tensor, int32 from torch.amp import autocast from typing import Any, Dict, List, Tuple, Optional from collections import namedtuple from functools import wraps, partial from contextlib import nullcontext from packaging import version from einops import rearrange, repeat, reduce, pack, unpack from einops.layers.torch import Rearrange from einx import get_at # Ensure einx is installed: pip install einx # =============================================================== # Start: Content from sparktts/modules/blocks/layers.py # =============================================================== def WNConv1d(*args, **kwargs): return weight_norm(nn.Conv1d(*args, **kwargs)) def WNConvTranspose1d(*args, **kwargs): return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) # Scripting this brings model speed up 1.4x @torch.jit.script def snake(x, alpha): shape = x.shape x = x.reshape(shape[0], shape[1], -1) x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2) x = x.reshape(shape) return x class Snake1d(nn.Module): def __init__(self, channels): super().__init__() self.alpha = nn.Parameter(torch.ones(1, channels, 1)) def forward(self, x): return snake(x, self.alpha) class ResidualUnit(nn.Module): def __init__(self, dim: int = 16, dilation: int = 1): super().__init__() pad = ((7 - 1) * dilation) // 2 self.block = nn.Sequential( Snake1d(dim), WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad), Snake1d(dim), WNConv1d(dim, dim, kernel_size=1), ) def forward(self, x): y = self.block(x) # Adjust padding handling if input and output shapes differ diff = x.shape[-1] - y.shape[-1] if diff > 0: pad = diff // 2 x = x[..., pad:pad + y.shape[-1]] # Ensure shapes match for residual connection elif diff < 0: pad = -diff // 2 y = y[..., pad:pad + x.shape[-1]] return x + y def init_weights(m): if isinstance(m, nn.Conv1d): nn.init.trunc_normal_(m.weight, std=0.02) if m.bias is not None: nn.init.constant_(m.bias, 0) # =============================================================== # End: Content from sparktts/modules/blocks/layers.py # =============================================================== # =============================================================== # Start: Content from sparktts/modules/blocks/samper.py # =============================================================== class SamplingBlock(nn.Module): """Sampling block for upsampling or downsampling""" def __init__( self, dim: int, groups: int = 1, upsample_scale: int = 1, downsample_scale: int = 1, ) -> None: """ Args: dim: input dimension groups: number of groups upsample_scale: upsampling scale downsample_scale: downsampling scale """ super(SamplingBlock, self).__init__() self.upsample_scale = upsample_scale self.downsample_scale = downsample_scale if self.upsample_scale > 1: self.de_conv_upsampler = nn.Sequential( nn.LeakyReLU(0.2), nn.ConvTranspose1d( dim, dim, kernel_size=upsample_scale * 2, stride=upsample_scale, padding=upsample_scale // 2 + upsample_scale % 2, output_padding=upsample_scale % 2, groups=groups, ), ) if self.downsample_scale > 1: self.conv_downsampler = nn.Sequential( nn.LeakyReLU(0.2), nn.Conv1d( dim, dim, kernel_size=2 * downsample_scale, stride=downsample_scale, padding=downsample_scale // 2 + downsample_scale % 2, groups=groups, ), ) @staticmethod def repeat_upsampler(x, upsample_scale): return x.repeat_interleave(upsample_scale, dim=2) @staticmethod def skip_downsampler(x, downsample_scale): return F.avg_pool1d(x, kernel_size=downsample_scale, stride=downsample_scale) def forward(self, x): # Input expected as (B, D, T) from VocosBackbone output (B, T, D) # x = x.transpose(1, 2) # Remove this transpose, input should be (B, D, T) if self.upsample_scale > 1: repeat_res = self.repeat_upsampler(x, self.upsample_scale) deconv_res = self.de_conv_upsampler(x) # Ensure shapes match for addition if deconv_res.shape[-1] > repeat_res.shape[-1]: deconv_res = deconv_res[..., :repeat_res.shape[-1]] elif repeat_res.shape[-1] > deconv_res.shape[-1]: repeat_res = repeat_res[..., :deconv_res.shape[-1]] upmerge_res = repeat_res + deconv_res else: upmerge_res = x repeat_res = x if self.downsample_scale > 1: conv_res = self.conv_downsampler(upmerge_res) skip2_res = self.skip_downsampler(upmerge_res, self.downsample_scale) skip1_res = self.skip_downsampler(repeat_res, self.downsample_scale) # Ensure shapes match min_len = min(conv_res.shape[-1], skip1_res.shape[-1], skip2_res.shape[-1]) conv_res = conv_res[..., :min_len] skip1_res = skip1_res[..., :min_len] skip2_res = skip2_res[..., :min_len] else: conv_res = upmerge_res skip2_res = upmerge_res skip1_res = repeat_res final_res = conv_res + skip1_res + skip2_res # Return (B, D, T) for next VocosBackbone # return final_res.transpose(1, 2) # Remove this, keep (B, D, T) return final_res # =============================================================== # End: Content from sparktts/modules/blocks/samper.py # =============================================================== # =============================================================== # Start: Content from sparktts/modules/speaker/pooling_layers.py # =============================================================== class TAP(nn.Module): """ Temporal average pooling, only first-order mean is considered """ def __init__(self, in_dim=0, **kwargs): super(TAP, self).__init__() self.in_dim = in_dim def forward(self, x): pooling_mean = x.mean(dim=-1) # To be compatable with 2D input pooling_mean = pooling_mean.flatten(start_dim=1) return pooling_mean def get_out_dim(self): # This method seems specific to the original usage, might not be needed by HF # self.out_dim = self.in_dim # return self.out_dim return self.in_dim class TSDP(nn.Module): """ Temporal standard deviation pooling, only second-order std is considered """ def __init__(self, in_dim=0, **kwargs): super(TSDP, self).__init__() self.in_dim = in_dim def forward(self, x): # The last dimension is the temporal axis pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-7) pooling_std = pooling_std.flatten(start_dim=1) return pooling_std def get_out_dim(self): # self.out_dim = self.in_dim # return self.out_dim return self.in_dim class TSTP(nn.Module): """ Temporal statistics pooling, concatenate mean and std, which is used in x-vector Comment: simple concatenation can not make full use of both statistics """ def __init__(self, in_dim=0, **kwargs): super(TSTP, self).__init__() self.in_dim = in_dim def forward(self, x): # The last dimension is the temporal axis pooling_mean = x.mean(dim=-1) pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-7) pooling_mean = pooling_mean.flatten(start_dim=1) pooling_std = pooling_std.flatten(start_dim=1) stats = torch.cat((pooling_mean, pooling_std), 1) return stats def get_out_dim(self): # self.out_dim = self.in_dim * 2 # return self.out_dim return self.in_dim * 2 class ASTP(nn.Module): """ Attentive statistics pooling: Channel- and context-dependent statistics pooling, first used in ECAPA_TDNN. """ def __init__(self, in_dim, bottleneck_dim=128, global_context_att=False, **kwargs): super(ASTP, self).__init__() self.in_dim = in_dim self.global_context_att = global_context_att # Use Conv1d with stride == 1 rather than Linear, then we don't # need to transpose inputs. if global_context_att: self.linear1 = nn.Conv1d( in_dim * 3, bottleneck_dim, kernel_size=1) # equals W and b in the paper else: self.linear1 = nn.Conv1d( in_dim, bottleneck_dim, kernel_size=1) # equals W and b in the paper self.linear2 = nn.Conv1d(bottleneck_dim, in_dim, kernel_size=1) # equals V and k in the paper def forward(self, x): """ x: a 3-dimensional tensor in tdnn-based architecture (B,F,T) or a 4-dimensional tensor in resnet architecture (B,C,F,T) 0-dim: batch-dimension, last-dim: time-dimension (frame-dimension) """ if len(x.shape) == 4: x = x.reshape(x.shape[0], x.shape[1] * x.shape[2], x.shape[3]) assert len(x.shape) == 3 if self.global_context_att: context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x) context_std = torch.sqrt( torch.var(x, dim=-1, keepdim=True) + 1e-7).expand_as(x) x_in = torch.cat((x, context_mean, context_std), dim=1) else: x_in = x # DON'T use ReLU here! ReLU may be hard to converge. alpha = torch.tanh( self.linear1(x_in)) # alpha = F.relu(self.linear1(x_in)) alpha = torch.softmax(self.linear2(alpha), dim=2) mean = torch.sum(alpha * x, dim=2) var = torch.sum(alpha * (x**2), dim=2) - mean**2 std = torch.sqrt(var.clamp(min=1e-7)) return torch.cat([mean, std], dim=1) def get_out_dim(self): # self.out_dim = 2 * self.in_dim # return self.out_dim return self.in_dim * 2 class MHASTP(torch.nn.Module): """ Multi head attentive statistics pooling Reference: Self Multi-Head Attention for Speaker Recognition https://arxiv.org/pdf/1906.09890.pdf """ def __init__(self, in_dim, layer_num=2, head_num=2, d_s=1, bottleneck_dim=64, **kwargs): super(MHASTP, self).__init__() assert (in_dim % head_num ) == 0 # make sure that head num can be divided by input_dim self.in_dim = in_dim self.head_num = head_num d_model = int(in_dim / head_num) channel_dims = [bottleneck_dim for i in range(layer_num + 1)] if d_s > 1: d_s = d_model else: d_s = 1 self.d_s = d_s channel_dims[0], channel_dims[-1] = d_model, d_s heads_att_trans = [] for i in range(self.head_num): att_trans = nn.Sequential() for j in range(layer_num - 1): # Use different loop variable att_trans.add_module( 'att_' + str(j), nn.Conv1d(channel_dims[j], channel_dims[j + 1], 1, 1)) att_trans.add_module('tanh' + str(j), nn.Tanh()) att_trans.add_module( 'att_' + str(layer_num - 1), nn.Conv1d(channel_dims[layer_num - 1], channel_dims[layer_num], 1, 1)) heads_att_trans.append(att_trans) self.heads_att_trans = nn.ModuleList(heads_att_trans) def forward(self, input): """ input: a 3-dimensional tensor in xvector architecture or a 4-dimensional tensor in resnet architecture 0-dim: batch-dimension, last-dim: time-dimension (frame-dimension) """ if len(input.shape) == 4: # B x C x F x T input = input.reshape(input.shape[0], input.shape[1] * input.shape[2], input.shape[3]) # B x (C*F) x T assert len(input.shape) == 3 bs, f_dim, t_dim = input.shape chunks = torch.chunk(input, self.head_num, 1) # split chunks_out = [] for i, layer in enumerate(self.heads_att_trans): att_score = layer(chunks[i]) alpha = F.softmax(att_score, dim=-1) mean = torch.sum(alpha * chunks[i], dim=2) var = torch.sum(alpha * chunks[i]**2, dim=2) - mean**2 std = torch.sqrt(var.clamp(min=1e-7)) chunks_out.append(torch.cat((mean, std), dim=1)) out = torch.cat(chunks_out, dim=1) return out def get_out_dim(self): # self.out_dim = 2 * self.in_dim # return self.out_dim return self.in_dim * 2 class MQMHASTP(torch.nn.Module): """ An attentive pooling Reference: multi query multi head attentive statistics pooling https://arxiv.org/pdf/2110.05042.pdf Args: in_dim: the feature dimension of input layer_num: the number of layer in the pooling layer query_num: the number of querys head_num: the number of heads bottleneck_dim: the bottleneck dimension SA (H = 1, Q = 1, n = 2, d_s = 1) ref: https://www.danielpovey.com/files/2018_interspeech_xvector_attention.pdf MHA (H > 1, Q = 1, n = 1, d_s = 1) ref: https://arxiv.org/pdf/1906.09890.pdf AS (H = 1, Q > 1, n = 2, d_s = 1) ref: https://arxiv.org/pdf/1803.10963.pdf VSA (H = 1, Q > 1, n = 2, d_s = d_h) ref: http://www.interspeech2020.org/uploadfile/pdf/Mon-2-10-5.pdf """ def __init__(self, in_dim, layer_num=2, query_num=2, head_num=8, d_s=2, bottleneck_dim=64, **kwargs): super(MQMHASTP, self).__init__() self.n_query = nn.ModuleList([ MHASTP(in_dim, layer_num=layer_num, head_num=head_num, d_s=d_s, bottleneck_dim=bottleneck_dim) for i in range(query_num) ]) self.query_num = query_num self.in_dim = in_dim def forward(self, input): """ input: a 3-dimensional tensor in xvector architecture or a 4-dimensional tensor in resnet architecture 0-dim: batch-dimension, last-dim: time-dimension (frame-dimension) """ if len(input.shape) == 4: # B x C x F x T input = input.reshape(input.shape[0], input.shape[1] * input.shape[2], input.shape[3]) # B x (C*F) x T assert len(input.shape) == 3 res = [] for i, layer in enumerate(self.n_query): res.append(layer(input)) out = torch.cat(res, dim=-1) return out def get_out_dim(self): # self.out_dim = self.in_dim * 2 * self.query_num # return self.out_dim return self.in_dim * 2 * self.query_num # =============================================================== # End: Content from sparktts/modules/speaker/pooling_layers.py # =============================================================== # =============================================================== # Start: Content from sparktts/modules/blocks/vocos.py # =============================================================== # Helper functions needed by VocosBackbone etc. def exists(val): return val is not None def default(val, d): return val if exists(val) else d() if callable(d) else d class AdaLayerNorm(nn.Module): """ Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes Args: condition_dim (int): Dimension of the condition. embedding_dim (int): Dimension of the embeddings. """ def __init__(self, condition_dim: int, embedding_dim: int, eps: float = 1e-6): super().__init__() self.eps = eps self.dim = embedding_dim self.scale = nn.Linear(condition_dim, embedding_dim) self.shift = nn.Linear(condition_dim, embedding_dim) # Initialize weights similar to original implementation if needed # torch.nn.init.ones_(self.scale.weight) # Might be default # torch.nn.init.zeros_(self.shift.weight) # Might be default if self.scale.bias is not None: nn.init.zeros_(self.scale.bias) if self.shift.bias is not None: nn.init.zeros_(self.shift.bias) def forward(self, x: torch.Tensor, cond_embedding: torch.Tensor) -> torch.Tensor: scale = self.scale(cond_embedding) shift = self.shift(cond_embedding) x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps) x = x * scale.unsqueeze(1) + shift.unsqueeze(1) return x class ConvNeXtBlock(nn.Module): """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal. Args: dim (int): Number of input channels. intermediate_dim (int): Dimensionality of the intermediate layer. layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling. Defaults to None. condition_dim (int, optional): Dimension for AdaLayerNorm. None means non-conditional LayerNorm. Defaults to None. """ def __init__( self, dim: int, intermediate_dim: int, layer_scale_init_value: float, condition_dim: Optional[int] = None, ): super().__init__() self.dwconv = nn.Conv1d( dim, dim, kernel_size=7, padding=3, groups=dim ) # depthwise conv self.adanorm = condition_dim is not None if self.adanorm: self.norm = AdaLayerNorm(condition_dim, dim, eps=1e-6) else: self.norm = nn.LayerNorm(dim, eps=1e-6) self.pwconv1 = nn.Linear( dim, intermediate_dim ) # pointwise/1x1 convs, implemented with linear layers self.act = nn.GELU() self.pwconv2 = nn.Linear(intermediate_dim, dim) self.gamma = ( nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True) if layer_scale_init_value is not None and layer_scale_init_value > 0 else None ) def forward( self, x: torch.Tensor, cond_embedding: Optional[torch.Tensor] = None ) -> torch.Tensor: residual = x x = self.dwconv(x) x = x.transpose(1, 2) # (B, C, T) -> (B, T, C) if self.adanorm: assert cond_embedding is not None, "Conditioning embedding required for AdaLayerNorm" x = self.norm(x, cond_embedding) else: x = self.norm(x) x = self.pwconv1(x) x = self.act(x) x = self.pwconv2(x) if self.gamma is not None: x = self.gamma * x x = x.transpose(1, 2) # (B, T, C) -> (B, C, T) x = residual + x return x class ResBlock1(nn.Module): """ ResBlock adapted from HiFi-GAN V1 (https://github.com/jik876/hifi-gan) with dilated 1D convolutions, but without upsampling layers. Args: dim (int): Number of input channels. kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3. dilation (tuple[int], optional): Dilation factors for the dilated convolutions. Defaults to (1, 3, 5). lrelu_slope (float, optional): Negative slope of the LeakyReLU activation function. Defaults to 0.1. layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling. Defaults to None. """ def __init__( self, dim: int, kernel_size: int = 3, dilation: Tuple[int, int, int] = (1, 3, 5), lrelu_slope: float = 0.1, layer_scale_init_value: Optional[float] = None, ): super().__init__() self.lrelu_slope = lrelu_slope self.convs1 = nn.ModuleList( [ weight_norm( nn.Conv1d( dim, dim, kernel_size, 1, dilation=dilation[0], padding=self.get_padding(kernel_size, dilation[0]), ) ), weight_norm( nn.Conv1d( dim, dim, kernel_size, 1, dilation=dilation[1], padding=self.get_padding(kernel_size, dilation[1]), ) ), weight_norm( nn.Conv1d( dim, dim, kernel_size, 1, dilation=dilation[2], padding=self.get_padding(kernel_size, dilation[2]), ) ), ] ) self.convs2 = nn.ModuleList( [ weight_norm( nn.Conv1d( dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1), ) ), weight_norm( nn.Conv1d( dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1), ) ), weight_norm( nn.Conv1d( dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1), ) ), ] ) self.gamma = nn.ParameterList( [ ( nn.Parameter( layer_scale_init_value * torch.ones(dim, 1), requires_grad=True ) if layer_scale_init_value is not None else None ), ( nn.Parameter( layer_scale_init_value * torch.ones(dim, 1), requires_grad=True ) if layer_scale_init_value is not None else None ), ( nn.Parameter( layer_scale_init_value * torch.ones(dim, 1), requires_grad=True ) if layer_scale_init_value is not None else None ), ] ) def forward(self, x: torch.Tensor) -> torch.Tensor: for c1, c2, gamma in zip(self.convs1, self.convs2, self.gamma): xt = torch.nn.functional.leaky_relu(x, negative_slope=self.lrelu_slope) xt = c1(xt) xt = torch.nn.functional.leaky_relu(xt, negative_slope=self.lrelu_slope) xt = c2(xt) if gamma is not None: xt = gamma * xt x = xt + x return x def remove_weight_norm(self): for l in self.convs1: remove_weight_norm(l) for l in self.convs2: remove_weight_norm(l) @staticmethod def get_padding(kernel_size: int, dilation: int = 1) -> int: return int((kernel_size * dilation - dilation) / 2) class Backbone(nn.Module): """Base class for the generator's backbone. It preserves the same temporal resolution across all layers.""" def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: """ Args: x (Tensor): Input tensor of shape (B, C, L), where B is the batch size, C denotes input features, and L is the sequence length. Returns: Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length, and H denotes the model dimension. """ raise NotImplementedError("Subclasses must implement the forward method.") class VocosBackbone(Backbone): """ Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization Args: input_channels (int): Number of input features channels. dim (int): Hidden dimension of the model. intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock. num_layers (int): Number of ConvNeXtBlock layers. layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`. condition_dim (int, optional): Dimension for AdaLayerNorm. None means non-conditional model. Defaults to None. """ def __init__( self, input_channels: int, dim: int, intermediate_dim: int, num_layers: int, layer_scale_init_value: Optional[float] = None, condition_dim: Optional[int] = None, ): super().__init__() self.input_channels = input_channels self.embed = nn.Conv1d(input_channels, dim, kernel_size=7, padding=3) self.adanorm = condition_dim is not None if self.adanorm: self.norm = AdaLayerNorm(condition_dim, dim, eps=1e-6) else: self.norm = nn.LayerNorm(dim, eps=1e-6) layer_scale_init_value = layer_scale_init_value or 1 / num_layers if num_layers > 0 else None # Handle num_layers=0 self.convnext = nn.ModuleList( [ ConvNeXtBlock( dim=dim, intermediate_dim=intermediate_dim, layer_scale_init_value=layer_scale_init_value, condition_dim=condition_dim, ) for _ in range(num_layers) ] ) self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, (nn.Conv1d, nn.Linear)): nn.init.trunc_normal_(m.weight, std=0.02) if m.bias is not None: nn.init.constant_(m.bias, 0) def forward(self, x: torch.Tensor, condition: Optional[torch.Tensor] = None) -> torch.Tensor: # Input x: (B, C, L) x = self.embed(x) # After embed: (B, dim, L) x_transposed = x.transpose(1, 2) # (B, L, dim) if self.adanorm: assert condition is not None norm_out = self.norm(x_transposed, condition) else: norm_out = self.norm(x_transposed) # After norm: (B, L, dim) x = norm_out.transpose(1, 2) # (B, dim, L) for conv_block in self.convnext: x = conv_block(x, condition) # After convnext blocks: (B, dim, L) x = self.final_layer_norm(x.transpose(1, 2)) # (B, L, dim) return x class VocosResNetBackbone(Backbone): """ Vocos backbone module built with ResBlocks. Args: input_channels (int): Number of input features channels. dim (int): Hidden dimension of the model. num_blocks (int): Number of ResBlock1 blocks. layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to None. """ def __init__( self, input_channels, dim, num_blocks, layer_scale_init_value=None, ): super().__init__() self.input_channels = input_channels self.embed = weight_norm( nn.Conv1d(input_channels, dim, kernel_size=3, padding=1) ) layer_scale_init_value = layer_scale_init_value or 1 / num_blocks / 3 if num_blocks > 0 else None # Handle num_blocks=0 self.resnet = nn.Sequential( *[ ResBlock1(dim=dim, layer_scale_init_value=layer_scale_init_value) for _ in range(num_blocks) ] ) def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: # Input x: (B, C, L) x = self.embed(x) # After embed: (B, dim, L) x = self.resnet(x) # After resnet: (B, dim, L) x = x.transpose(1, 2) # (B, L, dim) return x # =============================================================== # End: Content from sparktts/modules/blocks/vocos.py # =============================================================== # =============================================================== # Start: Content from sparktts/modules/encoder_decoder/feat_decoder.py # =============================================================== class Decoder(nn.Module): """Decoder module with convnext and upsampling blocks Args: sample_ratios (List[int]): sample ratios example: [2, 2] means upsample by 2x and then upsample by 2x """ def __init__( self, input_channels: int, vocos_dim: int, vocos_intermediate_dim: int, vocos_num_layers: int, out_channels: int, condition_dim: int = None, sample_ratios: List[int] = [1, 1], use_tanh_at_final: bool = False, ): super().__init__() self.linear_pre = nn.Linear(input_channels, vocos_dim) upsample_modules = [] current_dim = vocos_dim for i, ratio in enumerate(sample_ratios): upsample_modules.append( nn.Sequential( SamplingBlock( dim=current_dim, groups=current_dim, # Maybe use 1 or fewer groups if dim is high? Check original intent. Using current_dim for now. upsample_scale=ratio, ), # Note: The original code used VocosBackbone here, but it changes dims B,T,D -> B,D,T. # SamplingBlock output is B,D,T, so VocosBackbone input matches. # However, the VocosBackbone output is B,T,D, which doesn't fit the next SamplingBlock. # Assuming the intent was to keep B,D,T format between sampling blocks. # Replacing intermediate VocosBackbone with a simple Conv1d block to maintain format & refine. nn.Conv1d(current_dim, current_dim, kernel_size=3, padding=1) # Simple refinement layer # VocosBackbone( # input_channels=current_dim, # dim=current_dim, # intermediate_dim=vocos_intermediate_dim // 2, # Smaller intermediate for efficiency? # num_layers=2, # Fewer layers # condition_dim=None, # ) ) ) # No dimension change expected here if using Conv1d refinement # If using VocosBackbone, need transpose logic self.upsample = nn.Sequential(*upsample_modules) # Final Backbone processes the fully upsampled features self.vocos_backbone = VocosBackbone( input_channels=current_dim, # Use the dim after upsampling dim=vocos_dim, # Map back to main vocos_dim or keep current_dim? Using vocos_dim intermediate_dim=vocos_intermediate_dim, num_layers=vocos_num_layers, condition_dim=condition_dim, ) self.linear_post = nn.Linear(vocos_dim, out_channels) self.use_tanh_at_final = use_tanh_at_final def forward(self, x: torch.Tensor, c: torch.Tensor = None): """decoder forward. Args: x (torch.Tensor): (batch_size, input_channels, length) c (torch.Tensor): (batch_size, condition_dim) - Optional condition Returns: x (torch.Tensor): (batch_size, out_channels, length_upsampled) """ # x: (B, C_in, T) x = self.linear_pre(x.transpose(1, 2)) # (B, T, vocos_dim) x = x.transpose(1, 2) # (B, vocos_dim, T) # Apply upsampling blocks x = self.upsample(x) # (B, vocos_dim, T_upsampled) # Apply final backbone x = self.vocos_backbone(x, condition=c) # (B, T_upsampled, vocos_dim) x = self.linear_post(x) # (B, T_upsampled, C_out) x = x.transpose(1, 2) # (B, C_out, T_upsampled) if self.use_tanh_at_final: x = torch.tanh(x) return x # =============================================================== # End: Content from sparktts/modules/encoder_decoder/feat_decoder.py # =============================================================== # =============================================================== # Start: Content from sparktts/modules/encoder_decoder/feat_encoder.py # =============================================================== class Encoder(nn.Module): """Encoder module with convnext and downsampling blocks""" def __init__( self, input_channels: int, vocos_dim: int, vocos_intermediate_dim: int, vocos_num_layers: int, out_channels: int, sample_ratios: List[int] = [1, 1], ): super().__init__() """ Encoder module with VocosBackbone and sampling blocks. Args: sample_ratios (List[int]): sample ratios example: [2, 2] means downsample by 2x and then downsample by 2x """ # Initial Backbone processing self.encoder_backbone = VocosBackbone( input_channels=input_channels, dim=vocos_dim, intermediate_dim=vocos_intermediate_dim, num_layers=vocos_num_layers, # Use main num_layers here condition_dim=None, ) downsample_modules = [] current_dim = vocos_dim for i, ratio in enumerate(sample_ratios): downsample_modules.append( nn.Sequential( SamplingBlock( dim=current_dim, groups=current_dim, # Again, check group size. Using current_dim. downsample_scale=ratio, ), # Add refinement layer (optional, similar to Decoder logic) nn.Conv1d(current_dim, current_dim, kernel_size=3, padding=1) # VocosBackbone( # Or a lighter VocosBackbone # input_channels=current_dim, # dim=current_dim, # intermediate_dim=vocos_intermediate_dim // 2, # num_layers=2, # condition_dim=None, # ) ) ) # No dimension change expected here self.downsample = nn.Sequential(*downsample_modules) self.project = nn.Linear(current_dim, out_channels) # Project from the final dimension def forward(self, x: torch.Tensor, *args): """ Args: x (torch.Tensor): (batch_size, input_channels, length) Returns: x (torch.Tensor): (batch_size, out_channels, length_downsampled) """ # x: (B, C_in, T) x = self.encoder_backbone(x) # (B, T, vocos_dim) x = x.transpose(1, 2) # (B, vocos_dim, T) # Apply downsampling blocks x = self.downsample(x) # (B, vocos_dim, T_downsampled) x = x.transpose(1, 2) # (B, T_downsampled, vocos_dim) x = self.project(x) # (B, T_downsampled, C_out) return x.transpose(1, 2) # (B, C_out, T_downsampled) # =============================================================== # End: Content from sparktts/modules/encoder_decoder/feat_encoder.py # =============================================================== # =============================================================== # Start: Content from sparktts/modules/encoder_decoder/wave_generator.py # =============================================================== class DecoderBlock(nn.Module): def __init__( self, input_dim: int = 16, output_dim: int = 8, kernel_size: int = 2, stride: int = 1, ): super().__init__() # Ensure stride is at least 1 stride = max(1, stride) # Ensure kernel_size is valid for ConvTranspose1d if kernel_size < stride: kernel_size = stride # Or handle differently padding = (kernel_size - stride) // 2 output_padding = stride % 2 if kernel_size % 2 == 0 else 0 # Basic calculation, might need adjustment based on desired output length # print(f"DecoderBlock - Input: {input_dim}, Output: {output_dim}, Kernel: {kernel_size}, Stride: {stride}, Padding: {padding}, OutputPadding: {output_padding}") self.block = nn.Sequential( Snake1d(input_dim), WNConvTranspose1d( input_dim, output_dim, kernel_size=kernel_size, stride=stride, padding=padding, output_padding=output_padding, # Add output_padding ), ResidualUnit(output_dim, dilation=1), ResidualUnit(output_dim, dilation=3), ResidualUnit(output_dim, dilation=9), ) def forward(self, x): return self.block(x) class WaveGenerator(nn.Module): def __init__( self, input_channel, channels, rates, kernel_sizes, d_out: int = 1, ): super().__init__() # Add first conv layer layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)] # Add upsampling + MRF blocks current_channels = channels for i, (kernel_size, stride) in enumerate(zip(kernel_sizes, rates)): input_dim = current_channels # Ensure output_dim doesn't go below 1 output_dim = max(1, channels // (2 ** (i + 1))) layers += [DecoderBlock(input_dim, output_dim, kernel_size, stride)] current_channels = output_dim # Update for the next block's input # Add final conv layer layers += [ Snake1d(current_channels), # Use the final output_dim WNConv1d(current_channels, d_out, kernel_size=7, padding=3), nn.Tanh(), ] self.model = nn.Sequential(*layers) self.apply(init_weights) # Apply weight initialization def forward(self, x): return self.model(x) # =============================================================== # End: Content from sparktts/modules/encoder_decoder/wave_generator.py # =============================================================== # =============================================================== # Start: Content from sparktts/modules/fsq/finite_scalar_quantization.py # =============================================================== # helper functions moved earlier def round_ste(z: Tensor) -> Tensor: """Round with straight through gradients.""" zhat = z.round() return z + (zhat - z).detach() class FSQ(nn.Module): def __init__( self, levels: List[int], dim: int | None = None, num_codebooks=1, keep_num_codebooks_dim: bool | None = None, scale: float | None = None, allowed_dtypes: Tuple[torch.dtype, ...] = (torch.float32, torch.float64), channel_first: bool = False, # Added based on usage in ResidualFSQ projection_has_bias: bool = True, return_indices=True, force_quantization_f32=True, ): super().__init__() _levels = torch.tensor(levels, dtype=int32) self.register_buffer("_levels", _levels, persistent=False) _basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=int32) self.register_buffer("_basis", _basis, persistent=False) self.scale = scale # Not used in current implementation, but kept codebook_dim = len(levels) self.codebook_dim = codebook_dim effective_codebook_dim = codebook_dim * num_codebooks self.num_codebooks = num_codebooks self.effective_codebook_dim = effective_codebook_dim # keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1) # Force keep_num_codebooks_dim to False if num_codebooks is 1 if num_codebooks == 1: keep_num_codebooks_dim = False else: keep_num_codebooks_dim = default(keep_num_codebooks_dim, True) # Original assert was checking if num_codebooks > 1 and keep_num_codebooks_dim is False. Let's refine. # If num_codebooks > 1, keep_num_codebooks_dim must be True based on how rearrange is used. if num_codebooks > 1 and not keep_num_codebooks_dim: raise ValueError("If num_codebooks > 1, keep_num_codebooks_dim must be True or None (defaults to True).") self.keep_num_codebooks_dim = keep_num_codebooks_dim self.dim = default(dim, len(_levels) * num_codebooks) self.channel_first = channel_first # Store channel_first setting has_projections = self.dim != effective_codebook_dim self.project_in = ( nn.Linear(self.dim, effective_codebook_dim, bias=projection_has_bias) if has_projections else nn.Identity() ) self.project_out = ( nn.Linear(effective_codebook_dim, self.dim, bias=projection_has_bias) if has_projections else nn.Identity() ) self.has_projections = has_projections self.return_indices = return_indices if return_indices: self.codebook_size = self._levels.prod().item() # Calculate implicit codebook based on current device during forward pass if needed # For now, calculate assuming CPU and move later if necessary # implicit_codebook = self._indices_to_codes(torch.arange(self.codebook_size, device=self._levels.device)) # Calculate on device # self.register_buffer("implicit_codebook", implicit_codebook, persistent=False) self.allowed_dtypes = allowed_dtypes self.force_quantization_f32 = force_quantization_f32 @property def implicit_codebook(self): # Calculate implicit codebook on the fly using the device of _levels device = self._levels.device indices = torch.arange(self.codebook_size, device=device) return self._indices_to_codes(indices) def bound(self, z, eps: float = 1e-3): """Bound `z`, an array of shape (..., d).""" levels = self._levels.to(z.device) # Ensure levels are on same device half_l = (levels - 1) * (1 + eps) / 2 offset = torch.where(levels % 2 == 0, 0.5, 0.0) shift = (offset / half_l).atanh() if torch.any(half_l != 0) else torch.zeros_like(offset) # Avoid div by zero # Ensure shift is compatible shape for broadcasting shift = shift.view(1, 1, -1) if z.ndim == 3 else shift # Adjust based on z dims half_l = half_l.view(1, 1, -1) if z.ndim == 3 else half_l # Clamp input to avoid inf/-inf in atanh z_clipped = torch.clamp(z, min=-1.0 + eps, max=1.0 - eps) # Assuming input z is somewhat normalized? # Original formula might be sensitive, let's try direct clamping. # return (z + shift).tanh() * half_l - offset # Alternative clamping approach (from original Jax version logic): upper_bound = (levels - 1) / 2 lower_bound = -upper_bound upper_bound = upper_bound.view(1, 1, -1) if z.ndim == 3 else upper_bound lower_bound = lower_bound.view(1, 1, -1) if z.ndim == 3 else lower_bound return torch.clamp(z, min=lower_bound, max=upper_bound) def quantize(self, z): """Quantizes z, returns quantized zhat, same shape as z.""" quantized = round_ste(self.bound(z)) levels = self._levels.to(z.device) half_width = levels // 2 # Renormalize to [-1, 1]. # Avoid division by zero if level is 1 half_width = torch.where(half_width == 0, torch.tensor(1.0, device=z.device), half_width.float()) half_width_view = half_width.view(1, 1, -1) if quantized.ndim == 3 else half_width return quantized / half_width_view def _scale_and_shift(self, zhat_normalized): levels = self._levels.to(zhat_normalized.device) half_width = levels // 2 half_width_view = half_width.view(1, 1, -1) if zhat_normalized.ndim == 3 else half_width return (zhat_normalized * half_width_view) + half_width_view def _scale_and_shift_inverse(self, zhat): levels = self._levels.to(zhat.device) half_width = levels // 2 # Avoid division by zero if level is 1 half_width = torch.where(half_width == 0, torch.tensor(1.0, device=zhat.device), half_width.float()) half_width_view = half_width.view(1, 1, -1) if zhat.ndim == 3 else half_width return (zhat - half_width_view) / half_width_view def _indices_to_codes(self, indices): level_indices = self.indices_to_level_indices(indices) codes = self._scale_and_shift_inverse(level_indices.float()) # Convert level indices to float return codes def codes_to_indices(self, zhat): """Converts a `code` to an index in the codebook.""" assert zhat.shape[-1] == self.codebook_dim zhat_scaled = self._scale_and_shift(zhat) # Ensure basis is on the correct device and dtype, handle potential shape mismatch basis = self._basis.to(zhat.device, dtype=int32) basis_view = basis.view(1, 1, -1) if zhat_scaled.ndim == 3 else basis # Match ndim # Ensure zhat_scaled is integer type for multiplication with basis product = (zhat_scaled * basis_view).round().int() return product.sum(dim=-1).to(int32) def indices_to_level_indices(self, indices): """Converts indices to indices at each level, perhaps needed for a transformer with factorized embeddings""" indices_reshaped = rearrange(indices, "... -> ... 1") basis = self._basis.to(indices.device) levels = self._levels.to(indices.device) # Ensure basis and levels match the device and potentially ndim of indices basis_view = basis.view(*([1] * (indices_reshaped.ndim - 1)), -1) levels_view = levels.view(*([1] * (indices_reshaped.ndim - 1)), -1) codes_non_centered = (indices_reshaped // basis_view) % levels_view return codes_non_centered # indices_to_codes is now handled by implicit_codebook property + project_out if needed def forward(self, z): """ einstein notation b - batch ... - sequence, spatial dimensions d - feature dimension c - number of codebook dim (within a single quantizer) g - number of quantizers (groups) - handled by ResidualFSQ/GroupedResidualFSQ """ # Input z can be (b d ...) or (b ... d) # self.channel_first determines the expected input format for projection if self.channel_first: # Expects (b d ...) if z.ndim > 2: # Has spatial/temporal dims z = rearrange(z, "b d ... -> b ... d") z, ps = pack([z], "b * d") # else: z is (b d) -> processed directly by linear else: # Expects (b ... d) if z.ndim > 2: z, ps = pack([z], "b * d") # else: z is (b d) -> processed directly by linear assert ( z.shape[-1] == self.dim ), f"expected dimension of {self.dim} but found dimension of {z.shape[-1]}" # Project in z_projected = self.project_in(z) # (b ... effective_codebook_dim) # Reshape for codebooks if num_codebooks > 1 if self.num_codebooks > 1: z_reshaped = rearrange(z_projected, "b ... (c d) -> b ... c d", c=self.num_codebooks) else: # Add a dummy codebook dim for consistent processing z_reshaped = rearrange(z_projected, "b ... d -> b ... 1 d") # Force quantization step to be full precision or not force_f32 = self.force_quantization_f32 quantization_context = ( partial(autocast, "cuda", enabled=False) if force_f32 else nullcontext ) codes = None indices = None with quantization_context(): orig_dtype = z_reshaped.dtype if force_f32 and orig_dtype not in self.allowed_dtypes: z_for_quant = z_reshaped.float() else: z_for_quant = z_reshaped codes = self.quantize(z_for_quant) # (b ... c d) if self.return_indices: indices = self.codes_to_indices(codes) # (b ... c) # Convert codes back to original dtype if changed codes = codes.type(orig_dtype) # Reshape codes back and project out if self.num_codebooks > 1: codes_reshaped = rearrange(codes, "b ... c d -> b ... (c d)") else: codes_reshaped = rearrange(codes, "b ... 1 d -> b ... d") out = self.project_out(codes_reshaped) # (b ... dim) # Restore original spatial/temporal dimensions if z.ndim > 2: # If we packed dimensions out = unpack(out, ps, "b * d")[0] if self.return_indices: indices = unpack(indices, ps, "b * c")[0] # Restore channel dimension if needed if self.channel_first and out.ndim > 2: out = rearrange(out, "b ... d -> b d ...") if self.return_indices and indices.ndim > 1: # Check indices ndim # Indices shape (b ... c), need to decide how to handle channel dim # Often indices might not need channel dim, depends on usage # If indices are e.g. (b H W c), permuting might be complex. # Keeping indices as (b ... c) for now. pass # Remove the dummy codebook dim from indices if num_codebooks was 1 if self.return_indices and self.num_codebooks == 1 and not self.keep_num_codebooks_dim: indices = indices.squeeze(-1) return out, indices # =============================================================== # End: Content from sparktts/modules/fsq/finite_scalar_quantization.py # =============================================================== # =============================================================== # Start: Content from sparktts/modules/fsq/residual_fsq.py # =============================================================== # Helper functions needed by ResidualFSQ def is_distributed(): return dist.is_initialized() and dist.get_world_size() > 1 def get_maybe_sync_seed(device, max_size=10_000): rand_int = torch.randint(0, max_size, (), device=device) if is_distributed(): # Ensure rand_int is on the correct device for all_reduce if rand_int.device != device: rand_int = rand_int.to(device) dist.all_reduce(rand_int) return rand_int.item() def round_up_multiple(num, mult): # Ensure mult is positive if mult <= 0: return num # Use ceiling division return (num + mult - 1) // mult * mult class ResidualFSQ(nn.Module): """Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf""" def __init__( self, *, levels: List[int], num_quantizers, dim=None, # is_channel_first=False, # Handled inside FSQ now quantize_dropout=False, quantize_dropout_cutoff_index=0, quantize_dropout_multiple_of=1, channel_first: bool = False, # Pass channel_first to FSQ **kwargs, # Pass remaining kwargs to FSQ ): super().__init__() codebook_dim = len(levels) dim = default(dim, codebook_dim) requires_projection = codebook_dim != dim self.project_in = ( nn.Linear(dim, codebook_dim) if requires_projection else nn.Identity() ) self.project_out = ( nn.Linear(codebook_dim, dim) if requires_projection else nn.Identity() ) self.has_projections = requires_projection self.channel_first = channel_first # Store for potential shape adjustments if needed later self.num_quantizers = num_quantizers self.levels = levels self.layers = nn.ModuleList([]) levels_tensor = torch.Tensor(levels) scales = [] for ind in range(num_quantizers): # Calculate scale: (levels - 1) is max value range (- (l-1)/2 to +(l-1)/2) # Residual is divided by scale before quantization # Effective scale for quantizer 'ind' is (levels - 1)^ind ? Needs check. # Original paper scale seems different. Let's stick to FSQ handling scale internally if needed. # Using scale = 1.0 for now, assuming FSQ handles normalization. scale_value = 1.0 # ((levels_tensor - 1)**-ind) - Check this logic scales.append(scale_value) # Pass channel_first to FSQ fsq = FSQ(levels=levels, dim=codebook_dim, channel_first=channel_first, **kwargs) self.layers.append(fsq) # Check if FSQ layers have projections internally. ResidualFSQ should handle overall projection. assert all([not fsq.has_projections for fsq in self.layers]), "FSQ layers within ResidualFSQ should not have internal projections." self.codebook_size = self.layers[0].codebook_size # Using scale = 1.0, so register_buffer might not be needed, or store 1.0s # self.register_buffer("scales", torch.Tensor(scales), persistent=False) # If scales are needed, they should likely be parameters or calculated differently. # For now, assuming FSQ normalizes correctly and scale is 1.0 here. self.quantize_dropout = quantize_dropout and num_quantizers > 1 assert quantize_dropout_cutoff_index >= 0 self.quantize_dropout_cutoff_index = quantize_dropout_cutoff_index self.quantize_dropout_multiple_of = quantize_dropout_multiple_of # encodec paper proposes structured dropout, believe this was set to 4 @property def codebooks(self): # Codebooks are implicit in FSQ, access via property codebooks = [layer.implicit_codebook for layer in self.layers] codebooks = torch.stack(codebooks, dim=0) return codebooks def get_codes_from_indices(self, indices): # indices shape: (b ... q) or (b q ...) depending on usage num_dims = indices.ndim q_dim = -1 # Assume last dim is quantizer dim by default # Find the quantizer dimension (q) for i in range(num_dims): if indices.shape[i] == self.num_quantizers: q_dim = i break if q_dim == -1 and self.num_quantizers == 1 and indices.shape[-1] != 1: # If only 1 quantizer, indices might not have the quantizer dim explicitly indices = indices.unsqueeze(-1) # Add the quantizer dim q_dim = -1 elif q_dim == -1: raise ValueError(f"Could not find quantizer dimension ({self.num_quantizers}) in indices shape {indices.shape}") # Ensure q_dim is the last dimension for processing if q_dim != num_dims - 1: permute_dims = list(range(num_dims)) permute_dims.pop(q_dim) permute_dims.append(q_dim) indices = indices.permute(*permute_dims) batch_shape = indices.shape[:-1] # Shape before the quantizer dim indices = indices.reshape(-1, self.num_quantizers) # Flatten batch/spatial dims # Handle dropout indices (-1) if indices.max() >= self.codebook_size: raise ValueError(f"Invalid index found in indices: {indices.max()}. Max allowed is {self.codebook_size - 1}.") if indices.min() < -1: raise ValueError(f"Invalid index found in indices: {indices.min()}. Min allowed is -1 (dropout).") mask = indices == -1 effective_indices = indices.masked_fill(mask, 0) # Use 0 for dropout indices temporarily all_codes = [] # Iterate through each quantizer layer for i in range(self.num_quantizers): layer_indices = effective_indices[:, i] # Use the FSQ layer's method to convert indices to codes (handles normalization) # Need to ensure indices_to_codes exists and works correctly in FSQ # Assuming FSQ.indices_to_codes takes (batch,) indices and returns (batch, codebook_dim) codes layer_codes = self.layers[i].indices_to_codes(layer_indices) # This needs correct FSQ method all_codes.append(layer_codes) all_codes_tensor = torch.stack(all_codes, dim=0) # (q, b_flat, d) # Mask out dropout codes mask_expanded = mask.permute(1, 0).unsqueeze(-1) # (q, b_flat, 1) all_codes_tensor = all_codes_tensor.masked_fill(mask_expanded, 0.0) # Reshape back to original batch/spatial shape all_codes_tensor = all_codes_tensor.reshape(self.num_quantizers, *batch_shape, -1) # (q, b ... d) # Restore original q_dim position if it was changed if q_dim != num_dims - 1: # Need inverse permutation inv_permute_dims = list(range(num_dims)) # Start with 0, 1, ..., num_dims-1 inv_permute_dims.insert(q_dim, num_dims) # Insert the last dim (q) at the original position inv_permute_dims.pop() # Remove the last element # Permute from (q, b ... d) -> (b ... q ... d) - careful with dims # Example: Input (b h w q), processed to (q, b*h*w), output (q, b*h*w, d) # Reshaped to (q, b, h, w, d) # Want (b, h, w, q, d) -> Need to confirm this logic # Let's assume output shape (q, b, ..., d) is desired for summation later. pass # Keep as (q, b ... d) for now return all_codes_tensor def get_output_from_indices(self, indices): # indices shape: (b ... q) codes = self.get_codes_from_indices(indices) # Output: (q, b ... d) codes_summed = reduce(codes, "q b ... d -> b ... d", "sum") # Project back to original dimension output = self.project_out(codes_summed) # Handle channel first if necessary for the final output if self.channel_first and output.ndim > 2: # Assumes input was (b d ...), so output should be too output = rearrange(output, "b ... d -> b d ...") return output def forward(self, x, return_all_codes=False, rand_quantize_dropout_fixed_seed=None): num_quant, quant_dropout_multiple_of, device = ( self.num_quantizers, self.quantize_dropout_multiple_of, x.device, ) # handle channel first input if necessary for projection original_shape = x.shape if self.channel_first: if x.ndim > 2: # Has spatial/temporal dims x = rearrange(x, "b d ... -> b ... d") x, ps = pack([x], "b * d") # else: x is (b d), processed directly else: # Input is (b ... d) if x.ndim > 2: x, ps = pack([x], "b * d") # else: x is (b d), processed directly # maybe project in projected_x = self.project_in(x) # (b ... codebook_dim) quantized_out = 0.0 residual = projected_x # Start residual from projected input all_indices = [] should_quantize_dropout = self.training and self.quantize_dropout # sample a layer index at which to dropout further residual quantization # also prepare null indices rand_quantize_dropout_index = num_quant # Default to no dropout if should_quantize_dropout: if not exists(rand_quantize_dropout_fixed_seed): rand_quantize_dropout_fixed_seed = get_maybe_sync_seed(device) rand = random.Random(rand_quantize_dropout_fixed_seed) # Ensure cutoff index is valid valid_cutoff = max(0, self.quantize_dropout_cutoff_index) rand_quantize_dropout_index = rand.randrange(valid_cutoff, num_quant) if quant_dropout_multiple_of != 1: rand_quantize_dropout_index = ( round_up_multiple(rand_quantize_dropout_index + 1, quant_dropout_multiple_of) - 1 ) # Clamp index to be within valid range rand_quantize_dropout_index = min(rand_quantize_dropout_index, num_quant - 1) # Null indices shape should match the batch/spatial dims of x before pack null_indices_shape = list(x.shape[:-1]) # All dims except last feature dim null_indices = torch.full(null_indices_shape, -1, device=device, dtype=torch.long) # go through the layers # Assuming scale is handled within FSQ or is 1.0 here # scales = self.scales.to(device) for quantizer_index, layer in enumerate(self.layers): # scale = scales[quantizer_index] # If using external scales if quantizer_index > rand_quantize_dropout_index: # Append null indices matching the shape of valid indices from FSQ # FSQ returns indices shape (b ...) or (b ... c) -> need (b ...) # Use the pre-calculated null_indices all_indices.append(null_indices) continue # Pass residual to the quantizer layer # Assume FSQ takes (b ... d) or (b d ...) based on its channel_first setting # Here, residual is (b ... codebook_dim) quantized, indices = layer(residual) # layer should handle channel_first internally # residual = residual - quantized.detach() # Update residual BEFORE summing output # quantized_out = quantized_out + quantized # Sum the quantized part # Algorithm 1 from paper: # Input: x # residual = x # codes = [] # for q in quantizers: # x_q, indices = q(residual) # Quantize # residual = residual - x_q # Update residual (use x_q directly, not detached?) - Check paper/encodec. Using detached version. # codes.append(indices) # x_hat = sum(x_q for each layer?) - No, final quantized output is reconstructed from indices. # Let's follow common implementation: sum quantized outputs, update residual with detached quantized quantized_detached = quantized.detach() residual = residual - quantized_detached quantized_out = quantized_out + quantized # Sum quantized outputs from each layer # Store indices if indices is None: raise ValueError(f"FSQ layer {quantizer_index} did not return indices.") all_indices.append(indices) # project out the summed quantized output final_quantized_out = self.project_out(quantized_out) # (b ... dim) # stack all indices all_indices = torch.stack(all_indices, dim=-1) # (b ... q) # Restore original shape if packed if x.ndim > 2: # If we packed dimensions final_quantized_out = unpack(final_quantized_out, ps, "b * d")[0] all_indices = unpack(all_indices, ps, "b * q")[0] # Restore channel dimension if needed if self.channel_first and final_quantized_out.ndim > 2: final_quantized_out = rearrange(final_quantized_out, "b ... d -> b d ...") # Decide how to handle indices shape. Keep as (b ... q) or (b q ...)? # Keeping as (b ... q) seems more common. # all_indices = rearrange(all_indices, "b ... q -> b q ...") # Optional rearrange # return ret = (final_quantized_out, all_indices) if not return_all_codes: return ret # Return all codes (reconstructed from indices) # Input to get_codes_from_indices should be (b ... q) all_codes = self.get_codes_from_indices(all_indices) # Output (q, b ... d) # Maybe reshape all_codes to match input shape conventions? # If input was channel_first (b d ...), maybe output codes as (q b d ...)? if self.channel_first and all_codes.ndim > 3: all_codes = rearrange(all_codes, "q b ... d -> q b d ...") return (*ret, all_codes) # =============================================================== # End: Content from sparktts/modules/fsq/residual_fsq.py # =============================================================== # =============================================================== # Start: Content from sparktts/modules/speaker/ecapa_tdnn.py # =============================================================== class Res2Conv1dReluBn(nn.Module): """ in_channels == out_channels == channels """ def __init__( self, channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True, scale=4, ): super().__init__() assert channels % scale == 0, "{} % {} != 0".format(channels, scale) self.scale = scale self.width = channels // scale self.nums = scale if scale == 1 else scale - 1 self.convs = [] self.bns = [] for i in range(self.nums): self.convs.append( nn.Conv1d( self.width, self.width, kernel_size, stride, padding, dilation, bias=bias, ) ) self.bns.append(nn.BatchNorm1d(self.width)) self.convs = nn.ModuleList(self.convs) self.bns = nn.ModuleList(self.bns) def forward(self, x): out = [] spx = torch.split(x, self.width, 1) sp = spx[0] # Enumerate starts from 0, matching list indices for i, (conv, bn) in enumerate(zip(self.convs, self.bns)): # Order: conv -> relu -> bn if i >= 1: sp = sp + spx[i] # Residual connection within block parts sp = conv(sp) sp = bn(F.relu(sp)) # Apply ReLU before BatchNorm out.append(sp) if self.scale != 1: # Append the last chunk without processing if scale > 1 out.append(spx[self.nums]) out = torch.cat(out, dim=1) return out """ Conv1d + BatchNorm1d + ReLU """ class Conv1dReluBn(nn.Module): def __init__( self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True, ): super().__init__() self.conv = nn.Conv1d( in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias ) self.bn = nn.BatchNorm1d(out_channels) def forward(self, x): # Original: bn(relu(conv(x))) # ECAPA Paper Figure/Desc seems to suggest conv -> bn -> relu ? Check Res2Net paper/ECAPA details. # Sticking to original code's bn(relu(conv(x))) for now. return self.bn(F.relu(self.conv(x))) """ The SE connection of 1D case. """ class SE_Connect(nn.Module): def __init__(self, channels, se_bottleneck_dim=128): super().__init__() self.linear1 = nn.Linear(channels, se_bottleneck_dim) self.linear2 = nn.Linear(se_bottleneck_dim, channels) def forward(self, x): # x shape: (B, C, T) out = x.mean(dim=2) # Global average pooling over time -> (B, C) out = F.relu(self.linear1(out)) out = torch.sigmoid(self.linear2(out)) out = x * out.unsqueeze(2) # (B, C, T) * (B, C, 1) -> (B, C, T) return out """ SE-Res2Block of the ECAPA-TDNN architecture. """ class SE_Res2Block(nn.Module): def __init__(self, channels, kernel_size, stride, padding, dilation, scale): super().__init__() self.se_res2block = nn.Sequential( Conv1dReluBn(channels, channels, kernel_size=1, stride=1, padding=0), Res2Conv1dReluBn( channels, kernel_size, stride, padding, dilation, scale=scale ), Conv1dReluBn(channels, channels, kernel_size=1, stride=1, padding=0), SE_Connect(channels), ) def forward(self, x): return x + self.se_res2block(x) class ECAPA_TDNN(nn.Module): def __init__( self, channels=512, feat_dim=80, embed_dim=192, pooling_func="ASTP", global_context_att=False, emb_bn=False, ): super().__init__() self.layer1 = Conv1dReluBn(feat_dim, channels, kernel_size=5, padding=2) self.layer2 = SE_Res2Block( channels, kernel_size=3, stride=1, padding=2, dilation=2, scale=8 ) self.layer3 = SE_Res2Block( channels, kernel_size=3, stride=1, padding=3, dilation=3, scale=8 ) self.layer4 = SE_Res2Block( channels, kernel_size=3, stride=1, padding=4, dilation=4, scale=8 ) cat_channels = channels * 3 # The output channels after conv depends on the pooling layer input expectation # Original paper uses 1536. Let's assume pooling expects 1536. self.conv = nn.Conv1d(cat_channels, cat_channels, kernel_size=1) # Keep channels same for pooling # Dynamically get pooling class based on string name from pooling_layers (defined earlier) if pooling_func == "TAP": pooling_layer = TAP elif pooling_func == "TSDP": pooling_layer = TSDP elif pooling_func == "TSTP": pooling_layer = TSTP elif pooling_func == "ASTP": pooling_layer = ASTP elif pooling_func == "MHASTP": pooling_layer = MHASTP elif pooling_func == "MQMHASTP": pooling_layer = MQMHASTP else: raise ValueError(f"Unsupported pooling function: {pooling_func}") self.pool = pooling_layer( in_dim=cat_channels, # Pooling operates on the output of self.conv global_context_att=global_context_att # Pass context flag if relevant (ASTP) # Add other necessary kwargs for specific pooling layers if needed ) # self.pool_out_dim = self.pool.get_out_dim() # Get output dim from pooling layer # Use standard way to get output dim if get_out_dim not standard # For TSTP/ASTP etc., it's usually 2 * in_dim if hasattr(self.pool, 'get_out_dim'): self.pool_out_dim = self.pool.get_out_dim() elif isinstance(self.pool, (TSTP, ASTP, MHASTP, MQMHASTP)): # Assuming these double the input dimension self.pool_out_dim = cat_channels * (2 * getattr(self.pool, 'query_num', 1) if isinstance(self.pool, MQMHASTP) else 2) else: # TAP, TSDP self.pool_out_dim = cat_channels self.bn = nn.BatchNorm1d(self.pool_out_dim) self.linear = nn.Linear(self.pool_out_dim, embed_dim) self.emb_bn = emb_bn if emb_bn: # better in SSL for SV self.bn2 = nn.BatchNorm1d(embed_dim) else: self.bn2 = nn.Identity() def forward(self, x, return_latent=False): # Input x expected as (B, T, F) e.g., mels x = x.permute(0, 2, 1) # (B, T, F) -> (B, F, T) out1 = self.layer1(x) out2 = self.layer2(out1) out3 = self.layer3(out2) out4 = self.layer4(out3) # Concat features from layers 2, 3, 4 out = torch.cat([out2, out3, out4], dim=1) # (B, 3*channels, T) latent = F.relu(self.conv(out)) # (B, 3*channels, T) # Pooling expects (B, F, T) pooled_out = self.pool(latent) # (B, pool_out_dim) bn_out = self.bn(pooled_out) embedding = self.linear(bn_out) # (B, embed_dim) if self.emb_bn: embedding = self.bn2(embedding) if return_latent: # Return the embedding and the features before pooling return embedding, latent # latent shape (B, 3*channels, T) return embedding # Return only the final embedding # Factory functions (optional, but keep if used elsewhere) def ECAPA_TDNN_c1024(feat_dim, embed_dim, pooling_func="ASTP", emb_bn=False): return ECAPA_TDNN( channels=1024, feat_dim=feat_dim, embed_dim=embed_dim, pooling_func=pooling_func, emb_bn=emb_bn, ) def ECAPA_TDNN_GLOB_c1024(feat_dim, embed_dim, pooling_func="ASTP", emb_bn=False): return ECAPA_TDNN( channels=1024, feat_dim=feat_dim, embed_dim=embed_dim, pooling_func=pooling_func, global_context_att=True, emb_bn=emb_bn, ) def ECAPA_TDNN_c512(feat_dim, embed_dim, pooling_func="ASTP", emb_bn=False): return ECAPA_TDNN( channels=512, feat_dim=feat_dim, embed_dim=embed_dim, pooling_func=pooling_func, emb_bn=emb_bn, ) def ECAPA_TDNN_GLOB_c512(feat_dim, embed_dim, pooling_func="ASTP", emb_bn=False): return ECAPA_TDNN( channels=512, feat_dim=feat_dim, embed_dim=embed_dim, pooling_func=pooling_func, global_context_att=True, emb_bn=emb_bn, ) # =============================================================== # End: Content from sparktts/modules/speaker/ecapa_tdnn.py # =============================================================== # =============================================================== # Start: Content from sparktts/modules/speaker/perceiver_encoder.py # =============================================================== # Helper functions for Perceiver/Attention def exists(val): # Redefined earlier return val is not None def once(fn): # Redefined earlier called = False @wraps(fn) def inner(x): nonlocal called if called: return called = True return fn(x) return inner print_once = once(print) class Attend(nn.Module): def __init__(self, dropout=0.0, causal=False, use_flash=False): super().__init__() self.dropout = dropout self.attn_dropout = nn.Dropout(dropout) self.causal = causal self.register_buffer("mask", None, persistent=False) self.use_flash = use_flash can_use_flash = hasattr(F, 'scaled_dot_product_attention') and use_flash if can_use_flash: print_once("Using Flash Attention for Perceiver.") else: if use_flash: print_once("Flash Attention requested but not available/enabled.") self.use_flash = False # Disable if not available # Flash attention config (simplified) self.efficient_config = namedtuple("EfficientAttentionConfig", ["enable_flash", "enable_math", "enable_mem_efficient"]) # Set default configs, actual backend selection happens in F.scaled_dot_product_attention self.cpu_config = self.efficient_config(True, True, True) # Default for CPU self.cuda_config = self.efficient_config(True, True, True) # Default for CUDA def get_mask(self, n, device): if exists(self.mask) and self.mask.shape[-1] >= n and self.mask.device == device: return self.mask[:n, :n] mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1) self.register_buffer("mask", mask, persistent=False) return mask def flash_attn(self, q, k, v, mask=None): _, heads, q_len, _, k_len, is_cuda = *q.shape, k.shape[-2], q.is_cuda # Expand KV if needed (for multi-query attention, though Perceiver might use standard MHA) if k.ndim == 3: # (b n_kv d) -> (b h n_kv d) ? No, needs (b h n_kv d_head) # Assume k/v are already (b h n d_head) or need different handling pass if v.ndim == 3: pass # Format mask for flash attention (B, N_q, N_kv) or (B, H, N_q, N_kv) flash_mask = None if exists(mask): # mask shape (b, n_kv) -> needs (b, 1, n_q, n_kv) or (b, h, n_q, n_kv) ? # Check documentation. For key padding mask, usually (B, N_kv). # Needs expansion. Let's assume (B, H, N_q, N_kv) for safety. if mask.ndim == 2: # (b, n_kv) flash_mask = rearrange(mask, "b j -> b 1 1 j") # Flash attention expects additive mask (-inf for masked) not boolean? Check. # F.scaled_dot_product_attention takes boolean mask with attn_mask arg. flash_mask = flash_mask.expand(-1, heads, q_len, -1) # (b h n_q n_kv) # Use ~mask because True means *mask out* in flash attn's attn_mask. flash_mask = ~flash_mask elif mask.ndim == 4 and mask.shape[1] == 1: # Maybe already expanded (b 1 1 n_kv) flash_mask = mask.expand(-1, heads, q_len, -1) flash_mask = ~flash_mask else: # Assuming mask might already be correctly shaped (e.g., B, H, Nq, Nkv boolean) flash_mask = ~mask # Invert mask if boolean # pytorch 2.0 flash attn: q, k, v, attn_mask, dropout_p, is_causal # attn_mask should be boolean where True indicates masking. out = F.scaled_dot_product_attention( q, k, v, attn_mask=flash_mask if exists(flash_mask) else None, # Pass boolean mask dropout_p=self.dropout if self.training else 0.0, is_causal=self.causal # Pass causal flag directly ) return out def forward(self, q, k, v, mask=None): """ einstein notation b - batch h - heads n, i, j - sequence length (query, key/value) d - feature dimension (d_head) """ n, device = q.shape[-2], q.device scale = q.shape[-1] ** -0.5 if self.use_flash: return self.flash_attn(q, k, v, mask=mask) # Manual Attention Calculation kv_einsum_eq = "b h j d" # Assuming k, v are always (b h n d) # similarity sim = einsum(f"b h i d, {kv_einsum_eq} -> b h i j", q, k) * scale # key padding mask if exists(mask): # mask shape (b, j) -> (b, 1, 1, j) mask_value = -torch.finfo(sim.dtype).max mask = rearrange(mask, "b j -> b 1 1 j") sim = sim.masked_fill(~mask, mask_value) # Mask where mask is False # causal mask (Not typically used in Perceiver cross-attention) if self.causal: causal_mask = self.get_mask(n, device) # (i, j) sim = sim.masked_fill(causal_mask, mask_value) # attention attn = sim.softmax(dim=-1) attn = self.attn_dropout(attn) # aggregate values out = einsum(f"b h i j, {kv_einsum_eq} -> b h i d", attn, v) return out # Need Sequential, default, RMSNorm, GEGLU, FeedForward, Attention for PerceiverResampler def Sequential(*mods): # Redefined earlier return nn.Sequential(*filter(exists, mods)) class RMSNorm(nn.Module): def __init__(self, dim, scale=True, dim_cond=None): super().__init__() self.cond = exists(dim_cond) # Conditional LayerNorm not used in PerceiverResampler, simplify # self.to_gamma_beta = nn.Linear(dim_cond, dim * 2) if self.cond else None self.scale = dim**0.5 self.gamma = nn.Parameter(torch.ones(dim)) if scale else None def forward(self, x, cond=None): # Remove cond argument if not used gamma = default(self.gamma, torch.tensor(1.0, device=x.device)) # Ensure gamma is tensor # Note: F.normalize normalizes across the *last* dimension by default normed_x = F.normalize(x, dim=-1) return normed_x * self.scale * gamma class CausalConv1d(nn.Conv1d): # Already defined earlier def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) kernel_size = self.kernel_size[0] dilation = self.dilation[0] stride = self.stride[0] assert stride == 1 self.causal_padding = dilation * (kernel_size - 1) def forward(self, x): # Input x: (B, C, T) causal_padded_x = F.pad(x, (self.causal_padding, 0), value=0.0) return super().forward(causal_padded_x) class GEGLU(nn.Module): # Already defined earlier def forward(self, x): x, gate = x.chunk(2, dim=-1) return F.gelu(gate) * x def FeedForward(dim, mult=4, causal_conv=False): # Already defined earlier dim_inner = int(dim * mult * 2 / 3) conv = None if causal_conv: conv = nn.Sequential( Rearrange("b n d -> b d n"), CausalConv1d(dim_inner, dim_inner, 3), Rearrange("b d n -> b n d"), ) return Sequential( nn.Linear(dim, dim_inner * 2, bias=False), # Bias False often used in transformers GEGLU(), conv, nn.Linear(dim_inner, dim, bias=False) # Bias False ) class Attention(nn.Module): def __init__( self, dim, *, dim_context=None, causal=False, dim_head=64, heads=8, dropout=0.0, use_flash=False, cross_attn_include_queries=False, ): super().__init__() # self.scale = dim_head**-0.5 # scale is handled by Attend or flash attn self.heads = heads self.cross_attn_include_queries = cross_attn_include_queries dim_inner = dim_head * heads dim_context = default(dim_context, dim) self.attend = Attend(causal=causal, dropout=dropout, use_flash=use_flash) self.to_q = nn.Linear(dim, dim_inner, bias=False) # Combine K and V projection for efficiency self.to_kv = nn.Linear(dim_context, dim_inner * 2, bias=False) self.to_out = nn.Linear(dim_inner, dim, bias=False) def forward(self, x, context=None, mask=None): h, has_context = self.heads, exists(context) # x shape: (b, n_q, d) # context shape: (b, n_kv, d_ctx) context = default(context, x) # Use self if context not provided if has_context and self.cross_attn_include_queries: # Prepend queries to context for attention calculation context = torch.cat((x, context), dim=-2) # (b, n_q + n_kv, d_ctx) - ensure dims match # Project q, k, v q = self.to_q(x) k, v = self.to_kv(context).chunk(2, dim=-1) # Reshape for multi-head attention q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) # Attend out = self.attend(q, k, v, mask=mask) # mask should be (b, n_kv) # Combine heads and project out out = rearrange(out, "b h n d -> b n (h d)") return self.to_out(out) class PerceiverResampler(nn.Module): def __init__( self, *, dim, depth=2, dim_context=None, num_latents=32, dim_head=64, heads=8, ff_mult=4, use_flash_attn=False, ): super().__init__() dim_context = default(dim_context, dim) # Project context to query dimension if different self.proj_context = ( nn.Linear(dim_context, dim) if dim_context != dim else nn.Identity() ) # Learnable latent queries self.latents = nn.Parameter(torch.randn(num_latents, dim)) nn.init.normal_(self.latents, std=0.02) # Initialize latents self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append( nn.ModuleList( [ # Cross-Attention from latents (queries) to context (keys/values) Attention( dim=dim, dim_context=dim, # Context is projected to dim dim_head=dim_head, heads=heads, use_flash=use_flash_attn, cross_attn_include_queries=False, # Standard Perceiver cross-attn ), # Self-Attention within latents # Optional: Add self-attention block here if needed # Attention( # dim=dim, dim_head=dim_head, heads=heads, use_flash=use_flash_attn # ), # FeedForward block FeedForward(dim=dim, mult=ff_mult), ] ) ) # Add LayerNorms (typically before attention and ff blocks) # self.layers[-1].insert(0, RMSNorm(dim)) # Pre-Attention Norm # self.layers[-1].insert(2, RMSNorm(dim)) # Pre-FF Norm # Using Post-Norm structure as in original reference: self.layers[-1].insert(1, RMSNorm(dim)) # After Attention self.layers[-1].append(RMSNorm(dim)) # After FeedForward # Final normalization of latents # self.norm = RMSNorm(dim) # Final norm applied inside loop in original? Let's apply at end. def forward(self, x, mask=None): # x shape: (b, n_ctx, d_ctx) batch = x.shape[0] # Project context x = self.proj_context(x) # (b, n_ctx, d) # Repeat latents for batch latents = repeat(self.latents, "n d -> b n d", b=batch) # (b, n_lat, d) # Apply layers # Original structure had norm inside loop, adapting: Attn -> Norm -> FF -> Norm for attn, norm1, ff, norm2 in self.layers: # Cross-Attention + Residual latents_attn = attn(latents, x, mask=mask) # Query: latents, Context: x latents = norm1(latents_attn + latents) # FeedForward + Residual latents_ff = ff(latents) latents = norm2(latents_ff + latents) # return self.norm(latents) # Apply final norm if defined outside loop return latents # Return latents after last block norm # =============================================================== # End: Content from sparktts/modules/speaker/perceiver_encoder.py # =============================================================== # =============================================================== # Start: Content from sparktts/modules/speaker/speaker_encoder.py # =============================================================== class SpeakerEncoder(nn.Module): """ Speaker Encoder using ECAPA-TDNN, Perceiver Resampler, and Residual FSQ. Args: input_dim (int): acoustic feature dimension (e.g., mel bins) out_dim (int): output dimension of the final d-vector latent_dim (int): latent dimension for perceiver and quantization token_num (int): number of latent tokens from perceiver fsq_levels (List[int]): levels for finite scalar quantization fsq_num_quantizers (int): number of residual quantizers in FSQ ecapa_embed_dim (int): embedding dimension from ECAPA-TDNN (before projection) """ def __init__( self, input_dim: int = 80, # Default mel bins out_dim: int = 1024, # Target d-vector dim from config latent_dim: int = 128, # Latent dim for perceiver/quantizer token_num: int = 32, # Number of speaker tokens fsq_levels: List[int] = [4, 4, 4, 4, 4, 4], fsq_num_quantizers: int = 1, # Add ECAPA config params if needed, or use defaults ecapa_channels: int = 512, ecapa_embed_dim: int = 192, # Default ECAPA embed dim ): super(SpeakerEncoder, self).__init__() # ECAPA-TDNN for initial feature extraction and x-vector (optional) # Using the GLOB variant as in the original __main__ test self.speaker_encoder_base = ECAPA_TDNN_GLOB_c512( feat_dim=input_dim, embed_dim=ecapa_embed_dim # Use specific ECAPA embed dim ) # Dimension of features extracted by ECAPA (latent before pooling) ecapa_feature_dim = ecapa_channels * 3 # From concatenation in ECAPA # Perceiver Resampler to get fixed-length sequence from variable-length ECAPA features self.perceiver_sampler = PerceiverResampler( dim=latent_dim, # Output dim of perceiver latents dim_context=ecapa_feature_dim, # Input dim from ECAPA features num_latents=token_num, depth=2, # Default depth, adjust if needed dim_head=64, heads=8, ff_mult=4, # Default attention/ff params use_flash_attn=True # Enable flash attention if available ) # Residual Finite Scalar Quantizer self.quantizer = ResidualFSQ( levels=fsq_levels, num_quantizers=fsq_num_quantizers, dim=latent_dim, # Quantizer operates on perceiver output dim channel_first=False, # Perceiver output is (B, T, D), so channel_first=False quantize_dropout=False, # No dropout specified in config ) # Final projection from flattened quantized tokens to the target output dimension self.project = nn.Linear(latent_dim * token_num, out_dim) def get_codes_from_indices(self, indices: torch.Tensor) -> torch.Tensor: """Reconstruct quantized vectors from indices.""" # indices shape: (B, T_token, Q) or (B, Q, T_token)? Check ResidualFSQ output. # Assuming (B, T_token, Q) from forward pass. # get_output_from_indices expects (indices_chunk1, indices_chunk2, ...) if grouped. # If not grouped, expects (B, ... Q). Let's assume (B, T_token, Q). zq = self.quantizer.get_output_from_indices(indices) # Output zq shape should be (B, T_token, latent_dim) return zq def get_indices(self, mels: torch.Tensor) -> torch.Tensor: """Get FSQ indices directly from mel spectrograms.""" # mels: (B, T_mel, D_mel) _, features = self.speaker_encoder_base(mels, return_latent=True) # features: (B, ecapa_feat_dim, T_feat) x = self.perceiver_sampler(features.transpose(1, 2)) # Input: (B, T_feat, ecapa_feat_dim), Output: (B, token_num, latent_dim) _, indices = self.quantizer(x) # Input: (B, token_num, latent_dim), indices: (B, token_num, Q) return indices def forward(self, mels: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: mels: (B, T_mel, D_mel) - Mel spectrogram input Return: x_vector: (B, ecapa_embed_dim) - Global speaker embedding from ECAPA d_vector: (B, out_dim) - Speaker embedding derived from quantized tokens """ # Get base speaker embedding (x-vector) and intermediate features from ECAPA x_vector, features = self.speaker_encoder_base(mels, return_latent=True) # features shape: (B, ecapa_feat_dim, T_feat) # Resample features using Perceiver # Perceiver expects (B, T, D), so transpose features perceiver_latents = self.perceiver_sampler(features.transpose(1, 2)) # perceiver_latents shape: (B, token_num, latent_dim) # Quantize the perceiver latents # Quantizer expects (B, T, D) if channel_first=False zq, indices = self.quantizer(perceiver_latents) # zq shape: (B, token_num, latent_dim), indices shape: (B, token_num, Q) # Flatten quantized tokens and project to final d-vector dimension zq_flat = rearrange(zq, 'b t d -> b (t d)') # (B, token_num * latent_dim) d_vector = self.project(zq_flat) # (B, out_dim) return x_vector, d_vector def tokenize(self, mels: torch.Tensor) -> torch.Tensor: """Tokenize the input mel spectrogram to get FSQ indices.""" # Same logic as get_indices _, features = self.speaker_encoder_base(mels, return_latent=True) # features: (B, ecapa_feat_dim, T_feat) x = self.perceiver_sampler(features.transpose(1, 2)) # (B, token_num, latent_dim) _, indices = self.quantizer(x) # indices: (B, token_num, Q) return indices def detokenize(self, indices: torch.Tensor) -> torch.Tensor: """Detokenize FSQ indices to get the final d-vector.""" # indices shape: (B, token_num, Q) # Reconstruct quantized vectors from indices zq = self.get_codes_from_indices(indices) # (B, token_num, latent_dim) # Flatten and project zq_flat = rearrange(zq, 'b t d -> b (t d)') d_vector = self.project(zq_flat) return d_vector # =============================================================== # End: Content from sparktts/modules/speaker/speaker_encoder.py # =============================================================== # =============================================================== # Start: Content from sparktts/modules/vq/factorized_vector_quantize.py # =============================================================== # Helper function from layers.py (already defined) # def WNConv1d(*args, **kwargs): # return weight_norm(nn.Conv1d(*args, **kwargs)) def ema_inplace(moving_avg, new, decay): moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) class FactorizedVectorQuantize(nn.Module): def __init__( self, input_dim: int, codebook_size: int, codebook_dim: int, commitment: float, codebook_loss_weight: float = 1.0, decay: float = 0.99, threshold_ema_dead_code: float = 2.0, # Changed default from 2 based on config momentum: float = 0.99, # Not used in current implementation? use_l2_normlize: bool = True, # Added from config **kwargs, ): super().__init__() self.input_dim = input_dim self.codebook_size = codebook_size self.codebook_dim = codebook_dim self.commitment = commitment self.codebook_loss_weight = codebook_loss_weight self.decay = decay self.threshold_ema_dead_code = threshold_ema_dead_code # self.momentum = momentum # Store if needed later self.use_l2_normlize = use_l2_normlize if input_dim != self.codebook_dim: self.in_project = WNConv1d(input_dim, self.codebook_dim, kernel_size=1) self.out_project = WNConv1d(self.codebook_dim, input_dim, kernel_size=1) else: self.in_project = nn.Identity() self.out_project = nn.Identity() # Codebook embedding layer self.codebook = nn.Embedding(self.codebook_size, self.codebook_dim) # Initialize codebook? Often random init is fine. # Buffers for EMA updates (cluster size and maybe embeddings) self.register_buffer("cluster_size", torch.zeros(self.codebook_size)) # EMA average embeddings (optional, can use self.codebook.weight directly for loss) # self.register_buffer("ema_embed", self.codebook.weight.clone()) def forward(self, z: torch.Tensor) -> Dict[str, Any]: """Quantizes the input tensor using a fixed codebook and returns the corresponding codebook vectors and losses. Parameters ---------- z : Tensor[B x D_in x T] Returns ------- Dict containing: z_q (Tensor[B x D_in x T]): Quantized continuous representation (passed through out_project) indices (Tensor[B x T]): Codebook indices vq_loss (Tensor[1]): Combined VQ loss (codebook + commitment) perplexity (Tensor[1]): Codebook perplexity metric active_num (Tensor[1]): Number of active codebook entries """ # z: (B, D_in, T) B, _, T = z.shape # Project input to codebook dimension if necessary z_e = self.in_project(z) # (B, D_code, T) # Find nearest neighbors and get quantized vectors + indices z_q, indices, dists = self.decode_latents(z_e) # z_q: (B, D_code, T), indices: (B, T) # Calculate statistics for perplexity and active codes with torch.no_grad(): # Stats should not contribute to gradient embed_onehot = F.one_hot(indices, self.codebook_size).type(z_e.dtype) # (B, T, C) # Flatten batch and time dims for stats embed_onehot_flat = rearrange(embed_onehot, 'b t c -> (b t) c') avg_probs = torch.mean(embed_onehot_flat, dim=0) # (C,) perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) # EMA update for cluster size (only in training) active_num_tensor = (embed_onehot_flat.sum(0) > 0).sum() # Before EMA if self.training: # Perform EMA update in place ema_inplace(self.cluster_size, embed_onehot_flat.sum(0), self.decay) # Calculate active codes based on EMA threshold active_num_tensor = (self.cluster_size > self.threshold_ema_dead_code).sum() # Calculate losses (only in training) commit_loss = torch.tensor(0.0, device=z.device) codebook_loss = torch.tensor(0.0, device=z.device) vq_loss = torch.tensor(0.0, device=z.device) if self.training: # Commitment loss (encourage encoder output E(x) to be close to codebook z_q) # Use z_e (projected encoder output) and z_q.detach() commit_loss = F.mse_loss(z_e, z_q.detach()) * self.commitment # Codebook loss (encourage codebook entries z_q to be close to encoder output E(x)) # Use z_q and z_e.detach() codebook_loss = F.mse_loss(z_q, z_e.detach()) * self.codebook_loss_weight vq_loss = commit_loss + codebook_loss # Straight-through estimator: copy gradient from z_q to z_e z_q_st = z_e + (z_q - z_e).detach() # Project quantized vectors back to input dimension if necessary z_q_out = self.out_project(z_q_st) # (B, D_in, T) return { "z_q": z_q_out, "indices": indices, # "dists": dists, # Dists might be large, exclude unless needed "vq_loss": vq_loss, "perplexity": perplexity, "active_num": active_num_tensor.float(), } def embed_code(self, embed_id): """Retrieve codebook vectors for given indices.""" return F.embedding(embed_id, self.codebook.weight) def decode_code(self, embed_id): """Retrieve codebook vectors and transpose to (B, D, T) format.""" # embed_id: (B, T) # Embedding: (B, T, D_code) # Transpose: (B, D_code, T) return self.embed_code(embed_id).transpose(1, 2) def decode_latents(self, latents): """Find nearest codebook entries for latent vectors.""" # latents: (B, D_code, T) B, D_code, T = latents.shape encodings = rearrange(latents, "b d t -> (b t) d") # ((B*T), D_code) codebook = self.codebook.weight # (C, D_code) # Normalize if required if self.use_l2_normlize: encodings = F.normalize(encodings, p=2, dim=-1) codebook = F.normalize(codebook, p=2, dim=-1) # Compute distances (squared Euclidean or Cosine depending on normalization) # dist = torch.cdist(encodings, codebook, p=2)**2 # Squared Euclidean # Faster calculation using matrix multiplication if normalized: # dist = 2 - 2 * (encodings @ codebook.t()) # Or full squared Euclidean: dist = ( encodings.pow(2).sum(1, keepdim=True) # (B*T, 1) - 2 * (encodings @ codebook.t()) # (B*T, C) + codebook.pow(2).sum(1, keepdim=True).t() # (1, C) ) # Result shape: (B*T, C) # Find nearest neighbors indices = torch.argmin(dist, dim=-1) # (B*T) indices = rearrange(indices, "(b t) -> b t", b=B) # (B, T) # Get the quantized vectors z_q = self.decode_code(indices) # (B, D_code, T) return z_q, indices, dist # Return dist if needed, e.g., for debugging # --- Methods for inference/tokenization --- def tokenize(self, z: torch.Tensor) -> torch.Tensor: """Tokenize the input tensor without loss calculation.""" # z: (B, D_in, T) z_e = self.in_project(z) # (B, D_code, T) _, indices, _ = self.decode_latents(z_e) # indices: (B, T) return indices def detokenize(self, indices: torch.Tensor) -> torch.Tensor: """Detokenize indices to quantized vectors in input dimension.""" # indices: (B, T) z_q_code_dim = self.decode_code(indices) # (B, D_code, T) z_q_out = self.out_project(z_q_code_dim) # (B, D_in, T) return z_q_out # =============================================================== # End: Content from sparktts/modules/vq/factorized_vector_quantize.py # =============================================================== # --- BiCodec Model Definition (Adapted from sparktts/models/bicodec.py) --- class BiCodec(nn.Module): """ BiCodec model for speech synthesis, incorporating a speaker encoder, feature encoder/decoder, quantizer, and wave generator. """ def __init__( self, mel_params: Dict[str, Any], encoder: nn.Module, decoder: nn.Module, quantizer: nn.Module, speaker_encoder: nn.Module, prenet: nn.Module, postnet: nn.Module, **kwargs ) -> None: """ Initializes the BiCodec model with the required components. Args: mel_params (dict): Parameters for the mel-spectrogram transformer. encoder (nn.Module): Encoder module. decoder (nn.Module): Decoder module. quantizer (nn.Module): Quantizer module. speaker_encoder (nn.Module): Speaker encoder module. prenet (nn.Module): Prenet network. postnet (nn.Module): Postnet network. """ super().__init__() self.encoder = encoder self.decoder = decoder self.quantizer = quantizer self.speaker_encoder = speaker_encoder self.prenet = prenet self.postnet = postnet self._init_mel_transformer(mel_params) @classmethod def load_from_config_and_checkpoint(cls, model_dir: Path, config_dict: Dict[str, Any], **kwargs) -> "BiCodec": """Loads the model from a config dictionary and checkpoint file.""" ckpt_path = model_dir / 'model.safetensors' if not ckpt_path.is_file(): raise FileNotFoundError(f"BiCodec checkpoint not found at {ckpt_path}") audio_tokenizer_config = config_dict # Assuming config_dict holds the relevant sub-config # Instantiate components using classes from _modeling_bicodec_components mel_params = audio_tokenizer_config.get("mel_params", {}) encoder_cfg = audio_tokenizer_config.get("encoder", {}) quantizer_cfg = audio_tokenizer_config.get("quantizer", {}) prenet_cfg = audio_tokenizer_config.get("prenet", {}) postnet_cfg = audio_tokenizer_config.get("postnet", {}) decoder_cfg = audio_tokenizer_config.get("decoder", {}) # This corresponds to WaveGenerator speaker_encoder_cfg = audio_tokenizer_config.get("speaker_encoder", {}) # --- Input Validation --- required_keys = { "encoder": ["input_channels", "vocos_dim", "vocos_intermediate_dim", "vocos_num_layers", "out_channels"], "quantizer": ["input_dim", "codebook_size", "codebook_dim", "commitment"], "prenet": ["input_channels", "vocos_dim", "vocos_intermediate_dim", "vocos_num_layers", "out_channels"], "postnet": ["input_channels", "vocos_dim", "vocos_intermediate_dim", "vocos_num_layers", "out_channels"], "decoder": ["input_channel", "channels", "rates", "kernel_sizes"], # WaveGenerator keys "speaker_encoder": ["input_dim", "out_dim", "latent_dim", "token_num"], "mel_params": ["sample_rate", "n_fft", "win_length", "hop_length", "num_mels"] } for comp, keys in required_keys.items(): cfg = audio_tokenizer_config.get(comp, {}) if not cfg: logging.get_logger(__name__).warning(f"BiCodec config missing section: '{comp}'") for key in keys: if key not in cfg: logging.get_logger(__name__).warning(f"BiCodec config missing key '{key}' in section '{comp}'") # --- End Validation --- # Instantiate modules encoder = Encoder(**encoder_cfg) if encoder_cfg else None quantizer = FactorizedVectorQuantize(**quantizer_cfg) if quantizer_cfg else None prenet = Decoder(**prenet_cfg) if prenet_cfg else None postnet = Decoder(**postnet_cfg) if postnet_cfg else None decoder = WaveGenerator(**decoder_cfg) if decoder_cfg else None # WaveGenerator instance speaker_encoder = SpeakerEncoder(**speaker_encoder_cfg) if speaker_encoder_cfg else None # Check if all components were successfully created if not all([encoder, quantizer, prenet, postnet, decoder, speaker_encoder, mel_params]): raise ValueError("Failed to initialize one or more BiCodec components due to missing configuration.") # Create the BiCodec instance model = cls( mel_params=mel_params, encoder=encoder, decoder=decoder, # Pass WaveGenerator instance as decoder quantizer=quantizer, speaker_encoder=speaker_encoder, prenet=prenet, postnet=postnet, ) # Load state dict try: state_dict = load_file(ckpt_path, device="cpu") # Load to CPU first missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) if missing_keys: print(f"BiCodec missing keys: {missing_keys}") if unexpected_keys: print(f"BiCodec unexpected keys: {unexpected_keys}") except Exception as e: raise IOError(f"Error loading BiCodec state dict from {ckpt_path}: {e}") model.eval() # model.remove_weight_norm() # Assuming this method exists in components return model def _init_mel_transformer(self, config: Dict[str, Any]): # Ensure required keys exist with defaults sr = config.get("sample_rate", 16000) n_fft = config.get("n_fft", 1024) win_length = config.get("win_length", n_fft) hop_length = config.get("hop_length", n_fft // 4) fmin = config.get("mel_fmin", 0) fmax = config.get("mel_fmax", None) n_mels = config.get("num_mels", 80) power = config.get("power", 2.0) # Typically 2.0 for power spectrogram norm = config.get("norm", "slaney") mel_scale = config.get("mel_scale", "htk") # htk or slaney self.mel_transformer = TT.MelSpectrogram( sample_rate=sr, n_fft=n_fft, win_length=win_length, hop_length=hop_length, f_min=fmin, f_max=fmax, n_mels=n_mels, power=power, norm=norm, mel_scale=mel_scale, ).eval() # Set to eval mode def remove_weight_norm(self): """Removes weight normalization from components that support it.""" def _remove_wn(m): if hasattr(m, 'remove_weight_norm'): m.remove_weight_norm() elif isinstance(m, (nn.Conv1d, nn.ConvTranspose1d)): try: remove_weight_norm(m) except ValueError: pass # Module might not have weight norm applied self.apply(_remove_wn) @torch.no_grad() def tokenize(self, feat: torch.Tensor, ref_wav: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Tokenizes input features and reference wav into semantic and global tokens. """ # Ensure models are on the correct device device = feat.device self.mel_transformer.to(device) self.encoder.to(device) self.quantizer.to(device) self.speaker_encoder.to(device) # feat: (B, D_feat, T_feat), ref_wav: (B, T_wav) mel = self.mel_transformer(ref_wav) # (B, D_mel, T_mel) # Encode features to get latents for semantic tokens z = self.encoder(feat) # (B, D_latent, T_latent) - Assuming Encoder output matches quantizer input dim # Quantize latents to get semantic tokens (indices) semantic_tokens = self.quantizer.tokenize(z) # (B, T_latent) # Encode mel spectrogram to get global tokens (indices) # SpeakerEncoder.tokenize expects (B, T_mel, D_mel) global_tokens = self.speaker_encoder.tokenize(mel.transpose(1, 2)) # (B, T_token, Q) - Check shape # Note: Original BiCodecTokenizer returned (global_tokens, semantic_tokens) # Let's stick to that order for consistency with original SparkTTS usage. return global_tokens, semantic_tokens @torch.no_grad() def detokenize(self, global_tokens: torch.Tensor, semantic_tokens: torch.Tensor) -> torch.Tensor: """ Detokenizes semantic and global tokens into a waveform. """ # Ensure models are on the correct device device = semantic_tokens.device # Assume tokens are on target device self.quantizer.to(device) self.speaker_encoder.to(device) self.prenet.to(device) self.decoder.to(device) # WaveGenerator # semantic_tokens: (B, T_latent) or (B, T_latent, Q)? Check quantizer.tokenize output shape. Assuming (B, T_latent). # global_tokens: (B, T_token, Q) - Check speaker_encoder.tokenize output shape. # Reconstruct quantized vectors from semantic tokens z_q = self.quantizer.detokenize(semantic_tokens) # (B, D_latent, T_latent) # Reconstruct d-vector (condition) from global tokens # SpeakerEncoder.detokenize expects (B, T_token, Q) d_vector = self.speaker_encoder.detokenize(global_tokens) # (B, D_dvector) # Apply prenet conditioned on d-vector # Prenet (Decoder class) expects input (B, D_latent, T_latent) and condition (B, D_dvector) x = self.prenet(z_q, d_vector) # (B, D_prenet_out, T_latent) - Assuming prenet maintains time dim # Add condition (broadcasted) before wave generation - Check original logic # Ensure d_vector has correct shape for broadcasting if d_vector.ndim == 2: d_vector_unsqueezed = d_vector.unsqueeze(-1) # (B, D_dvector, 1) else: # Should not happen if speaker_encoder outputs (B, D) d_vector_unsqueezed = d_vector # Ensure dimensions match for addition if x.shape[1] == d_vector_unsqueezed.shape[1]: # Broadcast d_vector across time dimension T_latent x = x + d_vector_unsqueezed else: # Maybe project d_vector or x? Log a warning or adapt based on expected dims. logging.get_logger(__name__).warning(f"Prenet output dim {x.shape[1]} != d-vector dim {d_vector_unsqueezed.shape[1]}. Skipping residual connection.") # Generate waveform using the decoder (WaveGenerator) # WaveGenerator expects (B, D_input, T_input) wav_recon = self.decoder(x) # (B, 1, T_wav) return wav_recon # --- Main SparkTTS Model --- from .configuration_spark_tts import SparkTTSConfig # from ._utils import load_audio # Use utils from _utils.py logger = logging.get_logger(__name__) class SparkTTSModel(PreTrainedModel, GenerationMixin): """ SparkTTS model integrating LLM, BiCodec, and Wav2Vec2 for text-to-speech. """ config_class = SparkTTSConfig base_model_prefix = "spark_tts" _supports_load_fast = True def __init__(self, config: SparkTTSConfig, llm=None, wav2vec2_model=None, wav2vec2_processor=None, bicodec=None): super().__init__(config) self.config = config self.llm = llm self.wav2vec2_model = wav2vec2_model self.wav2vec2_processor = wav2vec2_processor self.bicodec = bicodec # Ensure wav2vec2 config has output_hidden_states=True after loading self.post_init() @classmethod def from_pretrained( cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None, cache_dir: Optional[Union[str, os.PathLike]] = None, ignore_mismatched_sizes: bool = False, force_download: bool = False, local_files_only: bool = False, token: Optional[Union[str, bool]] = None, revision: str = "main", use_safetensors: Optional[bool] = None, # Keep None to let transformers decide **kwargs, ): # Pop device map and dtype early - handle placement later # Note: device_map is complex with multiple components. Manual .to(device) is simpler here. device_map = kwargs.pop("device_map", None) if device_map: logger.warning("`device_map` is not directly supported for this composite model. Use .to(device) after loading.") torch_dtype = kwargs.pop("torch_dtype", "auto") # Can be "auto", float32, float16, bfloat16 trust_remote_code = kwargs.pop("trust_remote_code", False) # CRITICAL for custom code # --- 1. Resolve the main model directory --- # This handles downloading from Hub or using a local path robustly. if pretrained_model_name_or_path is None: raise ValueError("`pretrained_model_name_or_path` must be provided.") model_path = Path(pretrained_model_name_or_path) if not model_path.is_dir(): # If it's not a local directory, assume it's a Hub ID and download everything logger.info(f"{pretrained_model_name_or_path} is not a local directory. Assuming Hub ID and downloading.") try: resolved_model_path = snapshot_download( repo_id=str(pretrained_model_name_or_path), cache_dir=cache_dir, force_download=force_download, local_files_only=local_files_only, token=token, revision=revision, allow_patterns=["*.json", "*.safetensors", "*.bin", "*.yaml", "*.txt", "README.md"], # Be somewhat permissive # ignore_patterns=["*.git*"], # Optional: ignore git files # user_agent={"agent": "spark-tts-custom-loader"}, # Optional ) resolved_model_path = Path(resolved_model_path) logger.info(f"Model downloaded to cache: {resolved_model_path}") except Exception as e: raise OSError( f"Failed to download model '{pretrained_model_name_or_path}' from Hugging Face Hub. " f"Ensure the ID is correct and network is available. Error: {e}" ) else: # It's a local directory path resolved_model_path = model_path logger.info(f"Loading model from local directory: {resolved_model_path}") if not resolved_model_path.is_dir(): # This should ideally not happen after snapshot_download or initial check raise EnvironmentError(f"Cannot find resolved model directory at {resolved_model_path}") # --- 2. Load the main configuration --- # The config might have been passed explicitly, otherwise load from resolved path if not isinstance(config, PretrainedConfig): config_path = config if config is not None else resolved_model_path config, model_kwargs = cls.config_class.from_pretrained( config_path, # Load from the resolved directory or explicit config path *model_args, # Pass *model_args here if they influence config loading cache_dir=cache_dir, # Pass relevant args down force_download=force_download, local_files_only=local_files_only, token=token, revision=revision, trust_remote_code=trust_remote_code, # Needed if config class itself is remote return_unused_kwargs=True, **kwargs, # Pass remaining kwargs ) # Update kwargs with unused ones from config loading kwargs.update(model_kwargs) else: # Config object was passed directly pass # kwargs remain as they were # --- Determine torch_dtype (use config value if specified and not overridden) --- # Priority: Explicit torch_dtype arg > config.torch_dtype > "auto" (default) final_torch_dtype = torch_dtype # Explicit arg has highest prio if final_torch_dtype == "auto": final_torch_dtype = getattr(config, "torch_dtype", None) # Use config value if present # final_torch_dtype can still be None or "auto" here, handle downstream # --- Helper function to resolve paths relative to the main model directory --- def _resolve_sub_path(sub_path): p = Path(sub_path) if p.is_absolute(): return str(p) else: # Resolve relative to the potentially cached main model path return str(resolved_model_path / p) # --- 3. Load Sub-components --- # --- Load LLM --- llm_path = _resolve_sub_path(config.llm_model_name_or_path) logger.info(f"Loading LLM from resolved path: {llm_path}") try: llm = AutoModelForCausalLM.from_pretrained( llm_path, torch_dtype=final_torch_dtype if final_torch_dtype != "auto" else None, # Pass resolved dtype or None trust_remote_code=trust_remote_code, # Pass down trust_remote_code # Pass remaining kwargs that might be relevant for AutoModelForCausalLM # Filter kwargs if necessary, but often passing them is fine **kwargs ) except Exception as e: raise OSError(f"Failed to load LLM from {llm_path}: {e}") # --- Load Wav2Vec2 --- w2v_path = _resolve_sub_path(config.wav2vec2_model_name_or_path) logger.info(f"Loading Wav2Vec2 components from resolved path: {w2v_path}") try: # Load feature extractor first wav2vec2_processor = Wav2Vec2FeatureExtractor.from_pretrained( w2v_path, trust_remote_code=trust_remote_code, # Add any relevant kwargs for feature extractor if needed ) # Load model wav2vec2_model = Wav2Vec2Model.from_pretrained( w2v_path, trust_remote_code=trust_remote_code, # Add any relevant kwargs for model if needed (e.g., add_adapter=False) ) except Exception as e: raise OSError(f"Failed to load Wav2Vec2 components from {w2v_path}: {e}") # --- Load BiCodec --- bicodec_path = _resolve_sub_path(config.bicodec_model_name_or_path) logger.info(f"Loading BiCodec from resolved path: {bicodec_path}") if not config.bicodec_config or "audio_tokenizer" not in config.bicodec_config: raise ValueError("BiCodec configuration ('bicodec_config' with 'audio_tokenizer' key) not found in SparkTTSConfig.") try: # Assuming BiCodec class has the custom loading method # Make sure BiCodec class is imported or defined above bicodec = BiCodec.load_from_config_and_checkpoint( model_dir=Path(bicodec_path), config_dict=config.bicodec_config["audio_tokenizer"] ) # Ensure BiCodec is an nn.Module if you want .to(device) to work easily if not isinstance(bicodec, torch.nn.Module): logger.warning("Loaded BiCodec component is not an instance of torch.nn.Module. Automatic device placement might not work.") except FileNotFoundError as e: raise OSError(f"Failed to load BiCodec: A required file was not found in {bicodec_path}. Original error: {e}") except Exception as e: logger.error(f"Raw error loading BiCodec: {type(e).__name__}: {e}") import traceback traceback.print_exc() # Print full traceback for debugging BiCodec loading raise OSError(f"Failed to load BiCodec from {bicodec_path}. Check BiCodec implementation and file paths. Error: {e}") # --- 4. Instantiate the main model wrapper --- # Pass the loaded config and components model = cls(config, llm=llm, wav2vec2_model=wav2vec2_model, wav2vec2_processor=wav2vec2_processor, bicodec=bicodec) # --- 5. Handle device placement --- # Move the entire model (including sub-modules if they are nn.Module attributes) # Determine target device based on availability if torch.cuda.is_available(): final_device = torch.device("cuda") # If multiple GPUs, could select one, e.g., torch.device("cuda:0") # Or rely on CUDA_VISIBLE_DEVICES environment variable else: final_device = torch.device("cpu") logger.info(f"Placing SparkTTSModel and components on device: {final_device}") # This should move all registered nn.Module attributes (llm, wav2vec2_model, bicodec if it's an nn.Module) try: model.to(final_device) except Exception as e: logger.error(f"Failed to move model to device {final_device}. Error: {e}") logger.warning("Device placement might be incomplete. Check component types.") # --- 6. Return the loaded and prepared model --- return model # --- Embedding getters/setters (delegate to LLM if loaded) --- def get_input_embeddings(self): if self.llm: return self.llm.get_input_embeddings() return None # Or raise error def set_input_embeddings(self, value): if self.llm: self.llm.set_input_embeddings(value) else: logger.warning("LLM not loaded, cannot set input embeddings.") def get_output_embeddings(self): if self.llm: # For causal LM, output embeddings are usually tied to lm_head return self.llm.get_output_embeddings() return None # Or raise error def set_output_embeddings(self, new_embeddings): if self.llm and hasattr(self.llm, 'set_output_embeddings'): self.llm.set_output_embeddings(new_embeddings) else: logger.warning("LLM not loaded or does not support set_output_embeddings.") # --- End Embedding methods --- # post_init is less critical now as loading happens in from_pretrained, # but can be used for final checks or setup. def post_init(self): # Ensure wav2vec2 config has output_hidden_states=True after loading if self.wav2vec2_model and hasattr(self.wav2vec2_model.config, 'output_hidden_states'): if not self.wav2vec2_model.config.output_hidden_states: self.wav2vec2_model.config.output_hidden_states = True logger.info("Set wav2vec2_model.config.output_hidden_states=True") else: logger.warning("Could not access wav2vec2_model.config to ensure output_hidden_states=True.") @property def device(self) -> torch.device: """ Override device property to report the LLM's device as representative """ if self.llm: return self.llm.device else: # Fallback or default if LLM not loaded yet # This might be called by pipeline before full init? Be cautious. try: return next(self.parameters()).device except StopIteration: # If no parameters, default to CPU return torch.device("cpu") @torch.no_grad() def _extract_wav2vec2_features(self, wavs: torch.Tensor) -> torch.Tensor: """Extract wav2vec2 features. Input wavs: (B, T_wav)""" if not self.wav2vec2_model or not self.wav2vec2_processor: raise RuntimeError("Wav2Vec2 components not loaded.") # Use component's device target_device = self.wav2vec2_model.device wavs_on_device = wavs.to(target_device) # Expected shape [B, T_wav] e.g., [1, 61120] # Process audio using the Wav2Vec2FeatureExtractor processor_output = self.wav2vec2_processor( wavs_on_device, sampling_rate=self.config.sample_rate, return_tensors="pt", padding=True, # Ensure padding is handled correctly ) inputs = processor_output.input_values # Should be shape [B, T_processed] # --- START DEBUG & FIX --- print(f"Shape returned by processor: {inputs.shape}") # Reshape if processor added extra dimensions if inputs.ndim == 4 and inputs.shape[1] == 1 and inputs.shape[2] == 1: print(f"Reshaping input from {inputs.shape} to 2D.") inputs = inputs.squeeze(1).squeeze(1) # Remove the two middle dimensions elif inputs.ndim == 3 and inputs.shape[1] == 1: print(f"Reshaping input from {inputs.shape} to 2D.") inputs = inputs.squeeze(1) # Remove the channel dimension # Ensure final shape is 2D: (batch_size, sequence_length) if inputs.ndim != 2: raise ValueError(f"Unexpected shape after processing/reshaping: {inputs.shape}. Expected 2D input for Wav2Vec2Model.") print(f"Shape BEFORE Wav2Vec2Model: {inputs.shape}") # --- END DEBUG & FIX --- inputs = inputs.to(target_device) # Ensure output_hidden_states=True during call if not set reliably in config outputs = self.wav2vec2_model(inputs, output_hidden_states=True) if outputs.hidden_states is None: raise ValueError("Wav2Vec2 model did not return hidden states. Ensure config.output_hidden_states=True.") # Mix specific layers num_layers = len(outputs.hidden_states) indices_to_mix = [11, 14, 16] valid_indices = [i for i in indices_to_mix if i < num_layers] if len(valid_indices) != len(indices_to_mix): logger.warning(f"Requested Wav2Vec2 hidden state indices {indices_to_mix} out of range (0-{num_layers-1}). Using available valid indices: {valid_indices}.") if not valid_indices: # If no valid indices, use last hidden state logger.warning("No valid hidden state indices for mixing. Using last hidden state.") feats_mix = outputs.last_hidden_state else: # Mix available valid indices feats_mix = torch.stack([outputs.hidden_states[i] for i in valid_indices]).mean(dim=0) else: # Original mixing logic feats_mix = (outputs.hidden_states[11] + outputs.hidden_states[14] + outputs.hidden_states[16]) / 3 # Output shape: (B, T_feat, D_feat) - Transpose needed for BiCodec Encoder return feats_mix.transpose(1, 2) # (B, D_feat, T_feat) def _get_ref_clip(self, wav: np.ndarray) -> np.ndarray: """Get reference audio clip for speaker embedding.""" ref_samples = int(self.config.sample_rate * self.config.ref_segment_duration) latent_hop_length = self.config.latent_hop_length # Ensure length is multiple of hop_length for potential downstream processing ref_segment_length = max(latent_hop_length, (ref_samples // latent_hop_length) * latent_hop_length) # Ensure at least one hop wav_length = len(wav) if wav_length == 0: # Handle empty input return np.zeros(ref_segment_length, dtype=np.float32) if ref_segment_length > wav_length: num_repeats = (ref_segment_length // wav_length) + 1 wav = np.tile(wav, num_repeats) return wav[:ref_segment_length].astype(np.float32) # Ensure float32 @torch.no_grad() def _tokenize_audio(self, audio_path: str) -> Tuple[torch.Tensor, torch.Tensor]: """Load audio, extract features, and tokenize using BiCodec.""" wav_np = load_audio( audio_path, sampling_rate=self.config.sample_rate, volume_normalize=self.config.volume_normalize, ) wav_ref_np = self._get_ref_clip(wav_np) # Convert to tensors, add batch dim, move to device wav = torch.from_numpy(wav_np).unsqueeze(0).float().to(self.device) ref_wav = torch.from_numpy(wav_ref_np).unsqueeze(0).float().to(self.device) # Extract Wav2Vec2 features -> (B, D_feat, T_feat) feat = self._extract_wav2vec2_features(wav) # Tokenize using BiCodec -> returns (global_tokens, semantic_tokens) # BiCodec.tokenize expects feat: (B, D_feat, T_feat), ref_wav: (B, T_wav) global_tokens, semantic_tokens = self.bicodec.tokenize(feat, ref_wav) # global_tokens: (B, T_token, Q), semantic_tokens: (B, T_latent) return global_tokens, semantic_tokens @torch.no_grad() def _detokenize_audio(self, global_tokens: torch.Tensor, semantic_tokens: torch.Tensor) -> np.ndarray: """Detokenize using BiCodec to get waveform.""" global_tokens = global_tokens.to(self.device) semantic_tokens = semantic_tokens.to(self.device) self.bicodec.to(self.device) # Ensure BiCodec is on device # BiCodec.detokenize expects global_tokens: (B, T_token, Q), semantic_tokens: (B, T_latent) wav_rec = self.bicodec.detokenize(global_tokens, semantic_tokens) # (B, 1, T_wav) # Remove channel dim and batch dim, convert to numpy return wav_rec.detach().squeeze(0).squeeze(0).cpu().numpy() def forward(self, *args, **kwargs): """ Forward pass delegates to the LLM for generation compatibility, but direct use is not intended for TTS. """ # return self.llm(*args, **kwargs) # Option 1: Delegate fully logger.warning("Direct forward pass on SparkTTSModel is not the intended use for TTS. Use the generate method or pipeline.") # Option 2: Minimal implementation for compatibility if needed if 'input_ids' in kwargs: return self.llm(input_ids=kwargs['input_ids']) else: raise NotImplementedError("SparkTTSModel's forward pass requires 'input_ids' or should not be called directly for TTS.") # Use GenerationMixin's forward method by default if needed. # Define prepare_inputs_for_generation if LLM needs specific handling. def prepare_inputs_for_generation(self, input_ids, **kwargs): """ Prepares inputs for the LLM's generate method. """ if not self.llm: raise RuntimeError("LLM component not loaded.") # --- START REVISED IMPLEMENTATION --- # Delegate to the LLM's prepare_inputs_for_generation method directly. # This ensures we use the exact logic defined for the specific LLM architecture (Qwen2). # It should handle past_key_values, attention_mask, use_cache etc. correctly. try: # Pass all relevant kwargs received by the top-level generate call # The LLM's method will select what it needs. model_inputs = self.llm.prepare_inputs_for_generation(input_ids, **kwargs) return model_inputs except AttributeError: # Fallback if the LLM doesn't have this method (unlikely for recent models) logger.warning("LLM does not have 'prepare_inputs_for_generation'. Using basic fallback.") model_kwargs = {} model_kwargs["past_key_values"] = kwargs.get("past_key_values", None) model_kwargs["use_cache"] = kwargs.get("use_cache", None) # Ensure attention_mask is included if present in kwargs if "attention_mask" in kwargs: model_kwargs["attention_mask"] = kwargs["attention_mask"] return {"input_ids": input_ids, **model_kwargs} # --- END REVISED IMPLEMENTATION --- # We need a minimal forward method compatible with GenerationMixin # It should accept the output of prepare_inputs_for_generation def forward( self, input_ids: Optional[torch.Tensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, **kwargs # Accept other potential kwargs from prepare_inputs ) -> Any: # Return type depends on the LLM, usually CausalLMOutputWithPast """ Minimal forward pass that delegates to the underlying LLM. Required for compatibility with GenerationMixin. Accepts arguments typically returned by prepare_inputs_for_generation. """ if not self.llm: raise RuntimeError("LLM component not loaded.") # Filter arguments for the LLM's forward method # (Some LLMs might not accept position_ids directly in forward when using past_key_values) llm_kwargs = { "past_key_values": past_key_values, "attention_mask": attention_mask, **kwargs # Pass through any other relevant kwargs } # Only pass position_ids if the LLM's forward signature accepts it # This requires inspecting the LLM's forward signature or knowing its behavior. # For simplicity, we might omit it if it causes issues, or handle it more dynamically. # Let's assume the LLM forward can handle it for now if prepare_inputs included it. if position_ids is not None: llm_kwargs["position_ids"] = position_ids return self.llm(input_ids=input_ids, **llm_kwargs) # Add generate method to use GenerationMixin capabilities directly on SparkTTSModel if desired # This will internally call prepare_inputs_for_generation and forward (which might need defining/adjusting) # However, the pipeline calls self.model.llm.generate, so this might not be strictly needed unless you want `model.generate(...)` # @torch.no_grad() # def generate(self, *args, **kwargs): # if not self.llm: # raise RuntimeError("LLM component not loaded.") # # This might need adjustments based on how GenerationMixin interacts with the overridden forward # # return super().generate(*args, **kwargs) # Calls self.prepare_inputs + self.forward loop # # Or directly call the LLM's generate if forward is problematic: # return self.llm.generate(*args, **kwargs)