Upload folder using huggingface_hub
Browse files- config.json +37 -0
 - generation_config.json +6 -0
 - interpretation/examples-0.pkl +3 -0
 - interpretation/examples-1.pkl +3 -0
 - interpretation/examples-2.pkl +3 -0
 - interpretation/examples-3.pkl +3 -0
 - interpretation/examples-4.pkl +3 -0
 - interpretation/examples-5.pkl +3 -0
 - interpretation/inputs.pt +3 -0
 - interpretation/routings-0.pkl +3 -0
 - interpretation/routings-1.pkl +3 -0
 - interpretation/routings-2.pkl +3 -0
 - interpretation/routings-3.pkl +3 -0
 - interpretation/routings-4.pkl +3 -0
 - interpretation/routings-5.pkl +3 -0
 - model.safetensors +3 -0
 - modeling_monet.py +663 -0
 - special_tokens_map.json +23 -0
 - tokenizer.model +3 -0
 - tokenizer_config.json +43 -0
 
    	
        config.json
    ADDED
    
    | 
         @@ -0,0 +1,37 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            {
         
     | 
| 2 | 
         
            +
              "architectures": ["MonetForCausalLM"],
         
     | 
| 3 | 
         
            +
              "attention_bias": false,
         
     | 
| 4 | 
         
            +
              "attention_dropout": 0.0,
         
     | 
| 5 | 
         
            +
              "auto_map": {
         
     | 
| 6 | 
         
            +
                "AutoConfig": "modeling_monet.MonetConfig",
         
     | 
| 7 | 
         
            +
                "AutoModelForCausalLM": "modeling_monet.MonetForCausalLM"
         
     | 
| 8 | 
         
            +
              },
         
     | 
| 9 | 
         
            +
              "bos_token_id": 1,
         
     | 
| 10 | 
         
            +
              "eos_token_id": 2,
         
     | 
| 11 | 
         
            +
              "hidden_act": "relu2",
         
     | 
| 12 | 
         
            +
              "hidden_size": 2048,
         
     | 
| 13 | 
         
            +
              "initializer_range": 0.02,
         
     | 
| 14 | 
         
            +
              "intermediate_size": null,
         
     | 
| 15 | 
         
            +
              "max_position_embeddings": 2048,
         
     | 
| 16 | 
         
            +
              "mlp_bias": null,
         
     | 
| 17 | 
         
            +
              "model_type": "monet",
         
     | 
| 18 | 
         
            +
              "moe_decompose": "vertical",
         
     | 
| 19 | 
         
            +
              "moe_dim": 16,
         
     | 
| 20 | 
         
            +
              "moe_experts": 512,
         
     | 
| 21 | 
         
            +
              "moe_groups": 4,
         
     | 
| 22 | 
         
            +
              "moe_heads": 8,
         
     | 
| 23 | 
         
            +
              "moe_topk": 8,
         
     | 
| 24 | 
         
            +
              "num_attention_heads": 16,
         
     | 
| 25 | 
         
            +
              "num_hidden_layers": 24,
         
     | 
| 26 | 
         
            +
              "num_key_value_heads": 16,
         
     | 
| 27 | 
         
            +
              "output_router_probs": false,
         
     | 
| 28 | 
         
            +
              "pretraining_tp": 1,
         
     | 
| 29 | 
         
            +
              "rms_norm_eps": 1e-6,
         
     | 
| 30 | 
         
            +
              "rope_scaling": null,
         
     | 
| 31 | 
         
            +
              "rope_theta": 10000.0,
         
     | 
| 32 | 
         
            +
              "tie_word_embeddings": false,
         
     | 
| 33 | 
         
            +
              "torch_dtype": "bfloat16",
         
     | 
| 34 | 
         
            +
              "transformers_version": "4.42.3",
         
     | 
| 35 | 
         
            +
              "use_cache": true,
         
     | 
| 36 | 
         
            +
              "vocab_size": 32000
         
     | 
| 37 | 
         
            +
            }
         
     | 
    	
        generation_config.json
    ADDED
    
    | 
         @@ -0,0 +1,6 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            {
         
     | 
| 2 | 
         
            +
              "_from_model_config": true,
         
     | 
| 3 | 
         
            +
              "bos_token_id": 1,
         
     | 
| 4 | 
         
            +
              "eos_token_id": 2,
         
     | 
| 5 | 
         
            +
              "transformers_version": "4.42.3"
         
     | 
| 6 | 
         
            +
            }
         
     | 
    	
        interpretation/examples-0.pkl
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:678aa1278b0e2aebe8a6eafdf9885bfdce16e549d58dc212dbe79beb06484036
         
     | 
| 3 | 
         
            +
            size 72515482
         
     | 
    	
        interpretation/examples-1.pkl
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:746272ad2c5008343dcf1be372a8aaefa3866d90204aa2c5ea7489b3082ed783
         
     | 
| 3 | 
         
            +
            size 139404798
         
     | 
    	
        interpretation/examples-2.pkl
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:97514d59ae4b53e3772a8d1de0c3a040c32bff216c58149ab2fa2add26558199
         
     | 
| 3 | 
         
            +
            size 160892104
         
     | 
    	
        interpretation/examples-3.pkl
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:2cdbebf070c753204d466d63a01157fc392f1011aa5a0ca9850b69ad3369dc05
         
     | 
| 3 | 
         
            +
            size 195673312
         
     | 
    	
        interpretation/examples-4.pkl
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:9a2199c5cbe30e28797d0b98d217a5d1b0985875493c1d0959831988167a622d
         
     | 
| 3 | 
         
            +
            size 227660355
         
     | 
    	
        interpretation/examples-5.pkl
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:c88f484cc1040cff5d3ae07313f8c5a145dd9a626896cb01c10b9b9315518223
         
     | 
| 3 | 
         
            +
            size 220131055
         
     | 
    	
        interpretation/inputs.pt
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:41c464248caf2445896a40b49814cd899d9f06e0b16efc57c2953359f33e8a0b
         
     | 
| 3 | 
         
            +
            size 115213399
         
     | 
    	
        interpretation/routings-0.pkl
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:6a01878bce5db1040ac02e2e1e32d05d39095789703f48344645cb6e8b2238ee
         
     | 
| 3 | 
         
            +
            size 1960732940
         
     | 
    	
        interpretation/routings-1.pkl
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:b4d8657ae3a0f4f35ff5717e06f8c3966d7ac5e82ded219335fa0cdcdb95c834
         
     | 
| 3 | 
         
            +
            size 2399759022
         
     | 
    	
        interpretation/routings-2.pkl
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:f8e3418b74bf13e9c1857a2a8c3812c952f8044c6f55183adcf3ba281fc982ce
         
     | 
| 3 | 
         
            +
            size 2373573108
         
     | 
    	
        interpretation/routings-3.pkl
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:1176df5400ee181c942c6af4df9b5cc4f847b8d9ee9ed44d1fc5b2df62ad501f
         
     | 
| 3 | 
         
            +
            size 2566831074
         
     | 
    	
        interpretation/routings-4.pkl
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:744a17dbdb0e85af47d70f55d2edd9c1b4d9263b3e4d5d8671b7a8c6704f1352
         
     | 
| 3 | 
         
            +
            size 2606568788
         
     | 
    	
        interpretation/routings-5.pkl
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:39f52af9b20cb759158087b9210e642b5755e85f8ecf71149d85c4f50854ad4c
         
     | 
| 3 | 
         
            +
            size 2633436501
         
     | 
    	
        model.safetensors
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:82cd15b76fdcc474ba76869e52436b362c9a5f695792c8a3a9e8f95b77ebd437
         
     | 
| 3 | 
         
            +
            size 2930363080
         
     | 
    	
        modeling_monet.py
    ADDED
    
    | 
         @@ -0,0 +1,663 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # fmt: off
         
     | 
| 2 | 
         
            +
            from __future__ import annotations
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            from dataclasses import dataclass
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            import torch
         
     | 
| 7 | 
         
            +
            import torch.utils.checkpoint
         
     | 
| 8 | 
         
            +
            from scipy.stats import norm
         
     | 
| 9 | 
         
            +
            from torch import nn
         
     | 
