|
|
|
from typing import List, Optional, Tuple, Union
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from transformers import AutoConfig, AutoModelForCausalLM, \
|
|
LlamaConfig, LlamaModel, LlamaForCausalLM
|
|
|
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
from transformers.generation.utils import GenerateOutput
|
|
|
|
from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
|
|
from dataclasses import dataclass, field
|
|
from transformers.models.llama.configuration_llama import LlamaConfig
|
|
|
|
class LlavaConfig(LlamaConfig):
|
|
model_type = "llava_llama"
|
|
|
|
attribute_map = {
|
|
**LlamaConfig.attribute_map,
|
|
"vocab_size": "vocab_size",
|
|
"attention_dropout": "attention_dropout",
|
|
}
|
|
|
|
def __init__(self,
|
|
vocab_size=32000,
|
|
attention_dropout=0.1,
|
|
mm_projector_type="mlp2x_gelu",
|
|
mm_hidden_size=1024,
|
|
**kwargs):
|
|
|
|
kwargs.setdefault("vocab_size", vocab_size)
|
|
kwargs.setdefault("attention_dropout", attention_dropout)
|
|
|
|
super().__init__(**kwargs)
|
|
|
|
|
|
self.vocab_size = kwargs.get('vocab_size', vocab_size)
|
|
self.attention_dropout = kwargs.get('attention_dropout', attention_dropout)
|
|
self.mm_projector_type = mm_projector_type
|
|
self.mm_hidden_size = mm_hidden_size
|
|
|
|
cfg = LlavaConfig.from_pretrained("sherlockjjkj/llava-v1.5-7b")
|
|
print("vocab_size:", cfg.vocab_size)
|
|
|