File size: 1,579 Bytes
05870d0 4a2a748 05870d0 409dba5 05870d0 e0364bc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
#!/usr/bin/env python3
# Based on
# https://raw.githubusercontent.com/woct0rdho/SageAttention/refs/heads/main/tests/test_sageattn.py
# https://huggingface.co/jt-zhang/SageAttention3
import torch
import torch.nn.functional as F
from sageattn import sageattn_blackwell
from torch.nn.attention import SDPBackend, sdpa_kernel
def get_rtol_atol(actual, expect):
actual = actual.float()
expect = expect.float()
diff = (actual - expect).abs()
eps = torch.tensor(
torch.finfo(actual.dtype).eps, device=actual.device, dtype=actual.dtype
)
rdiff = diff / torch.maximum(torch.maximum(actual.abs(), expect.abs()), eps)
return (
f"mean_rtol={rdiff.mean().item():.3g} "
f"max_rtol={rdiff.max().item():.3g} "
f"mean_atol={diff.max().item():.3g} "
f"max_atol={diff.max().item():.3g}"
)
def main():
batch_size = 4
head_num = 32
seq_len = 64
head_dim = 128
dtype = torch.float16
q = torch.randn(batch_size, head_num, seq_len, head_dim, device="cuda", dtype=dtype)
k = torch.randn_like(q)
v = torch.randn_like(q)
print("q", tuple(q.shape), q.device, q.dtype)
# 'Mathematically correct' implementation
torch.backends.cuda.enable_math_sdp(True)
with sdpa_kernel(SDPBackend.MATH):
out_math = F.scaled_dot_product_attention(q, k, v)
out_sage = sageattn_blackwell(q, k, v, is_causal=False)
print("sage vs math:", get_rtol_atol(out_sage, out_math))
print("The above (except max_rtol) should be < 0.1 (on RTX 50xx)")
if __name__ == "__main__":
main() |