import itertools import json from datasets import load_dataset import faiss import pandas as pd import numpy as np import torch from huggingface_hub import hf_hub_download from sentence_transformers import SentenceTransformer class InstructionTemplateRetriever: FINETEMPLATES_REVISION = "831ab22c90f9da011bd972585afdf609f40fa54b" RETRIEVAL_EMBEDDING_NAME = "fineinstructions/matching_embedding" RETRIEVAL_EMBEDDING_REVISION = "db4efbde126216250ffa5a356663fc7da3bf7856" def __init__( self, coverage_chunks=10, sigma=0.05, alpha=1.0, nprobe=150, ): """ Computes embeddings that cover a document to find relevant instruction templates using Gaussian-weighted embeddings that cover different parts of the document. Args: coverage_chunks (int): The number of equally sized chunks/sections to get coverage over the entire document. sigma (float): Standard deviation for Gaussian weighting, this will essentially control how "wide" / "focused" each chunk is. alpha (float): A weighting factor to control how much to balance the representation of a single chunk, versus the representation of the entire document. nprobe (int): The number of probes to use when searching the FAISS index (larger is more accurate, but slower). """ self.d = load_dataset( "fineinstructions/finetemplates", revision=InstructionTemplateRetriever.FINETEMPLATES_REVISION, split="full", ) self.m = SentenceTransformer( InstructionTemplateRetriever.RETRIEVAL_EMBEDDING_NAME, revision=InstructionTemplateRetriever.RETRIEVAL_EMBEDDING_REVISION, device="cpu", ) self.m = use_gaussian_coverage_pooling( self.m, coverage_chunks=coverage_chunks, sigma=sigma, alpha=alpha ) self.index = faiss.read_index( hf_hub_download( "fineinstructions/finetemplates", "faiss_index/finetemplates.index", revision=InstructionTemplateRetriever.FINETEMPLATES_REVISION, repo_type="dataset", ), faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY, ) self.index.nprobe = nprobe if torch.cuda.is_available(): self.m = self.m.to("cuda") elif torch.backends.mps.is_available(): self.m = self.m.to("mps") def _filter_rows(self, rows, filter_string): if not rows: return [] df = pd.DataFrame(rows) try: filtered_df = df.query(filter_string) return filtered_df.to_dict(orient="records") except Exception as e: return rows def search( self, document, filters="", search_k=20000, max_results=250, deduplicate=True ): """ Given a document Args: document (str): The document to retrieve relevant instruction templates for. filters (str): A query string in the format of pandas.DataFrame.query() search_k (int): The number of search results to pull when retrieving from FAISS. max_results (int): The max number of results to return. deduplicate (bool): Deduplicate results between coverage sections. """ # Search FAISS index vecs = self.m.encode([document], normalize_embeddings=False).reshape( -1, self.m[0].auto_model.config.hidden_size ) scores_batch, indices_batch = self.index.search(np.vstack(vecs), k=search_k) # Pull in FineTemplates rows into memory to_select = [i.item() for i in itertools.chain.from_iterable(indices_batch)] d_in_mem = { i: row for i, row in zip(to_select, self.d.select(to_select).to_list()) } # Group by coverage chunk true_coverage_chunks = self.m[1].coverage_chunks + 1 scores_per_input, indices_per_input = ( [ scores_batch[i : i + true_coverage_chunks] for i in range(0, len(scores_batch), true_coverage_chunks) ], [ indices_batch[i : i + true_coverage_chunks] for i in range(0, len(indices_batch), true_coverage_chunks) ], ) # Get the results for the first result in the batch (assuming bz=1) scores_per_input, indices_per_input = scores_per_input[0], indices_per_input[0] # Create result rows rows = [ [ { "coverage_section": f"{chunk_idx}/{self.m[1].coverage_chunks}" if chunk_idx > 0 else "Entire Document", "score": s.item(), **d_in_mem[i.item()], } for i, s in zip(indices, scores) ] for chunk_idx, (indices, scores) in enumerate( zip(indices_per_input, scores_per_input) ) ] # Deduplicate if deduplicate: seen = set() rows = [ r for r in itertools.chain.from_iterable(zip(*rows)) if (len(seen) != len(seen.add(r["template_id"]) or seen)) ] else: rows = list(itertools.chain.from_iterable(zip(*rows))) # Filter rows = self._filter_rows(rows, filters)[:max_results] # Return rows return rows