Upload 9 files
Browse files- environment.yml +16 -0
- rag.py +277 -0
- rag_index/documents.json +12 -0
- rag_index/faiss.index +0 -0
- rag_index/vectors.npy +3 -0
- requirements.txt +10 -0
- run_rag.bat +11 -0
- sample_docs/doc1.txt +1 -0
- sample_docs/doc2.txt +2 -0
environment.yml
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: rag_demo
|
2 |
+
channels:
|
3 |
+
- conda-forge
|
4 |
+
- defaults
|
5 |
+
dependencies:
|
6 |
+
- python=3.10
|
7 |
+
- pip
|
8 |
+
- numpy=1.26.0
|
9 |
+
- scikit-learn=1.3.0
|
10 |
+
- pip:
|
11 |
+
- torch>=2.0.0
|
12 |
+
- transformers==4.40.0
|
13 |
+
- sentence-transformers==2.2.2
|
14 |
+
- faiss-cpu==1.7.4
|
15 |
+
- flask==2.2.5
|
16 |
+
- uvicorn==0.23.0
|
rag.py
ADDED
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
RAG minimal working example (indexing + retrieval + generation)
|
3 |
+
|
4 |
+
Features:
|
5 |
+
- Embed documents with sentence-transformers
|
6 |
+
- Index with FAISS (fallback to sklearn)
|
7 |
+
- Generate answers using Hugging Face transformers (FLAN-T5)
|
8 |
+
- CLI demo and a tiny Flask API endpoint
|
9 |
+
|
10 |
+
Usage:
|
11 |
+
pip install -r requirements.txt
|
12 |
+
python rag.py --build-corpus sample_docs/ # builds index
|
13 |
+
python rag.py --ask "What is RAG?" # ask question using built index
|
14 |
+
|
15 |
+
Author: ChatGPT (example)
|
16 |
+
"""
|
17 |
+
|
18 |
+
import os
|
19 |
+
import argparse
|
20 |
+
import json
|
21 |
+
from typing import List, Tuple
|
22 |
+
import numpy as np
|
23 |
+
|
24 |
+
# Embeddings
|
25 |
+
from sentence_transformers import SentenceTransformer
|
26 |
+
|
27 |
+
# Generation
|
28 |
+
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
|
29 |
+
|
30 |
+
# Indexing: try faiss, fallback to sklearn
|
31 |
+
try:
|
32 |
+
import faiss
|
33 |
+
_HAS_FAISS = True
|
34 |
+
except Exception:
|
35 |
+
_HAS_FAISS = False
|
36 |
+
from sklearn.neighbors import NearestNeighbors
|
37 |
+
|
38 |
+
# Simple Flask API (optional)
|
39 |
+
from flask import Flask, request, jsonify
|
40 |
+
|
41 |
+
# ---------- Configuration ----------
|
42 |
+
EMBEDDING_MODEL = "all-MiniLM-L6-v2" # small, fast, accurate for semantic search
|
43 |
+
GEN_MODEL = "google/flan-t5-small" # small and CPU-friendly; swap to larger if GPU available
|
44 |
+
INDEX_DIR = "rag_index" # where vectors + docs are stored
|
45 |
+
DOCS_JSON = os.path.join(INDEX_DIR, "documents.json")
|
46 |
+
VECTORS_NPY = os.path.join(INDEX_DIR, "vectors.npy")
|
47 |
+
FAISS_INDEX_FILE = os.path.join(INDEX_DIR, "faiss.index")
|
48 |
+
|
49 |
+
# RAG hyperparams
|
50 |
+
K = 5 # number of documents to retrieve
|
51 |
+
|
52 |
+
|
53 |
+
# ---------- Document Store ----------
|
54 |
+
class DocumentStore:
|
55 |
+
def __init__(self):
|
56 |
+
self.docs: List[dict] = [] # each doc: {"id": str, "text": str, "meta": {...}}
|
57 |
+
self.embeddings = None
|
58 |
+
|
59 |
+
def add_documents(self, texts: List[str], metas: List[dict] = None):
|
60 |
+
if metas is None:
|
61 |
+
metas = [{} for _ in texts]
|
62 |
+
start_id = len(self.docs)
|
63 |
+
for i, (t, m) in enumerate(zip(texts, metas)):
|
64 |
+
self.docs.append({"id": str(start_id + i), "text": t, "meta": m})
|
65 |
+
|
66 |
+
def save(self):
|
67 |
+
os.makedirs(INDEX_DIR, exist_ok=True)
|
68 |
+
with open(DOCS_JSON, "w", encoding="utf-8") as f:
|
69 |
+
json.dump(self.docs, f, ensure_ascii=False, indent=2)
|
70 |
+
if self.embeddings is not None:
|
71 |
+
np.save(VECTORS_NPY, self.embeddings)
|
72 |
+
|
73 |
+
def load(self):
|
74 |
+
if os.path.exists(DOCS_JSON):
|
75 |
+
with open(DOCS_JSON, "r", encoding="utf-8") as f:
|
76 |
+
self.docs = json.load(f)
|
77 |
+
if os.path.exists(VECTORS_NPY):
|
78 |
+
self.embeddings = np.load(VECTORS_NPY)
|
79 |
+
|
80 |
+
|
81 |
+
# ---------- Indexer ----------
|
82 |
+
class VectorIndex:
|
83 |
+
def __init__(self, dim: int):
|
84 |
+
self.dim = dim
|
85 |
+
self._use_faiss = _HAS_FAISS
|
86 |
+
if self._use_faiss:
|
87 |
+
# We'll use IndexFlatIP with normalized vectors for cosine similarity
|
88 |
+
self.index = faiss.IndexFlatIP(dim)
|
89 |
+
else:
|
90 |
+
self._nn = None # created during fit
|
91 |
+
self.index = None
|
92 |
+
|
93 |
+
def fit(self, vectors: np.ndarray):
|
94 |
+
# vectors assumed shape (n, dim)
|
95 |
+
if self._use_faiss:
|
96 |
+
# normalize vectors for cosine similarity using inner product
|
97 |
+
faiss.normalize_L2(vectors)
|
98 |
+
self.index.add(vectors.astype(np.float32))
|
99 |
+
# save index
|
100 |
+
faiss.write_index(self.index, FAISS_INDEX_FILE)
|
101 |
+
else:
|
102 |
+
# sklearn NearestNeighbors with cosine metric
|
103 |
+
self._nn = NearestNeighbors(n_neighbors=min(10, len(vectors)), metric="cosine")
|
104 |
+
self._nn.fit(vectors)
|
105 |
+
# store vectors in memory for later
|
106 |
+
self.index = vectors
|
107 |
+
|
108 |
+
def query(self, qvec: np.ndarray, top_k: int = 5) -> List[Tuple[int, float]]:
|
109 |
+
"""
|
110 |
+
Returns list of (doc_idx, score) sorted by highest similarity
|
111 |
+
Score is cosine similarity (higher is better). For sklearn fallback, we convert distance.
|
112 |
+
"""
|
113 |
+
if self._use_faiss:
|
114 |
+
q = qvec.copy().astype(np.float32)
|
115 |
+
faiss.normalize_L2(q)
|
116 |
+
distances, indices = self.index.search(q, top_k)
|
117 |
+
# distances are inner product (cosine because we normalized)
|
118 |
+
results = []
|
119 |
+
for idx, dist in zip(indices[0], distances[0]):
|
120 |
+
if idx == -1:
|
121 |
+
continue
|
122 |
+
results.append((int(idx), float(dist)))
|
123 |
+
return results
|
124 |
+
else:
|
125 |
+
# sklearn returns distances; convert to similarity = 1 - distance
|
126 |
+
distances, indices = self._nn.kneighbors(qvec, n_neighbors=min(top_k, len(self.index)))
|
127 |
+
res = []
|
128 |
+
for idx, d in zip(indices[0], distances[0]):
|
129 |
+
sim = 1.0 - float(d)
|
130 |
+
res.append((int(idx), sim))
|
131 |
+
return res
|
132 |
+
|
133 |
+
|
134 |
+
# ---------- RAG System ----------
|
135 |
+
class RAG:
|
136 |
+
def __init__(self, embedding_model_name=EMBEDDING_MODEL, gen_model_name=GEN_MODEL, device=-1):
|
137 |
+
# device: -1 = cpu, otherwise GPU device id (int)
|
138 |
+
print("Loading embedding model:", embedding_model_name)
|
139 |
+
self.embedder = SentenceTransformer(embedding_model_name, device="cpu" if device == -1 else f"cuda:{device}")
|
140 |
+
|
141 |
+
print("Loading generator model:", gen_model_name)
|
142 |
+
# For seq2seq models like FLAN-T5
|
143 |
+
self.tokenizer = AutoTokenizer.from_pretrained(gen_model_name)
|
144 |
+
self.gen_model = AutoModelForSeq2SeqLM.from_pretrained(gen_model_name)
|
145 |
+
# Use pipeline for simplicity
|
146 |
+
self.generator = pipeline("text2text-generation", model=self.gen_model, tokenizer=self.tokenizer, device=0 if device != -1 else -1)
|
147 |
+
|
148 |
+
# placeholders
|
149 |
+
self.store = DocumentStore()
|
150 |
+
self.index = None
|
151 |
+
|
152 |
+
def build_index(self, documents: List[str], metas: List[dict] = None):
|
153 |
+
self.store.add_documents(documents, metas)
|
154 |
+
# embed
|
155 |
+
print(f"Embedding {len(documents)} documents...")
|
156 |
+
vectors = self.embedder.encode(documents, convert_to_numpy=True, show_progress_bar=True)
|
157 |
+
self.store.embeddings = vectors.astype(np.float32)
|
158 |
+
# build index
|
159 |
+
dim = vectors.shape[1]
|
160 |
+
self.index = VectorIndex(dim)
|
161 |
+
self.index.fit(self.store.embeddings)
|
162 |
+
# save docs & vectors
|
163 |
+
self.store.save()
|
164 |
+
print("Index built and saved to disk.")
|
165 |
+
|
166 |
+
def load_index(self):
|
167 |
+
self.store.load()
|
168 |
+
if self.store.embeddings is None:
|
169 |
+
raise RuntimeError("No embeddings found on disk. Build the index first.")
|
170 |
+
dim = self.store.embeddings.shape[1]
|
171 |
+
self.index = VectorIndex(dim)
|
172 |
+
self.index.fit(self.store.embeddings)
|
173 |
+
print("Index loaded from disk. Documents:", len(self.store.docs))
|
174 |
+
|
175 |
+
def retrieve(self, query: str, k: int = K):
|
176 |
+
qvec = self.embedder.encode([query], convert_to_numpy=True)
|
177 |
+
results = self.index.query(qvec, top_k=k)
|
178 |
+
docs = []
|
179 |
+
for idx, score in results:
|
180 |
+
doc = self.store.docs[idx]
|
181 |
+
docs.append({"id": doc["id"], "text": doc["text"], "score": score, "meta": doc.get("meta", {})})
|
182 |
+
return docs
|
183 |
+
|
184 |
+
def generate(self, question: str, retrieved_docs: List[dict], max_length=256, temperature=0.1):
|
185 |
+
# Combine retrieved docs into a context string (short stacking)
|
186 |
+
# You may want to do more sophisticated chunking / prompt engineering
|
187 |
+
context_texts = "\n\n---\n\n".join([f"[{d['id']}] {d['text']}" for d in retrieved_docs])
|
188 |
+
prompt = (
|
189 |
+
"You are an assistant that answers questions using the provided context.\n"
|
190 |
+
"If the answer is not contained in the context, say 'I don't know.'\n\n"
|
191 |
+
"Context:\n"
|
192 |
+
f"{context_texts}\n\n"
|
193 |
+
"Question:\n"
|
194 |
+
f"{question}\n\nAnswer:"
|
195 |
+
)
|
196 |
+
# Generate
|
197 |
+
out = self.generator(prompt, max_length=max_length, do_sample=False, temperature=temperature, num_return_sequences=1)
|
198 |
+
answer = out[0]["generated_text"].strip()
|
199 |
+
return {"answer": answer, "prompt": prompt, "retrieved": retrieved_docs}
|
200 |
+
|
201 |
+
# ---------- Utilities ----------
|
202 |
+
def read_text_files_from_dir(dir_path: str) -> List[str]:
|
203 |
+
texts = []
|
204 |
+
for fname in sorted(os.listdir(dir_path)):
|
205 |
+
fp = os.path.join(dir_path, fname)
|
206 |
+
if os.path.isfile(fp) and fname.lower().endswith((".txt", ".md")):
|
207 |
+
with open(fp, "r", encoding="utf-8") as f:
|
208 |
+
texts.append(f.read())
|
209 |
+
return texts
|
210 |
+
|
211 |
+
|
212 |
+
# ---------- CLI & Flask API ----------
|
213 |
+
app = Flask(__name__)
|
214 |
+
rag_system: RAG = None # will be set in main
|
215 |
+
|
216 |
+
@app.route("/api/ask", methods=["POST"])
|
217 |
+
def api_ask():
|
218 |
+
payload = request.json
|
219 |
+
if not payload or "question" not in payload:
|
220 |
+
return jsonify({"error": "Send JSON with 'question' key."}), 400
|
221 |
+
question = payload["question"]
|
222 |
+
k = payload.get("k", K)
|
223 |
+
retrieved = rag_system.retrieve(question, k=k)
|
224 |
+
gen = rag_system.generate(question, retrieved, max_length=256)
|
225 |
+
return jsonify(gen)
|
226 |
+
|
227 |
+
|
228 |
+
def main():
|
229 |
+
parser = argparse.ArgumentParser(description="Simple RAG example")
|
230 |
+
parser.add_argument("--build-corpus", type=str, help="Directory of .txt/.md files to build index from")
|
231 |
+
parser.add_argument("--ask", type=str, help="Ask a question against the built index")
|
232 |
+
parser.add_argument("--host", type=str, default="127.0.0.1", help="Flask host")
|
233 |
+
parser.add_argument("--port", type=int, default=5000, help="Flask port")
|
234 |
+
parser.add_argument("--use-gpu", action="store_true", help="Use GPU if available (careful)")
|
235 |
+
args = parser.parse_args()
|
236 |
+
|
237 |
+
global rag_system
|
238 |
+
device = 0 if args.use_gpu else -1
|
239 |
+
rag_system = RAG(device=device)
|
240 |
+
|
241 |
+
if args.build_corpus:
|
242 |
+
docs = read_text_files_from_dir(args.build_corpus)
|
243 |
+
if not docs:
|
244 |
+
print("No .txt or .md files found in", args.build_corpus)
|
245 |
+
return
|
246 |
+
rag_system.build_index(docs)
|
247 |
+
print("Built index from", len(docs), "documents.")
|
248 |
+
return
|
249 |
+
|
250 |
+
# load index and run interactive ask / API
|
251 |
+
try:
|
252 |
+
rag_system.load_index()
|
253 |
+
except Exception as e:
|
254 |
+
print("Failed to load index:", e)
|
255 |
+
print("If you haven't built the index run: python rag.py --build-corpus ./sample_docs/")
|
256 |
+
return
|
257 |
+
|
258 |
+
if args.ask:
|
259 |
+
question = args.ask
|
260 |
+
retrieved = rag_system.retrieve(question, k=K)
|
261 |
+
print("Retrieved docs (id, score):")
|
262 |
+
for d in retrieved:
|
263 |
+
preview = d['text'][:160].replace('\n', ' ')
|
264 |
+
print(f"- id={d['id']} score={d['score']:.4f} preview={preview}")
|
265 |
+
|
266 |
+
res = rag_system.generate(question, retrieved)
|
267 |
+
print("\n=== ANSWER ===")
|
268 |
+
print(res["answer"])
|
269 |
+
return
|
270 |
+
|
271 |
+
# otherwise run API
|
272 |
+
print(f"Starting Flask API on http://{args.host}:{args.port}")
|
273 |
+
app.run(host=args.host, port=args.port)
|
274 |
+
|
275 |
+
|
276 |
+
if __name__ == "__main__":
|
277 |
+
main()
|
rag_index/documents.json
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[
|
2 |
+
{
|
3 |
+
"id": "0",
|
4 |
+
"text": "hii, i am doc1.txt. file. which that purpose of in this file requirement fullfill for RAG project file. it is first file.",
|
5 |
+
"meta": {}
|
6 |
+
},
|
7 |
+
{
|
8 |
+
"id": "1",
|
9 |
+
"text": "hii, i am ankitkushwaha90\nwelcome in this page. it is simple txt file page. which that in this page create for random text write for RAG project.\n",
|
10 |
+
"meta": {}
|
11 |
+
}
|
12 |
+
]
|
rag_index/faiss.index
ADDED
Binary file (3.12 kB). View file
|
|
rag_index/vectors.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ce7c611ace7997756f5647d38eeacc8f1f8fa9482f88b00572468e4a9295bec8
|
3 |
+
size 3200
|
requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Core
|
2 |
+
transformers==4.40.0
|
3 |
+
sentence-transformers==2.2.2
|
4 |
+
torch>=2.0.0
|
5 |
+
faiss-cpu==1.7.4 # if you have no GPU; if you have GPU you can swap to faiss-gpu
|
6 |
+
flask==2.2.5
|
7 |
+
numpy==1.26.0
|
8 |
+
scikit-learn==1.3.0
|
9 |
+
uvicorn==0.23.0 # optional, if you prefer uvicorn for running the API
|
10 |
+
# NOTE: version pinning is indicative – adjust for your environment.
|
run_rag.bat
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
@echo off
|
2 |
+
REM Activate your conda environment (replace rag_demo with your env name)
|
3 |
+
call conda activate rag_demo
|
4 |
+
|
5 |
+
REM Build the index from sample_docs folder
|
6 |
+
python rag.py --build-corpus sample_docs
|
7 |
+
|
8 |
+
REM Ask a sample question
|
9 |
+
python rag.py --ask "What is Retrieval-Augmented Generation?"
|
10 |
+
|
11 |
+
pause
|
sample_docs/doc1.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
hii, i am doc1.txt. file. which that purpose of in this file requirement fullfill for RAG project file. it is first file.
|
sample_docs/doc2.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
hii, i am ankitkushwaha90
|
2 |
+
welcome in this page. it is simple txt file page. which that in this page create for random text write for RAG project.
|