Ensure tests also work in test shells
Browse files- tests/test_rotary.py +6 -3
tests/test_rotary.py
CHANGED
@@ -2,7 +2,6 @@ import pytest
|
|
2 |
import torch
|
3 |
|
4 |
from tests.utils import infer_device, supports_bfloat16
|
5 |
-
from kernels import get_local_kernel
|
6 |
from pathlib import Path
|
7 |
|
8 |
# import rotary
|
@@ -10,8 +9,12 @@ from pathlib import Path
|
|
10 |
# set_seed(42)
|
11 |
|
12 |
# Set the local repo path, relative path
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
|
|
15 |
|
16 |
def apply_rotary_torch(x1: torch.Tensor, x2: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, conj: bool = False):
|
17 |
assert x1.shape == x2.shape, "x1 and x2 must have the same shape"
|
|
|
2 |
import torch
|
3 |
|
4 |
from tests.utils import infer_device, supports_bfloat16
|
|
|
5 |
from pathlib import Path
|
6 |
|
7 |
# import rotary
|
|
|
9 |
# set_seed(42)
|
10 |
|
11 |
# Set the local repo path, relative path
|
12 |
+
try:
|
13 |
+
import rotary
|
14 |
+
except ImportError:
|
15 |
+
from kernels import get_local_kernel
|
16 |
+
repo_path = Path(__file__).parent.parent
|
17 |
+
rotary = get_local_kernel(repo_path=repo_path, package_name="rotary")
|
18 |
|
19 |
def apply_rotary_torch(x1: torch.Tensor, x2: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, conj: bool = False):
|
20 |
assert x1.shape == x2.shape, "x1 and x2 must have the same shape"
|