imdatta0 commited on
Commit
29250b7
·
verified ·
1 Parent(s): 5c7bb87

Upload convert_to_linear.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. convert_to_linear.py +11 -0
convert_to_linear.py CHANGED
@@ -95,6 +95,17 @@ class NewGptOssExperts(nn.Module):
95
  return mixed.view(batch_size, -1, self.hidden_size)
96
 
97
 
 
98
  # monkey patch to linear
99
  from transformers.models.gpt_oss import modeling_gpt_oss
100
  modeling_gpt_oss.GptOssExperts = NewGptOssExperts
 
 
 
 
 
 
 
 
 
 
 
95
  return mixed.view(batch_size, -1, self.hidden_size)
96
 
97
 
98
+ # to load do
99
  # monkey patch to linear
100
  from transformers.models.gpt_oss import modeling_gpt_oss
101
  modeling_gpt_oss.GptOssExperts = NewGptOssExperts
102
+
103
+ from transformers import AutoModelForCausalLM, AutoTokenizer
104
+ import torch
105
+
106
+ model = AutoModelForCausalLM.from_pretrained(
107
+ "imdatta0/gpt_oss_20b_linear", # make sure you load the right weights
108
+ device_map='cuda:0', # modify appropriately.
109
+ torch_dtype=torch.bfloat16
110
+ )
111
+