from datasets import load_dataset import torch from transformers import PaliGemmaProcessor, PaliGemmaForConditionalGeneration, Trainer, TrainingArguments, BitsAndBytesConfig from peft import get_peft_model, LoraConfig import os USE_LORA = False USE_QLORA = False FREEZE_VISION = False ds = load_dataset('merve/vqav2-small', split="validation") ds = ds.train_test_split(test_size=0.5)["train"] model_id = "google/paligemma2-3b-pt-448" processor = PaliGemmaProcessor.from_pretrained(model_id) device = "cuda" if torch.cuda.is_available() else "cpu" image_token = processor.tokenizer.convert_tokens_to_ids("") def collate_fn(examples): texts = ["answer en " + example["question"] for example in examples] labels= [example['multiple_choice_answer'] for example in examples] images = [example["image"].convert("RGB") for example in examples] tokens = processor(text=texts, images=images, suffix=labels, return_tensors="pt", padding="longest") tokens = tokens.to(torch.bfloat16).to(device) return tokens if USE_LORA or USE_QLORA: lora_config = LoraConfig( r=8, target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"], task_type="CAUSAL_LM", ) if USE_QLORA: bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_type=torch.bfloat16 ) model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, device_map="auto", quantization_config=bnb_config if USE_QLORA else None, torch_dtype=torch.bfloat16) model = get_peft_model(model, lora_config) model = model.to(device) model.print_trainable_parameters() else: model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, device_map="auto").to(device) model = model.to(device) if FREEZE_VISION: for param in model.vision_tower.parameters(): param.requires_grad = False for param in model.multi_modal_projector.parameters(): param.requires_grad = False args=TrainingArguments( num_train_epochs=3, remove_unused_columns=False, per_device_train_batch_size=4, gradient_accumulation_steps=4, warmup_steps=2, learning_rate=2e-5, weight_decay=1e-6, adam_beta2=0.999, logging_steps=100, optim="adamw_hf", save_strategy="steps", save_steps=1000, save_total_limit=1, push_to_hub=True output_dir="paligemma_vqav2", bf16=True, report_to=["tensorboard"], dataloader_pin_memory=False ) trainer = Trainer( model=model, train_dataset=ds , data_collator=collate_fn, args=args ) trainer.train()