danieldk's picture
danieldk HF Staff
Sync with upstream, add tests
bbbdefe
raw
history blame
4.24 kB
"""Triton layer normalization kernels
This kernel implements layers normalization using Triton. This kernel is from
the `flash-attention <https://github.com/Dao-AILab/flash-attention>`_ project.
"""
from typing import Optional
import torch
from . import layers
from .layer_norm import layer_norm_fn, layer_norm_linear_fn, rms_norm_fn
def layer_norm(
x: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
residual: Optional[torch.Tensor] = None,
x1: Optional[torch.Tensor] = None,
weight1: Optional[torch.Tensor] = None,
bias1: Optional[torch.Tensor] = None,
eps: float = 1e-6,
dropout_p: float = 0.0,
rowscale=None,
prenorm: bool = False,
residual_in_fp32: bool = False,
zero_centered_weight: bool = False,
is_rms_norm: bool = False,
return_dropout_mask: bool = False,
out: Optional[torch.Tensor] = None,
residual_out: Optional[torch.Tensor] = None,
):
"""
Apply layer normalization to the input tensor with Triton acceleration.
Args:
x (`torch.Tensor`):
Input tensor to normalize.
weight (`torch.Tensor`):
Scale parameter for normalization.
bias (`torch.Tensor`):
Shift parameter for normalization.
residual (`torch.Tensor`, *optional*):
Optional residual tensor to add to the input before normalization.
x1 (`torch.Tensor`, *optional*):
Optional second input tensor to combine with `x`. When provided, the function
first adds `x1` to `x` and then applies normalization.
weight1 (`torch.Tensor`, *optional*):
Scale parameter for the second normalization.
bias1 (`torch.Tensor`, *optional*):
Shift parameter for the second normalization.
eps (`float`, *optional*, defaults to 1e-6):
Small constant added for numerical stability in normalization.
dropout_p (`float`, *optional*, defaults to 0.0):
Dropout probability. If greater than 0, applies dropout to the input before
normalization and residual addition.
rowscale (`torch.Tensor`, *optional*):
Optional scaling factor applied to each row of the input tensor.
Not compatible with the use of `x1`.
prenorm (`bool`, *optional*, defaults to False):
If True, returns both the normalized output and the unnormalized input+residual.
residual_in_fp32 (`bool`, *optional*, defaults to False):
If True, performs the residual connection in FP32 precision.
zero_centered_weight (`bool`, *optional*, defaults to False):
When set to true, 1.0 is added to the weight before applying it.
is_rms_norm (`bool`, *optional*, defaults to False):
If True, uses RMS normalization instead of layer normalization.
return_dropout_mask (`bool`, *optional*, defaults to False):
If True, returns the dropout mask used for the computation.
out (`torch.Tensor`, *optional*):
Output tensor for the normalized result. If `None`, a new tensor is allocated.
residual_out (`torch.Tensor`, *optional*):
Output tensor for the residual result when using prenorm. If `None`, a new tensor
is allocated when needed.
Returns:
`torch.Tensor` or tuple of `torch.Tensor`:
- The normalized input.
- The second normalization of the input if `weight1` is provided.
- The residual tensor if `prenorm` is set.
- The dropout mask if `return_dropout_mask` is set.
- The dropout mask for `x1` if `x1` is provided and `return_dropout_mask` is set.
"""
return layer_norm_fn(
x,
weight,
bias,
residual,
x1,
weight1,
bias1,
eps,
dropout_p,
rowscale,
prenorm,
residual_in_fp32,
is_rms_norm,
return_dropout_mask,
out=out,
residual_out=residual_out,
)
__kernel_metadata__ = {
"license": "bsd-3-clause",
}
__all__ = [
"__kernel_metadata__",
"layers",
"layer_norm",
"layer_norm_fn",
"layer_norm_linear_fn",
"rms_norm_fn",
]