# handler.py | |
import os | |
import onnxruntime as ort | |
import numpy as np | |
from transformers import AutoTokenizer | |
from typing import Dict, List, Any | |
from colbert_configuration import ColBERTConfig # Import ColBERTConfig | |
# Assuming modeling.py and colbert_configuration.py are in the same directory | |
# We'll use local imports since this handler will run within the model's directory context | |
# For ConstBERT to be recognized, you need to ensure these are importable. | |
# If you run into issues, consider a custom Docker image or ensuring the model | |
# is loadable via AutoModel.from_pretrained if it has auto_map in config.json | |
# For simplicity, we're relying on ConstBERT.from_pretrained working with ONNXRuntime path. | |
# Note: The EndpointHandler class must be named exactly this. | |
class EndpointHandler: | |
def __init__(self, path=""): # path will be '/repository' on HF Endpoints | |
# `path` is the directory where your model files (model.onnx, tokenizer files) are located. | |
# Load the tokenizer | |
self.tokenizer = AutoTokenizer.from_pretrained(path) | |
print(f"Tokenizer loaded from: {path}") | |
# Use the doc_maxlen that the ONNX model was *actually exported with* (250). | |
# This ensures consistency between the handler's tokenizer and the ONNX model's expectation. | |
self.doc_max_length = 250 | |
print(f"Hardcoded doc_maxlen for tokenizer as: {self.doc_max_length}") | |
# NOTE: If you need other colbert_config parameters, you'd load it here, | |
# but for doc_max_length, we are explicitly setting it to avoid mismatches. | |
# self.colbert_config = ColBERTConfig.load_from_checkpoint(path) | |
# self.doc_max_length = self.colbert_config.doc_maxlen | |
# Load the ONNX model | |
onnx_model_path = os.path.join(path, "model.onnx") | |
self.session = ort.InferenceSession(onnx_model_path) | |
print(f"ONNX model loaded from: {onnx_model_path}") | |
# Get input names from the ONNX model | |
self.input_names = [input.name for input in self.session.get_inputs()] | |
print(f"ONNX input names: {self.input_names}") | |
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: | |
""" | |
Inference call for the endpoint. | |
Args: | |
data (Dict[str, Any]): The request payload. | |
Expected to contain "inputs" (str or list of str). | |
Returns: | |
List[Dict[str, Any]]: A list of dictionaries, where each dict | |
contains the raw multi-vector output for an input. | |
Example: [{"embedding": [[...], [...], ...]}, ...] | |
""" | |
inputs = data.pop("inputs", None) | |
if inputs is None: | |
raise ValueError("No 'inputs' found in the request payload.") | |
# Ensure inputs is a list | |
if isinstance(inputs, str): | |
inputs = [inputs] | |
# Tokenize the inputs, ensuring consistent padding/truncation to doc_max_length | |
tokenized_inputs = self.tokenizer( | |
inputs, | |
padding="max_length", # Use max_length padding | |
truncation=True, | |
max_length=self.doc_max_length, # Use the loaded doc_max_length | |
return_tensors="np" | |
) | |
input_ids = tokenized_inputs["input_ids"] | |
attention_mask = tokenized_inputs["attention_mask"] | |
# Prepare ONNX input dictionary | |
onnx_inputs = { | |
"input_ids": input_ids, | |
"attention_mask": attention_mask | |
} | |
# Run ONNX inference | |
outputs = self.session.run(None, onnx_inputs) | |
# The first output is your multi-vector embedding | |
multi_vector_embeddings = outputs[0] | |
# Convert to list of lists (JSON serializable) | |
# Assuming batch_size will be 1 for typical endpoint requests, but handling potential batching from client for robustness. | |
result_list = [] | |
for i in range(multi_vector_embeddings.shape[0]): | |
# Each element in the result_list will be a dictionary for one input, | |
# containing its multi-vector embedding (fixed 32 x 128) | |
result_list.append({"embedding": multi_vector_embeddings[i].tolist()}) | |
return result_list | |