Kernels
danieldk's picture
danieldk HF Staff
Build
4db658e
raw
history blame
324 Bytes
from typing import Tuple
import torch
from ._ops import ops
def apply_rotary(
x1: torch.Tensor,
x2: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
out1: torch.Tensor,
out2: torch.Tensor,
conj: bool,
):
ops.apply_rotary(x1, x2, cos, sin, out1, out2, conj)
__all__ = ["apply_rotary"]