File size: 1,578 Bytes
44f770b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46

from typing import List, Optional, Tuple, Union

import torch
import torch.nn as nn

from transformers import AutoConfig, AutoModelForCausalLM, \
                         LlamaConfig, LlamaModel, LlamaForCausalLM

from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.generation.utils import GenerateOutput

from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
from dataclasses import dataclass, field
from transformers.models.llama.configuration_llama import LlamaConfig

class LlavaConfig(LlamaConfig):
    model_type = "llava_llama"

    attribute_map = {
        **LlamaConfig.attribute_map,
        "vocab_size": "vocab_size",
        "attention_dropout": "attention_dropout",
    }

    def __init__(self,

                 vocab_size=32000,

                 attention_dropout=0.1,

                 mm_projector_type="mlp2x_gelu",

                 mm_hidden_size=1024,

                 **kwargs):

        kwargs.setdefault("vocab_size", vocab_size)
        kwargs.setdefault("attention_dropout", attention_dropout)
        
        super().__init__(**kwargs)  # Call superclass first

        # Explicitly set attributes AFTER superclass init
        self.vocab_size = kwargs.get('vocab_size', vocab_size)
        self.attention_dropout = kwargs.get('attention_dropout', attention_dropout)
        self.mm_projector_type = mm_projector_type
        self.mm_hidden_size = mm_hidden_size

cfg = LlavaConfig.from_pretrained("sherlockjjkj/llava-v1.5-7b")
print("vocab_size:", cfg.vocab_size)