OlivierDehaene
commited on
Commit
·
f0792b2
1
Parent(s):
02857b0
fuse qkv
Browse files- model.safetensors +2 -2
- modeling_gpt2_mq.py +15 -25
model.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ecb114682d0f35efe851ed162500e9d7babdf7c8c008fdcf76ec679e9a788533
|
| 3 |
+
size 4903278480
|
modeling_gpt2_mq.py
CHANGED
|
@@ -13,7 +13,6 @@ from transformers.modeling_outputs import (
|
|
| 13 |
)
|
| 14 |
from transformers.models.gpt2.modeling_gpt2 import GPT2Model, GPT2Block, GPT2PreTrainedModel, GPT2LMHeadModel
|
| 15 |
from transformers.utils import logging
|
| 16 |
-
|
| 17 |
from .configuration_gpt2_mq import GPT2CustomConfig, MULTI_QUERY
|
| 18 |
|
| 19 |
logger = logging.get_logger(__name__)
|
|
@@ -130,10 +129,7 @@ class GPT2MQAttention(nn.Module):
|
|
| 130 |
if self.is_cross_attention:
|
| 131 |
raise NotImplementedError("Cross-attention not implemented for MQA")
|
| 132 |
else:
|
| 133 |
-
|
| 134 |
-
self.q_attn = nn.Linear(self.embed_dim, self.embed_dim)
|
| 135 |
-
# Keys and values are shared across heads
|
| 136 |
-
self.kv_attn = nn.Linear(self.embed_dim, 2 * self.head_dim)
|
| 137 |
self.c_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
| 138 |
|
| 139 |
self.attn_dropout = nn.Dropout(config.attn_pdrop)
|
|
@@ -143,13 +139,13 @@ class GPT2MQAttention(nn.Module):
|
|
| 143 |
self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
|
| 144 |
|
| 145 |
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
|
| 146 |
-
# query: (b,
|
| 147 |
# key: (b, head_dim, sk)
|
| 148 |
# value: (b, sk, head_dim)
|
| 149 |
batch_size = query.size(0)
|
| 150 |
query_length = query.size(1) // self.num_heads
|
| 151 |
key_length = key.size(2)
|
| 152 |
-
# (b,
|
| 153 |
|
| 154 |
if self.scale_attn_weights:
|
| 155 |
query *= self.inv_norm_factor
|
|
@@ -157,7 +153,7 @@ class GPT2MQAttention(nn.Module):
|
|
| 157 |
attn_weights = torch.bmm(query, key)
|
| 158 |
|
| 159 |
# -> (b, num_heads, sq, sk)
|
| 160 |
-
attn_weights = attn_weights.view(batch_size, self.num_heads,
|
| 161 |
|
| 162 |
# Layer-wise attention scaling
|
| 163 |
if self.scale_attn_by_inverse_layer_idx:
|
|
@@ -174,13 +170,13 @@ class GPT2MQAttention(nn.Module):
|
|
| 174 |
|
| 175 |
# Mask heads if we want to
|
| 176 |
if head_mask is not None:
|
| 177 |
-
|
| 178 |
|
| 179 |
# (b, num_heads, sq, sk) -> (b, num_heads * sq, sk)
|
| 180 |
-
_attn_weights = attn_weights.view(batch_size, self.num_heads
|
| 181 |
# (b, num_heads * sq, sk) x (b, sk, head_dim) -> (b, num_heads * sq, head_dim)
|
| 182 |
attn_output = torch.bmm(_attn_weights, value)
|
| 183 |
-
attn_output = attn_output.view(batch_size, self.num_heads,
|
| 184 |
|
| 185 |
return attn_output, attn_weights
|
| 186 |
|
|
@@ -188,10 +184,8 @@ class GPT2MQAttention(nn.Module):
|
|
| 188 |
"""
|
| 189 |
Merges attn_head_size dim and num_attn_heads dim into hidden_size
|
| 190 |
"""
|
| 191 |
-
batch_size,
|
| 192 |
-
|
| 193 |
-
tensor = tensor.permute(0, 2, 1, 3)
|
| 194 |
-
return tensor.reshape(batch_size, seq_length, num_heads * head_dim)
|
| 195 |
|
| 196 |
def forward(
|
| 197 |
self,
|
|
@@ -207,17 +201,14 @@ class GPT2MQAttention(nn.Module):
|
|
| 207 |
if encoder_hidden_states is not None:
|
| 208 |
raise NotImplementedError("Cross-attention not implemented for MQA")
|
| 209 |
else:
|
| 210 |
-
|
| 211 |
-
key, value =
|
| 212 |
|
| 213 |
batch_size, seq_length = query.shape[:2]
|
| 214 |
-
# (query_length, batch, num_heads, head_dim)
|
| 215 |
-
# (batch, num_heads * query_length, head_dim)\
|
| 216 |
|
| 217 |
-
# (batch, query_length, hidden_size) -> (batch, num_heads,
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
query = query.reshape(batch_size, self.num_heads * seq_length, self.head_dim)
|
| 221 |
|
| 222 |
key = key.transpose(1, 2) # (batch_size, head_dim, seq_length)
|
| 223 |
|
|
@@ -360,8 +351,7 @@ class GPT2CustomModel(GPT2Model):
|
|
| 360 |
past_key_values_length=past_key_values_length,
|
| 361 |
)
|
| 362 |
|
| 363 |
-
attention_mask = attention_mask.unsqueeze(
|
| 364 |
-
*attention_mask.shape[1:])
|
| 365 |
|
| 366 |
# If a 2D or 3D attention mask is provided for the cross-attention
|
| 367 |
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
|
|
|
| 13 |
)
|
| 14 |
from transformers.models.gpt2.modeling_gpt2 import GPT2Model, GPT2Block, GPT2PreTrainedModel, GPT2LMHeadModel
|
| 15 |
from transformers.utils import logging
|
|
|
|
| 16 |
from .configuration_gpt2_mq import GPT2CustomConfig, MULTI_QUERY
|
| 17 |
|
| 18 |
logger = logging.get_logger(__name__)
|
|
|
|
| 129 |
if self.is_cross_attention:
|
| 130 |
raise NotImplementedError("Cross-attention not implemented for MQA")
|
| 131 |
else:
|
| 132 |
+
self.attn = nn.Linear(self.embed_dim, self.embed_dim + 2 * self.head_dim)
|
|
|
|
|
|
|
|
|
|
| 133 |
self.c_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
| 134 |
|
| 135 |
self.attn_dropout = nn.Dropout(config.attn_pdrop)
|
|
|
|
| 139 |
self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
|
| 140 |
|
| 141 |
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
|
| 142 |
+
# query: (b, sq * num_heads, head_dim)
|
| 143 |
# key: (b, head_dim, sk)
|
| 144 |
# value: (b, sk, head_dim)
|
| 145 |
batch_size = query.size(0)
|
| 146 |
query_length = query.size(1) // self.num_heads
|
| 147 |
key_length = key.size(2)
|
| 148 |
+
# (b, sq * num_heads, head_dim) x (b, head_dim, sk) -> (b, sq * num_heads, sk)
|
| 149 |
|
| 150 |
if self.scale_attn_weights:
|
| 151 |
query *= self.inv_norm_factor
|
|
|
|
| 153 |
attn_weights = torch.bmm(query, key)
|
| 154 |
|
| 155 |
# -> (b, num_heads, sq, sk)
|
| 156 |
+
attn_weights = attn_weights.view(batch_size, query_length, self.num_heads, key_length)
|
| 157 |
|
| 158 |
# Layer-wise attention scaling
|
| 159 |
if self.scale_attn_by_inverse_layer_idx:
|
|
|
|
| 170 |
|
| 171 |
# Mask heads if we want to
|
| 172 |
if head_mask is not None:
|
| 173 |
+
raise NotImplementedError
|
| 174 |
|
| 175 |
# (b, num_heads, sq, sk) -> (b, num_heads * sq, sk)
|
| 176 |
+
_attn_weights = attn_weights.view(batch_size, query_length * self.num_heads, key_length)
|
| 177 |
# (b, num_heads * sq, sk) x (b, sk, head_dim) -> (b, num_heads * sq, head_dim)
|
| 178 |
attn_output = torch.bmm(_attn_weights, value)
|
| 179 |
+
attn_output = attn_output.view(batch_size, query_length, self.num_heads, self.head_dim)
|
| 180 |
|
| 181 |
return attn_output, attn_weights
|
| 182 |
|
|
|
|
| 184 |
"""
|
| 185 |
Merges attn_head_size dim and num_attn_heads dim into hidden_size
|
| 186 |
"""
|
| 187 |
+
batch_size, seq_length, num_heads, head_dim = tensor.shape
|
| 188 |
+
return tensor.view(batch_size, seq_length, num_heads * head_dim)
|
|
|
|
|
|
|
| 189 |
|
| 190 |
def forward(
|
| 191 |
self,
|
|
|
|
| 201 |
if encoder_hidden_states is not None:
|
| 202 |
raise NotImplementedError("Cross-attention not implemented for MQA")
|
| 203 |
else:
|
| 204 |
+
qkv = self.attn(hidden_states)
|
| 205 |
+
query, key, value = qkv.split([self.embed_dim, self.head_dim, self.head_dim], dim=2)
|
| 206 |
|
| 207 |
batch_size, seq_length = query.shape[:2]
|
|
|
|
|
|
|
| 208 |
|
| 209 |
+
# (batch, query_length, hidden_size) -> (batch, query_length * num_heads, head_dim)
|
| 210 |
+
# forced to reshape here
|
| 211 |
+
query = query.reshape(batch_size, seq_length * self.num_heads, self.head_dim)
|
|
|
|
| 212 |
|
| 213 |
key = key.transpose(1, 2) # (batch_size, head_dim, seq_length)
|
| 214 |
|
|
|
|
| 351 |
past_key_values_length=past_key_values_length,
|
| 352 |
)
|
| 353 |
|
| 354 |
+
attention_mask = attention_mask.unsqueeze(2).expand(batch_size, attention_mask.shape[1], self.config.num_attention_heads, attention_mask.shape[2])
|
|
|
|
| 355 |
|
| 356 |
# If a 2D or 3D attention mask is provided for the cross-attention
|
| 357 |
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|