File size: 3,758 Bytes
4765ac9
 
 
 
 
 
1f94519
4765ac9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from transformers import PretrainedConfig, PreTrainedModel, AutoModelForCausalLM
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from transformers.modeling_outputs import CausalLMOutputWithPast

class MeshModel(PreTrainedModel):
    config_class = MeshConfig

    def __init__(self, config: MeshConfig):
        super().__init__(config)
        self.config = config

        self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
        self.layers = nn.ModuleList([MeshLayer(config) for _ in range(config.num_hidden_layers)])
        self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        self.post_init()

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        labels=None,
        past_key_values=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        # Ensure return_dict is set to True by default if not specified
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
            input_shape = input_ids.size()
            inputs_embeds = self.embedding(input_ids)
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        hidden_states = inputs_embeds
        expert_indices_list = [] # To collect expert indices from each layer

        for i, layer in enumerate(self.layers):
            hidden_states, expert_indices = layer(hidden_states)
            expert_indices_list.append(expert_indices) # Collect indices

        hidden_states = self.norm(hidden_states)
        logits = self.lm_head(hidden_states)

        loss = None
        if labels is not None:
            # Compute loss (e.g., CrossEntropyLoss)
            loss_fct = nn.CrossEntropyLoss()
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Calculate scalar loss
            loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))

        # Return a CausalLMOutputWithPast object or a tuple
        if return_dict:
             return CausalLMOutputWithPast(
                 loss=loss,
                 logits=logits,
                 past_key_values=None, # Need to implement caching
                 hidden_states=hidden_states,
                 attentions=None, # Need to implement attention handling
             )
        else:
             # Return a tuple including loss, logits, and collected expert indices
             # Ensure the order and content match what the Trainer expects or can handle
             # Trainer expects (loss, logits, hidden_states, attentions) or similar.
             # We can return (loss, logits) as the primary outputs for the Trainer
             # and potentially include expert_indices as an additional output if needed
             # by a custom callback or logging, but the default Trainer expects loss as the first element for backward.
             return (loss, logits, hidden_states, expert_indices_list) # Include expert_indices_list