Update modeling_rwkv6qwen2.py
Browse filesbugfix for FLA and transformers lib updates
- modeling_rwkv6qwen2.py +232 -164
modeling_rwkv6qwen2.py
CHANGED
@@ -51,6 +51,7 @@ from transformers.utils import (
|
|
51 |
from .configuration_rwkv6qwen2 import RWKV6Qwen2Config
|
52 |
|
53 |
from transformers.models.qwen2.modeling_qwen2 import Qwen2DecoderLayer, Qwen2MLP, Qwen2RMSNorm, repeat_kv
|
|
|
54 |
|
55 |
logger = logging.get_logger(__name__)
|
56 |
|
@@ -58,7 +59,7 @@ logger = logging.get_logger(__name__)
|
|
58 |
_CHECKPOINT_FOR_DOC = "RWKV/RWKV6Qwen2-7B"
|
59 |
_CONFIG_FOR_DOC = "RWKV6Qwen2Config"
|
60 |
|
61 |
-
class RWKV6State(
|
62 |
def __init__(self) -> None:
|
63 |
super().__init__()
|
64 |
self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
|
@@ -113,28 +114,6 @@ class RWKV6State(Cache):
|
|
113 |
"""
|
114 |
return None
|
115 |
|
116 |
-
# def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
|
117 |
-
# """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format. Used for
|
118 |
-
# backward compatibility."""
|
119 |
-
# legacy_cache = ()
|
120 |
-
# for layer_idx in range(len(self)):
|
121 |
-
# legacy_cache += ((self.layer_kv_states[layer_idx], self.layer_shift_states[layer_idx]),)
|
122 |
-
# return legacy_cache
|
123 |
-
|
124 |
-
# @classmethod
|
125 |
-
# #@deprecate_kwarg("num_hidden_layers", version="4.47.0")
|
126 |
-
# def from_legacy_cache(
|
127 |
-
# cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor, torch.FloatTensor]]] = None, num_hidden_layers: int | None = None
|
128 |
-
# ) -> "RWKV6State":
|
129 |
-
# """Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for
|
130 |
-
# backward compatibility."""
|
131 |
-
# cache = cls()
|
132 |
-
# if past_key_values is not None:
|
133 |
-
# for layer_idx in range(len(past_key_values)):
|
134 |
-
# layer_kv_state, layer_shift_state = past_key_values[layer_idx]
|
135 |
-
# cache.update(layer_kv_state, layer_shift_state, layer_idx)
|
136 |
-
# return cache
|
137 |
-
|
138 |
def crop(self, max_length: int):
|
139 |
# can't implement this for linear attention variants
|
140 |
return
|
@@ -144,8 +123,8 @@ class RWKV6State(Cache):
|
|
144 |
self,
|
145 |
kv_state: torch.Tensor,
|
146 |
shift_state: torch.Tensor,
|
147 |
-
token_count: int,
|
148 |
layer_idx: int,
|
|
|
149 |
cache_kwargs: Optional[Dict[str, Any]] = None,
|
150 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
151 |
# Update the number of seen tokens
|
@@ -162,62 +141,140 @@ class RWKV6State(Cache):
|
|
162 |
|
163 |
return self.layer_kv_states[layer_idx], self.layer_shift_states[layer_idx]
|
164 |
|
165 |
-
# @deprecate_kwarg("num_hidden_layers", version="4.47.0")
|
166 |
-
# def batch_split(
|
167 |
-
# self, full_batch_size: int, split_size: int, num_hidden_layers: int = None
|
168 |
-
# ) -> List["DynamicCache"]:
|
169 |
-
# """Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
|
170 |
-
# `_split_model_inputs()` in `generation.utils`"""
|
171 |
-
# out = []
|
172 |
-
# for i in range(0, full_batch_size, split_size):
|
173 |
-
# current_split = DynamicCache()
|
174 |
-
# current_split._seen_tokens = self._seen_tokens
|
175 |
-
# current_split.key_cache = [tensor[i : i + split_size] for tensor in self.key_cache]
|
176 |
-
# current_split.value_cache = [tensor[i : i + split_size] for tensor in self.value_cache]
|
177 |
-
# out.append(current_split)
|
178 |
-
# return out
|
179 |
-
|
180 |
-
# @classmethod
|
181 |
-
# @deprecate_kwarg("num_hidden_layers", version="4.47.0")
|
182 |
-
# def from_batch_splits(cls, splits: List["DynamicCache"], num_hidden_layers: int = None) -> "DynamicCache":
|
183 |
-
# """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in
|
184 |
-
# `generation.utils`"""
|
185 |
-
# cache = cls()
|
186 |
-
# for idx in range(len(splits[0])):
|
187 |
-
# key_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx] != []]
|
188 |
-
# value_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx] != []]
|
189 |
-
# if key_cache != []:
|
190 |
-
# layer_keys = torch.cat(key_cache, dim=0)
|
191 |
-
# layer_values = torch.cat(value_cache, dim=0)
|
192 |
-
# cache.update(layer_keys, layer_values, idx)
|
193 |
-
# return cache
|
194 |
-
|
195 |
-
# def batch_repeat_interleave(self, repeats: int):
|
196 |
-
# """Repeat the cache `repeats` times in the batch dimension. Used in contrastive search."""
|
197 |
-
# for layer_idx in range(len(self)):
|
198 |
-
# self.key_cache[layer_idx] = self.key_cache[layer_idx].repeat_interleave(repeats, dim=0)
|
199 |
-
# self.value_cache[layer_idx] = self.value_cache[layer_idx].repeat_interleave(repeats, dim=0)
|
200 |
-
|
201 |
-
# def batch_select_indices(self, indices: torch.Tensor):
|
202 |
-
# """Only keep the `indices` in the batch dimension of the cache. Used in contrastive search."""
|
203 |
-
# for layer_idx in range(len(self)):
|
204 |
-
# self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...]
|
205 |
-
# self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...]
|
206 |
-
|
207 |
try:
|
208 |
#from fla.ops.gla.chunk import chunk_gla
|
209 |
from fla.ops.gla.fused_recurrent import fused_recurrent_gla
|
210 |
except ImportError:
|
211 |
print("Required module is not installed. Please install it using the following commands:")
|
212 |
-
print("pip install -
|
213 |
print("Additionally, ensure you have at least version 2.2.0 of Triton installed:")
|
214 |
print("pip install triton>=2.2.0")
|
215 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
216 |
class RWKV6Attention(nn.Module):
|
217 |
def __init__(self, config, layer_idx: Optional[int] = None):
|
218 |
super().__init__()
|
219 |
self.config = config
|
220 |
self.layer_idx = layer_idx
|
|
|
221 |
if layer_idx is None:
|
222 |
logger.warning_once(
|
223 |
f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
|
@@ -232,6 +289,11 @@ class RWKV6Attention(nn.Module):
|
|
232 |
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
233 |
self.attention_dropout = config.attention_dropout
|
234 |
|
|
|
|
|
|
|
|
|
|
|
235 |
if self.hidden_size % self.num_heads != 0:
|
236 |
raise ValueError(
|
237 |
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
@@ -242,41 +304,55 @@ class RWKV6Attention(nn.Module):
|
|
242 |
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
243 |
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=getattr(config, 'attention_output_bias', config.attention_bias))
|
244 |
|
245 |
-
|
246 |
-
nn.init.zeros_(self.gate.weight)
|
247 |
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
|
|
|
|
|
|
|
|
|
|
252 |
|
253 |
with torch.no_grad():
|
|
|
|
|
|
|
|
|
|
|
|
|
254 |
ratio_0_to_1 = layer_id / (n_layer - 1) # 0 to 1
|
255 |
ratio_1_to_almost0 = 1.0 - (layer_id / n_layer) # 1 to ~0
|
256 |
-
|
257 |
-
|
258 |
-
ddd
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
|
|
|
|
|
|
|
|
|
|
271 |
|
272 |
# RWKV-6
|
273 |
decay_speed = torch.ones(dim_att)
|
274 |
for n in range(dim_att):
|
275 |
decay_speed[n] = -6 + 5 * (n / (dim_att - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
|
276 |
self.time_decay = nn.Parameter(decay_speed.reshape(1,1,dim_att))
|
277 |
-
|
278 |
-
self.
|
279 |
-
self.time_decay_w2 = nn.Parameter(torch.zeros(D_DECAY_LORA, dim_att).uniform_(-0.01, 0.01))
|
280 |
|
281 |
def forward(
|
282 |
self,
|
@@ -291,11 +367,12 @@ class RWKV6Attention(nn.Module):
|
|
291 |
):
|
292 |
output_shift_state = hidden_states[:, -1:].detach().clone()
|
293 |
|
294 |
-
bsz, q_len, hidden_dim = hidden_states.size()
|
295 |
-
H = self.num_heads
|
296 |
-
|
297 |
x = hidden_states
|
298 |
|
|
|
|
|
|
|
|
|
299 |
if use_cache and past_key_values is not None and len(past_key_values) > self.layer_idx:
|
300 |
input_kv_state, input_shift_state = past_key_values[self.layer_idx]
|
301 |
xprev = torch.cat([input_shift_state, x[:, :-1]], dim=1)
|
@@ -303,83 +380,71 @@ class RWKV6Attention(nn.Module):
|
|
303 |
input_kv_state = None
|
304 |
xprev = F.pad(x, (0, 0, 1, -1))
|
305 |
|
306 |
-
|
307 |
-
|
308 |
-
xxx = x + dxprev * self.time_maa_x
|
309 |
-
xxx = torch.tanh(xxx @ self.time_maa_w1).view(bsz*q_len, self.time_maa_w2.size(0), -1).transpose(0, 1)
|
310 |
-
xxx = torch.bmm(xxx, self.time_maa_w2).view(self.time_maa_w2.size(0), bsz, q_len, hidden_dim)
|
311 |
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
xv = x + dxprev * (self.time_maa_v + mv)
|
316 |
-
xw = x + dxprev * (self.time_maa_w + mw)
|
317 |
-
xg = x + dxprev * (self.time_maa_g + mg)
|
318 |
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
329 |
|
330 |
# repeat k/v heads if n_kv_heads < n_heads
|
331 |
-
|
332 |
-
|
333 |
dropout_rate = 0.0 if not self.training else self.attention_dropout
|
334 |
|
335 |
-
|
336 |
-
|
337 |
-
|
|
|
338 |
|
|
|
339 |
if attention_mask is not None:
|
340 |
-
|
341 |
-
decay_states_log = decay_states_log - 100 * F.pad(1 - attention_mask, [1, -1]).view(bsz, 1, q_len, 1)
|
342 |
-
|
343 |
-
query_states = query_states.to(value_states.dtype)
|
344 |
-
key_states = key_states.to(value_states.dtype)
|
345 |
-
|
346 |
-
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
347 |
-
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
348 |
-
# cast them back in float16 just to be sure everything works as expected.
|
349 |
-
input_dtype = query_states.dtype
|
350 |
-
if input_dtype == torch.float32:
|
351 |
-
if torch.is_autocast_enabled():
|
352 |
-
target_dtype = torch.get_autocast_gpu_dtype()
|
353 |
-
# Handle the case where the model is quantized
|
354 |
-
elif hasattr(self.config, "_pre_quantization_dtype"):
|
355 |
-
target_dtype = self.config._pre_quantization_dtype
|
356 |
-
else:
|
357 |
-
target_dtype = self.q_proj.weight.dtype
|
358 |
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
)
|
364 |
-
|
365 |
-
query_states = query_states.to(target_dtype)
|
366 |
-
key_states = key_states.to(target_dtype)
|
367 |
-
value_states = value_states.to(target_dtype)
|
368 |
|
369 |
attn_weights = torch.empty(0, device=x.device)
|
370 |
|
371 |
-
scale =
|
372 |
output_final_state = not self.training and use_cache and past_key_values is not None
|
373 |
-
|
374 |
-
#attn_output, output_kv_state = chunk_gla(query_states, key_states, value_states, decay_states_log, scale, input_kv_state, output_final_state)
|
375 |
-
attn_output, output_kv_state = fused_recurrent_gla(query_states, key_states, value_states, decay_states_log, None, scale, input_kv_state, output_final_state)
|
376 |
|
377 |
if output_final_state:
|
378 |
-
past_key_values.update(output_kv_state, output_shift_state,
|
379 |
|
380 |
-
attn_output = attn_output.
|
381 |
-
|
382 |
-
|
|
|
|
|
|
|
383 |
|
384 |
return attn_output, attn_weights
|
385 |
|
@@ -578,7 +643,7 @@ class RWKV6Qwen2Model(RWKV6Qwen2PreTrainedModel):
|
|
578 |
)
|
579 |
self._attn_implementation = config._attn_implementation
|
580 |
self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
581 |
-
|
582 |
|
583 |
self.gradient_checkpointing = False
|
584 |
# Initialize weights and apply final processing
|
@@ -640,13 +705,14 @@ class RWKV6Qwen2Model(RWKV6Qwen2PreTrainedModel):
|
|
640 |
if inputs_embeds is None:
|
641 |
inputs_embeds = self.embed_tokens(input_ids)
|
642 |
|
643 |
-
|
644 |
-
|
645 |
-
|
646 |
-
|
647 |
-
|
648 |
-
|
649 |
-
|
|
|
650 |
|
651 |
# causal_mask = self._update_causal_mask(
|
652 |
# attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
@@ -657,7 +723,9 @@ class RWKV6Qwen2Model(RWKV6Qwen2PreTrainedModel):
|
|
657 |
hidden_states = inputs_embeds
|
658 |
|
659 |
# create position embeddings to be shared across the decoder layers
|
660 |
-
position_embeddings = None
|
|
|
|
|
661 |
|
662 |
# decoder layers
|
663 |
all_hidden_states = () if output_hidden_states else None
|
|
|
51 |
from .configuration_rwkv6qwen2 import RWKV6Qwen2Config
|
52 |
|
53 |
from transformers.models.qwen2.modeling_qwen2 import Qwen2DecoderLayer, Qwen2MLP, Qwen2RMSNorm, repeat_kv
|
54 |
+
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
55 |
|
56 |
logger = logging.get_logger(__name__)
|
57 |
|
|
|
59 |
_CHECKPOINT_FOR_DOC = "RWKV/RWKV6Qwen2-7B"
|
60 |
_CONFIG_FOR_DOC = "RWKV6Qwen2Config"
|
61 |
|
62 |
+
class RWKV6State():
|
63 |
def __init__(self) -> None:
|
64 |
super().__init__()
|
65 |
self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
|
|
|
114 |
"""
|
115 |
return None
|
116 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
def crop(self, max_length: int):
|
118 |
# can't implement this for linear attention variants
|
119 |
return
|
|
|
123 |
self,
|
124 |
kv_state: torch.Tensor,
|
125 |
shift_state: torch.Tensor,
|
|
|
126 |
layer_idx: int,
|
127 |
+
token_count: int = 0,
|
128 |
cache_kwargs: Optional[Dict[str, Any]] = None,
|
129 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
130 |
# Update the number of seen tokens
|
|
|
141 |
|
142 |
return self.layer_kv_states[layer_idx], self.layer_shift_states[layer_idx]
|
143 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
144 |
try:
|
145 |
#from fla.ops.gla.chunk import chunk_gla
|
146 |
from fla.ops.gla.fused_recurrent import fused_recurrent_gla
|
147 |
except ImportError:
|
148 |
print("Required module is not installed. Please install it using the following commands:")
|
149 |
+
print("pip install --no-use-pep517 flash-linear-attention")
|
150 |
print("Additionally, ensure you have at least version 2.2.0 of Triton installed:")
|
151 |
print("pip install triton>=2.2.0")
|
152 |
|
153 |
+
class Qwen2RotaryEmbedding(nn.Module):
|
154 |
+
def __init__(self, config: RWKV6Qwen2Config, device=None):
|
155 |
+
super().__init__()
|
156 |
+
# BC: "rope_type" was originally "type"
|
157 |
+
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
158 |
+
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
159 |
+
else:
|
160 |
+
self.rope_type = "default"
|
161 |
+
self.max_seq_len_cached = config.max_position_embeddings
|
162 |
+
self.original_max_seq_len = config.max_position_embeddings
|
163 |
+
|
164 |
+
self.config = config
|
165 |
+
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
166 |
+
|
167 |
+
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
168 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
169 |
+
self.original_inv_freq = self.inv_freq
|
170 |
+
|
171 |
+
def _dynamic_frequency_update(self, position_ids, device):
|
172 |
+
"""
|
173 |
+
dynamic RoPE layers should recompute `inv_freq` in the following situations:
|
174 |
+
1 - growing beyond the cached sequence length (allow scaling)
|
175 |
+
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
|
176 |
+
"""
|
177 |
+
seq_len = torch.max(position_ids) + 1
|
178 |
+
if seq_len > self.max_seq_len_cached: # growth
|
179 |
+
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
|
180 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
181 |
+
self.max_seq_len_cached = seq_len
|
182 |
+
|
183 |
+
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
184 |
+
# This .to() is needed if the model has been moved to a device after being initialized (because
|
185 |
+
# the buffer is automatically moved, but not the original copy)
|
186 |
+
self.original_inv_freq = self.original_inv_freq.to(device)
|
187 |
+
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
188 |
+
self.max_seq_len_cached = self.original_max_seq_len
|
189 |
+
|
190 |
+
@torch.no_grad()
|
191 |
+
def forward(self, x, position_ids):
|
192 |
+
if "dynamic" in self.rope_type:
|
193 |
+
self._dynamic_frequency_update(position_ids, device=x.device)
|
194 |
+
|
195 |
+
# Core RoPE block
|
196 |
+
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
197 |
+
position_ids_expanded = position_ids[:, None, :].float()
|
198 |
+
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
|
199 |
+
device_type = x.device.type
|
200 |
+
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
201 |
+
with torch.autocast(device_type=device_type, enabled=False):
|
202 |
+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
203 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
204 |
+
cos = emb.cos()
|
205 |
+
sin = emb.sin()
|
206 |
+
|
207 |
+
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
|
208 |
+
cos = cos * self.attention_scaling
|
209 |
+
sin = sin * self.attention_scaling
|
210 |
+
|
211 |
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
212 |
+
|
213 |
+
def generate_rotary_embedding(max_seqlen:int, dim:int, theta:float = 10000.0, scale:float = 1):
|
214 |
+
#inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float).to(device) / dim))
|
215 |
+
|
216 |
+
angular_velocity = theta ** -(torch.arange(0, dim, 2, dtype=torch.float) / dim) / scale # frequencies from 1.0 ... 1/theta
|
217 |
+
angles = torch.outer(torch.arange(max_seqlen), angular_velocity)
|
218 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
219 |
+
emb = torch.cat((angles, angles), dim=-1)
|
220 |
+
return torch.stack([emb.cos(), emb.sin()], dim=0)
|
221 |
+
#return torch.polar(torch.ones_like(angles), angles)
|
222 |
+
|
223 |
+
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
224 |
+
def rotate_half(x):
|
225 |
+
"""Rotates half the hidden dims of the input."""
|
226 |
+
x1 = x[..., : x.shape[-1] // 2]
|
227 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
228 |
+
return torch.cat((-x2, x1), dim=-1)
|
229 |
+
|
230 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
231 |
+
"""Applies Rotary Position Embedding to the query and key tensors.
|
232 |
+
|
233 |
+
Args:
|
234 |
+
q (`torch.Tensor`): The query tensor.
|
235 |
+
k (`torch.Tensor`): The key tensor.
|
236 |
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
237 |
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
238 |
+
position_ids (`torch.Tensor`, *optional*):
|
239 |
+
Deprecated and unused.
|
240 |
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
241 |
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
242 |
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
243 |
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
244 |
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
245 |
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
246 |
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
247 |
+
Returns:
|
248 |
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
249 |
+
"""
|
250 |
+
cos = cos.unsqueeze(unsqueeze_dim)
|
251 |
+
sin = sin.unsqueeze(unsqueeze_dim)
|
252 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
253 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
254 |
+
return q_embed, k_embed
|
255 |
+
|
256 |
+
def ortho_init(x, scale):
|
257 |
+
with torch.no_grad():
|
258 |
+
shape = x.shape
|
259 |
+
if len(shape) == 2:
|
260 |
+
gain = math.sqrt(shape[0] / shape[1]) if shape[0] > shape[1] else 1
|
261 |
+
#nn.init.orthogonal_(x, gain=gain * scale)
|
262 |
+
x.copy_(nn.init.orthogonal_(torch.empty_like(x, dtype=torch.float32), gain=gain * scale))
|
263 |
+
elif len(shape) == 3:
|
264 |
+
gain = math.sqrt(shape[1] / shape[2]) if shape[1] > shape[2] else 1
|
265 |
+
for i in range(shape[0]):
|
266 |
+
#nn.init.orthogonal_(x[i], gain=gain * scale)
|
267 |
+
x[i].copy_(nn.init.orthogonal_(torch.empty_like(x[i], dtype=torch.float32), gain=gain * scale))
|
268 |
+
else:
|
269 |
+
assert False
|
270 |
+
return x
|
271 |
+
|
272 |
class RWKV6Attention(nn.Module):
|
273 |
def __init__(self, config, layer_idx: Optional[int] = None):
|
274 |
super().__init__()
|
275 |
self.config = config
|
276 |
self.layer_idx = layer_idx
|
277 |
+
|
278 |
if layer_idx is None:
|
279 |
logger.warning_once(
|
280 |
f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
|
|
|
289 |
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
290 |
self.attention_dropout = config.attention_dropout
|
291 |
|
292 |
+
n_layer = self.config.num_hidden_layers
|
293 |
+
n_embd = self.hidden_size
|
294 |
+
dim_att = self.num_heads * self.head_dim
|
295 |
+
layer_id = self.layer_idx
|
296 |
+
|
297 |
if self.hidden_size % self.num_heads != 0:
|
298 |
raise ValueError(
|
299 |
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
|
|
304 |
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
305 |
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=getattr(config, 'attention_output_bias', config.attention_bias))
|
306 |
|
307 |
+
calc_lora_rank = lambda exponent, multiplier: max(1, round(self.hidden_size ** exponent * multiplier / 32)) * 32
|
|
|
308 |
|
309 |
+
if config.gate_rank_type == 1:
|
310 |
+
self.gate = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
311 |
+
elif config.gate_rank_type == 2:
|
312 |
+
lora_rank_gate = config.lora_rank_gate or calc_lora_rank(0.8, 0.6)
|
313 |
+
self.g1 = nn.Parameter(torch.empty(n_embd, lora_rank_gate))
|
314 |
+
self.g2 = nn.Parameter(torch.empty(lora_rank_gate, n_embd))
|
315 |
+
|
316 |
+
if config.groupnorm_att:
|
317 |
+
self.ln_x = nn.GroupNorm(self.num_heads, dim_att, eps=self.head_dim * 1e-5)
|
318 |
|
319 |
with torch.no_grad():
|
320 |
+
if config.gate_rank_type == 1:
|
321 |
+
self.gate.weight.zero_()
|
322 |
+
elif config.gate_rank_type == 2:
|
323 |
+
self.g1.zero_()
|
324 |
+
ortho_init(self.g2, 0.1)
|
325 |
+
|
326 |
ratio_0_to_1 = layer_id / (n_layer - 1) # 0 to 1
|
327 |
ratio_1_to_almost0 = 1.0 - (layer_id / n_layer) # 1 to ~0
|
328 |
+
|
329 |
+
if self.config.use_tokenshift:
|
330 |
+
ddd = torch.ones(1, 1, n_embd)
|
331 |
+
for i in range(n_embd):
|
332 |
+
ddd[0, 0, i] = i / n_embd
|
333 |
+
|
334 |
+
ddd = torch.zeros(1, 1, n_embd)
|
335 |
+
self.time_maa_x = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
|
336 |
+
self.time_maa_r = nn.Parameter(torch.zeros_like(ddd))
|
337 |
+
self.time_maa_k = nn.Parameter(torch.zeros_like(ddd))
|
338 |
+
self.time_maa_v = nn.Parameter(torch.zeros_like(ddd))
|
339 |
+
self.time_maa_w = nn.Parameter(torch.zeros_like(ddd))
|
340 |
+
self.time_maa_g = nn.Parameter(torch.zeros_like(ddd))
|
341 |
+
|
342 |
+
lora_rank_tokenshift = config.lora_rank_tokenshift or (32 if n_embd < 4096 else 64)
|
343 |
+
|
344 |
+
self.time_maa_w2 = nn.Parameter(torch.zeros(5, lora_rank_tokenshift, n_embd).uniform_(-0.01, 0.01))
|
345 |
+
self.time_maa_w1 = nn.Parameter(torch.zeros(n_embd, lora_rank_tokenshift*self.time_maa_w2.size(0)))
|
346 |
+
|
347 |
+
lora_rank_decay = config.lora_rank_decay or (64 if n_embd < 4096 else 128)
|
348 |
|
349 |
# RWKV-6
|
350 |
decay_speed = torch.ones(dim_att)
|
351 |
for n in range(dim_att):
|
352 |
decay_speed[n] = -6 + 5 * (n / (dim_att - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
|
353 |
self.time_decay = nn.Parameter(decay_speed.reshape(1,1,dim_att))
|
354 |
+
self.time_decay_w1 = nn.Parameter(torch.zeros(n_embd, lora_rank_decay))
|
355 |
+
self.time_decay_w2 = nn.Parameter(torch.zeros(lora_rank_decay, dim_att).uniform_(-0.01, 0.01))
|
|
|
356 |
|
357 |
def forward(
|
358 |
self,
|
|
|
367 |
):
|
368 |
output_shift_state = hidden_states[:, -1:].detach().clone()
|
369 |
|
|
|
|
|
|
|
370 |
x = hidden_states
|
371 |
|
372 |
+
B, T, C = hidden_states.shape
|
373 |
+
H = self.num_heads
|
374 |
+
N = self.head_dim
|
375 |
+
|
376 |
if use_cache and past_key_values is not None and len(past_key_values) > self.layer_idx:
|
377 |
input_kv_state, input_shift_state = past_key_values[self.layer_idx]
|
378 |
xprev = torch.cat([input_shift_state, x[:, :-1]], dim=1)
|
|
|
380 |
input_kv_state = None
|
381 |
xprev = F.pad(x, (0, 0, 1, -1))
|
382 |
|
383 |
+
if self.config.use_tokenshift:
|
384 |
+
dxprev = xprev - x
|
|
|
|
|
|
|
385 |
|
386 |
+
xxx = x + dxprev * self.time_maa_x
|
387 |
+
xxx = torch.tanh(xxx @ self.time_maa_w1).view(B*T, self.time_maa_w2.size(0), -1).transpose(0, 1)
|
388 |
+
xxx = torch.bmm(xxx, self.time_maa_w2).view(self.time_maa_w2.size(0), B, T, C)
|
|
|
|
|
|
|
389 |
|
390 |
+
mr, mk, mv, mw, mg = xxx.unbind(dim=0)
|
391 |
+
xr = x + dxprev * (self.time_maa_r + mr)
|
392 |
+
xk = x + dxprev * (self.time_maa_k + mk)
|
393 |
+
xv = x + dxprev * (self.time_maa_v + mv)
|
394 |
+
xw = x + dxprev * (self.time_maa_w + mw)
|
395 |
+
xg = x + dxprev * (self.time_maa_g + mg)
|
396 |
+
else:
|
397 |
+
xr = xk = xv = xw = xg = x
|
398 |
+
|
399 |
+
r = self.q_proj(xr)
|
400 |
+
k = self.k_proj(xk)
|
401 |
+
v = self.v_proj(xv)
|
402 |
+
w_lora_result = (self.time_decay + torch.tanh(xw @ self.time_decay_w1) @ self.time_decay_w2).to(r.dtype)
|
403 |
+
if self.config.gate_rank_type == 1:
|
404 |
+
g = torch.sigmoid(self.gate(xg))
|
405 |
+
elif self.config.gate_rank_type == 2:
|
406 |
+
g = torch.sigmoid(xg @ self.g1) @ self.g2
|
407 |
+
|
408 |
+
if position_embeddings is not None:
|
409 |
+
r = r.view(B,T,-1,N)
|
410 |
+
k = k.view(B,T,-1,N)
|
411 |
+
cos, sin = position_embeddings
|
412 |
+
r, k = apply_rotary_pos_emb(r, k, cos, sin, unsqueeze_dim=2)
|
413 |
|
414 |
# repeat k/v heads if n_kv_heads < n_heads
|
415 |
+
k = k.view(B, T, -1, 1, self.head_dim).expand(-1, -1, -1, self.num_key_value_groups, -1).reshape(B, T, -1)
|
416 |
+
v = v.view(B, T, -1, 1, self.head_dim).expand(-1, -1, -1, self.num_key_value_groups, -1).reshape(B, T, -1)
|
417 |
dropout_rate = 0.0 if not self.training else self.attention_dropout
|
418 |
|
419 |
+
log_w = -w_lora_result.float().exp()
|
420 |
+
log_w = log_w.clamp(-5)
|
421 |
+
if self.config.balance_state:
|
422 |
+
k = (k * (1 - log_w.exp())).to(k.dtype)
|
423 |
|
424 |
+
# dealing with left-padding
|
425 |
if attention_mask is not None:
|
426 |
+
v = v * attention_mask[:, None, -v.shape[-2]:, None]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
427 |
|
428 |
+
r = r.view(B,T,-1,N).to(v.dtype)
|
429 |
+
k = k.view(B,T,-1,N).to(v.dtype)
|
430 |
+
v = v.view(B,T,-1,N)
|
431 |
+
log_w = log_w.view(B,T,-1,N)
|
|
|
|
|
|
|
|
|
|
|
432 |
|
433 |
attn_weights = torch.empty(0, device=x.device)
|
434 |
|
435 |
+
scale = r.shape[-1] ** -0.5
|
436 |
output_final_state = not self.training and use_cache and past_key_values is not None
|
437 |
+
attn_output, output_kv_state = fused_recurrent_gla(r, k, v, log_w, None, scale, input_kv_state, output_final_state)
|
|
|
|
|
438 |
|
439 |
if output_final_state:
|
440 |
+
past_key_values.update(output_kv_state, output_shift_state, T, self.layer_idx)
|
441 |
|
442 |
+
attn_output = attn_output.view(B, T, -1)
|
443 |
+
if self.config.groupnorm_att:
|
444 |
+
attn_output = self.ln_x(attn_output.view(B * T, -1)).view(B, T, -1)
|
445 |
+
if self.config.gate_rank_type != 0:
|
446 |
+
attn_output = attn_output * g
|
447 |
+
attn_output = self.o_proj(attn_output)
|
448 |
|
449 |
return attn_output, attn_weights
|
450 |
|
|
|
643 |
)
|
644 |
self._attn_implementation = config._attn_implementation
|
645 |
self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
646 |
+
self.rotary_emb = Qwen2RotaryEmbedding(config=config)
|
647 |
|
648 |
self.gradient_checkpointing = False
|
649 |
# Initialize weights and apply final processing
|
|
|
705 |
if inputs_embeds is None:
|
706 |
inputs_embeds = self.embed_tokens(input_ids)
|
707 |
|
708 |
+
if cache_position is None:
|
709 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
710 |
+
cache_position = torch.arange(
|
711 |
+
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
712 |
+
)
|
713 |
+
|
714 |
+
if position_ids is None:
|
715 |
+
position_ids = cache_position.unsqueeze(0)
|
716 |
|
717 |
# causal_mask = self._update_causal_mask(
|
718 |
# attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
|
|
723 |
hidden_states = inputs_embeds
|
724 |
|
725 |
# create position embeddings to be shared across the decoder layers
|
726 |
+
position_embeddings = None
|
727 |
+
if self.config.use_rope:
|
728 |
+
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
729 |
|
730 |
# decoder layers
|
731 |
all_hidden_states = () if output_hidden_states else None
|