Question Answering
Transformers
Safetensors
English
doge
text-generation
trl
sft
dpo
custom_code
JingzeShi commited on
Commit
c33b773
·
verified ·
1 Parent(s): 6f93486

Update modeling_doge.py

Browse files
Files changed (1) hide show
  1. modeling_doge.py +339 -698
modeling_doge.py CHANGED
@@ -5,10 +5,9 @@
5
  # modular_doge.py file directly. One of our CI enforces this.
6
  # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
  # coding=utf-8
8
- # Copyright 2024 Jingze Shi and the HuggingFace Inc. team. All rights reserved.
9
  #
10
- # This code is based on the Wonderful Matrices paper implementation.
11
- # The Doge family of small language models is trained by Jingze Shi.
12
  #
13
  # Licensed under the Apache License, Version 2.0 (the "License");
14
  # you may not use this file except in compliance with the License.
@@ -23,39 +22,33 @@
23
  # limitations under the License.
24
 
25
  import math
26
- from typing import Callable, List, Optional, Tuple, Union
27
- from packaging import version
28
 
29
  import torch
30
  import torch.nn.functional as F
31
  from torch import nn
32
 
33
  from transformers.activations import ACT2FN
34
- from transformers.cache_utils import Cache, DynamicCache, StaticCache
35
  from transformers.generation import GenerationMixin
36
- from transformers.modeling_attn_mask_utils import AttentionMaskConverter
37
- from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
38
- from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
39
- from transformers.modeling_utils import PreTrainedModel
 
 
 
40
  from transformers.processing_utils import Unpack
41
- from transformers.utils import (
42
- LossKwargs,
43
- add_start_docstrings,
44
- add_start_docstrings_to_model_forward,
45
- is_torch_flex_attn_available,
46
- logging,
47
- replace_return_docstrings,
48
- )
49
  from .configuration_doge import DogeConfig
50
 
51
- if is_torch_flex_attn_available() and version.parse(torch.__version__) >= version.parse("2.6.0"):
52
- from torch.nn.attention.flex_attention import flex_attention
53
 
54
- logger = logging.get_logger(__name__)
55
-
56
- _CONFIG_FOR_DOC = "DogeConfig"
57
 
58
 
 
59
  class DogeRMSNorm(nn.Module):
60
  def __init__(self, hidden_size, eps=1e-6):
61
  """
@@ -92,7 +85,7 @@ class DogeRotaryEmbedding(nn.Module):
92
  def __init__(self, config: DogeConfig, device=None):
93
  super().__init__()
94
  # BC: "rope_type" was originally "type"
95
- if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
96
  self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
97
  else:
98
  self.rope_type = "default"
@@ -106,45 +99,18 @@ class DogeRotaryEmbedding(nn.Module):
106
  self.register_buffer("inv_freq", inv_freq, persistent=False)
107
  self.original_inv_freq = self.inv_freq
108
 
109
- def _dynamic_frequency_update(self, position_ids, device):
110
- """
111
- dynamic RoPE layers should recompute `inv_freq` in the following situations:
112
- 1 - growing beyond the cached sequence length (allow scaling)
113
- 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
114
- """
115
- seq_len = torch.max(position_ids) + 1
116
- if seq_len > self.max_seq_len_cached: # growth
117
- inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
118
- self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
119
- self.max_seq_len_cached = seq_len
120
-
121
- if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
122
- # This .to() is needed if the model has been moved to a device after being initialized (because
123
- # the buffer is automatically moved, but not the original copy)
124
- self.original_inv_freq = self.original_inv_freq.to(device)
125
- self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
126
- self.max_seq_len_cached = self.original_max_seq_len
127
-
128
  @torch.no_grad()
 
129
  def forward(self, x, position_ids):
130
- if "dynamic" in self.rope_type:
131
- self._dynamic_frequency_update(position_ids, device=x.device)
132
-
133
- # Core RoPE block
134
- inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
135
  position_ids_expanded = position_ids[:, None, :].float()
136
- # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
137
- device_type = x.device.type
138
- device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
139
- with torch.autocast(device_type=device_type, enabled=False):
140
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
141
  emb = torch.cat((freqs, freqs), dim=-1)
142
- cos = emb.cos()
143
- sin = emb.sin()
144
-
145
- # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
146
- cos = cos * self.attention_scaling
147
- sin = sin * self.attention_scaling
148
 
149
  return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
150
 
@@ -203,105 +169,60 @@ def eager_attention_forward(
203
  attention_mask: Optional[torch.Tensor],
204
  scaling: float,
205
  dropout: float = 0.0,
206
- **kwargs,
207
- ) -> Tuple[torch.Tensor, torch.Tensor]:
208
  key_states = repeat_kv(key, module.num_key_value_groups)
209
  value_states = repeat_kv(value, module.num_key_value_groups)
210
 
211
- attn_weights = torch.matmul(query, key_states.transpose(-1, -2)) * scaling
212
  if attention_mask is not None:
213
  causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
214
  attn_weights = attn_weights + causal_mask
215
 
216
- attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
217
- attn_weights = F.dropout(attn_weights, p=dropout, training=module.training)
218
  attn_output = torch.matmul(attn_weights, value_states)
219
  attn_output = attn_output.transpose(1, 2).contiguous()
220
 
221
  return attn_output, attn_weights
222
 
223
 
224
- def sdpa_attention_forward(
225
- module: nn.Module,
226
- query: torch.Tensor,
227
- key: torch.Tensor,
228
- value: torch.Tensor,
229
- attention_mask: Optional[torch.Tensor],
230
- dropout: float = 0.0,
231
- scaling: Optional[float] = None,
232
- is_causal: Optional[bool] = None,
233
- **kwargs,
234
- ) -> Tuple[torch.Tensor, None]:
235
- key = repeat_kv(key, module.num_key_value_groups)
236
- value = repeat_kv(value, module.num_key_value_groups)
237
-
238
- causal_mask = attention_mask
239
- if attention_mask is not None:
240
- causal_mask = causal_mask[:, :, :, : key.shape[-2]]
241
-
242
- # SDPA with memory-efficient backend is bugged with non-contiguous inputs and custom attn_mask for some torch versions
243
- # Reference: https://github.com/pytorch/pytorch/issues/112577.
244
- query = query.contiguous()
245
- key = key.contiguous()
246
- value = value.contiguous()
247
-
248
- # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
249
- # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
250
- if is_causal is None:
251
- is_causal = causal_mask is None and query.shape[2] > 1
252
-
253
- # Shapes (e.g. query.shape[2]) are tensors during jit tracing, resulting in `is_causal` being a tensor.
254
- # We convert it to a bool for the SDPA kernel that only accepts bools.
255
- if torch.jit.is_tracing() and isinstance(is_causal, torch.Tensor):
256
- is_causal = is_causal.item()
257
-
258
- # NOTE: As of pytorch 2.5.1, SDPA backward pass of cuDNN is still incorrect, so we disable cuDNN SDPA (see https://github.com/pytorch/pytorch/issues/138581)
259
- torch.backends.cuda.enable_cudnn_sdp(False)
260
- attn_output = F.scaled_dot_product_attention(
261
- query=query,
262
- key=key,
263
- value=value,
264
- attn_mask=causal_mask,
265
- dropout_p=dropout,
266
- scale=scaling,
267
- is_causal=is_causal,
268
- )
269
- attn_output = attn_output.transpose(1, 2).contiguous()
270
-
271
- return attn_output, None
272
-
273
-
274
  def flex_attention_forward(
275
  module: nn.Module,
276
  query: torch.Tensor,
277
  key: torch.Tensor,
278
  value: torch.Tensor,
279
- attention_mask: Optional[torch.Tensor],
280
  scaling: Optional[float] = None,
281
- is_causal: Optional[bool] = None,
282
  softcap: Optional[float] = None,
283
  head_mask: Optional[torch.Tensor] = None,
284
  **kwargs,
285
- ) -> Tuple[torch.Tensor, torch.Tensor]:
286
- causal_mask = attention_mask
287
- if attention_mask is not None:
 
 
 
 
 
 
288
  causal_mask = causal_mask[:, :, :, : key.shape[-2]]
289
 
290
- # NOTE: Pytorch 2.6.0 and above support dynamic mask attention
291
- def mask_mod(score, batch, head, q_idx, kv_idx):
292
  if softcap is not None:
293
  score = softcap * torch.tanh(score / softcap)
294
  if causal_mask is not None:
295
- score = score + causal_mask[batch][head][q_idx][kv_idx]
296
  if head_mask is not None:
297
- score = score + head_mask[batch][head][0][0]
298
  return score
299
 
300
- attn_output, attention_weights = flex_attention(
301
- query=query,
302
- key=key,
303
- value=value,
304
- score_mod=mask_mod,
 
305
  enable_gqa=True,
306
  scale=scaling,
307
  # Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless.
@@ -315,16 +236,11 @@ def flex_attention_forward(
315
  return attn_output, attention_weights
316
 
317
 
318
- ALL_ATTENTION_FUNCTIONS = {
319
- "eager": eager_attention_forward,
320
- "sdpa": sdpa_attention_forward,
321
- "flex_attention": flex_attention_forward,
322
- }
323
-
324
 
325
- class DogeDynamicMaskAttention(nn.Module):
326
- """Dynamic Mask Attention from 'Wonderful Matrices' paper."""
327
 
 
328
  def __init__(self, config: DogeConfig, layer_idx: Optional[int] = None):
329
  super().__init__()
330
  self.config = config
@@ -334,35 +250,34 @@ class DogeDynamicMaskAttention(nn.Module):
334
  self.scaling = self.head_dim**-0.5
335
  self.attention_dropout = config.attention_dropout
336
  self.keep_window_size = config.keep_window_size
337
- self.dynamic_mask_ratio = config.dynamic_mask_ratio
338
 
339
  self.q_proj = nn.Linear(
340
- config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.hidden_bias
341
  )
342
  self.k_proj = nn.Linear(
343
- config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.hidden_bias
344
  )
345
  self.v_proj = nn.Linear(
346
- config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.hidden_bias
347
  )
348
  # dynamic mask for the QK^T attention weights matrix
349
  self.A = nn.Parameter(torch.zeros(config.num_attention_heads))
350
  self.dt_proj = nn.Linear(
351
- config.num_key_value_heads * self.head_dim, config.num_attention_heads, bias=config.hidden_bias
352
  )
353
  self.o_proj = nn.Linear(
354
- config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.hidden_bias
355
  )
356
 
357
  def forward(
358
  self,
359
  hidden_states: torch.Tensor,
360
- position_embeddings: Tuple[torch.Tensor, torch.Tensor],
361
  attention_mask: Optional[torch.Tensor] = None,
362
  past_key_value: Optional[Cache] = None,
363
  cache_position: Optional[torch.LongTensor] = None,
364
  **kwargs,
365
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
366
  input_shape = hidden_states.shape[:-1]
367
  hidden_shape = (*input_shape, -1, self.head_dim)
368
 
@@ -382,24 +297,17 @@ class DogeDynamicMaskAttention(nn.Module):
382
  dt_states = self.dt_proj(
383
  value_states.transpose(1, 2).reshape(value_states.shape[0], value_states.shape[-2], -1)
384
  )
385
- dynamic_mask = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2)
386
  attn_mask = self.prepare_dynamic_mask(
387
  hidden_states=hidden_states,
388
- dynamic_mask=dynamic_mask,
389
  keep_window_size=self.keep_window_size,
390
- dynamic_mask_ratio=self.dynamic_mask_ratio,
391
  attention_mask=attention_mask,
392
  )
393
 
394
  attention_interface: Callable = eager_attention_forward
395
  if self.config._attn_implementation != "eager":
396
- if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
397
- logger.warning_once(
398
- "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
399
- 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
400
- )
401
- else:
402
- attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
403
 
404
  attn_output, attn_weights = attention_interface(
405
  self,
@@ -419,77 +327,80 @@ class DogeDynamicMaskAttention(nn.Module):
419
  def prepare_dynamic_mask(
420
  self,
421
  hidden_states: torch.Tensor,
422
- dynamic_mask: torch.Tensor,
423
  keep_window_size: int = 2048,
424
- dynamic_mask_ratio: float = 0.0,
425
  attention_mask: Optional[torch.Tensor] = None,
426
  ):
427
  """
