my-kai-model / config /actions.py
aferrmt's picture
Merge branch 'main' of https://huggingface.co/rmtlabs/my-kai-model
9b2dfd6
# config/actions.py
from typing import Optional
from nemoguardrails.actions import action
from llama_index.core import SimpleDirectoryReader
from llama_index.packs.recursive_retriever import RecursiveRetrieverSmallToBigPack
from llama_index.core.base.base_query_engine import BaseQueryEngine
from llama_index.core.base.response.schema import StreamingResponse
import traceback
import logging
<<<<<<< HEAD
=======
>>>>>>> 3f9ef9d356d6a3a3fbfa4fd1887f2a716f06a0fc
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Cache for the query engine
query_engine_cache: Optional[BaseQueryEngine] = None
@action(name="simple_response")
async def simple_response_action(context: dict):
"""Direct response without RAG"""
user_message = context.get("user_message", "")
# In a real implementation, you might add custom logic here
# But for basic usage, we'll let the LLM handle the response
return {
"result": f"I received your question: '{user_message}'. Let me think about that."
}
def init_query_engine() -> BaseQueryEngine:
global query_engine_cache
if query_engine_cache is None:
docs = SimpleDirectoryReader("data").load_data()
retriever = RecursiveRetrieverSmallToBigPack(docs)
query_engine_cache = retriever.query_engine
return query_engine_cache
def get_query_response(engine: BaseQueryEngine, query: str) -> str:
resp = engine.query(query)
if isinstance(resp, StreamingResponse):
resp = resp.get_response()
return resp.response or ""
@action(name="user_query", execute_async=True)
async def UserQueryAction(context: dict):
try:
user_message = context.get("user_message", "")
if not user_message:
return "Please provide a valid question."
engine = init_query_engine()
return get_query_response(engine, user_message)
except Exception as e:
logger.error(f"Error in UserQueryAction: {str(e)}")
logger.error(traceback.format_exc())
return "I encountered an error processing your request. Please try again later."
@action(name="simple_query")
async def SimpleQueryAction(context: dict):
return "I received your question about: " + context.get("user_message", "")
@action(name="dummy_query")
async def DummyQueryAction(context: dict):
return "This is a test response"