Upload hunyuan.py with huggingface_hub
Browse files- hunyuan.py +1 -29
hunyuan.py
CHANGED
|
@@ -41,7 +41,6 @@ from transformers.utils.import_utils import is_torch_fx_available
|
|
| 41 |
from transformers.generation.utils import GenerateOutput
|
| 42 |
from .configuration_hunyuan import HunYuanConfig
|
| 43 |
from .modeling_hunyuan import HunYuanDecoderLayer, HunYuanRMSNorm
|
| 44 |
-
from .vit_model import NaVitForward, VitForward, Vit
|
| 45 |
|
| 46 |
|
| 47 |
if is_flash_attn_2_available():
|
|
@@ -363,16 +362,7 @@ class HunYuanMoEV1ForCausalLM(HunYuanPreTrainedModel):
|
|
| 363 |
|
| 364 |
def __init__(self, config: HunYuanConfig):
|
| 365 |
super().__init__(config)
|
| 366 |
-
|
| 367 |
-
if "-tp" in config.vit_type:
|
| 368 |
-
config.vit_type = config.vit_type.replace("-tp", "")
|
| 369 |
-
self.vit_type = config.vit_type
|
| 370 |
-
if self.vit_type not in ['NaVit', 'EvaVit']:
|
| 371 |
-
if config.vit_mapping_type == 'mlp':
|
| 372 |
-
self.vit_linear_encoder = torch.nn.Linear(config.hidden_size, config.hidden_size)
|
| 373 |
-
self.vit = Vit(config)
|
| 374 |
-
else:
|
| 375 |
-
self.vit = None
|
| 376 |
self.config = config
|
| 377 |
self.model = HunYuanModel(config)
|
| 378 |
self.add_classification_head = config.add_classification_head
|
|
@@ -643,15 +633,6 @@ class MultimodelHunYuanForCausalLM(HunYuanMoEV1ForCausalLM):
|
|
| 643 |
video_start_id = self.config.video_start_id
|
| 644 |
video_end_id = self.config.video_end_id
|
| 645 |
|
| 646 |
-
if self.vit is not None and imgs is not None:
|
| 647 |
-
encoder_input = self.model.embed_tokens(input_ids)
|
| 648 |
-
if self.vit_type in ['NaVit', 'EvaVit', 'AnyResVit']:
|
| 649 |
-
inputs_embeds, input_ids = NaVitForward(input_ids, encoder_input, self.vit, imgs, imgs_pos, self.config.vit_input_resolution, \
|
| 650 |
-
im_start_id, im_end_id, image_token_id, self.config.anyres_vit_two_views, self.config.torch_dtype)
|
| 651 |
-
else:
|
| 652 |
-
inputs_embeds, input_ids = VitForward(input_ids, encoder_input, self.vit, self.vit_linear_encoder, imgs, imgs_pos, \
|
| 653 |
-
self.config.vit_input_resolution, self.config.vit_mapping_type, self.config.vit_patch, self.config.vit_token)
|
| 654 |
-
|
| 655 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 656 |
output_hidden_states = (
|
| 657 |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
@@ -738,15 +719,6 @@ class MultimodelHunYuanForCausalLM(HunYuanMoEV1ForCausalLM):
|
|
| 738 |
if "inputs_embeds" in kwargs:
|
| 739 |
raise NotImplementedError("`inputs_embeds` is not supported")
|
| 740 |
|
| 741 |
-
if self.vit is not None:
|
| 742 |
-
encoder_input = self.model.embed_tokens(inputs)
|
| 743 |
-
if self.vit_type in ['NaVit', 'EvaVit', 'AnyResVit']:
|
| 744 |
-
inputs_embeds, input_ids = NaVitForward(inputs, encoder_input, self.vit, imgs, imgs_pos, self.config.vit_input_resolution, \
|
| 745 |
-
self.config.im_start_id, self.config.im_end_id, self.config.image_token_id, self.config.anyres_vit_two_views, self.config.torch_dtype)
|
| 746 |
-
else:
|
| 747 |
-
inputs_embeds, input_ids = VitForward(inputs, encoder_input, self.vit, self.vit_linear_encoder, imgs, imgs_pos, \
|
| 748 |
-
self.config.vit_input_resolution, self.config.vit_mapping_type, self.config.vit_patch, self.config.vit_token)
|
| 749 |
-
|
| 750 |
return super().generate(
|
| 751 |
inputs=input_ids,
|
| 752 |
position_ids=position_ids,
|
|
|
|
| 41 |
from transformers.generation.utils import GenerateOutput
|
| 42 |
from .configuration_hunyuan import HunYuanConfig
|
| 43 |
from .modeling_hunyuan import HunYuanDecoderLayer, HunYuanRMSNorm
|
|
|
|
| 44 |
|
| 45 |
|
| 46 |
if is_flash_attn_2_available():
|
|
|
|
| 362 |
|
| 363 |
def __init__(self, config: HunYuanConfig):
|
| 364 |
super().__init__(config)
|
| 365 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 366 |
self.config = config
|
| 367 |
self.model = HunYuanModel(config)
|
| 368 |
self.add_classification_head = config.add_classification_head
|
|
|
|
| 633 |
video_start_id = self.config.video_start_id
|
| 634 |
video_end_id = self.config.video_end_id
|
| 635 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 636 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 637 |
output_hidden_states = (
|
| 638 |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
|
|
| 719 |
if "inputs_embeds" in kwargs:
|
| 720 |
raise NotImplementedError("`inputs_embeds` is not supported")
|
| 721 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 722 |
return super().generate(
|
| 723 |
inputs=input_ids,
|
| 724 |
position_ids=position_ids,
|