|
|
|
"""
|
|
Author: 一铭
|
|
Date : 2024-08-28
|
|
|
|
Github: https://github.com/HG-ha
|
|
Home : https://api2.wer.plus
|
|
|
|
Description:
|
|
From ali dharma school project: https://github.com/FunAudioLLM/SenseVoice
|
|
|
|
This program is distributed using ONNX-encapsulated fastapi,Provides an interface for reading audio from a network or file and predicting content.
|
|
|
|
If you need to use cuda, you need to install the OnnxRun-time gpu, not the onnxruntime.
|
|
"""
|
|
|
|
import librosa
|
|
import numpy as np
|
|
import aiohttp
|
|
from fastapi import FastAPI, Form, UploadFile, HTTPException
|
|
from pydantic import HttpUrl, ValidationError, BaseModel, Field
|
|
from typing import List, Union
|
|
from funasr_onnx import SenseVoiceSmall
|
|
from funasr_onnx.utils.postprocess_utils import rich_transcription_postprocess
|
|
from io import BytesIO
|
|
|
|
|
|
class ApiResponse(BaseModel):
|
|
message: str = Field(..., description="Status message indicating the success of the operation.")
|
|
results: str = Field(..., description="Remove label output")
|
|
label_result: str = Field(..., description="Default output")
|
|
|
|
|
|
app = FastAPI()
|
|
|
|
async def from_url_load_audio(audio: HttpUrl) -> np.array:
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.get(
|
|
audio,
|
|
headers={
|
|
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/127.0.0.0 Safari/537.36 Edg/127.0.0.0"
|
|
},
|
|
) as response:
|
|
if response.status != 200:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Failed to download image: {response.status}",
|
|
)
|
|
image_bytes = await response.read()
|
|
return BytesIO(image_bytes)
|
|
|
|
@app.post("/extract_text",response_model=ApiResponse)
|
|
async def upload_url(url: Union[HttpUrl, None] = Form(None), file: Union[UploadFile, None] = Form(None)):
|
|
if file:
|
|
audio = BytesIO(await file.read())
|
|
elif url:
|
|
try:
|
|
audio = await from_url_load_audio(str(url))
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
else:
|
|
return HTTPException(400,{"error": "No valid audio source provided."})
|
|
try:
|
|
res = model(audio, language=language, use_itn=True)
|
|
return {
|
|
"message": "input processed successfully",
|
|
"results": rich_transcription_postprocess(res[0]),
|
|
"label_result": res[0]
|
|
}
|
|
except ValidationError as e:
|
|
raise HTTPException(status_code=400, detail=e.errors())
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
model_dir = "iic/SenseVoiceSmall"
|
|
device_id = 0
|
|
batch_size = 16
|
|
language = "auto"
|
|
quantize = True
|
|
|
|
|
|
|
|
|
|
|
|
def load_data(self, wav_content: Union[str, np.ndarray, List[str], BytesIO], fs: int = None) -> List:
|
|
def load_wav(path: str) -> np.ndarray:
|
|
waveform, _ = librosa.load(path, sr=fs)
|
|
return waveform
|
|
|
|
if isinstance(wav_content, np.ndarray):
|
|
return [wav_content]
|
|
|
|
if isinstance(wav_content, str):
|
|
return [load_wav(wav_content)]
|
|
|
|
if isinstance(wav_content, list):
|
|
return [load_wav(path) for path in wav_content]
|
|
|
|
if isinstance(wav_content, BytesIO):
|
|
return [load_wav(wav_content)]
|
|
|
|
raise TypeError(f"The type of {wav_content} is not in [str, np.ndarray, list]")
|
|
|
|
SenseVoiceSmall.load_data = load_data
|
|
|
|
model = SenseVoiceSmall(
|
|
model_dir,
|
|
quantize=quantize,
|
|
device_id=device_id,
|
|
batch_size=batch_size
|
|
)
|
|
|
|
print("\n\nDocs: http://127.0.0.1:8000/docs\n")
|
|
import uvicorn
|
|
|
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|
|
|