SmerkyG commited on
Commit
35648d6
·
verified ·
1 Parent(s): 93e5064

Update modeling_rwkv6qwen2.py

Browse files

bugfix for FLA and transformers lib updates

Files changed (1) hide show
  1. modeling_rwkv6qwen2.py +232 -164
modeling_rwkv6qwen2.py CHANGED
@@ -51,6 +51,7 @@ from transformers.utils import (
51
  from .configuration_rwkv6qwen2 import RWKV6Qwen2Config
52
 
53
  from transformers.models.qwen2.modeling_qwen2 import Qwen2DecoderLayer, Qwen2MLP, Qwen2RMSNorm, repeat_kv
 
54
 
55
  logger = logging.get_logger(__name__)
56
 
@@ -58,7 +59,7 @@ logger = logging.get_logger(__name__)
58
  _CHECKPOINT_FOR_DOC = "RWKV/RWKV6Qwen2-7B"
59
  _CONFIG_FOR_DOC = "RWKV6Qwen2Config"
60
 
61
- class RWKV6State(Cache):
62
  def __init__(self) -> None:
63
  super().__init__()
64
  self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
@@ -113,28 +114,6 @@ class RWKV6State(Cache):
113
  """
114
  return None
115
 
116
- # def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
117
- # """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format. Used for
118
- # backward compatibility."""
119
- # legacy_cache = ()
120
- # for layer_idx in range(len(self)):
121
- # legacy_cache += ((self.layer_kv_states[layer_idx], self.layer_shift_states[layer_idx]),)
122
- # return legacy_cache
123
-
124
- # @classmethod
125
- # #@deprecate_kwarg("num_hidden_layers", version="4.47.0")
126
- # def from_legacy_cache(
127
- # cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor, torch.FloatTensor]]] = None, num_hidden_layers: int | None = None
128
- # ) -> "RWKV6State":
129
- # """Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for
130
- # backward compatibility."""
131
- # cache = cls()
132
- # if past_key_values is not None:
133
- # for layer_idx in range(len(past_key_values)):
134
- # layer_kv_state, layer_shift_state = past_key_values[layer_idx]
135
- # cache.update(layer_kv_state, layer_shift_state, layer_idx)
136
- # return cache
137
-
138
  def crop(self, max_length: int):
139
  # can't implement this for linear attention variants
140
  return
@@ -144,8 +123,8 @@ class RWKV6State(Cache):
144
  self,
145
  kv_state: torch.Tensor,
146
  shift_state: torch.Tensor,
147
- token_count: int,
148
  layer_idx: int,
 
149
  cache_kwargs: Optional[Dict[str, Any]] = None,
150
  ) -> Tuple[torch.Tensor, torch.Tensor]:
151
  # Update the number of seen tokens
@@ -162,62 +141,140 @@ class RWKV6State(Cache):
162
 
163
  return self.layer_kv_states[layer_idx], self.layer_shift_states[layer_idx]
164
 
