import torch | |
import torch.nn as nn | |
class RMSNorm(nn.Module): | |
def __init__(self, dimensions: int, eps: float, device: torch.device, dtype: torch.dtype = torch.bfloat16, norm_in_fp32: bool = False): | |
super().__init__() | |
self.eps = eps | |
self.weight = torch.nn.Parameter(torch.ones(dimensions, dtype=dtype).to(device)) | |
self.norm_in_fp32 = norm_in_fp32 | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
original_dtype = x.dtype | |
if self.norm_in_fp32: | |
x = x.float() | |
out = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
if out.dtype != original_dtype: | |
out = out.to(original_dtype) | |
return out * self.weight | |