|
|
|
from typing import Optional |
|
from nemoguardrails.actions import action |
|
from llama_index.core import SimpleDirectoryReader |
|
from llama_index.packs.recursive_retriever import RecursiveRetrieverSmallToBigPack |
|
from llama_index.core.base.base_query_engine import BaseQueryEngine |
|
from llama_index.core.base.response.schema import StreamingResponse |
|
import traceback |
|
import logging |
|
|
|
<<<<<<< HEAD |
|
|
|
|
|
======= |
|
>>>>>>> 3f9ef9d356d6a3a3fbfa4fd1887f2a716f06a0fc |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
query_engine_cache: Optional[BaseQueryEngine] = None |
|
|
|
|
|
|
|
@action(name="simple_response") |
|
async def simple_response_action(context: dict): |
|
"""Direct response without RAG""" |
|
user_message = context.get("user_message", "") |
|
|
|
|
|
|
|
return { |
|
"result": f"I received your question: '{user_message}'. Let me think about that." |
|
} |
|
|
|
def init_query_engine() -> BaseQueryEngine: |
|
global query_engine_cache |
|
if query_engine_cache is None: |
|
docs = SimpleDirectoryReader("data").load_data() |
|
retriever = RecursiveRetrieverSmallToBigPack(docs) |
|
query_engine_cache = retriever.query_engine |
|
return query_engine_cache |
|
|
|
def get_query_response(engine: BaseQueryEngine, query: str) -> str: |
|
resp = engine.query(query) |
|
if isinstance(resp, StreamingResponse): |
|
resp = resp.get_response() |
|
return resp.response or "" |
|
|
|
@action(name="user_query", execute_async=True) |
|
async def UserQueryAction(context: dict): |
|
try: |
|
user_message = context.get("user_message", "") |
|
if not user_message: |
|
return "Please provide a valid question." |
|
|
|
engine = init_query_engine() |
|
return get_query_response(engine, user_message) |
|
|
|
except Exception as e: |
|
logger.error(f"Error in UserQueryAction: {str(e)}") |
|
logger.error(traceback.format_exc()) |
|
return "I encountered an error processing your request. Please try again later." |
|
|
|
@action(name="simple_query") |
|
async def SimpleQueryAction(context: dict): |
|
return "I received your question about: " + context.get("user_message", "") |
|
|
|
@action(name="dummy_query") |
|
async def DummyQueryAction(context: dict): |
|
return "This is a test response" |