MedGemma Fine-tuned for FLARE 2025 Medical Image Analysis

This model is a fine-tuned version of google/medgemma-4b-it specifically optimized for medical image analysis tasks in the FLARE 2025 2D Medical Multimodal Dataset challenge.

Model Description

  • Base Model: MedGemma-4B-IT (Google's medical-specialized Gemma model)
  • Fine-tuning Method: QLoRA (Low-Rank Adaptation)
  • Target Domain: Medical imaging across 7 modalities (CT, MRI, X-ray, Ultrasound, Fundus, Pathology, Endoscopy)
  • Tasks: Medical image captioning, visual question answering, report generation, diagnosis support
  • Training Data: 19 FLARE 2025 datasets with comprehensive medical annotations

Training Details

Training Data

The model was fine-tuned on 19 diverse medical imaging datasets from FLARE 2025, including:

  • Classification: Disease diagnosis with balanced accuracy optimization
  • Multi-label Classification: Multi-pathology identification
  • Detection: Anatomical structure and pathology detection
  • Instance Detection: Identity-aware detection (e.g., chromosome analysis)
  • Counting: Cell counting and quantitative analysis
  • Regression: Continuous medical measurements
  • Report Generation: Comprehensive medical report writing

Details available at: https://huggingface.co/datasets/FLARE-MedFM/FLARE-Task5-MLLM-2D

Training Configuration

# LoRA Configuration
lora_r: 16\nlora_alpha: 32
lora_dropout: 0.1
target_modules: ['gate_proj', 'up_proj', 'o_proj', 'down_proj', 'v_proj', 'q_proj', 'k_proj']
task_type: CAUSAL_LM
bias: none

Training Procedure

  • Base Architecture: MedGemma-4B with medical domain pre-training
  • Optimization: 4-bit quantization with BitsAndBytesConfig
  • LoRA Configuration:
    • r=64, alpha=16, dropout=0.1
    • Target modules: All attention and MLP layers
  • Memory Optimization: Gradient checkpointing, flash attention
  • Batch Size: Dynamic based on image resolution and GPU memory
  • Learning Rate: 1e-4 with cosine scheduling
  • Training Steps: 4000 steps with evaluation every 500 steps
  • Chat Template: Gemma-style chat formatting for medical conversations

Model Performance

This model has been evaluated across multiple medical imaging tasks using FLARE 2025 evaluation metrics:

Evaluation Metrics by Task Type

Classification Tasks (Disease Diagnosis):

  • Balanced Accuracy (PRIMARY): Handles class imbalance in medical diagnosis
  • Accuracy: Standard classification accuracy
  • F1 Score: Weighted F1 for multi-class scenarios

Multi-label Classification (Multi-pathology):

  • F1 Score (PRIMARY): Sample-wise F1 across multiple medical conditions
  • Precision: Label prediction precision
  • Recall: Medical condition coverage recall

Detection Tasks (Anatomical/Pathological):

  • F1 Score @ IoU > 0.5 (PRIMARY): Standard computer vision detection metric
  • Precision: Detection precision at IoU threshold
  • Recall: Detection recall at IoU threshold

Instance Detection (Identity-Aware Detection):

  • F1 Score @ IoU > 0.3 (PRIMARY): Medical imaging standard (e.g., chromosome detection)
  • F1 Score @ IoU > 0.5: Computer vision standard
  • Average F1: COCO-style average across IoU thresholds (0.3-0.7)
  • Per-instance metrics: Detailed breakdown by object identity

Counting Tasks (Cell/Structure Counting):

  • Mean Absolute Error (PRIMARY): Cell counting accuracy
  • Root Mean Squared Error: Additional counting precision metric

Regression Tasks (Medical Measurements):

  • Mean Absolute Error (PRIMARY): Continuous value prediction accuracy
  • Root Mean Squared Error: Regression precision metric

Report Generation (Medical Reports):

  • GREEN Score (PRIMARY): Comprehensive medical report evaluation with 7 components:
    • Entity matching with severity assessment (30%)
    • Location accuracy with laterality (20%)
    • Negation and uncertainty handling (15%)
    • Temporal accuracy (10%)
    • Size/measurement accuracy (10%)
    • Clinical significance weighting (10%)
    • Report structure completeness (5%)
  • BLEU Score: Text generation quality
  • Clinical Efficacy: Medical relevance scoring

Usage

Installation

pip install transformers torch peft accelerate bitsandbytes

Basic Usage

import torch
from transformers import AutoTokenizer, AutoProcessor, AutoModelForImageTextToText
from peft import PeftModel
from PIL import Image

# Load the fine-tuned model
base_model_name = "google/medgemma-4b-it"
adapter_model_name = "leoyinn/flare25-medgemma"

# Load tokenizer and processor
tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)
processor = AutoProcessor.from_pretrained(base_model_name, trust_remote_code=True)

