# -------------------------------------------------------- # Ristretto # Copyright (c) 2025 LiAutoAD # Licensed under The MIT License # -------------------------------------------------------- import copy from typing import Any, List, Optional, Tuple, Union import torch.distributed as dist import torch.utils.checkpoint import transformers from torch import nn from torch.nn import CrossEntropyLoss from transformers import (GenerationConfig, LlamaConfig, LlamaForCausalLM, PretrainedConfig, Qwen2Config, Qwen2ForCausalLM, SiglipVisionConfig, SiglipVisionModel) from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.modeling_utils import PreTrainedModel from transformers.trainer_pt_utils import LabelSmoother from transformers.utils import logging from .conversation import get_conv_template from .projector import TokenAdaptiveProjector IGNORE_TOKEN_ID = LabelSmoother.ignore_index logger = logging.get_logger(__name__) logger.setLevel(logging.INFO) def version_cmp(v1, v2, op='eq'): import operator from packaging import version op_func = getattr(operator, op) return op_func(version.parse(v1), version.parse(v2)) class RistrettoConfig(PretrainedConfig): model_type = 'ristretto' is_composition = True def __init__( self, vision_config=dict(model_type='siglip_vision_model'), llm_config=dict(architectures=['Qwen2ForCausalLM']), pad2square=False, select_layer=-1, force_image_size=None, num_image_token=256, template=None, dynamic_image_size=False, use_thumbnail=False, min_dynamic_patch=1, max_dynamic_patch=6, **kwargs): super().__init__(**kwargs) if vision_config["model_type"] == "siglip_vision_model": self.vision_config = SiglipVisionConfig(**vision_config) else: raise ValueError('Unsupported architecture: {}'.format(vision_config['model_type'])) if llm_config['architectures'][0] == 'LlamaForCausalLM': self.llm_config = LlamaConfig(**llm_config) elif llm_config['architectures'][0] == 'Qwen2ForCausalLM': self.llm_config = Qwen2Config(**llm_config) else: raise ValueError('Unsupported architecture: {}'.format(llm_config['architectures'][0])) self.pad2square = pad2square self.select_layer = select_layer self.force_image_size = force_image_size self.num_image_token = num_image_token self.template = template self.dynamic_image_size = dynamic_image_size self.use_thumbnail = use_thumbnail self.min_dynamic_patch = min_dynamic_patch self.max_dynamic_patch = max_dynamic_patch logger.info(f'vision_select_layer: {self.select_layer}') logger.info(f'min_dynamic_patch: {self.min_dynamic_patch}') logger.info(f'max_dynamic_patch: {self.max_dynamic_patch}') def to_dict(self): """ Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. Returns: `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, """ output = copy.deepcopy(self.__dict__) output['vision_config'] = self.vision_config.to_dict() output['llm_config'] = self.llm_config.to_dict() output['model_type'] = self.__class__.model_type output['pad2square'] = self.pad2square output['select_layer'] = self.select_layer output['force_image_size'] = self.force_image_size output['num_image_token'] = self.num_image_token output['template'] = self.template output['dynamic_image_size'] = self.dynamic_image_size output['use_thumbnail'] = self.use_thumbnail output['min_dynamic_patch'] = self.min_dynamic_patch output['max_dynamic_patch'] = self.max_dynamic_patch return output class RistrettoModel(PreTrainedModel): config_class = RistrettoConfig main_input_name = 'pixel_values' _no_split_modules = ['SiglipVisionModel', 'LlamaDecoderLayer', 'Qwen2DecoderLayer'] _supports_flash_attn_2 = True _keys_to_ignore_on_save = [] def __init__(self, config: RistrettoConfig, vision_model=None, language_model=None): super().__init__(config) assert version_cmp(transformers.__version__, '4.37.0', 'ge') image_size = config.force_image_size or config.vision_config.image_size patch_size = config.vision_config.patch_size self.image_size = image_size self.patch_size = patch_size self.select_layer = config.select_layer self.template = config.template self.num_image_token = config.num_image_token self.llm_arch_name = config.llm_config.architectures[0] self.vision_model_type = config.vision_config.model_type if vision_model is not None: self.vision_model = vision_model else: if config.vision_config.model_type == 'siglip_vision_model': self.vision_model = SiglipVisionModel(config.vision_config) else: raise NotImplementedError(f'{config.vision_config.model_type} is not implemented.') if language_model is not None: self.language_model = language_model else: if config.llm_config.architectures[0] == 'LlamaForCausalLM': self.language_model = LlamaForCausalLM(config.llm_config) elif config.llm_config.architectures[0] == 'Qwen2ForCausalLM': self.language_model = Qwen2ForCausalLM(config.llm_config) else: raise NotImplementedError(f'{config.llm_config.architectures[0]} is not implemented.') vit_hidden_size = config.vision_config.hidden_size llm_hidden_size = config.llm_config.hidden_size self.projector = TokenAdaptiveProjector( vit_hidden_size=vit_hidden_size, llm_hidden_size=llm_hidden_size, num_image_token=self.num_image_token, ) self.img_context_token_id = None self.conv_template = get_conv_template(self.template) self.system_message = self.conv_template.system_message self.num_samples = 0 def forward( self, pixel_values: torch.FloatTensor, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, image_flags: Optional[torch.LongTensor] = None, num_image_tokens: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: return_dict = return_dict if return_dict is not None else self.config.use_return_dict num_image_token = None if num_image_tokens is not None: assert num_image_tokens.unique().shape[0] == 1, 'num_image_tokens must be the same for all samples in a batch' num_image_token = num_image_tokens[0].item() image_flags = image_flags.squeeze(-1) input_embeds = self.language_model.get_input_embeddings()(input_ids).clone() vit_embeds = self.extract_feature(pixel_values, num_image_token) vit_embeds = vit_embeds[image_flags == 1] vit_batch_size = pixel_values.shape[0] B, N, C = input_embeds.shape input_embeds = input_embeds.reshape(B * N, C) if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0: print(f'dynamic ViT batch size: {vit_batch_size}, images per sample: {vit_batch_size / B}, dynamic token length: {N}') input_ids = input_ids.reshape(B * N) selected = (input_ids == self.img_context_token_id) try: input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds.reshape(-1, C) ignore_flag = False except Exception as e: vit_embeds = vit_embeds.reshape(-1, C) print(f'warning: {e}, input_embeds[selected].shape={input_embeds[selected].shape}, ' f'vit_embeds.shape={vit_embeds.shape}') n_token = selected.sum() input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds[:n_token] ignore_flag = True input_embeds = input_embeds.reshape(B, N, C) outputs = self.language_model( inputs_embeds=input_embeds, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) logits = outputs.logits loss = None if labels is not None: loss_fct = CrossEntropyLoss(reduction='none') # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Calc loss weight loss_token_mask = shift_labels != loss_fct.ignore_index loss_token_num = loss_token_mask.sum(dim=1, keepdim=True).float() loss_token_weight = 1. / (loss_token_num.expand_as(shift_labels) ** 0.5 + 1e-6) # Flatten the tokens shift_logits = shift_logits.view(-1, self.language_model.config.vocab_size) shift_labels = shift_labels.view(-1) loss_token_weight = loss_token_weight.view(-1) loss_token_mask = loss_token_mask.view(-1) # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) all_token_weight = (loss_token_weight * loss_token_mask.float()).sum() dist.all_reduce(all_token_weight, op=dist.ReduceOp.SUM) loss = (loss * loss_token_weight * loss_token_mask.float()).sum() / (all_token_weight + 1e-6) # Hack for DDP training, since the loss is reduced in the forward function loss = loss * dist.get_world_size() if ignore_flag: loss = loss * 0.0 if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def extract_feature(self, pixel_values, num_image_token=None): if self.select_layer == -1: vit_embeds = self.vision_model( pixel_values=pixel_values, output_hidden_states=False, return_dict=True).last_hidden_state else: vit_embeds = self.vision_model( pixel_values=pixel_values, output_hidden_states=True, return_dict=True).hidden_states[self.select_layer] vit_embeds = self.projector(vit_embeds, num_image_token=num_image_token) return vit_embeds def batch_chat(self, tokenizer, pixel_values, questions, generation_config, num_patches_list=None, history=None, return_history=False, IMG_START_TOKEN='', IMG_END_TOKEN='', IMG_CONTEXT_TOKEN='', verbose=False, image_counts=None): if history is not None or return_history: print('Now multi-turn chat is not supported in batch_chat.') raise NotImplementedError if image_counts is not None: num_patches_list = image_counts print('Warning: `image_counts` is deprecated. Please use `num_patches_list` instead.') img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN) self.img_context_token_id = img_context_token_id if verbose and pixel_values is not None: image_bs = pixel_values.shape[0] print(f'dynamic ViT batch size: {image_bs}') queries = [] for idx, _num_patches_list in enumerate(num_patches_list): question = questions[idx] if pixel_values is not None and '' not in question: question = '\n' + question template = get_conv_template(self.template) template.system_message = self.system_message template.append_message(template.roles[0], question) template.append_message(template.roles[1], None) query = template.get_prompt() for num_patches in _num_patches_list: image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN query = query.replace('', image_tokens, 1) queries.append(query) tokenizer.padding_side = 'left' model_inputs = tokenizer(queries, return_tensors='pt', padding=True) input_ids = model_inputs['input_ids'].cuda() attention_mask = model_inputs['attention_mask'].cuda() eos_token_id = tokenizer.convert_tokens_to_ids(template.sep) generation_config['eos_token_id'] = eos_token_id generation_output = self.generate( pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask, **generation_config ) responses = tokenizer.batch_decode(generation_output, skip_special_tokens=True) responses = [response.split(template.sep)[0].strip() for response in responses] return responses def chat(self, tokenizer, pixel_values, question, generation_config, num_image_token=None, history=None, return_history=False, num_patches_list=None, IMG_START_TOKEN='', IMG_END_TOKEN='', IMG_CONTEXT_TOKEN='', verbose=False): if history is None and pixel_values is not None and '' not in question: question = '\n' + question if num_patches_list is None: num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else [] assert pixel_values is None or len(pixel_values) == sum(num_patches_list) img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN) self.img_context_token_id = img_context_token_id template = get_conv_template(self.template) template.system_message = self.system_message eos_token_id = tokenizer.convert_tokens_to_ids(template.sep) history = [] if history is None else history for (old_question, old_answer) in history: template.append_message(template.roles[0], old_question) template.append_message(template.roles[1], old_answer) template.append_message(template.roles[0], question) template.append_message(template.roles[1], None) query = template.get_prompt() if verbose and pixel_values is not None: image_bs = pixel_values.shape[0] print(f'dynamic ViT batch size: {image_bs}') if num_image_token is None: num_image_token = self.num_image_token for num_patches in num_patches_list: image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * num_image_token * num_patches + IMG_END_TOKEN query = query.replace('', image_tokens, 1) model_inputs = tokenizer(query, return_tensors='pt') input_ids = model_inputs['input_ids'].cuda() attention_mask = model_inputs['attention_mask'].cuda() generation_config['eos_token_id'] = tokenizer.eos_token_id generation_config['pad_token_id'] = tokenizer.pad_token_id generation_output = self.generate( pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask, num_image_token=num_image_token, **generation_config ) response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0] response = response.split(template.sep)[0].strip() history.append((question, response)) if return_history: return response, history else: query_to_print = query.replace(IMG_CONTEXT_TOKEN, '') query_to_print = query_to_print.replace(f'{IMG_START_TOKEN}{IMG_END_TOKEN}', '') if verbose: print(query_to_print, response) return response @torch.no_grad() def generate( self, pixel_values: Optional[torch.FloatTensor] = None, input_ids: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.LongTensor] = None, visual_features: Optional[torch.FloatTensor] = None, num_image_token: Optional[int] = None, generation_config: Optional[GenerationConfig] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **generate_kwargs, ) -> torch.LongTensor: assert self.img_context_token_id is not None if pixel_values is not None: if visual_features is not None: vit_embeds = visual_features else: vit_embeds = self.extract_feature(pixel_values, num_image_token) input_embeds = self.language_model.get_input_embeddings()(input_ids) B, N, C = input_embeds.shape input_embeds = input_embeds.reshape(B * N, C) input_ids = input_ids.reshape(B * N) selected = (input_ids == self.img_context_token_id) assert selected.sum() != 0 input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device) input_embeds = input_embeds.reshape(B, N, C) else: input_embeds = self.language_model.get_input_embeddings()(input_ids) outputs = self.language_model.generate( inputs_embeds=input_embeds, attention_mask=attention_mask, generation_config=generation_config, output_hidden_states=output_hidden_states, use_cache=True, **generate_kwargs, ) return outputs