428
  The core idea of DMA is to calculate the dynamic attention mask to mask the tokens that should be masked, so as to form sparse attention.
429
 
430
- Combine `dynamic_mask` with `attention_mask` to generate the final `attn_mask`.
431
 
432
  Args:
433
  hidden_states (`torch.Tensor`): The input hidden_states, used to determine the minimum value of the current input precision.
434
- dynamic_mask (`torch.Tensor`): dynamic mask of shape `(batch_size, num_heads, key_sequence_length)`.
435
  keep_window_size (`int`): The window size of tokens that are not dynamically masked, and dynamic masking is only performed when the sequence length exceeds this value.
436
- dynamic_mask_ratio (`float`): Ratio from 0.0 to 1.0 used to control the proportion of the dynamic mask filled with the minimum value.
437
  attention_mask (`torch.Tensor`, *optional*): attention mask of shape `(batch_size, 1, query_sequence_length, key_sequence_length)`.
438
  """
439
- attn_mask = dynamic_mask[:, :, None, :]
440
- if dynamic_mask.shape[-1] > keep_window_size:
441
- if 0.0 < dynamic_mask_ratio <= 1.0:
442
- min_type = torch.finfo(hidden_states.dtype).min
443
- num_dynamic_mask = int((attn_mask.shape[-1] - keep_window_size) * dynamic_mask_ratio)
444
- if num_dynamic_mask > 0:
445
- rate_value = torch.kthvalue(attn_mask, num_dynamic_mask, dim=-1, keepdim=True).values
446
- attn_mask = attn_mask.masked_fill(attn_mask < rate_value, min_type)
447
- else:
448
- ValueError("`dynamic_mask_ratio` should be in the range (0.0, 1.0)")
449
- if attention_mask is not None:
450
- attn_mask = attn_mask + attention_mask[:, :, :, : attn_mask.shape[-1]]
451
-
 
 
 
 
452
  return attn_mask
453
 
454
 
455
  class DogeMLP(nn.Module):
456
- def __init__(self, config: DogeConfig):
457
  super().__init__()
458
- self.hidden_dim = config.hidden_size
459
- self.intermediate_dim = config.intermediate_size
 
 
 
 
460
  self.act_fn = ACT2FN[config.hidden_act]
461
 
462
- self.gate_proj = nn.Linear(self.hidden_dim, self.intermediate_dim, bias=config.hidden_bias)
463
- self.up_proj = nn.Linear(self.hidden_dim, self.intermediate_dim, bias=config.hidden_bias)
464
- self.down_proj = nn.Linear(self.intermediate_dim, self.hidden_dim, bias=config.hidden_bias)
465
 
466
- def forward(
467
- self,
468
- hidden_states: torch.Tensor,
469
- **kwargs,
470
- ) -> torch.Tensor:
471
- hidden_states = self.down_proj(self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states))
472
- return hidden_states
473
-
474
-
475
- class DogeCDMoE(DogeMLP):
476
- """Cross Domain Mixture of Experts from 'Wonderful Matrices' paper."""
477
 
 
478
  def __init__(self, config: DogeConfig):
479
- super().__init__(config)
480
- self.hidden_dim = config.hidden_size
 
481
  self.act_fn = ACT2FN[config.hidden_act]
482
 
483
  self.num_experts = config.num_experts
 
484
  self.top_k = config.num_experts_per_tok
485
- self.num_keys = int(math.sqrt(self.num_experts))
 
 
 
 
 
486
 
487
  # router gate for retrieval experts
488
- self.router_gate = nn.Linear(self.hidden_dim, self.num_keys * 2)
489
 
490
- # experts
491
- self.down_embed = nn.Embedding(self.num_experts, self.hidden_dim)
492
- self.up_embed = nn.Embedding(self.num_experts, self.hidden_dim)
493
 
494
  def forward(
495
  self,
@@ -498,288 +409,169 @@ class DogeCDMoE(DogeMLP):
498
  ) -> torch.Tensor:
499
  bsz, seq_len, _ = hidden_states.shape
500
 
501
- # get routing weights with router gate
502
- routing_weights = self.router_gate(hidden_states).view(2, bsz * seq_len, -1)
503
 
504
- # get experts with the highest routing weights
505
- (scores_x, scores_y), (indices_x, indices_y) = [w.topk(self.num_keys, dim=-1) for w in routing_weights]
506
  all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2)
507
  all_indices = indices_x.unsqueeze(-1) * self.num_keys + indices_y.unsqueeze(-2)
508
  all_scores = all_scores.view(*all_scores.shape[:-2], -1)
509
  all_indices = all_indices.view(*all_indices.shape[:-2], -1)
510
- scores, indices = all_scores.topk(self.top_k, dim=-1)
 
 
 
 
 
 
511
  down_embed = self.down_embed(indices)
512
  up_embed = self.up_embed(indices)
513
-
514
- # mix experts states with cross domain states
515
  experts_weights = torch.matmul(down_embed, hidden_states.view(bsz * seq_len, -1, 1)).view(bsz * seq_len, -1)
516
- experts_weights = self.act_fn(experts_weights) * scores.softmax(dim=-1)
517
  experts_states = torch.matmul(experts_weights.view(bsz * seq_len, 1, -1), up_embed).view(bsz, seq_len, -1)
518
  hidden_states = self.down_proj(self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states))
519
  hidden_states = hidden_states + experts_states
520
- return hidden_states
521
 
522
 
523
- class DogeDecoderLayer(nn.Module):
524
  def __init__(self, config: DogeConfig, layer_idx: Optional[int] = None):
525
  super().__init__()
526
  self.hidden_dropout = config.hidden_dropout
527
 
528
- self.pre_layernorm = DogeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
529
- self.self_attn = DogeDynamicMaskAttention(config=config, layer_idx=layer_idx)
530
- self.pre_residual = DogeResidual(config.hidden_size)
531
 
532
- self.post_layernorm = DogeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
533
- self.feed_forward = DogeMLP(config) if not config.is_moe else DogeCDMoE(config)
534
- self.post_residual = DogeResidual(config.hidden_size)
535
 
536
  def forward(
537
  self,
538
  hidden_states: torch.Tensor,
 
539
  attention_mask: Optional[torch.Tensor] = None,
540
  position_ids: Optional[torch.LongTensor] = None,
541
- past_key_value: Optional[Cache] = None,
542
- output_attentions: Optional[bool] = False,
543
  use_cache: Optional[bool] = False,
544
  cache_position: Optional[torch.LongTensor] = None,
545
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
546
- **kwargs,
547
- ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
548
  # sequence transformation
549
  residual = hidden_states
550
- hidden_states = self.pre_layernorm(hidden_states)
551
  hidden_states, self_attn_weights = self.self_attn(
552
  hidden_states=hidden_states,
 
553
  attention_mask=attention_mask,
554
  position_ids=position_ids,
555
  past_key_value=past_key_value,
556
- output_attentions=output_attentions,
557
  use_cache=use_cache,
558
  cache_position=cache_position,
559
- position_embeddings=position_embeddings,
560
  **kwargs,
561
  )
562
- self_attn_weights = None
563
  hidden_states = F.dropout(hidden_states, p=self.hidden_dropout, training=self.training)
564
- hidden_states = self.pre_residual(residual, hidden_states)
565
 
566
  # state transformation
567
  residual = hidden_states
568
- hidden_states = self.post_layernorm(hidden_states)
569
- hidden_states = self.feed_forward(hidden_states)
570
  hidden_states = F.dropout(hidden_states, p=self.hidden_dropout, training=self.training)
571
- hidden_states = self.post_residual(residual, hidden_states)
572
-
573
- outputs = (hidden_states,)
574
- if output_attentions:
575
- outputs += (self_attn_weights,)
576
-
577
- return outputs
578
-
579
 
580
- DOGE_START_DOCSTRING = r"""
581
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
582
- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
583
- etc.)
584
-
585
- This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
586
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
587
- and behavior.
588
-
589
- Parameters:
590
- config ([`DogeConfig`]):
591
- Model configuration class with all the parameters of the model. Initializing with a config file does not
592
- load the weights associated with the model, only the configuration. Check out the
593
- [`~PreTrainedModel.from_pretrained`] method to load the model weights.
594
- """
595
 
596
 
597
- @add_start_docstrings(
598
- "The bare Doge Model outputting raw hidden-states without any specific head on top.",
599
- DOGE_START_DOCSTRING,
600
- )
601
  class DogePreTrainedModel(PreTrainedModel):
602
- config_class = DogeConfig
603
  base_model_prefix = "model"
604
  supports_gradient_checkpointing = True
605
  _no_split_modules = ["DogeDecoderLayer"]
606
  _skip_keys_device_placement = ["past_key_values"]
 
607
  _supports_sdpa = True
608
  _supports_flex_attn = True
609
- _supports_cache_class = True
610
- _supports_quantized_cache = True
611
- _supports_static_cache = True
 
 
 
 
612
 
613
  def _init_weights(self, module):
614
- std = self.config.initializer_range
615
- if isinstance(module, nn.Linear):
616
- module.weight.data.normal_(mean=0.0, std=std)
617
- if module.bias is not None:
618
- module.bias.data.zero_()
619
- elif isinstance(module, nn.Embedding):
620
- module.weight.data.normal_(mean=0.0, std=std)
621
- if module.padding_idx is not None:
622
- module.weight.data[module.padding_idx].zero_()
623
-
624
-
625
- DOGE_INPUTS_DOCSTRING = r"""
626
- Args:
627
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
628
- Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
629
- it.
630
-
631
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
632
- [`PreTrainedTokenizer.__call__`] for details.
633
-
634
- [What are input IDs?](../glossary#input-ids)
635
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
636
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
637
-
638
- - 1 for tokens that are **not masked**,
639
- - 0 for tokens that are **masked**.
640
-
641
- [What are attention masks?](../glossary#attention-mask)
642
-
643
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
644
- [`PreTrainedTokenizer.__call__`] for details.
645
-
646
- If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
647
- `past_key_values`).
648
-
649
- If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
650
- and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
651
- information on the default strategy.
652
-
653
- - 1 indicates the head is **not masked**,
654
- - 0 indicates the head is **masked**.
655
- position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
656
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
657
- config.n_positions - 1]`.
658
-
659
- [What are position IDs?](../glossary#position-ids)
660
- past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
661
- Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
662
- blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
663
- returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
664
-
665
- Two formats are allowed:
666
- - a [`~cache_utils.Cache`] instance, see our
667
- [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
668
- - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
669
- shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
670
- cache format.
671
-
672
- The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
673
- legacy cache format will be returned.
674
-
675
- If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
676
- have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
677
- of shape `(batch_size, sequence_length)`.
678
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
679
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
680
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
681
- model's internal embedding lookup matrix.
682
- use_cache (`bool`, *optional*):
683
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
684
- `past_key_values`).
685
- output_attentions (`bool`, *optional*):
686
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
687
- tensors for more detail.
688
- output_hidden_states (`bool`, *optional*):
689
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
690
- more detail.
691
- return_dict (`bool`, *optional*):
692
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
693
- cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
694
- Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
695
- this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
696
- the complete sequence length.
697
- """
698
-
699
-
700
- @add_start_docstrings(
701
- "The bare Doge Model outputting raw hidden-states without any specific head on top.",
702
- DOGE_START_DOCSTRING,
703
- )
704
  class DogeModel(DogePreTrainedModel):
705
- """
706
- Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DogeDecoderLayer`]
707
-
708
- Args:
709
- config: DogeConfig
710
- """
711
-
712
  def __init__(self, config: DogeConfig):
713
  super().__init__(config)
714
- self.config = config
715
  self.padding_idx = config.pad_token_id
716
  self.vocab_size = config.vocab_size
717
 
718
- self.word_embed = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
719
- self.rotary_emb = DogeRotaryEmbedding(config)
720
  self.layers = nn.ModuleList(
721
  [DogeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
722
  )
723
- self.final_layernorm = DogeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
724
  self.gradient_checkpointing = False
725
 
726
  # Initialize weights and apply final processing
727
  self.post_init()
728
 
729
- def get_input_embeddings(self):
730
- return self.word_embed
731
-
732
- def set_input_embeddings(self, value):
733
- self.word_embed = value
734
-
735
- @add_start_docstrings_to_model_forward(DOGE_INPUTS_DOCSTRING)
736
  def forward(
737
  self,
738
- input_ids: torch.LongTensor = None,
739
  attention_mask: Optional[torch.Tensor] = None,
740
  position_ids: Optional[torch.LongTensor] = None,
741
- past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
742
  inputs_embeds: Optional[torch.FloatTensor] = None,
743
  use_cache: Optional[bool] = None,
744
- output_attentions: Optional[bool] = None,
745
- output_hidden_states: Optional[bool] = None,
746
- return_dict: Optional[bool] = None,
747
  cache_position: Optional[torch.LongTensor] = None,
748
- **kwargs,
749
- ) -> Union[Tuple, BaseModelOutputWithPast]:
750
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
751
- output_hidden_states = (
752
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
753
- )
754
- use_cache = use_cache if use_cache is not None else self.config.use_cache
755
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
756
-
757
  if (input_ids is None) ^ (inputs_embeds is not None):
758
- raise ValueError("You cannot specify both input_ids and inputs_embeds")
759
-
760
- if self.gradient_checkpointing and self.training and use_cache:
761
- logger.warning_once(
762
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
763
- )
764
- use_cache = False
765
-
766
- if inputs_embeds is None:
767
- inputs_embeds = self.word_embed(input_ids)
768
 
769
  if use_cache and past_key_values is None:
770
  past_key_values = DynamicCache()
771
 
 
 
 
772
  if cache_position is None:
773
  past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
774
  cache_position = torch.arange(
775
  past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
776
  )
777
-
778
  if position_ids is None:
779
  position_ids = cache_position.unsqueeze(0)
780
 
781
- causal_mask = self._update_causal_mask(
782
- attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
 
 
 
 
 
 
783
  )
784
 
785
  hidden_states = inputs_embeds
@@ -787,236 +579,185 @@ class DogeModel(DogePreTrainedModel):
787
  # create position embeddings to be shared across the decoder layers
788
  position_embeddings = self.rotary_emb(hidden_states, position_ids)
789
 
790
- # decoder layers
791
- all_hidden_states = () if output_hidden_states else None
792
- all_self_attns = () if output_attentions else None
793
-
794
  for decoder_layer in self.layers[: self.config.num_hidden_layers]:
795
- if output_hidden_states:
796
- all_hidden_states += (hidden_states,)
797
-
798
- if self.gradient_checkpointing and self.training:
799
- layer_outputs = self._gradient_checkpointing_func(
800
- decoder_layer.__call__,
801
- hidden_states,
802
- causal_mask,
803
- position_ids,
804
- past_key_values,
805
- output_attentions,
806
- use_cache,
807
- cache_position,
808
- position_embeddings,
809
- )
810
- else:
811
- layer_outputs = decoder_layer(
812
- hidden_states,
813
- attention_mask=causal_mask,
814
- position_ids=position_ids,
815
- past_key_value=past_key_values,
816
- output_attentions=output_attentions,
817
- use_cache=use_cache,
818
- cache_position=cache_position,
819
- position_embeddings=position_embeddings,
820
- **kwargs,
821
- )
822
 
823
- hidden_states = layer_outputs[0]
824
 
825
- if output_attentions:
826
- all_self_attns += (layer_outputs[1],)
 
 
 
 
 
 
 
827
 
828
- hidden_states = self.final_layernorm(hidden_states)
 
 
829
 
830
- # add hidden states from the last decoder layer
831
- if output_hidden_states:
832
- all_hidden_states += (hidden_states,)
 
 
 
 
 
 
 
 
 
 
 
833
 
834
- output = BaseModelOutputWithPast(
835
- last_hidden_state=hidden_states,
836
- past_key_values=past_key_values if use_cache else None,
837
- hidden_states=all_hidden_states,
838
- attentions=all_self_attns,
839
- )
840
- return output if return_dict else output.to_tuple()
841
 
842
- def _update_causal_mask(
843
- self,
844
- attention_mask: torch.Tensor,
845
- input_tensor: torch.Tensor,
846
- cache_position: torch.Tensor,
847
- past_key_values: Cache,
848
- output_attentions: bool,
849
- ):
850
- # We have to provide attention_mask for dynamic mask computation
851
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
852
- using_static_cache = isinstance(past_key_values, StaticCache)
853
-
854
- dtype, device = input_tensor.dtype, input_tensor.device
855
- sequence_length = input_tensor.shape[1]
856
- if using_static_cache:
857
- target_length = past_key_values.get_max_cache_shape()
858
- else:
859
- target_length = (
860
- attention_mask.shape[-1]
861
- if isinstance(attention_mask, torch.Tensor)
862
- else past_seen_tokens + sequence_length + 1
863
- )
864
 
865
- # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
866
- causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
867
- attention_mask,
868
- sequence_length=sequence_length,
869
- target_length=target_length,
870
- dtype=dtype,
871
- device=device,
872
- cache_position=cache_position,
873
- batch_size=input_tensor.shape[0],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
874
  )
 
875
 
876
- if (
877
- self.config._attn_implementation == "sdpa"
878
- and attention_mask is not None
879
- and attention_mask.device.type in ["cuda", "xpu"]
880
- and not output_attentions
881
- ):
882
- # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
883
- # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
884
- # Details: https://github.com/pytorch/pytorch/issues/110213
885
- min_dtype = torch.finfo(dtype).min
886
- causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
887
-
888
- return causal_mask
889
-
890
- @staticmethod
891
- def _prepare_4d_causal_attention_mask_with_cache_position(
892
- attention_mask: torch.Tensor,
893
- sequence_length: int,
894
- target_length: int,
895
- dtype: torch.dtype,
896
- device: torch.device,
897
- cache_position: torch.Tensor,
898
- batch_size: int,
899
- **kwargs,
900
- ):
901
- """
902
- Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
903
- `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
904
 
905
- Args:
906
- attention_mask (`torch.Tensor`):
907
- A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
908
- `(batch_size, 1, query_length, key_value_length)`.
909
- sequence_length (`int`):
910
- The sequence length being processed.
911
- target_length (`int`):
912
- The target length: when generating with static cache, the mask should be as long as the static cache,
913
- to account for the 0 padding, the part of the cache that is not filled yet.
914
- dtype (`torch.dtype`):
915
- The dtype to use for the 4D attention mask.
916
- device (`torch.device`):
917
- The device to plcae the 4D attention mask on.
918
- cache_position (`torch.Tensor`):
919
- Indices depicting the position of the input sequence tokens in the sequence.
920
- batch_size (`torch.Tensor`):
921
- Batch size.
922
- """
923
- if attention_mask is not None and attention_mask.dim() == 4:
924
- # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
925
- causal_mask = attention_mask
926
- else:
927
- min_dtype = torch.finfo(dtype).min
928
- causal_mask = torch.full(
929
- (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
930
- )
931
- if sequence_length != 1:
932
- causal_mask = torch.triu(causal_mask, diagonal=1)
933
- causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
934
- causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
935
- if attention_mask is not None:
936
- causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
937
- mask_length = attention_mask.shape[-1]
938
- padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
939
- padding_mask = padding_mask == 0
940
- causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
941
- padding_mask, min_dtype
942
- )
943
 
944
- return causal_mask
 
 
 
 
 
 
945
 
946
 
 
947
  class DogeForCausalLM(DogePreTrainedModel, GenerationMixin):
948
  _tied_weights_keys = ["lm_head.weight"]
949
  _tp_plan = {"lm_head": "colwise_rep"}
 
950
 
951
- def __init__(self, config: DogeConfig):
952
  super().__init__(config)
953
- self.config = config
954
  self.model = DogeModel(config)
955
  self.vocab_size = config.vocab_size
956
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
 
 
 
957
 
958
  # Initialize weights and apply final processing
959
  self.post_init()
960
 
961
- def get_input_embeddings(self):
962
- return self.model.word_embed
963
-
964
- def set_input_embeddings(self, value):
965
- self.model.word_embed = value
966
-
967
- def get_output_embeddings(self):
968
- return self.lm_head
969
-
970
- def set_output_embeddings(self, new_embeddings):
971
- self.lm_head = new_embeddings
972
 
973
  def get_decoder(self):
974
  return self.model
975
 
976
- def set_decoder(self, decoder):
977
- self.model = decoder
978
-
979
- @add_start_docstrings_to_model_forward(DOGE_INPUTS_DOCSTRING)
980
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
981
  def forward(
982
  self,
983
- input_ids: torch.LongTensor = None,
984
  attention_mask: Optional[torch.Tensor] = None,
985
  position_ids: Optional[torch.LongTensor] = None,
986
- past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
987
  inputs_embeds: Optional[torch.FloatTensor] = None,
988
  labels: Optional[torch.LongTensor] = None,
989
  use_cache: Optional[bool] = None,
990
- output_attentions: Optional[bool] = None,
991
- output_hidden_states: Optional[bool] = None,
992
- return_dict: Optional[bool] = None,
993
  cache_position: Optional[torch.LongTensor] = None,
994
  logits_to_keep: Union[int, torch.Tensor] = 0,
995
- **kwargs: Unpack[LossKwargs],
996
- ) -> Union[Tuple, CausalLMOutputWithPast]:
 
997
  r"""
998
- Args:
999
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1000
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1001
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1002
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1003
-
1004
- logits_to_keep (`int`, *optional*):
1005
- If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
1006
- `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
1007
- token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
1008
- If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
1009
- This is useful when using packed tensor format (single dimension for batch and sequence length).
1010
-
1011
- Returns:
1012
 
