File size: 878 Bytes
da1403d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
"""
Configuration class for MoLA-LM
"""

from transformers import PretrainedConfig
from typing import Dict, List

EXPERTS_LIST = [
    "0",
    "1",
    "2",
    "3",
    "4",
    "5",
    "6",
    "7",
    "8",
]


class MoLAConfig(PretrainedConfig):
    """Configuration class for MoLA-LM model."""
    
    model_type = "mola_lm"
    
    def __init__(
        self,
        base_model_name_or_path: str = "Qwen/Qwen2.5-3B-Instruct",
        task_labels: List[str] = None,
        router_config: Dict = None,
        lora_configs: Dict[str, Dict] = None,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.base_model_name_or_path = base_model_name_or_path
        self.task_labels = task_labels or EXPERTS_LIST
        self.router_config = router_config or {}
        self.lora_configs = lora_configs or {}
        self.num_loras = len(self.task_labels)