|
|
|
|
|
|
|
|
|
|
|
"""OLMo 2 configuration.""" |
|
|
|
from transformers.configuration_utils import PretrainedConfig |
|
from transformers.utils import logging |
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
class InstellaConfig(PretrainedConfig): |
|
r""" |
|
This is the configuration class to store the configuration of a [`InstellaModel`]. It is used to instantiate an OLMo2 |
|
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the |
|
defaults will yield a similar configuration to that of the [allenai/Instella-7B-1124-hf](https://huggingface.co/allenai/Instella-7B-1124-hf). |
|
|
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the |
|
documentation from [`PretrainedConfig`] for more information. |
|
|
|
|
|
Args: |
|
vocab_size (`int`, *optional*, defaults to 50304): |
|
Vocabulary size of the Instella model. Defines the number of different tokens that can be represented by the |
|
`inputs_ids` passed when calling [`InstellaModel`] |
|
hidden_size (`int`, *optional*, defaults to 4096): |
|
Dimension of the hidden representations. |
|
intermediate_size (`int`, *optional*, defaults to 11008): |
|
Dimension of the MLP representations. |
|
num_hidden_layers (`int`, *optional*, defaults to 32): |
|
Number of hidden layers in the Transformer decoder. |
|
num_attention_heads (`int`, *optional*, defaults to 32): |
|
Number of attention heads for each attention layer in the Transformer decoder. |
|
num_key_value_heads (`int`, *optional*): |
|
This is the number of key_value heads that should be used to implement Grouped Query Attention. If |
|
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if |
|
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When |
|
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed |
|
by meanpooling all the original heads within that group. For more details checkout [this |
|
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to |
|
`num_attention_heads`. |
|
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): |
|
The non-linear activation function (function or string) in the decoder. |
|
max_position_embeddings (`int`, *optional*, defaults to 2048): |
|
The maximum sequence length that this model might ever be used with. |
|
initializer_range (`float`, *optional*, defaults to 0.02): |
|
The standard deviation of the truncated_normal_initializer for initializing all weight matrices. |
|
use_cache (`bool`, *optional*, defaults to `True`): |
|
Whether or not the model should return the last key/values attentions (not used by all models). Only |
|
relevant if `config.is_decoder=True`. |
|
pad_token_id (`int`, *optional*, defaults to 1): |
|
Padding token id. |
|
bos_token_id (`int`, *optional*): |
|
Beginning of stream token id. |
|
eos_token_id (`int`, *optional*, defaults to 50279): |
|
End of stream token id. |
|
tie_word_embeddings (`bool`, *optional*, defaults to `False`): |
|
Whether to tie weight embeddings |
|
rope_theta (`float`, *optional*, defaults to 10000.0): |
|
The base period of the RoPE embeddings. |
|
rope_scaling (`Dict`, *optional*): |
|
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling |
|
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is |
|
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update |
|
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how |
|
these scaling strategies behave: |
|
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an |
|
experimental feature, subject to breaking API changes in future versions. |
|
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): |
|
Whether to use a bias in the query, key, value and output projection layers during self-attention. |
|
attention_dropout (`float`, *optional*, defaults to 0.0): |
|
The dropout ratio for the attention probabilities. |
|
rms_norm_eps (`float`, *optional*, defaults to 1e-05): |
|
The epsilon used by the rms normalization layers. |
|
|
|
```python |
|
>>> from transformers import InstellaModel, InstellaConfig |
|
|
|
>>> # Initializing a Instella 7B style configuration |
|
>>> configuration = InstellaConfig() |
|
|
|
>>> # Initializing a model from the Instella 7B style configuration |
|
>>> model = InstellaModel(configuration) |
|
|
|
>>> # Accessing the model configuration |
|
>>> configuration = model.config |
|
``` |
|
""" |
|
|
|
model_type = "instella" |
|
keys_to_ignore_at_inference = ["past_key_values"] |
|
|
|
def __init__( |
|
self, |
|
vocab_size=50304, |
|
hidden_size=4096, |
|
intermediate_size=11008, |
|
num_hidden_layers=32, |
|
num_attention_heads=32, |
|
num_key_value_heads=None, |
|
hidden_act="silu", |
|
max_position_embeddings=2048, |
|
initializer_range=0.02, |
|
use_cache=True, |
|
pad_token_id=1, |
|
bos_token_id=None, |
|
eos_token_id=50279, |
|
tie_word_embeddings=False, |
|
rope_theta=10000.0, |
|
rope_scaling=None, |
|
attention_bias=False, |
|
attention_dropout=0.0, |
|
rms_norm_eps=1e-5, |
|
**kwargs, |
|
): |
|
super().__init__( |
|
pad_token_id=pad_token_id, |
|
bos_token_id=bos_token_id, |
|
eos_token_id=eos_token_id, |
|
tie_word_embeddings=tie_word_embeddings, |
|
**kwargs, |
|
) |
|
self.vocab_size = vocab_size |
|
self.max_position_embeddings = max_position_embeddings |
|
self.hidden_size = hidden_size |
|
self.intermediate_size = intermediate_size |
|
self.num_hidden_layers = num_hidden_layers |
|
self.num_attention_heads = num_attention_heads |
|
|
|
|
|
if num_key_value_heads is None: |
|
num_key_value_heads = num_attention_heads |
|
|
|
self.num_key_value_heads = num_key_value_heads |
|
self.hidden_act = hidden_act |
|
self.initializer_range = initializer_range |
|
self.use_cache = use_cache |
|
self.rope_theta = rope_theta |
|
self.rope_scaling = rope_scaling |
|
self._rope_scaling_validation() |
|
self.attention_bias = attention_bias |
|
self.attention_dropout = attention_dropout |
|
|
|
self.rms_norm_eps = rms_norm_eps |
|
|
|
def _rope_scaling_validation(self): |
|
""" |
|
Validate the `rope_scaling` configuration. |
|
""" |
|
if self.rope_scaling is None: |
|
return |
|
|
|
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: |
|
raise ValueError( |
|
"`rope_scaling` must be a dictionary with two fields, `type` and `factor`, " f"got {self.rope_scaling}" |
|
) |
|
rope_scaling_type = self.rope_scaling.get("type", None) |
|
rope_scaling_factor = self.rope_scaling.get("factor", None) |
|
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: |
|
raise ValueError( |
|
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" |
|
) |
|
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: |
|
raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") |
|
|
|
|
|
|
|
|
|
from functools import partial |
|
|
|
from typing import Iterable, Optional, Set, Tuple, Union |
|
|
|
import torch |
|
from torch import nn |
|
|
|
|
|
from vllm.attention import Attention |
|
from vllm.config import VllmConfig |
|
|
|
|
|
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size |
|
from vllm.distributed.communication_op import tensor_model_parallel_all_gather |
|
from vllm.distributed.parallel_state import get_tensor_model_parallel_rank |
|
from vllm.distributed.utils import split_tensor_along_last_dim |
|
from vllm.model_executor.layers.activation import SiluAndMul |
|
from vllm.model_executor.layers.layernorm import RMSNorm |
|
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, |
|
QKVParallelLinear, |
|
RowParallelLinear) |
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor |
|
from vllm.model_executor.layers.rotary_embedding import get_rope |
|
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput |
|
from vllm.model_executor.layers.vocab_parallel_embedding import ( |
|
ParallelLMHead, VocabParallelEmbedding) |
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader |
|
from vllm.model_executor.models.interfaces import SupportsPP |
|
from vllm.model_executor.models.utils import ( |
|
is_pp_missing_parameter, make_empty_intermediate_tensors_factory, |
|
make_layers) |
|
from vllm.model_executor.sampling_metadata import SamplingMetadata |
|
from vllm.sequence import IntermediateTensors |
|
|
|
class InstellaAttention(nn.Module): |
|
""" |
|
This is the attention block where the output is computed as |
|
``Attention(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))`` |
|
(plus another skip connection). |
|
""" |
|
|
|
def __init__(self, *, |
|
vllm_config: VllmConfig, |
|
prefix: str = "" |
|
): |
|
super().__init__() |
|
self.config = vllm_config.model_config.hf_config |
|
|
|
|
|
hidden_size = self.config.hidden_size |
|
self.tp_size = get_tensor_model_parallel_world_size() |
|
self.total_num_heads = self.config.num_attention_heads |
|
|
|
assert hidden_size % self.total_num_heads == 0 |
|
assert self.total_num_heads % self.tp_size == 0 |
|
|
|
self.num_heads = self.total_num_heads // self.tp_size |
|
self.total_num_kv_heads = (self.config.num_key_value_heads |
|
or self.total_num_heads) |
|
if self.total_num_kv_heads >= self.tp_size: |
|
assert self.total_num_kv_heads % self.tp_size == 0 |
|
else: |
|
assert self.tp_size % self.total_num_kv_heads == 0 |
|
|
|
self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size) |
|
self.head_dim = hidden_size // self.total_num_heads |
|
self.q_size = self.num_heads * self.head_dim |
|
self.kv_size = self.num_kv_heads * self.head_dim |
|
self.max_position_embeddings = self.config.max_position_embeddings |
|
self.rope_theta = self.config.rope_theta |
|
|
|
|
|
self.qkv_proj = QKVParallelLinear( |
|
hidden_size, |
|
self.head_dim, |
|
self.total_num_heads, |
|
self.total_num_kv_heads, |
|
bias=False, |
|
quant_config=vllm_config.quant_config, |
|
prefix=f"{prefix}.qkv_proj", |
|
) |
|
|
|
self.tp_rank = get_tensor_model_parallel_rank() |
|
self.k_norm = RMSNorm( |
|
self.total_num_kv_heads * self.head_dim, |
|
eps=self.config.rms_norm_eps, |
|
) |
|
self.q_norm = RMSNorm(self.config.hidden_size, |
|
eps=self.config.rms_norm_eps) |
|
|
|
|
|
self.rotary_emb = get_rope( |
|
self.head_dim, |
|
rotary_dim=self.head_dim, |
|
max_position=self.max_position_embeddings, |
|
base=self.rope_theta, |
|
) |
|
self.scaling = self.head_dim**-0.5 |
|
self.attn = Attention( |
|
self.num_heads, |
|
self.head_dim, |
|
self.scaling, |
|
num_kv_heads=self.num_kv_heads, |
|
cache_config=vllm_config.cache_config, |
|
quant_config=vllm_config.quant_config, |
|
prefix=prefix, |
|
) |
|
|
|
|
|
self.o_proj = RowParallelLinear( |
|
self.total_num_heads * self.head_dim, |
|
hidden_size, |
|
bias=False, |
|
quant_config=vllm_config.quant_config, |
|
prefix=f"{prefix}.o_proj", |
|
) |
|
|
|
def _apply_qk_norm(self, q: torch.Tensor, |
|
k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
|
if self.tp_size > 1: |
|
q = tensor_model_parallel_all_gather(q.contiguous()) |
|
k = tensor_model_parallel_all_gather(k.contiguous()) |
|
q = self.q_norm.forward_native(q) |
|
k = self.k_norm.forward_native(k) |
|
if self.tp_size > 1: |
|
splitter = partial(split_tensor_along_last_dim, |
|
num_partitions=self.tp_size) |
|
q = splitter(q)[self.tp_rank] |
|
k = splitter(k)[self.tp_rank] |
|
return q, k |
|
|
|
def forward( |
|
self, |
|
positions: torch.Tensor, |
|
hidden_states: torch.Tensor, |
|
|
|
|
|
) -> torch.Tensor: |
|
qkv, _ = self.qkv_proj(hidden_states) |
|
q, k, v = qkv.chunk(chunks=3, dim=-1) |
|
q, k = self._apply_qk_norm(q, k) |
|
q, k = self.rotary_emb(positions, q, k) |
|
|
|
attn_output = self.attn(q, k, v) |
|
output, _ = self.o_proj(attn_output) |
|
return output |
|
|
|
class InstellaMLP(nn.Module): |
|
""" |
|
This is the MLP block where the output is computed as |
|
``MLP(x)`` in ``LN(MLP(x + LN(Attention(x))))`` |
|
(plus another skip connection). |
|
""" |
|
|
|
def __init__(self, *, |
|
vllm_config: VllmConfig, |
|
prefix: str = "" |
|
): |
|
super().__init__() |
|
config=vllm_config.model_config.hf_config |
|
|
|
hidden_size = config.hidden_size |
|
intermediate_size = config.intermediate_size |
|
|
|
|
|
self.gate_up_proj = MergedColumnParallelLinear( |
|
hidden_size, |
|
[intermediate_size] * 2, |
|
bias=False, |
|
quant_config=vllm_config.quant_config, |
|
prefix=f"{prefix}.gate_up_proj", |
|
) |
|
|
|
|
|
self.act_fn = SiluAndMul() |
|
|
|
|
|
self.down_proj = RowParallelLinear( |
|
intermediate_size, |
|
hidden_size, |
|
bias=False, |
|
quant_config=vllm_config.quant_config, |
|
prefix=f"{prefix}.down_proj", |
|
) |
|
|
|
def forward( |
|
self, |
|
x: torch.Tensor, |
|
) -> torch.Tensor: |
|
gate_up, _ = self.gate_up_proj(x) |
|
x = self.act_fn(gate_up) |
|
x, _ = self.down_proj(x) |
|
return x |
|
|
|
class InstellaDecoderLayer(nn.Module): |
|
""" |
|
This is a typical transformer block where the output is |
|
computed as ``MLP(LN(x + Attention(LN(x))))`` |
|
(plus another skip connection). |
|
""" |
|
|
|
def __init__(self, *, |
|
vllm_config: VllmConfig, |
|
prefix: str = "" |
|
): |
|
super().__init__() |
|
config=vllm_config.model_config.hf_config |
|
|
|
|
|
self.self_attn = InstellaAttention(vllm_config=vllm_config, prefix=f"{prefix}.self_attn") |
|
|
|
|
|
self.mlp = InstellaMLP(vllm_config=vllm_config, prefix=f"{prefix}.mlp") |
|
|
|
|
|
self.pre_attention_layernorm = RMSNorm(config.hidden_size, |
|
eps=config.rms_norm_eps) |
|
|
|
self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, |
|
eps=config.rms_norm_eps) |
|
|
|
def forward( |
|
self, |
|
positions: torch.Tensor, |
|
hidden_states: torch.Tensor, |
|
|
|
|
|
) -> torch.Tensor: |
|
|
|
residual = hidden_states |
|
hidden_states = self.pre_attention_layernorm(hidden_states) |
|
|
|
|
|
hidden_states = self.self_attn(positions, hidden_states) |
|
hidden_states = hidden_states + residual |
|
|
|
|
|
residual = hidden_states |
|
hidden_states = self.pre_feedforward_layernorm(hidden_states) |
|
hidden_states = self.mlp(hidden_states) |
|
hidden_states = residual + hidden_states |
|
return hidden_states |
|
|
|
class InstellaModel(nn.Module): |
|
|
|
def __init__(self, *, |
|
vllm_config: VllmConfig, prefix: str = "" |
|
): |
|
super().__init__() |
|
self.config = vllm_config.model_config.hf_config |
|
|
|
|
|
self.embed_tokens = VocabParallelEmbedding( |
|
self.config.vocab_size, |
|
self.config.hidden_size, |
|
prefix=f"{prefix}.embed_tokens", |
|
) |
|
self.start_layer, self.end_layer, self.layers = make_layers( |
|
self.config.num_hidden_layers, |
|
lambda prefix: InstellaDecoderLayer(vllm_config=vllm_config, prefix=prefix), |
|
prefix=f"{prefix}.layers", |
|
) |
|
self.norm = RMSNorm( |
|
self.config.hidden_size, |
|
eps=self.config.rms_norm_eps, |
|
) |
|
self.make_empty_intermediate_tensors = ( |
|
make_empty_intermediate_tensors_factory(["hidden_states"], |
|
self.config.hidden_size)) |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.Tensor, |
|
positions: torch.Tensor, |
|
|
|
|
|
intermediate_tensors: Optional[IntermediateTensors], |
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
) -> Union[torch.Tensor, IntermediateTensors]: |
|
""" |
|
:param input_ids: A tensor of shape `(batch_size, seq_len)`. |
|
""" |
|
if get_pp_group().is_first_rank: |
|
if inputs_embeds is not None: |
|
hidden_states = inputs_embeds |
|
|
|
|
|
else: |
|
hidden_states = self.embed_tokens(input_ids) |
|
else: |
|
assert intermediate_tensors is not None |
|
hidden_states = intermediate_tensors["hidden_states"] |
|
assert isinstance(hidden_states, torch.Tensor) |
|
|
|
|
|
|
|
for layer in self.layers[self.start_layer:self.end_layer]: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hidden_states = layer(positions, hidden_states) |
|
|
|
if not get_pp_group().is_last_rank: |
|
return IntermediateTensors({"hidden_states": hidden_states}) |
|
|
|
|
|
|
|
hidden_states = self.norm(hidden_states) |
|
return hidden_states |
|
|
|
class InstellaForCausalLM(nn.Module, SupportsPP): |
|
""" |
|
Extremely barebones HF model wrapper. |
|
""" |
|
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): |
|
super().__init__() |
|
config=vllm_config.model_config.hf_config |
|
|
|
|
|
|
|
self.config = vllm_config.model_config.hf_config |
|
self.model = InstellaModel(vllm_config=vllm_config, prefix=f"{prefix}.model") |
|
if config.tie_word_embeddings: |
|
self.lm_head = self.model.embed_tokens |
|
else: |
|
self.unpadded_vocab_size = config.vocab_size |
|
self.lm_head = ParallelLMHead( |
|
config.vocab_size, |
|
config.hidden_size, |
|
org_num_embeddings=config.vocab_size, |
|
quant_config=vllm_config.quant_config, |
|
prefix=f"{prefix}.lm_head" |
|
) |
|
self.logits_processor = LogitsProcessor(config.vocab_size) |
|
self.sampler = Sampler() |
|
self.make_empty_intermediate_tensors = ( |
|
self.model.make_empty_intermediate_tensors) |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.Tensor, |
|
positions: torch.Tensor, |
|
|
|
|
|
intermediate_tensors: Optional[IntermediateTensors] = None, |
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
) -> Union[torch.Tensor, IntermediateTensors]: |
|
hidden_states = self.model( |
|
input_ids=input_ids, |
|
positions=positions, |
|
|
|
|
|
intermediate_tensors=intermediate_tensors, |
|
inputs_embeds=inputs_embeds, |
|
) |
|
return hidden_states |
|
|
|
def compute_logits( |
|
self, |
|
hidden_states: torch.Tensor, |
|
sampling_metadata: SamplingMetadata, |
|
) -> Optional[torch.Tensor]: |
|
logits = self.logits_processor(self.lm_head, hidden_states, |
|
sampling_metadata) |
|
return logits |
|
|
|
def sample( |
|
self, |
|
logits: torch.Tensor, |
|
sampling_metadata: SamplingMetadata, |
|
) -> Optional[SamplerOutput]: |
|
next_tokens = self.sampler(logits, sampling_metadata) |
|
return next_tokens |
|
|
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): |
|
stacked_params_mapping = [ |
|
|
|
("qkv_proj", "q_proj", "q"), |
|
("qkv_proj", "k_proj", "k"), |
|
("qkv_proj", "v_proj", "v"), |
|
("gate_up_proj", "gate_proj", 0), |
|
("gate_up_proj", "up_proj", 1), |
|
] |
|
|
|
params_dict = dict(self.named_parameters(remove_duplicate=False)) |
|
for name, loaded_weight in weights: |
|
if "rotary_emb.inv_freq" in name: |
|
continue |
|
if ("rotary_emb.cos_cached" in name |
|
or "rotary_emb.sin_cached" in name): |
|
|
|
|
|
continue |
|
if is_pp_missing_parameter(name, self): |
|
continue |
|
|
|
|
|
|
|
if self.config.tie_word_embeddings and "lm_head.weight" in name: |
|
continue |
|
for param_name, weight_name, shard_id in stacked_params_mapping: |
|
if weight_name not in name: |
|
continue |
|
name = name.replace(weight_name, param_name) |
|
|
|
if name.endswith(".bias") and name not in params_dict: |
|
continue |
|
param = params_dict[name] |
|
weight_loader = param.weight_loader |
|
weight_loader(param, loaded_weight, shard_id) |
|
break |
|
else: |
|
|
|
if name.endswith(".bias") and name not in params_dict: |
|
continue |
|
param = params_dict[name] |
|
weight_loader = getattr(param, "weight_loader", |
|
default_weight_loader) |
|
weight_loader(param, loaded_weight) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|