Upload convert_to_linear.py with huggingface_hub
Browse files- convert_to_linear.py +20 -0
    	
        convert_to_linear.py
    CHANGED
    
    | @@ -109,3 +109,23 @@ model = AutoModelForCausalLM.from_pretrained( | |
| 109 | 
             
                torch_dtype=torch.bfloat16
         | 
| 110 | 
             
            )
         | 
| 111 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 109 | 
             
                torch_dtype=torch.bfloat16
         | 
| 110 | 
             
            )
         | 
| 111 |  | 
| 112 | 
            +
             | 
| 113 | 
            +
            # or to convert on the go, use
         | 
| 114 | 
            +
             | 
| 115 | 
            +
            from transformers import AutoModelForCausalLM, AutoTokenizer
         | 
| 116 | 
            +
            import torch
         | 
| 117 | 
            +
             | 
| 118 | 
            +
            model = AutoModelForCausalLM.from_pretrained(
         | 
| 119 | 
            +
                "openai/gpt-oss-20b",
         | 
| 120 | 
            +
                device_map='cuda:0', # modify appropriately.
         | 
| 121 | 
            +
                torch_dtype=torch.bfloat16
         | 
| 122 | 
            +
            )
         | 
| 123 | 
            +
             | 
| 124 | 
            +
            from tqdm import tqdm
         | 
| 125 | 
            +
            for layer in tqdm(model.model.layers):
         | 
| 126 | 
            +
                experts = layer.mlp.experts
         | 
| 127 | 
            +
                if isinstance(experts, GptOssExperts):
         | 
| 128 | 
            +
                    new_experts = convert_to_linear_experts(experts, model.config) # function is defined in the file above
         | 
| 129 | 
            +
                    layer.mlp.experts = new_experts.to(model.device, model.dtype)
         | 
| 130 | 
            +
            print('✅ All experts converted to linear')
         | 
| 131 | 
            +
             | 
