ankitkushwaha90 commited on
Commit
5d64821
·
verified ·
1 Parent(s): 5f75e42

Upload 9 files

Browse files
environment.yml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: rag_demo
2
+ channels:
3
+ - conda-forge
4
+ - defaults
5
+ dependencies:
6
+ - python=3.10
7
+ - pip
8
+ - numpy=1.26.0
9
+ - scikit-learn=1.3.0
10
+ - pip:
11
+ - torch>=2.0.0
12
+ - transformers==4.40.0
13
+ - sentence-transformers==2.2.2
14
+ - faiss-cpu==1.7.4
15
+ - flask==2.2.5
16
+ - uvicorn==0.23.0
rag.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ RAG minimal working example (indexing + retrieval + generation)
3
+
4
+ Features:
5
+ - Embed documents with sentence-transformers
6
+ - Index with FAISS (fallback to sklearn)
7
+ - Generate answers using Hugging Face transformers (FLAN-T5)
8
+ - CLI demo and a tiny Flask API endpoint
9
+
10
+ Usage:
11
+ pip install -r requirements.txt
12
+ python rag.py --build-corpus sample_docs/ # builds index
13
+ python rag.py --ask "What is RAG?" # ask question using built index
14
+
15
+ Author: ChatGPT (example)
16
+ """
17
+
18
+ import os
19
+ import argparse
20
+ import json
21
+ from typing import List, Tuple
22
+ import numpy as np
23
+
24
+ # Embeddings
25
+ from sentence_transformers import SentenceTransformer
26
+
27
+ # Generation
28
+ from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
29
+
30
+ # Indexing: try faiss, fallback to sklearn
31
+ try:
32
+ import faiss
33
+ _HAS_FAISS = True
34
+ except Exception:
35
+ _HAS_FAISS = False
36
+ from sklearn.neighbors import NearestNeighbors
37
+
38
+ # Simple Flask API (optional)
39
+ from flask import Flask, request, jsonify
40
+
41
+ # ---------- Configuration ----------
42
+ EMBEDDING_MODEL = "all-MiniLM-L6-v2" # small, fast, accurate for semantic search
43
+ GEN_MODEL = "google/flan-t5-small" # small and CPU-friendly; swap to larger if GPU available
44
+ INDEX_DIR = "rag_index" # where vectors + docs are stored
45
+ DOCS_JSON = os.path.join(INDEX_DIR, "documents.json")
46
+ VECTORS_NPY = os.path.join(INDEX_DIR, "vectors.npy")
47
+ FAISS_INDEX_FILE = os.path.join(INDEX_DIR, "faiss.index")
48
+
49
+ # RAG hyperparams
50
+ K = 5 # number of documents to retrieve
51
+
52
+
53
+ # ---------- Document Store ----------
54
+ class DocumentStore:
55
+ def __init__(self):
56
+ self.docs: List[dict] = [] # each doc: {"id": str, "text": str, "meta": {...}}
57
+ self.embeddings = None
58
+
59
+ def add_documents(self, texts: List[str], metas: List[dict] = None):
60
+ if metas is None:
61
+ metas = [{} for _ in texts]
62
+ start_id = len(self.docs)
63
+ for i, (t, m) in enumerate(zip(texts, metas)):
64
+ self.docs.append({"id": str(start_id + i), "text": t, "meta": m})
65
+
66
+ def save(self):
67
+ os.makedirs(INDEX_DIR, exist_ok=True)
68
+ with open(DOCS_JSON, "w", encoding="utf-8") as f:
69
+ json.dump(self.docs, f, ensure_ascii=False, indent=2)
70
+ if self.embeddings is not None:
71
+ np.save(VECTORS_NPY, self.embeddings)
72
+
73
+ def load(self):
74
+ if os.path.exists(DOCS_JSON):
75
+ with open(DOCS_JSON, "r", encoding="utf-8") as f:
76
+ self.docs = json.load(f)
77
+ if os.path.exists(VECTORS_NPY):
78
+ self.embeddings = np.load(VECTORS_NPY)
79
+
80
+
81
+ # ---------- Indexer ----------
82
+ class VectorIndex:
83
+ def __init__(self, dim: int):
84
+ self.dim = dim
85
+ self._use_faiss = _HAS_FAISS
86
+ if self._use_faiss:
87
+ # We'll use IndexFlatIP with normalized vectors for cosine similarity
88
+ self.index = faiss.IndexFlatIP(dim)
89
+ else:
90
+ self._nn = None # created during fit
91
+ self.index = None
92
+
93
+ def fit(self, vectors: np.ndarray):
94
+ # vectors assumed shape (n, dim)
95
+ if self._use_faiss:
96
+ # normalize vectors for cosine similarity using inner product
97
+ faiss.normalize_L2(vectors)
98
+ self.index.add(vectors.astype(np.float32))
99
+ # save index
100
+ faiss.write_index(self.index, FAISS_INDEX_FILE)
101
+ else:
102
+ # sklearn NearestNeighbors with cosine metric
103
+ self._nn = NearestNeighbors(n_neighbors=min(10, len(vectors)), metric="cosine")
104
+ self._nn.fit(vectors)
105
+ # store vectors in memory for later
106
+ self.index = vectors
107
+
108
+ def query(self, qvec: np.ndarray, top_k: int = 5) -> List[Tuple[int, float]]:
109
+ """
110
+ Returns list of (doc_idx, score) sorted by highest similarity
111
+ Score is cosine similarity (higher is better). For sklearn fallback, we convert distance.
112
+ """
113
+ if self._use_faiss:
114
+ q = qvec.copy().astype(np.float32)
115
+ faiss.normalize_L2(q)
116
+ distances, indices = self.index.search(q, top_k)
117
+ # distances are inner product (cosine because we normalized)
118
+ results = []
119
+ for idx, dist in zip(indices[0], distances[0]):
120
+ if idx == -1:
121
+ continue
122
+ results.append((int(idx), float(dist)))
123
+ return results
124
+ else:
125
+ # sklearn returns distances; convert to similarity = 1 - distance
126
+ distances, indices = self._nn.kneighbors(qvec, n_neighbors=min(top_k, len(self.index)))
127
+ res = []
128
+ for idx, d in zip(indices[0], distances[0]):
129
+ sim = 1.0 - float(d)
130
+ res.append((int(idx), sim))
131
+ return res
132
+
133
+
134
+ # ---------- RAG System ----------
135
+ class RAG:
136
+ def __init__(self, embedding_model_name=EMBEDDING_MODEL, gen_model_name=GEN_MODEL, device=-1):
137
+ # device: -1 = cpu, otherwise GPU device id (int)
138
+ print("Loading embedding model:", embedding_model_name)
139
+ self.embedder = SentenceTransformer(embedding_model_name, device="cpu" if device == -1 else f"cuda:{device}")
140
+
141
+ print("Loading generator model:", gen_model_name)
142
+ # For seq2seq models like FLAN-T5
143
+ self.tokenizer = AutoTokenizer.from_pretrained(gen_model_name)
144
+ self.gen_model = AutoModelForSeq2SeqLM.from_pretrained(gen_model_name)
145
+ # Use pipeline for simplicity
146
+ self.generator = pipeline("text2text-generation", model=self.gen_model, tokenizer=self.tokenizer, device=0 if device != -1 else -1)
147
+
148
+ # placeholders
149
+ self.store = DocumentStore()
150
+ self.index = None
151
+
152
+ def build_index(self, documents: List[str], metas: List[dict] = None):
153
+ self.store.add_documents(documents, metas)
154
+ # embed
155
+ print(f"Embedding {len(documents)} documents...")
156
+ vectors = self.embedder.encode(documents, convert_to_numpy=True, show_progress_bar=True)
157
+ self.store.embeddings = vectors.astype(np.float32)
158
+ # build index
159
+ dim = vectors.shape[1]
160
+ self.index = VectorIndex(dim)
161
+ self.index.fit(self.store.embeddings)
162
+ # save docs & vectors
163
+ self.store.save()
164
+ print("Index built and saved to disk.")
165
+
166
+ def load_index(self):
167
+ self.store.load()
168
+ if self.store.embeddings is None:
169
+ raise RuntimeError("No embeddings found on disk. Build the index first.")
170
+ dim = self.store.embeddings.shape[1]
171
+ self.index = VectorIndex(dim)
172
+ self.index.fit(self.store.embeddings)
173
+ print("Index loaded from disk. Documents:", len(self.store.docs))
174
+
175
+ def retrieve(self, query: str, k: int = K):
176
+ qvec = self.embedder.encode([query], convert_to_numpy=True)
177
+ results = self.index.query(qvec, top_k=k)
178
+ docs = []
179
+ for idx, score in results:
180
+ doc = self.store.docs[idx]
181
+ docs.append({"id": doc["id"], "text": doc["text"], "score": score, "meta": doc.get("meta", {})})
182
+ return docs
183
+
184
+ def generate(self, question: str, retrieved_docs: List[dict], max_length=256, temperature=0.1):
185
+ # Combine retrieved docs into a context string (short stacking)
186
+ # You may want to do more sophisticated chunking / prompt engineering
187
+ context_texts = "\n\n---\n\n".join([f"[{d['id']}] {d['text']}" for d in retrieved_docs])
188
+ prompt = (
189
+ "You are an assistant that answers questions using the provided context.\n"
190
+ "If the answer is not contained in the context, say 'I don't know.'\n\n"
191
+ "Context:\n"
192
+ f"{context_texts}\n\n"
193
+ "Question:\n"
194
+ f"{question}\n\nAnswer:"
195
+ )
196
+ # Generate
197
+ out = self.generator(prompt, max_length=max_length, do_sample=False, temperature=temperature, num_return_sequences=1)
198
+ answer = out[0]["generated_text"].strip()
199
+ return {"answer": answer, "prompt": prompt, "retrieved": retrieved_docs}
200
+
201
+ # ---------- Utilities ----------
202
+ def read_text_files_from_dir(dir_path: str) -> List[str]:
203
+ texts = []
204
+ for fname in sorted(os.listdir(dir_path)):
205
+ fp = os.path.join(dir_path, fname)
206
+ if os.path.isfile(fp) and fname.lower().endswith((".txt", ".md")):
207
+ with open(fp, "r", encoding="utf-8") as f:
208
+ texts.append(f.read())
209
+ return texts
210
+
211
+
212
+ # ---------- CLI & Flask API ----------
213
+ app = Flask(__name__)
214
+ rag_system: RAG = None # will be set in main
215
+
216
+ @app.route("/api/ask", methods=["POST"])
217
+ def api_ask():
218
+ payload = request.json
219
+ if not payload or "question" not in payload:
220
+ return jsonify({"error": "Send JSON with 'question' key."}), 400
221
+ question = payload["question"]
222
+ k = payload.get("k", K)
223
+ retrieved = rag_system.retrieve(question, k=k)
224
+ gen = rag_system.generate(question, retrieved, max_length=256)
225
+ return jsonify(gen)
226
+
227
+
228
+ def main():
229
+ parser = argparse.ArgumentParser(description="Simple RAG example")
230
+ parser.add_argument("--build-corpus", type=str, help="Directory of .txt/.md files to build index from")
231
+ parser.add_argument("--ask", type=str, help="Ask a question against the built index")
232
+ parser.add_argument("--host", type=str, default="127.0.0.1", help="Flask host")
233
+ parser.add_argument("--port", type=int, default=5000, help="Flask port")
234
+ parser.add_argument("--use-gpu", action="store_true", help="Use GPU if available (careful)")
235
+ args = parser.parse_args()
236
+
237
+ global rag_system
238
+ device = 0 if args.use_gpu else -1
239
+ rag_system = RAG(device=device)
240
+
241
+ if args.build_corpus:
242
+ docs = read_text_files_from_dir(args.build_corpus)
243
+ if not docs:
244
+ print("No .txt or .md files found in", args.build_corpus)
245
+ return
246
+ rag_system.build_index(docs)
247
+ print("Built index from", len(docs), "documents.")
248
+ return
249
+
250
+ # load index and run interactive ask / API
251
+ try:
252
+ rag_system.load_index()
253
+ except Exception as e:
254
+ print("Failed to load index:", e)
255
+ print("If you haven't built the index run: python rag.py --build-corpus ./sample_docs/")
256
+ return
257
+
258
+ if args.ask:
259
+ question = args.ask
260
+ retrieved = rag_system.retrieve(question, k=K)
261
+ print("Retrieved docs (id, score):")
262
+ for d in retrieved:
263
+ preview = d['text'][:160].replace('\n', ' ')
264
+ print(f"- id={d['id']} score={d['score']:.4f} preview={preview}")
265
+
266
+ res = rag_system.generate(question, retrieved)
267
+ print("\n=== ANSWER ===")
268
+ print(res["answer"])
269
+ return
270
+
271
+ # otherwise run API
272
+ print(f"Starting Flask API on http://{args.host}:{args.port}")
273
+ app.run(host=args.host, port=args.port)
274
+
275
+
276
+ if __name__ == "__main__":
277
+ main()
rag_index/documents.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "id": "0",
4
+ "text": "hii, i am doc1.txt. file. which that purpose of in this file requirement fullfill for RAG project file. it is first file.",
5
+ "meta": {}
6
+ },
7
+ {
8
+ "id": "1",
9
+ "text": "hii, i am ankitkushwaha90\nwelcome in this page. it is simple txt file page. which that in this page create for random text write for RAG project.\n",
10
+ "meta": {}
11
+ }
12
+ ]
rag_index/faiss.index ADDED
Binary file (3.12 kB). View file
 
rag_index/vectors.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ce7c611ace7997756f5647d38eeacc8f1f8fa9482f88b00572468e4a9295bec8
3
+ size 3200
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core
2
+ transformers==4.40.0
3
+ sentence-transformers==2.2.2
4
+ torch>=2.0.0
5
+ faiss-cpu==1.7.4 # if you have no GPU; if you have GPU you can swap to faiss-gpu
6
+ flask==2.2.5
7
+ numpy==1.26.0
8
+ scikit-learn==1.3.0
9
+ uvicorn==0.23.0 # optional, if you prefer uvicorn for running the API
10
+ # NOTE: version pinning is indicative – adjust for your environment.
run_rag.bat ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @echo off
2
+ REM Activate your conda environment (replace rag_demo with your env name)
3
+ call conda activate rag_demo
4
+
5
+ REM Build the index from sample_docs folder
6
+ python rag.py --build-corpus sample_docs
7
+
8
+ REM Ask a sample question
9
+ python rag.py --ask "What is Retrieval-Augmented Generation?"
10
+
11
+ pause
sample_docs/doc1.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ hii, i am doc1.txt. file. which that purpose of in this file requirement fullfill for RAG project file. it is first file.
sample_docs/doc2.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ hii, i am ankitkushwaha90
2
+ welcome in this page. it is simple txt file page. which that in this page create for random text write for RAG project.