| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						import math | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						import torch | 
					
					
						
						| 
							 | 
						import torch.nn as nn | 
					
					
						
						| 
							 | 
						import torch.nn.functional as F | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						from einops import rearrange | 
					
					
						
						| 
							 | 
						import opt_einsum as oe | 
					
					
						
						| 
							 | 
						contract = oe.contract | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						""" Utils for the training loop. Copied from https://github.com/HazyResearch/transformers/blob/master/src/utils/utils.py """ | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class OptimModule(nn.Module): | 
					
					
						
						| 
							 | 
						    """ Interface for Module that allows registering buffers/parameters with configurable optimizer hyperparameters """ | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def register(self, name, tensor, lr=None, wd=0.0): | 
					
					
						
						| 
							 | 
						        """Register a tensor with a configurable learning rate and 0 weight decay""" | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if lr == 0.0: | 
					
					
						
						| 
							 | 
						            self.register_buffer(name, tensor) | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            self.register_parameter(name, nn.Parameter(tensor)) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            optim = {} | 
					
					
						
						| 
							 | 
						            if lr is not None: optim["lr"] = lr | 
					
					
						
						| 
							 | 
						            if wd is not None: optim["weight_decay"] = wd | 
					
					
						
						| 
							 | 
						            setattr(getattr(self, name), "_optim", optim) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def fftconv_ref(u, k, D, dropout_mask, gelu=True, k_rev=None): | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    seqlen = u.shape[-1] | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    fft_size = 2 * seqlen | 
					
					
						
						| 
							 | 
						    k_f = torch.fft.rfft(k, n=fft_size) / fft_size | 
					
					
						
						| 
							 | 
						    if k_rev is not None: | 
					
					
						
						| 
							 | 
						        k_rev_f = torch.fft.rfft(k_rev, n=fft_size) / fft_size | 
					
					
						
						| 
							 | 
						        k_f = k_f + k_rev_f.conj() | 
					
					
						
						| 
							 | 
						    u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if len(u.shape) > 3: | 
					
					
						
						| 
							 | 
						        k_f = k_f.unsqueeze(1) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    y = torch.fft.irfft(u_f * k_f, n=fft_size, norm="forward")[..., :seqlen] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    out = y + u * D | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if gelu: | 
					
					
						
						| 
							 | 
						        out = F.gelu(out) | 
					
					
						
						| 
							 | 
						    if dropout_mask is not None: | 
					
					
						
						| 
							 | 
						        return (out * rearrange(dropout_mask, "b H -> b H 1")).to(dtype=u.dtype) | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						        return out.to(dtype=u.dtype) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						@torch.jit.script | 
					
					
						
						| 
							 | 
						def mul_sum(q, y): | 
					
					
						
						| 
							 | 
						    return (q * y).sum(dim=1) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class Sin(nn.Module): | 
					
					
						
						| 
							 | 
						    def __init__(self, dim, w=10, w_mod=1, train_freq=True): | 
					
					
						
						| 
							 | 
						        super().__init__() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        init_tensor = torch.ones(1, dim) | 
					
					
						
						| 
							 | 
						        self.freq = ( | 
					
					
						
						| 
							 | 
						            nn.Parameter(w * init_tensor) | 
					
					
						
						| 
							 | 
						            if train_freq | 
					
					
						
						| 
							 | 
						            else w * torch.ones(1, dim) | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						        self.w_mod = w_mod | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def forward(self, x): | 
					
					
						
						| 
							 | 
						        return torch.sin(self.w_mod * self.freq * x) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class PositionalEmbedding(OptimModule): | 
					
					
						
						| 
							 | 
						    def __init__(self, emb_dim: int, seq_len: int, lr_pos_emb: float = 1e-5, **kwargs): | 
					
					
						
						| 
							 | 
						        """Complex exponential positional embeddings for Hyena filters.""" | 
					
					
						
						| 
							 | 
						        super().__init__() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.seq_len = seq_len | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        t = torch.linspace(0, 1, self.seq_len)[None, :, None]   | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if emb_dim > 1: | 
					
					
						
						| 
							 | 
						            bands = (emb_dim - 1) // 2 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        t_rescaled = torch.linspace(0, seq_len - 1, seq_len)[None, :, None] | 
					
					
						
						| 
							 | 
						        w = 2 * math.pi * t_rescaled / seq_len   | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        f = torch.linspace(1e-4, bands - 1, bands)[None, None] | 
					
					
						
						| 
							 | 
						        z = torch.exp(-1j * f * w) | 
					
					
						
						| 
							 | 
						        z = torch.cat([t, z.real, z.imag], dim=-1) | 
					
					
						
						| 
							 | 
						        self.register("z", z, lr=lr_pos_emb) | 
					
					
						
						| 
							 | 
						        self.register("t", t, lr=0.0) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def forward(self, L): | 
					
					
						
						| 
							 | 
						        return self.z[:, :L], self.t[:, :L] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class ExponentialModulation(OptimModule): | 
					
					
						
						| 
							 | 
						    def __init__( | 
					
					
						
						| 
							 | 
						        self, | 
					
					
						
						| 
							 | 
						        d_model, | 
					
					
						
						| 
							 | 
						        fast_decay_pct=0.3, | 
					
					
						
						| 
							 | 
						        slow_decay_pct=1.5, | 
					
					
						
						| 
							 | 
						        target=1e-2, | 
					
					
						
						| 
							 | 
						        modulation_lr=0.0, | 
					
					
						
						| 
							 | 
						        shift: float = 0.0, | 
					
					
						
						| 
							 | 
						        **kwargs, | 
					
					
						
						| 
							 | 
						    ): | 
					
					
						
						| 
							 | 
						        super().__init__() | 
					
					
						
						| 
							 | 
						        self.shift = shift | 
					
					
						
						| 
							 | 
						        max_decay = math.log(target) / fast_decay_pct | 
					
					
						
						| 
							 | 
						        min_decay = math.log(target) / slow_decay_pct | 
					
					
						
						| 
							 | 
						        deltas = torch.linspace(min_decay, max_decay, d_model)[None, None] | 
					
					
						
						| 
							 | 
						        self.register("deltas", deltas, lr=modulation_lr) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def forward(self, t, x): | 
					
					
						
						| 
							 | 
						        decay = torch.exp(-t * self.deltas.abs()) | 
					
					
						
						| 
							 | 
						        x = x * (decay + self.shift) | 
					
					
						
						| 
							 | 
						        return x | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class HyenaFilter(OptimModule): | 
					
					
						
						| 
							 | 
						    def __init__( | 
					
					
						
						| 
							 | 
						        self, | 
					
					
						
						| 
							 | 
						        d_model, | 
					
					
						
						| 
							 | 
						        emb_dim=3,   | 
					
					
						
						| 
							 | 
						        order=16,   | 
					
					
						
						| 
							 | 
						        seq_len=1024, | 
					
					
						
						| 
							 | 
						        lr=1e-3, | 
					
					
						
						| 
							 | 
						        lr_pos_emb=1e-5, | 
					
					
						
						| 
							 | 
						        dropout=0.0, | 
					
					
						
						| 
							 | 
						        w=1,   | 
					
					
						
						| 
							 | 
						        w_mod=1,  | 
					
					
						
						| 
							 | 
						        wd=0,   | 
					
					
						
						| 
							 | 
						        bias=True, | 
					
					
						
						| 
							 | 
						        num_inner_mlps=2, | 
					
					
						
						| 
							 | 
						        linear_mixer=False, | 
					
					
						
						| 
							 | 
						        modulate: bool = True, | 
					
					
						
						| 
							 | 
						        normalized=False, | 
					
					
						
						| 
							 | 
						        bidirectional=False, | 
					
					
						
						| 
							 | 
						        **kwargs, | 
					
					
						
						| 
							 | 
						    ): | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        Implicit long filter with modulation. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						        Args: | 
					
					
						
						| 
							 | 
						            d_model: number of channels in the input | 
					
					
						
						| 
							 | 
						            emb_dim: dimension of the positional encoding (`emb_dim` - 1) // 2 is the number of bands | 
					
					
						
						| 
							 | 
						            order: width of the FFN | 
					
					
						
						| 
							 | 
						            num_inner_mlps: number of inner linear layers inside filter MLP | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						        Note: | 
					
					
						
						| 
							 | 
						            filter_dropout is not implemented | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        super().__init__() | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        self.d_model=d_model | 
					
					
						
						| 
							 | 
						        self.emb_dim=emb_dim | 
					
					
						
						| 
							 | 
						        self.seq_len=seq_len | 
					
					
						
						| 
							 | 
						        self.modulate=modulate | 
					
					
						
						| 
							 | 
						        self.use_bias = bias | 
					
					
						
						| 
							 | 
						        self.bidirectional = bidirectional | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.bias = nn.Parameter(torch.randn(self.d_model)) | 
					
					
						
						| 
							 | 
						        self.dropout = nn.Dropout(dropout) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        act = Sin(dim=order, w=w, w_mod=w_mod) | 
					
					
						
						| 
							 | 
						        assert ( | 
					
					
						
						| 
							 | 
						            emb_dim % 2 != 0 and emb_dim >= 3 | 
					
					
						
						| 
							 | 
						        ), "emb_dim must be odd and greater or equal to 3 (time, sine and cosine)" | 
					
					
						
						| 
							 | 
						        self.pos_emb = PositionalEmbedding(emb_dim, seq_len, lr_pos_emb) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        if linear_mixer is False: | 
					
					
						
						| 
							 | 
						            self.implicit_filter = nn.Sequential( | 
					
					
						
						| 
							 | 
						                nn.Linear(emb_dim, order), | 
					
					
						
						| 
							 | 
						                act, | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						            for i in range(num_inner_mlps): | 
					
					
						
						| 
							 | 
						                self.implicit_filter.append(nn.Linear(order, order)) | 
					
					
						
						| 
							 | 
						                self.implicit_filter.append(act) | 
					
					
						
						| 
							 | 
						            self.implicit_filter.append(nn.Linear(order, d_model, bias=False)) | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            self.implicit_filter = nn.Sequential( | 
					
					
						
						| 
							 | 
						                nn.Linear(emb_dim, d_model, bias=False), | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if self.bidirectional: | 
					
					
						
						| 
							 | 
						            self.implicit_filter_rev = nn.Sequential( | 
					
					
						
						| 
							 | 
						                nn.Linear(emb_dim, order), | 
					
					
						
						| 
							 | 
						                act, | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						            for i in range(num_inner_mlps): | 
					
					
						
						| 
							 | 
						                self.implicit_filter_rev.append(nn.Linear(order, order)) | 
					
					
						
						| 
							 | 
						                self.implicit_filter_rev.append(act) | 
					
					
						
						| 
							 | 
						            self.implicit_filter_rev.append(nn.Linear(order, d_model, bias=False)) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.modulation = ExponentialModulation(d_model, **kwargs) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.normalized = normalized | 
					
					
						
						| 
							 | 
						        for c in self.implicit_filter.children(): | 
					
					
						
						| 
							 | 
						            for name, v in c.state_dict().items(): | 
					
					
						
						| 
							 | 
						                optim = {"weight_decay": wd, "lr": lr} | 
					
					
						
						| 
							 | 
						                setattr(getattr(c, name), "_optim", optim) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def filter(self, L, *args, **kwargs): | 
					
					
						
						| 
							 | 
						        z, t = self.pos_emb(L) | 
					
					
						
						| 
							 | 
						        h = self.implicit_filter(z) | 
					
					
						
						| 
							 | 
						        if self.modulate: | 
					
					
						
						| 
							 | 
						            h = self.modulation(t, h) | 
					
					
						
						| 
							 | 
						        if self.normalized: | 
					
					
						
						| 
							 | 
						            h = h / torch.norm(h, dim=-1, p=1, keepdim=True) | 
					
					
						
						| 
							 | 
						        return h | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    def filter_rev(self, L, *args, **kwargs): | 
					
					
						
						| 
							 | 
						        z, t = self.pos_emb(L) | 
					
					
						
						| 
							 | 
						        h = self.implicit_filter_rev(z) | 
					
					
						
						| 
							 | 
						        if self.modulate: | 
					
					
						
						| 
							 | 
						            h = self.modulation(t, h) | 
					
					
						
						| 
							 | 
						        if self.normalized: | 
					
					
						
						| 
							 | 
						            h = h / torch.norm(h, dim=-1, p=1, keepdim=True) | 
					
					
						
						| 
							 | 
						        return h | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def forward(self, x, L, k_fwd=None, k_rev=None, bias=None, *args, **kwargs): | 
					
					
						
						| 
							 | 
						        if k_fwd is None: | 
					
					
						
						| 
							 | 
						            k_fwd = self.filter(L) | 
					
					
						
						| 
							 | 
						            if self.bidirectional and k_rev is None: | 
					
					
						
						| 
							 | 
						                k_rev = self.filter_rev(L) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        k_fwd = k_fwd[0] if type(k_fwd) is tuple else k_fwd | 
					
					
						
						| 
							 | 
						        if bias is None: | 
					
					
						
						| 
							 | 
						            bias = self.bias | 
					
					
						
						| 
							 | 
						        bias = bias if self.use_bias else 0 * bias | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if self.bidirectional: | 
					
					
						
						| 
							 | 
						            k_rev = k_rev[0] if type(k_rev) is tuple else k_rev | 
					
					
						
						| 
							 | 
						            k = F.pad(k_fwd, (0, L)) \ | 
					
					
						
						| 
							 | 
						                      + F.pad(k_rev.flip(-1), (L, 0)) | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            k = k_fwd | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        y = fftconv_ref( | 
					
					
						
						| 
							 | 
						            x,  | 
					
					
						
						| 
							 | 
						            k,  | 
					
					
						
						| 
							 | 
						            bias,  | 
					
					
						
						| 
							 | 
						            dropout_mask=None, | 
					
					
						
						| 
							 | 
						            gelu=False, | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        return y.to(dtype=x.dtype) |