Advanced_Rag_Lora_Finetune / Advanced_Rag_Lora_Finetune.py
ankitkushwaha90's picture
Create Advanced_Rag_Lora_Finetune.py
17f7b01 verified
"""
Advanced RAG + LoRA Fine-tuning Example
File: advanced_rag_lora_finetune.py
What this provides (end-to-end):
1. Build a FAISS-backed document store with chunking and embeddings.
2. Create an instruction-style training dataset that includes: query, retrieved_chunks, target_answer.
3. LoRA (PEFT) fine-tuning that trains the LM to reason with retrieved context.
4. Inference pipeline that does retrieval -> prompt assembly -> generation.
Notes:
- This is a runnable blueprint but will need environment setup and decent GPU for model training.
- Tested concepts: sentence-transformers for embeddings, faiss for vector search, Hugging Face Transformers + PEFT for LoRA.
Requirements (pip):
transformers>=4.30.0
sentence-transformers
faiss-cpu (or faiss-gpu)
datasets
peft
accelerate
torch
Adjust model names, batch sizes, and local paths as needed.
"""
import os
import json
import math
from typing import List, Dict
import torch
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
Trainer,
TrainingArguments,
DataCollatorForLanguageModeling,
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from sentence_transformers import SentenceTransformer
import faiss
from datasets import Dataset, load_metric
# -------------------------
# Config
# -------------------------
CONFIG = {
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
"base_lm": "EleutherAI/gpt-neo-1.3B", # pick a causal LM you can run
"index_path": "./faiss_index.ivf",
"docs_path": "./docs.jsonl", # one doc per line: {"id":..., "text":...}
"chunk_size": 512,
"chunk_overlap": 64,
"top_k": 4,
"output_dir": "./rag_lora_checkpoints",
}
# -------------------------
# Utilities: chunking and building doc store
# -------------------------
def chunk_text(text: str, chunk_size: int = 512, overlap: int = 64) -> List[str]:
tokens = text.split()
chunks = []
i = 0
while i < len(tokens):
chunk = tokens[i : i + chunk_size]
chunks.append(" ".join(chunk))
i += chunk_size - overlap
return chunks
def build_faiss_from_docs(docs_path: str, embedder: SentenceTransformer, cfg=CONFIG):
# read docs
chunks = []
meta = []
with open(docs_path, "r", encoding="utf-8") as f:
for line in f:
obj = json.loads(line)
doc_id = obj.get("id")
text = obj.get("text", "")
for c in chunk_text(text, cfg["chunk_size"], cfg["chunk_overlap"]):
meta.append({"doc_id": doc_id, "text": c})
chunks.append(c)
if len(chunks) == 0:
raise ValueError("No chunks found in docs file")
embeddings = embedder.encode(chunks, show_progress_bar=True, convert_to_numpy=True)
d = embeddings.shape[1]
# Use IVF for larger corpora; for small corpora IndexFlatL2 is fine
index = faiss.IndexFlatL2(d)
index.add(embeddings)
# Save metadata mapping
faiss.write_index(index, cfg.get("index_path", "./faiss_index.ivf"))
meta_path = cfg.get("index_path", "./faiss_index.ivf") + ".meta.json"
with open(meta_path, "w", encoding="utf-8") as mf:
json.dump(meta, mf)
print(f"Built FAISS with {len(chunks)} chunks, dim={d}")
return index, meta
# -------------------------
# Create instruction dataset for retrieval-aware fine-tuning
# -------------------------
# Format: For each example, we store {query, retrieved_texts (string joined), answer}
# We'll assume user will supply a ground-truth answer for each query (supervised).
def create_instruction_examples(queries_with_answers: List[Dict], embedder: SentenceTransformer, cfg=CONFIG):
# queries_with_answers: list of {"query":..., "answer":...}
# Must have a built index & meta
index = faiss.read_index(cfg["index_path"])
with open(cfg["index_path"] + ".meta.json", "r", encoding="utf-8") as mf:
meta = json.load(mf)
examples = []
for qa in queries_with_answers:
q = qa["query"]
ans = qa["answer"]
qvec = embedder.encode([q], convert_to_numpy=True)
D, I = index.search(qvec, cfg["top_k"]) # I is indices
retrieved = [meta[idx]["text"] for idx in I[0]]
retrieved_joined = "\n---\n".join(retrieved)
# prompt template: instruct the model that retrieved context follows and answer concisely
prompt = (
"You are given retrieved documents separated by '---'. Use them to answer the question."
"\nContext:\n" + retrieved_joined + "\nQuestion: " + q + "\nAnswer:"
)
examples.append({"input": prompt, "target": ans})
return examples
# -------------------------
# Prepare HF Dataset for Causal LM training
# -------------------------
def prepare_hf_dataset_for_lm(examples: List[Dict], tokenizer: AutoTokenizer, max_length: int = 1024):
# We will concatenate input + target as a single sequence, and compute labels that mask the input tokens
inputs = []
labels = []
for ex in examples:
inp = ex["input"]
tgt = " " + ex["target"].strip()
full = inp + tgt
enc = tokenizer(full, truncation=True, max_length=max_length)
input_ids = enc["input_ids"]
# find the split point: encode the input alone
enc_inp = tokenizer(inp, truncation=True, max_length=max_length)
len_inp = len(enc_inp["input_ids"])
# labels: -100 for input token positions so loss only computed on target
label_ids = [-100] * len_inp + input_ids[len_inp:]
# pad label_ids to match input length
if len(label_ids) < len(input_ids):
label_ids += [-100] * (len(input_ids) - len(label_ids))
inputs.append({"input_ids": input_ids, "labels": label_ids})
ds = Dataset.from_list(inputs)
return ds
# -------------------------
# LoRA setup and training
# -------------------------
def train_lora_on_dataset(hf_dataset: Dataset, tokenizer: AutoTokenizer, cfg=CONFIG):
model_name = cfg["base_lm"]
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
# prepare for kbit if using quantization (optional), here we use standard float16
# prepare_model_for_kbit_training(model)
# LoRA config
lora_config = LoraConfig(
r=8,
lora_alpha=32,
target_modules=["q_proj", "v_proj"], # often correct for causal transformers
lora_dropout=0.1,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_config)
print("Model wrapped with LoRA. Trainable params:", model.get_peft_config())
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
training_args = TrainingArguments(
output_dir=cfg["output_dir"],
per_device_train_batch_size=1,
gradient_accumulation_steps=8,
num_train_epochs=3,
fp16=True,
logging_steps=10,
save_strategy="epoch",
remove_unused_columns=False,
optim="adamw_torch",
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=hf_dataset,
data_collator=data_collator,
tokenizer=tokenizer,
)
trainer.train()
trainer.save_model(cfg["output_dir"]) # saves peft adapter weights
return model
# -------------------------
# Inference: RAG flow using the fine-tuned LoRA adapter
# -------------------------
def rag_infer(query: str, embedder: SentenceTransformer, tokenizer: AutoTokenizer, model: AutoModelForCausalLM, cfg=CONFIG):
# load index + meta
index = faiss.read_index(cfg["index_path"])
with open(cfg["index_path"] + ".meta.json", "r", encoding="utf-8") as mf:
meta = json.load(mf)
qvec = embedder.encode([query], convert_to_numpy=True)
D, I = index.search(qvec, cfg["top_k"]) # top_k retrieved chunks
retrieved = [meta[idx]["text"] for idx in I[0]]
retrieved_joined = "\n---\n".join(retrieved)
prompt = (
"You are given retrieved documents separated by '---'. Use them to answer the question concisely and cite which chunk you used."
"\nContext:\n" + retrieved_joined + "\nQuestion: " + query + "\nAnswer:"
)
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
out = model.generate(input_ids, max_new_tokens=256, do_sample=False, num_beams=4, early_stopping=True)
ans = tokenizer.decode(out[0], skip_special_tokens=True)
# the generated text will include the prompt; strip prompt from ans
if prompt in ans:
ans = ans.split(prompt)[-1].strip()
return ans
# -------------------------
# Example usage (not executed on import)
# -------------------------
if __name__ == "__main__":
# 1) Build embeddings & FAISS index from docs
embedder = SentenceTransformer(CONFIG["embedding_model"])
# If you haven't yet created your docs.jsonl, create sample docs now
if not os.path.exists(CONFIG["docs_path"]):
sample_docs = [
{"id": "d1", "text": "RAG stands for Retrieval-Augmented Generation. It combines a retriever with a generator."},
{"id": "d2", "text": "LoRA is a parameter efficient fine-tuning method that inserts rank-decomposition matrices into attention layers."},
{"id": "d3", "text": "FAISS is a library for fast nearest neighbor search of vector embeddings."},
]
with open(CONFIG["docs_path"], "w", encoding="utf-8") as f:
for d in sample_docs:
f.write(json.dumps(d) + "\n")
index, meta = build_faiss_from_docs(CONFIG["docs_path"], embedder)
# 2) Create training examples: you'll need labeled (query, answer) pairs.
# For demo, create a tiny dataset (in practice you will have many samples)
queries_with_answers = [
{"query": "What is LoRA?", "answer": "LoRA (Low-Rank Adaptation) is a technique to fine-tune large models by injecting small low-rank matrices into weight updates."},
{"query": "What does FAISS do?", "answer": "FAISS provides scalable similarity search and clustering of dense vectors for retrieval tasks."},
]
examples = create_instruction_examples(queries_with_answers, embedder)
# 3) Prepare HF dataset
tokenizer = AutoTokenizer.from_pretrained(CONFIG["base_lm"])
hf_ds = prepare_hf_dataset_for_lm(examples, tokenizer, max_length=1024)
# 4) Train LoRA on the dataset
model = train_lora_on_dataset(hf_ds, tokenizer, CONFIG)
# 5) Inference
q = "Explain how RAG reduces hallucinations"
ans = rag_infer(q, embedder, tokenizer, model, CONFIG)
print("Answer:\n", ans)
# End of file