File size: 2,009 Bytes
634cac7 9cb83c5 634cac7 9cb83c5 634cac7 2336bf5 9cb83c5 634cac7 9cb83c5 634cac7 9cb83c5 634cac7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
from pathlib import Path
import torch
import sys
from .modeling import ConstBERT
from .colbert_configuration import ColBERTConfig
try:
print("Loading model...")
model_name = Path(__file__).parent # directory containing this script
# Load ColBERTConfig for ConstBERT
colbert_config = ColBERTConfig(str(model_name))
model = ConstBERT.from_pretrained(str(model_name), colbert_config=colbert_config)
print("✓ Model loaded successfully")
print("Setting model to evaluation mode...")
model.eval()
print("✓ Model set to evaluation mode")
# Use the doc_maxlen from the *loaded model's* colbert_config
actual_doc_maxlen = model.colbert_config.doc_maxlen
print(f"DEBUG: model.colbert_config.doc_maxlen = {actual_doc_maxlen}")
print(f"Preparing dummy input for ONNX export with doc_maxlen={actual_doc_maxlen}...")
dummy_input_ids = torch.ones((1, actual_doc_maxlen), dtype=torch.long)
dummy_attention_mask = torch.ones((1, actual_doc_maxlen), dtype=torch.long)
print("✓ Dummy input prepared")
print("Exporting model to ONNX format...")
# Export ONNX
torch.onnx.export(
model,
(dummy_input_ids, dummy_attention_mask),
str(model_name / "model.onnx"),
input_names=["input_ids", "attention_mask"],
output_names=["last_hidden_state"],
dynamic_axes={
"input_ids": {0: "batch", 1: "seq"},
"attention_mask": {0: "batch", 1: "seq"},
"last_hidden_state": {0: "batch", 1: "seq"},
},
opset_version=14,
)
print("✓ Model exported to ONNX successfully")
print(f"✓ ONNX file saved as: {model_name / 'model.onnx'}")
except FileNotFoundError as e:
print(f"❌ Error: Model files not found in current directory: {e}")
sys.exit(1)
except ImportError as e:
print(f"❌ Error: Failed to import required modules: {e}")
sys.exit(1)
except Exception as e:
print(f"❌ Error during model export: {e}")
sys.exit(1)
|