kernel
danieldk HF Staff commited on
Commit
e78fee4
·
1 Parent(s): 77fc3a8

Ensure tests also work in test shells

Browse files
Files changed (1) hide show
  1. tests/test_rotary.py +6 -3
tests/test_rotary.py CHANGED
@@ -2,7 +2,6 @@ import pytest
2
  import torch
3
 
4
  from tests.utils import infer_device, supports_bfloat16
5
- from kernels import get_local_kernel
6
  from pathlib import Path
7
 
8
  # import rotary
@@ -10,8 +9,12 @@ from pathlib import Path
10
  # set_seed(42)
11
 
12
  # Set the local repo path, relative path
13
- repo_path = Path(__file__).parent.parent
14
- rotary = get_local_kernel(repo_path=repo_path, package_name="rotary")
 
 
 
 
15
 
16
  def apply_rotary_torch(x1: torch.Tensor, x2: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, conj: bool = False):
17
  assert x1.shape == x2.shape, "x1 and x2 must have the same shape"
 
2
  import torch
3
 
4
  from tests.utils import infer_device, supports_bfloat16
 
5
  from pathlib import Path
6
 
7
  # import rotary
 
9
  # set_seed(42)
10
 
11
  # Set the local repo path, relative path
12
+ try:
13
+ import rotary
14
+ except ImportError:
15
+ from kernels import get_local_kernel
16
+ repo_path = Path(__file__).parent.parent
17
+ rotary = get_local_kernel(repo_path=repo_path, package_name="rotary")
18
 
19
  def apply_rotary_torch(x1: torch.Tensor, x2: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, conj: bool = False):
20
  assert x1.shape == x2.shape, "x1 and x2 must have the same shape"