# handler.py import os import onnxruntime as ort import numpy as np from transformers import AutoTokenizer from typing import Dict, List, Any from colbert_configuration import ColBERTConfig # Import ColBERTConfig # Assuming modeling.py and colbert_configuration.py are in the same directory # We'll use local imports since this handler will run within the model's directory context # For ConstBERT to be recognized, you need to ensure these are importable. # If you run into issues, consider a custom Docker image or ensuring the model # is loadable via AutoModel.from_pretrained if it has auto_map in config.json # For simplicity, we're relying on ConstBERT.from_pretrained working with ONNXRuntime path. # Note: The EndpointHandler class must be named exactly this. class EndpointHandler: def __init__(self, path=""): # path will be '/repository' on HF Endpoints # `path` is the directory where your model files (model.onnx, tokenizer files) are located. # Load the tokenizer self.tokenizer = AutoTokenizer.from_pretrained(path) print(f"Tokenizer loaded from: {path}") # Use the doc_maxlen that the ONNX model was *actually exported with* (250). # This ensures consistency between the handler's tokenizer and the ONNX model's expectation. self.doc_max_length = 250 print(f"Hardcoded doc_maxlen for tokenizer as: {self.doc_max_length}") # NOTE: If you need other colbert_config parameters, you'd load it here, # but for doc_max_length, we are explicitly setting it to avoid mismatches. # self.colbert_config = ColBERTConfig.load_from_checkpoint(path) # self.doc_max_length = self.colbert_config.doc_maxlen # Load the ONNX model onnx_model_path = os.path.join(path, "model.onnx") self.session = ort.InferenceSession(onnx_model_path) print(f"ONNX model loaded from: {onnx_model_path}") # Get input names from the ONNX model self.input_names = [input.name for input in self.session.get_inputs()] print(f"ONNX input names: {self.input_names}") def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ Inference call for the endpoint. Args: data (Dict[str, Any]): The request payload. Expected to contain "inputs" (str or list of str). Returns: List[Dict[str, Any]]: A list of dictionaries, where each dict contains the raw multi-vector output for an input. Example: [{"embedding": [[...], [...], ...]}, ...] """ inputs = data.pop("inputs", None) if inputs is None: raise ValueError("No 'inputs' found in the request payload.") # Ensure inputs is a list if isinstance(inputs, str): inputs = [inputs] # Tokenize the inputs, ensuring consistent padding/truncation to doc_max_length tokenized_inputs = self.tokenizer( inputs, padding="max_length", # Use max_length padding truncation=True, max_length=self.doc_max_length, # Use the loaded doc_max_length return_tensors="np" ) input_ids = tokenized_inputs["input_ids"] attention_mask = tokenized_inputs["attention_mask"] # Prepare ONNX input dictionary onnx_inputs = { "input_ids": input_ids, "attention_mask": attention_mask } # Run ONNX inference outputs = self.session.run(None, onnx_inputs) # The first output is your multi-vector embedding multi_vector_embeddings = outputs[0] # Convert to list of lists (JSON serializable) # Assuming batch_size will be 1 for typical endpoint requests, but handling potential batching from client for robustness. result_list = [] for i in range(multi_vector_embeddings.shape[0]): # Each element in the result_list will be a dictionary for one input, # containing its multi-vector embedding (fixed 32 x 128) result_list.append({"embedding": multi_vector_embeddings[i].tolist()}) return result_list