Update modeling_rwkv6qwen2.py
Browse files- modeling_rwkv6qwen2.py +19 -66
modeling_rwkv6qwen2.py
CHANGED
@@ -423,7 +423,7 @@ class RWKV6Attention(nn.Module):
|
|
423 |
|
424 |
# dealing with left-padding
|
425 |
if attention_mask is not None:
|
426 |
-
v = v * attention_mask[:,
|
427 |
|
428 |
r = r.view(B,T,-1,N).to(v.dtype)
|
429 |
k = k.view(B,T,-1,N).to(v.dtype)
|
@@ -436,9 +436,6 @@ class RWKV6Attention(nn.Module):
|
|
436 |
output_final_state = not self.training and use_cache and past_key_values is not None
|
437 |
attn_output, output_kv_state = fused_recurrent_gla(r, k, v, log_w, None, scale, input_kv_state, output_final_state)
|
438 |
|
439 |
-
if output_final_state:
|
440 |
-
past_key_values.update(output_kv_state, output_shift_state, T, self.layer_idx)
|
441 |
-
|
442 |
attn_output = attn_output.view(B, T, -1)
|
443 |
if self.config.groupnorm_att:
|
444 |
attn_output = self.ln_x(attn_output.view(B * T, -1)).view(B, T, -1)
|
@@ -446,6 +443,9 @@ class RWKV6Attention(nn.Module):
|
|
446 |
attn_output = attn_output * g
|
447 |
attn_output = self.o_proj(attn_output)
|
448 |
|
|
|
|
|
|
|
449 |
return attn_output, attn_weights
|
450 |
|
451 |
class RWKV6Qwen2DecoderLayer(Qwen2DecoderLayer):
|
@@ -680,36 +680,23 @@ class RWKV6Qwen2Model(RWKV6Qwen2PreTrainedModel):
|
|
680 |
if (input_ids is None) ^ (inputs_embeds is not None):
|
681 |
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
682 |
|
683 |
-
if self.gradient_checkpointing and self.training:
|
684 |
-
|
685 |
-
|
686 |
-
|
687 |
-
|
688 |
-
use_cache = False
|
689 |
-
|
690 |
-
# kept for BC (non `Cache` `past_key_values` inputs)
|
691 |
-
#return_legacy_cache = False
|
692 |
-
if use_cache and not isinstance(past_key_values, RWKV6State):
|
693 |
-
#return_legacy_cache = True
|
694 |
-
past_key_values = RWKV6State()
|
695 |
-
# if past_key_values is None:
|
696 |
-
# past_key_values = DynamicCache()
|
697 |
-
# else:
|
698 |
-
# past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
699 |
-
# logger.warning_once(
|
700 |
-
# "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
|
701 |
-
# "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
|
702 |
-
# "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
|
703 |
-
# )
|
704 |
|
705 |
if inputs_embeds is None:
|
706 |
inputs_embeds = self.embed_tokens(input_ids)
|
707 |
|
708 |
-
if
|
709 |
-
|
710 |
-
|
711 |
-
|
712 |
-
|
|
|
|
|
|
|
713 |
|
714 |
if position_ids is None:
|
715 |
position_ids = cache_position.unsqueeze(0)
|
@@ -723,9 +710,10 @@ class RWKV6Qwen2Model(RWKV6Qwen2PreTrainedModel):
|
|
723 |
hidden_states = inputs_embeds
|
724 |
|
725 |
# create position embeddings to be shared across the decoder layers
|
726 |
-
position_embeddings = None
|
727 |
if self.config.use_rope:
|
728 |
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
|
|
|
|
729 |
|
730 |
# decoder layers
|
731 |
all_hidden_states = () if output_hidden_states else None
|
@@ -902,41 +890,6 @@ class RWKV6Qwen2ForCausalLM(RWKV6Qwen2PreTrainedModel, GenerationMixin):
|
|
902 |
attentions=outputs.attentions,
|
903 |
)
|
904 |
|
905 |
-
def prepare_inputs_for_generation(
|
906 |
-
self,
|
907 |
-
input_ids: torch.LongTensor,
|
908 |
-
past_key_values: Optional[Cache] = None,
|
909 |
-
attention_mask: Optional[torch.LongTensor] = None,
|
910 |
-
inputs_embeds: Optional[torch.FloatTensor] = None,
|
911 |
-
cache_position: Optional[torch.LongTensor] = None,
|
912 |
-
**kwargs,
|
913 |
-
):
|
914 |
-
# only last token for `inputs_ids` if the `past_key_values` is not empty.
|
915 |
-
if past_key_values is not None and len(past_key_values) > 0:
|
916 |
-
input_ids = input_ids[:, -1:]
|
917 |
-
|
918 |
-
model_inputs = {
|
919 |
-
'past_key_values': past_key_values,
|
920 |
-
'attention_mask': attention_mask,
|
921 |
-
'cache_position': cache_position,
|
922 |
-
}
|
923 |
-
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
924 |
-
if inputs_embeds is not None and past_key_values is None:
|
925 |
-
model_inputs['inputs_embeds'] = inputs_embeds
|
926 |
-
else:
|
927 |
-
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
|
928 |
-
# recompiles graphs as the stride of the inputs is a guard.
|
929 |
-
# Ref: https://github.com/huggingface/transformers/pull/29114
|
930 |
-
# TODO: use `next_tokens` directly instead.
|
931 |
-
model_inputs['input_ids'] = input_ids.contiguous()
|
932 |
-
|
933 |
-
model_inputs.update(**kwargs)
|
934 |
-
|
935 |
-
# 8. Remove unexpected `generate` inputs (TODO @joao: fix trainer and examples)
|
936 |
-
model_inputs.pop("labels", None)
|
937 |
-
|
938 |
-
return model_inputs
|
939 |
-
|
940 |
@add_start_docstrings(
|
941 |
"""
|
942 |
The RWKV6Qwen2 Model transformer with a sequence classification head on top (linear layer).
|
|
|
423 |
|
424 |
# dealing with left-padding
|
425 |
if attention_mask is not None:
|
426 |
+
v = v * attention_mask[:, -v.shape[-2]:, None]
|
427 |
|
428 |
r = r.view(B,T,-1,N).to(v.dtype)
|
429 |
k = k.view(B,T,-1,N).to(v.dtype)
|
|
|
436 |
output_final_state = not self.training and use_cache and past_key_values is not None
|
437 |
attn_output, output_kv_state = fused_recurrent_gla(r, k, v, log_w, None, scale, input_kv_state, output_final_state)
|
438 |
|
|
|
|
|
|
|
439 |
attn_output = attn_output.view(B, T, -1)
|
440 |
if self.config.groupnorm_att:
|
441 |
attn_output = self.ln_x(attn_output.view(B * T, -1)).view(B, T, -1)
|
|
|
443 |
attn_output = attn_output * g
|
444 |
attn_output = self.o_proj(attn_output)
|
445 |
|
446 |
+
if output_final_state:
|
447 |
+
past_key_values.update(output_kv_state, output_shift_state, self.layer_idx, T)
|
448 |
+
|
449 |
return attn_output, attn_weights
|
450 |
|
451 |
class RWKV6Qwen2DecoderLayer(Qwen2DecoderLayer):
|
|
|
680 |
if (input_ids is None) ^ (inputs_embeds is not None):
|
681 |
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
682 |
|
683 |
+
if self.gradient_checkpointing and self.training and use_cache:
|
684 |
+
logger.warning_once(
|
685 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
686 |
+
)
|
687 |
+
use_cache = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
688 |
|
689 |
if inputs_embeds is None:
|
690 |
inputs_embeds = self.embed_tokens(input_ids)
|
691 |
|
692 |
+
if use_cache and not isinstance(past_key_values, RWKV6State):
|
693 |
+
past_key_values = RWKV6State()
|
694 |
+
|
695 |
+
#if cache_position is None:
|
696 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
697 |
+
cache_position = torch.arange(
|
698 |
+
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
699 |
+
)
|
700 |
|
701 |
if position_ids is None:
|
702 |
position_ids = cache_position.unsqueeze(0)
|
|
|
710 |
hidden_states = inputs_embeds
|
711 |
|
712 |
# create position embeddings to be shared across the decoder layers
|
|
|
713 |
if self.config.use_rope:
|
714 |
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
715 |
+
else:
|
716 |
+
position_embeddings = None
|
717 |
|
718 |
# decoder layers
|
719 |
all_hidden_states = () if output_hidden_states else None
|
|
|
890 |
attentions=outputs.attentions,
|
891 |
)
|
892 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
893 |
@add_start_docstrings(
|
894 |
"""
|
895 |
The RWKV6Qwen2 Model transformer with a sequence classification head on top (linear layer).
|