| 10 | 
         
            +
            from torch.nn import CrossEntropyLoss
         
     | 
| 11 | 
         
            +
            from transformers.activations import ACT2FN
         
     | 
| 12 | 
         
            +
            from transformers.cache_utils import Cache, DynamicCache, StaticCache
         
     | 
| 13 | 
         
            +
            from transformers.modeling_attn_mask_utils import AttentionMaskConverter
         
     | 
| 14 | 
         
            +
            from transformers.modeling_utils import PreTrainedModel
         
     | 
| 15 | 
         
            +
            from transformers.models.llama.configuration_llama import LlamaConfig
         
     | 
| 16 | 
         
            +
            from transformers.models.llama.modeling_llama import (
         
     | 
| 17 | 
         
            +
                LLAMA_ATTENTION_CLASSES,
         
     | 
| 18 | 
         
            +
                LlamaRMSNorm,
         
     | 
| 19 | 
         
            +
            )
         
     | 
| 20 | 
         
            +
            from transformers.utils import ModelOutput, logging
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
            logger = logging.get_logger(__name__)
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
            @dataclass
         
     | 
| 26 | 
         
            +
            class MonetModelOutputWithPast(ModelOutput):
         
     | 
| 27 | 
         
            +
                last_hidden_state: torch.FloatTensor = None
         
     | 
| 28 | 
         
            +
                past_key_values: tuple[tuple[torch.FloatTensor]] | None = None
         
     | 
| 29 | 
         
            +
                hidden_states: tuple[torch.FloatTensor, ...] | None = None
         
     | 
| 30 | 
         
            +
                attentions: tuple[torch.FloatTensor, ...] | None = None
         
     | 
| 31 | 
         
            +
                router_probs: tuple[tuple[torch.FloatTensor, ...], ...] | None = None
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
            @dataclass
         
     | 
| 35 | 
         
            +
            class MonetCausalLMOutputWithPast(ModelOutput):
         
     | 
| 36 | 
         
            +
                loss: torch.FloatTensor | None = None
         
     | 
| 37 | 
         
            +
                aux_loss: torch.FloatTensor | None = None
         
     | 
| 38 | 
         
            +
                logits: torch.FloatTensor = None
         
     | 
| 39 | 
         
            +
                past_key_values: tuple[tuple[torch.FloatTensor]] | None = None
         
     | 
| 40 | 
         
            +
                hidden_states: tuple[torch.FloatTensor, ...] | None = None
         
     | 
| 41 | 
         
            +
                attentions: tuple[torch.FloatTensor, ...] | None = None
         
     | 
| 42 | 
         
            +
                router_probs: tuple[tuple[torch.FloatTensor, ...], ...] | None = None
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
            class MonetConfig(LlamaConfig):
         
     | 
| 46 | 
         
            +
                model_type = "monet"
         
     | 
| 47 | 
         
            +
                keys_to_ignore_at_inference = ["past_key_values"]
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
                def __init__(
         
     | 
| 50 | 
         
            +
                    self,
         
     | 
| 51 | 
         
            +
                    vocab_size=32000,
         
     | 
| 52 | 
         
            +
                    hidden_size=4096,
         
     | 
| 53 | 
         
            +
                    intermediate_size=None,
         
     | 
| 54 | 
         
            +
                    num_hidden_layers=32,
         
     | 
| 55 | 
         
            +
                    num_attention_heads=32,
         
     | 
| 56 | 
         
            +
                    num_key_value_heads=None,
         
     | 
| 57 | 
         
            +
                    hidden_act="relu2",
         
     | 
| 58 | 
         
            +
                    max_position_embeddings=2048,
         
     | 
| 59 | 
         
            +
                    initializer_range=0.02,
         
     | 
| 60 | 
         
            +
                    rms_norm_eps=1e-6,
         
     | 
| 61 | 
         
            +
                    use_cache=True,
         
     | 
| 62 | 
         
            +
                    pad_token_id=None,
         
     | 
| 63 | 
         
            +
                    bos_token_id=1,
         
     | 
| 64 | 
         
            +
                    eos_token_id=2,
         
     | 
| 65 | 
         
            +
                    pretraining_tp=1,
         
     | 
| 66 | 
         
            +
                    tie_word_embeddings=False,
         
     | 
| 67 | 
         
            +
                    rope_theta=10000.0,
         
     | 
| 68 | 
         
            +
                    rope_scaling=None,
         
     | 
| 69 | 
         
            +
                    attention_bias=False,
         
     | 
| 70 | 
         
            +
                    attention_dropout=0.0,
         
     | 
| 71 | 
         
            +
                    mlp_bias=None,
         
     | 
| 72 | 
         
            +
                    moe_dim=8,
         
     | 
| 73 | 
         
            +
                    moe_heads=8,
         
     | 
| 74 | 
         
            +
                    moe_experts=512,
         
     | 
| 75 | 
         
            +
                    moe_topk=32,
         
     | 
| 76 | 
         
            +
                    moe_groups=4,
         
     | 
| 77 | 
         
            +
                    moe_decompose="vertical",
         
     | 
| 78 | 
         
            +
                    output_router_probs=False,
         
     | 
| 79 | 
         
            +
                    **kwargs,
         
     | 
| 80 | 
         
            +
                ):
         
     | 
| 81 | 
         
            +
                    self.moe_dim = moe_dim
         
     | 
| 82 | 
         
            +
                    self.moe_heads = moe_heads
         
     | 
| 83 | 
         
            +
                    self.moe_experts = moe_experts
         
     | 
| 84 | 
         
            +
                    self.moe_topk = moe_topk
         
     | 
| 85 | 
         
            +
                    self.moe_groups = moe_groups
         
     | 
| 86 | 
         
            +
                    self.moe_decompose = moe_decompose
         
     | 
| 87 | 
         
            +
                    self.output_router_probs = output_router_probs
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
                    super().__init__(
         
     | 
| 90 | 
         
            +
                        vocab_size=vocab_size,
         
     | 
| 91 | 
         
            +
                        hidden_size=hidden_size,
         
     | 
| 92 | 
         
            +
                        intermediate_size=intermediate_size,
         
     | 
| 93 | 
         
            +
                        num_hidden_layers=num_hidden_layers,
         
     | 
| 94 | 
         
            +
                        num_attention_heads=num_attention_heads,
         
     | 
| 95 | 
         
            +
                        num_key_value_heads=num_key_value_heads,
         
     | 
| 96 | 
         
            +
                        hidden_act=hidden_act,
         
     | 
| 97 | 
         
            +
                        max_position_embeddings=max_position_embeddings,
         
     | 
| 98 | 
         
            +
                        initializer_range=initializer_range,
         
     | 
| 99 | 
         
            +
                        rms_norm_eps=rms_norm_eps,
         
     | 
| 100 | 
         
            +
                        use_cache=use_cache,
         
     | 
| 101 | 
         
            +
                        pad_token_id=pad_token_id,
         
     | 
| 102 | 
         
            +
                        bos_token_id=bos_token_id,
         
     | 
| 103 | 
         
            +
                        eos_token_id=eos_token_id,
         
     | 
| 104 | 
         
            +
                        pretraining_tp=pretraining_tp,
         
     | 
| 105 | 
         
            +
                        tie_word_embeddings=tie_word_embeddings,
         
     | 
| 106 | 
         
            +
                        rope_theta=rope_theta,
         
     | 
| 107 | 
         
            +
                        rope_scaling=rope_scaling,
         
     | 
| 108 | 
         
            +
                        attention_bias=attention_bias,
         
     | 
| 109 | 
         
            +
                        attention_dropout=attention_dropout,
         
     | 
| 110 | 
         
            +
                        mlp_bias=mlp_bias,
         
     | 
| 111 | 
         
            +
                        **kwargs,
         
     | 
| 112 | 
         
            +
                    )
         
     | 
| 113 | 
         
            +
             
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
            class MonetRouter(nn.Module):
         
     | 
| 116 | 
         
            +
                def __init__(self, config: MonetConfig):
         
     | 
| 117 | 
         
            +
                    super().__init__()
         
     | 
| 118 | 
         
            +
                    self.config = config
         
     | 
| 119 | 
         
            +
                    flatten_shape = config.moe_heads * config.moe_experts
         
     | 
