import torch import torch.nn as nn from ._ops import ops class SiluAndMul(nn.Module): def forward(self, x: torch.Tensor): d = x.shape[-1] // 2 output_shape = x.shape[:-1] + (d,) out = torch.empty(output_shape, dtype=x.dtype, device=x.device) ops.silu_and_mul(out, x) return out class GeluAndMul(nn.Module): def forward(self, x: torch.Tensor): d = x.shape[-1] // 2 output_shape = x.shape[:-1] + (d,) out = torch.empty(output_shape, dtype=x.dtype, device=x.device) ops.gelu_and_mul(out, x) return out class GeluTanhAndMul(nn.Module): def forward(self, x: torch.Tensor): d = x.shape[-1] // 2 output_shape = x.shape[:-1] + (d,) out = torch.empty(output_shape, dtype=x.dtype, device=x.device) ops.gelu_tanh_and_mul(out, x) return out class FatreluAndMul(nn.Module): def __init__(self, threshold: float = 0.0): super().__init__() self.threshold = threshold def forward(self, x: torch.Tensor): d = x.shape[-1] // 2 output_shape = x.shape[:-1] + (d,) out = torch.empty(output_shape, dtype=x.dtype, device=x.device) ops.fatrelu_and_mul(out, x, self.threshold) return out class FastGELU(nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: out = torch.empty_like(x) ops.gelu_fast(out, x) return out class NewGELU(nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: out = torch.empty_like(x) ops.gelu_new(out, x) return out class QuickGELU(nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: out = torch.empty_like(x) ops.gelu_quick(out, x) return out