1013
  Example:
1014
 
1015
  ```python
1016
- >>> from transformers import AutoTokenizer, AutoModelForCausalLM
1017
 
1018
- >>> model = AutoModelForCausalLM.from_pretrained("SmallDoge/Doge-20M")
1019
- >>> tokenizer = AutoTokenizer.from_pretrained("SmallDoge/Doge-20M")
1020
 
1021
  >>> prompt = "Hey, are you conscious? Can you talk to me?"
1022
  >>> inputs = tokenizer(prompt, return_tensors="pt")
@@ -1026,156 +767,56 @@ class DogeForCausalLM(DogePreTrainedModel, GenerationMixin):
1026
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1027
  "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1028
  ```"""
1029
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1030
- output_hidden_states = (
1031
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1032
  )
1033
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1034
 
1035
- # decoder output consists of (dec_features, layer_state, dec_hidden, dec_attn)
1036
- outputs = self.model(
1037
  input_ids=input_ids,
1038
  attention_mask=attention_mask,
1039
  position_ids=position_ids,
1040
  past_key_values=past_key_values,
1041
  inputs_embeds=inputs_embeds,
1042
  use_cache=use_cache,
1043
- output_attentions=output_attentions,
1044
- output_hidden_states=output_hidden_states,
1045
- return_dict=return_dict,
1046
  cache_position=cache_position,
1047
  **kwargs,
1048
  )
1049
 
1050
- hidden_states = outputs[0]
1051
- # only compute necessary logits, and do not upcast them to float if we are not computing the loss
1052
  slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
1053
  logits = self.lm_head(hidden_states[:, slice_indices, :])
1054
 
1055
  loss = None
1056
  if labels is not None:
1057
- loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.vocab_size, **kwargs)
1058
-
1059
- if not return_dict:
1060
- output = (logits,) + outputs[1:]
1061
- return (loss,) + output if loss is not None else output
 
 
 
 
 
 
 
 
1062
 
1063
- return CausalLMOutputWithPast(
1064
  loss=loss,
 
1065
  logits=logits,
1066
  past_key_values=outputs.past_key_values,
1067
  hidden_states=outputs.hidden_states,
1068
  attentions=outputs.attentions,
 
1069
  )
1070
 
1071
 
1072
- @add_start_docstrings(
1073
- """
1074
- The Doge Model transformer with a sequence classification head on top (linear layer).
1075
-
1076
- [`DogeForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1077
- (e.g. GPT-2) do.
1078
-
1079
- Since it does classification on the last token, it requires to know the position of the last token. If a
1080
- `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1081
- no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1082
- padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1083
- each row of the batch).
1084
- """,
1085
- DOGE_START_DOCSTRING,
1086
- )
1087
- class DogeForSequenceClassification(DogePreTrainedModel):
1088
- def __init__(self, config: DogeConfig):
1089
- super().__init__(config)
1090
- self.num_labels = config.num_labels
1091
-
1092
- self.model = DogeModel(config)
1093
- self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1094
- self.config = config
1095
-
1096
- # Initialize weights and apply final processing
1097
- self.post_init()
1098
-
1099
- def get_input_embeddings(self):
1100
- return self.model.word_embed
1101
-
1102
- def set_input_embeddings(self, value):
1103
- self.model.word_embed = value
1104
-
1105
- @add_start_docstrings_to_model_forward(DOGE_INPUTS_DOCSTRING)
1106
- def forward(
1107
- self,
1108
- input_ids: Optional[torch.LongTensor] = None,
1109
- attention_mask: Optional[torch.Tensor] = None,
1110
- position_ids: Optional[torch.LongTensor] = None,
1111
- past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
1112
- inputs_embeds: Optional[torch.FloatTensor] = None,
1113
- labels: Optional[torch.LongTensor] = None,
1114
- use_cache: Optional[bool] = None,
1115
- output_attentions: Optional[bool] = None,
1116
- output_hidden_states: Optional[bool] = None,
1117
- return_dict: Optional[bool] = None,
1118
- ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1119
- r"""
1120
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1121
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1122
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1123
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1124
- """
1125
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1126
-
1127
- transformer_outputs = self.model(
1128
- input_ids,
1129
- attention_mask=attention_mask,
1130
- position_ids=position_ids,
1131
- past_key_values=past_key_values,
1132
- inputs_embeds=inputs_embeds,
1133
- use_cache=use_cache,
1134
- output_attentions=output_attentions,
1135
- output_hidden_states=output_hidden_states,
1136
- return_dict=return_dict,
1137
- )
1138
- hidden_states = transformer_outputs[0]
1139
- logits = self.score(hidden_states)
1140
-
1141
- if input_ids is not None:
1142
- batch_size = input_ids.shape[0]
1143
- else:
1144
- batch_size = inputs_embeds.shape[0]
1145
-
1146
- if self.config.pad_token_id is None and batch_size != 1:
1147
- raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1148
- if self.config.pad_token_id is None:
1149
- last_non_pad_token = -1
1150
- elif input_ids is not None:
1151
- # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
1152
- non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
1153
- token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
1154
- last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
1155
- else:
1156
- last_non_pad_token = -1
1157
- logger.warning_once(
1158
- f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
1159
- "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
1160
- )
1161
-
1162
- pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
1163
-
1164
- loss = None
1165
- if labels is not None:
1166
- loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
1167
-
1168
- if not return_dict:
1169
- output = (pooled_logits,) + transformer_outputs[1:]
1170
- return ((loss,) + output) if loss is not None else output
1171
-
1172
- return SequenceClassifierOutputWithPast(
1173
- loss=loss,
1174
- logits=pooled_logits,
1175
- past_key_values=transformer_outputs.past_key_values,
1176
- hidden_states=transformer_outputs.hidden_states,
1177
- attentions=transformer_outputs.attentions,
1178
- )
1179
 
1180
 
1181
  __all__ = ["DogeForCausalLM", "DogeModel", "DogePreTrainedModel", "DogeForSequenceClassification"]
 
5
  # modular_doge.py file directly. One of our CI enforces this.
6
  # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
  # coding=utf-8
8
+ # Copyright 2025 Jingze Shi and the HuggingFace Inc. team. All rights reserved.
9
  #
10
+ # The Doge family of small language models is trained by SmallDoge Team.
 
11
  #
12
  # Licensed under the Apache License, Version 2.0 (the "License");
13
  # you may not use this file except in compliance with the License.
 
22
  # limitations under the License.
23
 
24
  import math
25
+ from typing import Callable, Optional, Union
 
26
 
27
  import torch
28
  import torch.nn.functional as F
29
  from torch import nn
30
 
31
  from transformers.activations import ACT2FN
32
+ from transformers.cache_utils import Cache, DynamicCache
33
  from transformers.generation import GenerationMixin
34
+ from transformers.integrations import use_kernel_forward_from_hub
35
+ from transformers.integrations.flex_attention import compile_friendly_flex_attention
36
+ from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
37
+ from transformers.modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer
38
+ from transformers.modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
39
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
40
+ from transformers.modeling_utils import AttentionInterface, PreTrainedModel
41
  from transformers.processing_utils import Unpack
42
+ from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available
43
+ from transformers.utils.generic import OutputRecorder, check_model_inputs
 
 
 
 
 
 
44
  from .configuration_doge import DogeConfig
45
 
 
 
46
 
47
+ if is_torch_flex_attn_available():
48
+ from torch.nn.attention.flex_attention import BlockMask
 
49
 
50
 
51
+ @use_kernel_forward_from_hub("RMSNorm")
52
  class DogeRMSNorm(nn.Module):
53
  def __init__(self, hidden_size, eps=1e-6):
54
  """
 
