eyad-silx commited on
Commit
20026e7
·
verified ·
1 Parent(s): da44902

Update modeling_quasrav4.py

Browse files
Files changed (1) hide show
  1. modeling_quasrav4.py +17 -4
modeling_quasrav4.py CHANGED
@@ -9,7 +9,7 @@ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutpu
9
  from transformers.generation import GenerationMixin
10
  from transformers.utils import logging
11
 
12
- from .configuration_quasrav4 import InfinityFormerConfig
13
 
14
  logger = logging.get_logger(__name__)
15
 
@@ -136,13 +136,26 @@ class InfinityFormerLayer(nn.Module):
136
  self.embed_dim = config.hidden_size
137
  self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
138
  self.self_attn = LinearAttention(config, layer_idx)
 
 
 
 
 
 
139
  self.ffn = GatedFeedForward(config)
140
 
141
  def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> Tuple[torch.Tensor, ...]:
142
  residual = hidden_states
143
- hidden_states_ln = self.self_attn_layer_norm(hidden_states)
144
- attn_outputs, _ = self.self_attn(hidden_states=hidden_states_ln, attention_mask=attention_mask, **kwargs)
145
- hidden_states = residual + attn_outputs
 
 
 
 
 
 
 
146
  hidden_states = self.ffn(hidden_states)
147
  return (hidden_states,)
148
 
 
9
  from transformers.generation import GenerationMixin
10
  from transformers.utils import logging
11
 
12
+ from configuration_quasrav4 import InfinityFormerConfig
13
 
14
  logger = logging.get_logger(__name__)
15
 
 
136
  self.embed_dim = config.hidden_size
137
  self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
138
  self.self_attn = LinearAttention(config, layer_idx)
139
+
140
+ print(f"DEBUG: Layer {layer_idx}, use_memory_attention={config.use_memory_attention}") # DEBUG
141
+ if config.use_memory_attention:
142
+ self.mem_attn_layer_norm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
143
+ self.mem_attn = LinearAttention(config, layer_idx)
144
+
145
  self.ffn = GatedFeedForward(config)
146
 
147
  def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> Tuple[torch.Tensor, ...]:
148
  residual = hidden_states
149
+ hidden_states = self.self_attn_layer_norm(hidden_states)
150
+ hidden_states = self.self_attn(hidden_states, attention_mask=attention_mask, **kwargs)[0]
151
+ hidden_states = residual + hidden_states
152
+
153
+ if hasattr(self, 'mem_attn'):
154
+ mem_residual = hidden_states
155
+ hidden_states = self.mem_attn_layer_norm(hidden_states)
156
+ hidden_states = self.mem_attn(hidden_states, attention_mask=attention_mask, **kwargs)[0]
157
+ hidden_states = mem_residual + hidden_states
158
+
159
  hidden_states = self.ffn(hidden_states)
160
  return (hidden_states,)
161