File size: 5,250 Bytes
76abf0f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
from typing import List, Optional, Tuple, Union
import torch
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3ForCausalLM
from transformers import AutoModelForCausalLM
from .configuration_dots_vlm import DotsVisionConfig, DotsVLMConfig
from .modeling_dots_vision import DotsVisionTransformer
DOTS_VLM_MAX_IMAGES = 200
class DotsVLMForCausalLM(DeepseekV3ForCausalLM):
config_class = DotsVLMConfig
def __init__(self, config: DotsVLMConfig):
super().__init__(config)
if isinstance(self.config.vision_config, dict):
vision_config = DotsVisionConfig(**self.config.vision_config)
self.config.vision_config = vision_config
else:
vision_config = self.config.vision_config
self.vision_tower = DotsVisionTransformer(vision_config)
def prepare_inputs_embeds(
self,
input_ids: torch.LongTensor,
pixel_values: Optional[torch.FloatTensor] = None,
grid_thw: Optional[torch.FloatTensor] = None,
img_mask: Optional[torch.BoolTensor] = None,
) -> torch.Tensor:
inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None:
assert img_mask is not None
if grid_thw.shape[0] > DOTS_VLM_MAX_IMAGES:
print(
f"Num image exceeded: {grid_thw.shape[0]} > {DOTS_VLM_MAX_IMAGES}, which may cause FSDP hang"
)
vision_embeddings = self.vision_tower(pixel_values, grid_thw)
true_indices = torch.nonzero(img_mask).squeeze()
if len(true_indices) > vision_embeddings.size(0):
print(
f"img_mask sum > VE and will be truncated, mask.sum()={len(true_indices)} {vision_embeddings.size(0)=}"
)
true_indices = true_indices[: vision_embeddings.size(0)]
new_img_mask = torch.zeros_like(img_mask, device=img_mask.device)
new_img_mask[true_indices[:, 0], true_indices[:, 1]] = True
else:
new_img_mask = img_mask
assert (
vision_embeddings.size(0) == new_img_mask.sum()
), f"{vision_embeddings.size(0)=}, {new_img_mask.sum()=}"
inputs_embeds = inputs_embeds.masked_scatter(
new_img_mask.to(inputs_embeds.device).unsqueeze(-1).expand_as(inputs_embeds),
vision_embeddings.to(inputs_embeds.device).type(inputs_embeds.dtype),
)
return inputs_embeds
def forward(
self,
input_ids: torch.LongTensor,
pixel_values: Optional[torch.FloatTensor] = None,
image_grid_thw: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
use_cache: Optional[bool] = None,
logits_to_keep: int = 0,
**loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
assert len(input_ids) >= 1, f"empty input_ids {input_ids.shape=} will cause gradnorm nan"
if inputs_embeds is None:
img_mask = input_ids == self.config.image_token_id
inputs_embeds = self.prepare_inputs_embeds(input_ids, pixel_values, image_grid_thw, img_mask)
# Call parent's forward method
outputs = super().forward(
input_ids=None, # Pass None since we're using inputs_embeds
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
labels=labels,
use_cache=use_cache if use_cache is not None else self.config.use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
return outputs
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
inputs_embeds=None,
pixel_values=None,
attention_mask=None,
cache_position=None,
num_logits_to_keep=None,
**kwargs,
):
model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
num_logits_to_keep=num_logits_to_keep,
**kwargs,
)
if cache_position is not None and cache_position[0] == 0:
model_inputs["pixel_values"] = pixel_values
return model_inputs
# Register the model with AutoModel
AutoModelForCausalLM.register(DotsVLMConfig, DotsVLMForCausalLM) |