constbert-onnx / export_to_onnx.py
ag-nexla's picture
added custom hadnler
2336bf5
from pathlib import Path
import torch
import sys
from .modeling import ConstBERT
from .colbert_configuration import ColBERTConfig
try:
print("Loading model...")
model_name = Path(__file__).parent # directory containing this script
# Load ColBERTConfig for ConstBERT
colbert_config = ColBERTConfig(str(model_name))
model = ConstBERT.from_pretrained(str(model_name), colbert_config=colbert_config)
print("✓ Model loaded successfully")
print("Setting model to evaluation mode...")
model.eval()
print("✓ Model set to evaluation mode")
# Use the doc_maxlen from the *loaded model's* colbert_config
actual_doc_maxlen = model.colbert_config.doc_maxlen
print(f"DEBUG: model.colbert_config.doc_maxlen = {actual_doc_maxlen}")
print(f"Preparing dummy input for ONNX export with doc_maxlen={actual_doc_maxlen}...")
dummy_input_ids = torch.ones((1, actual_doc_maxlen), dtype=torch.long)
dummy_attention_mask = torch.ones((1, actual_doc_maxlen), dtype=torch.long)
print("✓ Dummy input prepared")
print("Exporting model to ONNX format...")
# Export ONNX
torch.onnx.export(
model,
(dummy_input_ids, dummy_attention_mask),
str(model_name / "model.onnx"),
input_names=["input_ids", "attention_mask"],
output_names=["last_hidden_state"],
dynamic_axes={
"input_ids": {0: "batch", 1: "seq"},
"attention_mask": {0: "batch", 1: "seq"},
"last_hidden_state": {0: "batch", 1: "seq"},
},
opset_version=14,
)
print("✓ Model exported to ONNX successfully")
print(f"✓ ONNX file saved as: {model_name / 'model.onnx'}")
except FileNotFoundError as e:
print(f"❌ Error: Model files not found in current directory: {e}")
sys.exit(1)
except ImportError as e:
print(f"❌ Error: Failed to import required modules: {e}")
sys.exit(1)
except Exception as e:
print(f"❌ Error during model export: {e}")
sys.exit(1)