Update modeling_rwkv6qwen2.py
Browse files- modeling_rwkv6qwen2.py +34 -120
modeling_rwkv6qwen2.py
CHANGED
|
@@ -834,126 +834,40 @@ class RWKV6Qwen2ForCausalLM(RWKV6Qwen2PreTrainedModel, GenerationMixin):
|
|
| 834 |
attentions=outputs.attentions,
|
| 835 |
)
|
| 836 |
|
| 837 |
-
|
| 838 |
-
|
| 839 |
-
|
| 840 |
-
|
| 841 |
-
|
| 842 |
-
|
| 843 |
-
|
| 844 |
-
|
| 845 |
-
|
| 846 |
-
|
| 847 |
-
|
| 848 |
-
|
| 849 |
-
|
| 850 |
-
|
| 851 |
-
|
| 852 |
-
|
| 853 |
-
|
| 854 |
-
|
| 855 |
-
|
| 856 |
-
|
| 857 |
-
|
| 858 |
-
|
| 859 |
-
|
| 860 |
-
|
| 861 |
-
|
| 862 |
-
|
| 863 |
-
|
| 864 |
-
|
| 865 |
-
|
| 866 |
-
|
| 867 |
-
|
| 868 |
-
|
| 869 |
-
|
| 870 |
-
|
| 871 |
-
# if past_key_values is not None:
|
| 872 |
-
# model_inputs["past_key_values"] = past_key_values
|
| 873 |
-
# if inputs_embeds is not None or cache_position[-1] >= input_ids.shape[1]: # Exception 1 or Exception 3
|
| 874 |
-
# input_ids = input_ids[:, -cache_position.shape[0] :]
|
| 875 |
-
# elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
| 876 |
-
# input_ids = input_ids[:, cache_position]
|
| 877 |
-
|
| 878 |
-
# # 3. Prepare base model inputs
|
| 879 |
-
# input_ids_key = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
|
| 880 |
-
# # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
| 881 |
-
# if not self.config.is_encoder_decoder:
|
| 882 |
-
# if inputs_embeds is not None and cache_position[0] == 0:
|
| 883 |
-
# model_inputs[input_ids_key] = None
|
| 884 |
-
# model_inputs["inputs_embeds"] = inputs_embeds
|
| 885 |
-
# else:
|
| 886 |
-
# # `clone` calls in this function ensure a consistent stride. See #32227
|
| 887 |
-
# model_inputs[input_ids_key] = input_ids.clone(memory_format=torch.contiguous_format)
|
| 888 |
-
# model_inputs["inputs_embeds"] = None
|
| 889 |
-
# else:
|
| 890 |
-
# model_inputs[input_ids_key] = input_ids.clone(memory_format=torch.contiguous_format)
|
| 891 |
-
|
| 892 |
-
# # 4. Create missing `position_ids` on the fly
|
| 893 |
-
# if (attention_mask is not None and kwargs.get("position_ids") is None and "position_ids" in set(inspect.signature(self.forward).parameters.keys())):
|
| 894 |
-
# position_ids = attention_mask.long().cumsum(-1) - 1
|
| 895 |
-
# position_ids.masked_fill_(attention_mask == 0, 1)
|
| 896 |
-
# kwargs["position_ids"] = position_ids # placed in kwargs for further processing (see below)
|
| 897 |
-
|
| 898 |
-
# # 5. Slice model inputs if it's an input that should have the same length as `input_ids`
|
| 899 |
-
# for model_input_name in ["position_ids", "token_type_ids"]:
|
| 900 |
-
# model_input = kwargs.get(model_input_name)
|
| 901 |
-
# if model_input is not None:
|
| 902 |
-
# if past_key_values:
|
| 903 |
-
# model_input = model_input[:, -input_ids.shape[1] :]
|
| 904 |
-
# model_input = model_input.clone(memory_format=torch.contiguous_format)
|
| 905 |
-
# model_inputs[model_input_name] = model_input
|
| 906 |
-
|
| 907 |
-
# # 6. Create 4D attention mask is we are using a `StaticCache` (important for performant compiled forward pass)
|
| 908 |
-
# if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
|
| 909 |
-
# if model_inputs["inputs_embeds"] is not None:
|
| 910 |
-
# batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
|
| 911 |
-
# device = model_inputs["inputs_embeds"].device
|
| 912 |
-
# else:
|
| 913 |
-
# batch_size, sequence_length = model_inputs[input_ids_key].shape
|
| 914 |
-
# device = model_inputs[input_ids_key].device
|
| 915 |
-
|
| 916 |
-
# # Create the causal mask with fixed shape in advance, to reduce recompilations. If the function to create
|
| 917 |
-
# # the 4D causal mask exists, it should be present in the base model (XXXModel class).
|
| 918 |
-
# base_model = getattr(self, self.base_model_prefix, None)
|
| 919 |
-
# if base_model is None:
|
| 920 |
-
# causal_mask_creation_function = getattr(
|
| 921 |
-
# self, "_prepare_4d_causal_attention_mask_with_cache_position", None
|
| 922 |
-
# )
|
| 923 |
-
# else:
|
| 924 |
-
# causal_mask_creation_function = getattr(
|
| 925 |
-
# base_model, "_prepare_4d_causal_attention_mask_with_cache_position", None
|
| 926 |
-
# )
|
| 927 |
-
# if causal_mask_creation_function is None:
|
| 928 |
-
# logger.warning_once(
|
| 929 |
-
# f"{self.__class__.__name__} has no `_prepare_4d_causal_attention_mask_with_cache_position` method "
|
| 930 |
-
# "defined in its base modeling class. Compiled forward passes will be sub-optimal. If you're "
|
| 931 |
-
# "writing code, see Llama for an example implementation. If you're a user, please report this "
|
| 932 |
-
# "issue on GitHub."
|
| 933 |
-
# )
|
| 934 |
-
# else:
|
| 935 |
-
# attention_mask = causal_mask_creation_function(
|
| 936 |
-
# attention_mask,
|
| 937 |
-
# sequence_length=sequence_length,
|
| 938 |
-
# target_length=past_key_values.get_max_cache_shape(),
|
| 939 |
-
# dtype=self.dtype,
|
| 940 |
-
# device=device,
|
| 941 |
-
# cache_position=cache_position,
|
| 942 |
-
# batch_size=batch_size,
|
| 943 |
-
# config=self.config,
|
| 944 |
-
# past_key_values=past_key_values,
|
| 945 |
-
# )
|
| 946 |
-
# if attention_mask is not None:
|
| 947 |
-
# model_inputs["attention_mask"] = attention_mask
|
| 948 |
-
|
| 949 |
-
# # 7. Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
|
| 950 |
-
# for key, value in kwargs.items():
|
| 951 |
-
# if key not in model_inputs:
|
| 952 |
-
# model_inputs[key] = value
|
| 953 |
-
|
| 954 |
-
# # 8. Remove unexpected `generate` inputs (TODO @joao: fix trainer and examples)
|
| 955 |
-
# model_inputs.pop("labels", None)
|
| 956 |
-
# return model_inputs
|
| 957 |
|
| 958 |
@add_start_docstrings(
|
| 959 |
"""
|
|
|
|
| 834 |
attentions=outputs.attentions,
|
| 835 |
)
|
| 836 |
|
| 837 |
+
def prepare_inputs_for_generation(
|
| 838 |
+
self,
|
| 839 |
+
input_ids: torch.LongTensor,
|
| 840 |
+
past_key_values: Optional[Cache] = None,
|
| 841 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 842 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 843 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 844 |
+
**kwargs,
|
| 845 |
+
):
|
| 846 |
+
# only last token for `inputs_ids` if the `past_key_values` is not empty.
|
| 847 |
+
if past_key_values is not None and len(past_key_values) > 0:
|
| 848 |
+
input_ids = input_ids[:, -1:]
|
| 849 |
+
|
| 850 |
+
model_inputs = {
|
| 851 |
+
'past_key_values': past_key_values,
|
| 852 |
+
'attention_mask': attention_mask,
|
| 853 |
+
'cache_position': cache_position,
|
| 854 |
+
}
|
| 855 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
| 856 |
+
if inputs_embeds is not None and past_key_values is None:
|
| 857 |
+
model_inputs['inputs_embeds'] = inputs_embeds
|
| 858 |
+
else:
|
| 859 |
+
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
|
| 860 |
+
# recompiles graphs as the stride of the inputs is a guard.
|
| 861 |
+
# Ref: https://github.com/huggingface/transformers/pull/29114
|
| 862 |
+
# TODO: use `next_tokens` directly instead.
|
| 863 |
+
model_inputs['input_ids'] = input_ids.contiguous()
|
| 864 |
+
|
| 865 |
+
model_inputs.update(**kwargs)
|
| 866 |
+
|
| 867 |
+
# 8. Remove unexpected `generate` inputs (TODO @joao: fix trainer and examples)
|
| 868 |
+
model_inputs.pop("labels", None)
|
| 869 |
+
|
| 870 |
+
return model_inputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 871 |
|
| 872 |
@add_start_docstrings(
|
| 873 |
"""
|