|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import sys |
|
import argparse |
|
import gradio as gr |
|
import numpy as np |
|
import torch |
|
import torchaudio |
|
import random |
|
import librosa |
|
|
|
ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) |
|
sys.path.append(os.path.join(ROOT_DIR, 'third_party', 'Matcha-TTS')) |
|
|
|
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2 |
|
from cosyvoice.utils.file_utils import load_wav, logging |
|
from cosyvoice.utils.common import set_all_random_seed |
|
|
|
|
|
|
|
|
|
|
|
|
|
max_val = 0.8 |
|
|
|
|
|
cosyvoice = None |
|
prompt_sr = 16000 |
|
default_data = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def postprocess(speech, top_db=60, hop_length=220, win_length=440): |
|
""" |
|
后处理函数,处理音频数据(包括归一化、去除静音、添加尾部静音)。 |
|
输入: speech (torch.Tensor), 可能是 (N,) 或 (C, N) |
|
输出: out (torch.Tensor), 始终为 (1, N') |
|
""" |
|
|
|
|
|
speech_np = speech.cpu().numpy() |
|
if speech_np.ndim > 1: |
|
speech_np = speech_np[0] |
|
|
|
|
|
speech_trimmed_np, _ = librosa.effects.trim( |
|
speech_np, top_db=top_db, |
|
frame_length=win_length, |
|
hop_length=hop_length |
|
) |
|
|
|
|
|
speech_trimmed_tensor = torch.from_numpy(speech_trimmed_np).to(speech.device).float() |
|
|
|
|
|
if speech_trimmed_tensor.ndim == 1: |
|
speech_trimmed_tensor = speech_trimmed_tensor.unsqueeze(0) |
|
|
|
|
|
if speech_trimmed_tensor.abs().max() > max_val: |
|
speech_trimmed_tensor = speech_trimmed_tensor / speech_trimmed_tensor.abs().max() * max_val |
|
|
|
|
|
|
|
pad_tensor = torch.zeros(1, int(cosyvoice.sample_rate * 0.2), device=speech_trimmed_tensor.device, dtype=speech_trimmed_tensor.dtype) |
|
|
|
|
|
out = torch.cat([speech_trimmed_tensor, pad_tensor], dim=1) |
|
return out |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_audio( |
|
tts_text: str, |
|
prompt_wav_upload: str, |
|
prompt_wav_record: str, |
|
prompt_text: str |
|
|
|
|
|
): |
|
""" |
|
根据输入文本和prompt音频生成语音(仅支持3s极速复刻模式)。 |
|
""" |
|
global cosyvoice, default_data |
|
|
|
if cosyvoice is None: |
|
gr.Info("模型未初始化,请检查启动配置。") |
|
|
|
return None |
|
|
|
if prompt_wav_upload is not None: |
|
prompt_wav = prompt_wav_upload |
|
elif prompt_wav_record is not None: |
|
prompt_wav = prompt_wav_record |
|
else: |
|
prompt_wav = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if prompt_wav is None: |
|
gr.Info('prompt音频为空,您是否忘记输入prompt音频?') |
|
return None |
|
|
|
|
|
try: |
|
|
|
info = torchaudio.info(prompt_wav) |
|
if info.sample_rate < prompt_sr: |
|
gr.Info(f"prompt 音频采样率过低:{info.sample_rate} < {prompt_sr}") |
|
return None |
|
except Exception as e: |
|
gr.Info(f"无法读取 prompt 音频信息,请检查文件格式或损坏:{e}") |
|
return None |
|
|
|
if not prompt_text: |
|
gr.Info('prompt文本为空,您是否忘记输入prompt文本?') |
|
return None |
|
|
|
|
|
|
|
|
|
try: |
|
|
|
wav_tensor = load_wav(prompt_wav, prompt_sr) |
|
prompt_speech_16k = postprocess(wav_tensor) |
|
except Exception as e: |
|
gr.Info(f"处理 prompt 音频时出错:{e}") |
|
return None |
|
|
|
|
|
set_all_random_seed(0) |
|
logging.info("执行 3s 极速复刻 推理") |
|
|
|
try: |
|
|
|
|
|
|
|
result = next(cosyvoice.inference_zero_shot( |
|
tts_text, |
|
prompt_text, |
|
prompt_speech_16k, |
|
stream=False, |
|
speed=1.0 |
|
)) |
|
audio = result["tts_speech"].numpy().flatten() |
|
return cosyvoice.sample_rate, audio |
|
except Exception as e: |
|
gr.Info(f"推理过程中发生错误:{e}") |
|
|
|
return cosyvoice.sample_rate, default_data |
|
|
|
|
|
def main(): |
|
with gr.Blocks() as demo: |
|
|
|
gr.Markdown("### SMIIP-NV finetune CosyVoice2") |
|
gr.Markdown("#### 上传一段 ≤30s 的 prompt 音频,填写对应文本,合成目标语音。") |
|
|
|
tts_text = gr.Textbox(label="输入合成文本", lines=1, value="在这个孤独的夜晚<crying>,窗外的雨声让我想起了你,<crying>我真的好想你。") |
|
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
|
prompt_wav_upload = gr.Audio(sources=['upload'], type='filepath', label='选择prompt音频文件,注意采样率不低于16khz') |
|
prompt_wav_record = gr.Audio(sources=['microphone'], type='filepath', label='录制prompt音频文件') |
|
prompt_text = gr.Textbox(label="输入prompt文本", lines=1, placeholder="请输入prompt文本,需与prompt音频内容一致,暂时不支持自动识别...", value='') |
|
|
|
|
|
generate_button = gr.Button("生成音频") |
|
|
|
|
|
|
|
audio_output = gr.Audio(label="合成音频", autoplay=True, streaming=True) |
|
|
|
|
|
|
|
|
|
|
|
generate_button.click(generate_audio, |
|
inputs=[tts_text, prompt_wav_upload, prompt_wav_record, prompt_text], |
|
outputs=[audio_output]) |
|
|
|
|
|
|
|
|
|
demo.queue(max_size=4, default_concurrency_limit=2) |
|
demo.launch(server_name='0.0.0.0', server_port=args.port) |
|
|
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--port', |
|
type=int, |
|
default=8000, |
|
help="服务启动端口") |
|
parser.add_argument('--model_dir', |
|
type=str, |
|
default='pretrained_models/CosyVoice2-0.5B', |
|
help='local path or modelscope repo id') |
|
args = parser.parse_args() |
|
|
|
|
|
try: |
|
cosyvoice = CosyVoice(args.model_dir) |
|
print("CosyVoice 模型加载成功!") |
|
except Exception as e: |
|
print(f"加载 CosyVoice 模型失败:{e},尝试加载 CosyVoice2...") |
|
try: |
|
cosyvoice = CosyVoice2(args.model_dir) |
|
print("CosyVoice2 模型加载成功!") |
|
except Exception as e2: |
|
print(f"加载 CosyVoice2 模型也失败了:{e2}") |
|
|
|
raise TypeError('no valid model_type found for model_dir: ' + args.model_dir + f'\nError: {e2}') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
default_data = np.zeros(cosyvoice.sample_rate, dtype=np.float32) |
|
|
|
main() |
|
|