File size: 2,009 Bytes
634cac7
 
 
9cb83c5
 
634cac7
 
 
9cb83c5
 
 
 
634cac7
 
 
 
 
 
2336bf5
 
 
 
 
 
9cb83c5
634cac7
 
 
 
 
9cb83c5
 
634cac7
 
 
 
 
 
 
 
 
 
9cb83c5
634cac7
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
from pathlib import Path
import torch
import sys
from .modeling import ConstBERT
from .colbert_configuration import ColBERTConfig

try:
    print("Loading model...")
    model_name = Path(__file__).parent  # directory containing this script
    # Load ColBERTConfig for ConstBERT
    colbert_config = ColBERTConfig(str(model_name))
    model = ConstBERT.from_pretrained(str(model_name), colbert_config=colbert_config)
    print("✓ Model loaded successfully")
    
    print("Setting model to evaluation mode...")
    model.eval()
    print("✓ Model set to evaluation mode")

    # Use the doc_maxlen from the *loaded model's* colbert_config
    actual_doc_maxlen = model.colbert_config.doc_maxlen
    print(f"DEBUG: model.colbert_config.doc_maxlen = {actual_doc_maxlen}")
    print(f"Preparing dummy input for ONNX export with doc_maxlen={actual_doc_maxlen}...")
    dummy_input_ids = torch.ones((1, actual_doc_maxlen), dtype=torch.long)
    dummy_attention_mask = torch.ones((1, actual_doc_maxlen), dtype=torch.long)
    print("✓ Dummy input prepared")

    print("Exporting model to ONNX format...")
    # Export ONNX
    torch.onnx.export(
        model,
        (dummy_input_ids, dummy_attention_mask),
        str(model_name / "model.onnx"),
        input_names=["input_ids", "attention_mask"],
        output_names=["last_hidden_state"],
        dynamic_axes={
            "input_ids": {0: "batch", 1: "seq"},
            "attention_mask": {0: "batch", 1: "seq"},
            "last_hidden_state": {0: "batch", 1: "seq"},
        },
        opset_version=14,
    )
    print("✓ Model exported to ONNX successfully")
    print(f"✓ ONNX file saved as: {model_name / 'model.onnx'}")

except FileNotFoundError as e:
    print(f"❌ Error: Model files not found in current directory: {e}")
    sys.exit(1)
except ImportError as e:
    print(f"❌ Error: Failed to import required modules: {e}")
    sys.exit(1)
except Exception as e:
    print(f"❌ Error during model export: {e}")
    sys.exit(1)