File size: 4,190 Bytes
cebca1f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
import argparse
import shutil
from pathlib import Path
import torch
from datasets import load_dataset
from peft import LoraConfig
from transformers import TrainingArguments
from colpali_engine.data.dataset import ColPaliEngineDataset
from colpali_engine.loss.late_interaction_losses import ColbertLoss, ColbertPairwiseCELoss
from colpali_engine.models import ColQwen2_5Omni, ColQwen2_5OmniProcessor
from colpali_engine.trainer.colmodel_torch_training import ColModelTorchTraining
from colpali_engine.trainer.colmodel_training import ColModelTraining, ColModelTrainingConfig
def parse_args():
p = argparse.ArgumentParser()
p.add_argument("--output-dir", type=str, required=True, help="where to write model + script copy")
p.add_argument("--lr", type=float, default=1e-4, help="learning rate")
p.add_argument("--tau", type=float, default=0.02, help="temperature for loss function")
p.add_argument("--trainer", type=str, default="hf", choices=["torch", "hf"], help="trainer to use")
p.add_argument("--loss", type=str, default="ce", choices=["ce", "pairwise"], help="loss function to use")
p.add_argument("--peft", action="store_true", help="use PEFT for training")
return p.parse_args()
if __name__ == "__main__":
args = parse_args()
if args.loss == "ce":
loss_func = ColbertLoss(
temperature=args.tau,
normalize_scores=True,
use_smooth_max=False,
pos_aware_negative_filtering=False,
)
elif args.loss == "pairwise":
loss_func = ColbertPairwiseCELoss(
normalize_scores=False,
)
else:
raise ValueError(f"Unknown loss function: {args.loss}")
config = ColModelTrainingConfig(
output_dir=args.output_dir,
processor=ColQwen2_5OmniProcessor.from_pretrained(
pretrained_model_name_or_path="./models/base_models/colqwen2.5omni-base",
),
model=ColQwen2_5Omni.from_pretrained(
pretrained_model_name_or_path="./models/base_models/colqwen2.5omni-base",
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
),
train_dataset=ColPaliEngineDataset(
load_dataset("./data_dir/colpali_train_set", split="train"), pos_target_column_name="image"
),
eval_dataset=ColPaliEngineDataset(
load_dataset("./data_dir/colpali_train_set", split="test"), pos_target_column_name="image"
),
run_eval=True,
loss_func=loss_func,
tr_args=TrainingArguments(
output_dir=None,
overwrite_output_dir=True,
num_train_epochs=5,
per_device_train_batch_size=64,
gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": False},
per_device_eval_batch_size=16,
eval_strategy="steps",
dataloader_num_workers=2,
save_steps=500,
logging_steps=10,
eval_steps=100,
warmup_steps=100,
learning_rate=args.lr,
save_total_limit=1,
dataloader_prefetch_factor=2,
dataloader_pin_memory=True,
dataloader_persistent_workers=True,
),
peft_config=LoraConfig(
r=32,
lora_alpha=32,
lora_dropout=0.1,
init_lora_weights="gaussian",
bias="none",
task_type="FEATURE_EXTRACTION",
target_modules="(.*(model)(?!.*visual).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)",
)
if args.peft
else None,
)
config.model.audio_tower = torch.nn.Identity() # Disable the audio tower
# config.model = torch.compile(config.model, dynamic=True, fullgraph=True, mode="max-autotune")
# make sure output_dir exists and copy script for provenance
Path(config.output_dir).mkdir(parents=True, exist_ok=True)
shutil.copy(Path(__file__), Path(config.output_dir) / Path(__file__).name)
trainer = ColModelTraining(config) if args.trainer == "hf" else ColModelTorchTraining(config)
trainer.train()
trainer.save()
|