duzx16
commited on
Commit
·
0cfae21
1
Parent(s):
aea6cef
Fix backward for quantization
Browse files- modeling_chatglm.py +6 -8
- quantization.py +2 -2
modeling_chatglm.py
CHANGED
|
@@ -134,11 +134,11 @@ def load_tf_weights_in_chatglm_6b(model, config, tf_checkpoint_path):
|
|
| 134 |
|
| 135 |
|
| 136 |
class PrefixEncoder(torch.nn.Module):
|
| 137 |
-
|
| 138 |
The torch.nn model to encode the prefix
|
| 139 |
Input shape: (batch-size, prefix-length)
|
| 140 |
Output shape: (batch-size, prefix-length, 2*layers*hidden)
|
| 141 |
-
|
| 142 |
def __init__(self, config):
|
| 143 |
super().__init__()
|
| 144 |
self.prefix_projection = config.prefix_projection
|
|
@@ -148,7 +148,7 @@ class PrefixEncoder(torch.nn.Module):
|
|
| 148 |
self.trans = torch.nn.Sequential(
|
| 149 |
torch.nn.Linear(config.hidden_size, config.hidden_size),
|
| 150 |
torch.nn.Tanh(),
|
| 151 |
-
torch.nn.Linear(config.hidden_size, config.num_layers
|
| 152 |
)
|
| 153 |
else:
|
| 154 |
self.embedding = torch.nn.Embedding(config.pre_seq_len, config.num_layers * config.hidden_size * 2)
|
|
@@ -814,7 +814,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
| 814 |
self.num_attention_heads,
|
| 815 |
self.hidden_size // self.num_attention_heads
|
| 816 |
)
|
| 817 |
-
#seq_len, b, nh, hidden_size
|
| 818 |
past_key_values = self.dropout(past_key_values)
|
| 819 |
past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2)
|
| 820 |
# past_key_values = [(v[0], v[1]) for v in past_key_values]
|
|
@@ -909,7 +909,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
| 909 |
)
|
| 910 |
|
| 911 |
if self.pre_seq_len is not None:
|
| 912 |
-
prefix_attention_mask = torch.ones(batch_size, 1, input_ids.size(-1), self.pre_seq_len).to(
|
|
|
|
| 913 |
prefix_attention_mask = (prefix_attention_mask < 0.5).bool()
|
| 914 |
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3)
|
| 915 |
|
|
@@ -942,9 +943,6 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
| 942 |
else:
|
| 943 |
attention_mask = attention_mask.to(input_ids.device)
|
| 944 |
|
| 945 |
-
if self.training:
|
| 946 |
-
hidden_states = hidden_states.requires_grad_(True)
|
| 947 |
-
|
| 948 |
for i, layer in enumerate(self.layers):
|
| 949 |
|
| 950 |
if output_hidden_states:
|
|
|
|
| 134 |
|
| 135 |
|
| 136 |
class PrefixEncoder(torch.nn.Module):
|
| 137 |
+
"""
|
| 138 |
The torch.nn model to encode the prefix
|
| 139 |
Input shape: (batch-size, prefix-length)
|
| 140 |
Output shape: (batch-size, prefix-length, 2*layers*hidden)
|
| 141 |
+
"""
|
| 142 |
def __init__(self, config):
|
| 143 |
super().__init__()
|
| 144 |
self.prefix_projection = config.prefix_projection
|
|
|
|
| 148 |
self.trans = torch.nn.Sequential(
|
| 149 |
torch.nn.Linear(config.hidden_size, config.hidden_size),
|
| 150 |
torch.nn.Tanh(),
|
| 151 |
+
torch.nn.Linear(config.hidden_size, config.num_layers * config.hidden_size * 2)
|
| 152 |
)
|
| 153 |
else:
|
| 154 |
self.embedding = torch.nn.Embedding(config.pre_seq_len, config.num_layers * config.hidden_size * 2)
|
|
|
|
| 814 |
self.num_attention_heads,
|
| 815 |
self.hidden_size // self.num_attention_heads
|
| 816 |
)
|
| 817 |
+
# seq_len, b, nh, hidden_size
|
| 818 |
past_key_values = self.dropout(past_key_values)
|
| 819 |
past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2)
|
| 820 |
# past_key_values = [(v[0], v[1]) for v in past_key_values]
|
|
|
|
| 909 |
)
|
| 910 |
|
| 911 |
if self.pre_seq_len is not None:
|
| 912 |
+
prefix_attention_mask = torch.ones(batch_size, 1, input_ids.size(-1), self.pre_seq_len).to(
|
| 913 |
+
attention_mask.device)
|
| 914 |
prefix_attention_mask = (prefix_attention_mask < 0.5).bool()
|
| 915 |
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3)
|
| 916 |
|
|
|
|
| 943 |
else:
|
| 944 |
attention_mask = attention_mask.to(input_ids.device)
|
| 945 |
|
|
|
|
|
|
|
|
|
|
| 946 |
for i, layer in enumerate(self.layers):
|
| 947 |
|
| 948 |
if output_hidden_states:
|
quantization.py
CHANGED
|
@@ -14,11 +14,11 @@ class W8A16Linear(torch.autograd.Function):
|
|
| 14 |
@staticmethod
|
| 15 |
def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor, weight_bit_width):
|
| 16 |
ctx.inp_shape = inp.size()
|
| 17 |
-
ctx.weight_shape = quant_w.size()
|
| 18 |
ctx.weight_bit_width = weight_bit_width
|
| 19 |
out_features = quant_w.size(0)
|
| 20 |
inp = inp.contiguous().view(-1, inp.size(-1))
|
| 21 |
weight = extract_weight_to_half(quant_w, scale_w, weight_bit_width)
|
|
|
|
| 22 |
output = inp.mm(weight.t())
|
| 23 |
ctx.save_for_backward(inp, quant_w, scale_w)
|
| 24 |
return output.view(*(ctx.inp_shape[:-1] + (out_features,)))
|
|
@@ -30,7 +30,7 @@ class W8A16Linear(torch.autograd.Function):
|
|
| 30 |
grad_output = grad_output.contiguous().view(-1, weight.size(0))
|
| 31 |
grad_input = grad_output.mm(weight)
|
| 32 |
grad_weight = grad_output.t().mm(inp)
|
| 33 |
-
return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None
|
| 34 |
|
| 35 |
|
| 36 |
class Kernel:
|
|
|
|
| 14 |
@staticmethod
|
| 15 |
def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor, weight_bit_width):
|
| 16 |
ctx.inp_shape = inp.size()
|
|
|
|
| 17 |
ctx.weight_bit_width = weight_bit_width
|
| 18 |
out_features = quant_w.size(0)
|
| 19 |
inp = inp.contiguous().view(-1, inp.size(-1))
|
| 20 |
weight = extract_weight_to_half(quant_w, scale_w, weight_bit_width)
|
| 21 |
+
ctx.weight_shape = weight.size()
|
| 22 |
output = inp.mm(weight.t())
|
| 23 |
ctx.save_for_backward(inp, quant_w, scale_w)
|
| 24 |
return output.view(*(ctx.inp_shape[:-1] + (out_features,)))
|
|
|
|
| 30 |
grad_output = grad_output.contiguous().view(-1, weight.size(0))
|
| 31 |
grad_input = grad_output.mm(weight)
|
| 32 |
grad_weight = grad_output.t().mm(inp)
|
| 33 |
+
return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None, None
|
| 34 |
|
| 35 |
|
| 36 |
class Kernel:
|