constbert-onnx / handler.py
ag-nexla's picture
updated handler
5d58ea2
# handler.py
import os
import onnxruntime as ort
import numpy as np
from transformers import AutoTokenizer
from typing import Dict, List, Any
from colbert_configuration import ColBERTConfig # Import ColBERTConfig
# Assuming modeling.py and colbert_configuration.py are in the same directory
# We'll use local imports since this handler will run within the model's directory context
# For ConstBERT to be recognized, you need to ensure these are importable.
# If you run into issues, consider a custom Docker image or ensuring the model
# is loadable via AutoModel.from_pretrained if it has auto_map in config.json
# For simplicity, we're relying on ConstBERT.from_pretrained working with ONNXRuntime path.
# Note: The EndpointHandler class must be named exactly this.
class EndpointHandler:
def __init__(self, path=""): # path will be '/repository' on HF Endpoints
# `path` is the directory where your model files (model.onnx, tokenizer files) are located.
# Load the tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(path)
print(f"Tokenizer loaded from: {path}")
# Use the doc_maxlen that the ONNX model was *actually exported with* (250).
# This ensures consistency between the handler's tokenizer and the ONNX model's expectation.
self.doc_max_length = 250
print(f"Hardcoded doc_maxlen for tokenizer as: {self.doc_max_length}")
# NOTE: If you need other colbert_config parameters, you'd load it here,
# but for doc_max_length, we are explicitly setting it to avoid mismatches.
# self.colbert_config = ColBERTConfig.load_from_checkpoint(path)
# self.doc_max_length = self.colbert_config.doc_maxlen
# Load the ONNX model
onnx_model_path = os.path.join(path, "model.onnx")
self.session = ort.InferenceSession(onnx_model_path)
print(f"ONNX model loaded from: {onnx_model_path}")
# Get input names from the ONNX model
self.input_names = [input.name for input in self.session.get_inputs()]
print(f"ONNX input names: {self.input_names}")
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Inference call for the endpoint.
Args:
data (Dict[str, Any]): The request payload.
Expected to contain "inputs" (str or list of str).
Returns:
List[Dict[str, Any]]: A list of dictionaries, where each dict
contains the raw multi-vector output for an input.
Example: [{"embedding": [[...], [...], ...]}, ...]
"""
inputs = data.pop("inputs", None)
if inputs is None:
raise ValueError("No 'inputs' found in the request payload.")
# Ensure inputs is a list
if isinstance(inputs, str):
inputs = [inputs]
# Tokenize the inputs, ensuring consistent padding/truncation to doc_max_length
tokenized_inputs = self.tokenizer(
inputs,
padding="max_length", # Use max_length padding
truncation=True,
max_length=self.doc_max_length, # Use the loaded doc_max_length
return_tensors="np"
)
input_ids = tokenized_inputs["input_ids"]
attention_mask = tokenized_inputs["attention_mask"]
# Prepare ONNX input dictionary
onnx_inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask
}
# Run ONNX inference
outputs = self.session.run(None, onnx_inputs)
# The first output is your multi-vector embedding
multi_vector_embeddings = outputs[0]
# Convert to list of lists (JSON serializable)
# Assuming batch_size will be 1 for typical endpoint requests, but handling potential batching from client for robustness.
result_list = []
for i in range(multi_vector_embeddings.shape[0]):
# Each element in the result_list will be a dictionary for one input,
# containing its multi-vector embedding (fixed 32 x 128)
result_list.append({"embedding": multi_vector_embeddings[i].tolist()})
return result_list