|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | import torch | 
					
						
						|  | import torch.nn as nn | 
					
						
						|  | import torch.nn.functional as F | 
					
						
						|  | from torch.utils.checkpoint import checkpoint | 
					
						
						|  |  | 
					
						
						|  | from .utils import print_rank_0, column_split | 
					
						
						|  | from .cache import InferenceParams, RecurrentInferenceParams | 
					
						
						|  | from .engine import HyenaInferenceEngine | 
					
						
						|  | from .layers import ( | 
					
						
						|  | RMSNorm, | 
					
						
						|  | ParallelGatedMLP, | 
					
						
						|  | VocabParallelEmbedding, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | try: | 
					
						
						|  | from flash_attn.modules.mha import MHA | 
					
						
						|  | except ImportError: | 
					
						
						|  | "flash_attn not installed" | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class AttentionBlock(nn.Module): | 
					
						
						|  | def __init__(self, config, layer_idx) -> None: | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.config = config | 
					
						
						|  | self.pre_norm, self.post_norm = RMSNorm(config), RMSNorm(config) | 
					
						
						|  | self.layer_idx = layer_idx | 
					
						
						|  | self.proj_groups = config.get("proj_groups", 1) | 
					
						
						|  | dtype = config.get("attn_block_dtype", torch.bfloat16) | 
					
						
						|  | mlp_dtype = config.get("mlp_dtype", torch.bfloat16) | 
					
						
						|  | self.num_attention_heads = config.num_attention_heads | 
					
						
						|  | self.hidden_size_per_attention_head = config.hidden_size // config.num_attention_heads | 
					
						
						|  |  | 
					
						
						|  | self.counter = 0 | 
					
						
						|  | self.inner_mha_cls = MHA( | 
					
						
						|  | embed_dim=config.hidden_size, | 
					
						
						|  | num_heads=config.num_attention_heads, | 
					
						
						|  | num_heads_kv=config.num_attention_heads // self.proj_groups, | 
					
						
						|  | rotary_emb_dim=config.hidden_size // config.num_attention_heads, | 
					
						
						|  | qkv_proj_bias=config.get("qkv_proj_bias", True), | 
					
						
						|  | rotary_emb_base=config.get("rotary_emb_base", 10000), | 
					
						
						|  | causal=True, | 
					
						
						|  | layer_idx=layer_idx, | 
					
						
						|  | out_proj_bias=config.get("mha_out_proj_bias", True), | 
					
						
						|  | use_flash_attn=self.config.use_flash_attn, | 
					
						
						|  | ).to(dtype=dtype) | 
					
						
						|  |  | 
					
						
						|  | if self.config.get("smeared_gqa", False): | 
					
						
						|  | self.inner_mha_cls.num_heads_kv = self.inner_mha_cls.num_heads | 
					
						
						|  | self.inner_mha_cls.rotary_emb.register_buffer( | 
					
						
						|  | "inv_freq", self.inner_mha_cls.rotary_emb.inv_freq | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | self.mlp = ParallelGatedMLP(config).to(dtype=mlp_dtype) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, u, inference_params=None, padding_mask=None, *args, **kwargs): | 
					
						
						|  | if ( | 
					
						
						|  | type(padding_mask) == torch.Tensor | 
					
						
						|  | ): | 
					
						
						|  |  | 
					
						
						|  | u = u * padding_mask[..., None] | 
					
						
						|  |  | 
					
						
						|  | u = ( | 
					
						
						|  | self.inner_mha_cls( | 
					
						
						|  | self.pre_norm(u), | 
					
						
						|  | inference_params=inference_params, | 
					
						
						|  | ) | 
					
						
						|  | + u | 
					
						
						|  | ) | 
					
						
						|  | if type(padding_mask) == torch.Tensor: | 
					
						
						|  | u = u * padding_mask[..., None] | 
					
						
						|  | u = self.mlp(self.post_norm(u)) + u | 
					
						
						|  | return u, None | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class ParallelHyenaFilter(nn.Module): | 
					
						
						|  | def __init__(self, config, layer_idx) -> None: | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.config = config | 
					
						
						|  | self.layer_idx = layer_idx | 
					
						
						|  | self.hyena_filter_groups = config.get("hyena_filter_groups", self.config.hidden_size) | 
					
						
						|  |  | 
					
						
						|  | self.use_flashfft = config.get("use_flashfft", False) | 
					
						
						|  | self.state_size = config.state_size | 
					
						
						|  | self.hidden_size = config.hidden_size | 
					
						
						|  | self.num_filters = config.num_filters | 
					
						
						|  | self.inference_mode = config.get("inference_mode", True) | 
					
						
						|  | self.counter = 0 | 
					
						
						|  | self.column_split_hyena = config.get("column_split_hyena", True) | 
					
						
						|  |  | 
					
						
						|  | assert self.hidden_size % self.num_filters == 0 and self.num_filters <= self.hidden_size | 
					
						
						|  |  | 
					
						
						|  | self.D = nn.Parameter(torch.zeros(self.hidden_size)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.num_attention_heads = config.num_attention_heads | 
					
						
						|  | self.hidden_size_per_attention_head = self.hidden_size // self.num_attention_heads | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.short_filter_length = config.short_filter_length | 
					
						
						|  | self.short_filter_weight = nn.Parameter( | 
					
						
						|  | torch.randn(3 * config.hidden_size, 1, config.short_filter_length) | 
					
						
						|  | ) | 
					
						
						|  | self.short_filter_bias = ( | 
					
						
						|  | nn.Parameter(torch.randn(3 * config.hidden_size)) if config.short_filter_bias else None | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | self.engine = HyenaInferenceEngine(layer_idx=layer_idx) | 
					
						
						|  | self.use_flash_depthwise = config.get("use_flash_depthwise", False) | 
					
						
						|  | self.data_dtype = None | 
					
						
						|  |  | 
					
						
						|  | if self.use_flash_depthwise: | 
					
						
						|  | self.fir_fn = FlashDepthwiseConv1d( | 
					
						
						|  | channels=3 * self.hidden_size, | 
					
						
						|  | kernel_size=self.short_filter_length, | 
					
						
						|  | padding=self.short_filter_length - 1, | 
					
						
						|  | weights=self.short_filter_weight, | 
					
						
						|  | bias=self.short_filter_bias, | 
					
						
						|  | device=None, | 
					
						
						|  | dtype=self.config.get("depthwise_dtype", torch.bfloat16), | 
					
						
						|  | ) | 
					
						
						|  | else: | 
					
						
						|  | self.fir_fn = F.conv1d | 
					
						
						|  |  | 
					
						
						|  | self.fftconv_fn = None | 
					
						
						|  | self.long_fir_threshold = config.get("long_fir_threshold", None) | 
					
						
						|  | if self.long_fir_threshold is not None: | 
					
						
						|  | assert ( | 
					
						
						|  | self.use_flashfft is False | 
					
						
						|  | ), "long_fir_threshold not compatible with fused flashfft" | 
					
						
						|  |  | 
					
						
						|  | self.num_systems = self.hidden_size // self.hyena_filter_groups | 
					
						
						|  | self.poles = nn.Parameter(torch.randn(self.num_systems, self.state_size, 1, 2)) | 
					
						
						|  | self.residues = nn.Parameter(torch.randn(self.num_systems, self.state_size, 1, 2)) | 
					
						
						|  | self.h = None | 
					
						
						|  |  | 
					
						
						|  | def forward(self, u, inference_params=None, padding_mask=None, *args, **kwargs): | 
					
						
						|  | if ( | 
					
						
						|  | inference_params is not None | 
					
						
						|  | and self.layer_idx in inference_params.fir_state_dict.keys() | 
					
						
						|  | ): | 
					
						
						|  | return self.sequential_forward(u, inference_params) | 
					
						
						|  |  | 
					
						
						|  | else: | 
					
						
						|  | return self.parallel_forward(u, inference_params, padding_mask) | 
					
						
						|  |  | 
					
						
						|  | def parallel_forward(self, u, inference_params=None, padding_mask=None): | 
					
						
						|  | L = u.shape[1] | 
					
						
						|  | z_pre, fir_state = self.engine.parallel_fir( | 
					
						
						|  | self.fir_fn, | 
					
						
						|  | u, | 
					
						
						|  | self.short_filter_weight, | 
					
						
						|  | self.short_filter_bias, | 
					
						
						|  | L, | 
					
						
						|  | fir_length=self.short_filter_length, | 
					
						
						|  | inference_params=inference_params, | 
					
						
						|  | padding_mask=padding_mask, | 
					
						
						|  | ) | 
					
						
						|  | if inference_params: | 
					
						
						|  | inference_params.fir_state_dict[self.layer_idx] = fir_state | 
					
						
						|  |  | 
					
						
						|  | if self.h is None: | 
					
						
						|  | h, filter_dtype, poles, residues = self.compute_filter(L, u.device) | 
					
						
						|  | else: | 
					
						
						|  | h = self.h | 
					
						
						|  | filter_dtype = self.h.dtype | 
					
						
						|  |  | 
					
						
						|  | if self.hyena_filter_groups > 1: | 
					
						
						|  | h = h.repeat_interleave(self.hidden_size // self.hyena_filter_groups, 1) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | dims = ( | 
					
						
						|  | self.hidden_size, | 
					
						
						|  | self.num_attention_heads, | 
					
						
						|  | self.hidden_size_per_attention_head, | 
					
						
						|  | self.state_size, | 
					
						
						|  | self.hyena_filter_groups, | 
					
						
						|  | ) | 
					
						
						|  | y = self.engine.parallel_iir( | 
					
						
						|  | z_pre, | 
					
						
						|  | h, | 
					
						
						|  | self.D, | 
					
						
						|  | L, | 
					
						
						|  | t=self.t, | 
					
						
						|  | poles=self.poles, | 
					
						
						|  | dims=dims, | 
					
						
						|  | inference_params=inference_params, | 
					
						
						|  | layer_idx=self.layer_idx, | 
					
						
						|  | prefill_style=self.config.get("prefill_style", "fft"), | 
					
						
						|  | use_flashfft=self.use_flashfft, | 
					
						
						|  | fftconv_fn=self.fftconv_fn, | 
					
						
						|  | column_split_hyena=self.column_split_hyena, | 
					
						
						|  | long_fir_threshold=self.long_fir_threshold, | 
					
						
						|  | padding_mask=padding_mask, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | return y, inference_params | 
					
						
						|  |  | 
					
						
						|  | def sequential_forward(self, u, inference_params): | 
					
						
						|  | if self.data_dtype is None: | 
					
						
						|  | self.data_dtype = u.dtype | 
					
						
						|  | if len(u.shape) > 2: | 
					
						
						|  | u = u[:, -1] | 
					
						
						|  |  | 
					
						
						|  | fir_state, iir_state = ( | 
					
						
						|  | inference_params.fir_state_dict[self.layer_idx], | 
					
						
						|  | inference_params.state_dict[self.layer_idx], | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | z_pre, fir_state = self.engine.step_fir( | 
					
						
						|  | u, fir_state, weight=self.short_filter_weight, bias=self.short_filter_bias | 
					
						
						|  | ) | 
					
						
						|  | x2, x1, v = ( | 
					
						
						|  | column_split(z_pre, self.num_attention_heads, self.hidden_size_per_attention_head) | 
					
						
						|  | if self.column_split_hyena | 
					
						
						|  | else z_pre.split([self.hidden_size, self.hidden_size, self.hidden_size], dim=1) | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | y, iir_state = self.engine.step_iir( | 
					
						
						|  | x2, | 
					
						
						|  | x1, | 
					
						
						|  | v, | 
					
						
						|  | self.D, | 
					
						
						|  | self.residues, | 
					
						
						|  | self.poles, | 
					
						
						|  | iir_state, | 
					
						
						|  | iir_groups=self.hyena_filter_groups, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | inference_params.fir_state_dict[self.layer_idx] = fir_state | 
					
						
						|  | inference_params.state_dict[self.layer_idx] = iir_state | 
					
						
						|  | y = y.to(dtype=self.data_dtype) | 
					
						
						|  | return y[:, None], inference_params | 
					
						
						|  |  | 
					
						
						|  | def update_time(self, L, device): | 
					
						
						|  | """ | 
					
						
						|  | Set [0, 1, ..., L-1] where L is the length of the current batch of inputs. | 
					
						
						|  | If L is greater than the length of the previous batch, then the time vector is | 
					
						
						|  | reinitialized. Otherwise, the time vector is truncated from cache. | 
					
						
						|  | """ | 
					
						
						|  | if not hasattr(self, "t"): | 
					
						
						|  | self.t = torch.arange(L, device=device)[None, None] | 
					
						
						|  | elif self.t.shape[-1] < L: | 
					
						
						|  | self.t = torch.arange(L, device=device)[None, None] | 
					
						
						|  | else: | 
					
						
						|  | self.t = self.t[..., :L] | 
					
						
						|  |  | 
					
						
						|  | def compute_filter(self, L, device): | 
					
						
						|  | self.update_time(L, device) | 
					
						
						|  | filter_dtype = torch.float32 | 
					
						
						|  | residues, log_poles = ( | 
					
						
						|  | torch.view_as_complex(self.residues.to(filter_dtype)), | 
					
						
						|  | torch.view_as_complex(self.poles.to(filter_dtype)).log(), | 
					
						
						|  | ) | 
					
						
						|  | h = (residues * (log_poles * self.t).exp()).real.sum(1)[None] | 
					
						
						|  | return h, filter_dtype, log_poles, residues | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class ParallelGatedConvBlock(nn.Module): | 
					
						
						|  | def __init__(self, config, layer_idx) -> None: | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.config = config | 
					
						
						|  | self.layer_idx = layer_idx | 
					
						
						|  | dtype = config.get("hyena_block_dtype", torch.float32) | 
					
						
						|  | mlp_dtype = config.get("mlp_dtype", torch.bfloat16) | 
					
						
						|  | self.pre_norm, self.post_norm = RMSNorm(config).to(dtype=dtype), RMSNorm(config).to( | 
					
						
						|  | dtype=dtype | 
					
						
						|  | ) | 
					
						
						|  | self.filter = ParallelHyenaFilter(config, layer_idx).to(dtype=dtype) | 
					
						
						|  | self.projections = nn.Linear(config.hidden_size, 3 * config.hidden_size) | 
					
						
						|  | self.out_filter_dense = nn.Linear(config.hidden_size, config.hidden_size).to(dtype) | 
					
						
						|  | self.mlp = ParallelGatedMLP(config).to(dtype=mlp_dtype) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, u, inference_params=None, padding_mask=None, *args, **kwargs): | 
					
						
						|  | z = self.projections(self.pre_norm(u)) | 
					
						
						|  | if type(padding_mask) == torch.Tensor: | 
					
						
						|  | z = z * padding_mask[..., None] | 
					
						
						|  |  | 
					
						
						|  | z, inference_params = self.filter( | 
					
						
						|  | z, inference_params=inference_params, padding_mask=padding_mask | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | u = self.out_filter_dense(z) + u | 
					
						
						|  | if type(padding_mask) == torch.Tensor: | 
					
						
						|  | u = u * padding_mask[..., None] | 
					
						
						|  | u = self.mlp(self.post_norm(u)) + u | 
					
						
						|  | return u, inference_params | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def get_block(config, layer_idx, flash_fft=None): | 
					
						
						|  | if layer_idx in config.attn_layer_idxs: | 
					
						
						|  | return AttentionBlock(config, layer_idx) | 
					
						
						|  | elif layer_idx in config.hyena_layer_idxs: | 
					
						
						|  | block = ParallelGatedConvBlock(config, layer_idx) | 
					
						
						|  | if config.get("use_flashfft", "False"): | 
					
						
						|  | block.filter.fftconv_fn = flash_fft | 
					
						
						|  | return block | 
					
						
						|  | else: | 
					
						
						|  | raise NotImplementedError | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class StripedHyena(nn.Module): | 
					
						
						|  | def __init__(self, config): | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.config = config | 
					
						
						|  | self.embedding_layer = VocabParallelEmbedding(config) | 
					
						
						|  | self.norm = RMSNorm(config) if config.get("final_norm", True) else None | 
					
						
						|  | self.unembed = self.emb if config.tie_embeddings else VocabParallelEmbedding(config) | 
					
						
						|  | self.gradient_checkpointing = False | 
					
						
						|  |  | 
					
						
						|  | if config.get("use_flashfft", "False"): | 
					
						
						|  | raise NotImplementedError("Please use standalone SH code for other custom kernels") | 
					
						
						|  | else: | 
					
						
						|  | self.flash_fft = None | 
					
						
						|  |  | 
					
						
						|  | self.blocks = nn.ModuleList( | 
					
						
						|  | get_block(config, layer_idx, flash_fft=self.flash_fft) | 
					
						
						|  | for layer_idx in range(config.num_layers) | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x, inference_params_dict=None, padding_mask=None): | 
					
						
						|  | L = x.shape[1] | 
					
						
						|  | x = self.embedding_layer.embed(x) | 
					
						
						|  | if inference_params_dict is not None: | 
					
						
						|  | x, inference_params_dict_out = self.stateful_forward( | 
					
						
						|  | x, | 
					
						
						|  | inference_params_dict=inference_params_dict, | 
					
						
						|  | ) | 
					
						
						|  | else: | 
					
						
						|  | x, inference_params_dict_out = self.stateless_forward(x, padding_mask=padding_mask) | 
					
						
						|  | x = self.norm(x) | 
					
						
						|  | x = self.unembed.unembed(x) | 
					
						
						|  | return x, inference_params_dict_out | 
					
						
						|  |  | 
					
						
						|  | def stateful_forward(self, x, inference_params_dict=None): | 
					
						
						|  | for block_idx, block in enumerate(self.blocks): | 
					
						
						|  | block_name = "mha" if block_idx in self.config.attn_layer_idxs else "hyena" | 
					
						
						|  | inference_params = inference_params_dict[block_name] | 
					
						
						|  | x, _ = block(x, inference_params=inference_params) | 
					
						
						|  |  | 
					
						
						|  | return x, inference_params_dict | 
					
						
						|  |  | 
					
						
						|  | def stateless_forward(self, x, padding_mask=None): | 
					
						
						|  | if type(padding_mask) == torch.Tensor: | 
					
						
						|  | x = x * padding_mask[..., None] | 
					
						
						|  |  | 
					
						
						|  | for block_idx, block in enumerate(self.blocks): | 
					
						
						|  | if self.gradient_checkpointing and self.training: | 
					
						
						|  | def create_custom_forward(module): | 
					
						
						|  | def custom_forward(*inputs): | 
					
						
						|  |  | 
					
						
						|  | return module(*inputs, inference_params=None, padding_mask=padding_mask) | 
					
						
						|  |  | 
					
						
						|  | return custom_forward | 
					
						
						|  |  | 
					
						
						|  | x, _ = checkpoint(create_custom_forward(block), x, use_reentrant=False) | 
					
						
						|  | else: | 
					
						
						|  | x, _ = block(x, inference_params=None, padding_mask=padding_mask) | 
					
						
						|  | return x, None | 
					
						
						|  |  | 
					
						
						|  | def initialize_inference_params(self): | 
					
						
						|  | print_rank_0("Initializing inference params...") | 
					
						
						|  | inference_params_dict = { | 
					
						
						|  | "mha": InferenceParams( | 
					
						
						|  | max_seqlen=self.config.get("max_seqlen", 8192), | 
					
						
						|  | max_batch_size=self.config.get("max_batch_size", 1), | 
					
						
						|  | seqlen_offset=0, | 
					
						
						|  | ), | 
					
						
						|  | "hyena": RecurrentInferenceParams( | 
					
						
						|  | fir_filter_length=self.config.short_filter_length, | 
					
						
						|  | state_dim=self.config.state_size, | 
					
						
						|  | seqlen_offset=0, | 
					
						
						|  | ), | 
					
						
						|  | } | 
					
						
						|  | return inference_params_dict | 
					
						
						|  |  | 
					
						
						|  | def precompute_filters(self, L, device): | 
					
						
						|  | for block_idx, block in enumerate(self.blocks): | 
					
						
						|  | if type(block) == ParallelGatedConvBlock: | 
					
						
						|  | if type(block.filter) == ParallelHyenaFilter: | 
					
						
						|  | L = block.filter.long_fir_threshold or L | 
					
						
						|  | print_rank_0(f"Precomputing filters, L={L}...") | 
					
						
						|  |  | 
					
						
						|  | filter_dtype = torch.float16 if L >= 2048 else torch.float32 | 
					
						
						|  |  | 
					
						
						|  | block.filter._set_time(L, device) | 
					
						
						|  | residues, poles = ( | 
					
						
						|  | torch.view_as_complex(block.filter.residues.to(torch.float16)), | 
					
						
						|  | torch.view_as_complex(block.filter.poles.to(torch.float16)), | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | block.filter.h = (residues * poles**block.filter.t).real.sum(1)[None] | 
					
						
						|  | block.filter.h = block.filter.h.to(dtype=filter_dtype) | 
					
						
						|  |  | 
					
						
						|  | def load_poles_residues(self, path): | 
					
						
						|  | "Load different poles and residues for each layer." | 
					
						
						|  | for block_idx, block in enumerate(self.blocks): | 
					
						
						|  | if type(block) == ParallelGatedConvBlock: | 
					
						
						|  | if type(block.filter) == ParallelHyenaFilter: | 
					
						
						|  | print(f"Loading poles and residues for block {block_idx}") | 
					
						
						|  | poles = torch.load(path + f"/approx_poles_{block_idx+1}.pt", map_location="cpu") | 
					
						
						|  | poles = torch.view_as_real(poles) | 
					
						
						|  | residues = torch.load( | 
					
						
						|  | path + f"/approx_residues_{block_idx+1}.pt", map_location="cpu" | 
					
						
						|  | ) | 
					
						
						|  | residues = torch.view_as_real(residues) | 
					
						
						|  | poles = poles.permute(1, 0, 2).unsqueeze(-2) | 
					
						
						|  | residues = residues.permute(1, 0, 2).unsqueeze(-2) | 
					
						
						|  |  | 
					
						
						|  | block.filter.poles = nn.Parameter(poles) | 
					
						
						|  | block.filter.residues = nn.Parameter(residues) | 
					
						
						|  |  | 
					
						
						|  | def to_bfloat16_except_poles_residues(self): | 
					
						
						|  | """Convert all parameters to bfloat16 except for the poles and residues. | 
					
						
						|  |  | 
					
						
						|  | Particularly important for longer prompts. | 
					
						
						|  | """ | 
					
						
						|  | for k, p in self.named_parameters(): | 
					
						
						|  | if "poles" not in k and "residues" not in k: | 
					
						
						|  | p.data = p.data.to(torch.bfloat16) | 
					
						
						|  |  |