Update modeling_quasrav4.py
Browse files- modeling_quasrav4.py +17 -4
modeling_quasrav4.py
CHANGED
@@ -9,7 +9,7 @@ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutpu
|
|
9 |
from transformers.generation import GenerationMixin
|
10 |
from transformers.utils import logging
|
11 |
|
12 |
-
from
|
13 |
|
14 |
logger = logging.get_logger(__name__)
|
15 |
|
@@ -136,13 +136,26 @@ class InfinityFormerLayer(nn.Module):
|
|
136 |
self.embed_dim = config.hidden_size
|
137 |
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
138 |
self.self_attn = LinearAttention(config, layer_idx)
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
self.ffn = GatedFeedForward(config)
|
140 |
|
141 |
def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> Tuple[torch.Tensor, ...]:
|
142 |
residual = hidden_states
|
143 |
-
|
144 |
-
|
145 |
-
hidden_states = residual +
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
146 |
hidden_states = self.ffn(hidden_states)
|
147 |
return (hidden_states,)
|
148 |
|
|
|
9 |
from transformers.generation import GenerationMixin
|
10 |
from transformers.utils import logging
|
11 |
|
12 |
+
from configuration_quasrav4 import InfinityFormerConfig
|
13 |
|
14 |
logger = logging.get_logger(__name__)
|
15 |
|
|
|
136 |
self.embed_dim = config.hidden_size
|
137 |
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
138 |
self.self_attn = LinearAttention(config, layer_idx)
|
139 |
+
|
140 |
+
print(f"DEBUG: Layer {layer_idx}, use_memory_attention={config.use_memory_attention}") # DEBUG
|
141 |
+
if config.use_memory_attention:
|
142 |
+
self.mem_attn_layer_norm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
143 |
+
self.mem_attn = LinearAttention(config, layer_idx)
|
144 |
+
|
145 |
self.ffn = GatedFeedForward(config)
|
146 |
|
147 |
def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> Tuple[torch.Tensor, ...]:
|
148 |
residual = hidden_states
|
149 |
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
150 |
+
hidden_states = self.self_attn(hidden_states, attention_mask=attention_mask, **kwargs)[0]
|
151 |
+
hidden_states = residual + hidden_states
|
152 |
+
|
153 |
+
if hasattr(self, 'mem_attn'):
|
154 |
+
mem_residual = hidden_states
|
155 |
+
hidden_states = self.mem_attn_layer_norm(hidden_states)
|
156 |
+
hidden_states = self.mem_attn(hidden_states, attention_mask=attention_mask, **kwargs)[0]
|
157 |
+
hidden_states = mem_residual + hidden_states
|
158 |
+
|
159 |
hidden_states = self.ffn(hidden_states)
|
160 |
return (hidden_states,)
|
161 |
|