|
"""
|
|
RAG minimal working example (indexing + retrieval + generation)
|
|
|
|
Features:
|
|
- Embed documents with sentence-transformers
|
|
- Index with FAISS (fallback to sklearn)
|
|
- Generate answers using Hugging Face transformers (FLAN-T5)
|
|
- CLI demo and a tiny Flask API endpoint
|
|
|
|
Usage:
|
|
pip install -r requirements.txt
|
|
python rag.py --build-corpus sample_docs/ # builds index
|
|
python rag.py --ask "What is RAG?" # ask question using built index
|
|
|
|
Author: ChatGPT (example)
|
|
"""
|
|
|
|
import os
|
|
import argparse
|
|
import json
|
|
from typing import List, Tuple
|
|
import numpy as np
|
|
|
|
|
|
from sentence_transformers import SentenceTransformer
|
|
|
|
|
|
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
|
|
|
|
|
|
try:
|
|
import faiss
|
|
_HAS_FAISS = True
|
|
except Exception:
|
|
_HAS_FAISS = False
|
|
from sklearn.neighbors import NearestNeighbors
|
|
|
|
|
|
from flask import Flask, request, jsonify
|
|
|
|
|
|
EMBEDDING_MODEL = "all-MiniLM-L6-v2"
|
|
GEN_MODEL = "google/flan-t5-small"
|
|
INDEX_DIR = "rag_index"
|
|
DOCS_JSON = os.path.join(INDEX_DIR, "documents.json")
|
|
VECTORS_NPY = os.path.join(INDEX_DIR, "vectors.npy")
|
|
FAISS_INDEX_FILE = os.path.join(INDEX_DIR, "faiss.index")
|
|
|
|
|
|
K = 5
|
|
|
|
|
|
|
|
class DocumentStore:
|
|
def __init__(self):
|
|
self.docs: List[dict] = []
|
|
self.embeddings = None
|
|
|
|
def add_documents(self, texts: List[str], metas: List[dict] = None):
|
|
if metas is None:
|
|
metas = [{} for _ in texts]
|
|
start_id = len(self.docs)
|
|
for i, (t, m) in enumerate(zip(texts, metas)):
|
|
self.docs.append({"id": str(start_id + i), "text": t, "meta": m})
|
|
|
|
def save(self):
|
|
os.makedirs(INDEX_DIR, exist_ok=True)
|
|
with open(DOCS_JSON, "w", encoding="utf-8") as f:
|
|
json.dump(self.docs, f, ensure_ascii=False, indent=2)
|
|
if self.embeddings is not None:
|
|
np.save(VECTORS_NPY, self.embeddings)
|
|
|
|
def load(self):
|
|
if os.path.exists(DOCS_JSON):
|
|
with open(DOCS_JSON, "r", encoding="utf-8") as f:
|
|
self.docs = json.load(f)
|
|
if os.path.exists(VECTORS_NPY):
|
|
self.embeddings = np.load(VECTORS_NPY)
|
|
|
|
|
|
|
|
class VectorIndex:
|
|
def __init__(self, dim: int):
|
|
self.dim = dim
|
|
self._use_faiss = _HAS_FAISS
|
|
if self._use_faiss:
|
|
|
|
self.index = faiss.IndexFlatIP(dim)
|
|
else:
|
|
self._nn = None
|
|
self.index = None
|
|
|
|
def fit(self, vectors: np.ndarray):
|
|
|
|
if self._use_faiss:
|
|
|
|
faiss.normalize_L2(vectors)
|
|
self.index.add(vectors.astype(np.float32))
|
|
|
|
faiss.write_index(self.index, FAISS_INDEX_FILE)
|
|
else:
|
|
|
|
self._nn = NearestNeighbors(n_neighbors=min(10, len(vectors)), metric="cosine")
|
|
self._nn.fit(vectors)
|
|
|
|
self.index = vectors
|
|
|
|
def query(self, qvec: np.ndarray, top_k: int = 5) -> List[Tuple[int, float]]:
|
|
"""
|
|
Returns list of (doc_idx, score) sorted by highest similarity
|
|
Score is cosine similarity (higher is better). For sklearn fallback, we convert distance.
|
|
"""
|
|
if self._use_faiss:
|
|
q = qvec.copy().astype(np.float32)
|
|
faiss.normalize_L2(q)
|
|
distances, indices = self.index.search(q, top_k)
|
|
|
|
results = []
|
|
for idx, dist in zip(indices[0], distances[0]):
|
|
if idx == -1:
|
|
continue
|
|
results.append((int(idx), float(dist)))
|
|
return results
|
|
else:
|
|
|
|
distances, indices = self._nn.kneighbors(qvec, n_neighbors=min(top_k, len(self.index)))
|
|
res = []
|
|
for idx, d in zip(indices[0], distances[0]):
|
|
sim = 1.0 - float(d)
|
|
res.append((int(idx), sim))
|
|
return res
|
|
|
|
|
|
|
|
class RAG:
|
|
def __init__(self, embedding_model_name=EMBEDDING_MODEL, gen_model_name=GEN_MODEL, device=-1):
|
|
|
|
print("Loading embedding model:", embedding_model_name)
|
|
self.embedder = SentenceTransformer(embedding_model_name, device="cpu" if device == -1 else f"cuda:{device}")
|
|
|
|
print("Loading generator model:", gen_model_name)
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(gen_model_name)
|
|
self.gen_model = AutoModelForSeq2SeqLM.from_pretrained(gen_model_name)
|
|
|
|
self.generator = pipeline("text2text-generation", model=self.gen_model, tokenizer=self.tokenizer, device=0 if device != -1 else -1)
|
|
|
|
|
|
self.store = DocumentStore()
|
|
self.index = None
|
|
|
|
def build_index(self, documents: List[str], metas: List[dict] = None):
|
|
self.store.add_documents(documents, metas)
|
|
|
|
print(f"Embedding {len(documents)} documents...")
|
|
vectors = self.embedder.encode(documents, convert_to_numpy=True, show_progress_bar=True)
|
|
self.store.embeddings = vectors.astype(np.float32)
|
|
|
|
dim = vectors.shape[1]
|
|
self.index = VectorIndex(dim)
|
|
self.index.fit(self.store.embeddings)
|
|
|
|
self.store.save()
|
|
print("Index built and saved to disk.")
|
|
|
|
def load_index(self):
|
|
self.store.load()
|
|
if self.store.embeddings is None:
|
|
raise RuntimeError("No embeddings found on disk. Build the index first.")
|
|
dim = self.store.embeddings.shape[1]
|
|
self.index = VectorIndex(dim)
|
|
self.index.fit(self.store.embeddings)
|
|
print("Index loaded from disk. Documents:", len(self.store.docs))
|
|
|
|
def retrieve(self, query: str, k: int = K):
|
|
qvec = self.embedder.encode([query], convert_to_numpy=True)
|
|
results = self.index.query(qvec, top_k=k)
|
|
docs = []
|
|
for idx, score in results:
|
|
doc = self.store.docs[idx]
|
|
docs.append({"id": doc["id"], "text": doc["text"], "score": score, "meta": doc.get("meta", {})})
|
|
return docs
|
|
|
|
def generate(self, question: str, retrieved_docs: List[dict], max_length=256, temperature=0.1):
|
|
|
|
|
|
context_texts = "\n\n---\n\n".join([f"[{d['id']}] {d['text']}" for d in retrieved_docs])
|
|
prompt = (
|
|
"You are an assistant that answers questions using the provided context.\n"
|
|
"If the answer is not contained in the context, say 'I don't know.'\n\n"
|
|
"Context:\n"
|
|
f"{context_texts}\n\n"
|
|
"Question:\n"
|
|
f"{question}\n\nAnswer:"
|
|
)
|
|
|
|
out = self.generator(prompt, max_length=max_length, do_sample=False, temperature=temperature, num_return_sequences=1)
|
|
answer = out[0]["generated_text"].strip()
|
|
return {"answer": answer, "prompt": prompt, "retrieved": retrieved_docs}
|
|
|
|
|
|
def read_text_files_from_dir(dir_path: str) -> List[str]:
|
|
texts = []
|
|
for fname in sorted(os.listdir(dir_path)):
|
|
fp = os.path.join(dir_path, fname)
|
|
if os.path.isfile(fp) and fname.lower().endswith((".txt", ".md")):
|
|
with open(fp, "r", encoding="utf-8") as f:
|
|
texts.append(f.read())
|
|
return texts
|
|
|
|
|
|
|
|
app = Flask(__name__)
|
|
rag_system: RAG = None
|
|
|
|
@app.route("/api/ask", methods=["POST"])
|
|
def api_ask():
|
|
payload = request.json
|
|
if not payload or "question" not in payload:
|
|
return jsonify({"error": "Send JSON with 'question' key."}), 400
|
|
question = payload["question"]
|
|
k = payload.get("k", K)
|
|
retrieved = rag_system.retrieve(question, k=k)
|
|
gen = rag_system.generate(question, retrieved, max_length=256)
|
|
return jsonify(gen)
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Simple RAG example")
|
|
parser.add_argument("--build-corpus", type=str, help="Directory of .txt/.md files to build index from")
|
|
parser.add_argument("--ask", type=str, help="Ask a question against the built index")
|
|
parser.add_argument("--host", type=str, default="127.0.0.1", help="Flask host")
|
|
parser.add_argument("--port", type=int, default=5000, help="Flask port")
|
|
parser.add_argument("--use-gpu", action="store_true", help="Use GPU if available (careful)")
|
|
args = parser.parse_args()
|
|
|
|
global rag_system
|
|
device = 0 if args.use_gpu else -1
|
|
rag_system = RAG(device=device)
|
|
|
|
if args.build_corpus:
|
|
docs = read_text_files_from_dir(args.build_corpus)
|
|
if not docs:
|
|
print("No .txt or .md files found in", args.build_corpus)
|
|
return
|
|
rag_system.build_index(docs)
|
|
print("Built index from", len(docs), "documents.")
|
|
return
|
|
|
|
|
|
try:
|
|
rag_system.load_index()
|
|
except Exception as e:
|
|
print("Failed to load index:", e)
|
|
print("If you haven't built the index run: python rag.py --build-corpus ./sample_docs/")
|
|
return
|
|
|
|
if args.ask:
|
|
question = args.ask
|
|
retrieved = rag_system.retrieve(question, k=K)
|
|
print("Retrieved docs (id, score):")
|
|
for d in retrieved:
|
|
preview = d['text'][:160].replace('\n', ' ')
|
|
print(f"- id={d['id']} score={d['score']:.4f} preview={preview}")
|
|
|
|
res = rag_system.generate(question, retrieved)
|
|
print("\n=== ANSWER ===")
|
|
print(res["answer"])
|
|
return
|
|
|
|
|
|
print(f"Starting Flask API on http://{args.host}:{args.port}")
|
|
app.run(host=args.host, port=args.port)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|