Upload 2 files
Browse files- configuration_minicpm.py +5 -0
- modeling_minicpm.py +4 -4
    	
        configuration_minicpm.py
    CHANGED
    
    | @@ -174,6 +174,11 @@ class MiniCPMConfig(PretrainedConfig): | |
| 174 | 
             
                        tie_word_embeddings=tie_word_embeddings,
         | 
| 175 | 
             
                        **kwargs,
         | 
| 176 | 
             
                    )
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 177 |  | 
| 178 | 
             
                def _rope_scaling_validation(self):
         | 
| 179 | 
             
                    """
         | 
|  | |
| 174 | 
             
                        tie_word_embeddings=tie_word_embeddings,
         | 
| 175 | 
             
                        **kwargs,
         | 
| 176 | 
             
                    )
         | 
| 177 | 
            +
                    try:
         | 
| 178 | 
            +
                        import flash_attn
         | 
| 179 | 
            +
                        self._attn_implementation = "flash_attention_2"
         | 
| 180 | 
            +
                    except:
         | 
| 181 | 
            +
                        pass
         | 
| 182 |  | 
| 183 | 
             
                def _rope_scaling_validation(self):
         | 
| 184 | 
             
                    """
         | 
    	
        modeling_minicpm.py
    CHANGED
    
    | @@ -51,10 +51,11 @@ from transformers.utils.import_utils import is_torch_fx_available | |
| 51 | 
             
            from .configuration_minicpm import MiniCPMConfig
         | 
| 52 | 
             
            import re
         | 
| 53 |  | 
| 54 | 
            -
             | 
| 55 | 
            -
            if is_flash_attn_2_available():
         | 
| 56 | 
             
                from flash_attn import flash_attn_func, flash_attn_varlen_func
         | 
| 57 | 
             
                from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input  # noqa
         | 
|  | |
|  | |
| 58 |  | 
| 59 |  | 
| 60 | 
             
            # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
         | 
| @@ -125,7 +126,7 @@ ALL_LAYERNORM_LAYERS.append(MiniCPMRMSNorm) | |
| 125 |  | 
| 126 |  | 
| 127 | 
             
            class MiniCPMRotaryEmbedding(nn.Module):
         | 
| 128 | 
            -
                def __init__(self, dim, max_position_embeddings=2048, base=10000, device= | 
| 129 | 
             
                    super().__init__()
         | 
| 130 |  | 
| 131 | 
             
                    self.dim = dim
         | 
| @@ -763,7 +764,6 @@ class MiniCPMDecoderLayer(nn.Module): | |
| 763 | 
             
                def __init__(self, config: MiniCPMConfig, layer_idx: int):
         | 
| 764 | 
             
                    super().__init__()
         | 
| 765 | 
             
                    self.hidden_size = config.hidden_size
         | 
| 766 | 
            -
             | 
| 767 | 
             
                    self.self_attn = MINICPM_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
         | 
| 768 |  | 
| 769 | 
             
                    self.mlp = MiniCPMMLP(config)
         | 
|  | |
| 51 | 
             
            from .configuration_minicpm import MiniCPMConfig
         | 
| 52 | 
             
            import re
         | 
| 53 |  | 
| 54 | 
            +
            try:
         | 
|  | |
| 55 | 
             
                from flash_attn import flash_attn_func, flash_attn_varlen_func
         | 
| 56 | 
             
                from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input  # noqa
         | 
| 57 | 
            +
            except:
         | 
| 58 | 
            +
                pass
         | 
| 59 |  | 
| 60 |  | 
| 61 | 
             
            # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
         | 
|  | |
| 126 |  | 
| 127 |  | 
| 128 | 
             
            class MiniCPMRotaryEmbedding(nn.Module):
         | 
| 129 | 
            +
                def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
         | 
| 130 | 
             
                    super().__init__()
         | 
| 131 |  | 
| 132 | 
             
                    self.dim = dim
         | 
|  | |
| 764 | 
             
                def __init__(self, config: MiniCPMConfig, layer_idx: int):
         | 
| 765 | 
             
                    super().__init__()
         | 
| 766 | 
             
                    self.hidden_size = config.hidden_size
         | 
|  | |
| 767 | 
             
                    self.self_attn = MINICPM_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
         | 
| 768 |  | 
| 769 | 
             
                    self.mlp = MiniCPMMLP(config)
         | 
