zR
commited on
Commit
·
ca14f13
1
Parent(s):
d2415e6
update
Browse files- modeling_cogvlm.py +34 -37
modeling_cogvlm.py
CHANGED
|
@@ -8,26 +8,17 @@ from torch import nn
|
|
| 8 |
from torch.nn import CrossEntropyLoss
|
| 9 |
from torchvision import transforms
|
| 10 |
from einops import rearrange
|
| 11 |
-
|
| 12 |
-
from decord import VideoReader, cpu
|
| 13 |
-
import decord
|
| 14 |
-
import io
|
| 15 |
-
import numpy as np
|
| 16 |
-
|
| 17 |
from transformers import PreTrainedModel, PreTrainedTokenizer
|
| 18 |
from transformers.utils.logging import get_logger
|
| 19 |
from transformers.activations import ACT2FN
|
| 20 |
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
| 21 |
-
from torchvision.transforms.functional import InterpolationMode
|
| 22 |
from torchvision.transforms import Lambda
|
| 23 |
-
from torchvision.transforms._transforms_video import NormalizeVideo,
|
| 24 |
-
from pytorchvideo.transforms import
|
| 25 |
from .configuration_cogvlm import CogVLMConfig
|
| 26 |
from .util import FastRotaryEmbedding
|
| 27 |
from .visual import EVA2CLIPModel
|
| 28 |
|
| 29 |
-
|
| 30 |
-
|
| 31 |
if TYPE_CHECKING:
|
| 32 |
from transformers.utils import ModelOutput
|
| 33 |
|
|
@@ -101,7 +92,8 @@ class MLP(nn.Module):
|
|
| 101 |
|
| 102 |
def get_expert_mask(token_type_ids: "torch.LongTensor(B, L)") -> "[torch.BoolTensor(B, L), torch.BoolTensor(B, L)]":
|
| 103 |
vision_token_mask = torch.zeros_like(token_type_ids, dtype=torch.bool)
|
| 104 |
-
vision_token_mask[:, :-1] = (token_type_ids[:, :-1] == VISION_TOKEN_TYPE) & (
|
|
|
|
| 105 |
language_token_mask = ~vision_token_mask
|
| 106 |
return vision_token_mask, language_token_mask
|
| 107 |
|
|
@@ -117,7 +109,7 @@ class VisionExpertMLP(nn.Module):
|
|
| 117 |
# vision_token_mask, language_token_mask = get_expert_mask(token_type_ids)
|
| 118 |
# output[vision_token_mask] = self.vision_mlp(hidden_states[vision_token_mask])
|
| 119 |
# output[language_token_mask] = self.language_mlp(hidden_states[language_token_mask])
|
| 120 |
-
|
| 121 |
output = self.language_mlp(hidden_states)
|
| 122 |
return output
|
| 123 |
|
|
@@ -177,7 +169,7 @@ class VisionExpertAttention(nn.Module):
|
|
| 177 |
def _transpose_for_scores(self, tensor):
|
| 178 |
"""Transpose a 3D tensor [B, L, H*HD] into a 4D tensor with size [B H L HD]."""
|
| 179 |
new_tensor_shape = tensor.size()[:-1] + \
|
| 180 |
-
(-1,
|
| 181 |
self.hidden_size_per_attention_head)
|
| 182 |
tensor = tensor.view(*new_tensor_shape)
|
| 183 |
return tensor.permute(0, 2, 1, 3)
|
|
@@ -214,7 +206,8 @@ class VisionExpertAttention(nn.Module):
|
|
| 214 |
if past_key_value is not None:
|
| 215 |
kv_seq_len += past_key_value[0].shape[-2]
|
| 216 |
|
| 217 |
-
query_states, key_states = self.rotary_emb(query_states, key_states, position_ids=position_ids,
|
|
|
|
| 218 |
|
| 219 |
if past_key_value is not None:
|
| 220 |
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
|
@@ -222,10 +215,13 @@ class VisionExpertAttention(nn.Module):
|
|
| 222 |
|
| 223 |
past_key_value = (key_states, value_states) if use_cache else None
|
| 224 |
|
| 225 |
-
key_states = key_states.unsqueeze(2).expand(-1, -1, self.num_attention_heads // self.num_multi_query_heads, -1,
|
|
|
|
| 226 |
bsz, self.num_attention_heads, *key_states.shape[2:])
|
| 227 |
-
value_states = value_states.unsqueeze(2).expand(-1, -1, self.num_attention_heads // self.num_multi_query_heads,
|
| 228 |
-
-1
|
|
|
|
|
|
|
| 229 |
|
| 230 |
context_layer = attention_fn(
|
| 231 |
query_layer=query_states, key_layer=key_states, value_layer=value_states, attention_mask=attention_mask,
|
|
@@ -240,7 +236,7 @@ class VisionExpertAttention(nn.Module):
|
|
| 240 |
# attn_output = torch.empty(context_layer.shape, dtype=hidden_states.dtype, device=hidden_states.device)
|
| 241 |
# attn_output[vision_token_mask] = self.vision_expert_dense(context_layer[vision_token_mask])
|
| 242 |
# attn_output[language_token_mask] = self.language_expert_dense(context_layer[language_token_mask])
|
| 243 |
-
|
| 244 |
attn_output = self.language_expert_dense(context_layer)
|
| 245 |
|
| 246 |
if output_attentions:
|
|
@@ -329,7 +325,8 @@ def is_empty(images_list: Optional[List[List[torch.Tensor]]]):
|
|
| 329 |
return True
|
| 330 |
|
| 331 |
|
| 332 |
-
def build_position_ids(x: "torch.BoolTensor(B, L)",
|
|
|
|
| 333 |
if attention_mask is not None:
|
| 334 |
tmp = x.clone()
|
| 335 |
tmp[~(attention_mask.bool())] = -1
|
|
@@ -344,7 +341,8 @@ def build_position_ids(x: "torch.BoolTensor(B, L)", attention_mask: Optional["to
|
|
| 344 |
tmp[is_boi_eoi] = LANGUAGE_TOKEN_TYPE
|
| 345 |
# final position ids
|
| 346 |
y = torch.zeros_like(x, dtype=torch.long)
|
| 347 |
-
y[:, 1:] = (tmp[:, 1:] == LANGUAGE_TOKEN_TYPE) | (
|
|
|
|
| 348 |
y = y.cumsum(dim=-1)
|
| 349 |
return y
|
| 350 |
|
|
@@ -407,7 +405,8 @@ class CogVLMVideoModel(CogVLMPreTrainedModel):
|
|
| 407 |
inputs_embeds = inputs_embeds.index_put([token_type_ids == VISION_TOKEN_TYPE], images_features)
|
| 408 |
else: # single-modality
|
| 409 |
if token_type_ids is None:
|
| 410 |
-
token_type_ids = torch.ones_like(input_ids, dtype=torch.long,
|
|
|
|
| 411 |
assert not (token_type_ids == VISION_TOKEN_TYPE).any(), f"{(token_type_ids == VISION_TOKEN_TYPE).sum()}"
|
| 412 |
inputs_embeds = self.embed_tokens(input_ids)
|
| 413 |
|
|
@@ -588,7 +587,7 @@ class CogVLMVideoForCausalLM(CogVLMPreTrainedModel):
|
|
| 588 |
self.model = CogVLMVideoModel(config)
|
| 589 |
self.vocab_size = config.vocab_size
|
| 590 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 591 |
-
self.video_downsample = 1
|
| 592 |
|
| 593 |
# Initialize weights and apply final processing
|
| 594 |
self.post_init()
|
|
@@ -685,7 +684,8 @@ class CogVLMVideoForCausalLM(CogVLMPreTrainedModel):
|
|
| 685 |
return torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device) # type: ignore
|
| 686 |
|
| 687 |
def prepare_inputs_for_generation(
|
| 688 |
-
self, input_ids, token_type_ids, images=None, past_key_values=None, attention_mask=None, inputs_embeds=None,
|
|
|
|
| 689 |
):
|
| 690 |
# build position_ids if needed
|
| 691 |
position_ids = kwargs.get("position_ids", None)
|
|
@@ -732,7 +732,8 @@ class CogVLMVideoForCausalLM(CogVLMPreTrainedModel):
|
|
| 732 |
# update token_type_ids with last value
|
| 733 |
if "token_type_ids" in model_kwargs:
|
| 734 |
token_type_ids = model_kwargs["token_type_ids"]
|
| 735 |
-
new_token_type_ids = torch.ones(size=(token_type_ids.shape[0], 1), dtype=token_type_ids.dtype,
|
|
|
|
| 736 |
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, new_token_type_ids], dim=-1)
|
| 737 |
|
| 738 |
if not is_encoder_decoder:
|
|
@@ -761,8 +762,6 @@ class CogVLMVideoForCausalLM(CogVLMPreTrainedModel):
|
|
| 761 |
)
|
| 762 |
return reordered_past
|
| 763 |
|
| 764 |
-
|
| 765 |
-
|
| 766 |
def build_conversation_input_ids(
|
| 767 |
self,
|
| 768 |
tokenizer: "PreTrainedTokenizer",
|
|
@@ -780,7 +779,7 @@ class CogVLMVideoForCausalLM(CogVLMPreTrainedModel):
|
|
| 780 |
text = _history_to_prompt(template_version, history, query)
|
| 781 |
input_ids = [tokenizer.bos_token_id]
|
| 782 |
token_type_ids = [LANGUAGE_TOKEN_TYPE]
|
| 783 |
-
add_time_indices = False
|
| 784 |
if images is not None and len(images) == 1:
|
| 785 |
# vision
|
| 786 |
transform = transforms.Compose(
|
|
@@ -793,18 +792,19 @@ class CogVLMVideoForCausalLM(CogVLMPreTrainedModel):
|
|
| 793 |
# RandomHorizontalFlipVideo(p=0.5),
|
| 794 |
]
|
| 795 |
)
|
| 796 |
-
images = [transform(images[0]).transpose(0, 1)]
|
| 797 |
num_eois = len(images[0])
|
| 798 |
tokenizer.pad_token_id = 128002
|
| 799 |
-
vision_token_num = (64 + 2) * num_eois
|
| 800 |
if not add_time_indices:
|
| 801 |
-
|
|
|
|
| 802 |
token_type_ids += [VISION_TOKEN_TYPE] * vision_token_num
|
| 803 |
else:
|
| 804 |
video_ids, video_type_ids = [], []
|
|
|
|
| 805 |
for _time_idx in range(num_eois):
|
| 806 |
-
video_ids += [tokenizer.pad_token_id] *
|
| 807 |
-
video_type_ids += [VISION_TOKEN_TYPE] *
|
| 808 |
# add time indices
|
| 809 |
time_indices = tokenizer.encode(str(_time_idx), add_special_tokens=False)
|
| 810 |
video_ids += time_indices
|
|
@@ -812,7 +812,7 @@ class CogVLMVideoForCausalLM(CogVLMPreTrainedModel):
|
|
| 812 |
# llama3 adapt for cogvlm
|
| 813 |
input_ids += video_ids
|
| 814 |
token_type_ids += video_type_ids
|
| 815 |
-
|
| 816 |
text_ids = tokenizer.encode(text, add_special_tokens=False)
|
| 817 |
|
| 818 |
if answer is not None:
|
|
@@ -820,7 +820,6 @@ class CogVLMVideoForCausalLM(CogVLMPreTrainedModel):
|
|
| 820 |
answer_ids += [tokenizer.eos_token_id]
|
| 821 |
text_ids += answer_ids
|
| 822 |
|
| 823 |
-
|
| 824 |
input_ids += text_ids
|
| 825 |
token_type_ids += [LANGUAGE_TOKEN_TYPE] * len(text_ids)
|
| 826 |
attention_mask = [1] * len(input_ids)
|
|
@@ -837,5 +836,3 @@ class CogVLMVideoForCausalLM(CogVLMPreTrainedModel):
|
|
| 837 |
'images': images,
|
| 838 |
'labels': labels,
|
| 839 |
}
|
| 840 |
-
|
| 841 |
-
|
|
|
|
| 8 |
from torch.nn import CrossEntropyLoss
|
| 9 |
from torchvision import transforms
|
| 10 |
from einops import rearrange
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
from transformers import PreTrainedModel, PreTrainedTokenizer
|
| 12 |
from transformers.utils.logging import get_logger
|
| 13 |
from transformers.activations import ACT2FN
|
| 14 |
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
|
|
|
| 15 |
from torchvision.transforms import Lambda
|
| 16 |
+
from torchvision.transforms._transforms_video import NormalizeVideo, CenterCropVideo
|
| 17 |
+
from pytorchvideo.transforms import ShortSideScale
|
| 18 |
from .configuration_cogvlm import CogVLMConfig
|
| 19 |
from .util import FastRotaryEmbedding
|
| 20 |
from .visual import EVA2CLIPModel
|
| 21 |
|
|
|
|
|
|
|
| 22 |
if TYPE_CHECKING:
|
| 23 |
from transformers.utils import ModelOutput
|
| 24 |
|
|
|
|
| 92 |
|
| 93 |
def get_expert_mask(token_type_ids: "torch.LongTensor(B, L)") -> "[torch.BoolTensor(B, L), torch.BoolTensor(B, L)]":
|
| 94 |
vision_token_mask = torch.zeros_like(token_type_ids, dtype=torch.bool)
|
| 95 |
+
vision_token_mask[:, :-1] = (token_type_ids[:, :-1] == VISION_TOKEN_TYPE) & (
|
| 96 |
+
token_type_ids[:, 1:] == VISION_TOKEN_TYPE)
|
| 97 |
language_token_mask = ~vision_token_mask
|
| 98 |
return vision_token_mask, language_token_mask
|
| 99 |
|
|
|
|
| 109 |
# vision_token_mask, language_token_mask = get_expert_mask(token_type_ids)
|
| 110 |
# output[vision_token_mask] = self.vision_mlp(hidden_states[vision_token_mask])
|
| 111 |
# output[language_token_mask] = self.language_mlp(hidden_states[language_token_mask])
|
| 112 |
+
|
| 113 |
output = self.language_mlp(hidden_states)
|
| 114 |
return output
|
| 115 |
|
|
|
|
| 169 |
def _transpose_for_scores(self, tensor):
|
| 170 |
"""Transpose a 3D tensor [B, L, H*HD] into a 4D tensor with size [B H L HD]."""
|
| 171 |
new_tensor_shape = tensor.size()[:-1] + \
|
| 172 |
+
(-1, # flexible for multi-query
|
| 173 |
self.hidden_size_per_attention_head)
|
| 174 |
tensor = tensor.view(*new_tensor_shape)
|
| 175 |
return tensor.permute(0, 2, 1, 3)
|
|
|
|
| 206 |
if past_key_value is not None:
|
| 207 |
kv_seq_len += past_key_value[0].shape[-2]
|
| 208 |
|
| 209 |
+
query_states, key_states = self.rotary_emb(query_states, key_states, position_ids=position_ids,
|
| 210 |
+
max_seqlen=position_ids.max() + 1)
|
| 211 |
|
| 212 |
if past_key_value is not None:
|
| 213 |
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
|
|
|
| 215 |
|
| 216 |
past_key_value = (key_states, value_states) if use_cache else None
|
| 217 |
|
| 218 |
+
key_states = key_states.unsqueeze(2).expand(-1, -1, self.num_attention_heads // self.num_multi_query_heads, -1,
|
| 219 |
+
-1).contiguous().view(
|
| 220 |
bsz, self.num_attention_heads, *key_states.shape[2:])
|
| 221 |
+
value_states = value_states.unsqueeze(2).expand(-1, -1, self.num_attention_heads // self.num_multi_query_heads,
|
| 222 |
+
-1,
|
| 223 |
+
-1).contiguous().view(bsz, self.num_attention_heads,
|
| 224 |
+
*value_states.shape[2:])
|
| 225 |
|
| 226 |
context_layer = attention_fn(
|
| 227 |
query_layer=query_states, key_layer=key_states, value_layer=value_states, attention_mask=attention_mask,
|
|
|
|
| 236 |
# attn_output = torch.empty(context_layer.shape, dtype=hidden_states.dtype, device=hidden_states.device)
|
| 237 |
# attn_output[vision_token_mask] = self.vision_expert_dense(context_layer[vision_token_mask])
|
| 238 |
# attn_output[language_token_mask] = self.language_expert_dense(context_layer[language_token_mask])
|
| 239 |
+
|
| 240 |
attn_output = self.language_expert_dense(context_layer)
|
| 241 |
|
| 242 |
if output_attentions:
|
|
|
|
| 325 |
return True
|
| 326 |
|
| 327 |
|
| 328 |
+
def build_position_ids(x: "torch.BoolTensor(B, L)",
|
| 329 |
+
attention_mask: Optional["torch.BoolTensor(B, L)"] = None) -> "torch.LongTensor(B, L)":
|
| 330 |
if attention_mask is not None:
|
| 331 |
tmp = x.clone()
|
| 332 |
tmp[~(attention_mask.bool())] = -1
|
|
|
|
| 341 |
tmp[is_boi_eoi] = LANGUAGE_TOKEN_TYPE
|
| 342 |
# final position ids
|
| 343 |
y = torch.zeros_like(x, dtype=torch.long)
|
| 344 |
+
y[:, 1:] = (tmp[:, 1:] == LANGUAGE_TOKEN_TYPE) | (
|
| 345 |
+
(tmp[:, 1:] == VISION_TOKEN_TYPE) & (tmp[:, :-1] == LANGUAGE_TOKEN_TYPE))
|
| 346 |
y = y.cumsum(dim=-1)
|
| 347 |
return y
|
| 348 |
|
|
|
|
| 405 |
inputs_embeds = inputs_embeds.index_put([token_type_ids == VISION_TOKEN_TYPE], images_features)
|
| 406 |
else: # single-modality
|
| 407 |
if token_type_ids is None:
|
| 408 |
+
token_type_ids = torch.ones_like(input_ids, dtype=torch.long,
|
| 409 |
+
device=input_ids.device) * LANGUAGE_TOKEN_TYPE
|
| 410 |
assert not (token_type_ids == VISION_TOKEN_TYPE).any(), f"{(token_type_ids == VISION_TOKEN_TYPE).sum()}"
|
| 411 |
inputs_embeds = self.embed_tokens(input_ids)
|
| 412 |
|
|
|
|
| 587 |
self.model = CogVLMVideoModel(config)
|
| 588 |
self.vocab_size = config.vocab_size
|
| 589 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 590 |
+
self.video_downsample = 1 # TODO: change this to config
|
| 591 |
|
| 592 |
# Initialize weights and apply final processing
|
| 593 |
self.post_init()
|
|
|
|
| 684 |
return torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device) # type: ignore
|
| 685 |
|
| 686 |
def prepare_inputs_for_generation(
|
| 687 |
+
self, input_ids, token_type_ids, images=None, past_key_values=None, attention_mask=None, inputs_embeds=None,
|
| 688 |
+
**kwargs
|
| 689 |
):
|
| 690 |
# build position_ids if needed
|
| 691 |
position_ids = kwargs.get("position_ids", None)
|
|
|
|
| 732 |
# update token_type_ids with last value
|
| 733 |
if "token_type_ids" in model_kwargs:
|
| 734 |
token_type_ids = model_kwargs["token_type_ids"]
|
| 735 |
+
new_token_type_ids = torch.ones(size=(token_type_ids.shape[0], 1), dtype=token_type_ids.dtype,
|
| 736 |
+
device=token_type_ids.device) * LANGUAGE_TOKEN_TYPE
|
| 737 |
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, new_token_type_ids], dim=-1)
|
| 738 |
|
| 739 |
if not is_encoder_decoder:
|
|
|
|
| 762 |
)
|
| 763 |
return reordered_past
|
| 764 |
|
|
|
|
|
|
|
| 765 |
def build_conversation_input_ids(
|
| 766 |
self,
|
| 767 |
tokenizer: "PreTrainedTokenizer",
|
|
|
|
| 779 |
text = _history_to_prompt(template_version, history, query)
|
| 780 |
input_ids = [tokenizer.bos_token_id]
|
| 781 |
token_type_ids = [LANGUAGE_TOKEN_TYPE]
|
| 782 |
+
add_time_indices = True if template_version == 'chat' else False
|
| 783 |
if images is not None and len(images) == 1:
|
| 784 |
# vision
|
| 785 |
transform = transforms.Compose(
|
|
|
|
| 792 |
# RandomHorizontalFlipVideo(p=0.5),
|
| 793 |
]
|
| 794 |
)
|
| 795 |
+
images = [transform(images[0]).transpose(0, 1)] # (T, C, H, W)
|
| 796 |
num_eois = len(images[0])
|
| 797 |
tokenizer.pad_token_id = 128002
|
|
|
|
| 798 |
if not add_time_indices:
|
| 799 |
+
vision_token_num = (64 + 2) * num_eois
|
| 800 |
+
input_ids += [tokenizer.pad_token_id] * vision_token_num # add spetial token
|
| 801 |
token_type_ids += [VISION_TOKEN_TYPE] * vision_token_num
|
| 802 |
else:
|
| 803 |
video_ids, video_type_ids = [], []
|
| 804 |
+
sing_vision_token_num = (64 + 2)
|
| 805 |
for _time_idx in range(num_eois):
|
| 806 |
+
video_ids += [tokenizer.pad_token_id] * sing_vision_token_num
|
| 807 |
+
video_type_ids += [VISION_TOKEN_TYPE] * sing_vision_token_num
|
| 808 |
# add time indices
|
| 809 |
time_indices = tokenizer.encode(str(_time_idx), add_special_tokens=False)
|
| 810 |
video_ids += time_indices
|
|
|
|
| 812 |
# llama3 adapt for cogvlm
|
| 813 |
input_ids += video_ids
|
| 814 |
token_type_ids += video_type_ids
|
| 815 |
+
|
| 816 |
text_ids = tokenizer.encode(text, add_special_tokens=False)
|
| 817 |
|
| 818 |
if answer is not None:
|
|
|
|
| 820 |
answer_ids += [tokenizer.eos_token_id]
|
| 821 |
text_ids += answer_ids
|
| 822 |
|
|
|
|
| 823 |
input_ids += text_ids
|
| 824 |
token_type_ids += [LANGUAGE_TOKEN_TYPE] * len(text_ids)
|
| 825 |
attention_mask = [1] * len(input_ids)
|
|
|
|
| 836 |
'images': images,
|
| 837 |
'labels': labels,
|
| 838 |
}
|
|
|
|
|
|