165
- # @deprecate_kwarg("num_hidden_layers", version="4.47.0")
166
- # def batch_split(
167
- # self, full_batch_size: int, split_size: int, num_hidden_layers: int = None
168
- # ) -> List["DynamicCache"]:
169
- # """Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
170
- # `_split_model_inputs()` in `generation.utils`"""
171
- # out = []
172
- # for i in range(0, full_batch_size, split_size):
173
- # current_split = DynamicCache()
174
- # current_split._seen_tokens = self._seen_tokens
175
- # current_split.key_cache = [tensor[i : i + split_size] for tensor in self.key_cache]
176
- # current_split.value_cache = [tensor[i : i + split_size] for tensor in self.value_cache]
177
- # out.append(current_split)
178
- # return out
179
-
180
- # @classmethod
181
- # @deprecate_kwarg("num_hidden_layers", version="4.47.0")
182
- # def from_batch_splits(cls, splits: List["DynamicCache"], num_hidden_layers: int = None) -> "DynamicCache":
183
- # """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in
184
- # `generation.utils`"""
185
- # cache = cls()
186
- # for idx in range(len(splits[0])):
187
- # key_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx] != []]
188
- # value_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx] != []]
189
- # if key_cache != []:
190
- # layer_keys = torch.cat(key_cache, dim=0)
191
- # layer_values = torch.cat(value_cache, dim=0)
192
- # cache.update(layer_keys, layer_values, idx)
193
- # return cache
194
-
195
- # def batch_repeat_interleave(self, repeats: int):
196
- # """Repeat the cache `repeats` times in the batch dimension. Used in contrastive search."""
197
- # for layer_idx in range(len(self)):
198
- # self.key_cache[layer_idx] = self.key_cache[layer_idx].repeat_interleave(repeats, dim=0)
199
- # self.value_cache[layer_idx] = self.value_cache[layer_idx].repeat_interleave(repeats, dim=0)
200
-
201
- # def batch_select_indices(self, indices: torch.Tensor):
202
- # """Only keep the `indices` in the batch dimension of the cache. Used in contrastive search."""
203
- # for layer_idx in range(len(self)):
204
- # self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...]
205
- # self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...]
206
-
207
  try:
208
  #from fla.ops.gla.chunk import chunk_gla
209
  from fla.ops.gla.fused_recurrent import fused_recurrent_gla
210
  except ImportError:
211
  print("Required module is not installed. Please install it using the following commands:")
212
- print("pip install -U git+https://github.com/fla-org/flash-linear-attention")
213
  print("Additionally, ensure you have at least version 2.2.0 of Triton installed:")
214
  print("pip install triton>=2.2.0")
215
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  class RWKV6Attention(nn.Module):
217
  def __init__(self, config, layer_idx: Optional[int] = None):
218
  super().__init__()
219
  self.config = config
220
  self.layer_idx = layer_idx
 
221
  if layer_idx is None:
