import os from types import SimpleNamespace from typing import Tuple, List, Optional, Union import torch import torch.nn as nn from huggingface_hub import hf_hub_download from transformers import Qwen2ForCausalLM, AutoModel, AutoModelForCausalLM from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm, Qwen2RotaryEmbedding, Qwen2DecoderLayer, Qwen2Model, Qwen2PreTrainedModel from .configuration_xomni import XOmniConfig from .modeling_siglip_tokenizer import create_anyres_preprocess, SiglipTokenizer from .modeling_siglip_flux import FluxTransformer2DModelWithSigLIP, FluxPipelineWithSigLIP from .modeling_vit import create_siglip_vit class XOmniDecoderLayer(Qwen2DecoderLayer): def __init__(self, config: XOmniConfig, layer_idx: int): super().__init__(config, layer_idx) self.layer_idx = layer_idx self.is_lm_layer = config.num_mm_adap_layers <= layer_idx < config.num_hidden_layers - config.num_mm_head_layers def forward( self, hidden_states: torch.Tensor, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: hidden_states, multimodal_mask = torch.split(hidden_states, hidden_states.shape[-1] // 2, dim=-1) if self.is_lm_layer: output_hidden_states, *others = super().forward(hidden_states, **kwargs) output_hidden_states = torch.cat([output_hidden_states, multimodal_mask], dim=-1) return output_hidden_states, *others # mm_hidden_states = torch.where(multimodal_mask.bool(), hidden_states, torch.zeros_like(hidden_states)) output_hidden_states, *others = super().forward(hidden_states, **kwargs) output_hidden_states = torch.where(multimodal_mask.bool(), output_hidden_states, hidden_states) output_hidden_states = torch.cat([output_hidden_states, multimodal_mask], dim=-1) return output_hidden_states, *others class XOmniModel(Qwen2Model, Qwen2PreTrainedModel): model_type = "x-omni" config_class = XOmniConfig def __init__(self, config: XOmniConfig): Qwen2PreTrainedModel.__init__(self, config) self.padding_idx = -1 self.vocab_size = config.vocab_size self.lm_embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.mm_embed_tokens = nn.Embedding(config.mm_vocab_size, config.hidden_size, self.padding_idx) self.layers = nn.ModuleList( [XOmniDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self._attn_implementation = config._attn_implementation self.lm_norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.mm_norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = Qwen2RotaryEmbedding(config=config) self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.lm_embed_tokens def set_input_embeddings(self, value): self.lm_embed_tokens = value def embed_tokens(self, input_ids): (B, L), C = input_ids.shape, self.config.hidden_size multimodal_mask = input_ids >= self.config.vocab_size lm_input_ids = input_ids[~multimodal_mask][None, :] mm_input_ids = input_ids[multimodal_mask][None, :] - self.config.vocab_size lm_embeds = self.lm_embed_tokens(lm_input_ids) mm_embeds = self.mm_embed_tokens(mm_input_ids) inputs_embeds = lm_embeds.new_empty((B, L, C)) multimodal_mask = multimodal_mask[:, :, None].expand_as(inputs_embeds) inputs_embeds[~multimodal_mask] = lm_embeds.reshape(-1) inputs_embeds[multimodal_mask] = mm_embeds.reshape(-1) inputs_embeds = torch.cat([inputs_embeds, multimodal_mask.to(inputs_embeds.dtype)], dim=-1) return inputs_embeds def norm(self, hidden_states): hidden_states, multimodal_mask = torch.split(hidden_states, hidden_states.shape[-1] // 2, dim=-1) return torch.where(multimodal_mask.bool(), self.mm_norm(hidden_states), self.lm_norm(hidden_states)) class XOmniForCausalLM(Qwen2ForCausalLM): model_type = "x-omni" config_class = XOmniConfig _keys_to_ignore_on_load_missing = r'image_tokenizer\.*' def __init__(self, config): super().__init__(config) self.model = XOmniModel(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.mm_head = nn.Linear(config.hidden_size, config.mm_vocab_size, bias=False) self.generation_mode = 'text' # Initialize weights and apply final processing self.post_init() @property def device(self): return next(iter(self.parameters())).device def init_vision(self, flux_pipe_path): self.som_token = self.config.mm_special_tokens[0] self.eom_token = self.config.mm_special_tokens[1] self.img_token = self.config.mm_special_tokens[2] self.vision_config = SimpleNamespace(**self.config.vision_config) self.transform_config = SimpleNamespace(**self.vision_config.transform) self.encoder_config = SimpleNamespace(**self.vision_config.encoder) self.decoder_config = SimpleNamespace(**self.vision_config.decoder) dtype_map = {'float32': torch.float32, 'float16': torch.float16, 'bfloat16': torch.bfloat16} self.vision_dtype = dtype_map[self.vision_config.dtype] self.image_transform = create_anyres_preprocess(**self.vision_config.transform) self.encoder_config.siglip_path = os.path.join(self.name_or_path, self.encoder_config.siglip_path) if os.path.isdir(self.name_or_path) else hf_hub_download(repo_id=self.name_or_path, filename=self.encoder_config.siglip_path) self.encoder_config.projector_path = os.path.join(self.name_or_path, self.encoder_config.projector_path) if os.path.isdir(self.name_or_path) else hf_hub_download(repo_id=self.name_or_path, filename=self.encoder_config.projector_path) self.image_tokenizer = SiglipTokenizer(**vars(self.encoder_config)) self.image_tokenizer.to(self.device, self.vision_dtype) self.decoder_pipe = FluxPipelineWithSigLIP.from_pretrained( flux_pipe_path, torch_dtype=self.vision_dtype, ) self.decoder_pipe.transformer = FluxTransformer2DModelWithSigLIP.from_pretrained( self.name_or_path, siglip_channels=self.encoder_config.z_channels, torch_dtype=self.vision_dtype, subfolder=self.decoder_config.model_path, ) self.decoder_pipe.set_progress_bar_config(disable=True) self.decoder_pipe.to(self.device) def set_generation_mode(self, mode): assert mode in ('text', 'image'), f'Invalid generation mode: {mode}' self.generation_mode = mode def mmencode(self, tokenizer, texts=None, images=None, **kwargs): texts = texts or [] images = images or [] doc = '' while len(texts) > 0 or len(images) > 0: if len(texts) > 0: doc += texts.pop(0) if len(images) > 0: doc += self.tokenize_image(images.pop(0)) return tokenizer.encode(doc, **kwargs) def mmdecode(self, tokenizer, token_ids, force_text=None, **kwargs): force_text = force_text or [] if isinstance(token_ids, torch.Tensor): if len(token_ids.shape) == 2: assert token_ids.shape[0] == 1 token_ids = token_ids[0] assert len(token_ids.shape) == 1 else: if not isinstance(token_ids[0], int): assert len(token_ids) == 1 token_ids = token_ids[0] assert isinstance(token_ids[0], int) doc = tokenizer.decode(token_ids, **kwargs) doc = doc.replace(tokenizer.pad_token, '') doc = doc.replace('', '') texts, images = [], [] text_image_chunks = doc.split(self.eom_token) for chunk in text_image_chunks: text, image_str = chunk.split(self.som_token) \ if self.som_token in chunk else (chunk, '') texts.append(text) if self.img_token in image_str: image_meta, token_str = image_str.split(self.img_token) H, W = tuple(map(int, image_meta.split(' '))) token_ids = list(map( lambda x: int(x.split('>')[0]), token_str.split(' 0: image = self.detokenize_image([force_text.pop(0)], images, token_ids, (H, W)) else: image = self.detokenize_image(texts, images, token_ids, (H, W)) images.append(image) return texts, images @torch.no_grad() def tokenize_image(self, image): assert hasattr(self, 'image_tokenizer'), 'Please call "init_vision" before that.' image_str = self.som_token image = self.image_transform(image) assert image is not None, f'Unsupported image aspect ratio (max {self.transform_config.max_aspect_ratio}) or image resolution is too low (min {self.transform_config.min_short_size})' image = image[None, ...].to(self.device, self.vision_dtype) tokens = self.image_tokenizer.encode(image) B, H, W = tokens.shape tokens = tokens.view(B, -1).cpu().tolist()[0] token_str = ''.join(map(lambda x: ''.format(token_id=x), tokens)) image_str = f'{self.som_token}{H} {W}{self.img_token}{token_str}{self.eom_token}' return image_str @torch.no_grad() def detokenize_image(self, texts, images, token_ids, shape): assert hasattr(self, 'image_tokenizer'), 'Please call "init_vision" before that.' assert len(texts) == 1 and len(images) == 0, 'Only support one image per sample.' H, W = shape tokens = torch.tensor(token_ids, device=self.device, dtype=torch.long) latents = self.image_tokenizer.decode(tokens, (1, H, W, self.encoder_config.codebook_dim)) upscale_factor = self.decoder_config.upscale_factor latents = latents.reshape(*latents.shape[:2], -1).transpose(1, 2).contiguous() image = self.decoder_pipe( latents, [texts[0]], negative_prompt=[''], height=H * upscale_factor, width=W * upscale_factor, num_inference_steps=self.decoder_config.num_inference_steps, guidance_scale=1.0, true_cfg_scale=self.decoder_config.cfg_scale, true_cfg_scale_2=self.decoder_config.cfg_scale_2, ).images[0] return image def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, num_logits_to_keep: int = 0, ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, ) hidden_states = outputs[0] hidden_states = hidden_states[:, -num_logits_to_keep:, :] logits = hidden_states.new_full( (*hidden_states.shape[:-1], self.config.vocab_size + self.config.mm_vocab_size), torch.finfo(hidden_states.dtype).min ) if self.generation_mode == 'text': logits[:, :, :self.config.vocab_size] = self.lm_head(hidden_states) else: logits[:, :, self.config.vocab_size:self.config.vocab_size + self.config.image_vocab_size] = self.mm_head(hidden_states)[:, :, :self.config.image_vocab_size] logits = logits.float() loss = None if labels is not None: # Upcast to float if we need to compute the loss to avoid potential precision issues logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = nn.CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) AutoModel.register(XOmniConfig, XOmniModel) AutoModelForCausalLM.register(XOmniConfig, XOmniForCausalLM)