Commit 
							
							·
						
						d021148
	
1
								Parent(s):
							
							8770b6f
								
Replace the inplace operation (#7)
Browse files- Update modeling_minicpmo.py (909a86b1f20fd048c8a8fbe4119910812cc3eaaf)
Co-authored-by: Zhangchi Feng <[email protected]>
- modeling_minicpmo.py +10 -6
    	
        modeling_minicpmo.py
    CHANGED
    
    | @@ -377,10 +377,12 @@ class MiniCPMO(MiniCPMOPreTrainedModel): | |
| 377 | 
             
                    else:
         | 
| 378 | 
             
                        vllm_embedding = self.llm.model.embed_tokens(data["input_ids"])
         | 
| 379 |  | 
|  | |
|  | |
| 380 | 
             
                    vision_hidden_states = [
         | 
| 381 | 
             
                        i.type(vllm_embedding.dtype) if isinstance(i, torch.Tensor) else i for i in vision_hidden_states
         | 
| 382 | 
             
                    ]
         | 
| 383 | 
            -
             | 
| 384 | 
             
                    bs = len(data["input_ids"])
         | 
| 385 | 
             
                    for i in range(bs):
         | 
| 386 | 
             
                        cur_vs_hs = vision_hidden_states[i]
         | 
| @@ -392,15 +394,16 @@ class MiniCPMO(MiniCPMOPreTrainedModel): | |
| 392 | 
             
                                    [torch.arange(r[0], r[1], dtype=torch.long) for r in cur_image_bound]
         | 
| 393 | 
             
                                ).to(vllm_embedding.device)
         | 
| 394 |  | 
| 395 | 
            -
                                cur_vllm_emb. | 
| 396 | 
             
                                    0,
         | 
| 397 | 
             
                                    image_indices.view(-1, 1).repeat(1, cur_vllm_emb.shape[-1]),
         | 
| 398 | 
             
                                    cur_vs_hs.view(-1, cur_vs_hs.shape[-1]),
         | 
| 399 | 
             
                                )
         | 
|  | |
| 400 | 
             
                            elif self.training:
         | 
| 401 | 
            -
                                 | 
| 402 |  | 
| 403 | 
            -
                    return  | 
| 404 |  | 
| 405 | 
             
                def get_audio_embedding_streaming(self, data):
         | 
| 406 | 
             
                    r"""
         | 
| @@ -595,7 +598,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel): | |
| 595 | 
             
                    elif self.training:
         | 
| 596 | 
             
                        for i in range(bs):
         | 
| 597 | 
             
                            # dummy audio_embeddings
         | 
| 598 | 
            -
                            input_embeddings  | 
| 599 |  | 
| 600 | 
             
                    return input_embeddings
         | 
| 601 |  | 
| @@ -751,7 +754,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel): | |
| 751 | 
             
                    input_ids=None,
         | 
| 752 | 
             
                    pixel_values=None,
         | 
| 753 | 
             
                    tgt_sizes=None,
         | 
| 754 | 
            -
                    audio_features= | 
| 755 | 
             
                    audio_feature_lens=None,
         | 
| 756 | 
             
                    image_bound=None,
         | 
| 757 | 
             
                    audio_bounds=None,
         | 
| @@ -2655,6 +2658,7 @@ class ConditionalChatTTS(PreTrainedModel): | |
| 2655 | 
             
                """
         | 
| 2656 |  | 
| 2657 | 
             
                config_class = ConditionalChatTTSConfig
         | 
|  | |
| 2658 |  | 
| 2659 | 
             
                def __init__(self, config: ConditionalChatTTSConfig):
         | 
| 2660 | 
             
                    super().__init__(config)
         | 
|  | |
| 377 | 
             
                    else:
         | 
| 378 | 
             
                        vllm_embedding = self.llm.model.embed_tokens(data["input_ids"])
         | 
| 379 |  | 
| 380 | 
            +
                    new_vllm_embedding = vllm_embedding.clone()
         | 
| 381 | 
            +
                    
         | 
| 382 | 
             
                    vision_hidden_states = [
         | 
| 383 | 
             
                        i.type(vllm_embedding.dtype) if isinstance(i, torch.Tensor) else i for i in vision_hidden_states
         | 
| 384 | 
             
                    ]
         | 
| 385 | 
            +
                    
         | 
| 386 | 
             
                    bs = len(data["input_ids"])
         | 
| 387 | 
             
                    for i in range(bs):
         | 
| 388 | 
             
                        cur_vs_hs = vision_hidden_states[i]
         | 
|  | |
| 394 | 
             
                                    [torch.arange(r[0], r[1], dtype=torch.long) for r in cur_image_bound]
         | 
| 395 | 
             
                                ).to(vllm_embedding.device)
         | 
| 396 |  | 
| 397 | 
            +
                                new_vllm_embedding[i] = cur_vllm_emb.scatter(
         | 
| 398 | 
             
                                    0,
         | 
| 399 | 
             
                                    image_indices.view(-1, 1).repeat(1, cur_vllm_emb.shape[-1]),
         | 
| 400 | 
             
                                    cur_vs_hs.view(-1, cur_vs_hs.shape[-1]),
         | 
| 401 | 
             
                                )
         | 
| 402 | 
            +
             | 
| 403 | 
             
                            elif self.training:
         | 
| 404 | 
            +
                                new_vllm_embedding[i] += cur_vs_hs[0].mean() * 0
         | 
| 405 |  | 
| 406 | 
            +
                    return new_vllm_embedding, vision_hidden_states
         | 
| 407 |  | 
| 408 | 
             
                def get_audio_embedding_streaming(self, data):
         | 
| 409 | 
             
                    r"""
         | 
|  | |
| 598 | 
             
                    elif self.training:
         | 
| 599 | 
             
                        for i in range(bs):
         | 
| 600 | 
             
                            # dummy audio_embeddings
         | 
| 601 | 
            +
                            input_embeddings = input_embeddings + audio_embeddings[0].mean() * 0
         | 
| 602 |  | 
| 603 | 
             
                    return input_embeddings
         | 
| 604 |  | 
|  | |
| 754 | 
             
                    input_ids=None,
         | 
| 755 | 
             
                    pixel_values=None,
         | 
| 756 | 
             
                    tgt_sizes=None,
         | 
| 757 | 
            +
                    audio_features=[],
         | 
| 758 | 
             
                    audio_feature_lens=None,
         | 
| 759 | 
             
                    image_bound=None,
         | 
| 760 | 
             
                    audio_bounds=None,
         | 
|  | |
| 2658 | 
             
                """
         | 
| 2659 |  | 
| 2660 | 
             
                config_class = ConditionalChatTTSConfig
         | 
| 2661 | 
            +
                _no_split_modules = []
         | 
| 2662 |  | 
| 2663 | 
             
                def __init__(self, config: ConditionalChatTTSConfig):
         | 
| 2664 | 
             
                    super().__init__(config)
         | 
 
		