lym00 commited on
Commit
e0364bc
·
verified ·
1 Parent(s): 80bc98f

Upload test_sageattn3.py

Browse files
Files changed (1) hide show
  1. test_sageattn3.py +48 -0
test_sageattn3.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from sageattn import sageattn_blackwell
6
+ from torch.nn.attention import SDPBackend, sdpa_kernel
7
+
8
+
9
+ def get_rtol_atol(actual, expect):
10
+ actual = actual.float()
11
+ expect = expect.float()
12
+ diff = (actual - expect).abs()
13
+ eps = torch.tensor(
14
+ torch.finfo(actual.dtype).eps, device=actual.device, dtype=actual.dtype
15
+ )
16
+ rdiff = diff / torch.maximum(torch.maximum(actual.abs(), expect.abs()), eps)
17
+ return (
18
+ f"mean_rtol={rdiff.mean().item():.3g} "
19
+ f"max_rtol={rdiff.max().item():.3g} "
20
+ f"mean_atol={diff.max().item():.3g} "
21
+ f"max_atol={diff.max().item():.3g}"
22
+ )
23
+
24
+
25
+ def main():
26
+ batch_size = 4
27
+ head_num = 32
28
+ seq_len = 64
29
+ head_dim = 128
30
+ dtype = torch.float16
31
+
32
+ q = torch.randn(batch_size, head_num, seq_len, head_dim, device="cuda", dtype=dtype)
33
+ k = torch.randn_like(q)
34
+ v = torch.randn_like(q)
35
+ print("q", tuple(q.shape), q.device, q.dtype)
36
+
37
+ # 'Mathematically correct' implementation
38
+ torch.backends.cuda.enable_math_sdp(True)
39
+ with sdpa_kernel(SDPBackend.MATH):
40
+ out_math = F.scaled_dot_product_attention(q, k, v)
41
+
42
+ out_sage = sageattn_blackwell(q, k, v, is_causal=False)
43
+ print("sage vs math:", get_rtol_atol(out_sage, out_math))
44
+ print("The above (except max_rtol) should be < 0.05 (on RTX 20xx/30xx) or < 0.1 (on RTX 40xx/50xx)")
45
+
46
+
47
+ if __name__ == "__main__":
48
+ main()