|
from pathlib import Path |
|
import torch |
|
import sys |
|
from .modeling import ConstBERT |
|
from .colbert_configuration import ColBERTConfig |
|
|
|
try: |
|
print("Loading model...") |
|
model_name = Path(__file__).parent |
|
|
|
colbert_config = ColBERTConfig(str(model_name)) |
|
model = ConstBERT.from_pretrained(str(model_name), colbert_config=colbert_config) |
|
print("✓ Model loaded successfully") |
|
|
|
print("Setting model to evaluation mode...") |
|
model.eval() |
|
print("✓ Model set to evaluation mode") |
|
|
|
|
|
actual_doc_maxlen = model.colbert_config.doc_maxlen |
|
print(f"DEBUG: model.colbert_config.doc_maxlen = {actual_doc_maxlen}") |
|
print(f"Preparing dummy input for ONNX export with doc_maxlen={actual_doc_maxlen}...") |
|
dummy_input_ids = torch.ones((1, actual_doc_maxlen), dtype=torch.long) |
|
dummy_attention_mask = torch.ones((1, actual_doc_maxlen), dtype=torch.long) |
|
print("✓ Dummy input prepared") |
|
|
|
print("Exporting model to ONNX format...") |
|
|
|
torch.onnx.export( |
|
model, |
|
(dummy_input_ids, dummy_attention_mask), |
|
str(model_name / "model.onnx"), |
|
input_names=["input_ids", "attention_mask"], |
|
output_names=["last_hidden_state"], |
|
dynamic_axes={ |
|
"input_ids": {0: "batch", 1: "seq"}, |
|
"attention_mask": {0: "batch", 1: "seq"}, |
|
"last_hidden_state": {0: "batch", 1: "seq"}, |
|
}, |
|
opset_version=14, |
|
) |
|
print("✓ Model exported to ONNX successfully") |
|
print(f"✓ ONNX file saved as: {model_name / 'model.onnx'}") |
|
|
|
except FileNotFoundError as e: |
|
print(f"❌ Error: Model files not found in current directory: {e}") |
|
sys.exit(1) |
|
except ImportError as e: |
|
print(f"❌ Error: Failed to import required modules: {e}") |
|
sys.exit(1) |
|
except Exception as e: |
|
print(f"❌ Error during model export: {e}") |
|
sys.exit(1) |
|
|