nvedant07's picture
Upload folder using huggingface_hub
5c004da verified
raw
history blame contribute delete
719 Bytes
import torch
import torch.nn as nn
class RMSNorm(nn.Module):
def __init__(self, dimensions: int, eps: float, device: torch.device, dtype: torch.dtype = torch.bfloat16, norm_in_fp32: bool = False):
super().__init__()
self.eps = eps
self.weight = torch.nn.Parameter(torch.ones(dimensions, dtype=dtype).to(device))
self.norm_in_fp32 = norm_in_fp32
def forward(self, x: torch.Tensor) -> torch.Tensor:
original_dtype = x.dtype
if self.norm_in_fp32:
x = x.float()
out = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
if out.dtype != original_dtype:
out = out.to(original_dtype)
return out * self.weight