Rename to paged-attention
Browse files- build.toml +1 -1
- tests/kernels/conftest.py +3 -3
- tests/kernels/test_attention.py +2 -2
- tests/kernels/test_cache.py +2 -2
- tests/kernels/utils.py +1 -1
- torch-ext/{attention → paged_attention}/__init__.py +0 -0
- torch-ext/{attention → paged_attention}/_custom_ops.py +0 -0
- torch-ext/{attention → paged_attention}/platforms.py +0 -0
build.toml
CHANGED
|
@@ -2,7 +2,7 @@
|
|
| 2 |
version = "0.0.1"
|
| 3 |
|
| 4 |
[torch]
|
| 5 |
-
name = "
|
| 6 |
src = [
|
| 7 |
"torch-ext/registration.h",
|
| 8 |
"torch-ext/torch_binding.cpp",
|
|
|
|
| 2 |
version = "0.0.1"
|
| 3 |
|
| 4 |
[torch]
|
| 5 |
+
name = "paged_attention"
|
| 6 |
src = [
|
| 7 |
"torch-ext/registration.h",
|
| 8 |
"torch-ext/torch_binding.cpp",
|
tests/kernels/conftest.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
from typing import List, Optional, Tuple, Union
|
| 2 |
|
| 3 |
-
import
|
| 4 |
import pytest
|
| 5 |
import torch
|
| 6 |
|
|
@@ -41,7 +41,7 @@ def create_kv_caches_with_random(
|
|
| 41 |
raise ValueError(
|
| 42 |
f"Does not support key cache of type fp8 with head_size {head_size}"
|
| 43 |
)
|
| 44 |
-
from
|
| 45 |
|
| 46 |
current_platform.seed_everything(seed)
|
| 47 |
|
|
@@ -88,7 +88,7 @@ def create_kv_caches_with_random_flash(
|
|
| 88 |
seed: int = 0,
|
| 89 |
device: Optional[str] = "cuda",
|
| 90 |
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
| 91 |
-
from
|
| 92 |
|
| 93 |
current_platform.seed_everything(seed)
|
| 94 |
|
|
|
|
| 1 |
from typing import List, Optional, Tuple, Union
|
| 2 |
|
| 3 |
+
import paged_attention as ops
|
| 4 |
import pytest
|
| 5 |
import torch
|
| 6 |
|
|
|
|
| 41 |
raise ValueError(
|
| 42 |
f"Does not support key cache of type fp8 with head_size {head_size}"
|
| 43 |
)
|
| 44 |
+
from paged_attention.platforms import current_platform
|
| 45 |
|
| 46 |
current_platform.seed_everything(seed)
|
| 47 |
|
|
|
|
| 88 |
seed: int = 0,
|
| 89 |
device: Optional[str] = "cuda",
|
| 90 |
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
| 91 |
+
from paged_attention.platforms import current_platform
|
| 92 |
|
| 93 |
current_platform.seed_everything(seed)
|
| 94 |
|
tests/kernels/test_attention.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
| 1 |
import random
|
| 2 |
from typing import List, Optional, Tuple
|
| 3 |
|
| 4 |
-
import
|
| 5 |
import pytest
|
| 6 |
import torch
|
| 7 |
-
from
|
| 8 |
|
| 9 |
from .allclose_default import get_default_atol, get_default_rtol
|
| 10 |
from .utils import get_max_shared_memory_bytes, opcheck
|
|
|
|
| 1 |
import random
|
| 2 |
from typing import List, Optional, Tuple
|
| 3 |
|
| 4 |
+
import paged_attention as ops
|
| 5 |
import pytest
|
| 6 |
import torch
|
| 7 |
+
from paged_attention.platforms import current_platform
|
| 8 |
|
| 9 |
from .allclose_default import get_default_atol, get_default_rtol
|
| 10 |
from .utils import get_max_shared_memory_bytes, opcheck
|
tests/kernels/test_cache.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
| 1 |
import random
|
| 2 |
from typing import List, Tuple
|
| 3 |
|
| 4 |
-
import
|
| 5 |
import pytest
|
| 6 |
import torch
|
| 7 |
-
from
|
| 8 |
|
| 9 |
from .utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
|
| 10 |
|
|
|
|
| 1 |
import random
|
| 2 |
from typing import List, Tuple
|
| 3 |
|
| 4 |
+
import paged_attention as ops
|
| 5 |
import pytest
|
| 6 |
import torch
|
| 7 |
+
from paged_attention.platforms import current_platform
|
| 8 |
|
| 9 |
from .utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
|
| 10 |
|
tests/kernels/utils.py
CHANGED
|
@@ -83,7 +83,7 @@ def opcheck(
|
|
| 83 |
@lru_cache(maxsize=None)
|
| 84 |
def get_max_shared_memory_bytes(gpu: int = 0) -> int:
|
| 85 |
"""Returns the maximum shared memory per thread block in bytes."""
|
| 86 |
-
from
|
| 87 |
|
| 88 |
max_shared_mem = ops.get_max_shared_memory_per_block_device_attribute(gpu)
|
| 89 |
# value 0 will cause MAX_SEQ_LEN become negative and test_attention.py
|
|
|
|
| 83 |
@lru_cache(maxsize=None)
|
| 84 |
def get_max_shared_memory_bytes(gpu: int = 0) -> int:
|
| 85 |
"""Returns the maximum shared memory per thread block in bytes."""
|
| 86 |
+
from paged_attention import ops
|
| 87 |
|
| 88 |
max_shared_mem = ops.get_max_shared_memory_per_block_device_attribute(gpu)
|
| 89 |
# value 0 will cause MAX_SEQ_LEN become negative and test_attention.py
|
torch-ext/{attention → paged_attention}/__init__.py
RENAMED
|
File without changes
|
torch-ext/{attention → paged_attention}/_custom_ops.py
RENAMED
|
File without changes
|
torch-ext/{attention → paged_attention}/platforms.py
RENAMED
|
File without changes
|