SmerkyG commited on
Commit
2da855c
·
verified ·
1 Parent(s): 864bad4

Update modeling_rwkv6qwen2.py

Browse files
Files changed (1) hide show
  1. modeling_rwkv6qwen2.py +19 -66
modeling_rwkv6qwen2.py CHANGED
@@ -423,7 +423,7 @@ class RWKV6Attention(nn.Module):
423
 
424
  # dealing with left-padding
425
  if attention_mask is not None:
426
- v = v * attention_mask[:, None, -v.shape[-2]:, None]
427
 
428
  r = r.view(B,T,-1,N).to(v.dtype)
429
  k = k.view(B,T,-1,N).to(v.dtype)
@@ -436,9 +436,6 @@ class RWKV6Attention(nn.Module):
436
  output_final_state = not self.training and use_cache and past_key_values is not None
437
  attn_output, output_kv_state = fused_recurrent_gla(r, k, v, log_w, None, scale, input_kv_state, output_final_state)
438
 
439
- if output_final_state:
440
- past_key_values.update(output_kv_state, output_shift_state, T, self.layer_idx)
441
-
442
  attn_output = attn_output.view(B, T, -1)
443
  if self.config.groupnorm_att:
444
  attn_output = self.ln_x(attn_output.view(B * T, -1)).view(B, T, -1)
@@ -446,6 +443,9 @@ class RWKV6Attention(nn.Module):
446
  attn_output = attn_output * g
447
  attn_output = self.o_proj(attn_output)
448
 
 
 
 
449
  return attn_output, attn_weights
450
 
451
  class RWKV6Qwen2DecoderLayer(Qwen2DecoderLayer):
@@ -680,36 +680,23 @@ class RWKV6Qwen2Model(RWKV6Qwen2PreTrainedModel):
680
  if (input_ids is None) ^ (inputs_embeds is not None):
681
  raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
682
 
683
- if self.gradient_checkpointing and self.training:
684
- if use_cache:
685
- logger.warning_once(
686
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
687
- )
688
- use_cache = False
689
-
690
- # kept for BC (non `Cache` `past_key_values` inputs)
691
- #return_legacy_cache = False
692
- if use_cache and not isinstance(past_key_values, RWKV6State):
693
- #return_legacy_cache = True
694
- past_key_values = RWKV6State()
695
- # if past_key_values is None:
696
- # past_key_values = DynamicCache()
697
- # else:
698
- # past_key_values = DynamicCache.from_legacy_cache(past_key_values)
699
- # logger.warning_once(
700
- # "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
701
- # "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
702
- # "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
703
- # )
704
 
705
  if inputs_embeds is None:
706
  inputs_embeds = self.embed_tokens(input_ids)
707
 
708
- if cache_position is None:
709
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
710
- cache_position = torch.arange(
711
- past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
712
- )
 
 
 
713
 
714
  if position_ids is None:
715
  position_ids = cache_position.unsqueeze(0)
@@ -723,9 +710,10 @@ class RWKV6Qwen2Model(RWKV6Qwen2PreTrainedModel):
723
  hidden_states = inputs_embeds
724
 
725
  # create position embeddings to be shared across the decoder layers
726
- position_embeddings = None
727
  if self.config.use_rope:
728
  position_embeddings = self.rotary_emb(hidden_states, position_ids)
 
 
729
 
730
  # decoder layers
731
  all_hidden_states = () if output_hidden_states else None
@@ -902,41 +890,6 @@ class RWKV6Qwen2ForCausalLM(RWKV6Qwen2PreTrainedModel, GenerationMixin):
902
  attentions=outputs.attentions,
903
  )
904
 
905
- def prepare_inputs_for_generation(
906
- self,
907
- input_ids: torch.LongTensor,
908
- past_key_values: Optional[Cache] = None,
909
- attention_mask: Optional[torch.LongTensor] = None,
910
- inputs_embeds: Optional[torch.FloatTensor] = None,
911
- cache_position: Optional[torch.LongTensor] = None,
912
- **kwargs,
913
- ):
914
- # only last token for `inputs_ids` if the `past_key_values` is not empty.
915
- if past_key_values is not None and len(past_key_values) > 0:
916
- input_ids = input_ids[:, -1:]
917
-
918
- model_inputs = {
919
- 'past_key_values': past_key_values,
920
- 'attention_mask': attention_mask,
921
- 'cache_position': cache_position,
922
- }
923
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
924
- if inputs_embeds is not None and past_key_values is None:
925
- model_inputs['inputs_embeds'] = inputs_embeds
926
- else:
927
- # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
928
- # recompiles graphs as the stride of the inputs is a guard.
929
- # Ref: https://github.com/huggingface/transformers/pull/29114
930
- # TODO: use `next_tokens` directly instead.
931
- model_inputs['input_ids'] = input_ids.contiguous()
932
-
933
- model_inputs.update(**kwargs)
934
-
935
- # 8. Remove unexpected `generate` inputs (TODO @joao: fix trainer and examples)
936
- model_inputs.pop("labels", None)
937
-
938
- return model_inputs
939
-
940
  @add_start_docstrings(
941
  """
942
  The RWKV6Qwen2 Model transformer with a sequence classification head on top (linear layer).
 
423
 
424
  # dealing with left-padding
425
  if attention_mask is not None:
426
+ v = v * attention_mask[:, -v.shape[-2]:, None]
427
 
428
  r = r.view(B,T,-1,N).to(v.dtype)
429
  k = k.view(B,T,-1,N).to(v.dtype)
 
436
  output_final_state = not self.training and use_cache and past_key_values is not None
437
  attn_output, output_kv_state = fused_recurrent_gla(r, k, v, log_w, None, scale, input_kv_state, output_final_state)
438
 
 
 
 
439
  attn_output = attn_output.view(B, T, -1)
440
  if self.config.groupnorm_att:
441
  attn_output = self.ln_x(attn_output.view(B * T, -1)).view(B, T, -1)
 
443
  attn_output = attn_output * g
444
  attn_output = self.o_proj(attn_output)
445
 
446
+ if output_final_state:
447
+ past_key_values.update(output_kv_state, output_shift_state, self.layer_idx, T)
448
+
449
  return attn_output, attn_weights
450
 
451
  class RWKV6Qwen2DecoderLayer(Qwen2DecoderLayer):
 
680
  if (input_ids is None) ^ (inputs_embeds is not None):
681
  raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
682
 
683
+ if self.gradient_checkpointing and self.training and use_cache:
684
+ logger.warning_once(
685
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
686
+ )
687
+ use_cache = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
688
 
689
  if inputs_embeds is None:
690
  inputs_embeds = self.embed_tokens(input_ids)
691
 
692
+ if use_cache and not isinstance(past_key_values, RWKV6State):
693
+ past_key_values = RWKV6State()
694
+
695
+ #if cache_position is None:
696
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
697
+ cache_position = torch.arange(
698
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
699
+ )
700
 
701
  if position_ids is None:
702
  position_ids = cache_position.unsqueeze(0)
 
710
  hidden_states = inputs_embeds
711
 
712
  # create position embeddings to be shared across the decoder layers
 
713
  if self.config.use_rope:
714
  position_embeddings = self.rotary_emb(hidden_states, position_ids)
715
+ else:
716
+ position_embeddings = None
717
 
718
  # decoder layers
719
  all_hidden_states = () if output_hidden_states else None
 
890
  attentions=outputs.attentions,
891
  )
892
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
893
  @add_start_docstrings(
894
  """
895
  The RWKV6Qwen2 Model transformer with a sequence classification head on top (linear layer).