File size: 4,250 Bytes
2336bf5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5d58ea2
 
 
 
 
 
 
 
 
2336bf5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
# handler.py
import os
import onnxruntime as ort
import numpy as np
from transformers import AutoTokenizer
from typing import Dict, List, Any
from colbert_configuration import ColBERTConfig # Import ColBERTConfig

# Assuming modeling.py and colbert_configuration.py are in the same directory
# We'll use local imports since this handler will run within the model's directory context
# For ConstBERT to be recognized, you need to ensure these are importable.
# If you run into issues, consider a custom Docker image or ensuring the model
# is loadable via AutoModel.from_pretrained if it has auto_map in config.json
# For simplicity, we're relying on ConstBERT.from_pretrained working with ONNXRuntime path.

# Note: The EndpointHandler class must be named exactly this.
class EndpointHandler:
    def __init__(self, path=""): # path will be '/repository' on HF Endpoints
        # `path` is the directory where your model files (model.onnx, tokenizer files) are located.

        # Load the tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(path)
        print(f"Tokenizer loaded from: {path}")

        # Use the doc_maxlen that the ONNX model was *actually exported with* (250).
        # This ensures consistency between the handler's tokenizer and the ONNX model's expectation.
        self.doc_max_length = 250
        print(f"Hardcoded doc_maxlen for tokenizer as: {self.doc_max_length}")

        # NOTE: If you need other colbert_config parameters, you'd load it here,
        # but for doc_max_length, we are explicitly setting it to avoid mismatches.
        # self.colbert_config = ColBERTConfig.load_from_checkpoint(path)
        # self.doc_max_length = self.colbert_config.doc_maxlen

        # Load the ONNX model
        onnx_model_path = os.path.join(path, "model.onnx")
        self.session = ort.InferenceSession(onnx_model_path)
        print(f"ONNX model loaded from: {onnx_model_path}")

        # Get input names from the ONNX model
        self.input_names = [input.name for input in self.session.get_inputs()]
        print(f"ONNX input names: {self.input_names}")


    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
        Inference call for the endpoint.

        Args:
            data (Dict[str, Any]): The request payload.
                                   Expected to contain "inputs" (str or list of str).

        Returns:
            List[Dict[str, Any]]: A list of dictionaries, where each dict
                                  contains the raw multi-vector output for an input.
                                  Example: [{"embedding": [[...], [...], ...]}, ...]
        """
        inputs = data.pop("inputs", None)
        if inputs is None:
            raise ValueError("No 'inputs' found in the request payload.")

        # Ensure inputs is a list
        if isinstance(inputs, str):
            inputs = [inputs]

        # Tokenize the inputs, ensuring consistent padding/truncation to doc_max_length
        tokenized_inputs = self.tokenizer(
            inputs,
            padding="max_length",  # Use max_length padding
            truncation=True,
            max_length=self.doc_max_length, # Use the loaded doc_max_length
            return_tensors="np"
        )

        input_ids = tokenized_inputs["input_ids"]
        attention_mask = tokenized_inputs["attention_mask"]

        # Prepare ONNX input dictionary
        onnx_inputs = {
            "input_ids": input_ids,
            "attention_mask": attention_mask
        }

        # Run ONNX inference
        outputs = self.session.run(None, onnx_inputs)

        # The first output is your multi-vector embedding
        multi_vector_embeddings = outputs[0]

        # Convert to list of lists (JSON serializable)
        # Assuming batch_size will be 1 for typical endpoint requests, but handling potential batching from client for robustness.
        result_list = []
        for i in range(multi_vector_embeddings.shape[0]):
            # Each element in the result_list will be a dictionary for one input,
            # containing its multi-vector embedding (fixed 32 x 128)
            result_list.append({"embedding": multi_vector_embeddings[i].tolist()})

        return result_list