|
""" |
|
Advanced RAG + LoRA Fine-tuning Example |
|
File: advanced_rag_lora_finetune.py |
|
|
|
What this provides (end-to-end): |
|
1. Build a FAISS-backed document store with chunking and embeddings. |
|
2. Create an instruction-style training dataset that includes: query, retrieved_chunks, target_answer. |
|
3. LoRA (PEFT) fine-tuning that trains the LM to reason with retrieved context. |
|
4. Inference pipeline that does retrieval -> prompt assembly -> generation. |
|
|
|
Notes: |
|
- This is a runnable blueprint but will need environment setup and decent GPU for model training. |
|
- Tested concepts: sentence-transformers for embeddings, faiss for vector search, Hugging Face Transformers + PEFT for LoRA. |
|
|
|
Requirements (pip): |
|
transformers>=4.30.0 |
|
sentence-transformers |
|
faiss-cpu (or faiss-gpu) |
|
datasets |
|
peft |
|
accelerate |
|
torch |
|
|
|
Adjust model names, batch sizes, and local paths as needed. |
|
""" |
|
|
|
import os |
|
import json |
|
import math |
|
from typing import List, Dict |
|
|
|
import torch |
|
from transformers import ( |
|
AutoTokenizer, |
|
AutoModelForCausalLM, |
|
Trainer, |
|
TrainingArguments, |
|
DataCollatorForLanguageModeling, |
|
) |
|
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training |
|
from sentence_transformers import SentenceTransformer |
|
import faiss |
|
from datasets import Dataset, load_metric |
|
|
|
|
|
|
|
|
|
CONFIG = { |
|
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2", |
|
"base_lm": "EleutherAI/gpt-neo-1.3B", |
|
"index_path": "./faiss_index.ivf", |
|
"docs_path": "./docs.jsonl", |
|
"chunk_size": 512, |
|
"chunk_overlap": 64, |
|
"top_k": 4, |
|
"output_dir": "./rag_lora_checkpoints", |
|
} |
|
|
|
|
|
|
|
|
|
|
|
def chunk_text(text: str, chunk_size: int = 512, overlap: int = 64) -> List[str]: |
|
tokens = text.split() |
|
chunks = [] |
|
i = 0 |
|
while i < len(tokens): |
|
chunk = tokens[i : i + chunk_size] |
|
chunks.append(" ".join(chunk)) |
|
i += chunk_size - overlap |
|
return chunks |
|
|
|
|
|
def build_faiss_from_docs(docs_path: str, embedder: SentenceTransformer, cfg=CONFIG): |
|
|
|
chunks = [] |
|
meta = [] |
|
with open(docs_path, "r", encoding="utf-8") as f: |
|
for line in f: |
|
obj = json.loads(line) |
|
doc_id = obj.get("id") |
|
text = obj.get("text", "") |
|
for c in chunk_text(text, cfg["chunk_size"], cfg["chunk_overlap"]): |
|
meta.append({"doc_id": doc_id, "text": c}) |
|
chunks.append(c) |
|
|
|
if len(chunks) == 0: |
|
raise ValueError("No chunks found in docs file") |
|
|
|
embeddings = embedder.encode(chunks, show_progress_bar=True, convert_to_numpy=True) |
|
d = embeddings.shape[1] |
|
|
|
|
|
index = faiss.IndexFlatL2(d) |
|
index.add(embeddings) |
|
|
|
|
|
faiss.write_index(index, cfg.get("index_path", "./faiss_index.ivf")) |
|
meta_path = cfg.get("index_path", "./faiss_index.ivf") + ".meta.json" |
|
with open(meta_path, "w", encoding="utf-8") as mf: |
|
json.dump(meta, mf) |
|
|
|
print(f"Built FAISS with {len(chunks)} chunks, dim={d}") |
|
return index, meta |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_instruction_examples(queries_with_answers: List[Dict], embedder: SentenceTransformer, cfg=CONFIG): |
|
|
|
|
|
index = faiss.read_index(cfg["index_path"]) |
|
with open(cfg["index_path"] + ".meta.json", "r", encoding="utf-8") as mf: |
|
meta = json.load(mf) |
|
|
|
examples = [] |
|
for qa in queries_with_answers: |
|
q = qa["query"] |
|
ans = qa["answer"] |
|
qvec = embedder.encode([q], convert_to_numpy=True) |
|
D, I = index.search(qvec, cfg["top_k"]) |
|
retrieved = [meta[idx]["text"] for idx in I[0]] |
|
retrieved_joined = "\n---\n".join(retrieved) |
|
|
|
|
|
prompt = ( |
|
"You are given retrieved documents separated by '---'. Use them to answer the question." |
|
"\nContext:\n" + retrieved_joined + "\nQuestion: " + q + "\nAnswer:" |
|
) |
|
examples.append({"input": prompt, "target": ans}) |
|
|
|
return examples |
|
|
|
|
|
|
|
|
|
|
|
def prepare_hf_dataset_for_lm(examples: List[Dict], tokenizer: AutoTokenizer, max_length: int = 1024): |
|
|
|
inputs = [] |
|
labels = [] |
|
for ex in examples: |
|
inp = ex["input"] |
|
tgt = " " + ex["target"].strip() |
|
full = inp + tgt |
|
enc = tokenizer(full, truncation=True, max_length=max_length) |
|
input_ids = enc["input_ids"] |
|
|
|
enc_inp = tokenizer(inp, truncation=True, max_length=max_length) |
|
len_inp = len(enc_inp["input_ids"]) |
|
|
|
label_ids = [-100] * len_inp + input_ids[len_inp:] |
|
|
|
if len(label_ids) < len(input_ids): |
|
label_ids += [-100] * (len(input_ids) - len(label_ids)) |
|
inputs.append({"input_ids": input_ids, "labels": label_ids}) |
|
ds = Dataset.from_list(inputs) |
|
return ds |
|
|
|
|
|
|
|
|
|
|
|
def train_lora_on_dataset(hf_dataset: Dataset, tokenizer: AutoTokenizer, cfg=CONFIG): |
|
model_name = cfg["base_lm"] |
|
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16) |
|
|
|
|
|
|
|
|
|
|
|
lora_config = LoraConfig( |
|
r=8, |
|
lora_alpha=32, |
|
target_modules=["q_proj", "v_proj"], |
|
lora_dropout=0.1, |
|
bias="none", |
|
task_type="CAUSAL_LM", |
|
) |
|
|
|
model = get_peft_model(model, lora_config) |
|
print("Model wrapped with LoRA. Trainable params:", model.get_peft_config()) |
|
|
|
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) |
|
|
|
training_args = TrainingArguments( |
|
output_dir=cfg["output_dir"], |
|
per_device_train_batch_size=1, |
|
gradient_accumulation_steps=8, |
|
num_train_epochs=3, |
|
fp16=True, |
|
logging_steps=10, |
|
save_strategy="epoch", |
|
remove_unused_columns=False, |
|
optim="adamw_torch", |
|
) |
|
|
|
trainer = Trainer( |
|
model=model, |
|
args=training_args, |
|
train_dataset=hf_dataset, |
|
data_collator=data_collator, |
|
tokenizer=tokenizer, |
|
) |
|
|
|
trainer.train() |
|
trainer.save_model(cfg["output_dir"]) |
|
return model |
|
|
|
|
|
|
|
|
|
|
|
def rag_infer(query: str, embedder: SentenceTransformer, tokenizer: AutoTokenizer, model: AutoModelForCausalLM, cfg=CONFIG): |
|
|
|
index = faiss.read_index(cfg["index_path"]) |
|
with open(cfg["index_path"] + ".meta.json", "r", encoding="utf-8") as mf: |
|
meta = json.load(mf) |
|
|
|
qvec = embedder.encode([query], convert_to_numpy=True) |
|
D, I = index.search(qvec, cfg["top_k"]) |
|
retrieved = [meta[idx]["text"] for idx in I[0]] |
|
retrieved_joined = "\n---\n".join(retrieved) |
|
|
|
prompt = ( |
|
"You are given retrieved documents separated by '---'. Use them to answer the question concisely and cite which chunk you used." |
|
"\nContext:\n" + retrieved_joined + "\nQuestion: " + query + "\nAnswer:" |
|
) |
|
|
|
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device) |
|
out = model.generate(input_ids, max_new_tokens=256, do_sample=False, num_beams=4, early_stopping=True) |
|
ans = tokenizer.decode(out[0], skip_special_tokens=True) |
|
|
|
if prompt in ans: |
|
ans = ans.split(prompt)[-1].strip() |
|
return ans |
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
embedder = SentenceTransformer(CONFIG["embedding_model"]) |
|
|
|
|
|
if not os.path.exists(CONFIG["docs_path"]): |
|
sample_docs = [ |
|
{"id": "d1", "text": "RAG stands for Retrieval-Augmented Generation. It combines a retriever with a generator."}, |
|
{"id": "d2", "text": "LoRA is a parameter efficient fine-tuning method that inserts rank-decomposition matrices into attention layers."}, |
|
{"id": "d3", "text": "FAISS is a library for fast nearest neighbor search of vector embeddings."}, |
|
] |
|
with open(CONFIG["docs_path"], "w", encoding="utf-8") as f: |
|
for d in sample_docs: |
|
f.write(json.dumps(d) + "\n") |
|
|
|
index, meta = build_faiss_from_docs(CONFIG["docs_path"], embedder) |
|
|
|
|
|
|
|
queries_with_answers = [ |
|
{"query": "What is LoRA?", "answer": "LoRA (Low-Rank Adaptation) is a technique to fine-tune large models by injecting small low-rank matrices into weight updates."}, |
|
{"query": "What does FAISS do?", "answer": "FAISS provides scalable similarity search and clustering of dense vectors for retrieval tasks."}, |
|
] |
|
|
|
examples = create_instruction_examples(queries_with_answers, embedder) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(CONFIG["base_lm"]) |
|
hf_ds = prepare_hf_dataset_for_lm(examples, tokenizer, max_length=1024) |
|
|
|
|
|
model = train_lora_on_dataset(hf_ds, tokenizer, CONFIG) |
|
|
|
|
|
q = "Explain how RAG reduces hallucinations" |
|
ans = rag_infer(q, embedder, tokenizer, model, CONFIG) |
|
print("Answer:\n", ans) |
|
|
|
|
|
|
|
|