RAG_Project / rag.py
ankitkushwaha90's picture
Upload 9 files
5d64821 verified
"""
RAG minimal working example (indexing + retrieval + generation)
Features:
- Embed documents with sentence-transformers
- Index with FAISS (fallback to sklearn)
- Generate answers using Hugging Face transformers (FLAN-T5)
- CLI demo and a tiny Flask API endpoint
Usage:
pip install -r requirements.txt
python rag.py --build-corpus sample_docs/ # builds index
python rag.py --ask "What is RAG?" # ask question using built index
Author: ChatGPT (example)
"""
import os
import argparse
import json
from typing import List, Tuple
import numpy as np
# Embeddings
from sentence_transformers import SentenceTransformer
# Generation
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
# Indexing: try faiss, fallback to sklearn
try:
import faiss
_HAS_FAISS = True
except Exception:
_HAS_FAISS = False
from sklearn.neighbors import NearestNeighbors
# Simple Flask API (optional)
from flask import Flask, request, jsonify
# ---------- Configuration ----------
EMBEDDING_MODEL = "all-MiniLM-L6-v2" # small, fast, accurate for semantic search
GEN_MODEL = "google/flan-t5-small" # small and CPU-friendly; swap to larger if GPU available
INDEX_DIR = "rag_index" # where vectors + docs are stored
DOCS_JSON = os.path.join(INDEX_DIR, "documents.json")
VECTORS_NPY = os.path.join(INDEX_DIR, "vectors.npy")
FAISS_INDEX_FILE = os.path.join(INDEX_DIR, "faiss.index")
# RAG hyperparams
K = 5 # number of documents to retrieve
# ---------- Document Store ----------
class DocumentStore:
def __init__(self):
self.docs: List[dict] = [] # each doc: {"id": str, "text": str, "meta": {...}}
self.embeddings = None
def add_documents(self, texts: List[str], metas: List[dict] = None):
if metas is None:
metas = [{} for _ in texts]
start_id = len(self.docs)
for i, (t, m) in enumerate(zip(texts, metas)):
self.docs.append({"id": str(start_id + i), "text": t, "meta": m})
def save(self):
os.makedirs(INDEX_DIR, exist_ok=True)
with open(DOCS_JSON, "w", encoding="utf-8") as f:
json.dump(self.docs, f, ensure_ascii=False, indent=2)
if self.embeddings is not None:
np.save(VECTORS_NPY, self.embeddings)
def load(self):
if os.path.exists(DOCS_JSON):
with open(DOCS_JSON, "r", encoding="utf-8") as f:
self.docs = json.load(f)
if os.path.exists(VECTORS_NPY):
self.embeddings = np.load(VECTORS_NPY)
# ---------- Indexer ----------
class VectorIndex:
def __init__(self, dim: int):
self.dim = dim
self._use_faiss = _HAS_FAISS
if self._use_faiss:
# We'll use IndexFlatIP with normalized vectors for cosine similarity
self.index = faiss.IndexFlatIP(dim)
else:
self._nn = None # created during fit
self.index = None
def fit(self, vectors: np.ndarray):
# vectors assumed shape (n, dim)
if self._use_faiss:
# normalize vectors for cosine similarity using inner product
faiss.normalize_L2(vectors)
self.index.add(vectors.astype(np.float32))
# save index
faiss.write_index(self.index, FAISS_INDEX_FILE)
else:
# sklearn NearestNeighbors with cosine metric
self._nn = NearestNeighbors(n_neighbors=min(10, len(vectors)), metric="cosine")
self._nn.fit(vectors)
# store vectors in memory for later
self.index = vectors
def query(self, qvec: np.ndarray, top_k: int = 5) -> List[Tuple[int, float]]:
"""
Returns list of (doc_idx, score) sorted by highest similarity
Score is cosine similarity (higher is better). For sklearn fallback, we convert distance.
"""
if self._use_faiss:
q = qvec.copy().astype(np.float32)
faiss.normalize_L2(q)
distances, indices = self.index.search(q, top_k)
# distances are inner product (cosine because we normalized)
results = []
for idx, dist in zip(indices[0], distances[0]):
if idx == -1:
continue
results.append((int(idx), float(dist)))
return results
else:
# sklearn returns distances; convert to similarity = 1 - distance
distances, indices = self._nn.kneighbors(qvec, n_neighbors=min(top_k, len(self.index)))
res = []
for idx, d in zip(indices[0], distances[0]):
sim = 1.0 - float(d)
res.append((int(idx), sim))
return res
# ---------- RAG System ----------
class RAG:
def __init__(self, embedding_model_name=EMBEDDING_MODEL, gen_model_name=GEN_MODEL, device=-1):
# device: -1 = cpu, otherwise GPU device id (int)
print("Loading embedding model:", embedding_model_name)
self.embedder = SentenceTransformer(embedding_model_name, device="cpu" if device == -1 else f"cuda:{device}")
print("Loading generator model:", gen_model_name)
# For seq2seq models like FLAN-T5
self.tokenizer = AutoTokenizer.from_pretrained(gen_model_name)
self.gen_model = AutoModelForSeq2SeqLM.from_pretrained(gen_model_name)
# Use pipeline for simplicity
self.generator = pipeline("text2text-generation", model=self.gen_model, tokenizer=self.tokenizer, device=0 if device != -1 else -1)
# placeholders
self.store = DocumentStore()
self.index = None
def build_index(self, documents: List[str], metas: List[dict] = None):
self.store.add_documents(documents, metas)
# embed
print(f"Embedding {len(documents)} documents...")
vectors = self.embedder.encode(documents, convert_to_numpy=True, show_progress_bar=True)
self.store.embeddings = vectors.astype(np.float32)
# build index
dim = vectors.shape[1]
self.index = VectorIndex(dim)
self.index.fit(self.store.embeddings)
# save docs & vectors
self.store.save()
print("Index built and saved to disk.")
def load_index(self):
self.store.load()
if self.store.embeddings is None:
raise RuntimeError("No embeddings found on disk. Build the index first.")
dim = self.store.embeddings.shape[1]
self.index = VectorIndex(dim)
self.index.fit(self.store.embeddings)
print("Index loaded from disk. Documents:", len(self.store.docs))
def retrieve(self, query: str, k: int = K):
qvec = self.embedder.encode([query], convert_to_numpy=True)
results = self.index.query(qvec, top_k=k)
docs = []
for idx, score in results:
doc = self.store.docs[idx]
docs.append({"id": doc["id"], "text": doc["text"], "score": score, "meta": doc.get("meta", {})})
return docs
def generate(self, question: str, retrieved_docs: List[dict], max_length=256, temperature=0.1):
# Combine retrieved docs into a context string (short stacking)
# You may want to do more sophisticated chunking / prompt engineering
context_texts = "\n\n---\n\n".join([f"[{d['id']}] {d['text']}" for d in retrieved_docs])
prompt = (
"You are an assistant that answers questions using the provided context.\n"
"If the answer is not contained in the context, say 'I don't know.'\n\n"
"Context:\n"
f"{context_texts}\n\n"
"Question:\n"
f"{question}\n\nAnswer:"
)
# Generate
out = self.generator(prompt, max_length=max_length, do_sample=False, temperature=temperature, num_return_sequences=1)
answer = out[0]["generated_text"].strip()
return {"answer": answer, "prompt": prompt, "retrieved": retrieved_docs}
# ---------- Utilities ----------
def read_text_files_from_dir(dir_path: str) -> List[str]:
texts = []
for fname in sorted(os.listdir(dir_path)):
fp = os.path.join(dir_path, fname)
if os.path.isfile(fp) and fname.lower().endswith((".txt", ".md")):
with open(fp, "r", encoding="utf-8") as f:
texts.append(f.read())
return texts
# ---------- CLI & Flask API ----------
app = Flask(__name__)
rag_system: RAG = None # will be set in main
@app.route("/api/ask", methods=["POST"])
def api_ask():
payload = request.json
if not payload or "question" not in payload:
return jsonify({"error": "Send JSON with 'question' key."}), 400
question = payload["question"]
k = payload.get("k", K)
retrieved = rag_system.retrieve(question, k=k)
gen = rag_system.generate(question, retrieved, max_length=256)
return jsonify(gen)
def main():
parser = argparse.ArgumentParser(description="Simple RAG example")
parser.add_argument("--build-corpus", type=str, help="Directory of .txt/.md files to build index from")
parser.add_argument("--ask", type=str, help="Ask a question against the built index")
parser.add_argument("--host", type=str, default="127.0.0.1", help="Flask host")
parser.add_argument("--port", type=int, default=5000, help="Flask port")
parser.add_argument("--use-gpu", action="store_true", help="Use GPU if available (careful)")
args = parser.parse_args()
global rag_system
device = 0 if args.use_gpu else -1
rag_system = RAG(device=device)
if args.build_corpus:
docs = read_text_files_from_dir(args.build_corpus)
if not docs:
print("No .txt or .md files found in", args.build_corpus)
return
rag_system.build_index(docs)
print("Built index from", len(docs), "documents.")
return
# load index and run interactive ask / API
try:
rag_system.load_index()
except Exception as e:
print("Failed to load index:", e)
print("If you haven't built the index run: python rag.py --build-corpus ./sample_docs/")
return
if args.ask:
question = args.ask
retrieved = rag_system.retrieve(question, k=K)
print("Retrieved docs (id, score):")
for d in retrieved:
preview = d['text'][:160].replace('\n', ' ')
print(f"- id={d['id']} score={d['score']:.4f} preview={preview}")
res = rag_system.generate(question, retrieved)
print("\n=== ANSWER ===")
print(res["answer"])
return
# otherwise run API
print(f"Starting Flask API on http://{args.host}:{args.port}")
app.run(host=args.host, port=args.port)
if __name__ == "__main__":
main()