85
  def __init__(self, config: DogeConfig, device=None):
86
  super().__init__()
87
  # BC: "rope_type" was originally "type"
88
+ if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
89
  self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
90
  else:
91
  self.rope_type = "default"
 
99
  self.register_buffer("inv_freq", inv_freq, persistent=False)
100
  self.original_inv_freq = self.inv_freq
101
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  @torch.no_grad()
103
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
104
  def forward(self, x, position_ids):
105
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
 
 
 
 
106
  position_ids_expanded = position_ids[:, None, :].float()
107
+
108
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
109
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
 
110
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
111
  emb = torch.cat((freqs, freqs), dim=-1)
112
+ cos = emb.cos() * self.attention_scaling
113
+ sin = emb.sin() * self.attention_scaling
 
 
 
 
114
 
115
  return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
116
 
 
169
  attention_mask: Optional[torch.Tensor],
170
  scaling: float,
171
  dropout: float = 0.0,
172
+ **kwargs: Unpack[TransformersKwargs],
173
+ ):
174
  key_states = repeat_kv(key, module.num_key_value_groups)
175
  value_states = repeat_kv(value, module.num_key_value_groups)
176
 
177
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
178
  if attention_mask is not None:
179
  causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
180
  attn_weights = attn_weights + causal_mask
