fix-glu-mlp (#17)
Browse files- fix: glu for non-flash-attn (c768124c3822ff864a2fc7477dbd2a175754fc5b)
- mlp.py +8 -2
- modeling_bert.py +2 -1
    	
        mlp.py
    CHANGED
    
    | @@ -33,6 +33,7 @@ class GLUMLP(nn.Module): | |
| 33 | 
             
                        in_features,
         | 
| 34 | 
             
                        hidden_features,
         | 
| 35 | 
             
                        activation,
         | 
|  | |
| 36 | 
             
                        return_residual=False,
         | 
| 37 | 
             
                        hidden_dropout_prob=0.1
         | 
| 38 | 
             
                ):
         | 
| @@ -52,14 +53,19 @@ class GLUMLP(nn.Module): | |
| 52 | 
             
                    self.wo = nn.Linear(hidden_features, in_features)
         | 
| 53 | 
             
                    self.dropout = nn.Dropout(hidden_dropout_prob)
         | 
| 54 | 
             
                    self.return_residual = return_residual
         | 
|  | |
| 55 | 
             
                    #self.layernorm = nn.LayerNorm(in_features, eps=layer_norm_eps)
         | 
| 56 |  | 
| 57 | 
             
                def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
         | 
| 58 | 
             
                    residual_connection = hidden_states
         | 
| 59 | 
             
                    # compute the activation
         | 
| 60 | 
             
                    hidden_states = self.gated_layers(hidden_states)
         | 
| 61 | 
            -
                     | 
| 62 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
| 63 | 
             
                    hidden_states = self.act(gated) * non_gated
         | 
| 64 | 
             
                    hidden_states = self.dropout(hidden_states)
         | 
| 65 | 
             
                    # multiply by the second matrix
         | 
|  | |
| 33 | 
             
                        in_features,
         | 
| 34 | 
             
                        hidden_features,
         | 
| 35 | 
             
                        activation,
         | 
| 36 | 
            +
                        use_flash_attn,
         | 
| 37 | 
             
                        return_residual=False,
         | 
| 38 | 
             
                        hidden_dropout_prob=0.1
         | 
| 39 | 
             
                ):
         | 
|  | |
| 53 | 
             
                    self.wo = nn.Linear(hidden_features, in_features)
         | 
| 54 | 
             
                    self.dropout = nn.Dropout(hidden_dropout_prob)
         | 
| 55 | 
             
                    self.return_residual = return_residual
         | 
| 56 | 
            +
                    self.use_flash_attn = use_flash_attn
         | 
| 57 | 
             
                    #self.layernorm = nn.LayerNorm(in_features, eps=layer_norm_eps)
         | 
| 58 |  | 
| 59 | 
             
                def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
         | 
| 60 | 
             
                    residual_connection = hidden_states
         | 
| 61 | 
             
                    # compute the activation
         | 
| 62 | 
             
                    hidden_states = self.gated_layers(hidden_states)
         | 
| 63 | 
            +
                    if self.use_flash_attn:
         | 
| 64 | 
            +
                        gated = hidden_states[:, : self.hidden_features]
         | 
| 65 | 
            +
                        non_gated = hidden_states[:, self.hidden_features :]
         | 
| 66 | 
            +
                    else:
         | 
| 67 | 
            +
                        gated = hidden_states[:, :, : self.hidden_features]
         | 
| 68 | 
            +
                        non_gated = hidden_states[:, :, self.hidden_features :]
         | 
| 69 | 
             
                    hidden_states = self.act(gated) * non_gated
         | 
| 70 | 
             
                    hidden_states = self.dropout(hidden_states)
         | 
| 71 | 
             
                    # multiply by the second matrix
         | 
    	
        modeling_bert.py
    CHANGED
    
    | @@ -114,6 +114,7 @@ def create_mlp_cls(config, layer_idx=None, return_residual=False): | |
| 114 | 
             
                        GLUMLP,
         | 
| 115 | 
             
                        hidden_features=inner_dim,
         | 
| 116 | 
             
                        activation=config.hidden_act,
         | 
|  | |
| 117 | 
             
                        hidden_dropout_prob=config.hidden_dropout_prob,
         | 
| 118 | 
             
                        return_residual=return_residual,
         | 
| 119 | 
             
                    )
         | 
| @@ -802,4 +803,4 @@ class BertForMaskedLM(BertPreTrainedModel): | |
| 802 | 
             
                        loss=masked_lm_loss,
         | 
| 803 | 
             
                        prediction_logits=prediction_scores,
         | 
| 804 | 
             
                        seq_relationship_logits=seq_relationship_score,
         | 
| 805 | 
            -
                    )
         | 
|  | |
| 114 | 
             
                        GLUMLP,
         | 
| 115 | 
             
                        hidden_features=inner_dim,
         | 
| 116 | 
             
                        activation=config.hidden_act,
         | 
| 117 | 
            +
                        use_flash_attn=config.use_flash_attn,
         | 
| 118 | 
             
                        hidden_dropout_prob=config.hidden_dropout_prob,
         | 
| 119 | 
             
                        return_residual=return_residual,
         | 
| 120 | 
             
                    )
         | 
|  | |
| 803 | 
             
                        loss=masked_lm_loss,
         | 
| 804 | 
             
                        prediction_logits=prediction_scores,
         | 
| 805 | 
             
                        seq_relationship_logits=seq_relationship_score,
         | 
| 806 | 
            +
                    )
         | 

