Upload DogeForCausalLM
Browse files- configuration_doge.py +1 -1
- modeling_doge.py +5 -4
configuration_doge.py
CHANGED
|
@@ -3,7 +3,7 @@
|
|
| 3 |
#
|
| 4 |
# This code is based on the Wonderful Matrices paper implementation.
|
| 5 |
#
|
| 6 |
-
# https://arxiv.org/abs/
|
| 7 |
#
|
| 8 |
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 9 |
# you may not use this file except in compliance with the License.
|
|
|
|
| 3 |
#
|
| 4 |
# This code is based on the Wonderful Matrices paper implementation.
|
| 5 |
#
|
| 6 |
+
# https://arxiv.org/abs/2412.11834
|
| 7 |
#
|
| 8 |
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 9 |
# you may not use this file except in compliance with the License.
|
modeling_doge.py
CHANGED
|
@@ -3,7 +3,7 @@
|
|
| 3 |
#
|
| 4 |
# This code is based on the Wonderful Matrices paper implementation.
|
| 5 |
#
|
| 6 |
-
# https://arxiv.org/abs/
|
| 7 |
#
|
| 8 |
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 9 |
# you may not use this file except in compliance with the License.
|
|
@@ -184,6 +184,7 @@ def apply_QK_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|
| 184 |
|
| 185 |
|
| 186 |
class DogeDynamicMaskAttention(nn.Module):
|
|
|
|
| 187 |
|
| 188 |
def __init__(self, config: DogeConfig, layer_idx: Optional[int] = None):
|
| 189 |
super().__init__()
|
|
@@ -387,6 +388,7 @@ class DogeMLP(nn.Module):
|
|
| 387 |
|
| 388 |
|
| 389 |
class DogeCDMoE(DogeMLP):
|
|
|
|
| 390 |
|
| 391 |
def __init__(self, config: DogeConfig):
|
| 392 |
super().__init__(config)
|
|
@@ -816,7 +818,7 @@ class DogeModel(DogePreTrainedModel):
|
|
| 816 |
)
|
| 817 |
|
| 818 |
# in case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
| 819 |
-
causal_mask = self.
|
| 820 |
attention_mask=attention_mask,
|
| 821 |
sequence_length=sequence_length,
|
| 822 |
target_length=target_length,
|
|
@@ -829,7 +831,7 @@ class DogeModel(DogePreTrainedModel):
|
|
| 829 |
return causal_mask
|
| 830 |
|
| 831 |
@staticmethod
|
| 832 |
-
def
|
| 833 |
attention_mask: torch.Tensor = None,
|
| 834 |
sequence_length: int = None,
|
| 835 |
target_length: int = None,
|
|
@@ -875,7 +877,6 @@ class DogeModel(DogePreTrainedModel):
|
|
| 875 |
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
| 876 |
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
| 877 |
if attention_mask is not None:
|
| 878 |
-
# print(f"attention_mask: {attention_mask.shape}")
|
| 879 |
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
| 880 |
mask_length = attention_mask.shape[-1]
|
| 881 |
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
|
|
|
| 3 |
#
|
| 4 |
# This code is based on the Wonderful Matrices paper implementation.
|
| 5 |
#
|
| 6 |
+
# https://arxiv.org/abs/2412.11834
|
| 7 |
#
|
| 8 |
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 9 |
# you may not use this file except in compliance with the License.
|
|
|
|
| 184 |
|
| 185 |
|
| 186 |
class DogeDynamicMaskAttention(nn.Module):
|
| 187 |
+
"""Dynamic Mask Attention from 'Wonderful Matrices' paper."""
|
| 188 |
|
| 189 |
def __init__(self, config: DogeConfig, layer_idx: Optional[int] = None):
|
| 190 |
super().__init__()
|
|
|
|
| 388 |
|
| 389 |
|
| 390 |
class DogeCDMoE(DogeMLP):
|
| 391 |
+
"""Cross Domain Mixture of Experts from 'Wonderful Matrices' paper."""
|
| 392 |
|
| 393 |
def __init__(self, config: DogeConfig):
|
| 394 |
super().__init__(config)
|
|
|
|
| 818 |
)
|
| 819 |
|
| 820 |
# in case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
| 821 |
+
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
| 822 |
attention_mask=attention_mask,
|
| 823 |
sequence_length=sequence_length,
|
| 824 |
target_length=target_length,
|
|
|
|
| 831 |
return causal_mask
|
| 832 |
|
| 833 |
@staticmethod
|
| 834 |
+
def _prepare_4d_causal_attention_mask_with_cache_position(
|
| 835 |
attention_mask: torch.Tensor = None,
|
| 836 |
sequence_length: int = None,
|
| 837 |
target_length: int = None,
|
|
|
|
| 877 |
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
| 878 |
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
| 879 |
if attention_mask is not None:
|
|
|
|
| 880 |
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
| 881 |
mask_length = attention_mask.shape[-1]
|
| 882 |
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|