181
 
182
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
183
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
184
  attn_output = torch.matmul(attn_weights, value_states)
185
  attn_output = attn_output.transpose(1, 2).contiguous()
186
 
187
  return attn_output, attn_weights
188
 
189
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  def flex_attention_forward(
191
  module: nn.Module,
192
  query: torch.Tensor,
193
  key: torch.Tensor,
194
  value: torch.Tensor,
195
+ attention_mask: Union[torch.Tensor, "BlockMask"],
196
  scaling: Optional[float] = None,
 
197
  softcap: Optional[float] = None,
198
  head_mask: Optional[torch.Tensor] = None,
199
  **kwargs,
200
+ ) -> tuple[torch.Tensor, torch.Tensor]:
201
+ block_mask = None
202
+ causal_mask = None
203
+ if isinstance(attention_mask, BlockMask):
204
+ block_mask = attention_mask
205
+ else:
206
+ causal_mask = attention_mask
207
+
208
+ if causal_mask is not None:
209
  causal_mask = causal_mask[:, :, :, : key.shape[-2]]
210
 
211
+ def score_mod(score, batch_idx, head_idx, q_idx, kv_idx):
 
212
  if softcap is not None:
213
  score = softcap * torch.tanh(score / softcap)
214
  if causal_mask is not None:
