Update modeling_rwkv6qwen2.py
Browse files- modeling_rwkv6qwen2.py +181 -143
modeling_rwkv6qwen2.py
CHANGED
|
@@ -29,7 +29,7 @@ from torch import nn
|
|
| 29 |
import torch.nn.functional as F
|
| 30 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 31 |
|
| 32 |
-
from transformers.cache_utils import Cache, StaticCache
|
| 33 |
from transformers.generation import GenerationMixin
|
| 34 |
from transformers.modeling_outputs import (
|
| 35 |
BaseModelOutputWithPast,
|
|
@@ -209,7 +209,7 @@ try:
|
|
| 209 |
from fla.ops.gla.fused_recurrent import fused_recurrent_gla
|
| 210 |
except ImportError:
|
| 211 |
print("Required module is not installed. Please install it using the following commands:")
|
| 212 |
-
print("pip install -U git+https://github.com/
|
| 213 |
print("Additionally, ensure you have at least version 2.2.0 of Triton installed:")
|
| 214 |
print("pip install triton>=2.2.0")
|
| 215 |
|
|
@@ -230,7 +230,6 @@ class RWKV6Attention(nn.Module):
|
|
| 230 |
self.head_dim = getattr(config, 'head_dim', self.hidden_size // self.num_heads)
|
| 231 |
self.num_key_value_heads = config.num_key_value_heads
|
| 232 |
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
| 233 |
-
self.is_causal = True
|
| 234 |
self.attention_dropout = config.attention_dropout
|
| 235 |
|
| 236 |
if self.hidden_size % self.num_heads != 0:
|
|
@@ -284,7 +283,7 @@ class RWKV6Attention(nn.Module):
|
|
| 284 |
hidden_states: torch.Tensor,
|
| 285 |
attention_mask: Optional[torch.Tensor] = None,
|
| 286 |
position_ids: Optional[torch.LongTensor] = None,
|
| 287 |
-
|
| 288 |
output_attentions: bool = False,
|
| 289 |
use_cache: bool = False,
|
| 290 |
cache_position: Optional[torch.LongTensor] = None,
|
|
@@ -297,8 +296,8 @@ class RWKV6Attention(nn.Module):
|
|
| 297 |
|
| 298 |
x = hidden_states
|
| 299 |
|
| 300 |
-
if use_cache and
|
| 301 |
-
input_kv_state, input_shift_state =
|
| 302 |
xprev = torch.cat([input_shift_state, x[:, :-1]], dim=1)
|
| 303 |
else:
|
| 304 |
input_kv_state = None
|
|
@@ -334,9 +333,13 @@ class RWKV6Attention(nn.Module):
|
|
| 334 |
dropout_rate = 0.0 if not self.training else self.attention_dropout
|
| 335 |
|
| 336 |
decay_states_log = -decay_states.float().exp()
|
| 337 |
-
|
| 338 |
key_states = (key_states * (1 - decay_states_log.exp())).to(key_states.dtype)
|
| 339 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 340 |
query_states = query_states.to(value_states.dtype)
|
| 341 |
key_states = key_states.to(value_states.dtype)
|
| 342 |
|
|
@@ -366,19 +369,19 @@ class RWKV6Attention(nn.Module):
|
|
| 366 |
attn_weights = torch.empty(0, device=x.device)
|
| 367 |
|
| 368 |
scale = query_states.shape[-1] ** -0.5
|
| 369 |
-
output_final_state = not self.training and use_cache and
|
| 370 |
#attn_output, output_kv_state = ChunkGLAFunction.apply(query_states, key_states, value_states, decay_states_log.float(), scale, input_kv_state, output_final_state)
|
| 371 |
#attn_output, output_kv_state = chunk_gla(query_states, key_states, value_states, decay_states_log, scale, input_kv_state, output_final_state)
|
| 372 |
attn_output, output_kv_state = fused_recurrent_gla(query_states, key_states, value_states, decay_states_log, None, scale, input_kv_state, output_final_state)
|
| 373 |
|
| 374 |
if output_final_state:
|
| 375 |
-
|
| 376 |
|
| 377 |
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 378 |
attn_output = attn_output.view(bsz, q_len, -1)
|
| 379 |
attn_output = self.o_proj(attn_output * gate_states)
|
| 380 |
|
| 381 |
-
return attn_output, attn_weights
|
| 382 |
|
| 383 |
class RWKV6Qwen2DecoderLayer(Qwen2DecoderLayer):
|
| 384 |
def __init__(self, config: RWKV6Qwen2Config, layer_idx: int):
|
|
@@ -391,6 +394,48 @@ class RWKV6Qwen2DecoderLayer(Qwen2DecoderLayer):
|
|
| 391 |
self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 392 |
self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 393 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 394 |
RWKV6QWEN2_START_DOCSTRING = r"""
|
| 395 |
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 396 |
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
|
@@ -581,6 +626,7 @@ class RWKV6Qwen2Model(RWKV6Qwen2PreTrainedModel):
|
|
| 581 |
#return_legacy_cache = False
|
| 582 |
if use_cache and not isinstance(past_key_values, RWKV6State):
|
| 583 |
#return_legacy_cache = True
|
|
|
|
| 584 |
past_key_values = RWKV6State()
|
| 585 |
# if past_key_values is None:
|
| 586 |
# past_key_values = DynamicCache()
|
|
@@ -638,9 +684,9 @@ class RWKV6Qwen2Model(RWKV6Qwen2PreTrainedModel):
|
|
| 638 |
else:
|
| 639 |
layer_outputs = decoder_layer(
|
| 640 |
hidden_states,
|
| 641 |
-
attention_mask=
|
| 642 |
position_ids=position_ids,
|
| 643 |
-
|
| 644 |
output_attentions=output_attentions,
|
| 645 |
use_cache=use_cache,
|
| 646 |
cache_position=cache_position,
|
|
@@ -649,9 +695,6 @@ class RWKV6Qwen2Model(RWKV6Qwen2PreTrainedModel):
|
|
| 649 |
|
| 650 |
hidden_states = layer_outputs[0]
|
| 651 |
|
| 652 |
-
if use_cache:
|
| 653 |
-
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
| 654 |
-
|
| 655 |
if output_attentions:
|
| 656 |
all_self_attns += (layer_outputs[1],)
|
| 657 |
|
|
@@ -661,15 +704,14 @@ class RWKV6Qwen2Model(RWKV6Qwen2PreTrainedModel):
|
|
| 661 |
if output_hidden_states:
|
| 662 |
all_hidden_states += (hidden_states,)
|
| 663 |
|
| 664 |
-
next_cache = next_decoder_cache if use_cache else None
|
| 665 |
#if return_legacy_cache:
|
| 666 |
# next_cache = next_cache.to_legacy_cache()
|
| 667 |
|
| 668 |
if not return_dict:
|
| 669 |
-
return tuple(v for v in [hidden_states,
|
| 670 |
return BaseModelOutputWithPast(
|
| 671 |
last_hidden_state=hidden_states,
|
| 672 |
-
past_key_values=
|
| 673 |
hidden_states=all_hidden_states,
|
| 674 |
attentions=all_self_attns,
|
| 675 |
)
|
|
@@ -793,130 +835,126 @@ class RWKV6Qwen2ForCausalLM(RWKV6Qwen2PreTrainedModel, GenerationMixin):
|
|
| 793 |
attentions=outputs.attentions,
|
| 794 |
)
|
| 795 |
|
| 796 |
-
def prepare_inputs_for_generation(
|
| 797 |
-
|
| 798 |
-
|
| 799 |
-
|
| 800 |
-
|
| 801 |
-
|
| 802 |
-
|
| 803 |
-
|
| 804 |
-
):
|
| 805 |
-
|
| 806 |
-
|
| 807 |
-
|
| 808 |
-
|
| 809 |
-
|
| 810 |
-
|
| 811 |
-
|
| 812 |
-
|
| 813 |
-
|
| 814 |
-
|
| 815 |
-
|
| 816 |
-
|
| 817 |
-
|
| 818 |
-
|
| 819 |
-
|
| 820 |
-
|
| 821 |
-
|
| 822 |
-
|
| 823 |
-
|
| 824 |
-
|
| 825 |
-
|
| 826 |
-
|
| 827 |
-
|
| 828 |
-
|
| 829 |
-
|
| 830 |
-
|
| 831 |
-
|
| 832 |
-
|
| 833 |
-
|
| 834 |
-
|
| 835 |
-
|
| 836 |
-
|
| 837 |
-
|
| 838 |
-
|
| 839 |
-
|
| 840 |
-
|
| 841 |
-
|
| 842 |
-
|
| 843 |
-
|
| 844 |
-
|
| 845 |
-
|
| 846 |
-
|
| 847 |
-
|
| 848 |
-
|
| 849 |
-
|
| 850 |
-
|
| 851 |
-
|
| 852 |
-
|
| 853 |
-
|
| 854 |
-
|
| 855 |
-
|
| 856 |
-
|
| 857 |
-
|
| 858 |
-
|
| 859 |
-
|
| 860 |
-
|
| 861 |
-
|
| 862 |
-
|
| 863 |
-
|
| 864 |
-
|
| 865 |
-
|
| 866 |
-
|
| 867 |
-
|
| 868 |
-
|
| 869 |
-
|
| 870 |
-
|
| 871 |
-
|
| 872 |
-
|
| 873 |
-
|
| 874 |
-
|
| 875 |
-
|
| 876 |
-
|
| 877 |
-
|
| 878 |
-
|
| 879 |
-
|
| 880 |
-
|
| 881 |
-
|
| 882 |
-
|
| 883 |
-
|
| 884 |
-
|
| 885 |
-
|
| 886 |
-
|
| 887 |
-
|
| 888 |
-
|
| 889 |
-
|
| 890 |
-
|
| 891 |
-
|
| 892 |
-
|
| 893 |
-
|
| 894 |
-
|
| 895 |
-
|
| 896 |
-
|
| 897 |
-
|
| 898 |
-
|
| 899 |
-
|
| 900 |
-
|
| 901 |
-
|
| 902 |
-
|
| 903 |
-
|
| 904 |
-
|
| 905 |
-
|
| 906 |
-
|
| 907 |
-
|
| 908 |
-
|
| 909 |
-
|
| 910 |
-
|
| 911 |
-
|
| 912 |
-
|
| 913 |
-
|
| 914 |
-
|
| 915 |
-
|
| 916 |
-
|
| 917 |
-
# 8. Remove unexpected `generate` inputs (TODO @joao: fix trainer and examples)
|
| 918 |
-
model_inputs.pop("labels", None)
|
| 919 |
-
return model_inputs
|
| 920 |
|
| 921 |
@add_start_docstrings(
|
| 922 |
"""
|
|
@@ -1215,4 +1253,4 @@ class RWKV6Qwen2ForQuestionAnswering(RWKV6Qwen2PreTrainedModel):
|
|
| 1215 |
end_logits=end_logits,
|
| 1216 |
hidden_states=outputs.hidden_states,
|
| 1217 |
attentions=outputs.attentions,
|
| 1218 |
-
)
|
|
|
|
| 29 |
import torch.nn.functional as F
|
| 30 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 31 |
|
| 32 |
+
from transformers.cache_utils import Cache, StaticCache, DynamicCache
|
| 33 |
from transformers.generation import GenerationMixin
|
| 34 |
from transformers.modeling_outputs import (
|
| 35 |
BaseModelOutputWithPast,
|
|
|
|
| 209 |
from fla.ops.gla.fused_recurrent import fused_recurrent_gla
|
| 210 |
except ImportError:
|
| 211 |
print("Required module is not installed. Please install it using the following commands:")
|
| 212 |
+
print("pip install -U git+https://github.com/fla-org/flash-linear-attention")
|
| 213 |
print("Additionally, ensure you have at least version 2.2.0 of Triton installed:")
|
| 214 |
print("pip install triton>=2.2.0")
|
| 215 |
|
|
|
|
| 230 |
self.head_dim = getattr(config, 'head_dim', self.hidden_size // self.num_heads)
|
| 231 |
self.num_key_value_heads = config.num_key_value_heads
|
| 232 |
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
|
|
|
| 233 |
self.attention_dropout = config.attention_dropout
|
| 234 |
|
| 235 |
if self.hidden_size % self.num_heads != 0:
|
|
|
|
| 283 |
hidden_states: torch.Tensor,
|
| 284 |
attention_mask: Optional[torch.Tensor] = None,
|
| 285 |
position_ids: Optional[torch.LongTensor] = None,
|
| 286 |
+
past_key_values: Optional[RWKV6State] = None,
|
| 287 |
output_attentions: bool = False,
|
| 288 |
use_cache: bool = False,
|
| 289 |
cache_position: Optional[torch.LongTensor] = None,
|
|
|
|
| 296 |
|
| 297 |
x = hidden_states
|
| 298 |
|
| 299 |
+
if use_cache and past_key_values is not None and len(past_key_values) > self.layer_idx:
|
| 300 |
+
input_kv_state, input_shift_state = past_key_values[self.layer_idx]
|
| 301 |
xprev = torch.cat([input_shift_state, x[:, :-1]], dim=1)
|
| 302 |
else:
|
| 303 |
input_kv_state = None
|
|
|
|
| 333 |
dropout_rate = 0.0 if not self.training else self.attention_dropout
|
| 334 |
|
| 335 |
decay_states_log = -decay_states.float().exp()
|
| 336 |
+
decay_states_log = decay_states_log.clamp(-5) # FIXME - is this necessary?
|
| 337 |
key_states = (key_states * (1 - decay_states_log.exp())).to(key_states.dtype)
|
| 338 |
|
| 339 |
+
if attention_mask is not None:
|
| 340 |
+
if q_len > 1:
|
| 341 |
+
decay_states_log = decay_states_log - 100 * F.pad(1 - attention_mask, [1, -1]).view(bsz, 1, q_len, 1)
|
| 342 |
+
|
| 343 |
query_states = query_states.to(value_states.dtype)
|
| 344 |
key_states = key_states.to(value_states.dtype)
|
| 345 |
|
|
|
|
| 369 |
attn_weights = torch.empty(0, device=x.device)
|
| 370 |
|
| 371 |
scale = query_states.shape[-1] ** -0.5
|
| 372 |
+
output_final_state = not self.training and use_cache and past_key_values is not None
|
| 373 |
#attn_output, output_kv_state = ChunkGLAFunction.apply(query_states, key_states, value_states, decay_states_log.float(), scale, input_kv_state, output_final_state)
|
| 374 |
#attn_output, output_kv_state = chunk_gla(query_states, key_states, value_states, decay_states_log, scale, input_kv_state, output_final_state)
|
| 375 |
attn_output, output_kv_state = fused_recurrent_gla(query_states, key_states, value_states, decay_states_log, None, scale, input_kv_state, output_final_state)
|
| 376 |
|
| 377 |
if output_final_state:
|
| 378 |
+
past_key_values.update(output_kv_state, output_shift_state, q_len, self.layer_idx)
|
| 379 |
|
| 380 |
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 381 |
attn_output = attn_output.view(bsz, q_len, -1)
|
| 382 |
attn_output = self.o_proj(attn_output * gate_states)
|
| 383 |
|
| 384 |
+
return attn_output, attn_weights
|
| 385 |
|
| 386 |
class RWKV6Qwen2DecoderLayer(Qwen2DecoderLayer):
|
| 387 |
def __init__(self, config: RWKV6Qwen2Config, layer_idx: int):
|
|
|
|
| 394 |
self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 395 |
self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 396 |
|
| 397 |
+
def forward(
|
| 398 |
+
self,
|
| 399 |
+
hidden_states: torch.Tensor,
|
| 400 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 401 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 402 |
+
past_key_values: Optional[Cache] = None,
|
| 403 |
+
output_attentions: Optional[bool] = False,
|
| 404 |
+
use_cache: Optional[bool] = False,
|
| 405 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 406 |
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
| 407 |
+
**kwargs,
|
| 408 |
+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
| 409 |
+
residual = hidden_states
|
| 410 |
+
|
| 411 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 412 |
+
|
| 413 |
+
# Self Attention
|
| 414 |
+
hidden_states, self_attn_weights = self.self_attn(
|
| 415 |
+
hidden_states=hidden_states,
|
| 416 |
+
attention_mask=attention_mask,
|
| 417 |
+
position_ids=position_ids,
|
| 418 |
+
past_key_values=past_key_values,
|
| 419 |
+
output_attentions=output_attentions,
|
| 420 |
+
use_cache=use_cache,
|
| 421 |
+
cache_position=cache_position,
|
| 422 |
+
position_embeddings=position_embeddings,
|
| 423 |
+
**kwargs,
|
| 424 |
+
)
|
| 425 |
+
hidden_states = residual + hidden_states
|
| 426 |
+
|
| 427 |
+
# Fully Connected
|
| 428 |
+
residual = hidden_states
|
| 429 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 430 |
+
hidden_states = self.mlp(hidden_states)
|
| 431 |
+
hidden_states = residual + hidden_states
|
| 432 |
+
|
| 433 |
+
outputs = (hidden_states,)
|
| 434 |
+
if output_attentions:
|
| 435 |
+
outputs += (self_attn_weights,)
|
| 436 |
+
|
| 437 |
+
return outputs
|
| 438 |
+
|
| 439 |
RWKV6QWEN2_START_DOCSTRING = r"""
|
| 440 |
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 441 |
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
|
|
|
| 626 |
#return_legacy_cache = False
|
| 627 |
if use_cache and not isinstance(past_key_values, RWKV6State):
|
| 628 |
#return_legacy_cache = True
|
| 629 |
+
print("creating past_key_values", past_key_values)
|
| 630 |
past_key_values = RWKV6State()
|
| 631 |
# if past_key_values is None:
|
| 632 |
# past_key_values = DynamicCache()
|
|
|
|
| 684 |
else:
|
| 685 |
layer_outputs = decoder_layer(
|
| 686 |
hidden_states,
|
| 687 |
+
attention_mask=attention_mask,
|
| 688 |
position_ids=position_ids,
|
| 689 |
+
past_key_values=past_key_values,
|
| 690 |
output_attentions=output_attentions,
|
| 691 |
use_cache=use_cache,
|
| 692 |
cache_position=cache_position,
|
|
|
|
| 695 |
|
| 696 |
hidden_states = layer_outputs[0]
|
| 697 |
|
|
|
|
|
|
|
|
|
|
| 698 |
if output_attentions:
|
| 699 |
all_self_attns += (layer_outputs[1],)
|
| 700 |
|
|
|
|
| 704 |
if output_hidden_states:
|
| 705 |
all_hidden_states += (hidden_states,)
|
| 706 |
|
|
|
|
| 707 |
#if return_legacy_cache:
|
| 708 |
# next_cache = next_cache.to_legacy_cache()
|
| 709 |
|
| 710 |
if not return_dict:
|
| 711 |
+
return tuple(v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns] if v is not None)
|
| 712 |
return BaseModelOutputWithPast(
|
| 713 |
last_hidden_state=hidden_states,
|
| 714 |
+
past_key_values=past_key_values,
|
| 715 |
hidden_states=all_hidden_states,
|
| 716 |
attentions=all_self_attns,
|
| 717 |
)
|
|
|
|
| 835 |
attentions=outputs.attentions,
|
| 836 |
)
|
| 837 |
|
| 838 |
+
# def prepare_inputs_for_generation(
|
| 839 |
+
# self,
|
| 840 |
+
# input_ids: torch.LongTensor,
|
| 841 |
+
# past_key_values: Optional[Cache] = None,
|
| 842 |
+
# attention_mask: Optional[torch.LongTensor] = None,
|
| 843 |
+
# inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 844 |
+
# cache_position: Optional[torch.LongTensor] = None,
|
| 845 |
+
# **kwargs,
|
| 846 |
+
# ):
|
| 847 |
+
# """
|
| 848 |
+
# Prepare the model inputs for generation. In includes operations like computing the 4D attention mask or
|
| 849 |
+
# slicing inputs given the existing cache.
|
| 850 |
+
|
| 851 |
+
# See the forward pass in the model documentation for expected arguments (different models might have different
|
| 852 |
+
# requirements for e.g. `past_key_values`). This function should work as is for most LLMs.
|
| 853 |
+
# """
|
| 854 |
+
|
| 855 |
+
# # 1. Handle BC:
|
| 856 |
+
# model_inputs = {}
|
| 857 |
+
# # - some models don't have `Cache` support (which implies they don't expect `cache_position` in `forward`)
|
| 858 |
+
# if self._supports_cache_class:
|
| 859 |
+
# model_inputs["cache_position"] = cache_position
|
| 860 |
+
# # - `cache_position` was not a mandatory input in `prepare_inputs_for_generation` for those models, and this
|
| 861 |
+
# # function may be called outside of `generate`. Handle most use cases by creating `cache_position` on the fly
|
| 862 |
+
# # (this alternative is not as robust as calling `generate` and letting it create `cache_position`)
|
| 863 |
+
# elif cache_position is None:
|
| 864 |
+
# past_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
| 865 |
+
# cache_position = torch.arange(past_length, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
|
| 866 |
+
|
| 867 |
+
# # 2. Generic cache-dependent input preparation
|
| 868 |
+
# # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
| 869 |
+
# # Exception 1: when passing input_embeds, input_ids may be missing entries
|
| 870 |
+
# # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
| 871 |
+
# # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case
|
| 872 |
+
# if past_key_values is not None:
|
| 873 |
+
# model_inputs["past_key_values"] = past_key_values
|
| 874 |
+
# if inputs_embeds is not None or cache_position[-1] >= input_ids.shape[1]: # Exception 1 or Exception 3
|
| 875 |
+
# input_ids = input_ids[:, -cache_position.shape[0] :]
|
| 876 |
+
# elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
| 877 |
+
# input_ids = input_ids[:, cache_position]
|
| 878 |
+
|
| 879 |
+
# # 3. Prepare base model inputs
|
| 880 |
+
# input_ids_key = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
|
| 881 |
+
# # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
| 882 |
+
# if not self.config.is_encoder_decoder:
|
| 883 |
+
# if inputs_embeds is not None and cache_position[0] == 0:
|
| 884 |
+
# model_inputs[input_ids_key] = None
|
| 885 |
+
# model_inputs["inputs_embeds"] = inputs_embeds
|
| 886 |
+
# else:
|
| 887 |
+
# # `clone` calls in this function ensure a consistent stride. See #32227
|
| 888 |
+
# model_inputs[input_ids_key] = input_ids.clone(memory_format=torch.contiguous_format)
|
| 889 |
+
# model_inputs["inputs_embeds"] = None
|
| 890 |
+
# else:
|
| 891 |
+
# model_inputs[input_ids_key] = input_ids.clone(memory_format=torch.contiguous_format)
|
| 892 |
+
|
| 893 |
+
# # 4. Create missing `position_ids` on the fly
|
| 894 |
+
# if (attention_mask is not None and kwargs.get("position_ids") is None and "position_ids" in set(inspect.signature(self.forward).parameters.keys())):
|
| 895 |
+
# position_ids = attention_mask.long().cumsum(-1) - 1
|
| 896 |
+
# position_ids.masked_fill_(attention_mask == 0, 1)
|
| 897 |
+
# kwargs["position_ids"] = position_ids # placed in kwargs for further processing (see below)
|
| 898 |
+
|
| 899 |
+
# # 5. Slice model inputs if it's an input that should have the same length as `input_ids`
|
| 900 |
+
# for model_input_name in ["position_ids", "token_type_ids"]:
|
| 901 |
+
# model_input = kwargs.get(model_input_name)
|
| 902 |
+
# if model_input is not None:
|
| 903 |
+
# if past_key_values:
|
| 904 |
+
# model_input = model_input[:, -input_ids.shape[1] :]
|
| 905 |
+
# model_input = model_input.clone(memory_format=torch.contiguous_format)
|
| 906 |
+
# model_inputs[model_input_name] = model_input
|
| 907 |
+
|
| 908 |
+
# # 6. Create 4D attention mask is we are using a `StaticCache` (important for performant compiled forward pass)
|
| 909 |
+
# if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
|
| 910 |
+
# if model_inputs["inputs_embeds"] is not None:
|
| 911 |
+
# batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
|
| 912 |
+
# device = model_inputs["inputs_embeds"].device
|
| 913 |
+
# else:
|
| 914 |
+
# batch_size, sequence_length = model_inputs[input_ids_key].shape
|
| 915 |
+
# device = model_inputs[input_ids_key].device
|
| 916 |
+
|
| 917 |
+
# # Create the causal mask with fixed shape in advance, to reduce recompilations. If the function to create
|
| 918 |
+
# # the 4D causal mask exists, it should be present in the base model (XXXModel class).
|
| 919 |
+
# base_model = getattr(self, self.base_model_prefix, None)
|
| 920 |
+
# if base_model is None:
|
| 921 |
+
# causal_mask_creation_function = getattr(
|
| 922 |
+
# self, "_prepare_4d_causal_attention_mask_with_cache_position", None
|
| 923 |
+
# )
|
| 924 |
+
# else:
|
| 925 |
+
# causal_mask_creation_function = getattr(
|
| 926 |
+
# base_model, "_prepare_4d_causal_attention_mask_with_cache_position", None
|
| 927 |
+
# )
|
| 928 |
+
# if causal_mask_creation_function is None:
|
| 929 |
+
# logger.warning_once(
|
| 930 |
+
# f"{self.__class__.__name__} has no `_prepare_4d_causal_attention_mask_with_cache_position` method "
|
| 931 |
+
# "defined in its base modeling class. Compiled forward passes will be sub-optimal. If you're "
|
| 932 |
+
# "writing code, see Llama for an example implementation. If you're a user, please report this "
|
| 933 |
+
# "issue on GitHub."
|
| 934 |
+
# )
|
| 935 |
+
# else:
|
| 936 |
+
# attention_mask = causal_mask_creation_function(
|
| 937 |
+
# attention_mask,
|
| 938 |
+
# sequence_length=sequence_length,
|
| 939 |
+
# target_length=past_key_values.get_max_cache_shape(),
|
| 940 |
+
# dtype=self.dtype,
|
| 941 |
+
# device=device,
|
| 942 |
+
# cache_position=cache_position,
|
| 943 |
+
# batch_size=batch_size,
|
| 944 |
+
# config=self.config,
|
| 945 |
+
# past_key_values=past_key_values,
|
| 946 |
+
# )
|
| 947 |
+
# if attention_mask is not None:
|
| 948 |
+
# model_inputs["attention_mask"] = attention_mask
|
| 949 |
+
|
| 950 |
+
# # 7. Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
|
| 951 |
+
# for key, value in kwargs.items():
|
| 952 |
+
# if key not in model_inputs:
|
| 953 |
+
# model_inputs[key] = value
|
| 954 |
+
|
| 955 |
+
# # 8. Remove unexpected `generate` inputs (TODO @joao: fix trainer and examples)
|
| 956 |
+
# model_inputs.pop("labels", None)
|
| 957 |
+
# return model_inputs
|
|
|
|
|
|
|
|
|
|
|
|
|
| 958 |
|
| 959 |
@add_start_docstrings(
|
| 960 |
"""
|
|
|
|
| 1253 |
end_logits=end_logits,
|
| 1254 |
hidden_states=outputs.hidden_states,
|
| 1255 |
attentions=outputs.attentions,
|
| 1256 |
+
)
|