| 120 | 
         
            +
             
     | 
| 121 | 
         
            +
                    self.w1 = nn.Linear(config.hidden_size, flatten_shape, bias=False)
         
     | 
| 122 | 
         
            +
                    self.w2 = nn.Linear(config.hidden_size, flatten_shape, bias=False)
         
     | 
| 123 | 
         
            +
                    self.norm1 = nn.BatchNorm1d(config.moe_heads, affine=False)
         
     | 
| 124 | 
         
            +
                    self.norm2 = nn.BatchNorm1d(config.moe_heads, affine=False)
         
     | 
| 125 | 
         
            +
             
     | 
| 126 | 
         
            +
                def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
         
     | 
| 127 | 
         
            +
                    g1z = self.w1(x).unflatten(-1, (self.config.moe_heads, -1)).float()
         
     | 
| 128 | 
         
            +
                    g2z = self.w2(x).unflatten(-1, (self.config.moe_heads, -1)).float()
         
     | 
| 129 | 
         
            +
             
     | 
| 130 | 
         
            +
                    g1n = self.norm1(g1z.transpose(2, 3).flatten(0, -2))
         
     | 
| 131 | 
         
            +
                    g2n = self.norm2(g2z.transpose(2, 3).flatten(0, -2))
         
     | 
| 132 | 
         
            +
                    g1n = g1n.view(g1z.size(0), g1z.size(1), g1z.size(3), -1).transpose(2, 3)
         
     | 
| 133 | 
         
            +
                    g2n = g2n.view(g2z.size(0), g2z.size(1), g2z.size(3), -1).transpose(2, 3)
         
     | 
| 134 | 
         
            +
             
     | 
| 135 | 
         
            +
                    sigma = float(norm.ppf(1 - self.config.moe_topk / self.config.moe_experts))
         
     | 
| 136 | 
         
            +
                    g1s = g1n.amax(-1, keepdim=True).clamp_max_(sigma)
         
     | 
| 137 | 
         
            +
                    g2s = g2n.amax(-1, keepdim=True).clamp_max_(sigma)
         
     | 
| 138 | 
         
            +
             
     | 
| 139 | 
         
            +
                    g1 = nn.functional.softmax(torch.where(g1n >= g1s, g1z, -1e10), dim=-1)
         
     | 
| 140 | 
         
            +
                    g2 = nn.functional.softmax(torch.where(g2n >= g2s, g2z, -1e10), dim=-1)
         
     | 
| 141 | 
         
            +
                    return g1, g2
         
     | 
| 142 | 
         
            +
             
     | 
| 143 | 
         
            +
             
     | 
| 144 | 
         
            +
            class MonetMoVDE(nn.Module):
         
     | 
| 145 | 
         
            +
                def __init__(self, config: MonetConfig):
         
     | 
| 146 | 
         
            +
                    super().__init__()
         
     | 
| 147 | 
         
            +
                    self.config = config
         
     | 
| 148 | 
         
            +
                    self.act_fn = ACT2FN[config.hidden_act]
         
     | 
| 149 | 
         
            +
                    flatten_shape = config.moe_experts * config.moe_dim // 2
         
     | 
| 150 | 
         
            +
             
     | 
| 151 | 
         
            +
                    self.u1 = nn.Linear(config.hidden_size, flatten_shape)
         
     | 
| 152 | 
         
            +
                    self.u2 = nn.Linear(config.hidden_size, flatten_shape)
         
     | 
| 153 | 
         
            +
             
     | 