215
+ score = score + causal_mask[batch_idx][head_idx][q_idx][kv_idx]
216
  if head_mask is not None:
217
+ score = score + head_mask[batch_idx][head_idx][0][0]
218
  return score
219
 
220
+ attn_output, attention_weights = compile_friendly_flex_attention(
221
+ query,
222
+ key,
223
+ value,
224
+ score_mod=score_mod,
225
+ block_mask=block_mask,
226
  enable_gqa=True,
227
  scale=scaling,
228
  # Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless.
 
236
  return attn_output, attention_weights
237
 
238
 
239
+ ALL_ATTENTION_FUNCTIONS = AttentionInterface()
240
+ ALL_ATTENTION_FUNCTIONS["doge_flex_attention"] = flex_attention_forward
 
 
 
 
241
 
 
 
242
 
243
+ class DogeAttention(nn.Module):
244
  def __init__(self, config: DogeConfig, layer_idx: Optional[int] = None):
245
  super().__init__()
246
  self.config = config
 
250
  self.scaling = self.head_dim**-0.5
251
  self.attention_dropout = config.attention_dropout
252
  self.keep_window_size = config.keep_window_size
 
253
 
254
  self.q_proj = nn.Linear(
255
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
256
  )
257
  self.k_proj = nn.Linear(
258
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
259
  )
260
  self.v_proj = nn.Linear(
261
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
262
  )
263
  # dynamic mask for the QK^T attention weights matrix
264
  self.A = nn.Parameter(torch.zeros(config.num_attention_heads))
265
  self.dt_proj = nn.Linear(
266
+ config.num_key_value_heads * self.head_dim, config.num_attention_heads, bias=config.attention_bias
267
  )
268
  self.o_proj = nn.Linear(
269
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
270
  )
271
 
272
  def forward(
273
  self,
274
  hidden_states: torch.Tensor,
275
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
276
  attention_mask: Optional[torch.Tensor] = None,
277
  past_key_value: Optional[Cache] = None,
278
  cache_position: Optional[torch.LongTensor] = None,
279
  **kwargs,
280
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
281
  input_shape = hidden_states.shape[:-1]
282
  hidden_shape = (*input_shape, -1, self.head_dim)
283
 
 
297
  dt_states = self.dt_proj(
298
  value_states.transpose(1, 2).reshape(value_states.shape[0], value_states.shape[-2], -1)
299
  )
300
+ dt_states = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2)
301
  attn_mask = self.prepare_dynamic_mask(
302
  hidden_states=hidden_states,
303
+ dt_states=dt_states,
304
  keep_window_size=self.keep_window_size,
 
305
  attention_mask=attention_mask,
306
  )
307
 
308
  attention_interface: Callable = eager_attention_forward
309
  if self.config._attn_implementation != "eager":
310
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
 
 
 
 
 
 
311
 
312
  attn_output, attn_weights = attention_interface(
313
  self,
 
327
  def prepare_dynamic_mask(
328
  self,
329
  hidden_states: torch.Tensor,
330
+ dt_states: torch.Tensor,
331
  keep_window_size: int = 2048,
 
332
  attention_mask: Optional[torch.Tensor] = None,
333
  ):
334
  """
335
  The core idea of DMA is to calculate the dynamic attention mask to mask the tokens that should be masked, so as to form sparse attention.
336
 
337
+ Combine `dt_states` with `attention_mask` to generate the final `attn_mask`.
338
 
339
  Args:
340
  hidden_states (`torch.Tensor`): The input hidden_states, used to determine the minimum value of the current input precision.
341
+ dt_states (`torch.Tensor`): dt_states of shape `(batch_size, num_heads, key_sequence_length)`.
342
  keep_window_size (`int`): The window size of tokens that are not dynamically masked, and dynamic masking is only performed when the sequence length exceeds this value.
 
343
  attention_mask (`torch.Tensor`, *optional*): attention mask of shape `(batch_size, 1, query_sequence_length, key_sequence_length)`.
344
  """
345
+ min_dtype = torch.finfo(hidden_states.dtype).min
346
+ dtype = hidden_states.dtype
347
+ attn_mask = dt_states[:, :, None, :].expand(
348
+ -1, -1, hidden_states.shape[1], -1
349
+ ) # [batch_size, num_heads, query_len, key_len]
350
+ if attention_mask is not None and not isinstance(attention_mask, BlockMask):
351
+ if attention_mask.dtype == torch.bool:
352
+ dtype = hidden_states.dtype
353
+ attention_mask = torch.where(
354
+ attention_mask, torch.tensor(0.0, device=attention_mask.device, dtype=dtype), min_dtype
355
+ )
356
+ attn_mask = attn_mask.masked_fill(attention_mask[:, :, :, : attn_mask.shape[-1]] != 0, min_dtype)
357
+ if attn_mask.shape[-1] > keep_window_size:
358
+ active_mask = torch.zeros_like(attn_mask, dtype=dtype, device=attn_mask.device)
359
+ topk_indices = torch.topk(attn_mask, keep_window_size, dim=-1, largest=True, sorted=False).indices
360
+ active_mask = active_mask.scatter(-1, topk_indices, 1.0)
361
+ attn_mask = attn_mask.masked_fill(active_mask == 0.0, min_dtype)
362
  return attn_mask
363
 
364
 
365
  class DogeMLP(nn.Module):
366
+ def __init__(self, config):
367
  super().__init__()
368
+ self.config = config
369
+ self.hidden_size = config.hidden_size
370
+ self.intermediate_size = config.intermediate_size
371
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
372
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
373
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
374
  self.act_fn = ACT2FN[config.hidden_act]
375
 
376
+ def forward(self, x):
377
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
378
+ return down_proj
379
 
 
 
 
 
 
 
 
 
 
 
 
380
 
381
+ class DogeCDMoE(nn.Module):
382
  def __init__(self, config: DogeConfig):
383
+ super().__init__()
384
+ self.hidden_size = config.hidden_size
385
+ self.intermediate_size = config.intermediate_size
386
  self.act_fn = ACT2FN[config.hidden_act]
387
 
388
  self.num_experts = config.num_experts
389
+ self.num_keys = math.floor(math.sqrt(self.num_experts))
390
  self.top_k = config.num_experts_per_tok
391
+ self.norm_topk_prob = config.norm_topk_prob
392
+
393
+ # shared expert
394
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
395
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
396
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
397
 
398
  # router gate for retrieval experts
399
+ self.router_gate = nn.Linear(self.hidden_size, self.num_keys * 2, bias=False)
400
 
401
+ # routed experts
402
+ self.down_embed = nn.Embedding(self.num_experts, self.hidden_size)
403
+ self.up_embed = nn.Embedding(self.num_experts, self.hidden_size)
404
 
405
  def forward(
406
  self,
 
409
  ) -> torch.Tensor:
410
  bsz, seq_len, _ = hidden_states.shape
411
 
412
+ # get routing logits with router gate
413
+ router_logits = self.router_gate(hidden_states).view(2, bsz * seq_len, -1)
414
 
415
+ # get experts with the highest routing logits
416
+ (scores_x, scores_y), (indices_x, indices_y) = router_logits.topk(self.num_keys, dim=-1)
417
  all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2)
418
  all_indices = indices_x.unsqueeze(-1) * self.num_keys + indices_y.unsqueeze(-2)
419
  all_scores = all_scores.view(*all_scores.shape[:-2], -1)
420
  all_indices = all_indices.view(*all_indices.shape[:-2], -1)
421
+ scores, position_indices = all_scores.topk(self.top_k, dim=-1)
422
+ indices = all_indices.gather(-1, position_indices)
423
+ routing_weights = F.softmax(scores, dim=-1)
424
+ if self.norm_topk_prob:
425
+ routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
426
+
427
+ # mix routed experts states with shared expert states
428
  down_embed = self.down_embed(indices)
429
  up_embed = self.up_embed(indices)
 
 
430
  experts_weights = torch.matmul(down_embed, hidden_states.view(bsz * seq_len, -1, 1)).view(bsz * seq_len, -1)
431
+ experts_weights = self.act_fn(experts_weights) * routing_weights
432
  experts_states = torch.matmul(experts_weights.view(bsz * seq_len, 1, -1), up_embed).view(bsz, seq_len, -1)
433
  hidden_states = self.down_proj(self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states))
