eyad-silx commited on
Commit
0e9d476
·
verified ·
1 Parent(s): a44b70f

Update modeling_quasrav4.py

Browse files
Files changed (1) hide show
  1. modeling_quasrav4.py +2 -2
modeling_quasrav4.py CHANGED
@@ -136,6 +136,7 @@ class InfinityFormerLayer(nn.Module):
136
  self.mem_attn_layer_norm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
137
  self.mem_attn = LinearAttention(config, layer_idx)
138
  self.ffn = GatedFeedForward(config)
 
139
 
140
  def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> Tuple[torch.Tensor, ...]:
141
  residual = hidden_states
@@ -150,6 +151,7 @@ class InfinityFormerLayer(nn.Module):
150
  hidden_states = mem_residual + hidden_states
151
 
152
  hidden_states = self.ffn(hidden_states)
 
153
  return (hidden_states,)
154
 
155
  class InfinityFormerEmbeddings(nn.Module):
@@ -196,7 +198,6 @@ class InfinityFormerModel(InfinityFormerPreTrainedModel):
196
  self.config = config
197
  self.embeddings = InfinityFormerEmbeddings(config)
198
  self.layers = nn.ModuleList([InfinityFormerLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)])
199
- self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
200
  self.gradient_checkpointing = False
201
  self.post_init()
202
 
@@ -215,7 +216,6 @@ class InfinityFormerModel(InfinityFormerPreTrainedModel):
215
  else:
216
  layer_outputs = layer_module(hidden_states, attention_mask=attention_mask)
217
  hidden_states = layer_outputs[0]
218
- hidden_states = self.final_layer_norm(hidden_states)
219
  if not return_dict:
220
  return (hidden_states,)
221
  return BaseModelOutputWithPast(last_hidden_state=hidden_states, past_key_values=None, hidden_states=None, attentions=None)
 
136
  self.mem_attn_layer_norm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
137
  self.mem_attn = LinearAttention(config, layer_idx)
138
  self.ffn = GatedFeedForward(config)
139
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
140
 
141
  def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> Tuple[torch.Tensor, ...]:
142
  residual = hidden_states
 
151
  hidden_states = mem_residual + hidden_states
152
 
153
  hidden_states = self.ffn(hidden_states)
154
+ hidden_states = self.final_layer_norm(hidden_states)
155
  return (hidden_states,)
156
 
157
  class InfinityFormerEmbeddings(nn.Module):
 
198
  self.config = config
199
  self.embeddings = InfinityFormerEmbeddings(config)
200
  self.layers = nn.ModuleList([InfinityFormerLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)])
 
201
  self.gradient_checkpointing = False
202
  self.post_init()
203
 
 
216
  else:
217
  layer_outputs = layer_module(hidden_states, attention_mask=attention_mask)
218
  hidden_states = layer_outputs[0]
 
219
  if not return_dict:
220
  return (hidden_states,)
221
  return BaseModelOutputWithPast(last_hidden_state=hidden_states, past_key_values=None, hidden_states=None, attentions=None)