qinxy
upload
21c58e8
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Liu Yue)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import argparse
import gradio as gr
import numpy as np
import torch
import torchaudio
import random # 即使没有随机种子UI,set_all_random_seed可能还用
import librosa
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(ROOT_DIR, 'third_party', 'Matcha-TTS')) # 使用os.path.join更安全
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
from cosyvoice.utils.file_utils import load_wav, logging
from cosyvoice.utils.common import set_all_random_seed
# 移除与多模式相关的全局变量,这些在简化版中不再需要
# inference_mode_list = ['预训练音色', '3s极速复刻', '跨语种复刻', '自然语言控制']
# instruct_dict = {'预训练音色': '1. 选择预训练音色\n2. 点击生成音频按钮', ...}
# stream_mode_list = [('否', False), ('是', True)]
max_val = 0.8 # 保持音频归一化参数
# 全局变量,在 __main__ 中初始化
cosyvoice = None
prompt_sr = 16000 # prompt 音频采样率
default_data = None # 默认静音音频数据,在 cosyvoice 初始化后定义
# generate_seed 函数不再需要,因为随机种子UI已被移除
# def generate_seed():
# seed = random.randint(1, 100000000)
# return {
# "__type__": "update",
# "value": seed
# }
def postprocess(speech, top_db=60, hop_length=220, win_length=440):
"""
后处理函数,处理音频数据(包括归一化、去除静音、添加尾部静音)。
输入: speech (torch.Tensor), 可能是 (N,) 或 (C, N)
输出: out (torch.Tensor), 始终为 (1, N')
"""
# 核心修复点:将 torch.Tensor 转换为 numpy.ndarray 以便 librosa 处理
# 并确保是单声道
speech_np = speech.cpu().numpy()
if speech_np.ndim > 1: # 如果是多声道 (C, N)
speech_np = speech_np[0] # 取第一个通道,变为 (N,)
# 去除开头结尾静音 (librosa 操作 numpy 数组)
speech_trimmed_np, _ = librosa.effects.trim(
speech_np, top_db=top_db,
frame_length=win_length,
hop_length=hop_length
)
# 核心修复点:将 numpy.ndarray 转换回 torch.Tensor 进行后续操作
speech_trimmed_tensor = torch.from_numpy(speech_trimmed_np).to(speech.device).float()
# 核心修复点:确保张量是二维的 (1, samples) 以兼容后面的 dim=1 拼接
if speech_trimmed_tensor.ndim == 1:
speech_trimmed_tensor = speech_trimmed_tensor.unsqueeze(0) # 从 (N,) 变为 (1, N)
# 归一化 (torch 操作)
if speech_trimmed_tensor.abs().max() > max_val:
speech_trimmed_tensor = speech_trimmed_tensor / speech_trimmed_tensor.abs().max() * max_val
# 尾部加 0.2s 静音 (torch 操作)
# 创建一个 (1, M) 形状的零张量以兼容 speech_trimmed_tensor 的 (1, N) 形状进行 dim=1 拼接
pad_tensor = torch.zeros(1, int(cosyvoice.sample_rate * 0.2), device=speech_trimmed_tensor.device, dtype=speech_trimmed_tensor.dtype)
# 拼接在 dim=1 上 (水平拼接)
out = torch.cat([speech_trimmed_tensor, pad_tensor], dim=1)
return out
# change_instruction 函数不再需要
# def change_instruction(mode_checkbox_group):
# return instruct_dict[mode_checkbox_group]
def generate_audio(
tts_text: str,
prompt_wav_upload: str,
prompt_wav_record: str,
prompt_text: str
# 移除 sft_dropdown, mode_checkbox_group, instruct_text, seed, stream, speed
# 这些参数因为 UI 元素被移除,所以不再需要作为函数的输入
):
"""
根据输入文本和prompt音频生成语音(仅支持3s极速复刻模式)。
"""
global cosyvoice, default_data # 确保能访问全局变量
if cosyvoice is None:
gr.Info("模型未初始化,请检查启动配置。")
# yield (cosyvoice.sample_rate, default_data) # yield 仅用于生成器函数
return None # 对于非生成器函数,返回 None 清除输出
if prompt_wav_upload is not None:
prompt_wav = prompt_wav_upload
elif prompt_wav_record is not None:
prompt_wav = prompt_wav_record
else:
prompt_wav = None
# 移除所有与多模式相关的条件判断和警告,只保留 3s极速复刻 的核心逻辑
# 例如:
# if mode_checkbox_group in ['自然语言控制']: ...
# if mode_checkbox_group in ['跨语种复刻']: ...
# if mode_checkbox_group in ['预训练音色']: ...
# if mode_checkbox_group in ['3s极速复刻']: (保留其内部检查)
# 针对 3s极速复刻 模式的检查
if prompt_wav is None:
gr.Info('prompt音频为空,您是否忘记输入prompt音频?') # 使用 gr.Info 弹窗
return None
# 检查采样率
try:
# 核心修复点:torchaudio.info 返回 AudioMetaData,从中获取采样率
info = torchaudio.info(prompt_wav)
if info.sample_rate < prompt_sr:
gr.Info(f"prompt 音频采样率过低:{info.sample_rate} < {prompt_sr}") # 使用 gr.Info 弹窗
return None
except Exception as e:
gr.Info(f"无法读取 prompt 音频信息,请检查文件格式或损坏:{e}") # 使用 gr.Info 弹窗
return None
if not prompt_text:
gr.Info('prompt文本为空,您是否忘记输入prompt文本?') # 使用 gr.Info 弹窗
return None
# 移除其他不必要的 Info 提示,例如关于 instruct_text 会被忽略的提示
# 处理 prompt 音频
try:
# 核心修复点:load_wav(filepath, sr) 返回一个 torch.Tensor,不是 (wav, sr) 元组
wav_tensor = load_wav(prompt_wav, prompt_sr)
prompt_speech_16k = postprocess(wav_tensor) # postprocess 现在可以处理 torch.Tensor
except Exception as e:
gr.Info(f"处理 prompt 音频时出错:{e}")
return None
# 固定种子 & 非流式、速度 1.0 (因为相关UI已被移除,所以硬编码)
set_all_random_seed(0) # 对应 generate_seed 函数的移除
logging.info("执行 3s 极速复刻 推理")
try:
# 仅保留 3s极速复刻 模式的推理逻辑 (zero_shot)
# 移除其他模式的推理分支
# 这里使用 next() 来从生成器获取结果,因为 Gradio 接口不是生成器
result = next(cosyvoice.inference_zero_shot(
tts_text,
prompt_text,
prompt_speech_16k,
stream=False, # 硬编码为 False (因为流式UI已被移除)
speed=1.0 # 硬编码为 1.0 (因为速度UI已被移除)
))
audio = result["tts_speech"].numpy().flatten()
return cosyvoice.sample_rate, audio
except Exception as e:
gr.Info(f"推理过程中发生错误:{e}")
# 发生错误时返回静音数据,而不是 None,这样 Gradio Audio 组件不会报错
return cosyvoice.sample_rate, default_data
def main():
with gr.Blocks() as demo:
# 简化 Gradio Markdown 提示
gr.Markdown("### SMIIP-NV finetune CosyVoice2")
gr.Markdown("#### 上传一段 ≤30s 的 prompt 音频,填写对应文本,合成目标语音。")
tts_text = gr.Textbox(label="输入合成文本", lines=1, value="在这个孤独的夜晚<crying>,窗外的雨声让我想起了你,<crying>我真的好想你。")
# 移除与多模式、速度、流式、随机种子、instruct文本相关的 UI 元素
# with gr.Row(): 中的模式选择、指导文本、预训练音色、流式、速度、随机种子按钮/数字都被移除
with gr.Row():
# Gradio 4.x 更改:sources 参数使用列表
prompt_wav_upload = gr.Audio(sources=['upload'], type='filepath', label='选择prompt音频文件,注意采样率不低于16khz')
prompt_wav_record = gr.Audio(sources=['microphone'], type='filepath', label='录制prompt音频文件')
prompt_text = gr.Textbox(label="输入prompt文本", lines=1, placeholder="请输入prompt文本,需与prompt音频内容一致,暂时不支持自动识别...", value='')
# 移除 instruct_text
generate_button = gr.Button("生成音频")
# 保持 audio_output 定义,streaming=True 意味着 Gradio UI 可以处理流式输出
# 虽然当前 generate_audio 已被修改为非生成器,但保留此参数无害
audio_output = gr.Audio(label="合成音频", autoplay=True, streaming=True)
# 移除 seed_button.click 绑定
# seed_button.click(generate_seed, inputs=[], outputs=seed)
# 调整 generate_audio 的 inputs 参数,只保留实际的 UI 输入
generate_button.click(generate_audio,
inputs=[tts_text, prompt_wav_upload, prompt_wav_record, prompt_text],
outputs=[audio_output])
# 移除 mode_checkbox_group.change 绑定
# mode_checkbox_group.change(fn=change_instruction, inputs=[mode_checkbox_group], outputs=[instruction_text])
demo.queue(max_size=4, default_concurrency_limit=2)
demo.launch(server_name='0.0.0.0', server_port=args.port)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--port',
type=int,
default=8000,
help="服务启动端口")
parser.add_argument('--model_dir',
type=str,
default='pretrained_models/CosyVoice2-0.5B',
help='local path or modelscope repo id')
args = parser.parse_args()
# 初始化模型,并捕获错误
try:
cosyvoice = CosyVoice(args.model_dir)
print("CosyVoice 模型加载成功!")
except Exception as e:
print(f"加载 CosyVoice 模型失败:{e},尝试加载 CosyVoice2...")
try:
cosyvoice = CosyVoice2(args.model_dir)
print("CosyVoice2 模型加载成功!")
except Exception as e2:
print(f"加载 CosyVoice2 模型也失败了:{e2}")
# 如果模型加载失败,程序退出,并给出详细错误信息
raise TypeError('no valid model_type found for model_dir: ' + args.model_dir + f'\nError: {e2}')
# 移除 sft_spk 的初始化,因为预训练音色 UI 已被移除
# sft_spk = cosyvoice.list_available_spks()
# if len(sft_spk) == 0:
# sft_spk = ['']
# 确保 default_data 在 cosyvoice 初始化后定义,用于错误返回时的静音音频
default_data = np.zeros(cosyvoice.sample_rate, dtype=np.float32)
main()