434
  hidden_states = hidden_states + experts_states
435
+ return hidden_states, router_logits
436
 
437
 
438
+ class DogeDecoderLayer(GradientCheckpointingLayer):
439
  def __init__(self, config: DogeConfig, layer_idx: Optional[int] = None):
440
  super().__init__()
441
  self.hidden_dropout = config.hidden_dropout
442
 
443
+ self.input_layernorm = DogeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
444
+ self.self_attn = DogeAttention(config=config, layer_idx=layer_idx)
445
+ self.input_residual = nn.Parameter(torch.ones(config.hidden_size))
446
 
447
+ self.post_attention_layernorm = DogeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
448
+ self.mlp = DogeMLP(config) if not config.is_moe else DogeCDMoE(config)
449
+ self.post_attention_residual = nn.Parameter(torch.ones(config.hidden_size))
450
 
451
  def forward(
452
  self,
453
  hidden_states: torch.Tensor,
454
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
455
  attention_mask: Optional[torch.Tensor] = None,
456
  position_ids: Optional[torch.LongTensor] = None,
457
+ past_key_value: Optional[tuple[torch.Tensor]] = None,
 
458
  use_cache: Optional[bool] = False,
459
  cache_position: Optional[torch.LongTensor] = None,
460
+ **kwargs: Unpack[TransformersKwargs],
461
+ ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
 
462
  # sequence transformation
463
  residual = hidden_states
464
+ hidden_states = self.input_layernorm(hidden_states)
465
  hidden_states, self_attn_weights = self.self_attn(
466
  hidden_states=hidden_states,
467
+ position_embeddings=position_embeddings,
468
  attention_mask=attention_mask,
469
  position_ids=position_ids,
470
  past_key_value=past_key_value,
 
471
  use_cache=use_cache,
472
  cache_position=cache_position,
 
473
  **kwargs,
474
  )
 
475
  hidden_states = F.dropout(hidden_states, p=self.hidden_dropout, training=self.training)
476
+ hidden_states = self.input_residual * residual + hidden_states
477
 
478
  # state transformation
479
  residual = hidden_states
480
+ hidden_states = self.post_attention_layernorm(hidden_states)
481
+ hidden_states = self.mlp(hidden_states)
482
  hidden_states = F.dropout(hidden_states, p=self.hidden_dropout, training=self.training)
483
+ hidden_states = self.post_attention_residual * residual + hidden_states
 
 
 
 
 
 
 
484
 
485
+ return hidden_states
 
 
 
 
 
 
 
 
 
 
 
 
 
 
486
 
487
 
488
+ @auto_docstring
 
 
 
489
  class DogePreTrainedModel(PreTrainedModel):
490
+ config: DogeConfig
491
  base_model_prefix = "model"
492
  supports_gradient_checkpointing = True
493
  _no_split_modules = ["DogeDecoderLayer"]
494
  _skip_keys_device_placement = ["past_key_values"]
495
+ _supports_flash_attn = False
496
  _supports_sdpa = True
497
  _supports_flex_attn = True
498
+ _can_compile_fullgraph = False
499
+ _supports_attention_backend = True
500
+ _can_record_outputs = {
501
+ "router_logits": OutputRecorder(DogeCDMoE, index=1),
502
+ "hidden_states": DogeDecoderLayer,
503
+ "attentions": DogeAttention,
504
+ }
505
 
506
  def _init_weights(self, module):
507
+ """Initialize the weights"""
508
+ super()._init_weights(module)
509
+ if isinstance(module, DogeAttention):
510
+ if hasattr(module, "A"):
511
+ module.A.data.zero_()
512
+ elif isinstance(module, DogeDecoderLayer):
513
+ if hasattr(module, "input_residual"):
514
+ module.input_residual.data.fill_(1.0)
515
+ if hasattr(module, "post_attention_residual"):
516
+ module.post_attention_residual.data.fill_(1.0)
517
+
518
+
519
+ @auto_docstring
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
520
  class DogeModel(DogePreTrainedModel):
 
 
 
 
 
 
 
521
  def __init__(self, config: DogeConfig):
522
  super().__init__(config)
 
523
  self.padding_idx = config.pad_token_id
524
  self.vocab_size = config.vocab_size
525
 
526
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
 
527
  self.layers = nn.ModuleList(
528
  [DogeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
529
  )
530
+ self.norm = DogeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
531
+ self.rotary_emb = DogeRotaryEmbedding(config=config)
532
  self.gradient_checkpointing = False
533
 
534
  # Initialize weights and apply final processing
535
  self.post_init()
536
 
537
+ @check_model_inputs
538
+ @auto_docstring
 
 
 
 
 
539
  def forward(
540
  self,
541
+ input_ids: Optional[torch.LongTensor] = None,
542
  attention_mask: Optional[torch.Tensor] = None,
543
  position_ids: Optional[torch.LongTensor] = None,
544
+ past_key_values: Optional[Cache] = None,
545
  inputs_embeds: Optional[torch.FloatTensor] = None,
546
  use_cache: Optional[bool] = None,
 
 
 
547
  cache_position: Optional[torch.LongTensor] = None,
548
+ **kwargs: Unpack[TransformersKwargs],
549
+ ) -> MoeModelOutputWithPast:
 
 
 
 
 
 
 
550
  if (input_ids is None) ^ (inputs_embeds is not None):
551
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
 
 
 
 
 
 
 
 
 
552
 
553
  if use_cache and past_key_values is None:
554
  past_key_values = DynamicCache()
555
 
556
+ if inputs_embeds is None:
557
+ inputs_embeds = self.embed_tokens(input_ids)
558
+
559
  if cache_position is None:
560
  past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
561
  cache_position = torch.arange(
562
  past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
563
  )
 
564
  if position_ids is None:
565
  position_ids = cache_position.unsqueeze(0)
566
 
567
+ mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
568
+ causal_mask = mask_function(
569
+ config=self.config,
570
+ input_embeds=inputs_embeds,
571
+ attention_mask=attention_mask,
572
+ cache_position=cache_position,
573
+ past_key_values=past_key_values,
574
+ position_ids=position_ids,
575
  )
576
 
577
  hidden_states = inputs_embeds
 
579
  # create position embeddings to be shared across the decoder layers
580
  position_embeddings = self.rotary_emb(hidden_states, position_ids)
581
 
 
 
 
 
582
  for decoder_layer in self.layers[: self.config.num_hidden_layers]:
583
+ hidden_states = decoder_layer(
584
+ hidden_states,
585
+ position_embeddings=position_embeddings,
586
+ attention_mask=causal_mask,
587
+ position_ids=position_ids,
588
+ past_key_value=past_key_values,
589
+ use_cache=use_cache,
590
+ cache_position=cache_position,
591
+ **kwargs,
592
+ )
593
+
594
+ hidden_states = self.norm(hidden_states)
595
+
596
+ return MoeModelOutputWithPast( # only diff with Mistral is the output type, we need MoE
597
+ last_hidden_state=hidden_states,
598
+ past_key_values=past_key_values,
599
+ )
 
 
 
 
 
 
 
 
 
 
600
 
 
601
 
602
+ def load_balancing_loss_func(
603
+ gate_logits: Union[torch.Tensor, tuple[torch.Tensor], None],
604
+ num_experts: Optional[int] = None,
605
+ num_keys: Optional[int] = None,
606
+ top_k: int = 2,
607
+ attention_mask: Optional[torch.Tensor] = None,
608
+ ) -> Union[torch.Tensor, int]:
609
+ r"""
610
+ Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
611
 
612
+ See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
613
+ function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
614
+ experts is too unbalanced.
615
 
616
+ Args:
617
+ gate_logits:
618
+ Logits from the `router_gate`, should be a tuple of model.config.num_hidden_layers tensors of
619
+ shape [2, batch_size * sequence_length, num_keys].
620
+ num_experts:
621
+ Number of experts
622
+ num_keys:
623
+ Number of keys
624
+ top_k:
625
+ The number of experts to route per-token, can be also interpreted as the `top-k` routing
626
+ parameter.
627
+ attention_mask (`torch.Tensor`, *optional*):
628
+ The attention_mask used in forward function
629
+ shape [batch_size X sequence_length] if not None.
630
 
631
+ Returns:
632
+ The auxiliary loss.
633
+ """
634
+ if gate_logits is None or not isinstance(gate_logits, tuple):
635
+ return 0
 
 
636
 
637
+ compute_dtype = gate_logits[0].dtype
638
+ compute_device = gate_logits[0].device
639
+ all_expert_indices = []
640
+ all_routing_weights = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
641
 
642
+ for layer_gate_logits in gate_logits:
643
+ layer_gate_logits = layer_gate_logits.to(compute_device)
644
+
645
+ (scores_x, scores_y), (indices_x, indices_y) = layer_gate_logits.topk(num_keys, dim=-1)
646
+
647
+ all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2)
648
+ all_indices = indices_x.unsqueeze(-1) * num_keys + indices_y.unsqueeze(-2)
649
+ all_scores = all_scores.view(*all_scores.shape[:-2], -1)
650
+ all_indices = all_indices.view(*all_indices.shape[:-2], -1)
651
+
652
+ _, position_indices = all_scores.topk(top_k, dim=-1)
653
+ expert_indices = all_indices.gather(-1, position_indices)
654
+
655
+ routing_weights = F.softmax(all_scores, dim=-1)
656
+
657
+ all_expert_indices.append(expert_indices)
658
+ all_routing_weights.append(routing_weights)
659
+ all_expert_indices = torch.cat(all_expert_indices, dim=0)
660
+ all_routing_weights = torch.cat(all_routing_weights, dim=0)
661
+
662
+ if attention_mask is None:
663
+ # Compute the percentage of tokens routed to each experts
664
+ all_expert_indices = all_expert_indices.view(-1)
665
+ tokens_per_expert = torch.zeros(num_experts, dtype=compute_dtype, device=compute_device)
666
+ pad = torch.ones_like(all_expert_indices, dtype=compute_dtype, device=compute_device)
667
+ tokens_per_expert = tokens_per_expert.scatter_add_(0, all_expert_indices, pad) / all_expert_indices.shape[0]
668
+
669
+ # Compute the average probability of routing to these experts
670
+ router_prob_per_expert = torch.mean(all_routing_weights, dim=0)
671
+ else:
672
+ batch_size, sequence_length = attention_mask.shape
673
+ num_hidden_layers = len(gate_logits)
674
+
675
+ # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
676
+ expert_attention_mask = (
677
+ attention_mask[None, :, :, None]
678
+ .expand((num_hidden_layers, batch_size, sequence_length, top_k))
679
+ .reshape(-1)
680
+ .to(compute_device)
681
  )
682
+ all_expert_indices = all_expert_indices.view(-1)[expert_attention_mask.bool()]
683
 
684
+ # Compute the percentage of tokens routed to each experts
685
+ tokens_per_expert = torch.zeros(num_experts, dtype=compute_dtype, device=compute_device)
686
+ pad = torch.ones_like(all_expert_indices, dtype=compute_dtype, device=compute_device)
687
+ tokens_per_expert = tokens_per_expert.scatter_add_(0, all_expert_indices, pad) / torch.sum(
688
+ expert_attention_mask
689
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
690
 
691
+ # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
692
+ router_per_expert_attention_mask = (
693
+ attention_mask[None, :, :, None]
694
+ .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
695
+ .reshape(-1, num_experts)
696
+ .to(compute_device)
697
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
698
 
699
+ # Compute the average probability of routing to these experts
700
+ router_prob_per_expert = torch.sum(all_routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
701
+ router_per_expert_attention_mask, dim=0
702
+ )
703
+
704
+ overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert)
705
+ return overall_loss * num_experts
706
 
707
 
708
+ @auto_docstring
709
  class DogeForCausalLM(DogePreTrainedModel, GenerationMixin):
710
  _tied_weights_keys = ["lm_head.weight"]
711
  _tp_plan = {"lm_head": "colwise_rep"}
712
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
713
 
714
+ def __init__(self, config):
715
  super().__init__(config)
 
716
  self.model = DogeModel(config)
717
  self.vocab_size = config.vocab_size
718
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
719
+ self.router_aux_loss_coef = config.router_aux_loss_coef
720
+ self.num_experts = config.num_experts
721
+ self.num_experts_per_tok = config.num_experts_per_tok
722
 
723
  # Initialize weights and apply final processing
724
  self.post_init()
725
 
726
+ def set_decoder(self, decoder):
727
+ self.model = decoder
 
 
 
 
 
 
 
 
 
728
 
729
  def get_decoder(self):
730
  return self.model
731
 
732
+ @can_return_tuple
733
+ @auto_docstring
 
 
 
734
  def forward(
735
  self,
736
+ input_ids: Optional[torch.LongTensor] = None,
737
  attention_mask: Optional[torch.Tensor] = None,
738
  position_ids: Optional[torch.LongTensor] = None,
739
+ past_key_values: Optional[list[torch.FloatTensor]] = None,
740
  inputs_embeds: Optional[torch.FloatTensor] = None,
741
  labels: Optional[torch.LongTensor] = None,
742
  use_cache: Optional[bool] = None,
 
 
 
743
  cache_position: Optional[torch.LongTensor] = None,
744
  logits_to_keep: Union[int, torch.Tensor] = 0,
745
+ output_router_logits: Optional[bool] = None,
746
+ **kwargs: Unpack[TransformersKwargs],
747
+ ) -> MoeCausalLMOutputWithPast:
748
  r"""
749
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
750
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
751
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
752
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
 
 
 
 
 
 
 
 
 
 
753
 
754
  Example:
755
 
756
  ```python
757
+ >>> from transformers import AutoTokenizer, DogeForCausalLM
758
 
759
+ >>> model = DogeForCausalLM.from_pretrained("SmallDoge/Doge-320M")
760
+ >>> tokenizer = AutoTokenizer.from_pretrained("SmallDoge/Doge-320M")
761
 
762
  >>> prompt = "Hey, are you conscious? Can you talk to me?"
763
  >>> inputs = tokenizer(prompt, return_tensors="pt")
 
767
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
768
  "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
769
  ```"""
770
+ output_router_logits = (
771
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
 
772
  )
 
773
 
774
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
775
+ outputs: MoeModelOutputWithPast = self.model(
776
  input_ids=input_ids,
777
  attention_mask=attention_mask,
778
  position_ids=position_ids,
779
  past_key_values=past_key_values,
780
  inputs_embeds=inputs_embeds,
781
  use_cache=use_cache,
 
 
 
782
  cache_position=cache_position,
783
  **kwargs,
784
  )
785
 
786
+ hidden_states = outputs.last_hidden_state
787
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
788
  slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
789
  logits = self.lm_head(hidden_states[:, slice_indices, :])
790
 
791
  loss = None
792
  if labels is not None:
793
+ loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
794
+
795
+ aux_loss = None
796
+ if output_router_logits:
797
+ aux_loss = load_balancing_loss_func(
798
+ outputs.router_logits,
799
+ self.num_experts,
800
+ math.floor(math.sqrt(self.num_experts)),
801
+ self.num_experts_per_tok,
802
+ attention_mask,
803
+ )
804
+ if labels is not None:
805
+ loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
806
 
807
+ return MoeCausalLMOutputWithPast(
808
  loss=loss,
809
+ aux_loss=aux_loss,
810
  logits=logits,
811
  past_key_values=outputs.past_key_values,
812
  hidden_states=outputs.hidden_states,
813
  attentions=outputs.attentions,
814
+ router_logits=outputs.router_logits,
815
  )
816
 
817
 
818
+ class DogeForSequenceClassification(GenericForSequenceClassification, DogePreTrainedModel):
819
+ pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
820
 
821
 
822
  __all__ = ["DogeForCausalLM", "DogeModel", "DogePreTrainedModel", "DogeForSequenceClassification"]