my-kai-model / main.py
aferrmt's picture
0.1 Adding chat raw for improving response time
e582e30
raw
history blame
2.48 kB
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from nemoguardrails import LLMRails, RailsConfig
from typing import Any, Dict, Union
import os
from langchain_community.llms import LlamaCpp
from langchain_openai import ChatOpenAI
llm = ChatOpenAI(
base_url=os.getenv("OPENAI_API_BASE"),
api_key=os.getenv("OPENAI_API_KEY"),
model="kai-model:latest", # must match what your llama_cpp.server exposes
)
# --- Configura el provider OpenAI-like (llama.cpp server) ---
# Ajusta si usas otro host/puerto.
os.environ.setdefault("OPENAI_API_KEY", "sk-no-key-needed") # dummy
os.environ.setdefault("OPENAI_API_BASE", "http://localhost:8001/v1")
os.environ.setdefault("OPENAI_BASE_URL", "http://localhost:8001/v1") # por compatibilidad
# --- Carga tu configuración de guardrails ---
# Se espera estructura:
# ./config/
# config.yml
# rails/*.co (tus flows/policies)
config = RailsConfig.from_path("./config")
rails = LLMRails(config) # <- NO pases un LLM aquí; usa el provider OpenAI del config/env
app = FastAPI(title="Guardrailed LLM API")
class ChatRequest(BaseModel):
message: str
def _normalize_response(r: Union[str, Dict[str, Any]]) -> str:
if isinstance(r, str):
return r
if isinstance(r, dict):
for k in ("content", "output", "text"): # distintas versiones/devuelven claves distintas
if k in r and isinstance(r[k], str):
return r[k]
return str(r)
@app.post("/chat")
async def chat_endpoint(request: ChatRequest):
"""
Aplica NeMo Guardrails a la petición y delega la generación al servidor OpenAI-like de llama.cpp
configurado en OPENAI_API_BASE.
"""
try:
resp = await rails.generate_async(
messages=[{"role": "user", "content": request.message}]
)
return {"response": _normalize_response(resp)}
except Exception as e:
raise HTTPException(status_code=500, detail=f"{type(e).__name__}: {e}")
@app.get("/health")
def health_check():
return {
"status": "ok",
"openai_api_base": os.getenv("OPENAI_API_BASE") or os.getenv("OPENAI_BASE_URL"),
"rails_config_loaded": True,
}
@app.post("/chat_raw")
def chat_raw(r: ChatRequest):
return {"text": llm.invoke(r.message)} # same llm instance
if __name__ == "__main__":
# Desarrollo: uvicorn. En producción, usa gunicorn desde terminal.
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)