File size: 10,675 Bytes
17f7b01 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 |
"""
Advanced RAG + LoRA Fine-tuning Example
File: advanced_rag_lora_finetune.py
What this provides (end-to-end):
1. Build a FAISS-backed document store with chunking and embeddings.
2. Create an instruction-style training dataset that includes: query, retrieved_chunks, target_answer.
3. LoRA (PEFT) fine-tuning that trains the LM to reason with retrieved context.
4. Inference pipeline that does retrieval -> prompt assembly -> generation.
Notes:
- This is a runnable blueprint but will need environment setup and decent GPU for model training.
- Tested concepts: sentence-transformers for embeddings, faiss for vector search, Hugging Face Transformers + PEFT for LoRA.
Requirements (pip):
transformers>=4.30.0
sentence-transformers
faiss-cpu (or faiss-gpu)
datasets
peft
accelerate
torch
Adjust model names, batch sizes, and local paths as needed.
"""
import os
import json
import math
from typing import List, Dict
import torch
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
Trainer,
TrainingArguments,
DataCollatorForLanguageModeling,
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from sentence_transformers import SentenceTransformer
import faiss
from datasets import Dataset, load_metric
# -------------------------
# Config
# -------------------------
CONFIG = {
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
"base_lm": "EleutherAI/gpt-neo-1.3B", # pick a causal LM you can run
"index_path": "./faiss_index.ivf",
"docs_path": "./docs.jsonl", # one doc per line: {"id":..., "text":...}
"chunk_size": 512,
"chunk_overlap": 64,
"top_k": 4,
"output_dir": "./rag_lora_checkpoints",
}
# -------------------------
# Utilities: chunking and building doc store
# -------------------------
def chunk_text(text: str, chunk_size: int = 512, overlap: int = 64) -> List[str]:
tokens = text.split()
chunks = []
i = 0
while i < len(tokens):
chunk = tokens[i : i + chunk_size]
chunks.append(" ".join(chunk))
i += chunk_size - overlap
return chunks
def build_faiss_from_docs(docs_path: str, embedder: SentenceTransformer, cfg=CONFIG):
# read docs
chunks = []
meta = []
with open(docs_path, "r", encoding="utf-8") as f:
for line in f:
obj = json.loads(line)
doc_id = obj.get("id")
text = obj.get("text", "")
for c in chunk_text(text, cfg["chunk_size"], cfg["chunk_overlap"]):
meta.append({"doc_id": doc_id, "text": c})
chunks.append(c)
if len(chunks) == 0:
raise ValueError("No chunks found in docs file")
embeddings = embedder.encode(chunks, show_progress_bar=True, convert_to_numpy=True)
d = embeddings.shape[1]
# Use IVF for larger corpora; for small corpora IndexFlatL2 is fine
index = faiss.IndexFlatL2(d)
index.add(embeddings)
# Save metadata mapping
faiss.write_index(index, cfg.get("index_path", "./faiss_index.ivf"))
meta_path = cfg.get("index_path", "./faiss_index.ivf") + ".meta.json"
with open(meta_path, "w", encoding="utf-8") as mf:
json.dump(meta, mf)
print(f"Built FAISS with {len(chunks)} chunks, dim={d}")
return index, meta
# -------------------------
# Create instruction dataset for retrieval-aware fine-tuning
# -------------------------
# Format: For each example, we store {query, retrieved_texts (string joined), answer}
# We'll assume user will supply a ground-truth answer for each query (supervised).
def create_instruction_examples(queries_with_answers: List[Dict], embedder: SentenceTransformer, cfg=CONFIG):
# queries_with_answers: list of {"query":..., "answer":...}
# Must have a built index & meta
index = faiss.read_index(cfg["index_path"])
with open(cfg["index_path"] + ".meta.json", "r", encoding="utf-8") as mf:
meta = json.load(mf)
examples = []
for qa in queries_with_answers:
q = qa["query"]
ans = qa["answer"]
qvec = embedder.encode([q], convert_to_numpy=True)
D, I = index.search(qvec, cfg["top_k"]) # I is indices
retrieved = [meta[idx]["text"] for idx in I[0]]
retrieved_joined = "\n---\n".join(retrieved)
# prompt template: instruct the model that retrieved context follows and answer concisely
prompt = (
"You are given retrieved documents separated by '---'. Use them to answer the question."
"\nContext:\n" + retrieved_joined + "\nQuestion: " + q + "\nAnswer:"
)
examples.append({"input": prompt, "target": ans})
return examples
# -------------------------
# Prepare HF Dataset for Causal LM training
# -------------------------
def prepare_hf_dataset_for_lm(examples: List[Dict], tokenizer: AutoTokenizer, max_length: int = 1024):
# We will concatenate input + target as a single sequence, and compute labels that mask the input tokens
inputs = []
labels = []
for ex in examples:
inp = ex["input"]
tgt = " " + ex["target"].strip()
full = inp + tgt
enc = tokenizer(full, truncation=True, max_length=max_length)
input_ids = enc["input_ids"]
# find the split point: encode the input alone
enc_inp = tokenizer(inp, truncation=True, max_length=max_length)
len_inp = len(enc_inp["input_ids"])
# labels: -100 for input token positions so loss only computed on target
label_ids = [-100] * len_inp + input_ids[len_inp:]
# pad label_ids to match input length
if len(label_ids) < len(input_ids):
label_ids += [-100] * (len(input_ids) - len(label_ids))
inputs.append({"input_ids": input_ids, "labels": label_ids})
ds = Dataset.from_list(inputs)
return ds
# -------------------------
# LoRA setup and training
# -------------------------
def train_lora_on_dataset(hf_dataset: Dataset, tokenizer: AutoTokenizer, cfg=CONFIG):
model_name = cfg["base_lm"]
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
# prepare for kbit if using quantization (optional), here we use standard float16
# prepare_model_for_kbit_training(model)
# LoRA config
lora_config = LoraConfig(
r=8,
lora_alpha=32,
target_modules=["q_proj", "v_proj"], # often correct for causal transformers
lora_dropout=0.1,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_config)
print("Model wrapped with LoRA. Trainable params:", model.get_peft_config())
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
training_args = TrainingArguments(
output_dir=cfg["output_dir"],
per_device_train_batch_size=1,
gradient_accumulation_steps=8,
num_train_epochs=3,
fp16=True,
logging_steps=10,
save_strategy="epoch",
remove_unused_columns=False,
optim="adamw_torch",
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=hf_dataset,
data_collator=data_collator,
tokenizer=tokenizer,
)
trainer.train()
trainer.save_model(cfg["output_dir"]) # saves peft adapter weights
return model
# -------------------------
# Inference: RAG flow using the fine-tuned LoRA adapter
# -------------------------
def rag_infer(query: str, embedder: SentenceTransformer, tokenizer: AutoTokenizer, model: AutoModelForCausalLM, cfg=CONFIG):
# load index + meta
index = faiss.read_index(cfg["index_path"])
with open(cfg["index_path"] + ".meta.json", "r", encoding="utf-8") as mf:
meta = json.load(mf)
qvec = embedder.encode([query], convert_to_numpy=True)
D, I = index.search(qvec, cfg["top_k"]) # top_k retrieved chunks
retrieved = [meta[idx]["text"] for idx in I[0]]
retrieved_joined = "\n---\n".join(retrieved)
prompt = (
"You are given retrieved documents separated by '---'. Use them to answer the question concisely and cite which chunk you used."
"\nContext:\n" + retrieved_joined + "\nQuestion: " + query + "\nAnswer:"
)
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
out = model.generate(input_ids, max_new_tokens=256, do_sample=False, num_beams=4, early_stopping=True)
ans = tokenizer.decode(out[0], skip_special_tokens=True)
# the generated text will include the prompt; strip prompt from ans
if prompt in ans:
ans = ans.split(prompt)[-1].strip()
return ans
# -------------------------
# Example usage (not executed on import)
# -------------------------
if __name__ == "__main__":
# 1) Build embeddings & FAISS index from docs
embedder = SentenceTransformer(CONFIG["embedding_model"])
# If you haven't yet created your docs.jsonl, create sample docs now
if not os.path.exists(CONFIG["docs_path"]):
sample_docs = [
{"id": "d1", "text": "RAG stands for Retrieval-Augmented Generation. It combines a retriever with a generator."},
{"id": "d2", "text": "LoRA is a parameter efficient fine-tuning method that inserts rank-decomposition matrices into attention layers."},
{"id": "d3", "text": "FAISS is a library for fast nearest neighbor search of vector embeddings."},
]
with open(CONFIG["docs_path"], "w", encoding="utf-8") as f:
for d in sample_docs:
f.write(json.dumps(d) + "\n")
index, meta = build_faiss_from_docs(CONFIG["docs_path"], embedder)
# 2) Create training examples: you'll need labeled (query, answer) pairs.
# For demo, create a tiny dataset (in practice you will have many samples)
queries_with_answers = [
{"query": "What is LoRA?", "answer": "LoRA (Low-Rank Adaptation) is a technique to fine-tune large models by injecting small low-rank matrices into weight updates."},
{"query": "What does FAISS do?", "answer": "FAISS provides scalable similarity search and clustering of dense vectors for retrieval tasks."},
]
examples = create_instruction_examples(queries_with_answers, embedder)
# 3) Prepare HF dataset
tokenizer = AutoTokenizer.from_pretrained(CONFIG["base_lm"])
hf_ds = prepare_hf_dataset_for_lm(examples, tokenizer, max_length=1024)
# 4) Train LoRA on the dataset
model = train_lora_on_dataset(hf_ds, tokenizer, CONFIG)
# 5) Inference
q = "Explain how RAG reduces hallucinations"
ans = rag_infer(q, embedder, tokenizer, model, CONFIG)
print("Answer:\n", ans)
# End of file
|