Output logits differ significantly for differenet attn_implementations on image inputs
#53
by
zzigakovacic
- opened
I've been comparing fa2, sdpa, and eager attention implementations. My understand is that these should be very close in logits.
For textual inputs, the relative mean difference is ~1%. For image inputs, the relative mean difference is 25% and relative max difference is 75%. These are very large. I attempted this on:
Qwen 2.5 VL 7B, 3B and qwen 2 VL 7B.
Any thoughts?
Here is the script (slightly modified from here https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct):
import torch
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
torch.manual_seed(42)
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
messages = [{
"role": "user",
"content": [
# {"type": "image", "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg"},
{"type": "text", "text": "Describe this simage."}
],
}]
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt"
)
model_flash = Qwen2_5_VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2.5-VL-3B-Instruct",
torch_dtype=torch.bfloat16,
attn_implementation="sdpa",
device_map={"": 0},
)
model_flash.eval()
model_sdpa = Qwen2_5_VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2.5-VL-3B-Instruct",
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2", # or omit
# attn_implementation="eager",
device_map={"": 1},
)
model_sdpa.eval()
# === Run FlashAttention2 ===
with torch.no_grad():
inputs_flash = {k: v.to("cuda:0") for k, v in inputs.items()}
out_flash = model_flash(**inputs_flash).logits.cpu()
# === Run SDPA ===
with torch.no_grad():
inputs_sdpa = {k: v.to("cuda:1") for k, v in inputs.items()}
out_sdpa = model_sdpa(**inputs_sdpa).logits.cpu()
diff = (out_flash - out_sdpa).abs()
print(f"Max abs diff: {diff.max().item():.6f}")
print(f"Mean abs diff: {diff.mean().item():.6f}")
print(f"Relative max diff: {(diff.max() / out_flash.abs().max()).item():.6f}")
print(f"Relative mean diff: {(diff.mean() / out_flash.abs().mean()).item():.6f}")
Thanks!
I encountered the same issue! Whenever I have "output_attentions"=True
during the forward call, the resulting logits will be different (compared to "output_attentions"=False
), I think this is due to the forward call switching to eager instead of sdpa attention...