| from transformers import PretrainedConfig | |
| from typing import List | |
| class MoonshineConfig(PretrainedConfig): | |
| model_type = "moonshine" | |
| def __init__( | |
| self, | |
| dim: int = 288, | |
| inner_dim: int = None, | |
| enc_depth: int = 8, | |
| dec_depth: int = 8, | |
| n_head: int = 8, | |
| dec_voc_size: int = 32768, | |
| enc_ff_swiglu: bool = False, | |
| dec_ff_swiglu: bool = True, | |
| **kwargs | |
| ): | |
| if inner_dim is None: | |
| inner_dim = dim | |
| if inner_dim % n_head != 0: | |
| raise ValueError("`inner dim` must be divisible by `n_head`") | |
| self.dim = dim | |
| self.inner_dim = inner_dim | |
| self.enc_depth = enc_depth | |
| self.dec_depth = dec_depth | |
| self.n_head = n_head | |
| self.dec_voc_size = dec_voc_size | |
| self.enc_ff_swiglu = enc_ff_swiglu | |
| self.dec_ff_swiglu = dec_ff_swiglu | |
| super().__init__(**kwargs) | |