# Load base model
base_model = AutoModelForImageTextToText.from_pretrained(
    base_model_name,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True,
    attn_implementation="eager"
)

# Load the fine-tuned adapter
model = PeftModel.from_pretrained(base_model, adapter_model_name)

# Prepare input with MedGemma chat format
image = Image.open("medical_image.jpg").convert("RGB")
image = image.resize((448, 448))  # MedGemma standard size

# Create proper message format
messages = [
    {
        "role": "system",
        "content": [{
            "type": "text", 
            "text": "You are an expert medical AI assistant specialized in analyzing medical images and providing accurate diagnostic insights."
        }]
    },
    {
        "role": "user",
        "content": [
            {"type": "image"},
            {"type": "text", "text": "Describe the medical findings in this image and provide a diagnostic assessment."}
        ]
    }
]

# Apply chat template
full_text = tokenizer.apply_chat_template(
    messages, 
    tokenize=False,
    add_generation_prompt=True
)

# Process and generate
inputs = processor(
    images=[image],
    text=full_text,
    return_tensors="pt",
    padding=True,
    truncation=False
).to(model.device, dtype=torch.bfloat16)

# Generate medical response
with torch.inference_mode():
    outputs = model.generate(
        **inputs,
        max_new_tokens=300,
        do_sample=False,  # Deterministic for medical applications
        use_cache=True,
        cache_implementation="dynamic"
    )

# Decode response
input_len = inputs["input_ids"].shape[-1]
response = processor.decode(outputs[0][input_len:], skip_special_tokens=True)
print(response)

Limitations and Ethical Considerations

Limitations

  • Model outputs may contain inaccuracies and should be verified by medical professionals
  • Performance may vary across different medical imaging modalities and populations
  • Training data may contain biases present in medical literature and datasets
  • Model has not been validated in clinical settings
  • Designed for research and educational purposes, not clinical decision-making

Intended Use

  • Medical education and training
  • Research in medical AI and computer vision
  • Development of clinical decision support tools (with proper validation)
  • Academic research in multimodal medical AI
  • Medical image analysis prototyping

Out-of-Scope Use

  • Direct clinical diagnosis without physician oversight
  • Treatment recommendations without medical professional validation
  • Use in emergency medical situations
  • Deployment in production clinical systems without extensive validation
  • Patient-facing applications without proper medical supervision

Citation

If you use this model in your research, please cite:

@misc{medgemma-flare2025,
  title={MedGemma Fine-tuned for FLARE 2025 Medical Image Analysis},
  author={Shuolin Yin},
  year={2025},
  publisher={Hugging Face},
  url={https://huggingface.co/leoyinn/flare25-medgemma}
}

@misc{medgemma-base,
  title={MedGemma: Medical Gemma Models for Healthcare},
  author={Google Research},
  year={2024},
  publisher={Hugging Face},
  url={https://huggingface.co/google/medgemma-4b-it}
}

@misc{flare2025,
  title={FLARE 2025: A Multi-Modal Foundation Model Challenge for Medical AI},
  year={2025},
  url={https://huggingface.co/datasets/FLARE-MedFM/FLARE-Task5-MLLM-2D}
}

Model Details

  • Model Type: Vision-Language Model (VLM) specialized for medical applications
  • Architecture: MedGemma (Gemma-based) with LoRA adapters
  • Parameters: ~4B base parameters + LoRA adapters
  • Precision: bfloat16 base model + full precision adapters
  • Framework: PyTorch, Transformers, PEFT
  • Input Resolution: 448x448 pixels (standard for MedGemma)
  • Context Length: Supports long medical reports and conversations

Technical Specifications

  • Base Model: google/medgemma-4b-it
  • Adapter Type: LoRA (Low-Rank Adaptation)
  • Target Modules: All attention projection layers and MLP layers
  • Chat Template: Gemma-style with medical system prompts
  • Attention Implementation: Eager attention for stability
  • Cache Implementation: Dynamic caching for efficient inference

Contact

For questions or issues, please open an issue in the model repository or contact the authors.


Disclaimer: This model is for research and educational purposes only. Always consult qualified medical professionals for clinical decisions.

Downloads last month
5
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for leoyinn/flare25-medgemma

Adapter
(14)
this model

Dataset used to train leoyinn/flare25-medgemma