Output logits differ significantly for differenet attn_implementations on image inputs

#53
by zzigakovacic - opened

I've been comparing fa2, sdpa, and eager attention implementations. My understand is that these should be very close in logits.

For textual inputs, the relative mean difference is ~1%. For image inputs, the relative mean difference is 25% and relative max difference is 75%. These are very large. I attempted this on:

Qwen 2.5 VL 7B, 3B and qwen 2 VL 7B.

Any thoughts?

Here is the script (slightly modified from here https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct):

import torch
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info

torch.manual_seed(42)

processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
messages = [{
    "role": "user",
    "content": [
        # {"type": "image", "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg"},
        {"type": "text", "text": "Describe this simage."}
    ],
}]
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = process_vision_info(messages)

inputs = processor(
    text=[text],
    images=image_inputs,
    videos=video_inputs,
    padding=True,
    return_tensors="pt"
)

model_flash = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    "Qwen/Qwen2.5-VL-3B-Instruct",
    torch_dtype=torch.bfloat16,
    attn_implementation="sdpa",
    device_map={"": 0},
)
model_flash.eval()

model_sdpa = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    "Qwen/Qwen2.5-VL-3B-Instruct",
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",  # or omit
    # attn_implementation="eager",
    device_map={"": 1},
)
model_sdpa.eval()

# === Run FlashAttention2 ===
with torch.no_grad():
    inputs_flash = {k: v.to("cuda:0") for k, v in inputs.items()}
    out_flash = model_flash(**inputs_flash).logits.cpu()

# === Run SDPA ===
with torch.no_grad():
    inputs_sdpa = {k: v.to("cuda:1") for k, v in inputs.items()}
    out_sdpa = model_sdpa(**inputs_sdpa).logits.cpu()

diff = (out_flash - out_sdpa).abs()
print(f"Max abs diff: {diff.max().item():.6f}")
print(f"Mean abs diff: {diff.mean().item():.6f}")
print(f"Relative max diff: {(diff.max() / out_flash.abs().max()).item():.6f}")
print(f"Relative mean diff: {(diff.mean() / out_flash.abs().mean()).item():.6f}")

Thanks!

I encountered the same issue! Whenever I have "output_attentions"=True during the forward call, the resulting logits will be different (compared to "output_attentions"=False), I think this is due to the forward call switching to eager instead of sdpa attention...

Sign up or log in to comment