llava-v1.5-7b / Untitled-2.py
sherlockjjkj's picture
Upload 12 files
44f770b verified
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
from transformers import AutoConfig, AutoModelForCausalLM, \
LlamaConfig, LlamaModel, LlamaForCausalLM
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.generation.utils import GenerateOutput
from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
from dataclasses import dataclass, field
from transformers.models.llama.configuration_llama import LlamaConfig
class LlavaConfig(LlamaConfig):
model_type = "llava_llama"
attribute_map = {
**LlamaConfig.attribute_map,
"vocab_size": "vocab_size",
"attention_dropout": "attention_dropout",
}
def __init__(self,
vocab_size=32000,
attention_dropout=0.1,
mm_projector_type="mlp2x_gelu",
mm_hidden_size=1024,
**kwargs):
kwargs.setdefault("vocab_size", vocab_size)
kwargs.setdefault("attention_dropout", attention_dropout)
super().__init__(**kwargs) # Call superclass first
# Explicitly set attributes AFTER superclass init
self.vocab_size = kwargs.get('vocab_size', vocab_size)
self.attention_dropout = kwargs.get('attention_dropout', attention_dropout)
self.mm_projector_type = mm_projector_type
self.mm_hidden_size = mm_hidden_size
cfg = LlavaConfig.from_pretrained("sherlockjjkj/llava-v1.5-7b")
print("vocab_size:", cfg.vocab_size)