lym00 commited on
Commit
05870d0
·
verified ·
1 Parent(s): e0364bc

Update test_sageattn3.py

Browse files
Files changed (1) hide show
  1. test_sageattn3.py +48 -47
test_sageattn3.py CHANGED
@@ -1,48 +1,49 @@
1
- #!/usr/bin/env python3
2
-
3
- import torch
4
- import torch.nn.functional as F
5
- from sageattn import sageattn_blackwell
6
- from torch.nn.attention import SDPBackend, sdpa_kernel
7
-
8
-
9
- def get_rtol_atol(actual, expect):
10
- actual = actual.float()
11
- expect = expect.float()
12
- diff = (actual - expect).abs()
13
- eps = torch.tensor(
14
- torch.finfo(actual.dtype).eps, device=actual.device, dtype=actual.dtype
15
- )
16
- rdiff = diff / torch.maximum(torch.maximum(actual.abs(), expect.abs()), eps)
17
- return (
18
- f"mean_rtol={rdiff.mean().item():.3g} "
19
- f"max_rtol={rdiff.max().item():.3g} "
20
- f"mean_atol={diff.max().item():.3g} "
21
- f"max_atol={diff.max().item():.3g}"
22
- )
23
-
24
-
25
- def main():
26
- batch_size = 4
27
- head_num = 32
28
- seq_len = 64
29
- head_dim = 128
30
- dtype = torch.float16
31
-
32
- q = torch.randn(batch_size, head_num, seq_len, head_dim, device="cuda", dtype=dtype)
33
- k = torch.randn_like(q)
34
- v = torch.randn_like(q)
35
- print("q", tuple(q.shape), q.device, q.dtype)
36
-
37
- # 'Mathematically correct' implementation
38
- torch.backends.cuda.enable_math_sdp(True)
39
- with sdpa_kernel(SDPBackend.MATH):
40
- out_math = F.scaled_dot_product_attention(q, k, v)
41
-
42
- out_sage = sageattn_blackwell(q, k, v, is_causal=False)
43
- print("sage vs math:", get_rtol_atol(out_sage, out_math))
44
- print("The above (except max_rtol) should be < 0.05 (on RTX 20xx/30xx) or < 0.1 (on RTX 40xx/50xx)")
45
-
46
-
47
- if __name__ == "__main__":
 
48
  main()
 
1
+ #!/usr/bin/env python3
2
+ # Based on https://raw.githubusercontent.com/woct0rdho/SageAttention/refs/heads/main/tests/test_sageattn.py
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from sageattn import sageattn_blackwell
7
+ from torch.nn.attention import SDPBackend, sdpa_kernel
8
+
9
+
10
+ def get_rtol_atol(actual, expect):
11
+ actual = actual.float()
12
+ expect = expect.float()
13
+ diff = (actual - expect).abs()
14
+ eps = torch.tensor(
15
+ torch.finfo(actual.dtype).eps, device=actual.device, dtype=actual.dtype
16
+ )
17
+ rdiff = diff / torch.maximum(torch.maximum(actual.abs(), expect.abs()), eps)
18
+ return (
19
+ f"mean_rtol={rdiff.mean().item():.3g} "
20
+ f"max_rtol={rdiff.max().item():.3g} "
21
+ f"mean_atol={diff.max().item():.3g} "
22
+ f"max_atol={diff.max().item():.3g}"
23
+ )
24
+
25
+
26
+ def main():
27
+ batch_size = 4
28
+ head_num = 32
29
+ seq_len = 64
30
+ head_dim = 128
31
+ dtype = torch.float16
32
+
33
+ q = torch.randn(batch_size, head_num, seq_len, head_dim, device="cuda", dtype=dtype)
34
+ k = torch.randn_like(q)
35
+ v = torch.randn_like(q)
36
+ print("q", tuple(q.shape), q.device, q.dtype)
37
+
38
+ # 'Mathematically correct' implementation
39
+ torch.backends.cuda.enable_math_sdp(True)
40
+ with sdpa_kernel(SDPBackend.MATH):
41
+ out_math = F.scaled_dot_product_attention(q, k, v)
42
+
43
+ out_sage = sageattn_blackwell(q, k, v, is_causal=False)
44
+ print("sage vs math:", get_rtol_atol(out_sage, out_math))
45
+ print("The above (except max_rtol) should be < 0.05 (on RTX 20xx/30xx) or < 0.1 (on RTX 40xx/50xx)")
46
+
47
+
48
+ if __name__ == "__main__":
49
  main()