| 154 | 
         
            +
                    self.v11 = nn.Linear(flatten_shape, config.hidden_size // 2, bias=False)
         
     | 
| 155 | 
         
            +
                    self.v12 = nn.Linear(flatten_shape, config.hidden_size // 2, bias=False)
         
     | 
| 156 | 
         
            +
                    self.v21 = nn.Linear(flatten_shape, config.hidden_size // 2, bias=False)
         
     | 
| 157 | 
         
            +
                    self.v22 = nn.Linear(flatten_shape, config.hidden_size // 2, bias=False)
         
     | 
| 158 | 
         
            +
             
     | 
| 159 | 
         
            +
                    self.b1 = nn.Parameter(torch.zeros(config.moe_experts, config.hidden_size // 2))
         
     | 
| 160 | 
         
            +
                    self.b2 = nn.Parameter(torch.zeros(config.moe_experts, config.hidden_size // 2))
         
     | 
| 161 | 
         
            +
             
     | 
| 162 | 
         
            +
                def forward(
         
     | 
| 163 | 
         
            +
                    self, x: torch.Tensor, g1: torch.Tensor, g2: torch.Tensor
         
     | 
| 164 | 
         
            +
                ) -> torch.Tensor:
         
     | 
| 165 | 
         
            +
                    g1, g2 = g1.type_as(x), g2.type_as(x)
         
     | 
| 166 | 
         
            +
                    x1 = self.act_fn(self.u1(x).unflatten(-1, (self.config.moe_experts, -1)))
         
     | 
| 167 | 
         
            +
                    x2 = self.act_fn(self.u2(x).unflatten(-1, (self.config.moe_experts, -1)))
         
     | 
| 168 | 
         
            +
             
     | 
| 169 | 
         
            +
                    x11 = self.v11(torch.einsum("btim,bthi->btim", x1, g1).flatten(-2))
         
     | 
| 170 | 
         
            +
                    x12 = self.v12(torch.einsum("btjm,bthj,bthi->btim", x2, g2, g1).flatten(-2))
         
     | 
| 171 | 
         
            +
                    x13 = torch.einsum("bthi,id->btd", g1, self.b1.type_as(x))
         
     | 
| 172 | 
         
            +
             
     | 
| 173 | 
         
            +
                    x21 = self.v21(torch.einsum("btim,bthi,bthj->btjm", x1, g1, g2).flatten(-2))
         
     | 
| 174 | 
         
            +
                    x22 = self.v22(torch.einsum("btjm,bthj->btjm", x2, g2).flatten(-2))
         
     | 
| 175 | 
         
            +
                    x23 = torch.einsum("bthj,jd->btd", g2, self.b2.type_as(x))
         
     | 
| 176 | 
         
            +
             
     | 
| 177 | 
         
            +
                    return torch.cat((x11 + x12 + x13, x21 + x22 + x23), dim=-1)
         
     | 
| 178 | 
         
            +
             
     | 
| 179 | 
         
            +
             
     | 
| 180 | 
         
            +
            class MonetMoHDE(nn.Module):
         
     | 
| 181 | 
         
            +
                def __init__(self, config: MonetConfig):
         
     | 
| 182 | 
         
            +
                    super().__init__()
         
     | 
| 183 | 
         
            +
                    self.config = config
         
     | 
| 184 | 
         
            +
                    self.act_fn = ACT2FN[config.hidden_act]
         
     | 
| 185 | 
         
            +
                    flatten_shape = config.moe_experts * config.moe_dim
         
     | 
| 186 | 
         
            +
             
     | 
| 187 | 
         
            +
                    self.u = nn.Linear(config.hidden_size, flatten_shape)
         
     | 
| 188 | 
         
            +
                    self.v = nn.Linear(flatten_shape, config.hidden_size, bias=False)
         
     | 
| 189 | 
         
            +
                    self.b = nn.Parameter(torch.zeros(config.moe_experts, config.hidden_size))
         
     | 
| 190 | 
         
            +
             
     | 
| 191 | 
         
            +
                def forward(
         
     | 
| 192 | 
         
            +
                    self, x: torch.Tensor, g1: torch.Tensor, g2: torch.Tensor
         
     | 
| 193 | 
         
            +
                ) -> torch.Tensor:
         
     | 
| 194 | 
         
            +
                    g1, g2 = g1.type_as(x), g2.type_as(x)
         
     | 
| 195 | 
         
            +
                    x = self.act_fn(self.u(x).unflatten(-1, (self.config.moe_experts, -1)))
         
     | 
| 196 | 
         
            +
                    x = self.v(torch.einsum("btim,bthi,bthj->btjm", x, g1, g2).flatten(-2))
         
     | 
| 197 | 
         
            +
                    return x + torch.einsum("bthj,jd->btd", g2, self.b)
         
     | 
| 198 | 
         
            +
             
     | 
| 199 | 
         
            +
             
     | 
| 200 | 
         
            +
            class MonetDecoderLayer(nn.Module):
         
     | 
| 201 | 
         
            +
                def __init__(self, config: MonetConfig, layer_idx: int):
         
     | 
| 202 | 
         
            +
                    super().__init__()
         
     | 
| 203 | 
         
            +
                    self.hidden_size = config.hidden_size
         
     | 
| 204 | 
         
            +
                    self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](
         
     | 
| 205 | 
         
            +
                        config=config, layer_idx=layer_idx
         
     | 
| 206 | 
         
            +
                    )
         
     | 
| 207 | 
         
            +
                    self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
         
     | 
| 208 | 
         
            +
                    self.post_attention_layernorm = LlamaRMSNorm(
         
     | 
| 209 | 
         
            +
                        config.hidden_size, eps=config.rms_norm_eps
         
     | 
| 210 | 
         
            +
                    )
         
     | 
| 211 | 
         
            +
             
     | 
| 212 | 
         
            +
                    if config.moe_decompose == "vertical":
         
     | 
| 213 | 
         
            +
                        self.moe = MonetMoVDE(config)
         
     | 
| 214 | 
         
            +
                    elif config.moe_decompose == "horizontal":
         
     | 
| 215 | 
         
            +
                        self.moe = MonetMoHDE(config)
         
     | 
| 216 | 
         
            +
                    if layer_idx % config.moe_groups == 0:
         
     | 
| 217 | 
         
            +
                        self.router = MonetRouter(config).requires_grad_(False)
         
     | 
| 218 | 
         
            +
             
     | 
| 219 | 
         
            +
                def forward(
         
     | 
| 220 | 
         
            +
                    self,
         
     | 
| 221 | 
         
            +
                    hidden_states: torch.Tensor,
         
     | 
| 222 | 
         
            +
                    attention_mask: torch.Tensor | None = None,
         
     | 
| 223 | 
         
            +
                    position_ids: torch.LongTensor | None = None,
         
     | 
| 224 | 
         
            +
                    past_key_value: Cache | None = None,
         
     | 
| 225 | 
         
            +
                    previous_router_probs: tuple[torch.Tensor, torch.Tensor] | None = None,
         
     | 
| 226 | 
         
            +
                    output_attentions: bool | None = False,
         
     | 
| 227 | 
         
            +
                    use_cache: bool | None = False,
         
     | 
| 228 | 
         
            +
                    cache_position: torch.LongTensor | None = None,
         
     | 
| 229 | 
         
            +
                    **kwargs,
         
     | 
| 230 | 
         
            +
                ) -> tuple[torch.FloatTensor, ...]:
         
     | 
| 231 | 
         
            +
                    residual = hidden_states
         
     | 
| 232 | 
         
            +
             
     | 
| 233 | 
         
            +
                    hidden_states = self.input_layernorm(hidden_states)
         
     | 
| 234 | 
         
            +
             
     | 
| 235 | 
         
            +
                    # Self Attention
         
     | 
| 236 | 
         
            +
                    hidden_states, self_attn_weights, present_key_value = self.self_attn(
         
     | 
| 237 | 
         
            +
                        hidden_states=hidden_states,
         
     | 
| 238 | 
         
            +
                        attention_mask=attention_mask,
         
     | 
| 239 | 
         
            +
                        position_ids=position_ids,
         
     | 
| 240 | 
         
            +
                        past_key_value=past_key_value,
         
     | 
| 241 | 
         
            +
                        output_attentions=output_attentions,
         
     | 
| 242 | 
         
            +
                        use_cache=use_cache,
         
     | 
| 243 | 
         
            +
                        cache_position=cache_position,
         
     | 
| 244 | 
         
            +
                    )
         
     | 
| 245 | 
         
            +
                    hidden_states = residual + hidden_states
         
     | 
| 246 | 
         
            +
             
     | 
| 247 | 
         
            +
                    # Fully Connected
         
     | 
| 248 | 
         
            +
                    residual = hidden_states
         
     | 
| 249 | 
         
            +
                    hidden_states = self.post_attention_layernorm(hidden_states)
         
     | 
| 250 | 
         
            +
                    g1, g2 = (
         
     | 
| 251 | 
         
            +
                        self.router(hidden_states)
         
     | 
| 252 | 
         
            +
                        if hasattr(self, "router")
         
     | 
| 253 | 
         
            +
                        else previous_router_probs
         
     | 
| 254 | 
         
            +
                    )
         
     | 
| 255 | 
         
            +
                    hidden_states = self.moe(hidden_states, g1, g2)
         
     | 
| 256 | 
         
            +
                    hidden_states = residual + hidden_states
         
     | 
| 257 | 
         
            +
             
     | 
| 258 | 
         
            +
                    outputs = (hidden_states,)
         
     | 
| 259 | 
         
            +
             
     | 
| 260 | 
         
            +
                    if output_attentions:
         
     | 
| 261 | 
         
            +
                        outputs += (self_attn_weights,)
         
     | 
| 262 | 
         
            +
             
     | 
| 263 | 
         
            +
                    if use_cache:
         
     | 
| 264 | 
         
            +
                        outputs += (present_key_value,)
         
     | 
| 265 | 
         
            +
             
     | 
| 266 | 
         
            +
                    return outputs + ((g1, g2) if hasattr(self, "router") else None,)
         
     | 
| 267 | 
         
            +
             
     | 
| 268 | 
         
            +
             
     | 
| 269 | 
         
            +
            class MonetPreTrainedModel(PreTrainedModel):
         
     | 
| 270 | 
         
            +
                config_class = MonetConfig
         
     | 
| 271 | 
         
            +
                base_model_prefix = "model"
         
     | 
| 272 | 
         
            +
                supports_gradient_checkpointing = True
         
     | 
| 273 | 
         
            +
                _no_split_modules = ["MonetDecoderLayer"]
         
     | 
| 274 | 
         
            +
                _skip_keys_device_placement = ["past_key_values"]
         
     | 
| 275 | 
         
            +
                _supports_flash_attn_2 = True
         
     | 
| 276 | 
         
            +
                _supports_sdpa = True
         
     | 
| 277 | 
         
            +
                _supports_cache_class = True
         
     | 
| 278 | 
         
            +
                _supports_quantized_cache = True
         
     | 
| 279 | 
         
            +
                _supports_static_cache = True
         
     | 
| 280 | 
         
            +
             
     | 
| 281 | 
         
            +
                def _init_weights(self, module):
         
     | 
| 282 | 
         
            +
                    std = self.config.initializer_range
         
     | 
| 283 | 
         
            +
                    if isinstance(module, nn.Linear):
         
     | 
| 284 | 
         
            +
                        module.weight.data.normal_(mean=0.0, std=std)
         
     | 
| 285 | 
         
            +
                        if module.bias is not None:
         
     | 
| 286 | 
         
            +
                            module.bias.data.zero_()
         
     | 
| 287 | 
         
            +
                    elif isinstance(module, nn.Embedding):
         
     | 
| 288 | 
         
            +
                        module.weight.data.normal_(mean=0.0, std=std)
         
     | 
| 289 | 
         
            +
                        if module.padding_idx is not None:
         
     | 
| 290 | 
         
            +
                            module.weight.data[module.padding_idx].zero_()
         
     | 
| 291 | 
         
            +
             
     | 
| 292 | 
         
            +
             
     | 
| 293 | 
         
            +
            class MonetModel(MonetPreTrainedModel):
         
     | 
| 294 | 
         
            +
                def __init__(self, config: MonetConfig):
         
     | 
| 295 | 
         
            +
                    super().__init__(config)
         
     | 
| 296 | 
         
            +
                    self.padding_idx = config.pad_token_id
         
     | 
| 297 | 
         
            +
                    self.vocab_size = config.vocab_size
         
     | 
| 298 | 
         
            +
             
     | 
| 299 | 
         
            +
                    self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)  # noqa
         
     | 
| 300 | 
         
            +
                    self.layers = nn.ModuleList([MonetDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])  # noqa
         
     | 
| 301 | 
         
            +
                    self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
         
     | 
| 302 | 
         
            +
                    self.gradient_checkpointing = False
         
     | 
| 303 | 
         
            +
             
     | 
| 304 | 
         
            +
                    # Initialize weights and apply final processing
         
     | 
| 305 | 
         
            +
                    self.post_init()
         
     | 
| 306 | 
         
            +
             
     | 
| 307 | 
         
            +
                def get_input_embeddings(self):
         
     | 
| 308 | 
         
            +
                    return self.embed_tokens
         
     | 
| 309 | 
         
            +
             
     | 
| 310 | 
         
            +
                def set_input_embeddings(self, value):
         
     | 
| 311 | 
         
            +
                    self.embed_tokens = value
         
     | 
| 312 | 
         
            +
             
     | 
| 313 | 
         
            +
                def forward(
         
     | 
| 314 | 
         
            +
                    self,
         
     | 
| 315 | 
         
            +
                    input_ids: torch.LongTensor = None,
         
     | 
| 316 | 
         
            +
                    attention_mask: torch.Tensor | None = None,
         
     | 
| 317 | 
         
            +
                    position_ids: torch.LongTensor | None = None,
         
     | 
| 318 | 
         
            +
                    past_key_values: Cache | list[torch.FloatTensor] | None = None,
         
     | 
| 319 | 
         
            +
                    inputs_embeds: torch.FloatTensor | None = None,
         
     | 
| 320 | 
         
            +
                    use_cache: bool | None = None,
         
     | 
| 321 | 
         
            +
                    output_attentions: bool | None = None,
         
     | 
| 322 | 
         
            +
                    output_hidden_states: bool | None = None,
         
     | 
| 323 | 
         
            +
                    output_router_probs: bool | None = None,
         
     | 
| 324 | 
         
            +
                    return_dict: bool | None = None,
         
     | 
| 325 | 
         
            +
                    cache_position: torch.LongTensor | None = None,
         
     | 
| 326 | 
         
            +
                ) -> tuple[torch.Tensor, ...] | MonetModelOutputWithPast:
         
     | 
| 327 | 
         
            +
                    output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions  # noqa
         
     | 
| 328 | 
         
            +
                    output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states  # noqa
         
     | 
| 329 | 
         
            +
                    output_router_probs = output_router_probs if output_router_probs is not None else self.config.output_router_probs  # noqa
         
     | 
| 330 | 
         
            +
                    use_cache = use_cache if use_cache is not None else self.config.use_cache
         
     | 
| 331 | 
         
            +
                    return_dict = return_dict if return_dict is not None else self.config.use_return_dict  # noqa
         
     | 
| 332 | 
         
            +
             
     | 
| 333 | 
         
            +
                    if (input_ids is None) ^ (inputs_embeds is not None):
         
     | 
| 334 | 
         
            +
                        raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one")  # noqa
         
     | 
| 335 | 
         
            +
             
     | 
| 336 | 
         
            +
                    if self.gradient_checkpointing and self.training and use_cache:
         
     | 
| 337 | 
         
            +
                        logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.")  # noqa
         
     | 
| 338 | 
         
            +
                        use_cache = False
         
     | 
| 339 | 
         
            +
             
     | 
| 340 | 
         
            +
                    if inputs_embeds is None:
         
     | 
| 341 | 
         
            +
                        inputs_embeds = self.embed_tokens(input_ids)
         
     | 
| 342 | 
         
            +
             
     | 
| 343 | 
         
            +
                    return_legacy_cache = False
         
     | 
| 344 | 
         
            +
                    if use_cache and not isinstance(past_key_values, Cache):  # kept for BC (non `Cache` `past_key_values` inputs)  # noqa
         
     | 
| 345 | 
         
            +
                        return_legacy_cache = True
         
     | 
| 346 | 
         
            +
                        past_key_values = DynamicCache.from_legacy_cache(past_key_values)
         
     | 
| 347 | 
         
            +
                        logger.warning_once(
         
     | 
| 348 | 
         
            +
                            "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. "  # noqa
         
     | 
| 349 | 
         
            +
                            "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)"  # noqa
         
     | 
| 350 | 
         
            +
                        )
         
     | 
| 351 | 
         
            +
             
     | 
| 352 | 
         
            +
                    if cache_position is None:
         
     | 
| 353 | 
         
            +
                        past_seen_tokens = (
         
     | 
| 354 | 
         
            +
                            past_key_values.get_seq_length() if past_key_values is not None else 0
         
     | 
| 355 | 
         
            +
                        )
         
     | 
| 356 | 
         
            +
                        cache_position = torch.arange(past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device)  # noqa
         
     | 
| 357 | 
         
            +
                    if position_ids is None:
         
     | 
| 358 | 
         
            +
                        position_ids = cache_position.unsqueeze(0)
         
     | 
| 359 | 
         
            +
                    causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions)  # noqa
         
     | 
| 360 | 
         
            +
             
     | 
| 361 | 
         
            +
                    # embed positions
         
     | 
| 362 | 
         
            +
                    hidden_states = inputs_embeds
         
     | 
| 363 | 
         
            +
             
     | 
| 364 | 
         
            +
                    # decoder layers
         
     | 
| 365 | 
         
            +
                    all_hidden_states = () if output_hidden_states else None
         
     | 
| 366 | 
         
            +
                    all_self_attns = () if output_attentions else None
         
     | 
| 367 | 
         
            +
                    all_router_probs = () if output_router_probs else None
         
     | 
| 368 | 
         
            +
                    previous_router_probs, next_decoder_cache = None, None
         
     | 
| 369 | 
         
            +
             
     | 
| 370 | 
         
            +
                    for decoder_layer in self.layers:
         
     | 
| 371 | 
         
            +
                        if output_hidden_states:
         
     | 
| 372 | 
         
            +
                            all_hidden_states += (hidden_states,)
         
     | 
| 373 | 
         
            +
             
     | 
| 374 | 
         
            +
                        if self.gradient_checkpointing and self.training:
         
     | 
| 375 | 
         
            +
                            layer_outputs = self._gradient_checkpointing_func(
         
     | 
| 376 | 
         
            +
                                decoder_layer.__call__,
         
     | 
| 377 | 
         
            +
                                hidden_states,
         
     | 
| 378 | 
         
            +
                                causal_mask,
         
     | 
| 379 | 
         
            +
                                position_ids,
         
     | 
| 380 | 
         
            +
                                past_key_values,
         
     | 
| 381 | 
         
            +
                                previous_router_probs,
         
     | 
| 382 | 
         
            +
                                output_attentions,
         
     | 
| 383 | 
         
            +
                                use_cache,
         
     | 
| 384 | 
         
            +
                                cache_position,
         
     | 
| 385 | 
         
            +
                            )
         
     | 
| 386 | 
         
            +
                        else:
         
     | 
| 387 | 
         
            +
                            layer_outputs = decoder_layer(
         
     | 
| 388 | 
         
            +
                                hidden_states,
         
     | 
| 389 | 
         
            +
                                attention_mask=causal_mask,
         
     | 
| 390 | 
         
            +
                                position_ids=position_ids,
         
     | 
| 391 | 
         
            +
                                past_key_value=past_key_values,
         
     | 
| 392 | 
         
            +
                                previous_router_probs=previous_router_probs,
         
     | 
| 393 | 
         
            +
                                output_attentions=output_attentions,
         
     | 
| 394 | 
         
            +
                                use_cache=use_cache,
         
     | 
| 395 | 
         
            +
                                cache_position=cache_position,
         
     | 
| 396 | 
         
            +
                            )
         
     | 
| 397 | 
         
            +
             
     | 
| 398 | 
         
            +
                        hidden_states = layer_outputs[0]
         
     | 
| 399 | 
         
            +
                        if use_cache:
         
     | 
| 400 | 
         
            +
                            next_decoder_cache = layer_outputs[2 if output_attentions else 1]
         
     | 
| 401 | 
         
            +
                        if output_attentions:
         
     | 
| 402 | 
         
            +
                            all_self_attns += (layer_outputs[1],)
         
     | 
| 403 | 
         
            +
                        if output_router_probs:
         
     | 
| 404 | 
         
            +
                            all_router_probs += (layer_outputs[-1],)
         
     | 
| 405 | 
         
            +
                        previous_router_probs = (
         
     | 
| 406 | 
         
            +
                            layer_outputs[-1]
         
     | 
| 407 | 
         
            +
                            if layer_outputs[-1] is not None
         
     | 
| 408 | 
         
            +
                            else previous_router_probs
         
     | 
| 409 | 
         
            +
                        )
         
     | 
| 410 | 
         
            +
             
     | 
| 411 | 
         
            +
                    hidden_states = self.norm(hidden_states)
         
     | 
| 412 | 
         
            +
             
     | 
| 413 | 
         
            +
                    # add hidden states from the last decoder layer
         
     | 
| 414 | 
         
            +
                    if output_hidden_states:
         
     | 
| 415 | 
         
            +
                        all_hidden_states += (hidden_states,)
         
     | 
| 416 | 
         
            +
             
     | 
| 417 | 
         
            +
                    next_cache = next_decoder_cache if use_cache else None
         
     | 
| 418 | 
         
            +
                    if return_legacy_cache:
         
     | 
| 419 | 
         
            +
                        next_cache = next_cache.to_legacy_cache()
         
     | 
| 420 | 
         
            +
             
     | 
| 421 | 
         
            +
                    if not return_dict:
         
     | 
| 422 | 
         
            +
                        return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_probs] if v is not None)  # noqa
         
     | 
| 423 | 
         
            +
                    return MonetModelOutputWithPast(
         
     | 
| 424 | 
         
            +
                        last_hidden_state=hidden_states,
         
     | 
| 425 | 
         
            +
                        past_key_values=next_cache,
         
     | 
| 426 | 
         
            +
                        hidden_states=all_hidden_states,
         
     | 
| 427 | 
         
            +
                        attentions=all_self_attns,
         
     | 
| 428 | 
         
            +
                        router_probs=all_router_probs,
         
     | 
| 429 | 
         
            +
                    )
         
     | 
| 430 | 
         
            +
             
     | 
| 431 | 
         
            +
                def _update_causal_mask(
         
     | 
| 432 | 
         
            +
                    self,
         
     | 
| 433 | 
         
            +
                    attention_mask: torch.Tensor,
         
     | 
| 434 | 
         
            +
                    input_tensor: torch.Tensor,
         
     | 
| 435 | 
         
            +
                    cache_position: torch.Tensor,
         
     | 
| 436 | 
         
            +
                    past_key_values: Cache,
         
     | 
| 437 | 
         
            +
                    output_attentions: bool,
         
     | 
| 438 | 
         
            +
                ):
         
     | 
| 439 | 
         
            +
                    if self.config._attn_implementation == "flash_attention_2":
         
     | 
| 440 | 
         
            +
                        if attention_mask is not None and 0.0 in attention_mask:
         
     | 
| 441 | 
         
            +
                            return attention_mask
         
     | 
| 442 | 
         
            +
                        return None
         
     | 
| 443 | 
         
            +
             
     | 
| 444 | 
         
            +
                    past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0  # noqa
         
     | 
| 445 | 
         
            +
                    using_static_cache = isinstance(past_key_values, StaticCache)
         
     | 
| 446 | 
         
            +
             
     | 
| 447 | 
         
            +
                    if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:  # noqa
         
     | 
| 448 | 
         
            +
                        if AttentionMaskConverter._ignore_causal_mask_sdpa(
         
     | 
| 449 | 
         
            +
                            attention_mask,
         
     | 
| 450 | 
         
            +
                            inputs_embeds=input_tensor,
         
     | 
| 451 | 
         
            +
                            past_key_values_length=past_seen_tokens,
         
     | 
| 452 | 
         
            +
                            is_training=self.training,
         
     | 
| 453 | 
         
            +
                        ):
         
     | 
| 454 | 
         
            +
                            return None
         
     | 
| 455 | 
         
            +
             
     | 
| 456 | 
         
            +
                    dtype, device = input_tensor.dtype, input_tensor.device
         
     | 
| 457 | 
         
            +
                    min_dtype = torch.finfo(dtype).min
         
     | 
| 458 | 
         
            +
                    sequence_length = input_tensor.shape[1]
         
     | 
| 459 | 
         
            +
                    if using_static_cache:
         
     | 
| 460 | 
         
            +
                        target_length = past_key_values.get_max_length()
         
     | 
| 461 | 
         
            +
                    else:
         
     | 
| 462 | 
         
            +
                        target_length = (
         
     | 
| 463 | 
         
            +
                            attention_mask.shape[-1]
         
     | 
| 464 | 
         
            +
                            if isinstance(attention_mask, torch.Tensor)
         
     | 
| 465 | 
         
            +
                            else past_seen_tokens + sequence_length + 1
         
     | 
| 466 | 
         
            +
                        )
         
     | 
| 467 | 
         
            +
             
     | 
| 468 | 
         
            +
                    if attention_mask is not None and attention_mask.dim() == 4:
         
     | 
| 469 | 
         
            +
                        if attention_mask.max() != 0:
         
     | 
| 470 | 
         
            +
                            raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`")  # noqa
         
     | 
| 471 | 
         
            +
                        causal_mask = attention_mask
         
     | 
| 472 | 
         
            +
                    else:
         
     | 
| 473 | 
         
            +
                        causal_mask = torch.full(
         
     | 
| 474 | 
         
            +
                            (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device  # noqa
         
     | 
| 475 | 
         
            +
                        )
         
     | 
| 476 | 
         
            +
                        if sequence_length != 1:
         
     | 
| 477 | 
         
            +
                            causal_mask = torch.triu(causal_mask, diagonal=1)
         
     | 
| 478 | 
         
            +
                        causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)  # noqa
         
     | 
| 479 | 
         
            +
                        causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)  # noqa
         
     | 
| 480 | 
         
            +
                        if attention_mask is not None:
         
     | 
| 481 | 
         
            +
                            causal_mask = causal_mask.clone()  # copy to contiguous memory for in-place edit  # noqa
         
     | 
| 482 | 
         
            +
                            mask_length = attention_mask.shape[-1]
         
     | 
| 483 | 
         
            +
                            padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]  # noqa
         
     | 
| 484 | 
         
            +
                            padding_mask = padding_mask == 0
         
     | 
| 485 | 
         
            +
                            causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(padding_mask, min_dtype)  # noqa
         
     | 
| 486 | 
         
            +
                    if (
         
     | 
| 487 | 
         
            +
                        self.config._attn_implementation == "sdpa"
         
     | 
| 488 | 
         
            +
                        and attention_mask is not None
         
     | 
| 489 | 
         
            +
                        and attention_mask.device.type == "cuda"
         
     | 
| 490 | 
         
            +
                        and not output_attentions
         
     | 
| 491 | 
         
            +
                    ):
         
     | 
| 492 | 
         
            +
                        causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)  # noqa
         
     | 
| 493 | 
         
            +
             
     | 
| 494 | 
         
            +
                    return causal_mask
         
     | 
| 495 | 
         
            +
             
     | 
| 496 | 
         
            +
             
     | 
| 497 | 
         
            +
            class MonetForCausalLM(MonetPreTrainedModel):
         
     | 
| 498 | 
         
            +
                _tied_weights_keys = ["lm_head.weight"]
         
     | 
| 499 | 
         
            +
             
     | 
| 500 | 
         
            +
                def __init__(self, config):
         
     | 
| 501 | 
         
            +
                    super().__init__(config)
         
     | 
| 502 | 
         
            +
                    self.model = MonetModel(config)
         
     | 
| 503 | 
         
            +
                    self.vocab_size = config.vocab_size
         
     | 
| 504 | 
         
            +
                    self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
         
     | 
| 505 | 
         
            +
             
     | 
| 506 | 
         
            +
                    # Initialize weights and apply final processing
         
     | 
| 507 | 
         
            +
                    self.post_init()
         
     | 
| 508 | 
         
            +
             
     | 
| 509 | 
         
            +
                def get_input_embeddings(self):
         
     | 
| 510 | 
         
            +
                    return self.model.embed_tokens
         
     | 
| 511 | 
         
            +
             
     | 
| 512 | 
         
            +
                def set_input_embeddings(self, value):
         
     | 
| 513 | 
         
            +
                    self.model.embed_tokens = value
         
     | 
| 514 | 
         
            +
             
     | 
| 515 | 
         
            +
                def get_output_embeddings(self):
         
     | 
| 516 | 
         
            +
                    return self.lm_head
         
     | 
| 517 | 
         
            +
             
     | 
| 518 | 
         
            +
                def set_output_embeddings(self, new_embeddings):
         
     | 
| 519 | 
         
            +
                    self.lm_head = new_embeddings
         
     | 
| 520 | 
         
            +
             
     | 
| 521 | 
         
            +
                def set_decoder(self, decoder):
         
     | 
| 522 | 
         
            +
                    self.model = decoder
         
     | 
| 523 | 
         
            +
             
     | 
| 524 | 
         
            +
                def get_decoder(self):
         
     | 
| 525 | 
         
            +
                    return self.model
         
     | 
| 526 | 
         
            +
             
     | 
| 527 | 
         
            +
                def forward(
         
     | 
| 528 | 
         
            +
                    self,
         
     | 
| 529 | 
         
            +
                    input_ids: torch.LongTensor = None,
         
     | 
| 530 | 
         
            +
                    attention_mask: torch.Tensor | None = None,
         
     | 
| 531 | 
         
            +
                    position_ids: torch.LongTensor | None = None,
         
     | 
| 532 | 
         
            +
                    past_key_values: Cache | list[torch.FloatTensor] | None = None,
         
     | 
| 533 | 
         
            +
                    inputs_embeds: torch.FloatTensor | None = None,
         
     | 
| 534 | 
         
            +
                    labels: torch.LongTensor | None = None,
         
     | 
| 535 | 
         
            +
                    use_cache: bool | None = None,
         
     | 
| 536 | 
         
            +
                    output_attentions: bool | None = None,
         
     | 
| 537 | 
         
            +
                    output_hidden_states: bool | None = None,
         
     | 
| 538 | 
         
            +
                    output_router_probs: bool | None = None,
         
     | 
| 539 | 
         
            +
                    return_dict: bool | None = None,
         
     | 
| 540 | 
         
            +
                    cache_position: torch.LongTensor | None = None,
         
     | 
| 541 | 
         
            +
                ) -> tuple[torch.Tensor, ...] | MonetCausalLMOutputWithPast:
         
     | 
| 542 | 
         
            +
                    output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions  # noqa
         
     | 
| 543 | 
         
            +
                    output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states  # noqa
         
     | 
| 544 | 
         
            +
                    output_router_probs = output_router_probs if output_router_probs is not None else self.config.output_router_probs  # noqa
         
     | 
| 545 | 
         
            +
                    return_dict = return_dict if return_dict is not None else self.config.use_return_dict  # noqa
         
     | 
| 546 | 
         
            +
             
     | 
| 547 | 
         
            +
                    # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
         
     | 
| 548 | 
         
            +
                    outputs = self.model(
         
     | 
| 549 | 
         
            +
                        input_ids=input_ids,
         
     | 
| 550 | 
         
            +
                        attention_mask=attention_mask,
         
     | 
| 551 | 
         
            +
                        position_ids=position_ids,
         
     | 
| 552 | 
         
            +
                        past_key_values=past_key_values,
         
     | 
| 553 | 
         
            +
                        inputs_embeds=inputs_embeds,
         
     | 
| 554 | 
         
            +
                        use_cache=use_cache,
         
     | 
| 555 | 
         
            +
                        output_attentions=output_attentions,
         
     | 
| 556 | 
         
            +
                        output_hidden_states=output_hidden_states,
         
     | 
| 557 | 
         
            +
                        output_router_probs=output_router_probs,
         
     | 
| 558 | 
         
            +
                        return_dict=return_dict,
         
     | 
| 559 | 
         
            +
                        cache_position=cache_position,
         
     | 
| 560 | 
         
            +
                    )
         
     | 
| 561 | 
         
            +
             
     | 
| 562 | 
         
            +
                    hidden_states = outputs[0]
         
     | 
| 563 | 
         
            +
                    logits = self.lm_head(hidden_states)
         
     | 
| 564 | 
         
            +
                    logits = logits.float()
         
     | 
| 565 | 
         
            +
             
     | 
| 566 | 
         
            +
                    loss = None
         
     | 
| 567 | 
         
            +
                    if labels is not None:
         
     | 
| 568 | 
         
            +
                        # Shift so that tokens < n predict n
         
     | 
| 569 | 
         
            +
                        shift_logits = logits[..., :-1, :].contiguous()
         
     | 
| 570 | 
         
            +
                        shift_labels = labels[..., 1:].contiguous()
         
     | 
| 571 | 
         
            +
                        # Flatten the tokens
         
     | 
| 572 | 
         
            +
                        loss_fct = CrossEntropyLoss()
         
     | 
| 573 | 
         
            +
                        shift_logits = shift_logits.view(-1, self.config.vocab_size)
         
     | 
| 574 | 
         
            +
                        shift_labels = shift_labels.view(-1)
         
     | 
| 575 | 
         
            +
                        # Enable model parallelism
         
     | 
| 576 | 
         
            +
                        shift_labels = shift_labels.to(shift_logits.device)
         
     | 
| 577 | 
         
            +
                        loss = loss_fct(shift_logits, shift_labels)
         
     | 
| 578 | 
         
            +
             
     | 
| 579 | 
         
            +
                    if not return_dict:
         
     | 
| 580 | 
         
            +
                        output = (logits,) + outputs[1:]
         
     | 
| 581 | 
         
            +
                        return (loss,) + output if loss is not None else output
         
     | 
| 582 | 
         
            +
             
     | 
| 583 | 
         
            +
                    return MonetCausalLMOutputWithPast(
         
     | 
| 584 | 
         
            +
                        loss=loss,
         
     | 
| 585 | 
         
            +
                        logits=logits,
         
     | 
| 586 | 
         
            +
                        past_key_values=outputs.past_key_values,
         
     | 
| 587 | 
         
            +
                        hidden_states=outputs.hidden_states,
         
     | 
| 588 | 
         
            +
                        attentions=outputs.attentions,
         
     | 
| 589 | 
         
            +
                        router_probs=outputs.router_probs,
         
     | 
| 590 | 
         
            +
                    )
         
     | 
| 591 | 
         
            +
             
     | 
| 592 | 
         
            +
                def prepare_inputs_for_generation(
         
     | 
| 593 | 
         
            +
                    self,
         
     | 
| 594 | 
         
            +
                    input_ids,
         
     | 
| 595 | 
         
            +
                    past_key_values=None,
         
     | 
| 596 | 
         
            +
                    attention_mask=None,
         
     | 
| 597 | 
         
            +
                    inputs_embeds=None,
         
     | 
| 598 | 
         
            +
                    cache_position=None,
         
     | 
| 599 | 
         
            +
                    use_cache=True,
         
     | 
| 600 | 
         
            +
                    **kwargs,
         
     | 
| 601 | 
         
            +
                ):
         
     | 
| 602 | 
         
            +
                    past_length = 0
         
     | 
| 603 | 
         
            +
                    if past_key_values is not None:
         
     | 
| 604 | 
         
            +
                        past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()  # noqa
         
     | 
| 605 | 
         
            +
                        max_cache_length = (
         
     | 
| 606 | 
         
            +
                            torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
         
     | 
| 607 | 
         
            +
                            if past_key_values.get_max_length() is not None
         
     | 
| 608 | 
         
            +
                            else None
         
     | 
| 609 | 
         
            +
                        )
         
     | 
| 610 | 
         
            +
                        cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)  # noqa
         
     | 
| 611 | 
         
            +
             
     | 
| 612 | 
         
            +
                        # Keep only the unprocessed tokens:
         
     | 
| 613 | 
         
            +
                        if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:  # noqa
         
     | 
| 614 | 
         
            +
                            input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
         
     | 
| 615 | 
         
            +
                        # input_ids based on the past_length.
         
     | 
| 616 | 
         
            +
                        elif past_length < input_ids.shape[1]:
         
     | 
| 617 | 
         
            +
                            input_ids = input_ids[:, past_length:]
         
     | 
| 618 | 
         
            +
             
     | 
| 619 | 
         
            +
                        if (
         
     | 
| 620 | 
         
            +
                            max_cache_length is not None
         
     | 
| 621 | 
         
            +
                            and attention_mask is not None
         
     | 
| 622 | 
         
            +
                            and cache_length + input_ids.shape[1] > max_cache_length
         
     | 
| 623 | 
         
            +
                        ):
         
     | 
| 624 | 
         
            +
                            attention_mask = attention_mask[:, -max_cache_length:]
         
     | 
| 625 | 
         
            +
             
     | 
| 626 | 
         
            +
                    position_ids = kwargs.get("position_ids", None)
         
     | 
| 627 | 
         
            +
                    if attention_mask is not None and position_ids is None:
         
     | 
| 628 | 
         
            +
                        # create position_ids on the fly for batch generation
         
     | 
| 629 | 
         
            +
                        position_ids = attention_mask.long().cumsum(-1) - 1
         
     | 
| 630 | 
         
            +
                        position_ids.masked_fill_(attention_mask == 0, 1)
         
     | 
| 631 | 
         
            +
                        if past_key_values:
         
     | 
| 632 | 
         
            +
                            position_ids = position_ids[:, -input_ids.shape[1] :]
         
     | 
| 633 | 
         
            +
             
     | 
| 634 | 
         
            +
                    if inputs_embeds is not None and past_length == 0:
         
     | 
| 635 | 
         
            +
                        model_inputs = {"inputs_embeds": inputs_embeds}
         
     | 
| 636 | 
         
            +
                    else:
         
     | 
| 637 | 
         
            +
                        model_inputs = {"input_ids": input_ids.contiguous()}
         
     | 
| 638 | 
         
            +
             
     | 
| 639 | 
         
            +
                    input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]  # noqa
         
     | 
| 640 | 
         
            +
                    if cache_position is None:
         
     | 
| 641 | 
         
            +
                        cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)  # noqa
         
     | 
| 642 | 
         
            +
                    elif use_cache:
         
     | 
| 643 | 
         
            +
                        cache_position = cache_position[-input_length:]
         
     | 
| 644 | 
         
            +
             
     | 
| 645 | 
         
            +
                    model_inputs.update(
         
     | 
| 646 | 
         
            +
                        {
         
     | 
| 647 | 
         
            +
                            "position_ids": position_ids,
         
     | 
| 648 | 
         
            +
                            "cache_position": cache_position,
         
     | 
| 649 | 
         
            +
                            "past_key_values": past_key_values,
         
     | 
| 650 | 
         
            +
                            "use_cache": use_cache,
         
     | 
| 651 | 
         
            +
                            "attention_mask": attention_mask,
         
     | 
| 652 | 
         
            +
                        }
         
     | 
| 653 | 
         
            +
                    )
         
     | 
| 654 | 
         
            +
                    return model_inputs
         
     | 
| 655 | 
         
            +
             
     | 
| 656 | 
         
            +
                @staticmethod
         
     | 
| 657 | 
         
            +
                def _reorder_cache(past_key_values, beam_idx):
         
     | 
| 658 | 
         
            +
                    reordered_past = ()
         
     | 
| 659 | 
         
            +
                    for layer_past in past_key_values:
         
     | 
| 660 | 
         
            +
                        reordered_past += (
         
     | 
| 661 | 
         
            +
                            tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),  # noqa
         
     | 
| 662 | 
         
            +
                        )
         
     | 
| 663 | 
         
            +
                    return reordered_past
         
     | 
    	
        special_tokens_map.json
    ADDED
    
    | 
         @@ -0,0 +1,23 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            {
         
     | 
| 2 | 
         
            +
              "bos_token": {
         
     | 
| 3 | 
         
            +
                "content": "<s>",
         
     | 
| 4 | 
         
            +
                "lstrip": false,
         
     | 
| 5 | 
         
            +
                "normalized": false,
         
     | 
| 6 | 
         
            +
                "rstrip": false,
         
     | 
| 7 | 
         
            +
                "single_word": false
         
     | 
| 8 | 
         
            +
              },
         
     | 
| 9 | 
         
            +
              "eos_token": {
         
     | 
| 10 | 
         
            +
                "content": "</s>",
         
     | 
| 11 | 
         
            +
                "lstrip": false,
         
     | 
| 12 | 
         
            +
                "normalized": false,
         
     | 
| 13 | 
         
            +
                "rstrip": false,
         
     | 
| 14 | 
         
            +
                "single_word": false
         
     | 
| 15 | 
         
            +
              },
         
     | 
| 16 | 
         
            +
              "unk_token": {
         
     | 
| 17 | 
         
            +
                "content": "<unk>",
         
     | 
| 18 | 
         
            +
                "lstrip": false,
         
     | 
| 19 | 
         
            +
                "normalized": false,
         
     | 
| 20 | 
         
            +
                "rstrip": false,
         
     | 
| 21 | 
         
            +
                "single_word": false
         
     | 
| 22 | 
         
            +
              }
         
     | 
| 23 | 
         
            +
            }
         
     | 
    	
        tokenizer.model
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
         
     | 
| 3 | 
         
            +
            size 499723
         
     | 
    	
        tokenizer_config.json
    ADDED
    
    | 
         @@ -0,0 +1,43 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            {
         
     | 
| 2 | 
         
            +
              "add_bos_token": true,
         
     | 
| 3 | 
         
            +
              "add_eos_token": false,
         
     | 
| 4 | 
         
            +
              "add_prefix_space": true,
         
     | 
| 5 | 
         
            +
              "added_tokens_decoder": {
         
     | 
| 6 | 
         
            +
                "0": {
         
     | 
| 7 | 
         
            +
                  "content": "<unk>",
         
     | 
| 8 | 
         
            +
                  "lstrip": false,
         
     | 
| 9 | 
         
            +
                  "normalized": false,
         
     | 
| 10 | 
         
            +
                  "rstrip": false,
         
     | 
| 11 | 
         
            +
                  "single_word": false,
         
     | 
| 12 | 
         
            +
                  "special": true
         
     | 
| 13 | 
         
            +
                },
         
     | 
| 14 | 
         
            +
                "1": {
         
     | 
| 15 | 
         
            +
                  "content": "<s>",
         
     | 
| 16 | 
         
            +
                  "lstrip": false,
         
     | 
| 17 | 
         
            +
                  "normalized": false,
         
     | 
| 18 | 
         
            +
                  "rstrip": false,
         
     | 
| 19 | 
         
            +
                  "single_word": false,
         
     | 
| 20 | 
         
            +
                  "special": true
         
     | 
| 21 | 
         
            +
                },
         
     | 
| 22 | 
         
            +
                "2": {
         
     | 
| 23 | 
         
            +
                  "content": "</s>",
         
     | 
| 24 | 
         
            +
                  "lstrip": false,
         
     | 
| 25 | 
         
            +
                  "normalized": false,
         
     | 
| 26 | 
         
            +
                  "rstrip": false,
         
     | 
| 27 | 
         
            +
                  "single_word": false,
         
     | 
| 28 | 
         
            +
                  "special": true
         
     | 
| 29 | 
         
            +
                }
         
     | 
| 30 | 
         
            +
              },
         
     | 
| 31 | 
         
            +
              "bos_token": "<s>",
         
     | 
| 32 | 
         
            +
              "clean_up_tokenization_spaces": false,
         
     | 
| 33 | 
         
            +
              "eos_token": "</s>",
         
     | 
| 34 | 
         
            +
              "legacy": false,
         
     | 
| 35 | 
         
            +
              "model_max_length": 1000000000000000019884624838656,
         
     | 
| 36 | 
         
            +
              "pad_token": null,
         
     | 
| 37 | 
         
            +
              "padding_side": "right",
         
     | 
| 38 | 
         
            +
              "sp_model_kwargs": {},
         
     | 
| 39 | 
         
            +
              "spaces_between_special_tokens": false,
         
     | 
| 40 | 
         
            +
              "tokenizer_class": "LlamaTokenizer",
         
     | 
| 41 | 
         
            +
              "unk_token": "<unk>",
         
     | 
| 42 | 
         
            +
              "use_default_system_prompt": false
         
     | 
| 43 | 
         
            +
            }
         
     |