|
from fastapi import FastAPI, HTTPException |
|
from pydantic import BaseModel |
|
from llama_cpp import Llama |
|
from nemoguardrails import LLMRails, RailsConfig |
|
import os |
|
from langchain_community.llms import LlamaCpp |
|
|
|
|
|
app = FastAPI() |
|
MODEL_PATH = "./kai-model-7.2B-Q4_0.gguf" |
|
llm = LlamaCpp( |
|
model_path="./kai-model-7.2B-Q4_0.gguf", |
|
temperature=0.7, |
|
top_k=40, |
|
top_p=0.95 |
|
) |
|
|
|
|
|
config = RailsConfig.from_path("./config") |
|
rails = LLMRails(config, llm=llm) |
|
|
|
class ChatRequest(BaseModel): |
|
message: str |
|
|
|
@app.post("/chat") |
|
async def chat_endpoint(request: ChatRequest): |
|
try: |
|
|
|
response = await rails.generate_async( |
|
messages=[{"role": "user", "content": request.message}] |
|
) |
|
return {"response": response["content"]} |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
@app.get("/health") |
|
def health_check(): |
|
return {"status": "ok", "model": MODEL_PATH} |
|
|
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
uvicorn.run(main, host="127.0.0.1", port=8000) |