File size: 2,875 Bytes
7ce8e21
 
 
 
 
 
22cd724
7ce8e21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from transformers import PretrainedConfig, PreTrainedModel, AutoModelForCausalLM # Import AutoModelForCausalLM
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from transformers.modeling_outputs import CausalLMOutputWithPast # Import the necessary output class

# Define the main Mesh Layer
class MeshLayer(nn.Module):
    def __init__(self, config: MeshConfig):
        super().__init__()
        self.config = config
        self.router = MeshRouter(config)
        self.experts = nn.ModuleList([MeshExpert(config) for _ in range(config.mesh_grid_size[0] * config.mesh_grid_size[1])])
        self.neighbor_exchange = NeighborExchange(config)
        self.cross_expert_attention = CrossExpertAttention(config)

    def forward(self, hidden_states):
        # hidden_states shape: (batch_size, sequence_length, hidden_size)

        # 1. Routing
        topk_weights, topk_indices = self.router(hidden_states)
        # topk_weights shape: (batch_size, sequence_length, k)
        # topk_indices shape: (batch_size, sequence_length, k)

        # Prepare expert inputs: repeat hidden_states for each expert
        # shape: (batch_size, sequence_length, num_experts, hidden_size)
        expanded_hidden_states = hidden_states.unsqueeze(2).expand(-1, -1, self.config.mesh_grid_size[0] * self.config.mesh_grid_size[1], -1)

        # 2. Expert Computation
        # Compute output for all experts (can be optimized to only compute for selected experts)
        expert_outputs = torch.stack([expert(expanded_hidden_states[:, :, i, :]) for i, expert in enumerate(self.experts)], dim=2)
        # expert_outputs shape: (batch_size, sequence_length, num_experts, hidden_size)

        # 3. Neighbor Exchange (conceptual implementation needed)
        exchanged_expert_outputs = self.neighbor_exchange(expert_outputs, topk_indices)

        # 4. Cross-Expert Attention (conceptual implementation needed)
        cross_attned_expert_outputs = self.cross_expert_attention(exchanged_expert_outputs)

        # 5. Combine expert outputs based on routing weights
        # Create a tensor to gather the outputs of the selected experts
        # shape: (batch_size, sequence_length, k, hidden_size)
        gathered_outputs = torch.gather(
            cross_attned_expert_outputs,
            dim=2,
            index=topk_indices.unsqueeze(-1).expand(-1, -1, -1, self.config.hidden_size)
        )

        # Apply routing weights: (batch_size, sequence_length, k, 1) * (batch_size, sequence_length, k, hidden_size)
        combined_output = (gathered_outputs * topk_weights.unsqueeze(-1)).sum(dim=2)
        # combined_output shape: (batch_size, sequence_length, hidden_size)

        # Return the combined output and the expert indices for potential visualization
        return combined_output, topk_indices # Return combined output and expert indices