Update modeling_glm.py
Browse files- modeling_glm.py +27 -12
modeling_glm.py
CHANGED
|
@@ -417,7 +417,7 @@ class GlmSdpaAttention(GlmAttention):
|
|
| 417 |
)
|
| 418 |
|
| 419 |
bsz, q_len, _ = hidden_states.size()
|
| 420 |
-
|
| 421 |
query_states = self.q_proj(hidden_states)
|
| 422 |
key_states = self.k_proj(hidden_states)
|
| 423 |
value_states = self.v_proj(hidden_states)
|
|
@@ -425,7 +425,7 @@ class GlmSdpaAttention(GlmAttention):
|
|
| 425 |
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 426 |
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 427 |
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 428 |
-
|
| 429 |
cos, sin = position_embeddings
|
| 430 |
query_states, key_states = apply_rotary_pos_emb(
|
| 431 |
query_states, key_states, cos, sin, partial_rotary_factor=self.partial_rotary_factor
|
|
@@ -763,21 +763,36 @@ class GlmModel(GlmPreTrainedModel):
|
|
| 763 |
assert input_ids is not None and inputs_embeds is None, f"{input_ids} {inputs_embeds}"
|
| 764 |
inputs_embeds = self.embed_tokens(input_ids)
|
| 765 |
new_input_embeds = []
|
| 766 |
-
|
| 767 |
-
|
| 768 |
-
|
| 769 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 770 |
image_count = 0
|
| 771 |
for i in range(len(input_ids)):
|
| 772 |
input_id = input_ids[i].tolist()
|
| 773 |
-
if
|
| 774 |
boi_token_pos = input_id.index(self.config.boi_token_id)
|
| 775 |
assert boi_token_pos >= 0, "begin_of_image not found!"
|
| 776 |
num_image_padding_tokens = input_id.count(self.config.boi_token_id)
|
| 777 |
-
assert
|
| 778 |
-
|
| 779 |
-
|
| 780 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 781 |
image_count += 1
|
| 782 |
else:
|
| 783 |
new_input_embeds.append(inputs_embeds[i] + (0 * images_features[0].sum()))
|
|
@@ -1316,4 +1331,4 @@ __all__ = [
|
|
| 1316 |
"GlmModel",
|
| 1317 |
"GlmForCausalLM",
|
| 1318 |
"GlmForSequenceClassification",
|
| 1319 |
-
]
|
|
|
|
| 417 |
)
|
| 418 |
|
| 419 |
bsz, q_len, _ = hidden_states.size()
|
| 420 |
+
|
| 421 |
query_states = self.q_proj(hidden_states)
|
| 422 |
key_states = self.k_proj(hidden_states)
|
| 423 |
value_states = self.v_proj(hidden_states)
|
|
|
|
| 425 |
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 426 |
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 427 |
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 428 |
+
|
| 429 |
cos, sin = position_embeddings
|
| 430 |
query_states, key_states = apply_rotary_pos_emb(
|
| 431 |
query_states, key_states, cos, sin, partial_rotary_factor=self.partial_rotary_factor
|
|
|
|
| 763 |
assert input_ids is not None and inputs_embeds is None, f"{input_ids} {inputs_embeds}"
|
| 764 |
inputs_embeds = self.embed_tokens(input_ids)
|
| 765 |
new_input_embeds = []
|
| 766 |
+
multi_flags = [True if self.config.boi_token_id in input_id.tolist() else False for input_id in input_ids]
|
| 767 |
+
images_features = None
|
| 768 |
+
if not is_empty(images) and images.bool().any():
|
| 769 |
+
imgs = list()
|
| 770 |
+
for i in range(len(multi_flags)):
|
| 771 |
+
if multi_flags[i]:
|
| 772 |
+
imgs.append(images[i])
|
| 773 |
+
imgs = torch.stack(imgs, dim=0)
|
| 774 |
+
else:
|
| 775 |
+
imgs = torch.unsqueeze(images[0], 0)
|
| 776 |
+
images_features = self.vision(imgs).to(inputs_embeds.dtype)
|
| 777 |
image_count = 0
|
| 778 |
for i in range(len(input_ids)):
|
| 779 |
input_id = input_ids[i].tolist()
|
| 780 |
+
if multi_flags[i]:
|
| 781 |
boi_token_pos = input_id.index(self.config.boi_token_id)
|
| 782 |
assert boi_token_pos >= 0, "begin_of_image not found!"
|
| 783 |
num_image_padding_tokens = input_id.count(self.config.boi_token_id)
|
| 784 |
+
assert (
|
| 785 |
+
num_image_padding_tokens == images_features[image_count].shape[0]
|
| 786 |
+
), f"Wrong image padding token number: {num_image_padding_tokens}"
|
| 787 |
+
new_input_embeds.append(
|
| 788 |
+
torch.cat(
|
| 789 |
+
(
|
| 790 |
+
inputs_embeds[i, :boi_token_pos],
|
| 791 |
+
images_features[image_count].to(inputs_embeds.device),
|
| 792 |
+
inputs_embeds[i, boi_token_pos + num_image_padding_tokens :],
|
| 793 |
+
)
|
| 794 |
+
)
|
| 795 |
+
)
|
| 796 |
image_count += 1
|
| 797 |
else:
|
| 798 |
new_input_embeds.append(inputs_embeds[i] + (0 * images_features[0].sum()))
|
|
|
|
| 1331 |
"GlmModel",
|
| 1332 |
"GlmForCausalLM",
|
| 1333 |
"GlmForSequenceClassification",
|
| 1334 |
+
]
|