| from typing import Optional, Tuple, List |
|
|
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Applica Rotary Positional Embeddings (RoPE) a query e key. |
| |
| Args: |
| q: Query tensor di shape (B, nh, T, hd) |
| k: Key tensor di shape (B, nh, T, hd) |
| cos: Cosine values di shape (1, 1, T, hd) |
| sin: Sine values di shape (1, 1, T, hd) |
| |
| Returns: |
| Tuple[torch.Tensor, torch.Tensor]: (q_rotated, k_rotated) |
| """ |
| |
| |
| hd = q.shape[-1] |
| assert hd % 2 == 0, "head_dim deve essere pari per RoPE" |
| |
| q1, q2 = q[..., :hd//2], q[..., hd//2:] |
| k1, k2 = k[..., :hd//2], k[..., hd//2:] |
| |
| |
| cos_half = cos[..., :hd//2] |
| sin_half = sin[..., :hd//2] |
| |
| q_rot = torch.cat([ |
| q1 * cos_half - q2 * sin_half, |
| q1 * sin_half + q2 * cos_half |
| ], dim=-1) |
| |
| k_rot = torch.cat([ |
| k1 * cos_half - k2 * sin_half, |
| k1 * sin_half + k2 * cos_half |
| ], dim=-1) |
| |
| return q_rot, k_rot |
|
|
|
|
| def router_aux_loss(alpha: torch.Tensor) -> torch.Tensor: |
| """ |
| Entropia media della distribuzione alpha sui K rami. |
| alpha: (B, T, K) |
| Ritorna entropia normalizzata in [0, 1] circa. |
| """ |
| if alpha is None: |
| return torch.tensor(0.0, device="cpu") |
| eps = 1e-9 |
| k = alpha.size(-1) |
| ent = -(alpha * (alpha.clamp_min(eps)).log()).sum(dim=-1) |
| norm_ent = ent / (torch.log(torch.tensor(float(k), device=alpha.device))) |
| return norm_ent.mean() |
|
|
|
|
| class DepthwiseCausalConv1d(nn.Module): |
| """ |
| Depthwise 1D causal convolution sulla dimensione di sequenza. |
| |
| Input: (B, T, H) -> output: (B, T, H) |
| groups=H per avere un filtro per canale. |
| """ |
|
|
| def __init__(self, channels: int, kernel_size: int = 3): |
| super().__init__() |
| assert kernel_size >= 1 and kernel_size % 2 == 1, "kernel_size should be odd" |
| self.kernel_size = kernel_size |
| self.pad = kernel_size - 1 |
| self.conv = nn.Conv1d( |
| in_channels=channels, |
| out_channels=channels, |
| kernel_size=kernel_size, |
| padding=0, |
| groups=channels, |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| |
| x_c = x.transpose(1, 2) |
| |
| x_c = F.pad(x_c, (self.pad, 0)) |
| y = self.conv(x_c) |
| y = y.transpose(1, 2) |
| return y |
|
|
|
|
| class ChannelAttention(nn.Module): |
| """ |
| Attenzione per-canale (tipo SE) per token. |
| """ |
|
|
| def __init__(self, channels: int, reduction: int = 4): |
| super().__init__() |
| hidden = max(channels // reduction, 1) |
| self.ln = nn.LayerNorm(channels) |
| self.fc1 = nn.Linear(channels, hidden) |
| self.fc2 = nn.Linear(hidden, channels) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| g = self.ln(x) |
| g = F.gelu(self.fc1(g)) |
| g = torch.sigmoid(self.fc2(g)) |
| return x * g |
|
|
|
|
| class Fp32LayerNorm(nn.Module): |
| """ |
| LayerNorm in float32 per stabilità numerica, castando avanti/indietro. |
| I parametri rimangono in float32. |
| """ |
|
|
| def __init__(self, normalized_shape: int, eps: float = 1e-5): |
| super().__init__() |
| self.ln = nn.LayerNorm(normalized_shape, eps=eps) |
| self.ln.to(dtype=torch.float32) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| orig_dtype = x.dtype |
| |
| if x.is_cuda: |
| with torch.autocast(device_type="cuda", enabled=False): |
| y = self.ln(x.to(torch.float32)) |
| else: |
| with torch.autocast(device_type="cpu", enabled=False): |
| y = self.ln(x.to(torch.float32)) |
| return y.to(orig_dtype) |
|
|
|
|
| |
|
|
|
|
| class SwigluMLP(nn.Module): |
| def __init__(self, hidden_size: int, mlp_mult: float): |
| super().__init__() |
| mlp_dim = int(round(mlp_mult * hidden_size)) |
| self.mlp_dim = mlp_dim |
| self.up = nn.Linear(hidden_size, 2 * mlp_dim) |
| self.down = nn.Linear(mlp_dim, hidden_size) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| up = self.up(x) |
| a, b = up.split(self.mlp_dim, dim=-1) |
| y = F.silu(a) * b |
| return self.down(y) |
|
|
|
|
| class GluMLP(nn.Module): |
| def __init__(self, hidden_size: int, mlp_mult: float): |
| super().__init__() |
| mlp_dim = int(round(mlp_mult * hidden_size)) |
| self.mlp_dim = mlp_dim |
| self.up = nn.Linear(hidden_size, 2 * mlp_dim) |
| self.down = nn.Linear(mlp_dim, hidden_size) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| up = self.up(x) |
| a, b = up.split(self.mlp_dim, dim=-1) |
| y = torch.sigmoid(a) * b |
| return self.down(y) |
|
|
|
|
| class DepthwiseConvBranch(nn.Module): |
| def __init__(self, hidden_size: int, mlp_mult: float = 4.0): |
| super().__init__() |
| mlp_dim = int(round(mlp_mult * hidden_size)) |
| self.dw = DepthwiseCausalConv1d(hidden_size, kernel_size=3) |
| self.expand = nn.Linear(hidden_size, mlp_dim) |
| self.act = nn.GELU() |
| self.contract = nn.Linear(mlp_dim, hidden_size) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| y = self.dw(x) |
| y = self.expand(y) |
| y = self.act(y) |
| return self.contract(y) |
|
|
|
|
| class PolymorphicMLP(nn.Module): |
| """ |
| MLP polimorfico: |
| |
| - Router: produce alpha (B, T, K) |
| - K rami base in una ModuleList (es. SwiGLU, GLU, depthwise-conv) |
| - Output: somma pesata dei rami |
| - Opzionale ChannelAttention |
| - Espone: |
| - last_alpha (B, T, K) per logging |
| - last_aux (entropia normalizzata media) per aux-loss |
| - force_func: se >= 0, forza un solo ramo (debug / training per ramo) |
| """ |
|
|
| def __init__( |
| self, |
| hidden_size: int, |
| mlp_mult: float = 4.0, |
| num_funcs: int = 3, |
| router_dim: Optional[int] = None, |
| dropout: float = 0.0, |
| use_channel_attention: bool = False, |
| router_tau: float = 1.0, |
| ): |
| super().__init__() |
| assert num_funcs >= 1, "PolymorphicMLP richiede almeno 1 funzione di base" |
| self.hidden_size = hidden_size |
| self.mlp_mult = mlp_mult |
| self.num_funcs = num_funcs |
|
|
| |
| r_dim = router_dim or hidden_size |
| self.router = nn.Sequential( |
| nn.Linear(hidden_size, r_dim), |
| nn.GELU(), |
| nn.Linear(r_dim, num_funcs), |
| ) |
| self.router_tau = router_tau |
| |
| |
| for m in self.router.modules(): |
| if isinstance(m, nn.Linear): |
| nn.init.normal_(m.weight, mean=0.0, std=0.02) |
| if m.bias is not None: |
| nn.init.zeros_(m.bias) |
|
|
| |
| funcs: List[nn.Module] = [] |
| if num_funcs >= 1: |
| funcs.append(SwigluMLP(hidden_size, mlp_mult)) |
| if num_funcs >= 2: |
| funcs.append(GluMLP(hidden_size, mlp_mult)) |
| if num_funcs >= 3: |
| funcs.append(DepthwiseConvBranch(hidden_size, mlp_mult)) |
| |
| |
| while len(funcs) < num_funcs: |
| funcs.append(SwigluMLP(hidden_size, mlp_mult)) |
|
|
| self.funcs = nn.ModuleList(funcs) |
|
|
| self.dropout = nn.Dropout(dropout) |
| self.use_channel_attention = use_channel_attention |
| self.chan_attn = ChannelAttention(hidden_size) if use_channel_attention else None |
|
|
| |
| self.last_alpha: Optional[torch.Tensor] = None |
| self.last_aux: Optional[torch.Tensor] = None |
|
|
| |
| self.force_func: int = -1 |
|
|
| def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| |
| logits = self.router(x) |
| tau = float(self.router_tau) if self.router_tau is not None and self.router_tau > 0.0 else 1.0 |
| alpha = F.softmax(logits / tau, dim=-1) |
|
|
| |
| if self.force_func is not None and self.force_func >= 0 and self.force_func < self.num_funcs: |
| one_hot = torch.zeros_like(alpha) |
| one_hot[..., self.force_func] = 1.0 |
| alpha = one_hot |
|
|
| |
| ys = [f(x) for f in self.funcs] |
| y_stack = torch.stack(ys, dim=2) |
|
|
| alpha_exp = alpha.unsqueeze(-1) |
| y = (alpha_exp * y_stack).sum(dim=2) |
|
|
| if self.use_channel_attention and self.chan_attn is not None: |
| y = self.chan_attn(y) |
|
|
| y = self.dropout(y) |
|
|
| |
| self.last_alpha = alpha.detach() |
|
|
| self.last_aux = None |
| if self.training: |
| |
| token_ent = router_aux_loss(alpha) |
| |
| p = alpha.mean(dim=(0, 1)) |
| ent = -(p * (p + 1e-9).log()).sum() |
| k = alpha.size(-1) |
| global_ent = ent / math.log(float(k)) |
| |
| self.last_aux = 0.5 * token_ent + 0.5 * global_ent |
|
|
| return y, alpha |
|
|