|  | import torch | 
					
						
						|  | from .vision_encoder import VisionEncoder | 
					
						
						|  | from .configuration_moondream import MoondreamConfig | 
					
						
						|  | from transformers import PreTrainedModel | 
					
						
						|  |  | 
					
						
						|  | from .modeling_phi import PhiForCausalLM | 
					
						
						|  | from .configuration_moondream import PhiConfig | 
					
						
						|  |  | 
					
						
						|  | class Moondream(PreTrainedModel): | 
					
						
						|  | config_class = MoondreamConfig | 
					
						
						|  | _supports_flash_attn_2 = True | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, config): | 
					
						
						|  | super().__init__(config) | 
					
						
						|  | self.vision_encoder = VisionEncoder() | 
					
						
						|  |  | 
					
						
						|  | if type(config.phi_config) == dict: | 
					
						
						|  | phi_config = PhiConfig( | 
					
						
						|  | **config.phi_config, attn_implementation=config._attn_implementation | 
					
						
						|  | ) | 
					
						
						|  | else: | 
					
						
						|  | phi_config = config.phi_config | 
					
						
						|  | self.text_model = PhiForCausalLM(phi_config) | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | def device(self): | 
					
						
						|  | return self.text_model.device | 
					
						
						|  |  | 
					
						
						|  | def encode_image(self, image): | 
					
						
						|  | return self.vision_encoder(image) | 
					
						
						|  |  | 
					
						
						|  | def input_embeds(self, prompt, image_embeds, tokenizer): | 
					
						
						|  | def _tokenize(txt): | 
					
						
						|  | return tokenizer( | 
					
						
						|  | txt, return_tensors="pt", add_special_tokens=False | 
					
						
						|  | ).input_ids.to(self.device) | 
					
						
						|  |  | 
					
						
						|  | text_emb = self.text_model.get_input_embeddings() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | embeds = [] | 
					
						
						|  | embeds.append( | 
					
						
						|  | text_emb((torch.tensor([[tokenizer.bos_token_id]], device=self.device))) | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | if "<image>" not in prompt: | 
					
						
						|  | embeds.append(text_emb(_tokenize(prompt))) | 
					
						
						|  | else: | 
					
						
						|  | assert prompt.count("<image>") == 1 | 
					
						
						|  | before, after = prompt.split("<image>") | 
					
						
						|  | if len(before) > 0: | 
					
						
						|  | embeds.append(text_emb(_tokenize(before))) | 
					
						
						|  | embeds.append(image_embeds.to(self.device)) | 
					
						
						|  | if len(after) > 0: | 
					
						
						|  | embeds.append(text_emb(_tokenize(after))) | 
					
						
						|  |  | 
					
						
						|  | return torch.cat(embeds, dim=1) | 
					
						
						|  |  | 
					
						
						|  | def generate( | 
					
						
						|  | self, | 
					
						
						|  | image_embeds, | 
					
						
						|  | prompt, | 
					
						
						|  | tokenizer, | 
					
						
						|  | max_new_tokens=128, | 
					
						
						|  | **kwargs, | 
					
						
						|  | ): | 
					
						
						|  | generate_config = { | 
					
						
						|  | "eos_token_id": tokenizer.eos_token_id, | 
					
						
						|  | "bos_token_id": tokenizer.bos_token_id, | 
					
						
						|  | "pad_token_id": tokenizer.bos_token_id, | 
					
						
						|  | "max_new_tokens": max_new_tokens, | 
					
						
						|  | **kwargs, | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | with torch.no_grad(): | 
					
						
						|  | inputs_embeds = self.input_embeds(prompt, image_embeds, tokenizer) | 
					
						
						|  | output_ids = self.text_model.generate( | 
					
						
						|  | inputs_embeds=inputs_embeds, **generate_config | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | return tokenizer.batch_decode(output_ids, skip_special_tokens=True) | 
					
						
						|  |  | 
					
						
						|  | def answer_question( | 
					
						
						|  | self, | 
					
						
						|  | image_embeds, | 
					
						
						|  | question, | 
					
						
						|  | tokenizer, | 
					
						
						|  | chat_history="", | 
					
						
						|  | result_queue=None, | 
					
						
						|  | **kwargs, | 
					
						
						|  | ): | 
					
						
						|  | prompt = f"<image>\n\n{chat_history}Question: {question}\n\nAnswer:" | 
					
						
						|  | answer = self.generate( | 
					
						
						|  | image_embeds, | 
					
						
						|  | prompt, | 
					
						
						|  | tokenizer=tokenizer, | 
					
						
						|  | max_new_tokens=512, | 
					
						
						|  | **kwargs, | 
					
						
						|  | )[0] | 
					
						
						|  | cleaned_answer = answer.strip() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if result_queue: | 
					
						
						|  | result_queue.put(cleaned_answer) | 
					
						
						|  | else: | 
					
						
						|  | return cleaned_answer | 
					
						
						|  |  | 
					
						
						|  | def batch_answer( | 
					
						
						|  | self, | 
					
						
						|  | images, | 
					
						
						|  | prompts, | 
					
						
						|  | tokenizer, | 
					
						
						|  | **kwargs, | 
					
						
						|  | ): | 
					
						
						|  | image_embeds = self.encode_image(images) | 
					
						
						|  |  | 
					
						
						|  | templated_prompts = [ | 
					
						
						|  | f"<image>\n\nQuestion: {prompt}\n\nAnswer:" for prompt in prompts | 
					
						
						|  | ] | 
					
						
						|  | prompt_embs = [ | 
					
						
						|  | self.input_embeds(prompt, image_embed.unsqueeze(0), tokenizer)[0] | 
					
						
						|  | for prompt, image_embed in zip(templated_prompts, image_embeds) | 
					
						
						|  | ] | 
					
						
						|  |  | 
					
						
						|  | bos_emb = prompt_embs[0][0] | 
					
						
						|  | max_len = max([p.shape[0] for p in prompt_embs]) | 
					
						
						|  |  | 
					
						
						|  | inputs_embeds = torch.cat( | 
					
						
						|  | [ | 
					
						
						|  | torch.cat([bos_emb.repeat(max_len - p.shape[0], 1), p]).unsqueeze(0) | 
					
						
						|  | for p in prompt_embs | 
					
						
						|  | ], | 
					
						
						|  | dim=0, | 
					
						
						|  | ) | 
					
						
						|  | attention_mask = torch.cat( | 
					
						
						|  | [ | 
					
						
						|  | torch.cat( | 
					
						
						|  | [ | 
					
						
						|  | torch.zeros( | 
					
						
						|  | 1, | 
					
						
						|  | max_len - p.shape[0], | 
					
						
						|  | device=self.device, | 
					
						
						|  | dtype=torch.long, | 
					
						
						|  | ), | 
					
						
						|  | torch.ones(1, p.shape[0], device=self.device, dtype=torch.long), | 
					
						
						|  | ], | 
					
						
						|  | dim=1, | 
					
						
						|  | ) | 
					
						
						|  | for p in prompt_embs | 
					
						
						|  | ], | 
					
						
						|  | dim=0, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | generate_config = { | 
					
						
						|  | "eos_token_id": tokenizer.eos_token_id, | 
					
						
						|  | "bos_token_id": tokenizer.bos_token_id, | 
					
						
						|  | "pad_token_id": tokenizer.bos_token_id, | 
					
						
						|  | "max_new_tokens": 512, | 
					
						
						|  | **kwargs, | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | with torch.no_grad(): | 
					
						
						|  | output_ids = self.text_model.generate( | 
					
						
						|  | inputs_embeds=inputs_embeds, | 
					
						
						|  | attention_mask=attention_mask, | 
					
						
						|  | **generate_config, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | return [ | 
					
						
						|  | x.strip() | 
					
						
						|  | for x in tokenizer.batch_decode(output_ids, skip_special_tokens=True) | 
					
						
						|  | ] | 
					
						
						|  |  |