Commit
·
c768124
1
Parent(s):
ad76444
fix: glu for non-flash-attn
Browse files- mlp.py +8 -2
- modeling_bert.py +2 -1
mlp.py
CHANGED
|
@@ -33,6 +33,7 @@ class GLUMLP(nn.Module):
|
|
| 33 |
in_features,
|
| 34 |
hidden_features,
|
| 35 |
activation,
|
|
|
|
| 36 |
return_residual=False,
|
| 37 |
hidden_dropout_prob=0.1
|
| 38 |
):
|
|
@@ -52,14 +53,19 @@ class GLUMLP(nn.Module):
|
|
| 52 |
self.wo = nn.Linear(hidden_features, in_features)
|
| 53 |
self.dropout = nn.Dropout(hidden_dropout_prob)
|
| 54 |
self.return_residual = return_residual
|
|
|
|
| 55 |
#self.layernorm = nn.LayerNorm(in_features, eps=layer_norm_eps)
|
| 56 |
|
| 57 |
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 58 |
residual_connection = hidden_states
|
| 59 |
# compute the activation
|
| 60 |
hidden_states = self.gated_layers(hidden_states)
|
| 61 |
-
|
| 62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
hidden_states = self.act(gated) * non_gated
|
| 64 |
hidden_states = self.dropout(hidden_states)
|
| 65 |
# multiply by the second matrix
|
|
|
|
| 33 |
in_features,
|
| 34 |
hidden_features,
|
| 35 |
activation,
|
| 36 |
+
use_flash_attn,
|
| 37 |
return_residual=False,
|
| 38 |
hidden_dropout_prob=0.1
|
| 39 |
):
|
|
|
|
| 53 |
self.wo = nn.Linear(hidden_features, in_features)
|
| 54 |
self.dropout = nn.Dropout(hidden_dropout_prob)
|
| 55 |
self.return_residual = return_residual
|
| 56 |
+
self.use_flash_attn = use_flash_attn
|
| 57 |
#self.layernorm = nn.LayerNorm(in_features, eps=layer_norm_eps)
|
| 58 |
|
| 59 |
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 60 |
residual_connection = hidden_states
|
| 61 |
# compute the activation
|
| 62 |
hidden_states = self.gated_layers(hidden_states)
|
| 63 |
+
if self.use_flash_attn:
|
| 64 |
+
gated = hidden_states[:, : self.hidden_features]
|
| 65 |
+
non_gated = hidden_states[:, self.hidden_features :]
|
| 66 |
+
else:
|
| 67 |
+
gated = hidden_states[:, :, : self.hidden_features]
|
| 68 |
+
non_gated = hidden_states[:, :, self.hidden_features :]
|
| 69 |
hidden_states = self.act(gated) * non_gated
|
| 70 |
hidden_states = self.dropout(hidden_states)
|
| 71 |
# multiply by the second matrix
|
modeling_bert.py
CHANGED
|
@@ -114,6 +114,7 @@ def create_mlp_cls(config, layer_idx=None, return_residual=False):
|
|
| 114 |
GLUMLP,
|
| 115 |
hidden_features=inner_dim,
|
| 116 |
activation=config.hidden_act,
|
|
|
|
| 117 |
hidden_dropout_prob=config.hidden_dropout_prob,
|
| 118 |
return_residual=return_residual,
|
| 119 |
)
|
|
@@ -802,4 +803,4 @@ class BertForMaskedLM(BertPreTrainedModel):
|
|
| 802 |
loss=masked_lm_loss,
|
| 803 |
prediction_logits=prediction_scores,
|
| 804 |
seq_relationship_logits=seq_relationship_score,
|
| 805 |
-
)
|
|
|
|
| 114 |
GLUMLP,
|
| 115 |
hidden_features=inner_dim,
|
| 116 |
activation=config.hidden_act,
|
| 117 |
+
use_flash_attn=config.use_flash_attn,
|
| 118 |
hidden_dropout_prob=config.hidden_dropout_prob,
|
| 119 |
return_residual=return_residual,
|
| 120 |
)
|
|
|
|
| 803 |
loss=masked_lm_loss,
|
| 804 |
prediction_logits=prediction_scores,
|
| 805 |
seq_relationship_logits=seq_relationship_score,
|
| 806 |
+
)
|