222
  logger.warning_once(
223
  f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
@@ -232,6 +289,11 @@ class RWKV6Attention(nn.Module):
232
  self.num_key_value_groups = self.num_heads // self.num_key_value_heads
233
  self.attention_dropout = config.attention_dropout
234
 
 
 
 
 
 
235
  if self.hidden_size % self.num_heads != 0:
236
  raise ValueError(
237
  f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
@@ -242,41 +304,55 @@ class RWKV6Attention(nn.Module):
242
  self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
243
  self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=getattr(config, 'attention_output_bias', config.attention_bias))
244
 
245
- self.gate = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
246
- nn.init.zeros_(self.gate.weight)
247
 
248
- n_layer = self.config.num_hidden_layers
249
- n_embd = self.hidden_size
250
- dim_att = self.num_heads * self.head_dim
251
- layer_id = self.layer_idx
 
 
 
 
 
252
 
253
  with torch.no_grad():
 
 
 
 
 
 
254
  ratio_0_to_1 = layer_id / (n_layer - 1) # 0 to 1
255
  ratio_1_to_almost0 = 1.0 - (layer_id / n_layer) # 1 to ~0
256
- ddd = torch.ones(1, 1, n_embd)
257
- for i in range(n_embd):
258
- ddd[0, 0, i] = i / n_embd
259
-
260
- ddd = torch.zeros(1, 1, n_embd)
261
- self.time_maa_x = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
262
- self.time_maa_r = nn.Parameter(torch.zeros_like(ddd))
263
- self.time_maa_k = nn.Parameter(torch.zeros_like(ddd))
264
- self.time_maa_v = nn.Parameter(torch.zeros_like(ddd))
265
- self.time_maa_w = nn.Parameter(torch.zeros_like(ddd))
266
- self.time_maa_g = nn.Parameter(torch.zeros_like(ddd))
267
-
268
- D_MIX_LORA = 32 if n_embd < 4096 else 64
269
- self.time_maa_w2 = nn.Parameter(torch.zeros(5, D_MIX_LORA, n_embd).uniform_(-0.01, 0.01))
270
- self.time_maa_w1 = nn.Parameter(torch.zeros(n_embd, D_MIX_LORA*self.time_maa_w2.size(0)))
 
 
 
 
 
271
 
272
  # RWKV-6
273
  decay_speed = torch.ones(dim_att)
274
  for n in range(dim_att):
275
  decay_speed[n] = -6 + 5 * (n / (dim_att - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
276
  self.time_decay = nn.Parameter(decay_speed.reshape(1,1,dim_att))
277
- D_DECAY_LORA = 64 if n_embd < 4096 else 128
278
- self.time_decay_w1 = nn.Parameter(torch.zeros(n_embd, D_DECAY_LORA))
279
- self.time_decay_w2 = nn.Parameter(torch.zeros(D_DECAY_LORA, dim_att).uniform_(-0.01, 0.01))
280
 
281
  def forward(
282
  self,
@@ -291,11 +367,12 @@ class RWKV6Attention(nn.Module):
291
  ):
292
  output_shift_state = hidden_states[:, -1:].detach().clone()
293
 
294
- bsz, q_len, hidden_dim = hidden_states.size()
295
- H = self.num_heads
296
-
297
  x = hidden_states
298
 
 
 
 
 
299
  if use_cache and past_key_values is not None and len(past_key_values) > self.layer_idx:
300
  input_kv_state, input_shift_state = past_key_values[self.layer_idx]
301
  xprev = torch.cat([input_shift_state, x[:, :-1]], dim=1)
@@ -303,83 +380,71 @@ class RWKV6Attention(nn.Module):
303
  input_kv_state = None
304
  xprev = F.pad(x, (0, 0, 1, -1))
305
 
306
- dxprev = xprev - x
307
-
308
- xxx = x + dxprev * self.time_maa_x
309
- xxx = torch.tanh(xxx @ self.time_maa_w1).view(bsz*q_len, self.time_maa_w2.size(0), -1).transpose(0, 1)
310
- xxx = torch.bmm(xxx, self.time_maa_w2).view(self.time_maa_w2.size(0), bsz, q_len, hidden_dim)
311
 
312
- mr, mk, mv, mw, mg = xxx.unbind(dim=0)
313
- xr = x + dxprev * (self.time_maa_r + mr)
314
- xk = x + dxprev * (self.time_maa_k + mk)
315
- xv = x + dxprev * (self.time_maa_v + mv)
316
- xw = x + dxprev * (self.time_maa_w + mw)
317
- xg = x + dxprev * (self.time_maa_g + mg)
318
 
319
- query_states = self.q_proj(xr)
320
- key_states = self.k_proj(xk)
321
- value_states = self.v_proj(xv)
322
- decay_states = (self.time_decay + torch.tanh(xw @ self.time_decay_w1) @ self.time_decay_w2).to(query_states.dtype)
323
- gate_states = F.sigmoid(self.gate(xg))
324
-
325
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
326
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
327
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
328
- decay_states = decay_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
 
 
 
 
 
 
 
 
 
 
 
 
 
329
 
330
  # repeat k/v heads if n_kv_heads < n_heads
331
- key_states = repeat_kv(key_states, self.num_key_value_groups)
332
- value_states = repeat_kv(value_states, self.num_key_value_groups)
333
  dropout_rate = 0.0 if not self.training else self.attention_dropout
334
 
335
- decay_states_log = -decay_states.float().exp()
336
- decay_states_log = decay_states_log.clamp(-5) # FIXME - is this necessary?
337
- key_states = (key_states * (1 - decay_states_log.exp())).to(key_states.dtype)
 
338
 
 
339
  if attention_mask is not None:
340
- if q_len > 1:
341
- decay_states_log = decay_states_log - 100 * F.pad(1 - attention_mask, [1, -1]).view(bsz, 1, q_len, 1)
342
-
343
- query_states = query_states.to(value_states.dtype)
344
- key_states = key_states.to(value_states.dtype)
345
-
346
- # In PEFT, usually we cast the layer norms in float32 for training stability reasons
347
- # therefore the input hidden states gets silently casted in float32. Hence, we need
348
- # cast them back in float16 just to be sure everything works as expected.
349
- input_dtype = query_states.dtype
350
- if input_dtype == torch.float32:
351
- if torch.is_autocast_enabled():
352
- target_dtype = torch.get_autocast_gpu_dtype()
353
- # Handle the case where the model is quantized
354
- elif hasattr(self.config, "_pre_quantization_dtype"):
355
- target_dtype = self.config._pre_quantization_dtype
356
- else:
357
- target_dtype = self.q_proj.weight.dtype
358
 
359
- logger.warning_once(
360
- f"The input hidden states seems to be silently casted in float32, this might be related to"
361
- f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
362
- f" {target_dtype}."
363
- )
364
-
365
- query_states = query_states.to(target_dtype)
366
- key_states = key_states.to(target_dtype)
367
- value_states = value_states.to(target_dtype)
368
 
369
  attn_weights = torch.empty(0, device=x.device)
370
 
371
- scale = query_states.shape[-1] ** -0.5
372
  output_final_state = not self.training and use_cache and past_key_values is not None
373
- #attn_output, output_kv_state = ChunkGLAFunction.apply(query_states, key_states, value_states, decay_states_log.float(), scale, input_kv_state, output_final_state)
374
- #attn_output, output_kv_state = chunk_gla(query_states, key_states, value_states, decay_states_log, scale, input_kv_state, output_final_state)
375
- attn_output, output_kv_state = fused_recurrent_gla(query_states, key_states, value_states, decay_states_log, None, scale, input_kv_state, output_final_state)
376
 
377
  if output_final_state:
378
- past_key_values.update(output_kv_state, output_shift_state, q_len, self.layer_idx)
379
 
380
- attn_output = attn_output.transpose(1, 2).contiguous()
381
- attn_output = attn_output.view(bsz, q_len, -1)
382
- attn_output = self.o_proj(attn_output * gate_states)
 
 
 
383
 
384
  return attn_output, attn_weights
385
 
@@ -578,7 +643,7 @@ class RWKV6Qwen2Model(RWKV6Qwen2PreTrainedModel):
578
  )
579
  self._attn_implementation = config._attn_implementation
580
  self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
581
- #self.rotary_emb = Qwen2RotaryEmbedding(config=config)
582
 
583
  self.gradient_checkpointing = False
584
  # Initialize weights and apply final processing
@@ -640,13 +705,14 @@ class RWKV6Qwen2Model(RWKV6Qwen2PreTrainedModel):
640
  if inputs_embeds is None:
641
  inputs_embeds = self.embed_tokens(input_ids)
642
 
643
- # if cache_position is None:
644
- # past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
645
- # cache_position = torch.arange(
646
- # past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
647
- # )
648
- # if position_ids is None:
649
- # position_ids = cache_position.unsqueeze(0)
 
650
 
651
  # causal_mask = self._update_causal_mask(
652
  # attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
@@ -657,7 +723,9 @@ class RWKV6Qwen2Model(RWKV6Qwen2PreTrainedModel):
657
  hidden_states = inputs_embeds
658
 
659
  # create position embeddings to be shared across the decoder layers
660
- position_embeddings = None #self.rotary_emb(hidden_states, position_ids)
 
 
661
 
662
  # decoder layers
663
  all_hidden_states = () if output_hidden_states else None
 
51
  from .configuration_rwkv6qwen2 import RWKV6Qwen2Config
52
 
53
  from transformers.models.qwen2.modeling_qwen2 import Qwen2DecoderLayer, Qwen2MLP, Qwen2RMSNorm, repeat_kv
54
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
55
 
56
  logger = logging.get_logger(__name__)
57
 
 
59
  _CHECKPOINT_FOR_DOC = "RWKV/RWKV6Qwen2-7B"
60
  _CONFIG_FOR_DOC = "RWKV6Qwen2Config"
61
 
62
+ class RWKV6State():
63
  def __init__(self) -> None:
64
  super().__init__()
65
  self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
 
114
  """
115
  return None
116
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  def crop(self, max_length: int):
118
  # can't implement this for linear attention variants
119
  return
 
123
  self,
124
  kv_state: torch.Tensor,
125
  shift_state: torch.Tensor,
 
126
  layer_idx: int,
127
+ token_count: int = 0,
128
  cache_kwargs: Optional[Dict[str, Any]] = None,
129
  ) -> Tuple[torch.Tensor, torch.Tensor]:
130
  # Update the number of seen tokens
 
141
 
142
  return self.layer_kv_states[layer_idx], self.layer_shift_states[layer_idx]
143
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  try:
145
  #from fla.ops.gla.chunk import chunk_gla
146
  from fla.ops.gla.fused_recurrent import fused_recurrent_gla
147
  except ImportError:
148
  print("Required module is not installed. Please install it using the following commands:")
149
+ print("pip install --no-use-pep517 flash-linear-attention")
150
  print("Additionally, ensure you have at least version 2.2.0 of Triton installed:")
151
  print("pip install triton>=2.2.0")
152
 
153
+ class Qwen2RotaryEmbedding(nn.Module):
154
+ def __init__(self, config: RWKV6Qwen2Config, device=None):
155
+ super().__init__()
156
+ # BC: "rope_type" was originally "type"
157
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
158
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
159
+ else:
160
+ self.rope_type = "default"
161
+ self.max_seq_len_cached = config.max_position_embeddings
162
+ self.original_max_seq_len = config.max_position_embeddings
163
+
164
+ self.config = config
165
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
166
+
167
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
168
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
169
+ self.original_inv_freq = self.inv_freq
170
+
171
+ def _dynamic_frequency_update(self, position_ids, device):
172
+ """
173
+ dynamic RoPE layers should recompute `inv_freq` in the following situations:
174
+ 1 - growing beyond the cached sequence length (allow scaling)
175
+ 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
176
+ """
177
+ seq_len = torch.max(position_ids) + 1
178
+ if seq_len > self.max_seq_len_cached: # growth
179
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
180
+ self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
181
+ self.max_seq_len_cached = seq_len
182
+
183
+ if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
184
+ # This .to() is needed if the model has been moved to a device after being initialized (because
185
+ # the buffer is automatically moved, but not the original copy)
186
+ self.original_inv_freq = self.original_inv_freq.to(device)
187
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
188
+ self.max_seq_len_cached = self.original_max_seq_len
189
+
190
+ @torch.no_grad()
191
+ def forward(self, x, position_ids):
192
+ if "dynamic" in self.rope_type:
193
+ self._dynamic_frequency_update(position_ids, device=x.device)
194
+
195
+ # Core RoPE block
196
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
197
+ position_ids_expanded = position_ids[:, None, :].float()
198
+ # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
199
+ device_type = x.device.type
200
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
201
+ with torch.autocast(device_type=device_type, enabled=False):
202
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
203
+ emb = torch.cat((freqs, freqs), dim=-1)
204
+ cos = emb.cos()
205
+ sin = emb.sin()
206
+
207
+ # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
208
+ cos = cos * self.attention_scaling
209
+ sin = sin * self.attention_scaling
210
+
211
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
212
+
213
+ def generate_rotary_embedding(max_seqlen:int, dim:int, theta:float = 10000.0, scale:float = 1):
214
+ #inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float).to(device) / dim))
215
+
216
+ angular_velocity = theta ** -(torch.arange(0, dim, 2, dtype=torch.float) / dim) / scale # frequencies from 1.0 ... 1/theta
217
+ angles = torch.outer(torch.arange(max_seqlen), angular_velocity)
218
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
219
+ emb = torch.cat((angles, angles), dim=-1)
220
+ return torch.stack([emb.cos(), emb.sin()], dim=0)
221
+ #return torch.polar(torch.ones_like(angles), angles)
222
+
223
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
224
+ def rotate_half(x):
225
+ """Rotates half the hidden dims of the input."""
226
+ x1 = x[..., : x.shape[-1] // 2]
227
+ x2 = x[..., x.shape[-1] // 2 :]
228
+ return torch.cat((-x2, x1), dim=-1)
229
+
230
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
231
+ """Applies Rotary Position Embedding to the query and key tensors.
232
+
233
+ Args:
234
+ q (`torch.Tensor`): The query tensor.
235
+ k (`torch.Tensor`): The key tensor.
236
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
237
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
238
+ position_ids (`torch.Tensor`, *optional*):
239
+ Deprecated and unused.
240
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
241
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
242
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
243
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
244
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
245
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
246
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
247
+ Returns:
248
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
249
+ """
250
+ cos = cos.unsqueeze(unsqueeze_dim)
251
+ sin = sin.unsqueeze(unsqueeze_dim)
252
+ q_embed = (q * cos) + (rotate_half(q) * sin)
253
+ k_embed = (k * cos) + (rotate_half(k) * sin)
254
+ return q_embed, k_embed
255
+
256
+ def ortho_init(x, scale):
257
+ with torch.no_grad():
258
+ shape = x.shape
259
+ if len(shape) == 2:
260
+ gain = math.sqrt(shape[0] / shape[1]) if shape[0] > shape[1] else 1
261
+ #nn.init.orthogonal_(x, gain=gain * scale)
262
+ x.copy_(nn.init.orthogonal_(torch.empty_like(x, dtype=torch.float32), gain=gain * scale))
263
+ elif len(shape) == 3:
264
+ gain = math.sqrt(shape[1] / shape[2]) if shape[1] > shape[2] else 1
265
+ for i in range(shape[0]):
266
+ #nn.init.orthogonal_(x[i], gain=gain * scale)
267
+ x[i].copy_(nn.init.orthogonal_(torch.empty_like(x[i], dtype=torch.float32), gain=gain * scale))
268
+ else:
269
+ assert False
270
+ return x
271
+
272
  class RWKV6Attention(nn.Module):
273
  def __init__(self, config, layer_idx: Optional[int] = None):
274
  super().__init__()
275
  self.config = config
276
  self.layer_idx = layer_idx
277
+
278
  if layer_idx is None:
279
  logger.warning_once(
280
  f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
 
289
  self.num_key_value_groups = self.num_heads // self.num_key_value_heads
290
  self.attention_dropout = config.attention_dropout
291
 
292
+ n_layer = self.config.num_hidden_layers
293
+ n_embd = self.hidden_size
294
+ dim_att = self.num_heads * self.head_dim
295
+ layer_id = self.layer_idx
296
+
297
  if self.hidden_size % self.num_heads != 0:
298
  raise ValueError(
299
  f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
 
304
  self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
305
  self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=getattr(config, 'attention_output_bias', config.attention_bias))
306
 
307
+ calc_lora_rank = lambda exponent, multiplier: max(1, round(self.hidden_size ** exponent * multiplier / 32)) * 32
 
308
 
309
+ if config.gate_rank_type == 1:
310
+ self.gate = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
311
+ elif config.gate_rank_type == 2:
312
+ lora_rank_gate = config.lora_rank_gate or calc_lora_rank(0.8, 0.6)
313
+ self.g1 = nn.Parameter(torch.empty(n_embd, lora_rank_gate))
314
+ self.g2 = nn.Parameter(torch.empty(lora_rank_gate, n_embd))
315
+
316
+ if config.groupnorm_att:
317
+ self.ln_x = nn.GroupNorm(self.num_heads, dim_att, eps=self.head_dim * 1e-5)
318
 
319
  with torch.no_grad():
320
+ if config.gate_rank_type == 1:
321
+ self.gate.weight.zero_()
322
+ elif config.gate_rank_type == 2:
323
+ self.g1.zero_()
324
+ ortho_init(self.g2, 0.1)
325
+
326
  ratio_0_to_1 = layer_id / (n_layer - 1) # 0 to 1
327
  ratio_1_to_almost0 = 1.0 - (layer_id / n_layer) # 1 to ~0
328
+
329
+ if self.config.use_tokenshift:
330
+ ddd = torch.ones(1, 1, n_embd)
331
+ for i in range(n_embd):
332
+ ddd[0, 0, i] = i / n_embd
333
+
334
+ ddd = torch.zeros(1, 1, n_embd)
335
+ self.time_maa_x = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
336
+ self.time_maa_r = nn.Parameter(torch.zeros_like(ddd))
337
+ self.time_maa_k = nn.Parameter(torch.zeros_like(ddd))
338
+ self.time_maa_v = nn.Parameter(torch.zeros_like(ddd))
339
+ self.time_maa_w = nn.Parameter(torch.zeros_like(ddd))
340
+ self.time_maa_g = nn.Parameter(torch.zeros_like(ddd))
341
+
342
+ lora_rank_tokenshift = config.lora_rank_tokenshift or (32 if n_embd < 4096 else 64)
343
+
344
+ self.time_maa_w2 = nn.Parameter(torch.zeros(5, lora_rank_tokenshift, n_embd).uniform_(-0.01, 0.01))
345
+ self.time_maa_w1 = nn.Parameter(torch.zeros(n_embd, lora_rank_tokenshift*self.time_maa_w2.size(0)))
346
+
347
+ lora_rank_decay = config.lora_rank_decay or (64 if n_embd < 4096 else 128)
348
 
349
  # RWKV-6
350
  decay_speed = torch.ones(dim_att)
351
  for n in range(dim_att):
352
  decay_speed[n] = -6 + 5 * (n / (dim_att - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
353
  self.time_decay = nn.Parameter(decay_speed.reshape(1,1,dim_att))
354
+ self.time_decay_w1 = nn.Parameter(torch.zeros(n_embd, lora_rank_decay))
355
+ self.time_decay_w2 = nn.Parameter(torch.zeros(lora_rank_decay, dim_att).uniform_(-0.01, 0.01))
 
356
 
357
  def forward(
358
  self,
 
367
  ):
368
  output_shift_state = hidden_states[:, -1:].detach().clone()
369
 
 
 
 
370
  x = hidden_states
371
 
372
+ B, T, C = hidden_states.shape
373
+ H = self.num_heads
374
+ N = self.head_dim
375
+
376
  if use_cache and past_key_values is not None and len(past_key_values) > self.layer_idx:
377
  input_kv_state, input_shift_state = past_key_values[self.layer_idx]
378
  xprev = torch.cat([input_shift_state, x[:, :-1]], dim=1)
 
380
  input_kv_state = None
381
  xprev = F.pad(x, (0, 0, 1, -1))
382
 
383
+ if self.config.use_tokenshift:
384
+ dxprev = xprev - x
 
 
 
385
 
386
+ xxx = x + dxprev * self.time_maa_x
387
+ xxx = torch.tanh(xxx @ self.time_maa_w1).view(B*T, self.time_maa_w2.size(0), -1).transpose(0, 1)
388
+ xxx = torch.bmm(xxx, self.time_maa_w2).view(self.time_maa_w2.size(0), B, T, C)
 
 
 
389
 
390
+ mr, mk, mv, mw, mg = xxx.unbind(dim=0)
391
+ xr = x + dxprev * (self.time_maa_r + mr)
392
+ xk = x + dxprev * (self.time_maa_k + mk)
393
+ xv = x + dxprev * (self.time_maa_v + mv)
394
+ xw = x + dxprev * (self.time_maa_w + mw)
395
+ xg = x + dxprev * (self.time_maa_g + mg)
396
+ else:
397
+ xr = xk = xv = xw = xg = x
398
+
399
+ r = self.q_proj(xr)
400
+ k = self.k_proj(xk)
401
+ v = self.v_proj(xv)
402
+ w_lora_result = (self.time_decay + torch.tanh(xw @ self.time_decay_w1) @ self.time_decay_w2).to(r.dtype)
403
+ if self.config.gate_rank_type == 1:
404
+ g = torch.sigmoid(self.gate(xg))
405
+ elif self.config.gate_rank_type == 2:
406
+ g = torch.sigmoid(xg @ self.g1) @ self.g2
407
+
408
+ if position_embeddings is not None:
409
+ r = r.view(B,T,-1,N)
410
+ k = k.view(B,T,-1,N)
411
+ cos, sin = position_embeddings
412
+ r, k = apply_rotary_pos_emb(r, k, cos, sin, unsqueeze_dim=2)
413
 
414
  # repeat k/v heads if n_kv_heads < n_heads
415
+ k = k.view(B, T, -1, 1, self.head_dim).expand(-1, -1, -1, self.num_key_value_groups, -1).reshape(B, T, -1)
416
+ v = v.view(B, T, -1, 1, self.head_dim).expand(-1, -1, -1, self.num_key_value_groups, -1).reshape(B, T, -1)
417
  dropout_rate = 0.0 if not self.training else self.attention_dropout
418
 
419
+ log_w = -w_lora_result.float().exp()
420
+ log_w = log_w.clamp(-5)
421
+ if self.config.balance_state:
422
+ k = (k * (1 - log_w.exp())).to(k.dtype)
423
 
424
+ # dealing with left-padding
425
  if attention_mask is not None:
426
+ v = v * attention_mask[:, None, -v.shape[-2]:, None]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
427
 
428
+ r = r.view(B,T,-1,N).to(v.dtype)
429
+ k = k.view(B,T,-1,N).to(v.dtype)
430
+ v = v.view(B,T,-1,N)
431
+ log_w = log_w.view(B,T,-1,N)
 
 
 
 
 
432
 
433
  attn_weights = torch.empty(0, device=x.device)
434
 
435
+ scale = r.shape[-1] ** -0.5
436
  output_final_state = not self.training and use_cache and past_key_values is not None
437
+ attn_output, output_kv_state = fused_recurrent_gla(r, k, v, log_w, None, scale, input_kv_state, output_final_state)
 
 
438
 
439
  if output_final_state:
440
+ past_key_values.update(output_kv_state, output_shift_state, T, self.layer_idx)
441
 
442
+ attn_output = attn_output.view(B, T, -1)
443
+ if self.config.groupnorm_att:
444
+ attn_output = self.ln_x(attn_output.view(B * T, -1)).view(B, T, -1)
445
+ if self.config.gate_rank_type != 0:
446
+ attn_output = attn_output * g
447
+ attn_output = self.o_proj(attn_output)
448
 
449
  return attn_output, attn_weights
450
 
 
643
  )
644
  self._attn_implementation = config._attn_implementation
645
  self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
646
+ self.rotary_emb = Qwen2RotaryEmbedding(config=config)
647
 
648
  self.gradient_checkpointing = False
649
  # Initialize weights and apply final processing
 
705
  if inputs_embeds is None:
706
  inputs_embeds = self.embed_tokens(input_ids)
707
 
708
+ if cache_position is None:
709
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
710
+ cache_position = torch.arange(
711
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
712
+ )
713
+
714
+ if position_ids is None:
715
+ position_ids = cache_position.unsqueeze(0)
716
 
717
  # causal_mask = self._update_causal_mask(
718
  # attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
 
723
  hidden_states = inputs_embeds
724
 
725
  # create position embeddings to be shared across the decoder layers
726
+ position_embeddings = None
727
+ if self.config.use_rope:
728
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
729
 
730
  # decoder layers
731
  all_hidden_states = () if output_hidden_states else None