Update modeling_quasrav4.py
Browse files- modeling_quasrav4.py +2 -2
modeling_quasrav4.py
CHANGED
@@ -136,6 +136,7 @@ class InfinityFormerLayer(nn.Module):
|
|
136 |
self.mem_attn_layer_norm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
137 |
self.mem_attn = LinearAttention(config, layer_idx)
|
138 |
self.ffn = GatedFeedForward(config)
|
|
|
139 |
|
140 |
def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> Tuple[torch.Tensor, ...]:
|
141 |
residual = hidden_states
|
@@ -150,6 +151,7 @@ class InfinityFormerLayer(nn.Module):
|
|
150 |
hidden_states = mem_residual + hidden_states
|
151 |
|
152 |
hidden_states = self.ffn(hidden_states)
|
|
|
153 |
return (hidden_states,)
|
154 |
|
155 |
class InfinityFormerEmbeddings(nn.Module):
|
@@ -196,7 +198,6 @@ class InfinityFormerModel(InfinityFormerPreTrainedModel):
|
|
196 |
self.config = config
|
197 |
self.embeddings = InfinityFormerEmbeddings(config)
|
198 |
self.layers = nn.ModuleList([InfinityFormerLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)])
|
199 |
-
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
200 |
self.gradient_checkpointing = False
|
201 |
self.post_init()
|
202 |
|
@@ -215,7 +216,6 @@ class InfinityFormerModel(InfinityFormerPreTrainedModel):
|
|
215 |
else:
|
216 |
layer_outputs = layer_module(hidden_states, attention_mask=attention_mask)
|
217 |
hidden_states = layer_outputs[0]
|
218 |
-
hidden_states = self.final_layer_norm(hidden_states)
|
219 |
if not return_dict:
|
220 |
return (hidden_states,)
|
221 |
return BaseModelOutputWithPast(last_hidden_state=hidden_states, past_key_values=None, hidden_states=None, attentions=None)
|
|
|
136 |
self.mem_attn_layer_norm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
137 |
self.mem_attn = LinearAttention(config, layer_idx)
|
138 |
self.ffn = GatedFeedForward(config)
|
139 |
+
self.final_layer_norm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
140 |
|
141 |
def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> Tuple[torch.Tensor, ...]:
|
142 |
residual = hidden_states
|
|
|
151 |
hidden_states = mem_residual + hidden_states
|
152 |
|
153 |
hidden_states = self.ffn(hidden_states)
|
154 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
155 |
return (hidden_states,)
|
156 |
|
157 |
class InfinityFormerEmbeddings(nn.Module):
|
|
|
198 |
self.config = config
|
199 |
self.embeddings = InfinityFormerEmbeddings(config)
|
200 |
self.layers = nn.ModuleList([InfinityFormerLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)])
|
|
|
201 |
self.gradient_checkpointing = False
|
202 |
self.post_init()
|
203 |
|
|
|
216 |
else:
|
217 |
layer_outputs = layer_module(hidden_states, attention_mask=attention_mask)
|
218 |
hidden_states = layer_outputs[0]
|
|
|
219 |
if not return_dict:
|
220 |
return (hidden_states,)
|
221 |
return BaseModelOutputWithPast(last_hidden_state=hidden_states, past_key_values=None, hidden_states=None, attentions=None)
|